]> source.dussan.org Git - rspamd.git/commitdiff
Fix bayes classifier for the new architecture
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 5 Jan 2016 18:02:47 +0000 (18:02 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 5 Jan 2016 18:02:47 +0000 (18:02 +0000)
src/libstat/classifiers/bayes.c
src/libstat/classifiers/classifiers.h
src/libstat/stat_config.c
src/libstat/stat_internal.h
src/libstat/stat_process.c

index a271a424a0784eacaa65b5a36572a62fb0b6540c..0915933f1485b34fc112b68d6b199e77157567c6 100644 (file)
@@ -90,7 +90,10 @@ inv_chi_square (struct rspamd_task *task, gdouble value, gint freedom_deg)
 }
 
 struct bayes_task_closure {
-       struct rspamd_classifier_runtime *rt;
+       double ham_prob;
+       double spam_prob;
+       guint64 processed_tokens;
+       guint64 total_hits;
        struct rspamd_task *task;
 };
 
@@ -104,44 +107,46 @@ static const double feature_weight[] = { 0, 1, 4, 27, 256, 3125, 46656, 823543 }
 /*
  * In this callback we calculate local probabilities for tokens
  */
-static gboolean
-bayes_classify_callback (gpointer key, gpointer value, gpointer data)
+static void
+bayes_classify_token (struct rspamd_classifier *ctx,
+               rspamd_token_t *tok, struct bayes_task_closure *cl)
 {
-       rspamd_token_t *node = value;
-       struct bayes_task_closure *cl = data;
-       struct rspamd_classifier_runtime *rt;
        guint i;
-       struct rspamd_token_result *res;
+       gint id;
        guint64 spam_count = 0, ham_count = 0, total_count = 0;
+       struct rspamd_statfile *st;
        struct rspamd_task *task;
        double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
-               ham_prob, fw, w, norm_sum, norm_sub;
+               ham_prob, fw, w, norm_sum, norm_sub, val;
 
-       rt = cl->rt;
        task = cl->task;
 
-       for (i = rt->start_pos; i < rt->end_pos; i++) {
-               res = &g_array_index (node->results, struct rspamd_token_result, i);
+       for (i = 0; i < ctx->statfiles_ids->len; i++) {
+               id = g_array_index (ctx->statfiles_ids, gint, i);
+               st = g_ptr_array_index (ctx->ctx->statfiles, id);
+               g_assert (st != NULL);
+               val = tok->values[id];
 
-               if (res->value > 0) {
-                       if (res->st_runtime->st->is_spam) {
-                               spam_count += res->value;
+               if (val > 0) {
+                       if (st->stcf->is_spam) {
+                               spam_count += val;
                        }
                        else {
-                               ham_count += res->value;
+                               ham_count += val;
                        }
-                       total_count += res->value;
-                       res->st_runtime->total_hits += res->value;
+
+                       total_count += val;
+                       cl->total_hits += val;
                }
        }
 
        /* Probability for this token */
        if (total_count > 0) {
-               spam_freq = ((double)spam_count / MAX (1., (double)rt->total_spam));
-               ham_freq = ((double)ham_count / MAX (1., (double)rt->total_ham));
+               spam_freq = ((double)spam_count / MAX (1., (double) ctx->spam_learns));
+               ham_freq = ((double)ham_count / MAX (1., (double)ctx->ham_learns));
                spam_prob = spam_freq / (spam_freq + ham_freq);
                ham_prob = ham_freq / (spam_freq + ham_freq);
-               fw = feature_weight[node->window_idx % G_N_ELEMENTS (feature_weight)];
+               fw = feature_weight[tok->window_idx % G_N_ELEMENTS (feature_weight)];
                norm_sum = (spam_freq + ham_freq) * (spam_freq + ham_freq);
                norm_sub = (spam_freq - ham_freq) * (spam_freq - ham_freq);
                w = (norm_sub) / (norm_sum) *
@@ -151,9 +156,9 @@ bayes_classify_callback (gpointer key, gpointer value, gpointer data)
                w = (norm_sub) / (norm_sum) *
                                (fw * total_count) / (4.0 * (1.0 + fw * total_count));
                bayes_ham_prob = PROB_COMBINE (ham_prob, total_count, w, 0.5);
-               rt->spam_prob += log (bayes_spam_prob);
-               rt->ham_prob += log (bayes_ham_prob);
-               res->cl_runtime->processed_tokens ++;
+               cl->spam_prob += log (bayes_spam_prob);
+               cl->ham_prob += log (bayes_ham_prob);
+               cl->processed_tokens ++;
 
                msg_debug_bayes ("token: weight: %f, total_count: %L, "
                                "spam_count: %L, ham_count: %L,"
@@ -163,10 +168,8 @@ bayes_classify_callback (gpointer key, gpointer value, gpointer data)
                                fw, total_count, spam_count, ham_count,
                                spam_prob, ham_prob,
                                bayes_spam_prob, bayes_ham_prob,
-                               rt->spam_prob, rt->ham_prob);
+                               cl->spam_prob, cl->ham_prob);
        }
-
-       return FALSE;
 }
 
 /*
@@ -198,176 +201,146 @@ bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier *cl)
 
 gboolean
 bayes_classify (struct rspamd_classifier * ctx,
-       GTree *input,
-       struct rspamd_classifier_runtime *rt,
-       struct rspamd_task *task)
+               GPtrArray *tokens,
+               struct rspamd_task *task)
 {
        double final_prob, h, s;
-       guint maxhits = 0;
-       struct rspamd_statfile_runtime *st, *selected_st = NULL;
-       GList *cur;
        char *sumbuf;
+       struct rspamd_statfile *st = NULL;
        struct bayes_task_closure cl;
+       rspamd_token_t *tok;
+       guint i;
+       gint id;
+       GList *cur;
 
        g_assert (ctx != NULL);
-       g_assert (input != NULL);
-       g_assert (rt != NULL);
-       g_assert (rt->end_pos > rt->start_pos);
-
-       if (rt->stage == RSPAMD_STAT_STAGE_PRE) {
-               cl.rt = rt;
-               cl.task = task;
-               g_tree_foreach (input, bayes_classify_callback, &cl);
+       g_assert (tokens != NULL);
+
+       memset (&cl, 0, sizeof (cl));
+       cl.task = task;
+
+       for (i = 0; i < tokens->len; i ++) {
+               tok = g_ptr_array_index (tokens, i);
+
+               bayes_classify_token (ctx, tok, &cl);
+       }
+
+       h = 1 - inv_chi_square (task, cl.spam_prob, cl.processed_tokens);
+       s = 1 - inv_chi_square (task, cl.ham_prob, cl.processed_tokens);
+
+       if (isfinite (s) && isfinite (h)) {
+               final_prob = (s + 1.0 - h) / 2.;
+               msg_debug_bayes (
+                               "<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f,"
+                                               " %L tokens processed of %ud total tokens",
+                               task->message_id,
+                               cl.ham_prob,
+                               h,
+                               cl.spam_prob,
+                               s,
+                               cl.processed_tokens,
+                               tokens->len);
        }
        else {
-               h = 1 - inv_chi_square (task, rt->spam_prob, rt->processed_tokens);
-               s = 1 - inv_chi_square (task, rt->ham_prob, rt->processed_tokens);
-
-               if (isfinite (s) && isfinite (h)) {
-                       final_prob = (s + 1.0 - h) / 2.;
-                       msg_debug_bayes ("<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f,"
-                                       " %L tokens processed of %ud total tokens",
-                                       task->message_id, rt->ham_prob, h, rt->spam_prob, s,
-                                       rt->processed_tokens, g_tree_nnodes (input));
+               /*
+                * We have some overflow, hence we need to check which class
+                * is NaN
+                */
+               if (isfinite (h)) {
+                       final_prob = 1.0;
+                       msg_debug_bayes ("<%s> spam class is overflowed, as we have no"
+                                       " ham samples", task->message_id);
+               }
+               else if (isfinite (s)) {
+                       final_prob = 0.0;
+                       msg_debug_bayes ("<%s> ham class is overflowed, as we have no"
+                                       " spam samples", task->message_id);
                }
                else {
-                       /*
-                        * We have some overflow, hence we need to check which class
-                        * is NaN
-                        */
-                       if (isfinite (h)) {
-                               final_prob = 1.0;
-                               msg_debug_bayes ("<%s> spam class is overflowed, as we have no"
-                                               " ham samples", task->message_id);
-                       }
-                       else if (isfinite (s)){
-                               final_prob = 0.0;
-                               msg_debug_bayes ("<%s> ham class is overflowed, as we have no"
-                                               " spam samples", task->message_id);
-                       }
-                       else {
-                               final_prob = 0.5;
-                               msg_warn_bayes ("<%s> spam and ham classes are both overflowed",
-                                               task->message_id);
-                       }
+                       final_prob = 0.5;
+                       msg_warn_bayes ("<%s> spam and ham classes are both overflowed",
+                                       task->message_id);
                }
+       }
 
-               if (rt->processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
-
-                       sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
-                       cur = g_list_first (rt->st_runtime);
+       if (cl.processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
 
-                       while (cur) {
-                               st = (struct rspamd_statfile_runtime *)cur->data;
+               sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
 
-                               if ((final_prob < 0.5 && !st->st->is_spam) ||
-                                               (final_prob > 0.5 && st->st->is_spam)) {
-                                       if (st->total_hits > maxhits) {
-                                               maxhits = st->total_hits;
-                                               selected_st = st;
-                                       }
-                               }
+               /* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */
+               for (i = 0; i < ctx->statfiles_ids->len; i++) {
+                       id = g_array_index (ctx->statfiles_ids, gint, i);
+                       st = g_ptr_array_index (ctx->ctx->statfiles, id);
 
-                               cur = g_list_next (cur);
+                       if (final_prob > 0.5 && st->stcf->is_spam) {
+                               break;
                        }
-
-                       if (selected_st == NULL) {
-                               msg_err_bayes (
-                                       "unexpected classifier error: cannot select desired statfile, "
-                                       "prob: %.4f", final_prob);
+                       else if (final_prob < 0.5 && !st->stcf->is_spam) {
+                               break;
                        }
-                       else {
-                               /* Correctly scale HAM */
-                               if (final_prob < 0.5) {
-                                       final_prob = 1.0 - final_prob;
-                               }
-
-                               rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
-                               final_prob = bayes_normalize_prob (final_prob);
+               }
 
-                               cur = g_list_prepend (NULL, sumbuf);
-                               rspamd_task_insert_result (task,
-                                               selected_st->st->symbol,
-                                               final_prob,
-                                               cur);
-                       }
+               /* Correctly scale HAM */
+               if (final_prob < 0.5) {
+                       final_prob = 1.0 - final_prob;
                }
+
+               rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
+               final_prob = bayes_normalize_prob (final_prob);
+               g_assert (st != NULL);
+               cur = g_list_prepend (NULL, sumbuf);
+               rspamd_task_insert_result (task,
+                               st->stcf->symbol,
+                               final_prob,
+                               cur);
        }
 
        return TRUE;
 }
 
-static gboolean
-bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data)
+gboolean
+bayes_learn_spam (struct rspamd_classifier * ctx,
+               GPtrArray *tokens,
+               struct rspamd_task *task,
+               gboolean is_spam,
+               GError **err)
 {
-       rspamd_token_t *node = value;
-       struct rspamd_token_result *res;
-       struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
-       guint i;
-
+       guint i, j;
+       gint id;
+       struct rspamd_statfile *st;
+       rspamd_token_t *tok;
 
-       for (i = rt->start_pos; i < rt->end_pos; i++) {
-               res = &g_array_index (node->results, struct rspamd_token_result, i);
-
-               if (res->st_runtime) {
-                       if (res->st_runtime->st->is_spam) {
-                               res->value ++;
-                       }
-                       else if (res->value > 0) {
-                               /* Unlearning */
-                               res->value --;
-                       }
-               }
-       }
-
-       return FALSE;
-}
-
-static gboolean
-bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data)
-{
-       rspamd_token_t *node = value;
-       struct rspamd_token_result *res;
-       struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
-       guint i;
+       g_assert (ctx != NULL);
+       g_assert (tokens != NULL);
 
+       for (i = 0; i < tokens->len; i++) {
+               tok = g_ptr_array_index (tokens, i);
 
-       for (i = rt->start_pos; i < rt->end_pos; i++) {
-               res = &g_array_index (node->results, struct rspamd_token_result, i);
+               for (j = 0; j < ctx->statfiles_ids->len; j++) {
+                       id = g_array_index (ctx->statfiles_ids, gint, j);
+                       st = g_ptr_array_index (ctx->ctx->statfiles, id);
+                       g_assert (st != NULL);
 
-               if (res->st_runtime) {
-                       if (!res->st_runtime->st->is_spam) {
-                               res->value ++;
+                       if (is_spam) {
+                               if (st->stcf->is_spam) {
+                                       tok->values[id]++;
+                               }
+                               else if (tok->values[id] > 0) {
+                                       /* Unlearning */
+                                       tok->values[id]--;
+                               }
                        }
-                       else if (res->value > 0) {
-                               res->value --;
+                       else {
+                               if (!st->stcf->is_spam) {
+                                       tok->values[id]++;
+                               }
+                               else if (tok->values[id] > 0) {
+                                       /* Unlearning */
+                                       tok->values[id]--;
+                               }
                        }
                }
        }
 
-       return FALSE;
-}
-
-gboolean
-bayes_learn_spam (struct rspamd_classifier * ctx,
-       GTree *input,
-       struct rspamd_classifier_runtime *rt,
-       struct rspamd_task *task,
-       gboolean is_spam,
-       GError **err)
-{
-       g_assert (ctx != NULL);
-       g_assert (input != NULL);
-       g_assert (rt != NULL);
-       g_assert (rt->end_pos > rt->start_pos);
-
-       if (is_spam) {
-               g_tree_foreach (input, bayes_learn_spam_callback, rt);
-       }
-       else {
-               g_tree_foreach (input, bayes_learn_ham_callback, rt);
-       }
-
-
        return TRUE;
 }
index 62abb00521e6169228409ca7736a6fdda58dd6da..52b9a89f7b88e6f8ba28c2f523f7c1d289d8dd6f 100644 (file)
@@ -12,34 +12,31 @@ struct rspamd_task;
 struct rspamd_classifier;
 
 struct token_node_s;
-struct rspamd_classifier_runtime;
 
 struct rspamd_stat_classifier {
        char *name;
        void (*init_func)(rspamd_mempool_t *pool,
-               struct rspamd_classifier *cl);
+                       struct rspamd_classifier *cl);
        gboolean (*classify_func)(struct rspamd_classifier * ctx,
-               GTree *input, struct rspamd_classifier_runtime *rt,
-               struct rspamd_task *task);
+                       GPtrArray *tokens,
+                       struct rspamd_task *task);
        gboolean (*learn_spam_func)(struct rspamd_classifier * ctx,
-               GTree *input, struct rspamd_classifier_runtime *rt,
-               struct rspamd_task *task, gboolean is_spam,
-               GError **err);
+                       GPtrArray *input,
+                       struct rspamd_task *task, gboolean is_spam,
+                       GError **err);
 };
 
 /* Bayes algorithm */
 void bayes_init (rspamd_mempool_t *pool,
-       struct rspamd_classifier *);
-gboolean bayes_classify (struct rspamd_classifier * ctx,
-       GTree *input,
-       struct rspamd_classifier_runtime *rt,
-       struct rspamd_task *task);
-gboolean bayes_learn_spam (struct rspamd_classifier * ctx,
-       GTree *input,
-       struct rspamd_classifier_runtime *rt,
-       struct rspamd_task *task,
-       gboolean is_spam,
-       GError **err);
+               struct rspamd_classifier *);
+gboolean bayes_classify (struct rspamd_classifier *ctx,
+               GPtrArray *tokens,
+               struct rspamd_task *task);
+gboolean bayes_learn_spam (struct rspamd_classifier *ctx,
+               GPtrArray *tokens,
+               struct rspamd_task *task,
+               gboolean is_spam,
+               GError **err);
 
 #endif
 /*
index 6470793674eb1e96f3a81a56b4ce17c340bd5c0d..1cf19d412550d7f97ed3c661b4303d29598a50d9 100644 (file)
@@ -133,6 +133,7 @@ rspamd_stat_init (struct rspamd_config *cfg)
 
                cl = g_slice_alloc0 (sizeof (*cl));
                cl->cfg = clf;
+               cl->ctx = stat_ctx;
                cl->statfiles_ids = g_array_new (FALSE, FALSE, sizeof (gint));
 
                /* Init classifier cache */
index 09fd87fb64baad769837181f6c6238836095f1c4..edced84deda7000102ac435744c25a8d44f56c19 100644 (file)
 #include "backends/backends.h"
 #include "learn_cache/learn_cache.h"
 
-enum stat_process_stage {
-       RSPAMD_STAT_STAGE_PRE = 0,
-       RSPAMD_STAT_STAGE_POST
-};
-
 struct rspamd_statfile_runtime {
        struct rspamd_statfile_config *st;
        gpointer backend_runtime;
@@ -42,29 +37,14 @@ struct rspamd_statfile_runtime {
        guint64 total_hits;
 };
 
-struct rspamd_classifier_runtime {
-       struct rspamd_classifier_config *clcf;
-       struct classifier_ctx *clctx;
-       struct rspamd_stat_classifier *cl;
-       struct rspamd_stat_backend *backend;
-       struct rspamd_tokenizer_runtime *tok;
-       double ham_prob;
-       double spam_prob;
-       enum stat_process_stage stage;
-       guint64 total_spam;
-       guint64 total_ham;
-       guint64 processed_tokens;
-       GList *st_runtime;
-       guint start_pos;
-       guint end_pos;
-       gboolean skipped;
-};
-
 /* Common classifier structure */
 struct rspamd_classifier {
+       struct rspamd_stat_ctx *ctx;
        struct rspamd_stat_cache *cache;
        gpointer cachecf;
        GArray *statfiles_ids;
+       gulong spam_learns;
+       gulong ham_learns;
        struct rspamd_classifier_config *cfg;
 };
 
index 8bdf394b1b387139b41f3a02cd0339e0c35d9019..1506f4d489557d3b6c914520d15711977be572ce 100644 (file)
@@ -37,6 +37,7 @@
 
 static const gint similarity_treshold = 80;
 
+#if 0
 struct preprocess_cb_data {
        struct rspamd_task *task;
        GList *classifier_runtimes;
@@ -910,3 +911,4 @@ rspamd_stat_result_t rspamd_stat_statistics (struct rspamd_task *task,
 
        return RSPAMD_STAT_PROCESS_OK;
 }
+#endif