diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2015-01-26 22:20:29 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2015-01-26 22:20:29 +0000 |
commit | f117de8a3c06894d5c10d39abf72611c31bf31a3 (patch) | |
tree | 2ca37497102f6f23fafaa88b01e55a927c08d38c /src/libstat/stat_process.c | |
parent | 5525b96d6c9d3f4c36167f4e5abf48c485ac5a07 (diff) | |
download | rspamd-f117de8a3c06894d5c10d39abf72611c31bf31a3.tar.gz rspamd-f117de8a3c06894d5c10d39abf72611c31bf31a3.zip |
Fixing learning.
Diffstat (limited to 'src/libstat/stat_process.c')
-rw-r--r-- | src/libstat/stat_process.c | 87 |
1 files changed, 79 insertions, 8 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index e00c0aaa8..17c8059a9 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -135,7 +135,7 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d) static GList* rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, struct rspamd_task *task, struct rspamd_tokenizer_runtime *tklist, - lua_State *L, GError **err) + lua_State *L, gboolean learn, GError **err) { struct rspamd_classifier_config *clcf; struct rspamd_statfile_config *stcf; @@ -193,7 +193,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, continue; } - backend_runtime = bk->runtime (stcf, bk->ctx); + backend_runtime = bk->runtime (stcf, learn, bk->ctx); st_runtime = rspamd_mempool_alloc0 (task->task_pool, sizeof (*st_runtime)); @@ -243,8 +243,6 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, cbdata.results_count = result_size; cbdata.classifier_runtimes = cl_runtimes; cbdata.task = task; - - /* Allocate token results */ cbdata.tok = cl_runtime->tok; g_tree_foreach (cl_runtime->tok->tokens, preprocess_init_stat_token, &cbdata); @@ -346,7 +344,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err) } /* Initialize classifiers and statfiles runtime */ - if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, err)) + if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, FALSE, err)) == NULL) { return FALSE; } @@ -371,6 +369,65 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err) return ret; } +static gboolean +rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d) +{ + rspamd_token_t *t = (rspamd_token_t *)v; + struct preprocess_cb_data *cbdata = (struct preprocess_cb_data *)d; + struct rspamd_statfile_runtime *st_runtime; + struct rspamd_classifier_runtime *cl_runtime; + struct rspamd_token_result *res; + GList *cur, *curst; + gint i = 0; + + cur = g_list_first (cbdata->classifier_runtimes); + + while (cur) { + cl_runtime = (struct rspamd_classifier_runtime *)cur->data; + + if (cl_runtime->clcf->min_tokens > 0 && + (guint32)g_tree_nnodes (cbdata->tok->tokens) < cl_runtime->clcf->min_tokens) { + /* Skip this classifier */ + msg_debug ("<%s> contains less tokens than required for %s classifier: " + "%ud < %ud", cbdata->task->message_id, cl_runtime->clcf->name, + g_tree_nnodes (cbdata->tok->tokens), + cl_runtime->clcf->min_tokens); + cur = g_list_next (cur); + continue; + } + + res = &g_array_index (t->results, struct rspamd_token_result, i); + + curst = res->cl_runtime->st_runtime; + + while (curst) { + st_runtime = (struct rspamd_statfile_runtime *)curst->data; + + if (st_runtime->backend->learn_token (t, res, + st_runtime->backend->ctx)) { + cl_runtime->processed_tokens ++; + + if (cl_runtime->clcf->max_tokens > 0 && + cl_runtime->processed_tokens > cl_runtime->clcf->max_tokens) { + msg_debug ("<%s> contains more tokens than allowed for %s classifier: " + "%ud > %ud", cbdata->task, cl_runtime->clcf->name, + cl_runtime->processed_tokens, + cl_runtime->clcf->max_tokens); + + return TRUE; + } + } + + i ++; + curst = g_list_next (curst); + } + cur = g_list_next (cur); + } + + + return FALSE; +} + gboolean rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, GError **err) @@ -381,6 +438,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, struct rspamd_tokenizer_runtime *tklist = NULL, *tok; struct rspamd_classifier_runtime *cl_run; struct classifier_ctx *cl_ctx; + struct preprocess_cb_data cbdata; GList *cl_runtimes; GList *cur; gboolean ret = FALSE; @@ -416,7 +474,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, } /* Initialize classifiers and statfiles runtime */ - if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, err)) + if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, TRUE, err)) == NULL) { return FALSE; } @@ -430,8 +488,21 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf); if (cl_ctx != NULL) { - ret |= cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens, - cl_run, task, spam, err); + if (cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens, + cl_run, task, spam, err)) { + ret = TRUE; + + cbdata.classifier_runtimes = cur; + cbdata.task = task; + cbdata.tok = cl_run->tok; + g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token, + &cbdata); + + } + else { + return FALSE; + } + } } |