diff options
Diffstat (limited to 'src/libstat/stat_process.c')
-rw-r--r-- | src/libstat/stat_process.c | 36 |
1 files changed, 26 insertions, 10 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 8b2a19429..1ce439c51 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -29,12 +29,17 @@ #include "lua/lua_common.h" #include <utlist.h> +#define RSPAMD_CLASSIFY_OP 0 +#define RSPAMD_LEARN_OP 1 +#define RSPAMD_UNLEARN_OP 2 + struct preprocess_cb_data { struct rspamd_task *task; GList *classifier_runtimes; struct rspamd_tokenizer_runtime *tok; guint results_count; gboolean unlearn; + gboolean spam; }; static struct rspamd_tokenizer_runtime * @@ -135,7 +140,7 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d) static GList* rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, struct rspamd_task *task, struct rspamd_tokenizer_runtime *tklist, - lua_State *L, gboolean learn, gboolean spam, GError **err) + lua_State *L, gint op, gboolean spam, GError **err) { struct rspamd_classifier_config *clcf; struct rspamd_statfile_config *stcf; @@ -186,7 +191,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, stcf = (struct rspamd_statfile_config *)curst->data; /* On learning skip statfiles that do not belong to class */ - if (learn && (spam != stcf->is_spam)) { + if (op == RSPAMD_LEARN_OP && (spam != stcf->is_spam)) { curst = g_list_next (curst); continue; } @@ -199,7 +204,8 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, continue; } - backend_runtime = bk->runtime (stcf, learn, bk->ctx); + backend_runtime = bk->runtime (stcf, op != RSPAMD_CLASSIFY_OP, + bk->ctx); st_runtime = rspamd_mempool_alloc0 (task->task_pool, sizeof (*st_runtime)); @@ -354,7 +360,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err) /* Initialize classifiers and statfiles runtime */ if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, - FALSE, FALSE, err)) == NULL) { + RSPAMD_CLASSIFY_OP, FALSE, err)) == NULL) { return RSPAMD_STAT_PROCESS_ERROR; } @@ -407,11 +413,12 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d) continue; } - res = &g_array_index (t->results, struct rspamd_token_result, i); - curst = res->cl_runtime->st_runtime; + + curst = cl_runtime->st_runtime; while (curst) { + res = &g_array_index (t->results, struct rspamd_token_result, i); st_runtime = (struct rspamd_statfile_runtime *)curst->data; if (st_runtime->backend->learn_token (t, res, @@ -432,6 +439,7 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d) i ++; curst = g_list_next (curst); } + cur = g_list_next (cur); } @@ -507,7 +515,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, /* Initialize classifiers and statfiles runtime */ if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, - TRUE, spam, err)) == NULL) { + unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP, spam, err)) == NULL) { return RSPAMD_STAT_PROCESS_ERROR; } @@ -530,6 +538,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, cbdata.task = task; cbdata.tok = cl_run->tok; cbdata.unlearn = unlearn; + cbdata.spam = spam; g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token, &cbdata); @@ -538,11 +547,18 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, while (curst) { st_run = (struct rspamd_statfile_runtime *)curst->data; - nrev = st_run->backend->inc_learns (st_run->backend_runtime, + if (unlearn && spam != st_run->st->is_spam) { + nrev = st_run->backend->dec_learns (st_run->backend_runtime, + st_run->backend->ctx); + msg_debug ("unlearned %s, new revision: %ul", + st_run->st->symbol, nrev); + } + else { + nrev = st_run->backend->inc_learns (st_run->backend_runtime, st_run->backend->ctx); - - msg_debug ("learned %s, new revision: %ul", + msg_debug ("learned %s, new revision: %ul", st_run->st->symbol, nrev); + } curst = g_list_next (curst); } |