]> source.dussan.org Git - rspamd.git/commitdiff
* Add bayesian classifier (initial version)
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Fri, 13 Aug 2010 14:50:29 +0000 (18:50 +0400)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Fri, 13 Aug 2010 14:50:29 +0000 (18:50 +0400)
CMakeLists.txt
config.h.in
src/classifiers/bayes.c [new file with mode: 0644]
src/classifiers/classifiers.c
src/classifiers/classifiers.h
src/classifiers/winnow.c
src/statfile.c

index 2eac6ce598714b1a097276a471b6c15380b704d6..af324b67b3b6ec663eae103891eb49c2f9a61d48 100644 (file)
@@ -334,6 +334,7 @@ CHECK_INCLUDE_FILES(netdb.h  HAVE_NETDB_H)
 CHECK_INCLUDE_FILES(syslog.h HAVE_SYSLOG_H)
 CHECK_INCLUDE_FILES(locale.h HAVE_LOCALE_H)
 CHECK_INCLUDE_FILES(libgen.h HAVE_LIBGEN_H)
+CHECK_INCLUDE_FILES(search.h HAVE_SEARCH_H)
 CHECK_INCLUDE_FILES(pwd.h HAVE_PWD_H)
 CHECK_INCLUDE_FILES(grp.h HAVE_GRP_H)
 CHECK_INCLUDE_FILES(glob.h HAVE_GLOB_H)
@@ -498,6 +499,7 @@ SET(TOKENIZERSSRC  src/tokenizers/tokenizers.c
                                src/tokenizers/osb.c)
 
 SET(CLASSIFIERSSRC src/classifiers/classifiers.c
+                src/classifiers/bayes.c
                                src/classifiers/winnow.c)
 
 SET(PLUGINSSRC src/plugins/surbl.c
index 15b6295072d8af77a8dbeaf5b324f5ce6a205649..645310a5e4a565d7ac823102392093085adf9df1 100644 (file)
@@ -44,6 +44,8 @@
 
 #cmakedefine HAVE_LIBGEN_H       1
 
+#cmakedefine HAVE_SEARCH_H       1
+
 #cmakedefine HAVE_LOCALE_H       1
 
 #cmakedefine HAVE_GRP_H          1
 #define HAVE_DIRNAME 1
 #endif
 
+#ifdef HAVE_SEARCH_H
+#include <search.h>
+#endif
+
 #ifdef HAVE_LOCALE_H
 #include <locale.h>
 #define HAVE_SETLOCALE 1
diff --git a/src/classifiers/bayes.c b/src/classifiers/bayes.c
new file mode 100644 (file)
index 0000000..1ce36e0
--- /dev/null
@@ -0,0 +1,383 @@
+/*
+ * Copyright (c) 2009, Rambler media
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *     * Redistributions of source code must retain the above copyright
+ *       notice, this list of conditions and the following disclaimer.
+ *     * Redistributions in binary form must reproduce the above copyright
+ *       notice, this list of conditions and the following disclaimer in the
+ *       documentation and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY Rambler media ''AS IS'' AND ANY
+ * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL Rambler BE LIABLE FOR ANY
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/*
+ * Bayesian classifier
+ */
+#include "classifiers.h"
+#include "../tokenizers/tokenizers.h"
+#include "../main.h"
+#include "../filter.h"
+#include "../cfg_file.h"
+#ifdef WITH_LUA
+#include "../lua/lua_common.h"
+#endif
+
+#define LOCAL_PROB_DENOM 16.0
+
+G_INLINE_FUNC GQuark
+bayes_error_quark (void)
+{
+       return g_quark_from_static_string ("bayes-error");
+}
+
+struct bayes_statfile_data {
+       double                          hits;
+       double                          total_hits;
+       double                          local_probability;
+       double                          post_probability;
+       double                          value;
+       struct statfile                *st;
+       stat_file_t                    *file;
+};
+
+struct bayes_callback_data {
+       statfile_pool_t                *pool;
+       struct classifier_ctx          *ctx;
+       gboolean                        in_class;
+       time_t                          now;
+       stat_file_t                    *file;
+       struct bayes_statfile_data     *statfiles;
+       guint32                         statfiles_num;
+};
+
+static                          gboolean
+bayes_learn_callback (gpointer key, gpointer value, gpointer data)
+{
+       token_node_t                   *node = key;
+       struct bayes_callback_data     *cd = data;
+       double                          v, c;
+
+       c = (cd->in_class) ? 1 : -1;
+
+       /* Consider that not found blocks have value 1 */
+       v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
+       if (fabs (v) < ALPHA && c > 0) {
+               statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c);
+       }
+       else {
+               if (G_LIKELY (c > 0 && c < G_MAXDOUBLE)) {
+                       v += c;
+               }
+               else if (c < 0){
+                       if (v > -c) {
+                               v -= c;
+                       }
+                       else {
+                               v = 0;
+                       }
+               }
+               statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, v);
+       }
+
+       return FALSE;
+}
+
+/*
+ * In this callback we calculate local probabilities for tokens
+ */
+static gboolean
+bayes_classify_callback (gpointer key, gpointer value, gpointer data)
+{
+
+       token_node_t                   *node = key;
+       struct bayes_callback_data     *cd = data;
+       double                          local_hits = 0, renorm = 0;
+       int                             i;
+       struct bayes_statfile_data     *cur;
+
+       for (i = 0; i < cd->statfiles_num; i ++) {
+               cur = &cd->statfiles[i];
+               cur->value = statfile_pool_get_block (cd->pool, cur->file, node->h1, node->h2, cd->now);
+               if (cur->value > ALPHA) {
+                       cur->total_hits += cur->value;
+                       cur->hits = cur->value;
+                       local_hits += cur->value;
+               }
+               else {
+                       cur->value = 0;
+               }
+       }
+       for (i = 0; i < cd->statfiles_num; i ++) {
+               cur = &cd->statfiles[i];
+               cur->local_probability = 0.5 + (cur->value - (local_hits - cur->value)) / (LOCAL_PROB_DENOM * (local_hits + 1.0));
+               renorm += cur->post_probability * cur->local_probability;
+       }
+
+       for (i = 0; i < cd->statfiles_num; i ++) {
+               cur = &cd->statfiles[i];
+               cur->post_probability = (cur->post_probability * cur->local_probability) / renorm;
+               if (cur->post_probability < G_MINDOUBLE * 100) {
+                       cur->post_probability = G_MINDOUBLE * 100;
+               }
+       }
+       renorm = 0;
+       for (i = 0; i < cd->statfiles_num; i ++) {
+               cur = &cd->statfiles[i];
+               renorm += cur->post_probability;
+       }
+       /* Renormalize to form sum of probabilities equal to 1 */
+       for (i = 0; i < cd->statfiles_num; i ++) {
+               cur = &cd->statfiles[i];
+               cur->post_probability /= renorm;
+               if (cur->post_probability < G_MINDOUBLE * 10) {
+                       cur->post_probability = G_MINDOUBLE * 100;
+               }
+       }
+
+       return FALSE;
+}
+
+struct classifier_ctx*
+bayes_init (memory_pool_t *pool, struct classifier_config *cfg)
+{
+       struct classifier_ctx          *ctx = memory_pool_alloc (pool, sizeof (struct classifier_ctx));
+
+       ctx->pool = pool;
+       ctx->cfg = cfg;
+
+       return ctx;
+}
+
+gboolean
+bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task)
+{
+       struct bayes_callback_data      data;
+       char                           *value;
+       int                             nodes, minnodes, i, cnt, best_num = 0;
+       double                          best = 0;
+       struct statfile                *st;
+       stat_file_t                    *file;
+       GList                          *cur;
+       char                           *sumbuf;
+
+       g_assert (pool != NULL);
+       g_assert (ctx != NULL);
+
+       if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
+               minnodes = strtol (value, NULL, 10);
+               nodes = g_tree_nnodes (input) / FEATURE_WINDOW_SIZE;
+               if (nodes < minnodes) {
+                       return FALSE;
+               }
+       }
+
+       data.statfiles_num = g_list_length (ctx->cfg->statfiles);
+       data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num);
+       data.pool = pool;
+       data.now = time (NULL);
+       data.ctx = ctx;
+
+       cur = ctx->cfg->statfiles;
+       i = 0;
+       while (cur) {
+               /* Select statfile to learn */
+               st = cur->data;
+               if ((file = statfile_pool_is_open (pool, st->path)) == NULL) {
+                       if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
+                               msg_warn ("cannot open %s", st->path);
+                               cur = g_list_next (cur);
+                               data.statfiles_num --;
+                               continue;
+                       }
+               }
+               data.statfiles[i].file = file;
+               data.statfiles[i].st = st;
+               data.statfiles[i].post_probability = 0.5;
+               data.statfiles[i].local_probability = 0.5;
+               i ++;
+               cur = g_list_next (cur);
+       }
+       cnt = i;
+
+       g_tree_foreach (input, bayes_classify_callback, &data);
+
+       for (i = 0; i < cnt; i ++) {
+               debug_task ("got probability for symbol %s: %.2f", data.statfiles[i].st->symbol, data.statfiles[i].post_probability);
+               if (data.statfiles[i].post_probability > best) {
+                       best = data.statfiles[i].post_probability;
+                       best_num = i;
+               }
+       }
+
+       if (best > 0.5) {
+               sumbuf = memory_pool_alloc (task->task_pool, 32);
+               rspamd_snprintf (sumbuf, 32, "%.2f", best);
+               cur = g_list_prepend (NULL, sumbuf);
+               insert_result (task, data.statfiles[best_num].st->symbol, best, cur);
+       }
+
+       g_free (data.statfiles);
+
+       return TRUE;
+}
+
+gboolean
+bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input,
+                               gboolean in_class, double *sum, double multiplier, GError **err)
+{
+       struct bayes_callback_data      data;
+       char                           *value;
+       int                             nodes, minnodes;
+       struct statfile                *st, *sel_st = NULL;
+       stat_file_t                    *to_learn;
+       GList                          *cur;
+
+       g_assert (pool != NULL);
+       g_assert (ctx != NULL);
+
+       if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
+               minnodes = strtol (value, NULL, 10);
+               nodes = g_tree_nnodes (input) / FEATURE_WINDOW_SIZE;
+               if (nodes < minnodes) {
+                       msg_info ("do not learn message as it has too few tokens: %d, while %d min", nodes, minnodes);
+                       *sum = 0;
+                       g_set_error (err,
+                          bayes_error_quark(),         /* error domain */
+                          1,                                           /* error code */
+                          "message contains too few tokens: %d, while min is %d",
+                          nodes, minnodes);
+                       return FALSE;
+               }
+       }
+
+       data.pool = pool;
+       data.in_class = in_class;
+       data.now = time (NULL);
+       data.ctx = ctx;
+       cur = ctx->cfg->statfiles;
+       while (cur) {
+               /* Select statfile to learn */
+               st = cur->data;
+               if (strcmp (st->symbol, symbol) == 0) {
+                       sel_st = st;
+                       break;
+               }
+               cur = g_list_next (cur);
+       }
+       if (sel_st == NULL) {
+               g_set_error (err,
+                               bayes_error_quark(),            /* error domain */
+                               1,                                      /* error code */
+                               "cannot find statfile for symbol: %s",
+                               symbol);
+       }
+       if ((to_learn = statfile_pool_is_open (pool, sel_st->path)) == NULL) {
+               if ((to_learn = statfile_pool_open (pool, sel_st->path, sel_st->size, FALSE)) == NULL) {
+                       msg_warn ("cannot open %s", sel_st->path);
+                       if (statfile_pool_create (pool, sel_st->path, sel_st->size) == -1) {
+                               msg_err ("cannot create statfile %s", sel_st->path);
+                               g_set_error (err,
+                                               bayes_error_quark(),            /* error domain */
+                                               1,                                      /* error code */
+                                               "cannot create statfile: %s",
+                                               sel_st->path);
+                               return FALSE;
+                       }
+                       if ((to_learn = statfile_pool_open (pool, sel_st->path, sel_st->size, FALSE)) == NULL) {
+                               g_set_error (err,
+                                               bayes_error_quark(),            /* error domain */
+                                               1,                                      /* error code */
+                                               "cannot open statfile %s after creation",
+                                               sel_st->path);
+                               msg_err ("cannot open statfile %s after creation", sel_st->path);
+                               return FALSE;
+                       }
+               }
+       }
+       data.file = to_learn;
+       statfile_pool_lock_file (pool, data.file);
+       g_tree_foreach (input, bayes_learn_callback, &data);
+       statfile_pool_unlock_file (pool, data.file);
+
+       return TRUE;
+}
+
+GList *
+bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task)
+{
+       struct bayes_callback_data      data;
+       char                           *value;
+       int                             nodes, minnodes, i, cnt;
+       struct classify_weight         *w;
+       struct statfile                *st;
+       stat_file_t                    *file;
+       GList                          *cur, *resl = NULL;
+
+       g_assert (pool != NULL);
+       g_assert (ctx != NULL);
+
+       if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
+               minnodes = strtol (value, NULL, 10);
+               nodes = g_tree_nnodes (input) / FEATURE_WINDOW_SIZE;
+               if (nodes < minnodes) {
+                       return NULL;
+               }
+       }
+
+       data.statfiles_num = g_list_length (ctx->cfg->statfiles);
+       data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num);
+       data.pool = pool;
+       data.now = time (NULL);
+       data.ctx = ctx;
+
+       cur = ctx->cfg->statfiles;
+       i = 0;
+       while (cur) {
+               /* Select statfile to learn */
+               st = cur->data;
+               if ((file = statfile_pool_is_open (pool, st->path)) == NULL) {
+                       if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
+                               msg_warn ("cannot open %s", st->path);
+                               cur = g_list_next (cur);
+                               data.statfiles_num --;
+                               continue;
+                       }
+               }
+               data.statfiles[i].file = file;
+               data.statfiles[i].st = st;
+               data.statfiles[i].post_probability = 0.5;
+               data.statfiles[i].local_probability = 0.5;
+               i ++;
+               cur = g_list_next (cur);
+       }
+       cnt = i;
+
+       g_tree_foreach (input, bayes_classify_callback, &data);
+
+       for (i = 0; i < cnt; i ++) {
+               w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight));
+               w->name = data.statfiles[i].st->symbol;
+               w->weight = data.statfiles[i].post_probability;
+               resl = g_list_prepend (resl, w);
+       }
+
+       g_free (data.statfiles);
+
+       if (resl != NULL) {
+               memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, resl);
+       }
+
+       return resl;
+}
index 219576870d460b1c22f75186b811436b88bbb9c6..6b0554e1b0cd3db3dd382db6a12d7f5464a6c3be 100644 (file)
@@ -36,6 +36,13 @@ struct classifier               classifiers[] = {
                        .classify_func = winnow_classify,
                        .learn_func = winnow_learn,
                        .weights_func = winnow_weights
+               },
+               {
+                       .name = "bayes",
+                       .init_func = bayes_init,
+                       .classify_func = bayes_classify,
+                       .learn_func = bayes_learn,
+                       .weights_func = bayes_weights
                }
 };
 
index f69c1284ce7d33592ccf1bacd755cdb3c489b165..0e6df173a9843e6b78541d08e9ca2a24b4df9f15 100644 (file)
@@ -6,6 +6,9 @@
 #include "../statfile.h"
 #include "../tokenizers/tokenizers.h"
 
+/* Consider this value as 0 */
+#define ALPHA 0.0001
+
 struct classifier_config;
 struct worker_task;
 
@@ -41,7 +44,12 @@ gboolean winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const
                                gboolean in_class, double *sum, double multiplier, GError **err);
 GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
 
-
+/* Bayes algorithm */
+struct classifier_ctx* bayes_init (memory_pool_t *pool, struct classifier_config *cf);
+gboolean bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
+gboolean bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input,
+                               gboolean in_class, double *sum, double multiplier, GError **err);
+GList *bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
 /* Array of all defined classifiers */
 extern struct classifier classifiers[];
 
index 704f65b0a62a54da93678a6a692d12818c5a0e7c..f8c104a5208ffb054c6dc6ea90779bdd747ba414 100644 (file)
 
 #define MAX_WEIGHT G_MAXDOUBLE / 2.
 
-#define ALPHA 0.01
+
 
 #define MAX_LEARN_ITERATIONS 100
 
 G_INLINE_FUNC GQuark
 winnow_error_quark (void)
 {
-       return g_quark_from_static_string ("winnow-error-quark");
+       return g_quark_from_static_string ("winnow-error");
 }
 
 struct winnow_callback_data {
@@ -73,7 +73,7 @@ static const double max_common_weight = MAX_WEIGHT * WINNOW_DEMOTION;
 
 
 static                          gboolean
-classify_callback (gpointer key, gpointer value, gpointer data)
+winnow_classify_callback (gpointer key, gpointer value, gpointer data)
 {
        token_node_t                   *node = key;
        struct winnow_callback_data    *cd = data;
@@ -95,7 +95,7 @@ classify_callback (gpointer key, gpointer value, gpointer data)
 }
 
 static                          gboolean
-learn_callback (gpointer key, gpointer value, gpointer data)
+winnow_learn_callback (gpointer key, gpointer value, gpointer data)
 {
        token_node_t                   *node = key;
        struct winnow_callback_data    *cd = data;
@@ -247,7 +247,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
                }
 
                if (data.file != NULL) {
-                       g_tree_foreach (input, classify_callback, &data);
+                       g_tree_foreach (input, winnow_classify_callback, &data);
                }
 
                if (data.count != 0) {
@@ -320,7 +320,7 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu
                }
 
                if (data.file != NULL) {
-                       g_tree_foreach (input, classify_callback, &data);
+                       g_tree_foreach (input, winnow_classify_callback, &data);
                }
 
                w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight));
@@ -407,7 +407,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *sym
                                                                st->path);
                                                return FALSE;
                                        }
-                                       if (statfile_pool_open (pool, st->path, st->size, FALSE)) {
+                                       if (statfile_pool_open (pool, st->path, st->size, FALSE) == NULL) {
                                                g_set_error (err,
                                                                winnow_error_quark(),           /* error domain */
                                                                1,                                      /* error code */
@@ -438,7 +438,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *sym
                data.sum = 0;
                data.count = 0;
                data.new_blocks = 0;
-               g_tree_foreach (input, classify_callback, &data);
+               g_tree_foreach (input, winnow_classify_callback, &data);
                if (data.count > 0) {
                        max = data.sum / (double)data.count;
                }
@@ -462,7 +462,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *sym
                                                st->path);
                                return FALSE;
                        }
-                       g_tree_foreach (input, classify_callback, &data);
+                       g_tree_foreach (input, winnow_classify_callback, &data);
                        if (data.count != 0) {
                                res = data.sum / data.count;
                        }
@@ -513,7 +513,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *sym
                        }
 
                        statfile_pool_lock_file (pool, data.file);
-                       g_tree_foreach (input, learn_callback, &data);
+                       g_tree_foreach (input, winnow_learn_callback, &data);
                        statfile_pool_unlock_file (pool, data.file);
                        if (data.count != 0) {
                                res = data.sum / data.count;
index 662a70e7411cdf6ca8b81e9b40a9e0aec5ce84bc..3a0ff51713f658616ec518b8525e18960b154be2 100644 (file)
@@ -361,8 +361,6 @@ statfile_pool_open (statfile_pool_t * pool, char *filename, size_t size, gboolea
        new_file->access_time = new_file->open_time;
        new_file->lock = memory_pool_get_mutex (pool->pool);
 
-       /* Keep sorted */
-       qsort (pool->files, pool->opened, sizeof (stat_file_t), cmpstatfile);
        memory_pool_unlock_mutex (pool->lock);
 
        return statfile_pool_is_open (pool, filename);
@@ -392,11 +390,6 @@ statfile_pool_close (statfile_pool_t * pool, stat_file_t * file, gboolean keep_s
        pool->occupied -= file->len;
        pool->opened--;
 
-       if (keep_sorted) {
-               memmove (pos, &pool->files[pool->opened], sizeof (stat_file_t));
-               /* Keep sorted */
-               qsort (pool->files, pool->opened, sizeof (stat_file_t), cmpstatfile);
-       }
        memory_pool_unlock_mutex (pool->lock);
 
        return 0;
@@ -639,7 +632,7 @@ statfile_pool_is_open (statfile_pool_t * pool, char *filename)
 {
        static stat_file_t              f, *ret;
        g_strlcpy (f.filename, filename, sizeof (f.filename));
-       ret = bsearch (&f, pool->files, pool->opened, sizeof (stat_file_t), cmpstatfile);
+       ret = lfind (&f, pool->files, (size_t *)&pool->opened, sizeof (stat_file_t), cmpstatfile);
        return ret;
 }