]> source.dussan.org Git - rspamd.git/commitdiff
* Make autolearn working
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Thu, 9 Jul 2009 16:45:11 +0000 (20:45 +0400)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Thu, 9 Jul 2009 16:45:11 +0000 (20:45 +0400)
src/filter.c
src/filter.h
src/fstring.c
src/statfile.c
src/tokenizers/osb.c
src/tokenizers/tokenizers.c
src/util.c
src/view.c

index daa9b0e29f69da7ab6d27959fe3522b8955201c3..1c45f0886bddac1a815e3fee8ab7196ca0ee4628 100644 (file)
@@ -62,6 +62,7 @@ insert_result (struct worker_task *task, const char *metric_name, const char *sy
                /* Create new metric chain */
                metric_res = memory_pool_alloc (task->task_pool, sizeof (struct metric_result));
                metric_res->symbols = g_hash_table_new (g_str_hash, g_str_equal);
+               metric_res->checked = FALSE;
                memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_hash_table_destroy, metric_res->symbols);
                metric_res->metric = metric;
                g_hash_table_insert (task->results, (gpointer)metric_name, metric_res);
@@ -214,11 +215,19 @@ call_filter_by_name (struct worker_task *task, const char *name, enum filter_typ
 }
 
 static void
-metric_process_callback (gpointer key, gpointer value, void *data)
+metric_process_callback_common (gpointer key, gpointer value, void *data, gboolean is_forced)
 {
        struct worker_task *task = (struct worker_task *)data;
        struct metric_result *metric_res = (struct metric_result *)value;
        
+       if (metric_res->checked && !is_forced) {
+               /* Already checked */
+               return;
+       }
+       
+       /* Set flag */
+       metric_res->checked = TRUE;
+
        if (metric_res->metric->func != NULL) {
                metric_res->score = metric_res->metric->func (task, metric_res->metric->name, metric_res->metric->func_name);
        }
@@ -229,6 +238,18 @@ metric_process_callback (gpointer key, gpointer value, void *data)
                                        metric_res->score, metric_res->metric->name);
 }
 
+static void
+metric_process_callback_normal (gpointer key, gpointer value, void *data)
+{
+       metric_process_callback_common (key, value, data, FALSE);
+}
+
+static void
+metric_process_callback_forced (gpointer key, gpointer value, void *data)
+{
+       metric_process_callback_common (key, value, data, TRUE);
+}
+
 static int
 continue_process_filters (struct worker_task *task)
 {
@@ -359,7 +380,7 @@ process_filters (struct worker_task *task)
        }
 
        /* Process all metrics */
-       g_hash_table_foreach (task->results, metric_process_callback, task);
+       g_hash_table_foreach (task->results, metric_process_callback_forced, task);
        return 1;
 }
 
@@ -443,6 +464,75 @@ composites_foreach_callback (gpointer key, gpointer value, void *data)
        return;
 }
 
+static gboolean
+check_autolearn (struct statfile_autolearn_params *params, struct worker_task *task)
+{      
+       const char *metric_name = DEFAULT_METRIC;
+       struct metric_result *metric_res;
+       GList *cur;
+
+       if (params->metric != NULL) {
+               metric_name = params->metric;
+       }
+
+       /* First check threshold */
+       metric_res = g_hash_table_lookup (task->results, metric_name);
+       if (metric_res == NULL) {
+               if (params->symbols == NULL && params->threshold_max > 0) {
+                       /* For ham messages */
+                       return TRUE;
+               }
+               msg_debug ("check_autolearn: metric %s has no results", metric_name);
+               return FALSE;
+       }
+       else {
+               /* Process score of metric */
+               metric_process_callback_normal ((void *)metric_name, metric_res, task);
+               if ((params->threshold_min != 0 && metric_res->score > params->threshold_min) || 
+                       (params->threshold_max != 0 && metric_res->score < params->threshold_max)) {
+                       /* Now check for specific symbols */
+                       if (params->symbols) {
+                               cur = params->symbols;
+                               while (cur) {
+                                       if (g_hash_table_lookup (metric_res->symbols, cur->data) == NULL) {
+                                               return FALSE;
+                                       }
+                                       cur = g_list_next (cur);
+                               }
+                       }
+                       /* Now allow processing of actual autolearn */
+                       return TRUE;
+               }
+       }
+
+       return FALSE;
+}
+
+static void
+process_autolearn (struct statfile *st, struct worker_task *task, GTree *tokens, 
+                                       struct classifier *classifier, char *filename, struct classifier_ctx* ctx)
+{
+       if (check_autolearn (st->autolearn, task)) {
+               if (tokens) {
+                       msg_info ("process_autolearn: message with id <%s> autolearned statfile '%s'", task->message_id, filename);
+                       /* Check opened */
+                       if (! statfile_pool_is_open (task->worker->srv->statfile_pool, filename)) {
+                               /* Try open */
+                               if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL) {
+                                       /* Try create */
+                                       if (statfile_pool_create (task->worker->srv->statfile_pool, 
+                                                                       filename, st->size / sizeof (struct stat_file_block)) == -1) {
+                                               msg_info ("process_autolearn: error while creating statfile %s", filename);
+                                               return;
+                                       }
+                               }
+                       }
+
+                       classifier->learn_func (ctx, task->worker->srv->statfile_pool, filename, tokens, 1);
+               }
+       }
+}
+
 static void
 composites_metric_callback (gpointer key, gpointer value, void *data) 
 {
@@ -498,7 +588,7 @@ statfiles_callback (gpointer key, gpointer value, void *arg)
                filename = resolve_stat_filename (task->task_pool, st->pattern, task->from, "");
        }
        
-       if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL) {
+       if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL && !check_autolearn (st->autolearn, task)) {
                return;
        }
        
@@ -513,6 +603,7 @@ statfiles_callback (gpointer key, gpointer value, void *arg)
                                msg_info ("statfiles_callback: cannot tokenize input");
                                return;
                        }
+                       cur = g_list_next (cur);
                }
                g_hash_table_insert (data->tokens, st->tokenizer, tokens);
        }
@@ -533,6 +624,10 @@ statfiles_callback (gpointer key, gpointer value, void *arg)
        
        classifier->classify_func (res_data->ctx, task->worker->srv->statfile_pool, filename, tokens, st->weight);
 
+       if (st->autolearn) {
+               /* Process autolearn */
+               process_autolearn (st, task, tokens, classifier, filename, res_data->ctx);
+       }
 }
 
 static void
@@ -548,7 +643,6 @@ statfiles_results_callback (gpointer key, gpointer value, void *arg)
        filename = classifier->result_file_func (res->ctx, w);
        insert_result (task, res->metric->name, classifier->name, *w, NULL);
        msg_debug ("statfiles_results_callback: got total weight %.2f for metric %s", *w, res->metric->name);
-
 }
 
 
@@ -566,7 +660,8 @@ process_statfiles (struct worker_task *task)
        
        g_hash_table_destroy (cd.tokens);
        g_hash_table_destroy (cd.classifiers);
-       g_hash_table_foreach (task->results, metric_process_callback, task);
+       /* Process results */
+       g_hash_table_foreach (task->results, metric_process_callback_forced, task);
 
        task->state = WRITE_REPLY;
 }
index c460ec3177ac2e2d7f749e16efbc966f8cd34818..e0c989f85a66a4f04fd07be76ab3e1915cfef1af 100644 (file)
@@ -49,6 +49,7 @@ struct metric_result {
        struct metric *metric;                                                  /**< pointer to metric structure                        */
        double score;                                                                   /**< total score                                                        */
        GHashTable *symbols;                                                    /**< symbols of metric                                          */
+       gboolean checked;                                                               /**< whether metric result is consolidated  */
 };
 
 /**
index 935c8bdcc9b67c99a95c4e9e31a7c89b9b12d941..00ca4ed12fe06c18d64e606a21fa056f800dd30b 100644 (file)
@@ -306,19 +306,20 @@ fstrhash (f_str_t *str)
        size_t i;
        uint32_t hval;
        uint32_t tmp;
+       char *c = str->begin;
 
        if (str == NULL) {
                return 0;
        }
        hval = str->len;
 
-       for     (i = 0; i < str->len; i++) {
+       for     (i = 0; i < str->len; i++, c++) {
                /* 
                 * xor in the current byte against each byte of hval
                 * (which alone gaurantees that every bit of input will have
                 * an effect on the output)
                 */
-               tmp = *(str->begin + i) & 0xFF;
+               tmp = *c & 0xFF;
                tmp = tmp | (tmp << 8) | (tmp << 16) | (tmp << 24);
                hval ^= tmp;
 
index ac0c3bfaa81093046de26156f89848b1a5df4ed4..4a52008ed60e5938e99d520213a4aac99727692b 100644 (file)
@@ -129,9 +129,8 @@ statfile_pool_open (statfile_pool_t *pool, char *filename)
        struct stat st;
        stat_file_t *new_file;  
        
-       if (statfile_pool_is_open (pool, filename) != NULL) {
-               msg_info ("statfile_pool_open: file %s is already opened", filename);
-               return NULL;
+       if ((new_file = statfile_pool_is_open (pool, filename)) != NULL) {
+               return new_file;
        }
 
        if (pool->opened >= STATFILES_MAX - 1) {
@@ -400,9 +399,10 @@ statfile_pool_set_block (statfile_pool_t *pool, stat_file_t *file, uint32_t h1,
 stat_file_t *
 statfile_pool_is_open (statfile_pool_t *pool, char *filename)
 {
-       static stat_file_t f;
+       static stat_file_t f, *ret;
        f.filename = filename;
-       return bsearch (&f, pool->files, pool->opened, sizeof (stat_file_t), cmpstatfile);
+       ret = bsearch (&f, pool->files, pool->opened, sizeof (stat_file_t), cmpstatfile);
+       return ret;
 }
 
 uint32_t
index 32d6b902ad13cad1af8d0763a9dd4d1e182cf2ba..d2a1fe22f5f44f023f237b2ae673761fb41f67d2 100644 (file)
@@ -29,6 +29,8 @@
 #include <sys/types.h>
 #include "tokenizers.h"
 
+/* Minimum length of token */
+#define MIN_LEN 4
 
 extern const int primes[];
 
@@ -36,7 +38,7 @@ int
 osb_tokenize_text (struct tokenizer *tokenizer, memory_pool_t *pool, f_str_t *input, GTree **tree)
 {
        token_node_t *new = NULL;
-       f_str_t token = { NULL, 0, 0 };
+       f_str_t token = { NULL, 0, 0 }, *res;
        uint32_t hashpipe[FEATURE_WINDOW_SIZE], h1, h2;
        int i;
 
@@ -52,7 +54,11 @@ osb_tokenize_text (struct tokenizer *tokenizer, memory_pool_t *pool, f_str_t *in
 
        msg_debug ("osb_tokenize_text: got input length: %zd", input->len);
 
-       while (tokenizer->get_next_word (input, &token)) {
+       while ((res = tokenizer->get_next_word (input, &token)) != NULL) {
+               /* Skip small words */
+               if (token.len < MIN_LEN) {
+                       continue;
+               }
                /* Shift hashpipe */
                for (i = FEATURE_WINDOW_SIZE - 1; i > 0; i --) {
                        hashpipe[i] = hashpipe[i - 1];
index 4527e699ccf9749ab772ceb9d4e8bfbc2e255798..7db1af12cdad903719f74e4d0f9d64d15ef1585c 100644 (file)
@@ -78,12 +78,11 @@ f_str_t *
 get_next_word (f_str_t *buf, f_str_t *token)
 {
        size_t remain;
-       char *pos;
+       unsigned char *pos;
        
        if (buf == NULL) {
                return NULL;
        }
-
        if (token->begin == NULL) {
                token->begin = buf->begin;
        }
@@ -95,15 +94,14 @@ get_next_word (f_str_t *buf, f_str_t *token)
        if (remain <= 0) {
                return NULL;
        }
-
        pos = token->begin;
        /* Skip non graph symbols */
-       while (remain > 0 && !g_ascii_isgraph (*pos)) {
+       while (remain > 0 && (!g_ascii_isgraph (*pos) && *pos < 127)) {
                token->begin ++;
                pos ++;
                remain --;
        }
-       while (remain > 0 && g_ascii_isgraph (*pos)) {
+       while (remain > 0 && (g_ascii_isgraph (*pos) || *pos > 127)) {
                token->len ++;
                pos ++;
                remain --;
index 76c9c31a867a5e77c3c3d37f0ee5e092dba84f08..9dab02da76074128ad8fce8ae920d05cbf71d5d4 100644 (file)
@@ -1052,7 +1052,7 @@ maybe_parse_host_list (memory_pool_t *pool, GHashTable *tbl, const char *filenam
 gint
 rspamd_strcase_equal (gconstpointer v, gconstpointer v2)
 {
-       return g_ascii_strcasecmp ((const char *) v, (const char *) v2) == 0;
+       return g_ascii_strcasecmp ((const char *) v, (const char *) v2);
 }
 
 
index 0bd534b322f7fdfbb4c6a798602b533eee84ccc4..0a03d4304e2fee16171a23852051ffb30799ccf4 100644 (file)
@@ -120,7 +120,7 @@ find_view_by_ip (GList *views, struct worker_task *task)
        cur = views;
        while (cur) {
                v = cur->data;
-               if (radix32tree_find (v->ip_tree, task->from_addr.s_addr) != RADIX_NO_VALUE) {
+               if (radix32tree_find (v->ip_tree, ntohl (task->from_addr.s_addr)) != RADIX_NO_VALUE) {
                        return v;
                }
                cur = g_list_next (cur);