]> source.dussan.org Git - rspamd.git/commitdiff
Update bayes.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 26 Jan 2015 14:07:19 +0000 (14:07 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 26 Jan 2015 14:07:19 +0000 (14:07 +0000)
src/libstat/classifiers/bayes.c
src/libstat/classifiers/classifiers.h

index 6e068b79d430c245093cf177149e677659694ac3..54d696721e8b91ca41c9b013468e0945a6327cec 100644 (file)
@@ -39,63 +39,6 @@ bayes_error_quark (void)
        return g_quark_from_static_string ("bayes-error");
 }
 
-struct bayes_statfile_data {
-       guint64 hits;
-       guint64 total_hits;
-       double value;
-       struct rspamd_statfile_config *st;
-};
-
-struct bayes_callback_data {
-       struct classifier_ctx *ctx;
-       gboolean in_class;
-       time_t now;
-       struct bayes_statfile_data *statfiles;
-       guint32 statfiles_num;
-       guint64 total_spam;
-       guint64 total_ham;
-       guint64 processed_tokens;
-       gsize max_tokens;
-       double spam_probability;
-       double ham_probability;
-};
-
-static gboolean
-bayes_learn_callback (gpointer key, gpointer value, gpointer data)
-{
-       rspamd_token_t *node = key;
-       struct bayes_callback_data *cd = data;
-       gint c;
-       guint64 v;
-
-       c = (cd->in_class) ? 1 : -1;
-
-       /* Consider that not found blocks have value 1 */
-       /* XXX implement getting */
-       if (v == 0 && c > 0) {
-               /* XXX: add token to the backend */
-               cd->processed_tokens++;
-       }
-       else if (v != 0) {
-               if (G_LIKELY (c > 0)) {
-                       v++;
-               }
-               else if (c < 0) {
-                       if (v != 0) {
-                               v--;
-                       }
-               }
-               /* XXX: Implement setting */
-               cd->processed_tokens++;
-       }
-
-       if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) {
-               /* Stop learning on max tokens */
-               return TRUE;
-       }
-       return FALSE;
-}
-
 /**
  * Returns probability of chisquare > value with specified number of freedom
  * degrees
@@ -142,45 +85,35 @@ inv_chi_square (gdouble value, gint freedom_deg)
 static gboolean
 bayes_classify_callback (gpointer key, gpointer value, gpointer data)
 {
-
-       rspamd_token_t *node = key;
-       struct bayes_callback_data *cd = data;
+       rspamd_token_t *node = value;
+       struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
        guint i;
-       struct bayes_statfile_data *cur;
+       struct rspamd_token_result *res;
        guint64 spam_count = 0, ham_count = 0, total_count = 0;
        double spam_prob, spam_freq, ham_freq, bayes_spam_prob;
 
-       for (i = 0; i < cd->statfiles_num; i++) {
-               cur = &cd->statfiles[i];
-               /*
-                * XXX: Implement getting
-                */
-               if (cur->value > 0) {
-                       cur->total_hits += cur->value;
-                       if (cur->st->is_spam) {
-                               spam_count += cur->value;
+       for (i = rt->start_pos; i < rt->end_pos; i++) {
+               res = &g_array_index (node->results, struct rspamd_token_result, i);
+
+               if (res->value > 0) {
+                       if (res->st_runtime->st->is_spam) {
+                               spam_count += res->value;
                        }
                        else {
-                               ham_count += cur->value;
+                               ham_count += res->value;
                        }
-                       total_count += cur->value;
+                       total_count += res->value;
                }
        }
 
        /* Probability for this token */
        if (total_count > 0) {
-               spam_freq = ((double)spam_count / MAX (1., (double)cd->total_spam));
-               ham_freq = ((double)ham_count / MAX (1., (double)cd->total_ham));
+               spam_freq = ((double)spam_count / MAX (1., (double)rt->total_spam));
+               ham_freq = ((double)ham_count / MAX (1., (double)rt->total_ham));
                spam_prob = spam_freq / (spam_freq + ham_freq);
                bayes_spam_prob = (0.5 + spam_prob * total_count) / (1. + total_count);
-               cd->spam_probability += log (bayes_spam_prob);
-               cd->ham_probability += log (1. - bayes_spam_prob);
-               cd->processed_tokens++;
-       }
-
-       if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) {
-               /* Stop classifying on max tokens */
-               return TRUE;
+               rt->spam_prob += log (bayes_spam_prob);
+               rt->ham_prob += log (1. - bayes_spam_prob);
        }
 
        return FALSE;
@@ -202,91 +135,53 @@ bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier_config *cfg)
 gboolean
 bayes_classify (struct classifier_ctx * ctx,
        GTree *input,
+       struct rspamd_classifier_runtime *rt,
        struct rspamd_task *task)
 {
-       struct bayes_callback_data data;
-       gchar *value;
-       gint nodes, i = 0, selected_st = -1, cnt;
-       gint minnodes;
-       guint64 maxhits = 0, rev;
        double final_prob, h, s;
-       struct rspamd_statfile_config *st;
+       guint maxhits = 0;
+       struct rspamd_statfile_runtime *st, *selected_st = NULL;
        GList *cur;
        char *sumbuf;
 
        g_assert (ctx != NULL);
+       g_assert (input != NULL);
+       g_assert (rt != NULL);
+       g_assert (rt->end_pos > rt->start_pos);
 
-       if (ctx->cfg->opts &&
-               (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
-               minnodes = strtol (value, NULL, 10);
-               nodes = g_tree_nnodes (input);
+       g_tree_foreach (input, bayes_classify_callback, &rt);
 
-               if (nodes < minnodes) {
-                       return FALSE;
-               }
-       }
-
-       cur = ctx->cfg->statfiles;
-#if 0
-       cur = rspamd_lua_call_cls_pre_callbacks (ctx->cfg, task, FALSE, FALSE, L);
-       if (cur) {
-               rspamd_mempool_add_destructor (task->task_pool,
-                       (rspamd_mempool_destruct_t)g_list_free, cur);
-       }
-       else {
-               cur = ctx->cfg->statfiles;
-       }
-#endif
-
-
-       data.statfiles_num = g_list_length (cur);
-       data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num);
-       data.now = time (NULL);
-       data.ctx = ctx;
-
-       data.processed_tokens = 0;
-       data.spam_probability = 0;
-       data.ham_probability = 0;
-       data.total_ham = 0;
-       data.total_spam = 0;
-       if (ctx->cfg->opts &&
-               (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
-               minnodes = rspamd_config_parse_limit (value, -1);
-               data.max_tokens = minnodes;
-       }
-       else {
-               data.max_tokens = 0;
-       }
-
-       cnt = i;
-
-       g_tree_foreach (input, bayes_classify_callback, &data);
-
-       if (data.processed_tokens == 0 || data.spam_probability == 0) {
+       if (rt->spam_prob == 0) {
                final_prob = 0;
        }
        else {
-               h = 1 - inv_chi_square (-2. * data.spam_probability,
-                               2 * data.processed_tokens);
-               s = 1 - inv_chi_square (-2. * data.ham_probability,
-                               2 * data.processed_tokens);
+               h = 1 - inv_chi_square (-2. * rt->spam_prob,
+                               2 * rt->processed_tokens);
+               s = 1 - inv_chi_square (-2. * rt->ham_prob,
+                               2 * rt->processed_tokens);
                final_prob = (s + 1 - h) / 2.;
        }
 
-       if (data.processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
+       if (rt->processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
 
                sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
-               for (i = 0; i < cnt; i++) {
-                       if ((final_prob > 0.5 && !data.statfiles[i].st->is_spam) ||
-                               (final_prob < 0.5 && data.statfiles[i].st->is_spam)) {
-                               continue;
-                       }
-                       if (data.statfiles[i].total_hits > maxhits) {
-                               maxhits = data.statfiles[i].total_hits;
-                               selected_st = i;
+               cur = g_list_first (rt->st_runtime);
+
+               while (cur) {
+                       st = (struct rspamd_statfile_runtime *)cur->data;
+
+                       if ((final_prob < 0.5 && !st->st->is_spam) ||
+                               (final_prob > 0.5 && st->st->is_spam)) {
+                               if (st->total_hits > maxhits) {
+                                       maxhits = st->total_hits;
+                                       selected_st = st;
+                               }
                        }
+
+                       cur = g_list_next (cur);
                }
-               if (selected_st == -1) {
+
+               if (selected_st == NULL) {
                        msg_err (
                                "unexpected classifier error: cannot select desired statfile");
                }
@@ -298,65 +193,75 @@ bayes_classify (struct classifier_ctx * ctx,
                        rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
                        cur = g_list_prepend (NULL, sumbuf);
                        rspamd_task_insert_result (task,
-                               data.statfiles[selected_st].st->symbol,
+                               selected_st->st->symbol,
                                final_prob,
                                cur);
                }
        }
 
-       g_free (data.statfiles);
-
        return TRUE;
 }
 
+static gboolean
+bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data)
+{
+       rspamd_token_t *node = value;
+       struct rspamd_token_result *res;
+       struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
+       guint i;
+
+
+       for (i = rt->start_pos; i < rt->end_pos; i++) {
+               res = &g_array_index (node->results, struct rspamd_token_result, i);
+
+               if (res->st_runtime->st->is_spam) {
+                       res->value ++;
+               }
+       }
+
+       return FALSE;
+}
+
+static gboolean
+bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data)
+{
+       rspamd_token_t *node = value;
+       struct rspamd_token_result *res;
+       struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
+       guint i;
+
+
+       for (i = rt->start_pos; i < rt->end_pos; i++) {
+               res = &g_array_index (node->results, struct rspamd_token_result, i);
+
+               if (!res->st_runtime->st->is_spam) {
+                       res->value ++;
+               }
+       }
+
+       return FALSE;
+}
+
 gboolean
 bayes_learn_spam (struct classifier_ctx * ctx,
        GTree *input,
+       struct rspamd_classifier_runtime *rt,
        struct rspamd_task *task,
        gboolean is_spam,
        GError **err)
 {
-       struct bayes_callback_data data;
-       gchar *value;
-       gint nodes;
-       gint minnodes;
-       struct rspamd_statfile_config *st;
-       GList *cur;
-       gboolean skip_labels;
-
        g_assert (ctx != NULL);
+       g_assert (input != NULL);
+       g_assert (rt != NULL);
+       g_assert (rt->end_pos > rt->start_pos);
 
-       if (ctx->cfg->opts &&
-               (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
-               minnodes = strtol (value, NULL, 10);
-               nodes = g_tree_nnodes (input);
-
-               if (nodes < minnodes) {
-                       g_set_error (err,
-                               bayes_error_quark (),           /* error domain */
-                               1,                              /* error code */
-                               "message contains too few tokens: %d, while min is %d",
-                               nodes, (int)minnodes);
-                       return FALSE;
-               }
-       }
-
-       data.now = time (NULL);
-       data.ctx = ctx;
-       data.in_class = TRUE;
-
-       data.processed_tokens = 0;
-       if (ctx->cfg->opts &&
-               (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
-               minnodes = rspamd_config_parse_limit (value, -1);
-               data.max_tokens = minnodes;
+       if (is_spam) {
+               g_tree_foreach (input, bayes_learn_spam_callback, rt);
        }
        else {
-               data.max_tokens = 0;
+               g_tree_foreach (input, bayes_learn_ham_callback, rt);
        }
 
-       g_tree_foreach (input, bayes_learn_callback, &data);
-
 
        return TRUE;
 }
index e2bf57f81f30d443da20c5712723ef06cfd20963..9a30039df767e47486e3654d45adb734b7deef3e 100644 (file)
@@ -18,14 +18,19 @@ struct classifier_ctx {
        struct rspamd_classifier_config *cfg;
 };
 
+struct token_node_s;
+struct rspamd_classifier_runtime;
+
 struct rspamd_stat_classifier {
        char *name;
        struct classifier_ctx * (*init_func)(rspamd_mempool_t *pool,
                struct rspamd_classifier_config *cf);
        gboolean (*classify_func)(struct classifier_ctx * ctx,
-               GTree *input, struct rspamd_task *task);
+               GTree *input, struct rspamd_classifier_runtime *rt,
+               struct rspamd_task *task);
        gboolean (*learn_spam_func)(struct classifier_ctx * ctx,
-               GTree *input, struct rspamd_task *task, gboolean is_spam,
+               GTree *input, struct rspamd_classifier_runtime *rt,
+               struct rspamd_task *task, gboolean is_spam,
                GError **err);
 };
 
@@ -34,9 +39,11 @@ struct classifier_ctx * bayes_init (rspamd_mempool_t *pool,
        struct rspamd_classifier_config *cf);
 gboolean bayes_classify (struct classifier_ctx * ctx,
        GTree *input,
+       struct rspamd_classifier_runtime *rt,
        struct rspamd_task *task);
 gboolean bayes_learn_spam (struct classifier_ctx * ctx,
        GTree *input,
+       struct rspamd_classifier_runtime *rt,
        struct rspamd_task *task,
        gboolean is_spam,
        GError **err);