Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

sqlite3_cache.c 7.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. /*
  2. * Copyright 2024 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "config.h"
  17. #include "learn_cache.h"
  18. #include "rspamd.h"
  19. #include "stat_api.h"
  20. #include "stat_internal.h"
  21. #include "cryptobox.h"
  22. #include "ucl.h"
  23. #include "fstring.h"
  24. #include "message.h"
  25. #include "libutil/sqlite_utils.h"
  26. static const char *create_tables_sql =
  27. ""
  28. "CREATE TABLE IF NOT EXISTS learns("
  29. "id INTEGER PRIMARY KEY,"
  30. "flag INTEGER NOT NULL,"
  31. "digest TEXT NOT NULL);"
  32. "CREATE UNIQUE INDEX IF NOT EXISTS d ON learns(digest);"
  33. "";
  34. #define SQLITE_CACHE_PATH RSPAMD_DBDIR "/learn_cache.sqlite"
  35. enum rspamd_stat_sqlite3_stmt_idx {
  36. RSPAMD_STAT_CACHE_TRANSACTION_START_IM = 0,
  37. RSPAMD_STAT_CACHE_TRANSACTION_START_DEF,
  38. RSPAMD_STAT_CACHE_TRANSACTION_COMMIT,
  39. RSPAMD_STAT_CACHE_TRANSACTION_ROLLBACK,
  40. RSPAMD_STAT_CACHE_GET_LEARN,
  41. RSPAMD_STAT_CACHE_ADD_LEARN,
  42. RSPAMD_STAT_CACHE_UPDATE_LEARN,
  43. RSPAMD_STAT_CACHE_MAX
  44. };
  45. static struct rspamd_sqlite3_prstmt prepared_stmts[RSPAMD_STAT_CACHE_MAX] =
  46. {
  47. {.idx = RSPAMD_STAT_CACHE_TRANSACTION_START_IM,
  48. .sql = "BEGIN IMMEDIATE TRANSACTION;",
  49. .args = "",
  50. .stmt = NULL,
  51. .result = SQLITE_DONE,
  52. .ret = ""},
  53. {.idx = RSPAMD_STAT_CACHE_TRANSACTION_START_DEF,
  54. .sql = "BEGIN DEFERRED TRANSACTION;",
  55. .args = "",
  56. .stmt = NULL,
  57. .result = SQLITE_DONE,
  58. .ret = ""},
  59. {.idx = RSPAMD_STAT_CACHE_TRANSACTION_COMMIT,
  60. .sql = "COMMIT;",
  61. .args = "",
  62. .stmt = NULL,
  63. .result = SQLITE_DONE,
  64. .ret = ""},
  65. {.idx = RSPAMD_STAT_CACHE_TRANSACTION_ROLLBACK,
  66. .sql = "ROLLBACK;",
  67. .args = "",
  68. .stmt = NULL,
  69. .result = SQLITE_DONE,
  70. .ret = ""},
  71. {.idx = RSPAMD_STAT_CACHE_GET_LEARN,
  72. .sql = "SELECT flag FROM learns WHERE digest=?1",
  73. .args = "V",
  74. .stmt = NULL,
  75. .result = SQLITE_ROW,
  76. .ret = "I"},
  77. {.idx = RSPAMD_STAT_CACHE_ADD_LEARN,
  78. .sql = "INSERT INTO learns(digest, flag) VALUES (?1, ?2);",
  79. .args = "VI",
  80. .stmt = NULL,
  81. .result = SQLITE_DONE,
  82. .ret = ""},
  83. {.idx = RSPAMD_STAT_CACHE_UPDATE_LEARN,
  84. .sql = "UPDATE learns SET flag=?1 WHERE digest=?2;",
  85. .args = "IV",
  86. .stmt = NULL,
  87. .result = SQLITE_DONE,
  88. .ret = ""}};
  89. struct rspamd_stat_sqlite3_ctx {
  90. sqlite3 *db;
  91. GArray *prstmt;
  92. };
  93. gpointer
  94. rspamd_stat_cache_sqlite3_init(struct rspamd_stat_ctx *ctx,
  95. struct rspamd_config *cfg,
  96. struct rspamd_statfile *st,
  97. const ucl_object_t *cf)
  98. {
  99. struct rspamd_stat_sqlite3_ctx *new = NULL;
  100. const ucl_object_t *elt;
  101. char dbpath[PATH_MAX];
  102. const char *path = SQLITE_CACHE_PATH;
  103. sqlite3 *sqlite;
  104. GError *err = NULL;
  105. if (cf) {
  106. elt = ucl_object_lookup_any(cf, "path", "file", NULL);
  107. if (elt != NULL) {
  108. path = ucl_object_tostring(elt);
  109. }
  110. }
  111. rspamd_snprintf(dbpath, sizeof(dbpath), "%s", path);
  112. sqlite = rspamd_sqlite3_open_or_create(cfg->cfg_pool,
  113. dbpath, create_tables_sql, 0, &err);
  114. if (sqlite == NULL) {
  115. msg_err("cannot open sqlite3 cache: %e", err);
  116. g_error_free(err);
  117. err = NULL;
  118. }
  119. else {
  120. new = g_malloc0(sizeof(*new));
  121. new->db = sqlite;
  122. new->prstmt = rspamd_sqlite3_init_prstmt(sqlite, prepared_stmts,
  123. RSPAMD_STAT_CACHE_MAX, &err);
  124. if (new->prstmt == NULL) {
  125. msg_err("cannot open sqlite3 cache: %e", err);
  126. g_error_free(err);
  127. err = NULL;
  128. sqlite3_close(sqlite);
  129. g_free(new);
  130. new = NULL;
  131. }
  132. }
  133. return new;
  134. }
  135. gpointer
  136. rspamd_stat_cache_sqlite3_runtime(struct rspamd_task *task,
  137. gpointer ctx, gboolean learn)
  138. {
  139. /* No need of runtime for this type of classifier */
  140. return ctx;
  141. }
  142. int rspamd_stat_cache_sqlite3_check(struct rspamd_task *task,
  143. gboolean is_spam,
  144. gpointer runtime)
  145. {
  146. struct rspamd_stat_sqlite3_ctx *ctx = runtime;
  147. rspamd_cryptobox_hash_state_t st;
  148. rspamd_token_t *tok;
  149. unsigned char *out;
  150. char *user = NULL;
  151. unsigned int i;
  152. int rc;
  153. int64_t flag;
  154. if (task->tokens == NULL || task->tokens->len == 0) {
  155. return RSPAMD_LEARN_IGNORE;
  156. }
  157. if (ctx != NULL && ctx->db != NULL) {
  158. out = rspamd_mempool_alloc(task->task_pool, rspamd_cryptobox_HASHBYTES);
  159. rspamd_cryptobox_hash_init(&st, NULL, 0);
  160. user = rspamd_mempool_get_variable(task->task_pool, "stat_user");
  161. /* Use dedicated hash space for per users cache */
  162. if (user != NULL) {
  163. rspamd_cryptobox_hash_update(&st, user, strlen(user));
  164. }
  165. for (i = 0; i < task->tokens->len; i++) {
  166. tok = g_ptr_array_index(task->tokens, i);
  167. rspamd_cryptobox_hash_update(&st, (unsigned char *) &tok->data,
  168. sizeof(tok->data));
  169. }
  170. rspamd_cryptobox_hash_final(&st, out);
  171. rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
  172. RSPAMD_STAT_CACHE_TRANSACTION_START_DEF);
  173. rc = rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
  174. RSPAMD_STAT_CACHE_GET_LEARN, (int64_t) rspamd_cryptobox_HASHBYTES,
  175. out, &flag);
  176. rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
  177. RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
  178. /* Save hash into variables */
  179. rspamd_mempool_set_variable(task->task_pool, "words_hash", out, NULL);
  180. if (rc == SQLITE_OK) {
  181. /* We have some existing record in the table */
  182. if (!!flag == !!is_spam) {
  183. /* Already learned */
  184. msg_warn_task("already seen stat hash: %*bs",
  185. rspamd_cryptobox_HASHBYTES, out);
  186. return RSPAMD_LEARN_IGNORE;
  187. }
  188. else {
  189. /* Need to relearn */
  190. return RSPAMD_LEARN_UNLEARN;
  191. }
  192. }
  193. }
  194. return RSPAMD_LEARN_OK;
  195. }
  196. int rspamd_stat_cache_sqlite3_learn(struct rspamd_task *task,
  197. gboolean is_spam,
  198. gpointer runtime)
  199. {
  200. struct rspamd_stat_sqlite3_ctx *ctx = runtime;
  201. gboolean unlearn = !!(task->flags & RSPAMD_TASK_FLAG_UNLEARN);
  202. unsigned char *h;
  203. int64_t flag;
  204. h = rspamd_mempool_get_variable(task->task_pool, "words_hash");
  205. if (h == NULL) {
  206. return RSPAMD_LEARN_IGNORE;
  207. }
  208. flag = !!is_spam ? 1 : 0;
  209. if (!unlearn) {
  210. /* Insert result new id */
  211. rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
  212. RSPAMD_STAT_CACHE_TRANSACTION_START_IM);
  213. rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
  214. RSPAMD_STAT_CACHE_ADD_LEARN,
  215. (int64_t) rspamd_cryptobox_HASHBYTES, h, flag);
  216. rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
  217. RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
  218. }
  219. else {
  220. rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
  221. RSPAMD_STAT_CACHE_TRANSACTION_START_IM);
  222. rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
  223. RSPAMD_STAT_CACHE_UPDATE_LEARN,
  224. flag,
  225. (int64_t) rspamd_cryptobox_HASHBYTES, h);
  226. rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
  227. RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
  228. }
  229. rspamd_sqlite3_sync(ctx->db, NULL, NULL);
  230. return RSPAMD_LEARN_OK;
  231. }
  232. void rspamd_stat_cache_sqlite3_close(gpointer c)
  233. {
  234. struct rspamd_stat_sqlite3_ctx *ctx = (struct rspamd_stat_sqlite3_ctx *) c;
  235. if (ctx != NULL) {
  236. rspamd_sqlite3_close_prstmt(ctx->db, ctx->prstmt);
  237. sqlite3_close(ctx->db);
  238. g_free(ctx);
  239. }
  240. }