From 8e399cdba1bba1da8c1de2b8a22efe719aa30cde Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Mon, 8 Oct 2012 21:21:53 +0400 Subject: * Use murmur hash for all hashes as it is more efficient and provides more uniform distribution as glib's default one. * Fix probability renormalization while using advanced classification. --- src/classifiers/bayes.c | 38 +++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) (limited to 'src/classifiers') diff --git a/src/classifiers/bayes.c b/src/classifiers/bayes.c index a80bbe0ba..0e68f5d72 100644 --- a/src/classifiers/bayes.c +++ b/src/classifiers/bayes.c @@ -182,10 +182,10 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, { struct bayes_callback_data data; gchar *value; - gint nodes, i = 0, cnt, best_num = 0; + gint nodes, i = 0, cnt, best_num_spam = 0, best_num_ham = 0; gint minnodes; guint64 rev, total_learns = 0; - double best = 0; + double best_spam = 0., best_ham = 0., total_spam = 0., total_ham = 0.; struct statfile *st; stat_file_t *file; GList *cur; @@ -262,17 +262,37 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, for (i = 0; i < cnt; i ++) { debug_task ("got probability for symbol %s: %.2f", data.statfiles[i].st->symbol, data.statfiles[i].post_probability); - if (data.statfiles[i].post_probability > best) { - best = data.statfiles[i].post_probability; - best_num = i; + + if (data.statfiles[i].st->is_spam) { + total_spam += data.statfiles[i].post_probability; + if (data.statfiles[i].post_probability > best_spam) { + best_spam = data.statfiles[i].post_probability; + best_num_spam = i; + } + } + else { + total_ham += data.statfiles[i].post_probability; + if (data.statfiles[i].post_probability > best_ham) { + best_ham = data.statfiles[i].post_probability; + best_num_ham = i; + } } } - if (best > 0.5) { + + if (total_ham > 0.5 || total_spam > 0.5) { + sumbuf = memory_pool_alloc (task->task_pool, 32); - rspamd_snprintf (sumbuf, 32, "%.2f", best); - cur = g_list_prepend (NULL, sumbuf); - insert_result (task, data.statfiles[best_num].st->symbol, best, cur); + if (total_ham > total_spam) { + rspamd_snprintf (sumbuf, 32, "%.2f", total_ham); + cur = g_list_prepend (NULL, sumbuf); + insert_result (task, data.statfiles[best_num_ham].st->symbol, total_ham, cur); + } + else { + rspamd_snprintf (sumbuf, 32, "%.2f", total_spam); + cur = g_list_prepend (NULL, sumbuf); + insert_result (task, data.statfiles[best_num_spam].st->symbol, total_spam, cur); + } } g_free (data.statfiles); -- cgit v1.2.3