You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sqlite3_cache.c 7.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. /*-
  2. * Copyright 2016 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "config.h"
  17. #include "learn_cache.h"
  18. #include "rspamd.h"
  19. #include "stat_api.h"
  20. #include "stat_internal.h"
  21. #include "cryptobox.h"
  22. #include "ucl.h"
  23. #include "fstring.h"
  24. #include "message.h"
  25. #include "libutil/sqlite_utils.h"
  26. static const char *create_tables_sql =
  27. ""
  28. "CREATE TABLE IF NOT EXISTS learns("
  29. "id INTEGER PRIMARY KEY,"
  30. "flag INTEGER NOT NULL,"
  31. "digest TEXT NOT NULL);"
  32. "CREATE UNIQUE INDEX IF NOT EXISTS d ON learns(digest);"
  33. "";
  34. #define SQLITE_CACHE_PATH RSPAMD_DBDIR "/learn_cache.sqlite"
  35. enum rspamd_stat_sqlite3_stmt_idx {
  36. RSPAMD_STAT_CACHE_TRANSACTION_START_IM = 0,
  37. RSPAMD_STAT_CACHE_TRANSACTION_START_DEF,
  38. RSPAMD_STAT_CACHE_TRANSACTION_COMMIT,
  39. RSPAMD_STAT_CACHE_TRANSACTION_ROLLBACK,
  40. RSPAMD_STAT_CACHE_GET_LEARN,
  41. RSPAMD_STAT_CACHE_ADD_LEARN,
  42. RSPAMD_STAT_CACHE_UPDATE_LEARN,
  43. RSPAMD_STAT_CACHE_MAX
  44. };
  45. static struct rspamd_sqlite3_prstmt prepared_stmts[RSPAMD_STAT_CACHE_MAX] =
  46. {
  47. {
  48. .idx = RSPAMD_STAT_CACHE_TRANSACTION_START_IM,
  49. .sql = "BEGIN IMMEDIATE TRANSACTION;",
  50. .args = "",
  51. .stmt = NULL,
  52. .result = SQLITE_DONE,
  53. .ret = ""
  54. },
  55. {
  56. .idx = RSPAMD_STAT_CACHE_TRANSACTION_START_DEF,
  57. .sql = "BEGIN DEFERRED TRANSACTION;",
  58. .args = "",
  59. .stmt = NULL,
  60. .result = SQLITE_DONE,
  61. .ret = ""
  62. },
  63. {
  64. .idx = RSPAMD_STAT_CACHE_TRANSACTION_COMMIT,
  65. .sql = "COMMIT;",
  66. .args = "",
  67. .stmt = NULL,
  68. .result = SQLITE_DONE,
  69. .ret = ""
  70. },
  71. {
  72. .idx = RSPAMD_STAT_CACHE_TRANSACTION_ROLLBACK,
  73. .sql = "ROLLBACK;",
  74. .args = "",
  75. .stmt = NULL,
  76. .result = SQLITE_DONE,
  77. .ret = ""
  78. },
  79. {
  80. .idx = RSPAMD_STAT_CACHE_GET_LEARN,
  81. .sql = "SELECT flag FROM learns WHERE digest=?1",
  82. .args = "V",
  83. .stmt = NULL,
  84. .result = SQLITE_ROW,
  85. .ret = "I"
  86. },
  87. {
  88. .idx = RSPAMD_STAT_CACHE_ADD_LEARN,
  89. .sql = "INSERT INTO learns(digest, flag) VALUES (?1, ?2);",
  90. .args = "VI",
  91. .stmt = NULL,
  92. .result = SQLITE_DONE,
  93. .ret = ""
  94. },
  95. {
  96. .idx = RSPAMD_STAT_CACHE_UPDATE_LEARN,
  97. .sql = "UPDATE learns SET flag=?1 WHERE digest=?2;",
  98. .args = "IV",
  99. .stmt = NULL,
  100. .result = SQLITE_DONE,
  101. .ret = ""
  102. }
  103. };
  104. struct rspamd_stat_sqlite3_ctx {
  105. sqlite3 *db;
  106. GArray *prstmt;
  107. };
  108. gpointer
  109. rspamd_stat_cache_sqlite3_init (struct rspamd_stat_ctx *ctx,
  110. struct rspamd_config *cfg,
  111. struct rspamd_statfile *st,
  112. const ucl_object_t *cf)
  113. {
  114. struct rspamd_stat_sqlite3_ctx *new = NULL;
  115. const ucl_object_t *elt;
  116. gchar dbpath[PATH_MAX];
  117. const gchar *path = SQLITE_CACHE_PATH;
  118. sqlite3 *sqlite;
  119. GError *err = NULL;
  120. if (cf) {
  121. elt = ucl_object_find_any_key (cf, "path", "file", NULL);
  122. if (elt != NULL) {
  123. path = ucl_object_tostring (elt);
  124. }
  125. }
  126. rspamd_snprintf (dbpath, sizeof (dbpath), "%s", path);
  127. sqlite = rspamd_sqlite3_open_or_create (cfg->cfg_pool,
  128. dbpath, create_tables_sql, &err);
  129. if (sqlite == NULL) {
  130. msg_err ("cannot open sqlite3 cache: %e", err);
  131. g_error_free (err);
  132. err = NULL;
  133. }
  134. else {
  135. new = g_slice_alloc (sizeof (*new));
  136. new->db = sqlite;
  137. new->prstmt = rspamd_sqlite3_init_prstmt (sqlite, prepared_stmts,
  138. RSPAMD_STAT_CACHE_MAX, &err);
  139. if (new->prstmt == NULL) {
  140. msg_err ("cannot open sqlite3 cache: %e", err);
  141. g_error_free (err);
  142. err = NULL;
  143. sqlite3_close (sqlite);
  144. g_slice_free1 (sizeof (*new), new);
  145. new = NULL;
  146. }
  147. }
  148. return new;
  149. }
  150. gpointer
  151. rspamd_stat_cache_sqlite3_runtime (struct rspamd_task *task,
  152. gpointer ctx, gboolean learn)
  153. {
  154. /* No need of runtime for this type of classifier */
  155. return ctx;
  156. }
  157. gint
  158. rspamd_stat_cache_sqlite3_check (struct rspamd_task *task,
  159. gboolean is_spam,
  160. gpointer runtime)
  161. {
  162. struct rspamd_stat_sqlite3_ctx *ctx = runtime;
  163. rspamd_cryptobox_hash_state_t st;
  164. rspamd_token_t *tok;
  165. guchar *out;
  166. gchar *user = NULL;
  167. guint i;
  168. gint rc;
  169. gint64 flag;
  170. if (ctx != NULL && ctx->db != NULL) {
  171. out = rspamd_mempool_alloc (task->task_pool, rspamd_cryptobox_HASHBYTES);
  172. rspamd_cryptobox_hash_init (&st, NULL, 0);
  173. user = rspamd_mempool_get_variable (task->task_pool, "stat_user");
  174. /* Use dedicated hash space for per users cache */
  175. if (user != NULL) {
  176. rspamd_cryptobox_hash_update (&st, user, strlen (user));
  177. }
  178. for (i = 0; i < task->tokens->len; i ++) {
  179. tok = g_ptr_array_index (task->tokens, i);
  180. rspamd_cryptobox_hash_update (&st, tok->data, tok->datalen);
  181. }
  182. rspamd_cryptobox_hash_final (&st, out);
  183. rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt,
  184. RSPAMD_STAT_CACHE_TRANSACTION_START_DEF);
  185. rc = rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt,
  186. RSPAMD_STAT_CACHE_GET_LEARN, (gint64)rspamd_cryptobox_HASHBYTES,
  187. out, &flag);
  188. rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt,
  189. RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
  190. /* Save hash into variables */
  191. rspamd_mempool_set_variable (task->task_pool, "words_hash", out, NULL);
  192. if (rc == SQLITE_OK) {
  193. /* We have some existing record in the table */
  194. if (!!flag == !!is_spam) {
  195. /* Already learned */
  196. return RSPAMD_LEARN_INGORE;
  197. }
  198. else {
  199. /* Need to relearn */
  200. return RSPAMD_LEARN_UNLEARN;
  201. }
  202. }
  203. else {
  204. }
  205. }
  206. return RSPAMD_LEARN_OK;
  207. }
  208. gint
  209. rspamd_stat_cache_sqlite3_learn (struct rspamd_task *task,
  210. gboolean is_spam,
  211. gpointer runtime)
  212. {
  213. struct rspamd_stat_sqlite3_ctx *ctx = runtime;
  214. gboolean unlearn = !!(task->flags & RSPAMD_TASK_FLAG_UNLEARN);
  215. guchar *h;
  216. gint64 flag;
  217. h = rspamd_mempool_get_variable (task->task_pool, "words_hash");
  218. g_assert (h != NULL);
  219. flag = !!is_spam ? 1 : 0;
  220. if (!unlearn) {
  221. /* Insert result new id */
  222. rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt,
  223. RSPAMD_STAT_CACHE_TRANSACTION_START_IM);
  224. rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt,
  225. RSPAMD_STAT_CACHE_ADD_LEARN,
  226. (gint64)rspamd_cryptobox_HASHBYTES, h, flag);
  227. rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt,
  228. RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
  229. }
  230. else {
  231. rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt,
  232. RSPAMD_STAT_CACHE_TRANSACTION_START_IM);
  233. rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt,
  234. RSPAMD_STAT_CACHE_UPDATE_LEARN,
  235. flag,
  236. (gint64)rspamd_cryptobox_HASHBYTES, h);
  237. rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt,
  238. RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
  239. }
  240. rspamd_sqlite3_sync (ctx->db, NULL, NULL);
  241. return RSPAMD_LEARN_OK;
  242. }
  243. void
  244. rspamd_stat_cache_sqlite3_close (gpointer c)
  245. {
  246. struct rspamd_stat_sqlite3_ctx *ctx = (struct rspamd_stat_sqlite3_ctx *)c;
  247. if (ctx != NULL) {
  248. rspamd_sqlite3_close_prstmt (ctx->db, ctx->prstmt);
  249. sqlite3_close (ctx->db);
  250. g_slice_free1 (sizeof (*ctx), ctx);
  251. }
  252. }