/* * Copyright (c) 2009-2012, Vsevolod Stakhov * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY AUTHOR ''AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL AUTHOR BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ /* * Bayesian classifier */ #include "classifiers.h" #include "rspamd.h" #include "filter.h" #include "cfg_file.h" #include "stat_internal.h" #include "math.h" #define msg_err_bayes(...) rspamd_default_log_function (G_LOG_LEVEL_CRITICAL, \ "bayes", task->task_pool->tag.uid, \ G_STRFUNC, \ __VA_ARGS__) #define msg_warn_bayes(...) rspamd_default_log_function (G_LOG_LEVEL_WARNING, \ "bayes", task->task_pool->tag.uid, \ G_STRFUNC, \ __VA_ARGS__) #define msg_info_bayes(...) rspamd_default_log_function (G_LOG_LEVEL_INFO, \ "bayes", task->task_pool->tag.uid, \ G_STRFUNC, \ __VA_ARGS__) #define msg_debug_bayes(...) rspamd_default_log_function (G_LOG_LEVEL_DEBUG, \ "bayes", task->task_pool->tag.uid, \ G_STRFUNC, \ __VA_ARGS__) static inline GQuark bayes_error_quark (void) { return g_quark_from_static_string ("bayes-error"); } /** * Returns probability of chisquare > value with specified number of freedom * degrees * @param value value to test * @param freedom_deg number of degrees of freedom * @return */ static gdouble inv_chi_square (struct rspamd_task *task, gdouble value, gint freedom_deg) { double prob, sum, m; gint i; errno = 0; m = -value; prob = exp (value); if (errno == ERANGE) { msg_err_bayes ("exp overflow"); return 0; } sum = prob; for (i = 1; i < freedom_deg; i++) { prob *= m / (gdouble)i; msg_debug_bayes ("prob: %.6f", prob); sum += prob; } return MIN (1.0, sum); } struct bayes_task_closure { struct rspamd_classifier_runtime *rt; struct rspamd_task *task; }; /* * Mathematically we use pow(complexity, complexity), where complexity is the * window index */ static const double feature_weight[] = { 0, 1, 4, 27, 256, 3125, 46656, 823543 }; #define PROB_COMBINE(prob, cnt, weight, assumed) (((weight) * (assumed) + (cnt) * (prob)) / ((weight) + (cnt))) /* * In this callback we calculate local probabilities for tokens */ static gboolean bayes_classify_callback (gpointer key, gpointer value, gpointer data) { rspamd_token_t *node = value; struct bayes_task_closure *cl = data; struct rspamd_classifier_runtime *rt; guint i; struct rspamd_token_result *res; guint64 spam_count = 0, ham_count = 0, total_count = 0; struct rspamd_task *task; double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob, ham_prob, fw, w, norm_sum, norm_sub; rt = cl->rt; task = cl->task; for (i = rt->start_pos; i < rt->end_pos; i++) { res = &g_array_index (node->results, struct rspamd_token_result, i); if (res->value > 0) { if (res->st_runtime->st->is_spam) { spam_count += res->value; } else { ham_count += res->value; } total_count += res->value; res->st_runtime->total_hits += res->value; } } /* Probability for this token */ if (total_count > 0) { spam_freq = ((double)spam_count / MAX (1., (double)rt->total_spam)); ham_freq = ((double)ham_count / MAX (1., (double)rt->total_ham)); spam_prob = spam_freq / (spam_freq + ham_freq); ham_prob = ham_freq / (spam_freq + ham_freq); fw = feature_weight[node->window_idx % G_N_ELEMENTS (feature_weight)]; norm_sum = (spam_freq + ham_freq) * (spam_freq + ham_freq); norm_sub = (spam_freq - ham_freq) * (spam_freq - ham_freq); w = (norm_sub) / (norm_sum) * (fw * total_count) / (4.0 * (1.0 + fw * total_count)); bayes_spam_prob = PROB_COMBINE (spam_prob, total_count, w, 0.5); norm_sub = (ham_freq - spam_freq) * (ham_freq - spam_freq); w = (norm_sub) / (norm_sum) * (fw * total_count) / (4.0 * (1.0 + fw * total_count)); bayes_ham_prob = PROB_COMBINE (ham_prob, total_count, w, 0.5); rt->spam_prob += log (bayes_spam_prob); rt->ham_prob += log (bayes_ham_prob); res->cl_runtime->processed_tokens ++; msg_debug_bayes ("token: weight: %f, total_count: %L, " "spam_count: %L, ham_count: %L," "spam_prob: %.3f, ham_prob: %.3f, " "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, " "current spam prob: %.3f, current ham prob: %.3f", fw, total_count, spam_count, ham_count, spam_prob, ham_prob, bayes_spam_prob, bayes_ham_prob, rt->spam_prob, rt->ham_prob); } return FALSE; } /* * A(x - 0.5)^4 + B(x - 0.5)^3 + C(x - 0.5)^2 + D(x - 0.5) * A = 32, * B = -6 * C = -7 * D = 3 * y = 32(x - 0.5)^4 - 6(x - 0.5)^3 - 7(x - 0.5)^2 + 3(x - 0.5) */ static gdouble bayes_normalize_prob (gdouble x) { const gdouble a = 32, b = -6, c = -7, d = 3; gdouble xx, x2, x3, x4; xx = x - 0.5; x2 = xx * xx; x3 = x2 * xx; x4 = x3 * xx; return a*x4 + b*x3 + c*x2 + d*xx; } struct classifier_ctx * bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier_config *cfg) { struct classifier_ctx *ctx = rspamd_mempool_alloc (pool, sizeof (struct classifier_ctx)); ctx->pool = pool; ctx->cfg = cfg; ctx->debug = FALSE; return ctx; } gboolean bayes_classify (struct classifier_ctx * ctx, GTree *input, struct rspamd_classifier_runtime *rt, struct rspamd_task *task) { double final_prob, h, s; guint maxhits = 0; struct rspamd_statfile_runtime *st, *selected_st = NULL; GList *cur; char *sumbuf; struct bayes_task_closure cl; g_assert (ctx != NULL); g_assert (input != NULL); g_assert (rt != NULL); g_assert (rt->end_pos > rt->start_pos); if (rt->stage == RSPAMD_STAT_STAGE_PRE) { cl.rt = rt; cl.task = task; g_tree_foreach (input, bayes_classify_callback, &cl); } else { h = 1 - inv_chi_square (task, rt->spam_prob, rt->processed_tokens); s = 1 - inv_chi_square (task, rt->ham_prob, rt->processed_tokens); if (isfinite (s) && isfinite (h)) { final_prob = (s + 1.0 - h) / 2.; msg_debug_bayes ("<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f," " %L tokens processed of %ud total tokens", task->message_id, rt->ham_prob, h, rt->spam_prob, s, rt->processed_tokens, g_tree_nnodes (input)); } else { /* * We have some overflow, hence we need to check which class * is NaN */ if (isfinite (h)) { final_prob = 1.0; msg_debug_bayes ("<%s> spam class is overflowed, as we have no" " ham samples", task->message_id); } else if (isfinite (s)){ final_prob = 0.0; msg_debug_bayes ("<%s> ham class is overflowed, as we have no" " spam samples", task->message_id); } else { final_prob = 0.5; msg_warn_bayes ("<%s> spam and ham classes are both overflowed", task->message_id); } } if (rt->processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) { sumbuf = rspamd_mempool_alloc (task->task_pool, 32); cur = g_list_first (rt->st_runtime); while (cur) { st = (struct rspamd_statfile_runtime *)cur->data; if ((final_prob < 0.5 && !st->st->is_spam) || (final_prob > 0.5 && st->st->is_spam)) { if (st->total_hits > maxhits) { maxhits = st->total_hits; selected_st = st; } } cur = g_list_next (cur); } if (selected_st == NULL) { msg_err_bayes ( "unexpected classifier error: cannot select desired statfile, " "prob: %.4f", final_prob); } else { /* Correctly scale HAM */ if (final_prob < 0.5) { final_prob = 1.0 - final_prob; } rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.); final_prob = bayes_normalize_prob (final_prob); cur = g_list_prepend (NULL, sumbuf); rspamd_task_insert_result (task, selected_st->st->symbol, final_prob, cur); } } } return TRUE; } static gboolean bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data) { rspamd_token_t *node = value; struct rspamd_token_result *res; struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data; guint i; for (i = rt->start_pos; i < rt->end_pos; i++) { res = &g_array_index (node->results, struct rspamd_token_result, i); if (res->st_runtime) { if (res->st_runtime->st->is_spam) { res->value ++; } else if (res->value > 0) { /* Unlearning */ res->value --; } } } return FALSE; } static gboolean bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data) { rspamd_token_t *node = value; struct rspamd_token_result *res; struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data; guint i; for (i = rt->start_pos; i < rt->end_pos; i++) { res = &g_array_index (node->results, struct rspamd_token_result, i); if (res->st_runtime) { if (!res->st_runtime->st->is_spam) { res->value ++; } else if (res->value > 0) { res->value --; } } } return FALSE; } gboolean bayes_learn_spam (struct classifier_ctx * ctx, GTree *input, struct rspamd_classifier_runtime *rt, struct rspamd_task *task, gboolean is_spam, GError **err) { g_assert (ctx != NULL); g_assert (input != NULL); g_assert (rt != NULL); g_assert (rt->end_pos > rt->start_pos); if (is_spam) { g_tree_foreach (input, bayes_learn_spam_callback, rt); } else { g_tree_foreach (input, bayes_learn_ham_callback, rt); } return TRUE; }