]> source.dussan.org Git - rspamd.git/commitdiff
* Add weights command for getting weights of each message by each statfile
authorcebka@lenovo-laptop <cebka@lenovo-laptop>
Mon, 1 Mar 2010 15:37:06 +0000 (18:37 +0300)
committercebka@lenovo-laptop <cebka@lenovo-laptop>
Mon, 1 Mar 2010 15:37:06 +0000 (18:37 +0300)
* Add ability to specify multiplier when learning
* Add statistics about spam and ham messages

rspamc.pl.in
src/classifiers/classifiers.c
src/classifiers/classifiers.h
src/classifiers/winnow.c
src/controller.c
src/filter.c
src/main.h
src/protocol.c

index 4dddc7e3bd52332984ad718888c2a8114f997d8b..19bcab3f7822339c8306cf48c1cf300d7c7769b5 100755 (executable)
@@ -211,7 +211,7 @@ sub do_control_command {
         if (do_ctrl_auth ($sock)) {
             my $len = length ($input);
             print "Sending $len bytes...\n";
-            syswrite $sock, "learn $cfg{'statfile'} $len" . $CRLF;
+            syswrite $sock, "learn $cfg{'statfile'} $len -w $cfg{weight}" . $CRLF;
             syswrite $sock, $input . $CRLF;
             if (defined (my $reply = <$sock>)) {
                 if ($reply =~ /^learn ok, sum weight: ([0-9.]+)/) {
@@ -226,6 +226,18 @@ sub do_control_command {
             print "Authentication failed\n";
         }
     }
+    if ($cfg{'command'} =~ /^weights$/i) {
+        die "statfile is not specified to weights command" if !$cfg{'statfile'};
+        
+        
+               my $len = length ($input);
+               print "Sending $len bytes...\n";
+               syswrite $sock, "weights $cfg{'statfile'} $len" . $CRLF;
+               syswrite $sock, $input . $CRLF;
+               if (defined (my $reply = <$sock>)) {
+                       print $_;
+               }
+    }
     elsif ($cfg{'command'} =~ /(reload|shutdown)/i) {
         if (do_ctrl_auth ($sock)) {
             syswrite $sock, $cfg{'command'} . $CRLF;
@@ -540,7 +552,7 @@ else {
     die "unknown command $cmd";
 }
 
-if ($cmd =~ /SYMBOLS|SCAN|PROCESS|CHECK|REPORT_IFSPAM|REPORT|URLS|EMAILS|LEARN|FUZZY_ADD|FUZZY_DEL/i) {
+if ($cmd =~ /SYMBOLS|SCAN|PROCESS|CHECK|REPORT_IFSPAM|REPORT|URLS|EMAILS|LEARN|FUZZY_ADD|FUZZY_DEL|WEIGHTS/i) {
        $cfg{'require_input'} = 1;
 }
 
index 566cf2b75cbaf5d2951440b15564efbc7c9e13f7..219576870d460b1c22f75186b811436b88bbb9c6 100644 (file)
 #include "classifiers.h"
 
 struct classifier               classifiers[] = {
-       {
+               {
                        .name = "winnow",
                        .init_func = winnow_init,
                        .classify_func = winnow_classify,
                        .learn_func = winnow_learn,
-               },
+                       .weights_func = winnow_weights
+               }
 };
 
 struct classifier              *
index 12787f04999ed4b160595a215c7ba84546b51e35..de937bc3f36fbc648ca26482a0e619287afdcb74 100644 (file)
@@ -14,13 +14,20 @@ struct classifier_ctx {
        GHashTable *results;
        struct classifier_config *cfg;
 };
+
+struct classify_weight {
+       const char *name;
+       double weight;
+};
+
 /* Common classifier structure */
 struct classifier {
        char *name;
        struct classifier_ctx* (*init_func)(memory_pool_t *pool, struct classifier_config *cf);
        void (*classify_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
        void (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, 
-                                                       stat_file_t *file, GTree *input, gboolean in_class, double *sum);
+                                                       stat_file_t *file, GTree *input, gboolean in_class, double *sum, double multiplier);
+       GList* (*weights_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
 };
 
 /* Get classifier structure by name or return NULL if this name is not found */
@@ -29,7 +36,10 @@ struct classifier* get_classifier (char *name);
 /* Winnow algorithm */
 struct classifier_ctx* winnow_init (memory_pool_t *pool, struct classifier_config *cf);
 void winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
-void winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, stat_file_t *file, GTree *input, gboolean in_class, double *sum);
+void winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, stat_file_t *file, GTree *input, 
+                               gboolean in_class, double *sum, double multiplier);
+GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
+
 
 /* Array of all defined classifiers */
 extern struct classifier classifiers[];
index e103dd50de59d291e64c5c5eac8b164b1b8395df..af370bd3843bda0f156d5d413b82cd81f85750fb 100644 (file)
@@ -42,6 +42,7 @@ struct winnow_callback_data {
        struct classifier_ctx          *ctx;
        stat_file_t                    *file;
        double                          sum;
+       double                          multiplier;
        int                             count;
        int                             in_class;
        time_t                          now;
@@ -77,8 +78,9 @@ learn_callback (gpointer key, gpointer value, gpointer data)
        token_node_t                   *node = key;
        struct winnow_callback_data    *cd = data;
        double                           v, c;
-
+       
        c = (cd->in_class) ? WINNOW_PROMOTION : WINNOW_DEMOTION;
+       c *= cd->multiplier;
 
        /* Consider that not found blocks have value 1 */
        v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
@@ -195,13 +197,74 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
        }
 }
 
+GList *
+winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * input, struct worker_task *task)
+{
+       struct winnow_callback_data     data;
+       double                          res = 0.;
+       GList                          *cur, *resl = NULL;
+       struct statfile                *st;
+       struct classify_weight         *w;
+
+       g_assert (pool != NULL);
+       g_assert (ctx != NULL);
+
+       data.pool = pool;
+       data.sum = 0;
+       data.count = 0;
+       data.now = time (NULL);
+       data.ctx = ctx;
+    
+       cur = ctx->cfg->statfiles;
+       while (cur) {
+               st = cur->data;
+               if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
+                       if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
+                               msg_warn ("cannot open %s, skip it", st->path);
+                               cur = g_list_next (cur);
+                               continue;
+                       }
+               }
+
+               if (data.file != NULL) {
+                       statfile_pool_lock_file (pool, data.file);
+                       g_tree_foreach (input, classify_callback, &data);
+                       statfile_pool_unlock_file (pool, data.file);
+               }
+
+               w = memory_pool_alloc (task->task_pool, sizeof (struct classify_weight));
+               if (data.count != 0) {
+                       res = data.sum / data.count;
+                       w->name = st->symbol;
+                       w->weight = res;
+                       resl = g_list_prepend (resl, w);
+               }
+               else {
+                       res = 0;
+                       w->name = st->symbol;
+                       w->weight = res;
+                       resl = g_list_prepend (resl, w);
+               }
+               cur = g_list_next (cur);
+       }
+       
+       if (resl != NULL) {
+               memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, resl);
+       }
+
+       return resl;
+
+}
+
+
 void
-winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum)
+winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum, double multiplier)
 {
        struct winnow_callback_data     data = {
                .file = NULL,
                .sum = 0,
                .count = 0,
+               .multiplier = multiplier
        };
 
        g_assert (pool != NULL);
index d76c35db334f40a995ef78c7b02e694fae3ed130..3e4dc36858c80387d459a667395ddf5b14a96e01 100644 (file)
@@ -50,7 +50,8 @@ enum command_type {
        COMMAND_LEARN,
        COMMAND_HELP,
        COMMAND_COUNTERS,
-       COMMAND_SYNC
+       COMMAND_SYNC,
+       COMMAND_WEIGHTS
 };
 
 struct controller_command {
@@ -74,6 +75,7 @@ static struct controller_command commands[] = {
        {"shutdown", TRUE, COMMAND_SHUTDOWN},
        {"uptime", FALSE, COMMAND_UPTIME},
        {"learn", TRUE, COMMAND_LEARN},
+       {"weights", FALSE, COMMAND_WEIGHTS},
        {"help", FALSE, COMMAND_HELP},
        {"counters", FALSE, COMMAND_COUNTERS},
        {"sync", FALSE, COMMAND_SYNC}
@@ -336,6 +338,8 @@ process_stat_command (struct controller_session *session)
 
        memory_pool_stat (&mem_st);
        r = snprintf (out_buf, sizeof (out_buf), "Messages scanned: %u" CRLF, session->worker->srv->stat->messages_scanned);
+       r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as spam: %u" CRLF, session->worker->srv->stat->messages_spam);
+       r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as ham: %u" CRLF, session->worker->srv->stat->messages_ham);
        r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages learned: %u" CRLF, session->worker->srv->stat->messages_learned);
        r += snprintf (out_buf + r, sizeof (out_buf) - r, "Connections count: %u" CRLF, session->worker->srv->stat->connections_count);
        r += snprintf (out_buf + r, sizeof (out_buf) - r, "Control connections count: %u" CRLF, session->worker->srv->stat->control_connections_count);
@@ -477,7 +481,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
                        }
                        size = strtoul (arg, &err_str, 10);
                        if (err_str && *err_str != '\0') {
-                               msg_debug ("statfile size is invalid: %s", arg);
+                               msg_debug ("message size is invalid: %s", arg);
                                r = snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF);
                                rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
                                return;
@@ -495,6 +499,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
 
                        /* By default learn positive */
                        session->in_class = 1;
+                       session->learn_multiplier = 1.;
                        /* Get all arguments */
                        while (*cmd_args++) {
                                arg = *cmd_args;
@@ -521,6 +526,15 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
                                        case 'n':
                                                session->in_class = 0;
                                                break;
+                                       case 'm':
+                                               arg = *(cmd_args + 1);
+                                               if (!arg || *arg == '\0') {
+                                                       r = snprintf (out_buf, sizeof (out_buf), "recipient is not defined" CRLF);
+                                                       rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+                                                       return;
+                                               }
+                                               session->learn_multiplier = strtod (arg, NULL);
+                                               break;
                                        default:
                                                r = snprintf (out_buf, sizeof (out_buf), "tokenizer is not defined" CRLF);
                                                rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
@@ -532,6 +546,42 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
                        session->state = STATE_LEARN;
                }
                break;
+
+       case COMMAND_WEIGHTS:
+               arg = *cmd_args;
+               if (!arg || *arg == '\0') {
+                       msg_debug ("no statfile specified in weights command");
+                       r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF);
+                       rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+                       return;
+               }
+               arg = *(cmd_args + 1);
+               if (arg == NULL || *arg == '\0') {
+                       msg_debug ("no message size specified in weights command");
+                       r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF);
+                       rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+                       return;
+               }
+               size = strtoul (arg, &err_str, 10);
+               if (err_str && *err_str != '\0') {
+                       msg_debug ("message size is invalid: %s", arg);
+                       r = snprintf (out_buf, sizeof (out_buf), "message size is invalid" CRLF);
+                       rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+                       return;
+               }
+
+               cl = g_hash_table_lookup (session->cfg->classifiers_symbols, *cmd_args);
+               if (cl == NULL) {
+                       r = snprintf (out_buf, sizeof (out_buf), "statfile %s is not defined" CRLF, *cmd_args);
+                       rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+                       return;
+
+               }
+               session->learn_classifier = cl;
+
+               rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size);
+               session->state = STATE_WEIGHTS;
+               break;
        case COMMAND_SYNC:
                if (!process_sync_command (session, cmd_args)) {
                        r = snprintf (out_buf, sizeof (out_buf), "FAIL" CRLF);
@@ -543,7 +593,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
                r = snprintf (out_buf, sizeof (out_buf),
                        "Rspamd CLI commands (* - privilleged command):" CRLF
                        "    help - this help message" CRLF
-                       "(*) learn <statfile> <size> [-r recipient], [-f from] [-n] - learn message to specified statfile" CRLF
+                       "(*) learn <statfile> <size> [-r recipient] [-m multiplier] [-f from] [-n] - learn message to specified statfile" CRLF
                        "    quit - quit CLI session" CRLF
                        "(*) reload - reload rspamd" CRLF
                        "(*) shutdown - shutdown rspamd" CRLF "    stat - show different rspamd stat" CRLF "    counters - show rspamd counters" CRLF "    uptime - rspamd uptime" CRLF);
@@ -627,7 +677,7 @@ controller_read_socket (f_str_t * in, void *arg)
                if (session->state == STATE_COMMAND) {
                        session->state = STATE_REPLY;
                }
-               if (session->state != STATE_LEARN && session->state != STATE_OTHER) {
+               if (session->state != STATE_LEARN && session->state != STATE_WEIGHTS && session->state != STATE_OTHER) {
                        if (!rspamd_dispatcher_write (session->dispatcher, END, sizeof (END) - 1, FALSE, TRUE)) {
                                return FALSE;
                        }
@@ -714,11 +764,92 @@ controller_read_socket (f_str_t * in, void *arg)
                
                /* XXX: remove this awful legacy */
                session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool, 
-                                                                                                                               statfile, tokens, session->in_class, &sum);
+                                                                                                                               statfile, tokens, session->in_class, &sum,
+                                                                                                                               session->learn_multiplier);
                session->worker->srv->stat->messages_learned++;
 
                maybe_write_binlog (session->learn_classifier, st, statfile, tokens);
+               if (st->normalizer != NULL) {
+                       sum = st->normalizer (sum, st->normalizer_data);
+               }
+               
+               free_task (task, FALSE);
+               i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum);
+               if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+                       return FALSE;
+               }
+
+               session->state = STATE_REPLY;
+               break;
+       case STATE_WEIGHTS:
+               session->learn_buf = in;
+               task = construct_task (session->worker);
+
+               task->msg = memory_pool_alloc (task->task_pool, sizeof (f_str_t));
+               task->msg->begin = in->begin;
+               task->msg->len = in->len;
+
+               r = process_message (task);
+               if (r == -1) {
+                       msg_warn ("processing of message failed");
+                       free_task (task, FALSE);
+                       session->state = STATE_REPLY;
+                       r = snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF);
+                       rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+                       return FALSE;
+               }
+
+               cur = g_list_first (task->text_parts);
+               while (cur) {
+                       part = cur->data;
+                       if (part->is_empty) {
+                               cur = g_list_next (cur);
+                               continue;
+                       }
+
+                       c.begin = part->content->data;
+                       c.len = part->content->len;
+                       if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) {
+                               i = snprintf (out_buf, sizeof (out_buf), "weights fail, tokenizer error" CRLF);
+                               free_task (task, FALSE);
+                               if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+                                       return FALSE;
+                               }
+                               session->state = STATE_REPLY;
+                               return TRUE;
+                       }
+                       cur = g_list_next (cur);
+               }
                
+               /* Handle messages without text */
+               if (tokens == NULL) {
+                       i = snprintf (out_buf, sizeof (out_buf), "weights fail, no tokens can be extracted (no text data)" CRLF);
+                       free_task (task, FALSE);
+                       if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+                               return FALSE;
+                       }
+                       session->state = STATE_REPLY;
+                       return TRUE;
+               }
+       
+               
+               /* Init classifier */
+               cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier);
+               
+               cur = session->learn_classifier->classifier->weights_func (cls_ctx, session->worker->srv->statfile_pool, 
+                                                                                                                                       tokens, task);
+               i = 0;
+               struct classify_weight *w;
+
+               while (cur) {
+                       w = cur->data;
+                       i += snprintf (out_buf + i, sizeof (out_buf) - i, "%s: %.2f" CRLF, w->name, w->weight);
+                       cur = g_list_next (cur);
+               }
+               if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+                       return FALSE;
+               }
+
                free_task (task, FALSE);
                i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum);
                if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
index 9ad1362f07075fba9309a37f0c42dd49a310b54c..a300820da9323d3196b11f61624951e2cfdefa56 100644 (file)
@@ -489,7 +489,7 @@ process_autolearn (struct statfile *st, struct worker_task *task, GTree * tokens
                                return;
                        }
 
-                       classifier->learn_func (ctx, task->worker->srv->statfile_pool, statfile, tokens, TRUE, NULL);
+                       classifier->learn_func (ctx, task->worker->srv->statfile_pool, statfile, tokens, TRUE, NULL, 1.);
                        maybe_write_binlog (ctx->cfg, st, statfile, tokens);
                }
        }
index cffb215d5a4e9021213585f3a2b5e0db102291b0..0fffe15186b1570c0ccb6cd1a07b2f60fd44888b 100644 (file)
@@ -47,7 +47,7 @@ enum process_type {
        TYPE_WORKER,
        TYPE_CONTROLLER,
        TYPE_LMTP,
-       TYPE_FUZZY,
+       TYPE_FUZZY
 };
 
 
@@ -131,6 +131,7 @@ struct controller_session {
                STATE_QUIT,
                STATE_OTHER,
                STATE_WAIT,
+               STATE_WEIGHTS
        } state;                                                                                                        /**< current session state                                                      */
        int sock;                                                                                                       /**< socket descriptor                                                          */
        /* Access to authorized commands */
@@ -141,6 +142,7 @@ struct controller_session {
        char *learn_from;                                                                                       /**< from address for learning                                          */
        struct classifier_config *learn_classifier;
        char *learn_symbol;                                                                                     /**< symbol to train                                                            */
+       double learn_multiplier;                                                                        /**< multiplier for learning                                            */
        rspamd_io_dispatcher_t *dispatcher;                                                     /**< IO dispatcher object                                                       */
        f_str_t *learn_buf;                                                                                     /**< learn input                                                                        */
        GList *parts;                                                                                           /**< extracted mime parts                                                       */
@@ -165,7 +167,7 @@ struct worker_task {
                WRITE_REPLY,
                WRITE_ERROR,
                WAIT_FILTER,
-               CLOSING_CONNECTION,
+               CLOSING_CONNECTION
        } state;                                                                                                        /**< current session state                                                      */
        size_t content_length;                                                                          /**< length of user's input                                                     */
        enum rspamd_protocol proto;                                                                     /**< protocol (rspamc or spamc)                                         */
index e076408078e71348c30a300a528049cd3acf7e1b..409ef525471875d76704261c132ff2fb42fe1671 100644 (file)
@@ -721,6 +721,13 @@ write_check_reply (struct worker_task *task)
        msg_info ("%s", logbuf);
        rspamd_dispatcher_write (task->dispatcher, CRLF, sizeof (CRLF) - 1, FALSE, TRUE);
 
+       if (default_score >= default_required_score) {
+               task->worker->srv->stat->messages_ham ++;
+       }
+       else {
+               task->worker->srv->stat->messages_ham ++;
+       }
+
        return 0;
 }
 
@@ -778,6 +785,13 @@ write_process_reply (struct worker_task *task)
        rspamd_dispatcher_write (task->dispatcher, outbuf, r, TRUE, FALSE);
        rspamd_dispatcher_write (task->dispatcher, outmsg, strlen (outmsg), FALSE, TRUE);
 
+       if (default_score >= default_required_score) {
+               task->worker->srv->stat->messages_spam ++;
+       }
+       else {
+               task->worker->srv->stat->messages_ham ++;
+       }
+
        memory_pool_add_destructor (task->task_pool, (pool_destruct_func) g_free, outmsg);
 
        return 0;