From: Vsevolod Stakhov Date: Mon, 26 Jan 2015 22:20:29 +0000 (+0000) Subject: Fixing learning. X-Git-Tag: 0.9.0~825 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=f117de8a3c06894d5c10d39abf72611c31bf31a3;p=rspamd.git Fixing learning. --- diff --git a/src/libstat/backends/backends.h b/src/libstat/backends/backends.h index 6a1557a8d..d174756df 100644 --- a/src/libstat/backends/backends.h +++ b/src/libstat/backends/backends.h @@ -42,18 +42,24 @@ struct token_node_s; struct rspamd_stat_backend { const char *name; gpointer (*init)(struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg); - gpointer (*runtime)(struct rspamd_statfile_config *stcf, gpointer ctx); + gpointer (*runtime)(struct rspamd_statfile_config *stcf, gboolean learn, gpointer ctx); gboolean (*process_token)(struct token_node_s *tok, struct rspamd_token_result *res, gpointer ctx); + gboolean (*learn_token)(struct token_node_s *tok, + struct rspamd_token_result *res, gpointer ctx); gulong (*total_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx); gpointer ctx; }; gpointer rspamd_mmaped_file_init(struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg); -gpointer rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf, gpointer ctx); +gpointer rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf, + gboolean learn, gpointer ctx); gboolean rspamd_mmaped_file_process_token (struct token_node_s *tok, struct rspamd_token_result *res, gpointer ctx); +gboolean rspamd_mmaped_file_learn_token (struct token_node_s *tok, + struct rspamd_token_result *res, + gpointer ctx); gulong rspamd_mmaped_file_total_learns (struct rspamd_statfile_runtime *runtime, gpointer ctx); diff --git a/src/libstat/backends/mmaped_file.c b/src/libstat/backends/mmaped_file.c index 8da56ad87..927965586 100644 --- a/src/libstat/backends/mmaped_file.c +++ b/src/libstat/backends/mmaped_file.c @@ -251,7 +251,6 @@ rspamd_mmaped_file_set_block (rspamd_mmaped_file_ctx * pool, rspamd_mmaped_file_t * file, guint32 h1, guint32 h2, - time_t now, double value) { rspamd_mmaped_file_set_block_common (pool, file, h1, h2, value); @@ -856,15 +855,42 @@ rspamd_mmaped_file_init (struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg) } gpointer -rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf, gpointer p) +rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf, gboolean learn, + gpointer p) { rspamd_mmaped_file_ctx *ctx = (rspamd_mmaped_file_ctx *)p; rspamd_mmaped_file_t *mf; + const ucl_object_t *filenameo, *sizeo; + const gchar *filename; + gsize size; g_assert (ctx != NULL); mf = rspamd_mmaped_file_is_open (ctx, stcf); + if (mf == NULL && learn) { + /* Create file here */ + + filenameo = ucl_object_find_key (stcf->opts, "filename"); + if (filenameo == NULL || ucl_object_type (filenameo) != UCL_STRING) { + msg_err ("statfile %s has no filename defined", stcf->symbol); + return NULL; + } + + filename = ucl_object_tostring (filenameo); + + sizeo = ucl_object_find_key (stcf->opts, "size"); + if (sizeo == NULL || ucl_object_type (sizeo) != UCL_INT) { + msg_err ("statfile %s has no size defined", stcf->symbol); + return NULL; + } + + size = ucl_object_toint (sizeo); + rspamd_mmaped_file_create (ctx, filename, size, stcf); + + mf = rspamd_mmaped_file_open (ctx, filename, size, stcf); + } + return (gpointer)mf; } @@ -902,6 +928,40 @@ rspamd_mmaped_file_process_token (rspamd_token_t *tok, return FALSE; } +gboolean +rspamd_mmaped_file_learn_token (rspamd_token_t *tok, + struct rspamd_token_result *res, + gpointer p) +{ + rspamd_mmaped_file_ctx *ctx = (rspamd_mmaped_file_ctx *)p; + rspamd_mmaped_file_t *mf; + guint32 h1, h2; + + g_assert (res != NULL); + g_assert (p != NULL); + g_assert (res->st_runtime != NULL); + g_assert (tok != NULL); + g_assert (tok->datalen >= sizeof (guint32) * 2); + + mf = (rspamd_mmaped_file_t *)res->st_runtime->backend_runtime; + + if (mf == NULL) { + /* Statfile is does not exist, so all values are zero */ + res->value = 0.0; + return FALSE; + } + + memcpy (&h1, tok->data, sizeof (h1)); + memcpy (&h2, tok->data + sizeof (h1), sizeof (h2)); + rspamd_mmaped_file_set_block (ctx, mf, h1, h2, res->value); + + if (res->value > 0.0) { + return TRUE; + } + + return FALSE; +} + gulong rspamd_mmaped_file_total_learns (struct rspamd_statfile_runtime *runtime, gpointer ctx) diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index e00c0aaa8..17c8059a9 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -135,7 +135,7 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d) static GList* rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, struct rspamd_task *task, struct rspamd_tokenizer_runtime *tklist, - lua_State *L, GError **err) + lua_State *L, gboolean learn, GError **err) { struct rspamd_classifier_config *clcf; struct rspamd_statfile_config *stcf; @@ -193,7 +193,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, continue; } - backend_runtime = bk->runtime (stcf, bk->ctx); + backend_runtime = bk->runtime (stcf, learn, bk->ctx); st_runtime = rspamd_mempool_alloc0 (task->task_pool, sizeof (*st_runtime)); @@ -243,8 +243,6 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, cbdata.results_count = result_size; cbdata.classifier_runtimes = cl_runtimes; cbdata.task = task; - - /* Allocate token results */ cbdata.tok = cl_runtime->tok; g_tree_foreach (cl_runtime->tok->tokens, preprocess_init_stat_token, &cbdata); @@ -346,7 +344,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err) } /* Initialize classifiers and statfiles runtime */ - if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, err)) + if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, FALSE, err)) == NULL) { return FALSE; } @@ -371,6 +369,65 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err) return ret; } +static gboolean +rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d) +{ + rspamd_token_t *t = (rspamd_token_t *)v; + struct preprocess_cb_data *cbdata = (struct preprocess_cb_data *)d; + struct rspamd_statfile_runtime *st_runtime; + struct rspamd_classifier_runtime *cl_runtime; + struct rspamd_token_result *res; + GList *cur, *curst; + gint i = 0; + + cur = g_list_first (cbdata->classifier_runtimes); + + while (cur) { + cl_runtime = (struct rspamd_classifier_runtime *)cur->data; + + if (cl_runtime->clcf->min_tokens > 0 && + (guint32)g_tree_nnodes (cbdata->tok->tokens) < cl_runtime->clcf->min_tokens) { + /* Skip this classifier */ + msg_debug ("<%s> contains less tokens than required for %s classifier: " + "%ud < %ud", cbdata->task->message_id, cl_runtime->clcf->name, + g_tree_nnodes (cbdata->tok->tokens), + cl_runtime->clcf->min_tokens); + cur = g_list_next (cur); + continue; + } + + res = &g_array_index (t->results, struct rspamd_token_result, i); + + curst = res->cl_runtime->st_runtime; + + while (curst) { + st_runtime = (struct rspamd_statfile_runtime *)curst->data; + + if (st_runtime->backend->learn_token (t, res, + st_runtime->backend->ctx)) { + cl_runtime->processed_tokens ++; + + if (cl_runtime->clcf->max_tokens > 0 && + cl_runtime->processed_tokens > cl_runtime->clcf->max_tokens) { + msg_debug ("<%s> contains more tokens than allowed for %s classifier: " + "%ud > %ud", cbdata->task, cl_runtime->clcf->name, + cl_runtime->processed_tokens, + cl_runtime->clcf->max_tokens); + + return TRUE; + } + } + + i ++; + curst = g_list_next (curst); + } + cur = g_list_next (cur); + } + + + return FALSE; +} + gboolean rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, GError **err) @@ -381,6 +438,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, struct rspamd_tokenizer_runtime *tklist = NULL, *tok; struct rspamd_classifier_runtime *cl_run; struct classifier_ctx *cl_ctx; + struct preprocess_cb_data cbdata; GList *cl_runtimes; GList *cur; gboolean ret = FALSE; @@ -416,7 +474,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, } /* Initialize classifiers and statfiles runtime */ - if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, err)) + if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, TRUE, err)) == NULL) { return FALSE; } @@ -430,8 +488,21 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf); if (cl_ctx != NULL) { - ret |= cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens, - cl_run, task, spam, err); + if (cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens, + cl_run, task, spam, err)) { + ret = TRUE; + + cbdata.classifier_runtimes = cur; + cbdata.task = task; + cbdata.tok = cl_run->tok; + g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token, + &cbdata); + + } + else { + return FALSE; + } + } }