diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2013-05-23 16:15:46 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2013-05-23 16:15:46 +0100 |
commit | cac53229174befe479e48b7e0d5cb1d81c46c223 (patch) | |
tree | a23093c86e5281f0ec8be02dc17678dacdacb26b /src/classifiers | |
parent | b63a2f1532d41ceb011fe8badbe7a7f1f23ec8a5 (diff) | |
download | rspamd-cac53229174befe479e48b7e0d5cb1d81c46c223.tar.gz rspamd-cac53229174befe479e48b7e0d5cb1d81c46c223.zip |
New chi2square based bayes normalizer.
Diffstat (limited to 'src/classifiers')
-rw-r--r-- | src/classifiers/bayes.c | 270 |
1 files changed, 110 insertions, 160 deletions
diff --git a/src/classifiers/bayes.c b/src/classifiers/bayes.c index 0e68f5d72..f3ad36558 100644 --- a/src/classifiers/bayes.c +++ b/src/classifiers/bayes.c @@ -44,10 +44,7 @@ bayes_error_quark (void) struct bayes_statfile_data { guint64 hits; guint64 total_hits; - double local_probability; - double post_probability; - double corr; - double value; + double value; struct statfile *st; stat_file_t *file; }; @@ -60,8 +57,11 @@ struct bayes_callback_data { stat_file_t *file; struct bayes_statfile_data *statfiles; guint32 statfiles_num; - guint64 learned_tokens; + guint64 total_spam; + guint64 total_ham; + guint64 processed_tokens; gsize max_tokens; + double spam_probability; }; static gboolean @@ -78,7 +78,7 @@ bayes_learn_callback (gpointer key, gpointer value, gpointer data) v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now); if (v == 0 && c > 0) { statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c); - cd->learned_tokens ++; + cd->processed_tokens ++; } else if (v != 0) { if (G_LIKELY (c > 0)) { @@ -90,16 +90,50 @@ bayes_learn_callback (gpointer key, gpointer value, gpointer data) } } statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, v); - cd->learned_tokens ++; + cd->processed_tokens ++; } - if (cd->max_tokens != 0 && cd->learned_tokens > cd->max_tokens) { + if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) { /* Stop learning on max tokens */ return TRUE; } return FALSE; } +/** + * Returns probability of chisquare > value with specified number of freedom + * degrees + * @param value value to test + * @param freedom_deg number of degrees of freedom + * @return + */ +static gdouble +inv_chi_square (gdouble value, gint freedom_deg) +{ + gdouble prob, sum; + gint i; + + if ((freedom_deg & 1) != 0) { + msg_err ("non-odd freedom degrees count: %d", freedom_deg); + return 0; + } + + value /= 2.; + errno = 0; + prob = exp (-value); + if (errno == ERANGE) { + msg_err ("exp overflow"); + return 0; + } + sum = prob; + for (i = 1; i < freedom_deg / 2; i ++) { + prob *= value / (gdouble)i; + sum += prob; + } + + return MIN (1.0, sum); +} + /* * In this callback we calculate local probabilities for tokens */ @@ -107,57 +141,39 @@ static gboolean bayes_classify_callback (gpointer key, gpointer value, gpointer data) { - token_node_t *node = key; + token_node_t *node = key; struct bayes_callback_data *cd = data; - double renorm = 0; guint i; - double local_hits = 0; struct bayes_statfile_data *cur; + guint64 spam_count = 0, ham_count = 0, total_count = 0; + double spam_prob, spam_freq, ham_freq, bayes_spam_prob; for (i = 0; i < cd->statfiles_num; i ++) { cur = &cd->statfiles[i]; cur->value = statfile_pool_get_block (cd->pool, cur->file, node->h1, node->h2, cd->now); if (cur->value > 0) { - cur->total_hits ++; - cur->hits = cur->value; - local_hits += cur->value; + cur->total_hits += cur->value; + if (cur->st->is_spam) { + spam_count ++; + } + else { + ham_count ++; + } + total_count ++; } } - for (i = 0; i < cd->statfiles_num; i ++) { - cur = &cd->statfiles[i]; - cur->local_probability = 0.5 + (cur->value - (local_hits - cur->value)) / - (LOCAL_PROB_DENOM * (1.0 + local_hits)); - renorm += cur->post_probability * cur->local_probability; - } - - for (i = 0; i < cd->statfiles_num; i ++) { - cur = &cd->statfiles[i]; - cur->post_probability = (cur->post_probability * cur->local_probability) / renorm; - if (cur->post_probability < G_MINDOUBLE * 100) { - cur->post_probability = G_MINDOUBLE * 100; - } - } - renorm = 0; - for (i = 0; i < cd->statfiles_num; i ++) { - cur = &cd->statfiles[i]; - renorm += cur->post_probability; - } - /* Renormalize to form sum of probabilities equal to 1 */ - for (i = 0; i < cd->statfiles_num; i ++) { - cur = &cd->statfiles[i]; - cur->post_probability /= renorm; - if (cur->post_probability < G_MINDOUBLE * 10) { - cur->post_probability = G_MINDOUBLE * 100; - } - if (cd->ctx->debug) { - msg_info ("token: %s, statfile: %s, probability: %.4f, post_probability: %.4f", - node->extra, cur->st->symbol, cur->value, cur->post_probability); - } + /* Probability for this token */ + if (total_count > 0) { + spam_freq = ((double)spam_count / MAX (1., (double)cd->total_spam)); + ham_freq = ((double)ham_count / MAX (1., (double)cd->total_ham)); + spam_prob = spam_freq / (spam_freq + ham_freq); + bayes_spam_prob = (0.5 + spam_prob * total_count) / (double)total_count; + cd->spam_probability += log (bayes_spam_prob); + cd->processed_tokens ++; } - cd->learned_tokens ++; - if (cd->max_tokens != 0 && cd->learned_tokens > cd->max_tokens) { + if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) { /* Stop classifying on max tokens */ return TRUE; } @@ -181,15 +197,15 @@ gboolean bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task, lua_State *L) { struct bayes_callback_data data; - gchar *value; - gint nodes, i = 0, cnt, best_num_spam = 0, best_num_ham = 0; - gint minnodes; - guint64 rev, total_learns = 0; - double best_spam = 0., best_ham = 0., total_spam = 0., total_ham = 0.; + gchar *value; + gint nodes, i = 0, selected_st = -1, cnt; + gint minnodes; + guint64 maxhits = 0; + double final_prob; struct statfile *st; - stat_file_t *file; - GList *cur; - char *sumbuf; + stat_file_t *file; + GList *cur; + char *sumbuf; g_assert (pool != NULL); g_assert (ctx != NULL); @@ -219,7 +235,10 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, data.now = time (NULL); data.ctx = ctx; - data.learned_tokens = 0; + data.processed_tokens = 0; + data.spam_probability = 0; + data.total_ham = 0; + data.total_spam = 0; if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) { minnodes = parse_limit (value, -1); data.max_tokens = minnodes; @@ -241,10 +260,12 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, } data.statfiles[i].file = file; data.statfiles[i].st = st; - data.statfiles[i].post_probability = 0.5; - data.statfiles[i].local_probability = 0.5; - statfile_get_revision (file, &rev, NULL); - total_learns += rev; + if (st->is_spam) { + data.total_spam += statfile_get_used_blocks (file); + } + else { + data.total_ham += statfile_get_used_blocks (file); + } cur = g_list_next (cur); i ++; @@ -252,46 +273,39 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, cnt = i; - /* Calculate correction factor */ - for (i = 0; i < cnt; i ++) { - statfile_get_revision (data.statfiles[i].file, &rev, NULL); - data.statfiles[i].corr = ((double)rev / cnt) / (double)total_learns; - } - g_tree_foreach (input, bayes_classify_callback, &data); - for (i = 0; i < cnt; i ++) { - debug_task ("got probability for symbol %s: %.2f", data.statfiles[i].st->symbol, data.statfiles[i].post_probability); - - if (data.statfiles[i].st->is_spam) { - total_spam += data.statfiles[i].post_probability; - if (data.statfiles[i].post_probability > best_spam) { - best_spam = data.statfiles[i].post_probability; - best_num_spam = i; - } - } - else { - total_ham += data.statfiles[i].post_probability; - if (data.statfiles[i].post_probability > best_ham) { - best_ham = data.statfiles[i].post_probability; - best_num_ham = i; - } - } + if (data.processed_tokens == 0 || data.spam_probability == 0) { + final_prob = 0; + } + else { + final_prob = inv_chi_square (-2. * data.spam_probability, 2 * data.processed_tokens); } - - if (total_ham > 0.5 || total_spam > 0.5) { + if (final_prob > 0 && fabs (final_prob - 0.5) > 0.0001) { sumbuf = memory_pool_alloc (task->task_pool, 32); - if (total_ham > total_spam) { - rspamd_snprintf (sumbuf, 32, "%.2f", total_ham); - cur = g_list_prepend (NULL, sumbuf); - insert_result (task, data.statfiles[best_num_ham].st->symbol, total_ham, cur); + for (i = 0; i < cnt; i ++) { + if ((final_prob > 0.5 && !data.statfiles[i].st->is_spam) || + (final_prob < 0.5 && data.statfiles[i].st->is_spam)) { + continue; + } + if (data.statfiles[i].total_hits > maxhits) { + maxhits = data.statfiles[i].total_hits; + selected_st = i; + } + } + if (selected_st == -1) { + msg_err ("unexpected classifier error: cannot select desired statfile"); } else { - rspamd_snprintf (sumbuf, 32, "%.2f", total_spam); + /* Calculate ham probability correctly */ + if (final_prob < 0.5) { + final_prob = 1. - final_prob; + } + rspamd_snprintf (sumbuf, 32, "%.2f", final_prob); cur = g_list_prepend (NULL, sumbuf); - insert_result (task, data.statfiles[best_num_spam].st->symbol, total_spam, cur); + insert_result (task, data.statfiles[selected_st].st->symbol, final_prob, cur); } } @@ -337,8 +351,8 @@ bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symb data.in_class = in_class; data.now = time (NULL); data.ctx = ctx; - data.learned_tokens = 0; - data.learned_tokens = 0; + data.processed_tokens = 0; + data.processed_tokens = 0; if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) { minnodes = parse_limit (value, -1); data.max_tokens = minnodes; @@ -394,7 +408,7 @@ bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symb statfile_pool_unlock_file (pool, data.file); if (sum != NULL) { - *sum = data.learned_tokens; + *sum = data.processed_tokens; } return TRUE; @@ -447,7 +461,7 @@ bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool, data.now = time (NULL); data.ctx = ctx; - data.learned_tokens = 0; + data.processed_tokens = 0; if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) { minnodes = parse_limit (value, -1); data.max_tokens = minnodes; @@ -503,70 +517,6 @@ bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool, GList * bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task) { - struct bayes_callback_data data; - char *value; - int nodes, minnodes, i, cnt; - struct classify_weight *w; - struct statfile *st; - stat_file_t *file; - GList *cur, *resl = NULL; - - g_assert (pool != NULL); - g_assert (ctx != NULL); - - if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { - minnodes = strtol (value, NULL, 10); - nodes = g_tree_nnodes (input); - if (nodes > FEATURE_WINDOW_SIZE) { - nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE; - } - if (nodes < minnodes) { - return NULL; - } - } - - data.statfiles_num = g_list_length (ctx->cfg->statfiles); - data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num); - data.pool = pool; - data.now = time (NULL); - data.ctx = ctx; - - cur = ctx->cfg->statfiles; - i = 0; - while (cur) { - /* Select statfile to learn */ - st = cur->data; - if ((file = statfile_pool_is_open (pool, st->path)) == NULL) { - if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { - msg_warn ("cannot open %s", st->path); - cur = g_list_next (cur); - data.statfiles_num --; - continue; - } - } - data.statfiles[i].file = file; - data.statfiles[i].st = st; - data.statfiles[i].post_probability = 0.5; - data.statfiles[i].local_probability = 0.5; - i ++; - cur = g_list_next (cur); - } - cnt = i; - - g_tree_foreach (input, bayes_classify_callback, &data); - - for (i = 0; i < cnt; i ++) { - w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight)); - w->name = data.statfiles[i].st->symbol; - w->weight = data.statfiles[i].post_probability; - resl = g_list_prepend (resl, w); - } - - g_free (data.statfiles); - - if (resl != NULL) { - memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, resl); - } - - return resl; + /* This function is unimplemented with new normalizer */ + return NULL; } |