diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-01-05 18:02:47 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-01-05 18:02:47 +0000 |
commit | 8be715956820c4b14d3a3d7f41e50e4e69b7e679 (patch) | |
tree | 934d45c3e84b5e4c98982dee360267bd74e693a2 | |
parent | a53bd05ff9eac93f444b7bf8dd1607e16301ef5b (diff) | |
download | rspamd-8be715956820c4b14d3a3d7f41e50e4e69b7e679.tar.gz rspamd-8be715956820c4b14d3a3d7f41e50e4e69b7e679.zip |
Fix bayes classifier for the new architecture
-rw-r--r-- | src/libstat/classifiers/bayes.c | 299 | ||||
-rw-r--r-- | src/libstat/classifiers/classifiers.h | 33 | ||||
-rw-r--r-- | src/libstat/stat_config.c | 1 | ||||
-rw-r--r-- | src/libstat/stat_internal.h | 26 | ||||
-rw-r--r-- | src/libstat/stat_process.c | 2 |
5 files changed, 157 insertions, 204 deletions
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c index a271a424a..0915933f1 100644 --- a/src/libstat/classifiers/bayes.c +++ b/src/libstat/classifiers/bayes.c @@ -90,7 +90,10 @@ inv_chi_square (struct rspamd_task *task, gdouble value, gint freedom_deg) } struct bayes_task_closure { - struct rspamd_classifier_runtime *rt; + double ham_prob; + double spam_prob; + guint64 processed_tokens; + guint64 total_hits; struct rspamd_task *task; }; @@ -104,44 +107,46 @@ static const double feature_weight[] = { 0, 1, 4, 27, 256, 3125, 46656, 823543 } /* * In this callback we calculate local probabilities for tokens */ -static gboolean -bayes_classify_callback (gpointer key, gpointer value, gpointer data) +static void +bayes_classify_token (struct rspamd_classifier *ctx, + rspamd_token_t *tok, struct bayes_task_closure *cl) { - rspamd_token_t *node = value; - struct bayes_task_closure *cl = data; - struct rspamd_classifier_runtime *rt; guint i; - struct rspamd_token_result *res; + gint id; guint64 spam_count = 0, ham_count = 0, total_count = 0; + struct rspamd_statfile *st; struct rspamd_task *task; double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob, - ham_prob, fw, w, norm_sum, norm_sub; + ham_prob, fw, w, norm_sum, norm_sub, val; - rt = cl->rt; task = cl->task; - for (i = rt->start_pos; i < rt->end_pos; i++) { - res = &g_array_index (node->results, struct rspamd_token_result, i); + for (i = 0; i < ctx->statfiles_ids->len; i++) { + id = g_array_index (ctx->statfiles_ids, gint, i); + st = g_ptr_array_index (ctx->ctx->statfiles, id); + g_assert (st != NULL); + val = tok->values[id]; - if (res->value > 0) { - if (res->st_runtime->st->is_spam) { - spam_count += res->value; + if (val > 0) { + if (st->stcf->is_spam) { + spam_count += val; } else { - ham_count += res->value; + ham_count += val; } - total_count += res->value; - res->st_runtime->total_hits += res->value; + + total_count += val; + cl->total_hits += val; } } /* Probability for this token */ if (total_count > 0) { - spam_freq = ((double)spam_count / MAX (1., (double)rt->total_spam)); - ham_freq = ((double)ham_count / MAX (1., (double)rt->total_ham)); + spam_freq = ((double)spam_count / MAX (1., (double) ctx->spam_learns)); + ham_freq = ((double)ham_count / MAX (1., (double)ctx->ham_learns)); spam_prob = spam_freq / (spam_freq + ham_freq); ham_prob = ham_freq / (spam_freq + ham_freq); - fw = feature_weight[node->window_idx % G_N_ELEMENTS (feature_weight)]; + fw = feature_weight[tok->window_idx % G_N_ELEMENTS (feature_weight)]; norm_sum = (spam_freq + ham_freq) * (spam_freq + ham_freq); norm_sub = (spam_freq - ham_freq) * (spam_freq - ham_freq); w = (norm_sub) / (norm_sum) * @@ -151,9 +156,9 @@ bayes_classify_callback (gpointer key, gpointer value, gpointer data) w = (norm_sub) / (norm_sum) * (fw * total_count) / (4.0 * (1.0 + fw * total_count)); bayes_ham_prob = PROB_COMBINE (ham_prob, total_count, w, 0.5); - rt->spam_prob += log (bayes_spam_prob); - rt->ham_prob += log (bayes_ham_prob); - res->cl_runtime->processed_tokens ++; + cl->spam_prob += log (bayes_spam_prob); + cl->ham_prob += log (bayes_ham_prob); + cl->processed_tokens ++; msg_debug_bayes ("token: weight: %f, total_count: %L, " "spam_count: %L, ham_count: %L," @@ -163,10 +168,8 @@ bayes_classify_callback (gpointer key, gpointer value, gpointer data) fw, total_count, spam_count, ham_count, spam_prob, ham_prob, bayes_spam_prob, bayes_ham_prob, - rt->spam_prob, rt->ham_prob); + cl->spam_prob, cl->ham_prob); } - - return FALSE; } /* @@ -198,176 +201,146 @@ bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier *cl) gboolean bayes_classify (struct rspamd_classifier * ctx, - GTree *input, - struct rspamd_classifier_runtime *rt, - struct rspamd_task *task) + GPtrArray *tokens, + struct rspamd_task *task) { double final_prob, h, s; - guint maxhits = 0; - struct rspamd_statfile_runtime *st, *selected_st = NULL; - GList *cur; char *sumbuf; + struct rspamd_statfile *st = NULL; struct bayes_task_closure cl; + rspamd_token_t *tok; + guint i; + gint id; + GList *cur; g_assert (ctx != NULL); - g_assert (input != NULL); - g_assert (rt != NULL); - g_assert (rt->end_pos > rt->start_pos); - - if (rt->stage == RSPAMD_STAT_STAGE_PRE) { - cl.rt = rt; - cl.task = task; - g_tree_foreach (input, bayes_classify_callback, &cl); + g_assert (tokens != NULL); + + memset (&cl, 0, sizeof (cl)); + cl.task = task; + + for (i = 0; i < tokens->len; i ++) { + tok = g_ptr_array_index (tokens, i); + + bayes_classify_token (ctx, tok, &cl); + } + + h = 1 - inv_chi_square (task, cl.spam_prob, cl.processed_tokens); + s = 1 - inv_chi_square (task, cl.ham_prob, cl.processed_tokens); + + if (isfinite (s) && isfinite (h)) { + final_prob = (s + 1.0 - h) / 2.; + msg_debug_bayes ( + "<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f," + " %L tokens processed of %ud total tokens", + task->message_id, + cl.ham_prob, + h, + cl.spam_prob, + s, + cl.processed_tokens, + tokens->len); } else { - h = 1 - inv_chi_square (task, rt->spam_prob, rt->processed_tokens); - s = 1 - inv_chi_square (task, rt->ham_prob, rt->processed_tokens); - - if (isfinite (s) && isfinite (h)) { - final_prob = (s + 1.0 - h) / 2.; - msg_debug_bayes ("<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f," - " %L tokens processed of %ud total tokens", - task->message_id, rt->ham_prob, h, rt->spam_prob, s, - rt->processed_tokens, g_tree_nnodes (input)); + /* + * We have some overflow, hence we need to check which class + * is NaN + */ + if (isfinite (h)) { + final_prob = 1.0; + msg_debug_bayes ("<%s> spam class is overflowed, as we have no" + " ham samples", task->message_id); + } + else if (isfinite (s)) { + final_prob = 0.0; + msg_debug_bayes ("<%s> ham class is overflowed, as we have no" + " spam samples", task->message_id); } else { - /* - * We have some overflow, hence we need to check which class - * is NaN - */ - if (isfinite (h)) { - final_prob = 1.0; - msg_debug_bayes ("<%s> spam class is overflowed, as we have no" - " ham samples", task->message_id); - } - else if (isfinite (s)){ - final_prob = 0.0; - msg_debug_bayes ("<%s> ham class is overflowed, as we have no" - " spam samples", task->message_id); - } - else { - final_prob = 0.5; - msg_warn_bayes ("<%s> spam and ham classes are both overflowed", - task->message_id); - } + final_prob = 0.5; + msg_warn_bayes ("<%s> spam and ham classes are both overflowed", + task->message_id); } + } - if (rt->processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) { - - sumbuf = rspamd_mempool_alloc (task->task_pool, 32); - cur = g_list_first (rt->st_runtime); + if (cl.processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) { - while (cur) { - st = (struct rspamd_statfile_runtime *)cur->data; + sumbuf = rspamd_mempool_alloc (task->task_pool, 32); - if ((final_prob < 0.5 && !st->st->is_spam) || - (final_prob > 0.5 && st->st->is_spam)) { - if (st->total_hits > maxhits) { - maxhits = st->total_hits; - selected_st = st; - } - } + /* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */ + for (i = 0; i < ctx->statfiles_ids->len; i++) { + id = g_array_index (ctx->statfiles_ids, gint, i); + st = g_ptr_array_index (ctx->ctx->statfiles, id); - cur = g_list_next (cur); + if (final_prob > 0.5 && st->stcf->is_spam) { + break; } - - if (selected_st == NULL) { - msg_err_bayes ( - "unexpected classifier error: cannot select desired statfile, " - "prob: %.4f", final_prob); + else if (final_prob < 0.5 && !st->stcf->is_spam) { + break; } - else { - /* Correctly scale HAM */ - if (final_prob < 0.5) { - final_prob = 1.0 - final_prob; - } - - rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.); - final_prob = bayes_normalize_prob (final_prob); + } - cur = g_list_prepend (NULL, sumbuf); - rspamd_task_insert_result (task, - selected_st->st->symbol, - final_prob, - cur); - } + /* Correctly scale HAM */ + if (final_prob < 0.5) { + final_prob = 1.0 - final_prob; } + + rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.); + final_prob = bayes_normalize_prob (final_prob); + g_assert (st != NULL); + cur = g_list_prepend (NULL, sumbuf); + rspamd_task_insert_result (task, + st->stcf->symbol, + final_prob, + cur); } return TRUE; } -static gboolean -bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data) +gboolean +bayes_learn_spam (struct rspamd_classifier * ctx, + GPtrArray *tokens, + struct rspamd_task *task, + gboolean is_spam, + GError **err) { - rspamd_token_t *node = value; - struct rspamd_token_result *res; - struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data; - guint i; - + guint i, j; + gint id; + struct rspamd_statfile *st; + rspamd_token_t *tok; - for (i = rt->start_pos; i < rt->end_pos; i++) { - res = &g_array_index (node->results, struct rspamd_token_result, i); - - if (res->st_runtime) { - if (res->st_runtime->st->is_spam) { - res->value ++; - } - else if (res->value > 0) { - /* Unlearning */ - res->value --; - } - } - } - - return FALSE; -} - -static gboolean -bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data) -{ - rspamd_token_t *node = value; - struct rspamd_token_result *res; - struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data; - guint i; + g_assert (ctx != NULL); + g_assert (tokens != NULL); + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index (tokens, i); - for (i = rt->start_pos; i < rt->end_pos; i++) { - res = &g_array_index (node->results, struct rspamd_token_result, i); + for (j = 0; j < ctx->statfiles_ids->len; j++) { + id = g_array_index (ctx->statfiles_ids, gint, j); + st = g_ptr_array_index (ctx->ctx->statfiles, id); + g_assert (st != NULL); - if (res->st_runtime) { - if (!res->st_runtime->st->is_spam) { - res->value ++; + if (is_spam) { + if (st->stcf->is_spam) { + tok->values[id]++; + } + else if (tok->values[id] > 0) { + /* Unlearning */ + tok->values[id]--; + } } - else if (res->value > 0) { - res->value --; + else { + if (!st->stcf->is_spam) { + tok->values[id]++; + } + else if (tok->values[id] > 0) { + /* Unlearning */ + tok->values[id]--; + } } } } - return FALSE; -} - -gboolean -bayes_learn_spam (struct rspamd_classifier * ctx, - GTree *input, - struct rspamd_classifier_runtime *rt, - struct rspamd_task *task, - gboolean is_spam, - GError **err) -{ - g_assert (ctx != NULL); - g_assert (input != NULL); - g_assert (rt != NULL); - g_assert (rt->end_pos > rt->start_pos); - - if (is_spam) { - g_tree_foreach (input, bayes_learn_spam_callback, rt); - } - else { - g_tree_foreach (input, bayes_learn_ham_callback, rt); - } - - return TRUE; } diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h index 62abb0052..52b9a89f7 100644 --- a/src/libstat/classifiers/classifiers.h +++ b/src/libstat/classifiers/classifiers.h @@ -12,34 +12,31 @@ struct rspamd_task; struct rspamd_classifier; struct token_node_s; -struct rspamd_classifier_runtime; struct rspamd_stat_classifier { char *name; void (*init_func)(rspamd_mempool_t *pool, - struct rspamd_classifier *cl); + struct rspamd_classifier *cl); gboolean (*classify_func)(struct rspamd_classifier * ctx, - GTree *input, struct rspamd_classifier_runtime *rt, - struct rspamd_task *task); + GPtrArray *tokens, + struct rspamd_task *task); gboolean (*learn_spam_func)(struct rspamd_classifier * ctx, - GTree *input, struct rspamd_classifier_runtime *rt, - struct rspamd_task *task, gboolean is_spam, - GError **err); + GPtrArray *input, + struct rspamd_task *task, gboolean is_spam, + GError **err); }; /* Bayes algorithm */ void bayes_init (rspamd_mempool_t *pool, - struct rspamd_classifier *); -gboolean bayes_classify (struct rspamd_classifier * ctx, - GTree *input, - struct rspamd_classifier_runtime *rt, - struct rspamd_task *task); -gboolean bayes_learn_spam (struct rspamd_classifier * ctx, - GTree *input, - struct rspamd_classifier_runtime *rt, - struct rspamd_task *task, - gboolean is_spam, - GError **err); + struct rspamd_classifier *); +gboolean bayes_classify (struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task); +gboolean bayes_learn_spam (struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task, + gboolean is_spam, + GError **err); #endif /* diff --git a/src/libstat/stat_config.c b/src/libstat/stat_config.c index 647079367..1cf19d412 100644 --- a/src/libstat/stat_config.c +++ b/src/libstat/stat_config.c @@ -133,6 +133,7 @@ rspamd_stat_init (struct rspamd_config *cfg) cl = g_slice_alloc0 (sizeof (*cl)); cl->cfg = clf; + cl->ctx = stat_ctx; cl->statfiles_ids = g_array_new (FALSE, FALSE, sizeof (gint)); /* Init classifier cache */ diff --git a/src/libstat/stat_internal.h b/src/libstat/stat_internal.h index 09fd87fb6..edced84de 100644 --- a/src/libstat/stat_internal.h +++ b/src/libstat/stat_internal.h @@ -30,11 +30,6 @@ #include "backends/backends.h" #include "learn_cache/learn_cache.h" -enum stat_process_stage { - RSPAMD_STAT_STAGE_PRE = 0, - RSPAMD_STAT_STAGE_POST -}; - struct rspamd_statfile_runtime { struct rspamd_statfile_config *st; gpointer backend_runtime; @@ -42,29 +37,14 @@ struct rspamd_statfile_runtime { guint64 total_hits; }; -struct rspamd_classifier_runtime { - struct rspamd_classifier_config *clcf; - struct classifier_ctx *clctx; - struct rspamd_stat_classifier *cl; - struct rspamd_stat_backend *backend; - struct rspamd_tokenizer_runtime *tok; - double ham_prob; - double spam_prob; - enum stat_process_stage stage; - guint64 total_spam; - guint64 total_ham; - guint64 processed_tokens; - GList *st_runtime; - guint start_pos; - guint end_pos; - gboolean skipped; -}; - /* Common classifier structure */ struct rspamd_classifier { + struct rspamd_stat_ctx *ctx; struct rspamd_stat_cache *cache; gpointer cachecf; GArray *statfiles_ids; + gulong spam_learns; + gulong ham_learns; struct rspamd_classifier_config *cfg; }; diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 8bdf394b1..1506f4d48 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -37,6 +37,7 @@ static const gint similarity_treshold = 80; +#if 0 struct preprocess_cb_data { struct rspamd_task *task; GList *classifier_runtimes; @@ -910,3 +911,4 @@ rspamd_stat_result_t rspamd_stat_statistics (struct rspamd_task *task, return RSPAMD_STAT_PROCESS_OK; } +#endif |