]> source.dussan.org Git - rspamd.git/commitdiff
* Add correcting factor to statistics.
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Tue, 28 Jun 2011 15:07:26 +0000 (19:07 +0400)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Tue, 28 Jun 2011 15:07:26 +0000 (19:07 +0400)
Now learning increments version of a statfile.
Avoid learning and classifying of similar text parts if a message has 2 text parts.
Several fixes to statistics.

src/classifiers/bayes.c
src/filter.c
src/statfile.c
src/statfile.h
src/tokenizers/osb.c
src/tokenizers/tokenizers.c

index af79e0eaa11c2dc94f858596b5dead3d87ec4818..7363df52273343931aab3784b6311d1a227d1d74 100644 (file)
@@ -47,7 +47,8 @@ struct bayes_statfile_data {
        guint64                         total_hits;
        double                          local_probability;
        double                          post_probability;
-       guint64                         value;
+       double                          corr;
+       double                          value;
        struct statfile                *st;
        stat_file_t                    *file;
 };
@@ -60,6 +61,7 @@ struct bayes_callback_data {
        stat_file_t                    *file;
        struct bayes_statfile_data     *statfiles;
        guint32                         statfiles_num;
+       guint64                          learned_tokens;
 };
 
 static                          gboolean
@@ -67,7 +69,8 @@ bayes_learn_callback (gpointer key, gpointer value, gpointer data)
 {
        token_node_t                   *node = key;
        struct bayes_callback_data     *cd = data;
-       gint                            v, c;
+       gint                            c;
+       guint64                         v;
 
        c = (cd->in_class) ? 1 : -1;
 
@@ -75,8 +78,9 @@ bayes_learn_callback (gpointer key, gpointer value, gpointer data)
        v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
        if (v == 0 && c > 0) {
                statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c);
+               cd->learned_tokens ++;
        }
-       else {
+       else if (v != 0) {
                if (G_LIKELY (c > 0)) {
                        v ++;
                }
@@ -86,6 +90,7 @@ bayes_learn_callback (gpointer key, gpointer value, gpointer data)
                        }
                }
                statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, v);
+               cd->learned_tokens ++;
        }
 
        return FALSE;
@@ -102,24 +107,21 @@ bayes_classify_callback (gpointer key, gpointer value, gpointer data)
        struct bayes_callback_data     *cd = data;
        double                          renorm = 0;
        gint                            i;
-       guint64                         local_hits = 0;
+       double                          local_hits = 0;
        struct bayes_statfile_data     *cur;
 
        for (i = 0; i < cd->statfiles_num; i ++) {
                cur = &cd->statfiles[i];
-               cur->value = statfile_pool_get_block (cd->pool, cur->file, node->h1, node->h2, cd->now);
+               cur->value = statfile_pool_get_block (cd->pool, cur->file, node->h1, node->h2, cd->now) * cur->corr;
                if (cur->value > 0) {
-                       cur->total_hits += cur->value;
+                       cur->total_hits ++;
                        cur->hits = cur->value;
                        local_hits += cur->value;
                }
-               else {
-                       cur->value = 0;
-               }
        }
        for (i = 0; i < cd->statfiles_num; i ++) {
                cur = &cd->statfiles[i];
-               cur->local_probability = 0.5 + ((double)cur->value - ((double)local_hits - cur->value)) /
+               cur->local_probability = 0.5 + (cur->value - (local_hits - cur->value)) /
                                (LOCAL_PROB_DENOM * (1.0 + local_hits));
                renorm += cur->post_probability * cur->local_probability;
        }
@@ -145,7 +147,7 @@ bayes_classify_callback (gpointer key, gpointer value, gpointer data)
                        cur->post_probability = G_MINDOUBLE * 100;
                }
                if (cd->ctx->debug) {
-                       msg_info ("token: %s, statfile: %s, probability: %uL, post_probability: %.4f",
+                       msg_info ("token: %s, statfile: %s, probability: %.4f, post_probability: %.4f",
                                        node->extra, cur->st->symbol, cur->value, cur->post_probability);
                }
        }
@@ -169,8 +171,9 @@ gboolean
 bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task)
 {
        struct bayes_callback_data      data;
-       char                           *value;
-       int                             nodes, minnodes, i, cnt, best_num = 0;
+       gchar                          *value;
+       gint                            nodes, minnodes, i = 0, cnt, best_num = 0;
+       guint64                         rev, total_learns = 0;
        double                          best = 0;
        struct statfile                *st;
        stat_file_t                    *file;
@@ -198,7 +201,6 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input,
        data.ctx = ctx;
 
        cur = ctx->cfg->statfiles;
-       i = 0;
        while (cur) {
                /* Select statfile to learn */
                st = cur->data;
@@ -214,11 +216,21 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input,
                data.statfiles[i].st = st;
                data.statfiles[i].post_probability = 0.5;
                data.statfiles[i].local_probability = 0.5;
-               i ++;
+               statfile_get_revision (file, &rev, NULL);
+               total_learns += rev;
+
                cur = g_list_next (cur);
+               i ++;
        }
+
        cnt = i;
 
+       /* Calculate correction factor */
+       for (i = 0; i < cnt; i ++) {
+               statfile_get_revision (data.statfiles[i].file, &rev, NULL);
+               data.statfiles[i].corr = ((double)rev / cnt) / (double)total_learns;
+       }
+
        g_tree_foreach (input, bayes_classify_callback, &data);
 
        for (i = 0; i < cnt; i ++) {
@@ -277,6 +289,7 @@ bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symb
        data.in_class = in_class;
        data.now = time (NULL);
        data.ctx = ctx;
+       data.learned_tokens = 0;
        cur = ctx->cfg->statfiles;
        while (cur) {
                /* Select statfile to learn */
@@ -321,8 +334,13 @@ bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symb
        data.file = to_learn;
        statfile_pool_lock_file (pool, data.file);
        g_tree_foreach (input, bayes_learn_callback, &data);
+       statfile_inc_revision (to_learn);
        statfile_pool_unlock_file (pool, data.file);
 
+       if (sum != NULL) {
+               *sum = data.learned_tokens;
+       }
+
        return TRUE;
 }
 
index 797b4f6fe4f55ca97090233aaaa237f0c6969712..b625e4c722e6d22f47863023b491e74448b3af07 100644 (file)
@@ -40,6 +40,8 @@
 #   include "lua/lua_common.h"
 #endif
 
+#define COMMON_PART_FACTOR 80
+
 static inline                   GQuark
 filter_error_quark (void)
 {
@@ -593,7 +595,8 @@ classifiers_callback (gpointer value, void *arg)
        GTree                          *tokens = NULL;
        GList                          *cur;
        f_str_t                         c;
-       gchar                           *header = NULL;
+       gchar                          *header = NULL;
+       gboolean                        is_twopart = FALSE;
        
        if ((header = g_hash_table_lookup (cl->opts, "header")) != NULL) {
                cur = message_get_header (task->task_pool, task->message, header, FALSE);
@@ -603,6 +606,9 @@ classifiers_callback (gpointer value, void *arg)
        }
        else {
                cur = g_list_first (task->text_parts);
+               if (cur != NULL && cur->next != NULL && cur->next->next == NULL) {
+                       is_twopart = TRUE;
+               }
        }
        ctx = cl->classifier->init_func (task->task_pool, cl);
 
@@ -624,10 +630,18 @@ classifiers_callback (gpointer value, void *arg)
                                        cur = g_list_next (cur);
                                        continue;
                                }
+                               if (is_twopart && cur->next == NULL) {
+                                       /* Compare part's content */
+                                       if (fuzzy_compare_parts (cur->data, cur->prev->data) >= COMMON_PART_FACTOR) {
+                                               msg_info ("message <%s> has two common text parts, ignore the last one", task->message_id);
+                                               break;
+                                       }
+                               }
                                c.begin = text_part->content->data;
                                c.len = text_part->content->len;
                                /* Tree would be freed at task pool freeing */
-                               if (!cl->tokenizer->tokenize_func (cl->tokenizer, task->task_pool, &c, &tokens, FALSE, text_part->is_utf, text_part->urls_offset)) {
+                               if (!cl->tokenizer->tokenize_func (cl->tokenizer, task->task_pool, &c, &tokens,
+                                               FALSE, text_part->is_utf, text_part->urls_offset)) {
                                        msg_info ("cannot tokenize input");
                                        return;
                                }
@@ -815,7 +829,7 @@ learn_task (const gchar *statfile, struct worker_task *task, GError **err)
        stat_file_t                    *stf;
        gdouble                         sum;
        struct mime_text_part          *part;
-       gboolean                        is_utf = FALSE;
+       gboolean                        is_utf = FALSE, is_twopart = FALSE;
 
        /* Load classifier by symbol */
        cl = g_hash_table_lookup (task->cfg->classifiers_symbols, statfile);
@@ -834,6 +848,9 @@ learn_task (const gchar *statfile, struct worker_task *task, GError **err)
        else {
                /* Classify message otherwise */
                cur = g_list_first (task->text_parts);
+               if (cur != NULL && cur->next != NULL && cur->next->next == NULL) {
+                       is_twopart = TRUE;
+               }
        }
 
        /* Get tokens from each element */
@@ -854,6 +871,13 @@ learn_task (const gchar *statfile, struct worker_task *task, GError **err)
                        c.len = part->content->len;
                        is_utf = part->is_utf;
                        ex = part->urls_offset;
+                       if (is_twopart && cur->next == NULL) {
+                               /* Compare part's content */
+                               if (fuzzy_compare_parts (cur->data, cur->prev->data) >= COMMON_PART_FACTOR) {
+                                       msg_info ("message <%s> has two common text parts, ignore the last one", task->message_id);
+                                       break;
+                               }
+                       }
                }
                /* Get tokens */
                if (!cl->tokenizer->tokenize_func (
index 0359c0c4d2d2e6982bcf20494a4ad23463cedb62..65da1503355fc54b392281db94e498c4fcb4daac 100644 (file)
@@ -789,6 +789,22 @@ statfile_set_revision (stat_file_t *file, guint64 rev, time_t time)
 }
 
 gboolean 
+statfile_inc_revision (stat_file_t *file)
+{
+       struct stat_file_header        *header;
+
+       if (file == NULL || file->map == NULL) {
+               return FALSE;
+       }
+
+       header = (struct stat_file_header *)file->map;
+
+       header->revision ++;
+
+       return TRUE;
+}
+
+gboolean
 statfile_get_revision (stat_file_t *file, guint64 *rev, time_t *time)
 {
        struct stat_file_header        *header;
@@ -799,8 +815,12 @@ statfile_get_revision (stat_file_t *file, guint64 *rev, time_t *time)
        
        header = (struct stat_file_header *)file->map;
 
-       *rev = header->revision;
-       *time = header->rev_time;
+       if (rev != NULL) {
+               *rev = header->revision;
+       }
+       if (time != NULL) {
+               *time = header->rev_time;
+       }
 
        return TRUE;
 }
index 5714641966c1b45540243d5738dbbc7f9d6f962d..e76aa08971c148ddcdbb7d82de5c03679a4c7e02 100644 (file)
@@ -224,6 +224,14 @@ guint32 statfile_get_section_by_name (const gchar *name);
  */
 gboolean statfile_set_revision (stat_file_t *file, guint64 rev, time_t time);
 
+/**
+ * Increment statfile revision and revision time
+ * @param filename name of statfile
+ * @param time time of revision
+ * @return TRUE if revision was set
+ */
+gboolean statfile_inc_revision (stat_file_t *file);
+
 /**
  * Set statfile revision and revision time
  * @param filename name of statfile
index bc57255cb50a70e59e493204777ab9bc589ec57b..790069d6a74984a809ca175472fb07e0c6811f1c 100644 (file)
@@ -41,7 +41,7 @@ osb_tokenize_text (struct tokenizer *tokenizer, memory_pool_t * pool, f_str_t *
        token_node_t                   *new = NULL;
        f_str_t                         token = { NULL, 0, 0 };
        guint32                         hashpipe[FEATURE_WINDOW_SIZE], h1, h2;
-       gint                            i, k = 0, l;
+       gint                            i, l;
        gchar                          *res;
 
        if (*tree == NULL) {
@@ -49,6 +49,8 @@ osb_tokenize_text (struct tokenizer *tokenizer, memory_pool_t * pool, f_str_t *
                memory_pool_add_destructor (pool, (pool_destruct_func) g_tree_destroy, *tree);
        }
 
+       memset (hashpipe, 0xfe, FEATURE_WINDOW_SIZE * sizeof (hashpipe[0]));
+
        while ((res = tokenizer->get_next_word (input, &token, &exceptions)) != NULL) {
                /* Skip small words */
                if (is_utf) {
@@ -68,23 +70,20 @@ osb_tokenize_text (struct tokenizer *tokenizer, memory_pool_t * pool, f_str_t *
                }
                hashpipe[0] = fstrhash_lowercase (&token, is_utf);
 
-               if (k > FEATURE_WINDOW_SIZE) {
-                       for (i = 1; i < FEATURE_WINDOW_SIZE; i++) {
-                               h1 = hashpipe[0] * primes[0] + hashpipe[i] * primes[i << 1];
-                               h2 = hashpipe[0] * primes[1] + hashpipe[i] * primes[(i << 1) - 1];
-                               new = memory_pool_alloc0 (pool, sizeof (token_node_t));
-                               new->h1 = h1;
-                               new->h2 = h2;
-                               if (save_token) {
-                                       new->extra = (uintptr_t)memory_pool_fstrdup (pool, &token);
-                               }
+               for (i = 1; i < FEATURE_WINDOW_SIZE; i++) {
+                       h1 = hashpipe[0] * primes[0] + hashpipe[i] * primes[i << 1];
+                       h2 = hashpipe[0] * primes[1] + hashpipe[i] * primes[(i << 1) - 1];
+                       new = memory_pool_alloc0 (pool, sizeof (token_node_t));
+                       new->h1 = h1;
+                       new->h2 = h2;
+                       if (save_token) {
+                               new->extra = (uintptr_t)memory_pool_fstrdup (pool, &token);
+                       }
 
-                               if (g_tree_lookup (*tree, new) == NULL) {
-                                       g_tree_insert (*tree, new, new);
-                               }
+                       if (g_tree_lookup (*tree, new) == NULL) {
+                               g_tree_insert (*tree, new, new);
                        }
                }
-               k ++;
                token.begin = res;
        }
 
index 16dc763ed03240e3c6ac62bbd9b140005f556835..d5a820d1b32f784ce3d3f72232686312be282415 100644 (file)
@@ -138,7 +138,7 @@ get_next_word (f_str_t * buf, f_str_t * token, GList **exceptions)
        token->len = 0;
 
        remain = buf->len - (token->begin - buf->begin);
-       if (remain <= 0) {
+       if (remain == 0) {
                return NULL;
        }
        pos = token->begin - buf->begin;