diff options
Diffstat (limited to 'src/classifiers')
-rw-r--r-- | src/classifiers/bayes.c | 383 | ||||
-rw-r--r-- | src/classifiers/classifiers.c | 7 | ||||
-rw-r--r-- | src/classifiers/classifiers.h | 10 | ||||
-rw-r--r-- | src/classifiers/winnow.c | 20 |
4 files changed, 409 insertions, 11 deletions
diff --git a/src/classifiers/bayes.c b/src/classifiers/bayes.c new file mode 100644 index 000000000..1ce36e0bd --- /dev/null +++ b/src/classifiers/bayes.c @@ -0,0 +1,383 @@ +/* + * Copyright (c) 2009, Rambler media + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY Rambler media ''AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL Rambler BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * Bayesian classifier + */ +#include "classifiers.h" +#include "../tokenizers/tokenizers.h" +#include "../main.h" +#include "../filter.h" +#include "../cfg_file.h" +#ifdef WITH_LUA +#include "../lua/lua_common.h" +#endif + +#define LOCAL_PROB_DENOM 16.0 + +G_INLINE_FUNC GQuark +bayes_error_quark (void) +{ + return g_quark_from_static_string ("bayes-error"); +} + +struct bayes_statfile_data { + double hits; + double total_hits; + double local_probability; + double post_probability; + double value; + struct statfile *st; + stat_file_t *file; +}; + +struct bayes_callback_data { + statfile_pool_t *pool; + struct classifier_ctx *ctx; + gboolean in_class; + time_t now; + stat_file_t *file; + struct bayes_statfile_data *statfiles; + guint32 statfiles_num; +}; + +static gboolean +bayes_learn_callback (gpointer key, gpointer value, gpointer data) +{ + token_node_t *node = key; + struct bayes_callback_data *cd = data; + double v, c; + + c = (cd->in_class) ? 1 : -1; + + /* Consider that not found blocks have value 1 */ + v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now); + if (fabs (v) < ALPHA && c > 0) { + statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c); + } + else { + if (G_LIKELY (c > 0 && c < G_MAXDOUBLE)) { + v += c; + } + else if (c < 0){ + if (v > -c) { + v -= c; + } + else { + v = 0; + } + } + statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, v); + } + + return FALSE; +} + +/* + * In this callback we calculate local probabilities for tokens + */ +static gboolean +bayes_classify_callback (gpointer key, gpointer value, gpointer data) +{ + + token_node_t *node = key; + struct bayes_callback_data *cd = data; + double local_hits = 0, renorm = 0; + int i; + struct bayes_statfile_data *cur; + + for (i = 0; i < cd->statfiles_num; i ++) { + cur = &cd->statfiles[i]; + cur->value = statfile_pool_get_block (cd->pool, cur->file, node->h1, node->h2, cd->now); + if (cur->value > ALPHA) { + cur->total_hits += cur->value; + cur->hits = cur->value; + local_hits += cur->value; + } + else { + cur->value = 0; + } + } + for (i = 0; i < cd->statfiles_num; i ++) { + cur = &cd->statfiles[i]; + cur->local_probability = 0.5 + (cur->value - (local_hits - cur->value)) / (LOCAL_PROB_DENOM * (local_hits + 1.0)); + renorm += cur->post_probability * cur->local_probability; + } + + for (i = 0; i < cd->statfiles_num; i ++) { + cur = &cd->statfiles[i]; + cur->post_probability = (cur->post_probability * cur->local_probability) / renorm; + if (cur->post_probability < G_MINDOUBLE * 100) { + cur->post_probability = G_MINDOUBLE * 100; + } + } + renorm = 0; + for (i = 0; i < cd->statfiles_num; i ++) { + cur = &cd->statfiles[i]; + renorm += cur->post_probability; + } + /* Renormalize to form sum of probabilities equal to 1 */ + for (i = 0; i < cd->statfiles_num; i ++) { + cur = &cd->statfiles[i]; + cur->post_probability /= renorm; + if (cur->post_probability < G_MINDOUBLE * 10) { + cur->post_probability = G_MINDOUBLE * 100; + } + } + + return FALSE; +} + +struct classifier_ctx* +bayes_init (memory_pool_t *pool, struct classifier_config *cfg) +{ + struct classifier_ctx *ctx = memory_pool_alloc (pool, sizeof (struct classifier_ctx)); + + ctx->pool = pool; + ctx->cfg = cfg; + + return ctx; +} + +gboolean +bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task) +{ + struct bayes_callback_data data; + char *value; + int nodes, minnodes, i, cnt, best_num = 0; + double best = 0; + struct statfile *st; + stat_file_t *file; + GList *cur; + char *sumbuf; + + g_assert (pool != NULL); + g_assert (ctx != NULL); + + if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { + minnodes = strtol (value, NULL, 10); + nodes = g_tree_nnodes (input) / FEATURE_WINDOW_SIZE; + if (nodes < minnodes) { + return FALSE; + } + } + + data.statfiles_num = g_list_length (ctx->cfg->statfiles); + data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num); + data.pool = pool; + data.now = time (NULL); + data.ctx = ctx; + + cur = ctx->cfg->statfiles; + i = 0; + while (cur) { + /* Select statfile to learn */ + st = cur->data; + if ((file = statfile_pool_is_open (pool, st->path)) == NULL) { + if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { + msg_warn ("cannot open %s", st->path); + cur = g_list_next (cur); + data.statfiles_num --; + continue; + } + } + data.statfiles[i].file = file; + data.statfiles[i].st = st; + data.statfiles[i].post_probability = 0.5; + data.statfiles[i].local_probability = 0.5; + i ++; + cur = g_list_next (cur); + } + cnt = i; + + g_tree_foreach (input, bayes_classify_callback, &data); + + for (i = 0; i < cnt; i ++) { + debug_task ("got probability for symbol %s: %.2f", data.statfiles[i].st->symbol, data.statfiles[i].post_probability); + if (data.statfiles[i].post_probability > best) { + best = data.statfiles[i].post_probability; + best_num = i; + } + } + + if (best > 0.5) { + sumbuf = memory_pool_alloc (task->task_pool, 32); + rspamd_snprintf (sumbuf, 32, "%.2f", best); + cur = g_list_prepend (NULL, sumbuf); + insert_result (task, data.statfiles[best_num].st->symbol, best, cur); + } + + g_free (data.statfiles); + + return TRUE; +} + +gboolean +bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input, + gboolean in_class, double *sum, double multiplier, GError **err) +{ + struct bayes_callback_data data; + char *value; + int nodes, minnodes; + struct statfile *st, *sel_st = NULL; + stat_file_t *to_learn; + GList *cur; + + g_assert (pool != NULL); + g_assert (ctx != NULL); + + if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { + minnodes = strtol (value, NULL, 10); + nodes = g_tree_nnodes (input) / FEATURE_WINDOW_SIZE; + if (nodes < minnodes) { + msg_info ("do not learn message as it has too few tokens: %d, while %d min", nodes, minnodes); + *sum = 0; + g_set_error (err, + bayes_error_quark(), /* error domain */ + 1, /* error code */ + "message contains too few tokens: %d, while min is %d", + nodes, minnodes); + return FALSE; + } + } + + data.pool = pool; + data.in_class = in_class; + data.now = time (NULL); + data.ctx = ctx; + cur = ctx->cfg->statfiles; + while (cur) { + /* Select statfile to learn */ + st = cur->data; + if (strcmp (st->symbol, symbol) == 0) { + sel_st = st; + break; + } + cur = g_list_next (cur); + } + if (sel_st == NULL) { + g_set_error (err, + bayes_error_quark(), /* error domain */ + 1, /* error code */ + "cannot find statfile for symbol: %s", + symbol); + } + if ((to_learn = statfile_pool_is_open (pool, sel_st->path)) == NULL) { + if ((to_learn = statfile_pool_open (pool, sel_st->path, sel_st->size, FALSE)) == NULL) { + msg_warn ("cannot open %s", sel_st->path); + if (statfile_pool_create (pool, sel_st->path, sel_st->size) == -1) { + msg_err ("cannot create statfile %s", sel_st->path); + g_set_error (err, + bayes_error_quark(), /* error domain */ + 1, /* error code */ + "cannot create statfile: %s", + sel_st->path); + return FALSE; + } + if ((to_learn = statfile_pool_open (pool, sel_st->path, sel_st->size, FALSE)) == NULL) { + g_set_error (err, + bayes_error_quark(), /* error domain */ + 1, /* error code */ + "cannot open statfile %s after creation", + sel_st->path); + msg_err ("cannot open statfile %s after creation", sel_st->path); + return FALSE; + } + } + } + data.file = to_learn; + statfile_pool_lock_file (pool, data.file); + g_tree_foreach (input, bayes_learn_callback, &data); + statfile_pool_unlock_file (pool, data.file); + + return TRUE; +} + +GList * +bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task) +{ + struct bayes_callback_data data; + char *value; + int nodes, minnodes, i, cnt; + struct classify_weight *w; + struct statfile *st; + stat_file_t *file; + GList *cur, *resl = NULL; + + g_assert (pool != NULL); + g_assert (ctx != NULL); + + if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { + minnodes = strtol (value, NULL, 10); + nodes = g_tree_nnodes (input) / FEATURE_WINDOW_SIZE; + if (nodes < minnodes) { + return NULL; + } + } + + data.statfiles_num = g_list_length (ctx->cfg->statfiles); + data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num); + data.pool = pool; + data.now = time (NULL); + data.ctx = ctx; + + cur = ctx->cfg->statfiles; + i = 0; + while (cur) { + /* Select statfile to learn */ + st = cur->data; + if ((file = statfile_pool_is_open (pool, st->path)) == NULL) { + if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { + msg_warn ("cannot open %s", st->path); + cur = g_list_next (cur); + data.statfiles_num --; + continue; + } + } + data.statfiles[i].file = file; + data.statfiles[i].st = st; + data.statfiles[i].post_probability = 0.5; + data.statfiles[i].local_probability = 0.5; + i ++; + cur = g_list_next (cur); + } + cnt = i; + + g_tree_foreach (input, bayes_classify_callback, &data); + + for (i = 0; i < cnt; i ++) { + w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight)); + w->name = data.statfiles[i].st->symbol; + w->weight = data.statfiles[i].post_probability; + resl = g_list_prepend (resl, w); + } + + g_free (data.statfiles); + + if (resl != NULL) { + memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, resl); + } + + return resl; +} diff --git a/src/classifiers/classifiers.c b/src/classifiers/classifiers.c index 219576870..6b0554e1b 100644 --- a/src/classifiers/classifiers.c +++ b/src/classifiers/classifiers.c @@ -36,6 +36,13 @@ struct classifier classifiers[] = { .classify_func = winnow_classify, .learn_func = winnow_learn, .weights_func = winnow_weights + }, + { + .name = "bayes", + .init_func = bayes_init, + .classify_func = bayes_classify, + .learn_func = bayes_learn, + .weights_func = bayes_weights } }; diff --git a/src/classifiers/classifiers.h b/src/classifiers/classifiers.h index f69c1284c..0e6df173a 100644 --- a/src/classifiers/classifiers.h +++ b/src/classifiers/classifiers.h @@ -6,6 +6,9 @@ #include "../statfile.h" #include "../tokenizers/tokenizers.h" +/* Consider this value as 0 */ +#define ALPHA 0.0001 + struct classifier_config; struct worker_task; @@ -41,7 +44,12 @@ gboolean winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const gboolean in_class, double *sum, double multiplier, GError **err); GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); - +/* Bayes algorithm */ +struct classifier_ctx* bayes_init (memory_pool_t *pool, struct classifier_config *cf); +gboolean bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); +gboolean bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input, + gboolean in_class, double *sum, double multiplier, GError **err); +GList *bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); /* Array of all defined classifiers */ extern struct classifier classifiers[]; diff --git a/src/classifiers/winnow.c b/src/classifiers/winnow.c index 704f65b0a..f8c104a52 100644 --- a/src/classifiers/winnow.c +++ b/src/classifiers/winnow.c @@ -42,14 +42,14 @@ #define MAX_WEIGHT G_MAXDOUBLE / 2. -#define ALPHA 0.01 + #define MAX_LEARN_ITERATIONS 100 G_INLINE_FUNC GQuark winnow_error_quark (void) { - return g_quark_from_static_string ("winnow-error-quark"); + return g_quark_from_static_string ("winnow-error"); } struct winnow_callback_data { @@ -73,7 +73,7 @@ static const double max_common_weight = MAX_WEIGHT * WINNOW_DEMOTION; static gboolean -classify_callback (gpointer key, gpointer value, gpointer data) +winnow_classify_callback (gpointer key, gpointer value, gpointer data) { token_node_t *node = key; struct winnow_callback_data *cd = data; @@ -95,7 +95,7 @@ classify_callback (gpointer key, gpointer value, gpointer data) } static gboolean -learn_callback (gpointer key, gpointer value, gpointer data) +winnow_learn_callback (gpointer key, gpointer value, gpointer data) { token_node_t *node = key; struct winnow_callback_data *cd = data; @@ -247,7 +247,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp } if (data.file != NULL) { - g_tree_foreach (input, classify_callback, &data); + g_tree_foreach (input, winnow_classify_callback, &data); } if (data.count != 0) { @@ -320,7 +320,7 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu } if (data.file != NULL) { - g_tree_foreach (input, classify_callback, &data); + g_tree_foreach (input, winnow_classify_callback, &data); } w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight)); @@ -407,7 +407,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *sym st->path); return FALSE; } - if (statfile_pool_open (pool, st->path, st->size, FALSE)) { + if (statfile_pool_open (pool, st->path, st->size, FALSE) == NULL) { g_set_error (err, winnow_error_quark(), /* error domain */ 1, /* error code */ @@ -438,7 +438,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *sym data.sum = 0; data.count = 0; data.new_blocks = 0; - g_tree_foreach (input, classify_callback, &data); + g_tree_foreach (input, winnow_classify_callback, &data); if (data.count > 0) { max = data.sum / (double)data.count; } @@ -462,7 +462,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *sym st->path); return FALSE; } - g_tree_foreach (input, classify_callback, &data); + g_tree_foreach (input, winnow_classify_callback, &data); if (data.count != 0) { res = data.sum / data.count; } @@ -513,7 +513,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *sym } statfile_pool_lock_file (pool, data.file); - g_tree_foreach (input, learn_callback, &data); + g_tree_foreach (input, winnow_learn_callback, &data); statfile_pool_unlock_file (pool, data.file); if (data.count != 0) { res = data.sum / data.count; |