diff options
-rw-r--r-- | src/libstat/backends/backends.h | 3 | ||||
-rw-r--r-- | src/libstat/backends/mmaped_file.c | 40 | ||||
-rw-r--r-- | src/libstat/classifiers/bayes.c | 7 | ||||
-rw-r--r-- | src/libstat/stat_config.c | 1 | ||||
-rw-r--r-- | src/libstat/stat_process.c | 36 |
5 files changed, 72 insertions, 15 deletions
diff --git a/src/libstat/backends/backends.h b/src/libstat/backends/backends.h index c7c4210fb..f8a2af72c 100644 --- a/src/libstat/backends/backends.h +++ b/src/libstat/backends/backends.h @@ -49,6 +49,7 @@ struct rspamd_stat_backend { struct rspamd_token_result *res, gpointer ctx); gulong (*total_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx); gulong (*inc_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx); + gulong (*dec_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx); ucl_object_t* (*get_stat)(struct rspamd_statfile_runtime *runtime, gpointer ctx); gpointer ctx; }; @@ -66,6 +67,8 @@ gulong rspamd_mmaped_file_total_learns (struct rspamd_statfile_runtime *runtime, gpointer ctx); gulong rspamd_mmaped_file_inc_learns (struct rspamd_statfile_runtime *runtime, gpointer ctx); +gulong rspamd_mmaped_file_dec_learns (struct rspamd_statfile_runtime *runtime, + gpointer ctx); ucl_object_t * rspamd_mmaped_file_get_stat (struct rspamd_statfile_runtime *runtime, gpointer ctx); diff --git a/src/libstat/backends/mmaped_file.c b/src/libstat/backends/mmaped_file.c index 0fb386f61..02ea17c28 100644 --- a/src/libstat/backends/mmaped_file.c +++ b/src/libstat/backends/mmaped_file.c @@ -291,6 +291,23 @@ rspamd_mmaped_file_inc_revision (rspamd_mmaped_file_t *file) } gboolean +rspamd_mmaped_file_dec_revision (rspamd_mmaped_file_t *file) +{ + struct stat_file_header *header; + + if (file == NULL || file->map == NULL) { + return FALSE; + } + + header = (struct stat_file_header *)file->map; + + header->revision--; + + return TRUE; +} + + +gboolean rspamd_mmaped_file_get_revision (rspamd_mmaped_file_t *file, guint64 *rev, time_t *time) { struct stat_file_header *header; @@ -939,11 +956,7 @@ rspamd_mmaped_file_learn_token (rspamd_token_t *tok, memcpy (&h2, tok->data + sizeof (h1), sizeof (h2)); rspamd_mmaped_file_set_block (ctx, mf, h1, h2, res->value); - if (res->value > 0.0) { - return TRUE; - } - - return FALSE; + return TRUE; } gulong @@ -977,6 +990,23 @@ rspamd_mmaped_file_inc_learns (struct rspamd_statfile_runtime *runtime, return rev; } +gulong +rspamd_mmaped_file_dec_learns (struct rspamd_statfile_runtime *runtime, + gpointer ctx) +{ + rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *)runtime; + guint64 rev = 0; + time_t t; + + if (mf != NULL) { + rspamd_mmaped_file_dec_revision (mf); + rspamd_mmaped_file_get_revision (mf, &rev, &t); + } + + return rev; +} + + ucl_object_t * rspamd_mmaped_file_get_stat (struct rspamd_statfile_runtime *runtime, gpointer ctx) diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c index be6c6f545..7932ceb9e 100644 --- a/src/libstat/classifiers/bayes.c +++ b/src/libstat/classifiers/bayes.c @@ -221,6 +221,10 @@ bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data) if (res->st_runtime->st->is_spam) { res->value ++; } + else if (res->value > 0) { + /* Unlearning */ + res->value --; + } } return FALSE; @@ -241,6 +245,9 @@ bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data) if (!res->st_runtime->st->is_spam) { res->value ++; } + else if (res->value > 0) { + res->value --; + } } return FALSE; diff --git a/src/libstat/stat_config.c b/src/libstat/stat_config.c index b8ad6ec30..17b5c54f5 100644 --- a/src/libstat/stat_config.c +++ b/src/libstat/stat_config.c @@ -53,6 +53,7 @@ static struct rspamd_stat_backend stat_backends[] = { .learn_token = rspamd_mmaped_file_learn_token, .total_learns = rspamd_mmaped_file_total_learns, .inc_learns = rspamd_mmaped_file_inc_learns, + .dec_learns = rspamd_mmaped_file_dec_learns, .get_stat = rspamd_mmaped_file_get_stat } }; diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 8b2a19429..1ce439c51 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -29,12 +29,17 @@ #include "lua/lua_common.h" #include <utlist.h> +#define RSPAMD_CLASSIFY_OP 0 +#define RSPAMD_LEARN_OP 1 +#define RSPAMD_UNLEARN_OP 2 + struct preprocess_cb_data { struct rspamd_task *task; GList *classifier_runtimes; struct rspamd_tokenizer_runtime *tok; guint results_count; gboolean unlearn; + gboolean spam; }; static struct rspamd_tokenizer_runtime * @@ -135,7 +140,7 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d) static GList* rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, struct rspamd_task *task, struct rspamd_tokenizer_runtime *tklist, - lua_State *L, gboolean learn, gboolean spam, GError **err) + lua_State *L, gint op, gboolean spam, GError **err) { struct rspamd_classifier_config *clcf; struct rspamd_statfile_config *stcf; @@ -186,7 +191,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, stcf = (struct rspamd_statfile_config *)curst->data; /* On learning skip statfiles that do not belong to class */ - if (learn && (spam != stcf->is_spam)) { + if (op == RSPAMD_LEARN_OP && (spam != stcf->is_spam)) { curst = g_list_next (curst); continue; } @@ -199,7 +204,8 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, continue; } - backend_runtime = bk->runtime (stcf, learn, bk->ctx); + backend_runtime = bk->runtime (stcf, op != RSPAMD_CLASSIFY_OP, + bk->ctx); st_runtime = rspamd_mempool_alloc0 (task->task_pool, sizeof (*st_runtime)); @@ -354,7 +360,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err) /* Initialize classifiers and statfiles runtime */ if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, - FALSE, FALSE, err)) == NULL) { + RSPAMD_CLASSIFY_OP, FALSE, err)) == NULL) { return RSPAMD_STAT_PROCESS_ERROR; } @@ -407,11 +413,12 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d) continue; } - res = &g_array_index (t->results, struct rspamd_token_result, i); - curst = res->cl_runtime->st_runtime; + + curst = cl_runtime->st_runtime; while (curst) { + res = &g_array_index (t->results, struct rspamd_token_result, i); st_runtime = (struct rspamd_statfile_runtime *)curst->data; if (st_runtime->backend->learn_token (t, res, @@ -432,6 +439,7 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d) i ++; curst = g_list_next (curst); } + cur = g_list_next (cur); } @@ -507,7 +515,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, /* Initialize classifiers and statfiles runtime */ if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, - TRUE, spam, err)) == NULL) { + unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP, spam, err)) == NULL) { return RSPAMD_STAT_PROCESS_ERROR; } @@ -530,6 +538,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, cbdata.task = task; cbdata.tok = cl_run->tok; cbdata.unlearn = unlearn; + cbdata.spam = spam; g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token, &cbdata); @@ -538,11 +547,18 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, while (curst) { st_run = (struct rspamd_statfile_runtime *)curst->data; - nrev = st_run->backend->inc_learns (st_run->backend_runtime, + if (unlearn && spam != st_run->st->is_spam) { + nrev = st_run->backend->dec_learns (st_run->backend_runtime, + st_run->backend->ctx); + msg_debug ("unlearned %s, new revision: %ul", + st_run->st->symbol, nrev); + } + else { + nrev = st_run->backend->inc_learns (st_run->backend_runtime, st_run->backend->ctx); - - msg_debug ("learned %s, new revision: %ul", + msg_debug ("learned %s, new revision: %ul", st_run->st->symbol, nrev); + } curst = g_list_next (curst); } |