aboutsummaryrefslogtreecommitdiffstats
path: root/src/libstat/classifiers/bayes.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/libstat/classifiers/bayes.c')
-rw-r--r--src/libstat/classifiers/bayes.c597
1 files changed, 597 insertions, 0 deletions
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c
new file mode 100644
index 000000000..34169697e
--- /dev/null
+++ b/src/libstat/classifiers/bayes.c
@@ -0,0 +1,597 @@
+/*
+ * Copyright (c) 2009-2012, Vsevolod Stakhov
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY AUTHOR ''AS IS'' AND ANY
+ * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL AUTHOR BE LIABLE FOR ANY
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/*
+ * Bayesian classifier
+ */
+#include "classifiers.h"
+#include "tokenizers.h"
+#include "main.h"
+#include "filter.h"
+#include "cfg_file.h"
+#include "binlog.h"
+#include "lua/lua_common.h"
+
+#define LOCAL_PROB_DENOM 16.0
+
+static inline GQuark
+bayes_error_quark (void)
+{
+ return g_quark_from_static_string ("bayes-error");
+}
+
+struct bayes_statfile_data {
+ guint64 hits;
+ guint64 total_hits;
+ double value;
+ struct rspamd_statfile_config *st;
+ stat_file_t *file;
+};
+
+struct bayes_callback_data {
+ statfile_pool_t *pool;
+ struct classifier_ctx *ctx;
+ gboolean in_class;
+ time_t now;
+ stat_file_t *file;
+ struct bayes_statfile_data *statfiles;
+ guint32 statfiles_num;
+ guint64 total_spam;
+ guint64 total_ham;
+ guint64 processed_tokens;
+ gsize max_tokens;
+ double spam_probability;
+ double ham_probability;
+};
+
+static gboolean
+bayes_learn_callback (gpointer key, gpointer value, gpointer data)
+{
+ token_node_t *node = key;
+ struct bayes_callback_data *cd = data;
+ gint c;
+ guint64 v;
+
+ c = (cd->in_class) ? 1 : -1;
+
+ /* Consider that not found blocks have value 1 */
+ v =
+ statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2,
+ cd->now);
+ if (v == 0 && c > 0) {
+ statfile_pool_set_block (cd->pool,
+ cd->file,
+ node->h1,
+ node->h2,
+ cd->now,
+ c);
+ cd->processed_tokens++;
+ }
+ else if (v != 0) {
+ if (G_LIKELY (c > 0)) {
+ v++;
+ }
+ else if (c < 0) {
+ if (v != 0) {
+ v--;
+ }
+ }
+ statfile_pool_set_block (cd->pool,
+ cd->file,
+ node->h1,
+ node->h2,
+ cd->now,
+ v);
+ cd->processed_tokens++;
+ }
+
+ if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) {
+ /* Stop learning on max tokens */
+ return TRUE;
+ }
+ return FALSE;
+}
+
+/**
+ * Returns probability of chisquare > value with specified number of freedom
+ * degrees
+ * @param value value to test
+ * @param freedom_deg number of degrees of freedom
+ * @return
+ */
+static gdouble
+inv_chi_square (gdouble value, gint freedom_deg)
+{
+ long double prob, sum;
+ gint i;
+
+ if ((freedom_deg & 1) != 0) {
+ msg_err ("non-odd freedom degrees count: %d", freedom_deg);
+ return 0;
+ }
+
+ value /= 2.;
+ errno = 0;
+#ifdef HAVE_EXPL
+ prob = expl (-value);
+#elif defined(HAVE_EXP2L)
+ prob = exp2l (-value * log2 (M_E));
+#else
+ prob = exp (-value);
+#endif
+ if (errno == ERANGE) {
+ msg_err ("exp overflow");
+ return 0;
+ }
+ sum = prob;
+ for (i = 1; i < freedom_deg / 2; i++) {
+ prob *= value / (gdouble)i;
+ sum += prob;
+ }
+
+ return MIN (1.0, sum);
+}
+
+/*
+ * In this callback we calculate local probabilities for tokens
+ */
+static gboolean
+bayes_classify_callback (gpointer key, gpointer value, gpointer data)
+{
+
+ token_node_t *node = key;
+ struct bayes_callback_data *cd = data;
+ guint i;
+ struct bayes_statfile_data *cur;
+ guint64 spam_count = 0, ham_count = 0, total_count = 0;
+ double spam_prob, spam_freq, ham_freq, bayes_spam_prob;
+
+ for (i = 0; i < cd->statfiles_num; i++) {
+ cur = &cd->statfiles[i];
+ cur->value = statfile_pool_get_block (cd->pool,
+ cur->file,
+ node->h1,
+ node->h2,
+ cd->now);
+ if (cur->value > 0) {
+ cur->total_hits += cur->value;
+ if (cur->st->is_spam) {
+ spam_count += cur->value;
+ }
+ else {
+ ham_count += cur->value;
+ }
+ total_count += cur->value;
+ }
+ }
+
+ /* Probability for this token */
+ if (total_count > 0) {
+ spam_freq = ((double)spam_count / MAX (1., (double)cd->total_spam));
+ ham_freq = ((double)ham_count / MAX (1., (double)cd->total_ham));
+ spam_prob = spam_freq / (spam_freq + ham_freq);
+ bayes_spam_prob = (0.5 + spam_prob * total_count) / (1. + total_count);
+ cd->spam_probability += log (bayes_spam_prob);
+ cd->ham_probability += log (1. - bayes_spam_prob);
+ cd->processed_tokens++;
+ }
+
+ if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) {
+ /* Stop classifying on max tokens */
+ return TRUE;
+ }
+
+ return FALSE;
+}
+
+struct classifier_ctx *
+bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier_config *cfg)
+{
+ struct classifier_ctx *ctx =
+ rspamd_mempool_alloc (pool, sizeof (struct classifier_ctx));
+
+ ctx->pool = pool;
+ ctx->cfg = cfg;
+ ctx->debug = FALSE;
+
+ return ctx;
+}
+
+gboolean
+bayes_classify (struct classifier_ctx * ctx,
+ statfile_pool_t *pool,
+ GTree *input,
+ struct rspamd_task *task,
+ lua_State *L)
+{
+ struct bayes_callback_data data;
+ gchar *value;
+ gint nodes, i = 0, selected_st = -1, cnt;
+ gint minnodes;
+ guint64 maxhits = 0, rev;
+ double final_prob, h, s;
+ struct rspamd_statfile_config *st;
+ stat_file_t *file;
+ GList *cur;
+ char *sumbuf;
+
+ g_assert (pool != NULL);
+ g_assert (ctx != NULL);
+
+ if (ctx->cfg->opts &&
+ (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
+ minnodes = strtol (value, NULL, 10);
+ nodes = g_tree_nnodes (input);
+ if (nodes > FEATURE_WINDOW_SIZE) {
+ nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE;
+ }
+ if (nodes < minnodes) {
+ return FALSE;
+ }
+ }
+
+ cur = rspamd_lua_call_cls_pre_callbacks (ctx->cfg, task, FALSE, FALSE, L);
+ if (cur) {
+ rspamd_mempool_add_destructor (task->task_pool,
+ (rspamd_mempool_destruct_t)g_list_free, cur);
+ }
+ else {
+ cur = ctx->cfg->statfiles;
+ }
+
+ data.statfiles_num = g_list_length (cur);
+ data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num);
+ data.pool = pool;
+ data.now = time (NULL);
+ data.ctx = ctx;
+
+ data.processed_tokens = 0;
+ data.spam_probability = 0;
+ data.ham_probability = 0;
+ data.total_ham = 0;
+ data.total_spam = 0;
+ if (ctx->cfg->opts &&
+ (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
+ minnodes = rspamd_config_parse_limit (value, -1);
+ data.max_tokens = minnodes;
+ }
+ else {
+ data.max_tokens = 0;
+ }
+
+ while (cur) {
+ /* Select statfile to classify */
+ st = cur->data;
+ if ((file = statfile_pool_is_open (pool, st->path)) == NULL) {
+ if ((file =
+ statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
+ msg_warn ("cannot open %s", st->path);
+ cur = g_list_next (cur);
+ data.statfiles_num--;
+ continue;
+ }
+ }
+ data.statfiles[i].file = file;
+ data.statfiles[i].st = st;
+ statfile_get_revision (file, &rev, NULL);
+ if (st->is_spam) {
+ data.total_spam += rev;
+ }
+ else {
+ data.total_ham += rev;
+ }
+
+ cur = g_list_next (cur);
+ i++;
+ }
+
+ cnt = i;
+
+ g_tree_foreach (input, bayes_classify_callback, &data);
+
+ if (data.processed_tokens == 0 || data.spam_probability == 0) {
+ final_prob = 0;
+ }
+ else {
+ h = 1 - inv_chi_square (-2. * data.spam_probability,
+ 2 * data.processed_tokens);
+ s = 1 - inv_chi_square (-2. * data.ham_probability,
+ 2 * data.processed_tokens);
+ final_prob = (s + 1 - h) / 2.;
+ }
+
+ if (data.processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
+
+ sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
+ for (i = 0; i < cnt; i++) {
+ if ((final_prob > 0.5 && !data.statfiles[i].st->is_spam) ||
+ (final_prob < 0.5 && data.statfiles[i].st->is_spam)) {
+ continue;
+ }
+ if (data.statfiles[i].total_hits > maxhits) {
+ maxhits = data.statfiles[i].total_hits;
+ selected_st = i;
+ }
+ }
+ if (selected_st == -1) {
+ msg_err (
+ "unexpected classifier error: cannot select desired statfile");
+ }
+ else {
+ /* Calculate ham probability correctly */
+ if (final_prob < 0.5) {
+ final_prob = 1. - final_prob;
+ }
+ rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
+ cur = g_list_prepend (NULL, sumbuf);
+ rspamd_task_insert_result (task,
+ data.statfiles[selected_st].st->symbol,
+ final_prob,
+ cur);
+ }
+ }
+
+ g_free (data.statfiles);
+
+ return TRUE;
+}
+
+gboolean
+bayes_learn (struct classifier_ctx * ctx,
+ statfile_pool_t *pool,
+ const char *symbol,
+ GTree *input,
+ gboolean in_class,
+ double *sum,
+ double multiplier,
+ GError **err)
+{
+ struct bayes_callback_data data;
+ gchar *value;
+ gint nodes;
+ gint minnodes;
+ struct rspamd_statfile_config *st, *sel_st = NULL;
+ stat_file_t *to_learn;
+ GList *cur;
+
+ g_assert (pool != NULL);
+ g_assert (ctx != NULL);
+
+ if (ctx->cfg->opts &&
+ (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
+ minnodes = strtol (value, NULL, 10);
+ nodes = g_tree_nnodes (input);
+ if (nodes > FEATURE_WINDOW_SIZE) {
+ nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE;
+ }
+ if (nodes < minnodes) {
+ msg_info (
+ "do not learn message as it has too few tokens: %d, while %d min",
+ nodes,
+ minnodes);
+ *sum = 0;
+ g_set_error (err,
+ bayes_error_quark (), /* error domain */
+ 1, /* error code */
+ "message contains too few tokens: %d, while min is %d",
+ nodes, (int)minnodes);
+ return FALSE;
+ }
+ }
+
+ data.pool = pool;
+ data.in_class = in_class;
+ data.now = time (NULL);
+ data.ctx = ctx;
+ data.processed_tokens = 0;
+ data.processed_tokens = 0;
+ if (ctx->cfg->opts &&
+ (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
+ minnodes = rspamd_config_parse_limit (value, -1);
+ data.max_tokens = minnodes;
+ }
+ else {
+ data.max_tokens = 0;
+ }
+ cur = ctx->cfg->statfiles;
+ while (cur) {
+ /* Select statfile to learn */
+ st = cur->data;
+ if (strcmp (st->symbol, symbol) == 0) {
+ sel_st = st;
+ break;
+ }
+ cur = g_list_next (cur);
+ }
+ if (sel_st == NULL) {
+ g_set_error (err,
+ bayes_error_quark (), /* error domain */
+ 1, /* error code */
+ "cannot find statfile for symbol: %s",
+ symbol);
+ return FALSE;
+ }
+ if ((to_learn = statfile_pool_is_open (pool, sel_st->path)) == NULL) {
+ if ((to_learn =
+ statfile_pool_open (pool, sel_st->path, sel_st->size,
+ FALSE)) == NULL) {
+ msg_warn ("cannot open %s", sel_st->path);
+ if (statfile_pool_create (pool, sel_st->path, sel_st->size) == -1) {
+ msg_err ("cannot create statfile %s", sel_st->path);
+ g_set_error (err,
+ bayes_error_quark (), /* error domain */
+ 1, /* error code */
+ "cannot create statfile: %s",
+ sel_st->path);
+ return FALSE;
+ }
+ if ((to_learn =
+ statfile_pool_open (pool, sel_st->path, sel_st->size,
+ FALSE)) == NULL) {
+ g_set_error (err,
+ bayes_error_quark (), /* error domain */
+ 1, /* error code */
+ "cannot open statfile %s after creation",
+ sel_st->path);
+ msg_err ("cannot open statfile %s after creation",
+ sel_st->path);
+ return FALSE;
+ }
+ }
+ }
+ data.file = to_learn;
+ statfile_pool_lock_file (pool, data.file);
+ g_tree_foreach (input, bayes_learn_callback, &data);
+ statfile_inc_revision (to_learn);
+ statfile_pool_unlock_file (pool, data.file);
+
+ if (sum != NULL) {
+ *sum = data.processed_tokens;
+ }
+
+ return TRUE;
+}
+
+gboolean
+bayes_learn_spam (struct classifier_ctx * ctx,
+ statfile_pool_t *pool,
+ GTree *input,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ lua_State *L,
+ GError **err)
+{
+ struct bayes_callback_data data;
+ gchar *value;
+ gint nodes;
+ gint minnodes;
+ struct rspamd_statfile_config *st;
+ stat_file_t *file;
+ GList *cur;
+ gboolean skip_labels;
+
+ g_assert (pool != NULL);
+ g_assert (ctx != NULL);
+
+ if (ctx->cfg->opts &&
+ (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
+ minnodes = strtol (value, NULL, 10);
+ nodes = g_tree_nnodes (input);
+ if (nodes > FEATURE_WINDOW_SIZE) {
+ nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE;
+ }
+ if (nodes < minnodes) {
+ g_set_error (err,
+ bayes_error_quark (), /* error domain */
+ 1, /* error code */
+ "message contains too few tokens: %d, while min is %d",
+ nodes, (int)minnodes);
+ return FALSE;
+ }
+ }
+
+ cur = rspamd_lua_call_cls_pre_callbacks (ctx->cfg, task, TRUE, is_spam, L);
+ if (cur) {
+ skip_labels = FALSE;
+ rspamd_mempool_add_destructor (task->task_pool,
+ (rspamd_mempool_destruct_t)g_list_free, cur);
+ }
+ else {
+ /* Do not try to learn specific statfiles if pre callback returned nil */
+ skip_labels = TRUE;
+ cur = ctx->cfg->statfiles;
+ }
+
+ data.pool = pool;
+ data.now = time (NULL);
+ data.ctx = ctx;
+ data.in_class = TRUE;
+
+ data.processed_tokens = 0;
+ if (ctx->cfg->opts &&
+ (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
+ minnodes = rspamd_config_parse_limit (value, -1);
+ data.max_tokens = minnodes;
+ }
+ else {
+ data.max_tokens = 0;
+ }
+
+ while (cur) {
+ /* Select statfiles to learn */
+ st = cur->data;
+ if (st->is_spam != is_spam || (skip_labels && st->label)) {
+ cur = g_list_next (cur);
+ continue;
+ }
+ if ((file = statfile_pool_is_open (pool, st->path)) == NULL) {
+ if ((file =
+ statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
+ msg_warn ("cannot open %s", st->path);
+ if (statfile_pool_create (pool, st->path, st->size) == -1) {
+ msg_err ("cannot create statfile %s", st->path);
+ g_set_error (err,
+ bayes_error_quark (), /* error domain */
+ 1, /* error code */
+ "cannot create statfile: %s",
+ st->path);
+ return FALSE;
+ }
+ if ((file =
+ statfile_pool_open (pool, st->path, st->size,
+ FALSE)) == NULL) {
+ g_set_error (err,
+ bayes_error_quark (), /* error domain */
+ 1, /* error code */
+ "cannot open statfile %s after creation",
+ st->path);
+ msg_err ("cannot open statfile %s after creation",
+ st->path);
+ return FALSE;
+ }
+ }
+ }
+ data.file = file;
+ statfile_pool_lock_file (pool, data.file);
+ g_tree_foreach (input, bayes_learn_callback, &data);
+ statfile_inc_revision (file);
+ statfile_pool_unlock_file (pool, data.file);
+ maybe_write_binlog (ctx->cfg, st, file, input);
+ msg_info ("increase revision for %s", st->path);
+
+ cur = g_list_next (cur);
+ }
+
+ return TRUE;
+}
+
+GList *
+bayes_weights (struct classifier_ctx * ctx,
+ statfile_pool_t *pool,
+ GTree *input,
+ struct rspamd_task *task)
+{
+ /* This function is unimplemented with new normalizer */
+ return NULL;
+}