/* * Copyright (c) 2009, Rambler media * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY Rambler media ''AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL Rambler BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ /* * Bayesian classifier */ #include "classifiers.h" #include "tokenizers/tokenizers.h" #include "main.h" #include "filter.h" #include "cfg_file.h" #include "binlog.h" #include "lua/lua_common.h" #define LOCAL_PROB_DENOM 16.0 static inline GQuark bayes_error_quark (void) { return g_quark_from_static_string ("bayes-error"); } struct bayes_statfile_data { guint64 hits; guint64 total_hits; double local_probability; double post_probability; double corr; double value; struct statfile *st; stat_file_t *file; }; struct bayes_callback_data { statfile_pool_t *pool; struct classifier_ctx *ctx; gboolean in_class; time_t now; stat_file_t *file; struct bayes_statfile_data *statfiles; guint32 statfiles_num; guint64 learned_tokens; gsize max_tokens; }; static gboolean bayes_learn_callback (gpointer key, gpointer value, gpointer data) { token_node_t *node = key; struct bayes_callback_data *cd = data; gint c; guint64 v; c = (cd->in_class) ? 1 : -1; /* Consider that not found blocks have value 1 */ v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now); if (v == 0 && c > 0) { statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c); cd->learned_tokens ++; } else if (v != 0) { if (G_LIKELY (c > 0)) { v ++; } else if (c < 0){ if (v != 0) { v --; } } statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, v); cd->learned_tokens ++; } if (cd->max_tokens != 0 && cd->learned_tokens > cd->max_tokens) { /* Stop learning on max tokens */ return TRUE; } return FALSE; } /* * In this callback we calculate local probabilities for tokens */ static gboolean bayes_classify_callback (gpointer key, gpointer value, gpointer data) { token_node_t *node = key; struct bayes_callback_data *cd = data; double renorm = 0; guint i; double local_hits = 0; struct bayes_statfile_data *cur; for (i = 0; i < cd->statfiles_num; i ++) { cur = &cd->statfiles[i]; cur->value = statfile_pool_get_block (cd->pool, cur->file, node->h1, node->h2, cd->now); if (cur->value > 0) { cur->total_hits ++; cur->hits = cur->value; local_hits += cur->value; } } for (i = 0; i < cd->statfiles_num; i ++) { cur = &cd->statfiles[i]; cur->local_probability = 0.5 + (cur->value - (local_hits - cur->value)) / (LOCAL_PROB_DENOM * (1.0 + local_hits)); renorm += cur->post_probability * cur->local_probability; } for (i = 0; i < cd->statfiles_num; i ++) { cur = &cd->statfiles[i]; cur->post_probability = (cur->post_probability * cur->local_probability) / renorm; if (cur->post_probability < G_MINDOUBLE * 100) { cur->post_probability = G_MINDOUBLE * 100; } } renorm = 0; for (i = 0; i < cd->statfiles_num; i ++) { cur = &cd->statfiles[i]; renorm += cur->post_probability; } /* Renormalize to form sum of probabilities equal to 1 */ for (i = 0; i < cd->statfiles_num; i ++) { cur = &cd->statfiles[i]; cur->post_probability /= renorm; if (cur->post_probability < G_MINDOUBLE * 10) { cur->post_probability = G_MINDOUBLE * 100; } if (cd->ctx->debug) { msg_info ("token: %s, statfile: %s, probability: %.4f, post_probability: %.4f", node->extra, cur->st->symbol, cur->value, cur->post_probability); } } cd->learned_tokens ++; if (cd->max_tokens != 0 && cd->learned_tokens > cd->max_tokens) { /* Stop classifying on max tokens */ return TRUE; } return FALSE; } struct classifier_ctx* bayes_init (memory_pool_t *pool, struct classifier_config *cfg) { struct classifier_ctx *ctx = memory_pool_alloc (pool, sizeof (struct classifier_ctx)); ctx->pool = pool; ctx->cfg = cfg; ctx->debug = FALSE; return ctx; } gboolean bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task) { struct bayes_callback_data data; gchar *value; gint nodes, i = 0, cnt, best_num = 0; gint minnodes; guint64 rev, total_learns = 0; double best = 0; struct statfile *st; stat_file_t *file; GList *cur; char *sumbuf; g_assert (pool != NULL); g_assert (ctx != NULL); if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { minnodes = strtol (value, NULL, 10); nodes = g_tree_nnodes (input); if (nodes > FEATURE_WINDOW_SIZE) { nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE; } if (nodes < minnodes) { return FALSE; } } cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE); if (cur) { memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur); } else { cur = ctx->cfg->statfiles; } data.statfiles_num = g_list_length (cur); data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num); data.pool = pool; data.now = time (NULL); data.ctx = ctx; data.learned_tokens = 0; if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) { minnodes = parse_limit (value, -1); data.max_tokens = minnodes; } else { data.max_tokens = 0; } while (cur) { /* Select statfile to classify */ st = cur->data; if ((file = statfile_pool_is_open (pool, st->path)) == NULL) { if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { msg_warn ("cannot open %s", st->path); cur = g_list_next (cur); data.statfiles_num --; continue; } } data.statfiles[i].file = file; data.statfiles[i].st = st; data.statfiles[i].post_probability = 0.5; data.statfiles[i].local_probability = 0.5; statfile_get_revision (file, &rev, NULL); total_learns += rev; cur = g_list_next (cur); i ++; } cnt = i; /* Calculate correction factor */ for (i = 0; i < cnt; i ++) { statfile_get_revision (data.statfiles[i].file, &rev, NULL); data.statfiles[i].corr = ((double)rev / cnt) / (double)total_learns; } g_tree_foreach (input, bayes_classify_callback, &data); for (i = 0; i < cnt; i ++) { debug_task ("got probability for symbol %s: %.2f", data.statfiles[i].st->symbol, data.statfiles[i].post_probability); if (data.statfiles[i].post_probability > best) { best = data.statfiles[i].post_probability; best_num = i; } } if (best > 0.5) { sumbuf = memory_pool_alloc (task->task_pool, 32); rspamd_snprintf (sumbuf, 32, "%.2f", best); cur = g_list_prepend (NULL, sumbuf); insert_result (task, data.statfiles[best_num].st->symbol, best, cur); } g_free (data.statfiles); return TRUE; } gboolean bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input, gboolean in_class, double *sum, double multiplier, GError **err) { struct bayes_callback_data data; gchar *value; gint nodes; gint minnodes; struct statfile *st, *sel_st = NULL; stat_file_t *to_learn; GList *cur; g_assert (pool != NULL); g_assert (ctx != NULL); if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { minnodes = strtol (value, NULL, 10); nodes = g_tree_nnodes (input); if (nodes > FEATURE_WINDOW_SIZE) { nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE; } if (nodes < minnodes) { msg_info ("do not learn message as it has too few tokens: %d, while %d min", nodes, minnodes); *sum = 0; g_set_error (err, bayes_error_quark(), /* error domain */ 1, /* error code */ "message contains too few tokens: %d, while min is %d", nodes, (int)minnodes); return FALSE; } } data.pool = pool; data.in_class = in_class; data.now = time (NULL); data.ctx = ctx; data.learned_tokens = 0; data.learned_tokens = 0; if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) { minnodes = parse_limit (value, -1); data.max_tokens = minnodes; } else { data.max_tokens = 0; } cur = ctx->cfg->statfiles; while (cur) { /* Select statfile to learn */ st = cur->data; if (strcmp (st->symbol, symbol) == 0) { sel_st = st; break; } cur = g_list_next (cur); } if (sel_st == NULL) { g_set_error (err, bayes_error_quark(), /* error domain */ 1, /* error code */ "cannot find statfile for symbol: %s", symbol); return FALSE; } if ((to_learn = statfile_pool_is_open (pool, sel_st->path)) == NULL) { if ((to_learn = statfile_pool_open (pool, sel_st->path, sel_st->size, FALSE)) == NULL) { msg_warn ("cannot open %s", sel_st->path); if (statfile_pool_create (pool, sel_st->path, sel_st->size) == -1) { msg_err ("cannot create statfile %s", sel_st->path); g_set_error (err, bayes_error_quark(), /* error domain */ 1, /* error code */ "cannot create statfile: %s", sel_st->path); return FALSE; } if ((to_learn = statfile_pool_open (pool, sel_st->path, sel_st->size, FALSE)) == NULL) { g_set_error (err, bayes_error_quark(), /* error domain */ 1, /* error code */ "cannot open statfile %s after creation", sel_st->path); msg_err ("cannot open statfile %s after creation", sel_st->path); return FALSE; } } } data.file = to_learn; statfile_pool_lock_file (pool, data.file); g_tree_foreach (input, bayes_learn_callback, &data); statfile_inc_revision (to_learn); statfile_pool_unlock_file (pool, data.file); if (sum != NULL) { *sum = data.learned_tokens; } return TRUE; } gboolean bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task, gboolean is_spam, GError **err) { struct bayes_callback_data data; gchar *value; gint nodes; gint minnodes; struct statfile *st; stat_file_t *file; GList *cur; g_assert (pool != NULL); g_assert (ctx != NULL); if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { minnodes = strtol (value, NULL, 10); nodes = g_tree_nnodes (input); if (nodes > FEATURE_WINDOW_SIZE) { nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE; } if (nodes < minnodes) { g_set_error (err, bayes_error_quark(), /* error domain */ 1, /* error code */ "message contains too few tokens: %d, while min is %d", nodes, (int)minnodes); return FALSE; } } cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE); if (cur) { memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur); } else { cur = ctx->cfg->statfiles; } data.pool = pool; data.now = time (NULL); data.ctx = ctx; data.learned_tokens = 0; if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) { minnodes = parse_limit (value, -1); data.max_tokens = minnodes; } else { data.max_tokens = 0; } while (cur) { /* Select statfiles to learn */ st = cur->data; if (st->is_spam != is_spam) { cur = g_list_next (cur); continue; } if ((file = statfile_pool_is_open (pool, st->path)) == NULL) { if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { msg_warn ("cannot open %s", st->path); if (statfile_pool_create (pool, st->path, st->size) == -1) { msg_err ("cannot create statfile %s", st->path); g_set_error (err, bayes_error_quark(), /* error domain */ 1, /* error code */ "cannot create statfile: %s", st->path); return FALSE; } if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { g_set_error (err, bayes_error_quark(), /* error domain */ 1, /* error code */ "cannot open statfile %s after creation", st->path); msg_err ("cannot open statfile %s after creation", st->path); return FALSE; } cur = g_list_next (cur); continue; } } data.file = file; statfile_pool_lock_file (pool, data.file); g_tree_foreach (input, bayes_learn_callback, &data); statfile_inc_revision (file); statfile_pool_unlock_file (pool, data.file); maybe_write_binlog (ctx->cfg, st, file, input); cur = g_list_next (cur); } return TRUE; } GList * bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task) { struct bayes_callback_data data; char *value; int nodes, minnodes, i, cnt; struct classify_weight *w; struct statfile *st; stat_file_t *file; GList *cur, *resl = NULL; g_assert (pool != NULL); g_assert (ctx != NULL); if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { minnodes = strtol (value, NULL, 10); nodes = g_tree_nnodes (input); if (nodes > FEATURE_WINDOW_SIZE) { nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE; } if (nodes < minnodes) { return NULL; } } data.statfiles_num = g_list_length (ctx->cfg->statfiles); data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num); data.pool = pool; data.now = time (NULL); data.ctx = ctx; cur = ctx->cfg->statfiles; i = 0; while (cur) { /* Select statfile to learn */ st = cur->data; if ((file = statfile_pool_is_open (pool, st->path)) == NULL) { if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { msg_warn ("cannot open %s", st->path); cur = g_list_next (cur); data.statfiles_num --; continue; } } data.statfiles[i].file = file; data.statfiles[i].st = st; data.statfiles[i].post_probability = 0.5; data.statfiles[i].local_probability = 0.5; i ++; cur = g_list_next (cur); } cnt = i; g_tree_foreach (input, bayes_classify_callback, &data); for (i = 0; i < cnt; i ++) { w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight)); w->name = data.statfiles[i].st->symbol; w->weight = data.statfiles[i].post_probability; resl = g_list_prepend (resl, w); } g_free (data.statfiles); if (resl != NULL) { memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, resl); } return resl; }