]> source.dussan.org Git - rspamd.git/commitdiff
Fixing learning.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 26 Jan 2015 22:20:29 +0000 (22:20 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 26 Jan 2015 22:20:29 +0000 (22:20 +0000)
src/libstat/backends/backends.h
src/libstat/backends/mmaped_file.c
src/libstat/stat_process.c

index 6a1557a8d779ecde1df5981c13a5dd162c52c432..d174756df0d9766c76b474df18898d6f7d878bf0 100644 (file)
@@ -42,18 +42,24 @@ struct token_node_s;
 struct rspamd_stat_backend {
        const char *name;
        gpointer (*init)(struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg);
-       gpointer (*runtime)(struct rspamd_statfile_config *stcf, gpointer ctx);
+       gpointer (*runtime)(struct rspamd_statfile_config *stcf, gboolean learn, gpointer ctx);
        gboolean (*process_token)(struct token_node_s *tok,
                        struct rspamd_token_result *res, gpointer ctx);
+       gboolean (*learn_token)(struct token_node_s *tok,
+                       struct rspamd_token_result *res, gpointer ctx);
        gulong (*total_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx);
        gpointer ctx;
 };
 
 gpointer rspamd_mmaped_file_init(struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg);
-gpointer rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf, gpointer ctx);
+gpointer rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf,
+               gboolean learn, gpointer ctx);
 gboolean rspamd_mmaped_file_process_token (struct token_node_s *tok,
                struct rspamd_token_result *res,
                gpointer ctx);
+gboolean rspamd_mmaped_file_learn_token (struct token_node_s *tok,
+               struct rspamd_token_result *res,
+               gpointer ctx);
 gulong rspamd_mmaped_file_total_learns (struct rspamd_statfile_runtime *runtime,
                gpointer ctx);
 
index 8da56ad874e1ad895436989c8b02fd01744c824a..92796558689e05483c0e7a18e3e247e200386937 100644 (file)
@@ -251,7 +251,6 @@ rspamd_mmaped_file_set_block (rspamd_mmaped_file_ctx * pool,
        rspamd_mmaped_file_t * file,
        guint32 h1,
        guint32 h2,
-       time_t now,
        double value)
 {
        rspamd_mmaped_file_set_block_common (pool, file, h1, h2, value);
@@ -856,15 +855,42 @@ rspamd_mmaped_file_init (struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg)
 }
 
 gpointer
-rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf, gpointer p)
+rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf, gboolean learn,
+               gpointer p)
 {
        rspamd_mmaped_file_ctx *ctx = (rspamd_mmaped_file_ctx *)p;
        rspamd_mmaped_file_t *mf;
+       const ucl_object_t *filenameo, *sizeo;
+       const gchar *filename;
+       gsize size;
 
        g_assert (ctx != NULL);
 
        mf = rspamd_mmaped_file_is_open (ctx, stcf);
 
+       if (mf == NULL && learn) {
+               /* Create file here */
+
+               filenameo = ucl_object_find_key (stcf->opts, "filename");
+               if (filenameo == NULL || ucl_object_type (filenameo) != UCL_STRING) {
+                       msg_err ("statfile %s has no filename defined", stcf->symbol);
+                       return NULL;
+               }
+
+               filename = ucl_object_tostring (filenameo);
+
+               sizeo = ucl_object_find_key (stcf->opts, "size");
+               if (sizeo == NULL || ucl_object_type (sizeo) != UCL_INT) {
+                       msg_err ("statfile %s has no size defined", stcf->symbol);
+                       return NULL;
+               }
+
+               size = ucl_object_toint (sizeo);
+               rspamd_mmaped_file_create (ctx, filename, size, stcf);
+
+               mf = rspamd_mmaped_file_open (ctx, filename, size, stcf);
+       }
+
        return (gpointer)mf;
 }
 
@@ -902,6 +928,40 @@ rspamd_mmaped_file_process_token (rspamd_token_t *tok,
        return FALSE;
 }
 
+gboolean
+rspamd_mmaped_file_learn_token (rspamd_token_t *tok,
+               struct rspamd_token_result *res,
+               gpointer p)
+{
+       rspamd_mmaped_file_ctx *ctx = (rspamd_mmaped_file_ctx *)p;
+       rspamd_mmaped_file_t *mf;
+       guint32 h1, h2;
+
+       g_assert (res != NULL);
+       g_assert (p != NULL);
+       g_assert (res->st_runtime != NULL);
+       g_assert (tok != NULL);
+       g_assert (tok->datalen >= sizeof (guint32) * 2);
+
+       mf = (rspamd_mmaped_file_t *)res->st_runtime->backend_runtime;
+
+       if (mf == NULL) {
+               /* Statfile is does not exist, so all values are zero */
+               res->value = 0.0;
+               return FALSE;
+       }
+
+       memcpy (&h1, tok->data, sizeof (h1));
+       memcpy (&h2, tok->data + sizeof (h1), sizeof (h2));
+       rspamd_mmaped_file_set_block (ctx, mf, h1, h2, res->value);
+
+       if (res->value > 0.0) {
+               return TRUE;
+       }
+
+       return FALSE;
+}
+
 gulong
 rspamd_mmaped_file_total_learns (struct rspamd_statfile_runtime *runtime,
                gpointer ctx)
index e00c0aaa82341692d770e097b75a9f8b6797c5dc..17c8059a9c605f186c271e000ff95b62763da13b 100644 (file)
@@ -135,7 +135,7 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d)
 static GList*
 rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
                struct rspamd_task *task, struct rspamd_tokenizer_runtime *tklist,
-               lua_State *L, GError **err)
+               lua_State *L, gboolean learn, GError **err)
 {
        struct rspamd_classifier_config *clcf;
        struct rspamd_statfile_config *stcf;
@@ -193,7 +193,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
                                continue;
                        }
 
-                       backend_runtime = bk->runtime (stcf, bk->ctx);
+                       backend_runtime = bk->runtime (stcf, learn, bk->ctx);
 
                        st_runtime = rspamd_mempool_alloc0 (task->task_pool,
                                        sizeof (*st_runtime));
@@ -243,8 +243,6 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
                cbdata.results_count = result_size;
                cbdata.classifier_runtimes = cl_runtimes;
                cbdata.task = task;
-
-               /* Allocate token results */
                cbdata.tok = cl_runtime->tok;
                g_tree_foreach (cl_runtime->tok->tokens, preprocess_init_stat_token,
                                &cbdata);
@@ -346,7 +344,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
        }
 
        /* Initialize classifiers and statfiles runtime */
-       if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, err))
+       if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, FALSE, err))
                        == NULL) {
                return FALSE;
        }
@@ -371,6 +369,65 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
        return ret;
 }
 
+static gboolean
+rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
+{
+       rspamd_token_t *t = (rspamd_token_t *)v;
+       struct preprocess_cb_data *cbdata = (struct preprocess_cb_data *)d;
+       struct rspamd_statfile_runtime *st_runtime;
+       struct rspamd_classifier_runtime *cl_runtime;
+       struct rspamd_token_result *res;
+       GList *cur, *curst;
+       gint i = 0;
+
+       cur = g_list_first (cbdata->classifier_runtimes);
+
+       while (cur) {
+               cl_runtime = (struct rspamd_classifier_runtime *)cur->data;
+
+               if (cl_runtime->clcf->min_tokens > 0 &&
+                               (guint32)g_tree_nnodes (cbdata->tok->tokens) < cl_runtime->clcf->min_tokens) {
+                       /* Skip this classifier */
+                       msg_debug ("<%s> contains less tokens than required for %s classifier: "
+                                       "%ud < %ud", cbdata->task->message_id, cl_runtime->clcf->name,
+                                       g_tree_nnodes (cbdata->tok->tokens),
+                                       cl_runtime->clcf->min_tokens);
+                       cur = g_list_next (cur);
+                       continue;
+               }
+
+               res = &g_array_index (t->results, struct rspamd_token_result, i);
+
+               curst = res->cl_runtime->st_runtime;
+
+               while (curst) {
+                       st_runtime = (struct rspamd_statfile_runtime *)curst->data;
+
+                       if (st_runtime->backend->learn_token (t, res,
+                                       st_runtime->backend->ctx)) {
+                               cl_runtime->processed_tokens ++;
+
+                               if (cl_runtime->clcf->max_tokens > 0 &&
+                                               cl_runtime->processed_tokens > cl_runtime->clcf->max_tokens) {
+                                       msg_debug ("<%s> contains more tokens than allowed for %s classifier: "
+                                                       "%ud > %ud", cbdata->task, cl_runtime->clcf->name,
+                                                       cl_runtime->processed_tokens,
+                                                       cl_runtime->clcf->max_tokens);
+
+                                       return TRUE;
+                               }
+                       }
+
+                       i ++;
+                       curst = g_list_next (curst);
+               }
+               cur = g_list_next (cur);
+       }
+
+
+       return FALSE;
+}
+
 gboolean
 rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
                GError **err)
@@ -381,6 +438,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
        struct rspamd_tokenizer_runtime *tklist = NULL, *tok;
        struct rspamd_classifier_runtime *cl_run;
        struct classifier_ctx *cl_ctx;
+       struct preprocess_cb_data cbdata;
        GList *cl_runtimes;
        GList *cur;
        gboolean ret = FALSE;
@@ -416,7 +474,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
        }
 
        /* Initialize classifiers and statfiles runtime */
-       if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, err))
+       if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, TRUE, err))
                        == NULL) {
                return FALSE;
        }
@@ -430,8 +488,21 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
                        cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf);
 
                        if (cl_ctx != NULL) {
-                               ret |= cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens,
-                                               cl_run, task, spam, err);
+                               if (cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens,
+                                               cl_run, task, spam, err)) {
+                                       ret = TRUE;
+
+                                       cbdata.classifier_runtimes = cur;
+                                       cbdata.task = task;
+                                       cbdata.tok = cl_run->tok;
+                                       g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token,
+                                                       &cbdata);
+
+                               }
+                               else {
+                                       return FALSE;
+                               }
+
                        }
                }