diff options
-rw-r--r-- | src/libstat/classifiers/bayes.c | 277 | ||||
-rw-r--r-- | src/libstat/classifiers/classifiers.h | 11 |
2 files changed, 100 insertions, 188 deletions
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c index 6e068b79d..54d696721 100644 --- a/src/libstat/classifiers/bayes.c +++ b/src/libstat/classifiers/bayes.c @@ -39,63 +39,6 @@ bayes_error_quark (void) return g_quark_from_static_string ("bayes-error"); } -struct bayes_statfile_data { - guint64 hits; - guint64 total_hits; - double value; - struct rspamd_statfile_config *st; -}; - -struct bayes_callback_data { - struct classifier_ctx *ctx; - gboolean in_class; - time_t now; - struct bayes_statfile_data *statfiles; - guint32 statfiles_num; - guint64 total_spam; - guint64 total_ham; - guint64 processed_tokens; - gsize max_tokens; - double spam_probability; - double ham_probability; -}; - -static gboolean -bayes_learn_callback (gpointer key, gpointer value, gpointer data) -{ - rspamd_token_t *node = key; - struct bayes_callback_data *cd = data; - gint c; - guint64 v; - - c = (cd->in_class) ? 1 : -1; - - /* Consider that not found blocks have value 1 */ - /* XXX implement getting */ - if (v == 0 && c > 0) { - /* XXX: add token to the backend */ - cd->processed_tokens++; - } - else if (v != 0) { - if (G_LIKELY (c > 0)) { - v++; - } - else if (c < 0) { - if (v != 0) { - v--; - } - } - /* XXX: Implement setting */ - cd->processed_tokens++; - } - - if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) { - /* Stop learning on max tokens */ - return TRUE; - } - return FALSE; -} - /** * Returns probability of chisquare > value with specified number of freedom * degrees @@ -142,45 +85,35 @@ inv_chi_square (gdouble value, gint freedom_deg) static gboolean bayes_classify_callback (gpointer key, gpointer value, gpointer data) { - - rspamd_token_t *node = key; - struct bayes_callback_data *cd = data; + rspamd_token_t *node = value; + struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data; guint i; - struct bayes_statfile_data *cur; + struct rspamd_token_result *res; guint64 spam_count = 0, ham_count = 0, total_count = 0; double spam_prob, spam_freq, ham_freq, bayes_spam_prob; - for (i = 0; i < cd->statfiles_num; i++) { - cur = &cd->statfiles[i]; - /* - * XXX: Implement getting - */ - if (cur->value > 0) { - cur->total_hits += cur->value; - if (cur->st->is_spam) { - spam_count += cur->value; + for (i = rt->start_pos; i < rt->end_pos; i++) { + res = &g_array_index (node->results, struct rspamd_token_result, i); + + if (res->value > 0) { + if (res->st_runtime->st->is_spam) { + spam_count += res->value; } else { - ham_count += cur->value; + ham_count += res->value; } - total_count += cur->value; + total_count += res->value; } } /* Probability for this token */ if (total_count > 0) { - spam_freq = ((double)spam_count / MAX (1., (double)cd->total_spam)); - ham_freq = ((double)ham_count / MAX (1., (double)cd->total_ham)); + spam_freq = ((double)spam_count / MAX (1., (double)rt->total_spam)); + ham_freq = ((double)ham_count / MAX (1., (double)rt->total_ham)); spam_prob = spam_freq / (spam_freq + ham_freq); bayes_spam_prob = (0.5 + spam_prob * total_count) / (1. + total_count); - cd->spam_probability += log (bayes_spam_prob); - cd->ham_probability += log (1. - bayes_spam_prob); - cd->processed_tokens++; - } - - if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) { - /* Stop classifying on max tokens */ - return TRUE; + rt->spam_prob += log (bayes_spam_prob); + rt->ham_prob += log (1. - bayes_spam_prob); } return FALSE; @@ -202,91 +135,53 @@ bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier_config *cfg) gboolean bayes_classify (struct classifier_ctx * ctx, GTree *input, + struct rspamd_classifier_runtime *rt, struct rspamd_task *task) { - struct bayes_callback_data data; - gchar *value; - gint nodes, i = 0, selected_st = -1, cnt; - gint minnodes; - guint64 maxhits = 0, rev; double final_prob, h, s; - struct rspamd_statfile_config *st; + guint maxhits = 0; + struct rspamd_statfile_runtime *st, *selected_st = NULL; GList *cur; char *sumbuf; g_assert (ctx != NULL); + g_assert (input != NULL); + g_assert (rt != NULL); + g_assert (rt->end_pos > rt->start_pos); - if (ctx->cfg->opts && - (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { - minnodes = strtol (value, NULL, 10); - nodes = g_tree_nnodes (input); + g_tree_foreach (input, bayes_classify_callback, &rt); - if (nodes < minnodes) { - return FALSE; - } - } - - cur = ctx->cfg->statfiles; -#if 0 - cur = rspamd_lua_call_cls_pre_callbacks (ctx->cfg, task, FALSE, FALSE, L); - if (cur) { - rspamd_mempool_add_destructor (task->task_pool, - (rspamd_mempool_destruct_t)g_list_free, cur); - } - else { - cur = ctx->cfg->statfiles; - } -#endif - - - data.statfiles_num = g_list_length (cur); - data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num); - data.now = time (NULL); - data.ctx = ctx; - - data.processed_tokens = 0; - data.spam_probability = 0; - data.ham_probability = 0; - data.total_ham = 0; - data.total_spam = 0; - if (ctx->cfg->opts && - (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) { - minnodes = rspamd_config_parse_limit (value, -1); - data.max_tokens = minnodes; - } - else { - data.max_tokens = 0; - } - - cnt = i; - - g_tree_foreach (input, bayes_classify_callback, &data); - - if (data.processed_tokens == 0 || data.spam_probability == 0) { + if (rt->spam_prob == 0) { final_prob = 0; } else { - h = 1 - inv_chi_square (-2. * data.spam_probability, - 2 * data.processed_tokens); - s = 1 - inv_chi_square (-2. * data.ham_probability, - 2 * data.processed_tokens); + h = 1 - inv_chi_square (-2. * rt->spam_prob, + 2 * rt->processed_tokens); + s = 1 - inv_chi_square (-2. * rt->ham_prob, + 2 * rt->processed_tokens); final_prob = (s + 1 - h) / 2.; } - if (data.processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) { + if (rt->processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) { sumbuf = rspamd_mempool_alloc (task->task_pool, 32); - for (i = 0; i < cnt; i++) { - if ((final_prob > 0.5 && !data.statfiles[i].st->is_spam) || - (final_prob < 0.5 && data.statfiles[i].st->is_spam)) { - continue; - } - if (data.statfiles[i].total_hits > maxhits) { - maxhits = data.statfiles[i].total_hits; - selected_st = i; + cur = g_list_first (rt->st_runtime); + + while (cur) { + st = (struct rspamd_statfile_runtime *)cur->data; + + if ((final_prob < 0.5 && !st->st->is_spam) || + (final_prob > 0.5 && st->st->is_spam)) { + if (st->total_hits > maxhits) { + maxhits = st->total_hits; + selected_st = st; + } } + + cur = g_list_next (cur); } - if (selected_st == -1) { + + if (selected_st == NULL) { msg_err ( "unexpected classifier error: cannot select desired statfile"); } @@ -298,65 +193,75 @@ bayes_classify (struct classifier_ctx * ctx, rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.); cur = g_list_prepend (NULL, sumbuf); rspamd_task_insert_result (task, - data.statfiles[selected_st].st->symbol, + selected_st->st->symbol, final_prob, cur); } } - g_free (data.statfiles); - return TRUE; } +static gboolean +bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data) +{ + rspamd_token_t *node = value; + struct rspamd_token_result *res; + struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data; + guint i; + + + for (i = rt->start_pos; i < rt->end_pos; i++) { + res = &g_array_index (node->results, struct rspamd_token_result, i); + + if (res->st_runtime->st->is_spam) { + res->value ++; + } + } + + return FALSE; +} + +static gboolean +bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data) +{ + rspamd_token_t *node = value; + struct rspamd_token_result *res; + struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data; + guint i; + + + for (i = rt->start_pos; i < rt->end_pos; i++) { + res = &g_array_index (node->results, struct rspamd_token_result, i); + + if (!res->st_runtime->st->is_spam) { + res->value ++; + } + } + + return FALSE; +} + gboolean bayes_learn_spam (struct classifier_ctx * ctx, GTree *input, + struct rspamd_classifier_runtime *rt, struct rspamd_task *task, gboolean is_spam, GError **err) { - struct bayes_callback_data data; - gchar *value; - gint nodes; - gint minnodes; - struct rspamd_statfile_config *st; - GList *cur; - gboolean skip_labels; - g_assert (ctx != NULL); + g_assert (input != NULL); + g_assert (rt != NULL); + g_assert (rt->end_pos > rt->start_pos); - if (ctx->cfg->opts && - (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { - minnodes = strtol (value, NULL, 10); - nodes = g_tree_nnodes (input); - - if (nodes < minnodes) { - g_set_error (err, - bayes_error_quark (), /* error domain */ - 1, /* error code */ - "message contains too few tokens: %d, while min is %d", - nodes, (int)minnodes); - return FALSE; - } - } - - data.now = time (NULL); - data.ctx = ctx; - data.in_class = TRUE; - - data.processed_tokens = 0; - if (ctx->cfg->opts && - (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) { - minnodes = rspamd_config_parse_limit (value, -1); - data.max_tokens = minnodes; + if (is_spam) { + g_tree_foreach (input, bayes_learn_spam_callback, rt); } else { - data.max_tokens = 0; + g_tree_foreach (input, bayes_learn_ham_callback, rt); } - g_tree_foreach (input, bayes_learn_callback, &data); - return TRUE; } diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h index e2bf57f81..9a30039df 100644 --- a/src/libstat/classifiers/classifiers.h +++ b/src/libstat/classifiers/classifiers.h @@ -18,14 +18,19 @@ struct classifier_ctx { struct rspamd_classifier_config *cfg; }; +struct token_node_s; +struct rspamd_classifier_runtime; + struct rspamd_stat_classifier { char *name; struct classifier_ctx * (*init_func)(rspamd_mempool_t *pool, struct rspamd_classifier_config *cf); gboolean (*classify_func)(struct classifier_ctx * ctx, - GTree *input, struct rspamd_task *task); + GTree *input, struct rspamd_classifier_runtime *rt, + struct rspamd_task *task); gboolean (*learn_spam_func)(struct classifier_ctx * ctx, - GTree *input, struct rspamd_task *task, gboolean is_spam, + GTree *input, struct rspamd_classifier_runtime *rt, + struct rspamd_task *task, gboolean is_spam, GError **err); }; @@ -34,9 +39,11 @@ struct classifier_ctx * bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier_config *cf); gboolean bayes_classify (struct classifier_ctx * ctx, GTree *input, + struct rspamd_classifier_runtime *rt, struct rspamd_task *task); gboolean bayes_learn_spam (struct classifier_ctx * ctx, GTree *input, + struct rspamd_classifier_runtime *rt, struct rspamd_task *task, gboolean is_spam, GError **err); |