From 74cf00015278784d04d26b44bcf326f9493f7d62 Mon Sep 17 00:00:00 2001 From: "cebka@lenovo-laptop" Date: Mon, 1 Mar 2010 18:37:06 +0300 Subject: [PATCH] * Add weights command for getting weights of each message by each statfile * Add ability to specify multiplier when learning * Add statistics about spam and ham messages --- rspamc.pl.in | 16 +++- src/classifiers/classifiers.c | 5 +- src/classifiers/classifiers.h | 14 +++- src/classifiers/winnow.c | 67 +++++++++++++++- src/controller.c | 141 ++++++++++++++++++++++++++++++++-- src/filter.c | 2 +- src/main.h | 6 +- src/protocol.c | 14 ++++ 8 files changed, 249 insertions(+), 16 deletions(-) diff --git a/rspamc.pl.in b/rspamc.pl.in index 4dddc7e3b..19bcab3f7 100755 --- a/rspamc.pl.in +++ b/rspamc.pl.in @@ -211,7 +211,7 @@ sub do_control_command { if (do_ctrl_auth ($sock)) { my $len = length ($input); print "Sending $len bytes...\n"; - syswrite $sock, "learn $cfg{'statfile'} $len" . $CRLF; + syswrite $sock, "learn $cfg{'statfile'} $len -w $cfg{weight}" . $CRLF; syswrite $sock, $input . $CRLF; if (defined (my $reply = <$sock>)) { if ($reply =~ /^learn ok, sum weight: ([0-9.]+)/) { @@ -226,6 +226,18 @@ sub do_control_command { print "Authentication failed\n"; } } + if ($cfg{'command'} =~ /^weights$/i) { + die "statfile is not specified to weights command" if !$cfg{'statfile'}; + + + my $len = length ($input); + print "Sending $len bytes...\n"; + syswrite $sock, "weights $cfg{'statfile'} $len" . $CRLF; + syswrite $sock, $input . $CRLF; + if (defined (my $reply = <$sock>)) { + print $_; + } + } elsif ($cfg{'command'} =~ /(reload|shutdown)/i) { if (do_ctrl_auth ($sock)) { syswrite $sock, $cfg{'command'} . $CRLF; @@ -540,7 +552,7 @@ else { die "unknown command $cmd"; } -if ($cmd =~ /SYMBOLS|SCAN|PROCESS|CHECK|REPORT_IFSPAM|REPORT|URLS|EMAILS|LEARN|FUZZY_ADD|FUZZY_DEL/i) { +if ($cmd =~ /SYMBOLS|SCAN|PROCESS|CHECK|REPORT_IFSPAM|REPORT|URLS|EMAILS|LEARN|FUZZY_ADD|FUZZY_DEL|WEIGHTS/i) { $cfg{'require_input'} = 1; } diff --git a/src/classifiers/classifiers.c b/src/classifiers/classifiers.c index 566cf2b75..219576870 100644 --- a/src/classifiers/classifiers.c +++ b/src/classifiers/classifiers.c @@ -30,12 +30,13 @@ #include "classifiers.h" struct classifier classifiers[] = { - { + { .name = "winnow", .init_func = winnow_init, .classify_func = winnow_classify, .learn_func = winnow_learn, - }, + .weights_func = winnow_weights + } }; struct classifier * diff --git a/src/classifiers/classifiers.h b/src/classifiers/classifiers.h index 12787f049..de937bc3f 100644 --- a/src/classifiers/classifiers.h +++ b/src/classifiers/classifiers.h @@ -14,13 +14,20 @@ struct classifier_ctx { GHashTable *results; struct classifier_config *cfg; }; + +struct classify_weight { + const char *name; + double weight; +}; + /* Common classifier structure */ struct classifier { char *name; struct classifier_ctx* (*init_func)(memory_pool_t *pool, struct classifier_config *cf); void (*classify_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); void (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, - stat_file_t *file, GTree *input, gboolean in_class, double *sum); + stat_file_t *file, GTree *input, gboolean in_class, double *sum, double multiplier); + GList* (*weights_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); }; /* Get classifier structure by name or return NULL if this name is not found */ @@ -29,7 +36,10 @@ struct classifier* get_classifier (char *name); /* Winnow algorithm */ struct classifier_ctx* winnow_init (memory_pool_t *pool, struct classifier_config *cf); void winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); -void winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, stat_file_t *file, GTree *input, gboolean in_class, double *sum); +void winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, stat_file_t *file, GTree *input, + gboolean in_class, double *sum, double multiplier); +GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); + /* Array of all defined classifiers */ extern struct classifier classifiers[]; diff --git a/src/classifiers/winnow.c b/src/classifiers/winnow.c index e103dd50d..af370bd38 100644 --- a/src/classifiers/winnow.c +++ b/src/classifiers/winnow.c @@ -42,6 +42,7 @@ struct winnow_callback_data { struct classifier_ctx *ctx; stat_file_t *file; double sum; + double multiplier; int count; int in_class; time_t now; @@ -77,8 +78,9 @@ learn_callback (gpointer key, gpointer value, gpointer data) token_node_t *node = key; struct winnow_callback_data *cd = data; double v, c; - + c = (cd->in_class) ? WINNOW_PROMOTION : WINNOW_DEMOTION; + c *= cd->multiplier; /* Consider that not found blocks have value 1 */ v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now); @@ -195,13 +197,74 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp } } +GList * +winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * input, struct worker_task *task) +{ + struct winnow_callback_data data; + double res = 0.; + GList *cur, *resl = NULL; + struct statfile *st; + struct classify_weight *w; + + g_assert (pool != NULL); + g_assert (ctx != NULL); + + data.pool = pool; + data.sum = 0; + data.count = 0; + data.now = time (NULL); + data.ctx = ctx; + + cur = ctx->cfg->statfiles; + while (cur) { + st = cur->data; + if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) { + if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { + msg_warn ("cannot open %s, skip it", st->path); + cur = g_list_next (cur); + continue; + } + } + + if (data.file != NULL) { + statfile_pool_lock_file (pool, data.file); + g_tree_foreach (input, classify_callback, &data); + statfile_pool_unlock_file (pool, data.file); + } + + w = memory_pool_alloc (task->task_pool, sizeof (struct classify_weight)); + if (data.count != 0) { + res = data.sum / data.count; + w->name = st->symbol; + w->weight = res; + resl = g_list_prepend (resl, w); + } + else { + res = 0; + w->name = st->symbol; + w->weight = res; + resl = g_list_prepend (resl, w); + } + cur = g_list_next (cur); + } + + if (resl != NULL) { + memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, resl); + } + + return resl; + +} + + void -winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum) +winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum, double multiplier) { struct winnow_callback_data data = { .file = NULL, .sum = 0, .count = 0, + .multiplier = multiplier }; g_assert (pool != NULL); diff --git a/src/controller.c b/src/controller.c index d76c35db3..3e4dc3685 100644 --- a/src/controller.c +++ b/src/controller.c @@ -50,7 +50,8 @@ enum command_type { COMMAND_LEARN, COMMAND_HELP, COMMAND_COUNTERS, - COMMAND_SYNC + COMMAND_SYNC, + COMMAND_WEIGHTS }; struct controller_command { @@ -74,6 +75,7 @@ static struct controller_command commands[] = { {"shutdown", TRUE, COMMAND_SHUTDOWN}, {"uptime", FALSE, COMMAND_UPTIME}, {"learn", TRUE, COMMAND_LEARN}, + {"weights", FALSE, COMMAND_WEIGHTS}, {"help", FALSE, COMMAND_HELP}, {"counters", FALSE, COMMAND_COUNTERS}, {"sync", FALSE, COMMAND_SYNC} @@ -336,6 +338,8 @@ process_stat_command (struct controller_session *session) memory_pool_stat (&mem_st); r = snprintf (out_buf, sizeof (out_buf), "Messages scanned: %u" CRLF, session->worker->srv->stat->messages_scanned); + r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as spam: %u" CRLF, session->worker->srv->stat->messages_spam); + r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as ham: %u" CRLF, session->worker->srv->stat->messages_ham); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages learned: %u" CRLF, session->worker->srv->stat->messages_learned); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Connections count: %u" CRLF, session->worker->srv->stat->connections_count); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Control connections count: %u" CRLF, session->worker->srv->stat->control_connections_count); @@ -477,7 +481,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control } size = strtoul (arg, &err_str, 10); if (err_str && *err_str != '\0') { - msg_debug ("statfile size is invalid: %s", arg); + msg_debug ("message size is invalid: %s", arg); r = snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; @@ -495,6 +499,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control /* By default learn positive */ session->in_class = 1; + session->learn_multiplier = 1.; /* Get all arguments */ while (*cmd_args++) { arg = *cmd_args; @@ -521,6 +526,15 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control case 'n': session->in_class = 0; break; + case 'm': + arg = *(cmd_args + 1); + if (!arg || *arg == '\0') { + r = snprintf (out_buf, sizeof (out_buf), "recipient is not defined" CRLF); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return; + } + session->learn_multiplier = strtod (arg, NULL); + break; default: r = snprintf (out_buf, sizeof (out_buf), "tokenizer is not defined" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); @@ -532,6 +546,42 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control session->state = STATE_LEARN; } break; + + case COMMAND_WEIGHTS: + arg = *cmd_args; + if (!arg || *arg == '\0') { + msg_debug ("no statfile specified in weights command"); + r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return; + } + arg = *(cmd_args + 1); + if (arg == NULL || *arg == '\0') { + msg_debug ("no message size specified in weights command"); + r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return; + } + size = strtoul (arg, &err_str, 10); + if (err_str && *err_str != '\0') { + msg_debug ("message size is invalid: %s", arg); + r = snprintf (out_buf, sizeof (out_buf), "message size is invalid" CRLF); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return; + } + + cl = g_hash_table_lookup (session->cfg->classifiers_symbols, *cmd_args); + if (cl == NULL) { + r = snprintf (out_buf, sizeof (out_buf), "statfile %s is not defined" CRLF, *cmd_args); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return; + + } + session->learn_classifier = cl; + + rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size); + session->state = STATE_WEIGHTS; + break; case COMMAND_SYNC: if (!process_sync_command (session, cmd_args)) { r = snprintf (out_buf, sizeof (out_buf), "FAIL" CRLF); @@ -543,7 +593,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control r = snprintf (out_buf, sizeof (out_buf), "Rspamd CLI commands (* - privilleged command):" CRLF " help - this help message" CRLF - "(*) learn [-r recipient], [-f from] [-n] - learn message to specified statfile" CRLF + "(*) learn [-r recipient] [-m multiplier] [-f from] [-n] - learn message to specified statfile" CRLF " quit - quit CLI session" CRLF "(*) reload - reload rspamd" CRLF "(*) shutdown - shutdown rspamd" CRLF " stat - show different rspamd stat" CRLF " counters - show rspamd counters" CRLF " uptime - rspamd uptime" CRLF); @@ -627,7 +677,7 @@ controller_read_socket (f_str_t * in, void *arg) if (session->state == STATE_COMMAND) { session->state = STATE_REPLY; } - if (session->state != STATE_LEARN && session->state != STATE_OTHER) { + if (session->state != STATE_LEARN && session->state != STATE_WEIGHTS && session->state != STATE_OTHER) { if (!rspamd_dispatcher_write (session->dispatcher, END, sizeof (END) - 1, FALSE, TRUE)) { return FALSE; } @@ -714,11 +764,92 @@ controller_read_socket (f_str_t * in, void *arg) /* XXX: remove this awful legacy */ session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool, - statfile, tokens, session->in_class, &sum); + statfile, tokens, session->in_class, &sum, + session->learn_multiplier); session->worker->srv->stat->messages_learned++; maybe_write_binlog (session->learn_classifier, st, statfile, tokens); + if (st->normalizer != NULL) { + sum = st->normalizer (sum, st->normalizer_data); + } + + free_task (task, FALSE); + i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum); + if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { + return FALSE; + } + + session->state = STATE_REPLY; + break; + case STATE_WEIGHTS: + session->learn_buf = in; + task = construct_task (session->worker); + + task->msg = memory_pool_alloc (task->task_pool, sizeof (f_str_t)); + task->msg->begin = in->begin; + task->msg->len = in->len; + + r = process_message (task); + if (r == -1) { + msg_warn ("processing of message failed"); + free_task (task, FALSE); + session->state = STATE_REPLY; + r = snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return FALSE; + } + + cur = g_list_first (task->text_parts); + while (cur) { + part = cur->data; + if (part->is_empty) { + cur = g_list_next (cur); + continue; + } + + c.begin = part->content->data; + c.len = part->content->len; + if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) { + i = snprintf (out_buf, sizeof (out_buf), "weights fail, tokenizer error" CRLF); + free_task (task, FALSE); + if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { + return FALSE; + } + session->state = STATE_REPLY; + return TRUE; + } + cur = g_list_next (cur); + } + /* Handle messages without text */ + if (tokens == NULL) { + i = snprintf (out_buf, sizeof (out_buf), "weights fail, no tokens can be extracted (no text data)" CRLF); + free_task (task, FALSE); + if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { + return FALSE; + } + session->state = STATE_REPLY; + return TRUE; + } + + + /* Init classifier */ + cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier); + + cur = session->learn_classifier->classifier->weights_func (cls_ctx, session->worker->srv->statfile_pool, + tokens, task); + i = 0; + struct classify_weight *w; + + while (cur) { + w = cur->data; + i += snprintf (out_buf + i, sizeof (out_buf) - i, "%s: %.2f" CRLF, w->name, w->weight); + cur = g_list_next (cur); + } + if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { + return FALSE; + } + free_task (task, FALSE); i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { diff --git a/src/filter.c b/src/filter.c index 9ad1362f0..a300820da 100644 --- a/src/filter.c +++ b/src/filter.c @@ -489,7 +489,7 @@ process_autolearn (struct statfile *st, struct worker_task *task, GTree * tokens return; } - classifier->learn_func (ctx, task->worker->srv->statfile_pool, statfile, tokens, TRUE, NULL); + classifier->learn_func (ctx, task->worker->srv->statfile_pool, statfile, tokens, TRUE, NULL, 1.); maybe_write_binlog (ctx->cfg, st, statfile, tokens); } } diff --git a/src/main.h b/src/main.h index cffb215d5..0fffe1518 100644 --- a/src/main.h +++ b/src/main.h @@ -47,7 +47,7 @@ enum process_type { TYPE_WORKER, TYPE_CONTROLLER, TYPE_LMTP, - TYPE_FUZZY, + TYPE_FUZZY }; @@ -131,6 +131,7 @@ struct controller_session { STATE_QUIT, STATE_OTHER, STATE_WAIT, + STATE_WEIGHTS } state; /**< current session state */ int sock; /**< socket descriptor */ /* Access to authorized commands */ @@ -141,6 +142,7 @@ struct controller_session { char *learn_from; /**< from address for learning */ struct classifier_config *learn_classifier; char *learn_symbol; /**< symbol to train */ + double learn_multiplier; /**< multiplier for learning */ rspamd_io_dispatcher_t *dispatcher; /**< IO dispatcher object */ f_str_t *learn_buf; /**< learn input */ GList *parts; /**< extracted mime parts */ @@ -165,7 +167,7 @@ struct worker_task { WRITE_REPLY, WRITE_ERROR, WAIT_FILTER, - CLOSING_CONNECTION, + CLOSING_CONNECTION } state; /**< current session state */ size_t content_length; /**< length of user's input */ enum rspamd_protocol proto; /**< protocol (rspamc or spamc) */ diff --git a/src/protocol.c b/src/protocol.c index e07640807..409ef5254 100644 --- a/src/protocol.c +++ b/src/protocol.c @@ -721,6 +721,13 @@ write_check_reply (struct worker_task *task) msg_info ("%s", logbuf); rspamd_dispatcher_write (task->dispatcher, CRLF, sizeof (CRLF) - 1, FALSE, TRUE); + if (default_score >= default_required_score) { + task->worker->srv->stat->messages_ham ++; + } + else { + task->worker->srv->stat->messages_ham ++; + } + return 0; } @@ -778,6 +785,13 @@ write_process_reply (struct worker_task *task) rspamd_dispatcher_write (task->dispatcher, outbuf, r, TRUE, FALSE); rspamd_dispatcher_write (task->dispatcher, outmsg, strlen (outmsg), FALSE, TRUE); + if (default_score >= default_required_score) { + task->worker->srv->stat->messages_spam ++; + } + else { + task->worker->srv->stat->messages_ham ++; + } + memory_pool_add_destructor (task->task_pool, (pool_destruct_func) g_free, outmsg); return 0; -- 2.39.5