aboutsummaryrefslogtreecommitdiffstats
path: root/src/classifiers
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2013-05-23 16:15:46 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2013-05-23 16:15:46 +0100
commitcac53229174befe479e48b7e0d5cb1d81c46c223 (patch)
treea23093c86e5281f0ec8be02dc17678dacdacb26b /src/classifiers
parentb63a2f1532d41ceb011fe8badbe7a7f1f23ec8a5 (diff)
downloadrspamd-cac53229174befe479e48b7e0d5cb1d81c46c223.tar.gz
rspamd-cac53229174befe479e48b7e0d5cb1d81c46c223.zip
New chi2square based bayes normalizer.
Diffstat (limited to 'src/classifiers')
-rw-r--r--src/classifiers/bayes.c270
1 files changed, 110 insertions, 160 deletions
diff --git a/src/classifiers/bayes.c b/src/classifiers/bayes.c
index 0e68f5d72..f3ad36558 100644
--- a/src/classifiers/bayes.c
+++ b/src/classifiers/bayes.c
@@ -44,10 +44,7 @@ bayes_error_quark (void)
struct bayes_statfile_data {
guint64 hits;
guint64 total_hits;
- double local_probability;
- double post_probability;
- double corr;
- double value;
+ double value;
struct statfile *st;
stat_file_t *file;
};
@@ -60,8 +57,11 @@ struct bayes_callback_data {
stat_file_t *file;
struct bayes_statfile_data *statfiles;
guint32 statfiles_num;
- guint64 learned_tokens;
+ guint64 total_spam;
+ guint64 total_ham;
+ guint64 processed_tokens;
gsize max_tokens;
+ double spam_probability;
};
static gboolean
@@ -78,7 +78,7 @@ bayes_learn_callback (gpointer key, gpointer value, gpointer data)
v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
if (v == 0 && c > 0) {
statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c);
- cd->learned_tokens ++;
+ cd->processed_tokens ++;
}
else if (v != 0) {
if (G_LIKELY (c > 0)) {
@@ -90,16 +90,50 @@ bayes_learn_callback (gpointer key, gpointer value, gpointer data)
}
}
statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, v);
- cd->learned_tokens ++;
+ cd->processed_tokens ++;
}
- if (cd->max_tokens != 0 && cd->learned_tokens > cd->max_tokens) {
+ if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) {
/* Stop learning on max tokens */
return TRUE;
}
return FALSE;
}
+/**
+ * Returns probability of chisquare > value with specified number of freedom
+ * degrees
+ * @param value value to test
+ * @param freedom_deg number of degrees of freedom
+ * @return
+ */
+static gdouble
+inv_chi_square (gdouble value, gint freedom_deg)
+{
+ gdouble prob, sum;
+ gint i;
+
+ if ((freedom_deg & 1) != 0) {
+ msg_err ("non-odd freedom degrees count: %d", freedom_deg);
+ return 0;
+ }
+
+ value /= 2.;
+ errno = 0;
+ prob = exp (-value);
+ if (errno == ERANGE) {
+ msg_err ("exp overflow");
+ return 0;
+ }
+ sum = prob;
+ for (i = 1; i < freedom_deg / 2; i ++) {
+ prob *= value / (gdouble)i;
+ sum += prob;
+ }
+
+ return MIN (1.0, sum);
+}
+
/*
* In this callback we calculate local probabilities for tokens
*/
@@ -107,57 +141,39 @@ static gboolean
bayes_classify_callback (gpointer key, gpointer value, gpointer data)
{
- token_node_t *node = key;
+ token_node_t *node = key;
struct bayes_callback_data *cd = data;
- double renorm = 0;
guint i;
- double local_hits = 0;
struct bayes_statfile_data *cur;
+ guint64 spam_count = 0, ham_count = 0, total_count = 0;
+ double spam_prob, spam_freq, ham_freq, bayes_spam_prob;
for (i = 0; i < cd->statfiles_num; i ++) {
cur = &cd->statfiles[i];
cur->value = statfile_pool_get_block (cd->pool, cur->file, node->h1, node->h2, cd->now);
if (cur->value > 0) {
- cur->total_hits ++;
- cur->hits = cur->value;
- local_hits += cur->value;
+ cur->total_hits += cur->value;
+ if (cur->st->is_spam) {
+ spam_count ++;
+ }
+ else {
+ ham_count ++;
+ }
+ total_count ++;
}
}
- for (i = 0; i < cd->statfiles_num; i ++) {
- cur = &cd->statfiles[i];
- cur->local_probability = 0.5 + (cur->value - (local_hits - cur->value)) /
- (LOCAL_PROB_DENOM * (1.0 + local_hits));
- renorm += cur->post_probability * cur->local_probability;
- }
-
- for (i = 0; i < cd->statfiles_num; i ++) {
- cur = &cd->statfiles[i];
- cur->post_probability = (cur->post_probability * cur->local_probability) / renorm;
- if (cur->post_probability < G_MINDOUBLE * 100) {
- cur->post_probability = G_MINDOUBLE * 100;
- }
- }
- renorm = 0;
- for (i = 0; i < cd->statfiles_num; i ++) {
- cur = &cd->statfiles[i];
- renorm += cur->post_probability;
- }
- /* Renormalize to form sum of probabilities equal to 1 */
- for (i = 0; i < cd->statfiles_num; i ++) {
- cur = &cd->statfiles[i];
- cur->post_probability /= renorm;
- if (cur->post_probability < G_MINDOUBLE * 10) {
- cur->post_probability = G_MINDOUBLE * 100;
- }
- if (cd->ctx->debug) {
- msg_info ("token: %s, statfile: %s, probability: %.4f, post_probability: %.4f",
- node->extra, cur->st->symbol, cur->value, cur->post_probability);
- }
+ /* Probability for this token */
+ if (total_count > 0) {
+ spam_freq = ((double)spam_count / MAX (1., (double)cd->total_spam));
+ ham_freq = ((double)ham_count / MAX (1., (double)cd->total_ham));
+ spam_prob = spam_freq / (spam_freq + ham_freq);
+ bayes_spam_prob = (0.5 + spam_prob * total_count) / (double)total_count;
+ cd->spam_probability += log (bayes_spam_prob);
+ cd->processed_tokens ++;
}
- cd->learned_tokens ++;
- if (cd->max_tokens != 0 && cd->learned_tokens > cd->max_tokens) {
+ if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) {
/* Stop classifying on max tokens */
return TRUE;
}
@@ -181,15 +197,15 @@ gboolean
bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task, lua_State *L)
{
struct bayes_callback_data data;
- gchar *value;
- gint nodes, i = 0, cnt, best_num_spam = 0, best_num_ham = 0;
- gint minnodes;
- guint64 rev, total_learns = 0;
- double best_spam = 0., best_ham = 0., total_spam = 0., total_ham = 0.;
+ gchar *value;
+ gint nodes, i = 0, selected_st = -1, cnt;
+ gint minnodes;
+ guint64 maxhits = 0;
+ double final_prob;
struct statfile *st;
- stat_file_t *file;
- GList *cur;
- char *sumbuf;
+ stat_file_t *file;
+ GList *cur;
+ char *sumbuf;
g_assert (pool != NULL);
g_assert (ctx != NULL);
@@ -219,7 +235,10 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input,
data.now = time (NULL);
data.ctx = ctx;
- data.learned_tokens = 0;
+ data.processed_tokens = 0;
+ data.spam_probability = 0;
+ data.total_ham = 0;
+ data.total_spam = 0;
if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
minnodes = parse_limit (value, -1);
data.max_tokens = minnodes;
@@ -241,10 +260,12 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input,
}
data.statfiles[i].file = file;
data.statfiles[i].st = st;
- data.statfiles[i].post_probability = 0.5;
- data.statfiles[i].local_probability = 0.5;
- statfile_get_revision (file, &rev, NULL);
- total_learns += rev;
+ if (st->is_spam) {
+ data.total_spam += statfile_get_used_blocks (file);
+ }
+ else {
+ data.total_ham += statfile_get_used_blocks (file);
+ }
cur = g_list_next (cur);
i ++;
@@ -252,46 +273,39 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input,
cnt = i;
- /* Calculate correction factor */
- for (i = 0; i < cnt; i ++) {
- statfile_get_revision (data.statfiles[i].file, &rev, NULL);
- data.statfiles[i].corr = ((double)rev / cnt) / (double)total_learns;
- }
-
g_tree_foreach (input, bayes_classify_callback, &data);
- for (i = 0; i < cnt; i ++) {
- debug_task ("got probability for symbol %s: %.2f", data.statfiles[i].st->symbol, data.statfiles[i].post_probability);
-
- if (data.statfiles[i].st->is_spam) {
- total_spam += data.statfiles[i].post_probability;
- if (data.statfiles[i].post_probability > best_spam) {
- best_spam = data.statfiles[i].post_probability;
- best_num_spam = i;
- }
- }
- else {
- total_ham += data.statfiles[i].post_probability;
- if (data.statfiles[i].post_probability > best_ham) {
- best_ham = data.statfiles[i].post_probability;
- best_num_ham = i;
- }
- }
+ if (data.processed_tokens == 0 || data.spam_probability == 0) {
+ final_prob = 0;
+ }
+ else {
+ final_prob = inv_chi_square (-2. * data.spam_probability, 2 * data.processed_tokens);
}
-
- if (total_ham > 0.5 || total_spam > 0.5) {
+ if (final_prob > 0 && fabs (final_prob - 0.5) > 0.0001) {
sumbuf = memory_pool_alloc (task->task_pool, 32);
- if (total_ham > total_spam) {
- rspamd_snprintf (sumbuf, 32, "%.2f", total_ham);
- cur = g_list_prepend (NULL, sumbuf);
- insert_result (task, data.statfiles[best_num_ham].st->symbol, total_ham, cur);
+ for (i = 0; i < cnt; i ++) {
+ if ((final_prob > 0.5 && !data.statfiles[i].st->is_spam) ||
+ (final_prob < 0.5 && data.statfiles[i].st->is_spam)) {
+ continue;
+ }
+ if (data.statfiles[i].total_hits > maxhits) {
+ maxhits = data.statfiles[i].total_hits;
+ selected_st = i;
+ }
+ }
+ if (selected_st == -1) {
+ msg_err ("unexpected classifier error: cannot select desired statfile");
}
else {
- rspamd_snprintf (sumbuf, 32, "%.2f", total_spam);
+ /* Calculate ham probability correctly */
+ if (final_prob < 0.5) {
+ final_prob = 1. - final_prob;
+ }
+ rspamd_snprintf (sumbuf, 32, "%.2f", final_prob);
cur = g_list_prepend (NULL, sumbuf);
- insert_result (task, data.statfiles[best_num_spam].st->symbol, total_spam, cur);
+ insert_result (task, data.statfiles[selected_st].st->symbol, final_prob, cur);
}
}
@@ -337,8 +351,8 @@ bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symb
data.in_class = in_class;
data.now = time (NULL);
data.ctx = ctx;
- data.learned_tokens = 0;
- data.learned_tokens = 0;
+ data.processed_tokens = 0;
+ data.processed_tokens = 0;
if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
minnodes = parse_limit (value, -1);
data.max_tokens = minnodes;
@@ -394,7 +408,7 @@ bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symb
statfile_pool_unlock_file (pool, data.file);
if (sum != NULL) {
- *sum = data.learned_tokens;
+ *sum = data.processed_tokens;
}
return TRUE;
@@ -447,7 +461,7 @@ bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
data.now = time (NULL);
data.ctx = ctx;
- data.learned_tokens = 0;
+ data.processed_tokens = 0;
if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
minnodes = parse_limit (value, -1);
data.max_tokens = minnodes;
@@ -503,70 +517,6 @@ bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
GList *
bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task)
{
- struct bayes_callback_data data;
- char *value;
- int nodes, minnodes, i, cnt;
- struct classify_weight *w;
- struct statfile *st;
- stat_file_t *file;
- GList *cur, *resl = NULL;
-
- g_assert (pool != NULL);
- g_assert (ctx != NULL);
-
- if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
- minnodes = strtol (value, NULL, 10);
- nodes = g_tree_nnodes (input);
- if (nodes > FEATURE_WINDOW_SIZE) {
- nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE;
- }
- if (nodes < minnodes) {
- return NULL;
- }
- }
-
- data.statfiles_num = g_list_length (ctx->cfg->statfiles);
- data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num);
- data.pool = pool;
- data.now = time (NULL);
- data.ctx = ctx;
-
- cur = ctx->cfg->statfiles;
- i = 0;
- while (cur) {
- /* Select statfile to learn */
- st = cur->data;
- if ((file = statfile_pool_is_open (pool, st->path)) == NULL) {
- if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
- msg_warn ("cannot open %s", st->path);
- cur = g_list_next (cur);
- data.statfiles_num --;
- continue;
- }
- }
- data.statfiles[i].file = file;
- data.statfiles[i].st = st;
- data.statfiles[i].post_probability = 0.5;
- data.statfiles[i].local_probability = 0.5;
- i ++;
- cur = g_list_next (cur);
- }
- cnt = i;
-
- g_tree_foreach (input, bayes_classify_callback, &data);
-
- for (i = 0; i < cnt; i ++) {
- w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight));
- w->name = data.statfiles[i].st->symbol;
- w->weight = data.statfiles[i].post_probability;
- resl = g_list_prepend (resl, w);
- }
-
- g_free (data.statfiles);
-
- if (resl != NULL) {
- memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, resl);
- }
-
- return resl;
+ /* This function is unimplemented with new normalizer */
+ return NULL;
}