diff options
-rw-r--r-- | src/libstat/learn_cache/learn_cache.h | 4 | ||||
-rw-r--r-- | src/libstat/learn_cache/sqlite3_cache.c | 125 |
2 files changed, 74 insertions, 55 deletions
diff --git a/src/libstat/learn_cache/learn_cache.h b/src/libstat/learn_cache/learn_cache.h index 05c332076..1ebe2864a 100644 --- a/src/libstat/learn_cache/learn_cache.h +++ b/src/libstat/learn_cache/learn_cache.h @@ -62,11 +62,11 @@ struct rspamd_stat_cache { const ucl_object_t *cf); \ gpointer rspamd_stat_cache_##name##_runtime (struct rspamd_task *task, \ gpointer ctx); \ - gboolean rspamd_stat_cache_##name##_check (struct rspamd_task *task, \ + gint rspamd_stat_cache_##name##_check (struct rspamd_task *task, \ gboolean is_spam, \ gpointer runtime, \ gpointer ctx); \ - gboolean rspamd_stat_cache_##name##_learn (struct rspamd_task *task, \ + gint rspamd_stat_cache_##name##_learn (struct rspamd_task *task, \ gboolean is_spam, \ gpointer runtime, \ gpointer ctx); \ diff --git a/src/libstat/learn_cache/sqlite3_cache.c b/src/libstat/learn_cache/sqlite3_cache.c index 02a6a81ab..6e040595d 100644 --- a/src/libstat/learn_cache/sqlite3_cache.c +++ b/src/libstat/learn_cache/sqlite3_cache.c @@ -132,7 +132,6 @@ rspamd_stat_cache_sqlite3_init (struct rspamd_stat_ctx *ctx, sqlite3 *sqlite; GError *err = NULL; - if (cf) { elt = ucl_object_find_key (cf, "path"); @@ -170,69 +169,30 @@ rspamd_stat_cache_sqlite3_init (struct rspamd_stat_ctx *ctx, return new; } -static rspamd_learn_t -rspamd_stat_cache_sqlite3_check (rspamd_mempool_t *pool, - const guchar *h, gsize len, gboolean is_spam, - struct rspamd_stat_sqlite3_ctx *ctx) +gpointer +rspamd_stat_cache_sqlite3_runtime (struct rspamd_task *task, + gpointer ctx) { - gint rc, ret = RSPAMD_LEARN_OK; - gint64 flag; - - rspamd_sqlite3_run_prstmt (pool, ctx->db, ctx->prstmt, - RSPAMD_STAT_CACHE_TRANSACTION_START_DEF); - rc = rspamd_sqlite3_run_prstmt (pool, ctx->db, ctx->prstmt, - RSPAMD_STAT_CACHE_GET_LEARN, (gint64)len, h, &flag); - rspamd_sqlite3_run_prstmt (pool, ctx->db, ctx->prstmt, - RSPAMD_STAT_CACHE_TRANSACTION_COMMIT); - - - if (rc == SQLITE_OK) { - /* We have some existing record in the table */ - if (!!flag == !!is_spam) { - /* Already learned */ - - ret = RSPAMD_LEARN_INGORE; - } - else { - /* Need to relearn */ - flag = !!is_spam ? 1 : 0; - - rspamd_sqlite3_run_prstmt (pool, ctx->db, ctx->prstmt, - RSPAMD_STAT_CACHE_TRANSACTION_START_IM); - rc = rspamd_sqlite3_run_prstmt (pool, ctx->db, ctx->prstmt, - RSPAMD_STAT_CACHE_UPDATE_LEARN, flag, (gint64)len, h); - rspamd_sqlite3_run_prstmt (pool, ctx->db, ctx->prstmt, - RSPAMD_STAT_CACHE_TRANSACTION_COMMIT); - - return RSPAMD_LEARN_UNLEARN; - } - } - else { - /* Insert result new id */ - flag = !!is_spam ? 1 : 0; - rspamd_sqlite3_run_prstmt (pool, ctx->db, ctx->prstmt, - RSPAMD_STAT_CACHE_TRANSACTION_START_IM); - rspamd_sqlite3_run_prstmt (pool, ctx->db, ctx->prstmt, - RSPAMD_STAT_CACHE_ADD_LEARN, (gint64)len, h, flag); - rspamd_sqlite3_run_prstmt (pool, ctx->db, ctx->prstmt, - RSPAMD_STAT_CACHE_TRANSACTION_COMMIT); - } - - return ret; + /* No need of runtime for this type of classifier */ + return NULL; } gint -rspamd_stat_cache_sqlite3_process (struct rspamd_task *task, +rspamd_stat_cache_sqlite3_check (struct rspamd_task *task, gboolean is_spam, gpointer c) { struct rspamd_stat_sqlite3_ctx *ctx = (struct rspamd_stat_sqlite3_ctx *)c; struct mime_text_part *part; rspamd_cryptobox_hash_state_t st; rspamd_ftok_t *word; - guchar out[rspamd_cryptobox_HASHBYTES]; + guchar *out; guint i, j; + gint rc; + gint64 flag; if (ctx != NULL && ctx->db != NULL) { + out = rspamd_mempool_alloc (task->task_pool, rspamd_cryptobox_HASHBYTES); + rspamd_cryptobox_hash_init (&st, NULL, 0); for (i = 0; i < task->text_parts->len; i ++) { @@ -248,8 +208,67 @@ rspamd_stat_cache_sqlite3_process (struct rspamd_task *task, rspamd_cryptobox_hash_final (&st, out); - return rspamd_stat_cache_sqlite3_check (task->task_pool, - out, sizeof (out), is_spam, ctx); + rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_START_DEF); + rc = rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_GET_LEARN, (gint64)rspamd_cryptobox_HASHBYTES, + out, &flag); + rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_COMMIT); + + /* Save hash into variables */ + rspamd_mempool_set_variable (task->task_pool, "words_hash", out, NULL); + + if (rc == SQLITE_OK) { + /* We have some existing record in the table */ + if (!!flag == !!is_spam) { + /* Already learned */ + return RSPAMD_LEARN_INGORE; + } + else { + /* Need to relearn */ + return RSPAMD_LEARN_UNLEARN; + } + } + else { + + } + } + + return RSPAMD_LEARN_OK; +} + +gint +rspamd_stat_cache_sqlite3_learn (struct rspamd_task *task, + gboolean is_spam, gpointer c) +{ + struct rspamd_stat_sqlite3_ctx *ctx = (struct rspamd_stat_sqlite3_ctx *)c; + gboolean unlearn = !!(task->flags & RSPAMD_TASK_FLAG_UNLEARN); + guchar *h; + gint64 flag; + + h = rspamd_mempool_get_variable (task->task_pool, "words_hash"); + g_assert (h != NULL); + + if (!unlearn) { + /* Insert result new id */ + flag = !!is_spam ? 1 : 0; + rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_START_IM); + rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_ADD_LEARN, + (gint64)rspamd_cryptobox_HASHBYTES, h, flag); + rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_COMMIT); + } + else { + rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_START_IM); + rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_UPDATE_LEARN, task->task_pool, + (gint64)rspamd_cryptobox_HASHBYTES, h); + rspamd_sqlite3_run_prstmt (task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_COMMIT); } return RSPAMD_LEARN_OK; |