summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xrspamc.pl.in16
-rw-r--r--src/classifiers/classifiers.c5
-rw-r--r--src/classifiers/classifiers.h14
-rw-r--r--src/classifiers/winnow.c67
-rw-r--r--src/controller.c141
-rw-r--r--src/filter.c2
-rw-r--r--src/main.h6
-rw-r--r--src/protocol.c14
8 files changed, 249 insertions, 16 deletions
diff --git a/rspamc.pl.in b/rspamc.pl.in
index 4dddc7e3b..19bcab3f7 100755
--- a/rspamc.pl.in
+++ b/rspamc.pl.in
@@ -211,7 +211,7 @@ sub do_control_command {
if (do_ctrl_auth ($sock)) {
my $len = length ($input);
print "Sending $len bytes...\n";
- syswrite $sock, "learn $cfg{'statfile'} $len" . $CRLF;
+ syswrite $sock, "learn $cfg{'statfile'} $len -w $cfg{weight}" . $CRLF;
syswrite $sock, $input . $CRLF;
if (defined (my $reply = <$sock>)) {
if ($reply =~ /^learn ok, sum weight: ([0-9.]+)/) {
@@ -226,6 +226,18 @@ sub do_control_command {
print "Authentication failed\n";
}
}
+ if ($cfg{'command'} =~ /^weights$/i) {
+ die "statfile is not specified to weights command" if !$cfg{'statfile'};
+
+
+ my $len = length ($input);
+ print "Sending $len bytes...\n";
+ syswrite $sock, "weights $cfg{'statfile'} $len" . $CRLF;
+ syswrite $sock, $input . $CRLF;
+ if (defined (my $reply = <$sock>)) {
+ print $_;
+ }
+ }
elsif ($cfg{'command'} =~ /(reload|shutdown)/i) {
if (do_ctrl_auth ($sock)) {
syswrite $sock, $cfg{'command'} . $CRLF;
@@ -540,7 +552,7 @@ else {
die "unknown command $cmd";
}
-if ($cmd =~ /SYMBOLS|SCAN|PROCESS|CHECK|REPORT_IFSPAM|REPORT|URLS|EMAILS|LEARN|FUZZY_ADD|FUZZY_DEL/i) {
+if ($cmd =~ /SYMBOLS|SCAN|PROCESS|CHECK|REPORT_IFSPAM|REPORT|URLS|EMAILS|LEARN|FUZZY_ADD|FUZZY_DEL|WEIGHTS/i) {
$cfg{'require_input'} = 1;
}
diff --git a/src/classifiers/classifiers.c b/src/classifiers/classifiers.c
index 566cf2b75..219576870 100644
--- a/src/classifiers/classifiers.c
+++ b/src/classifiers/classifiers.c
@@ -30,12 +30,13 @@
#include "classifiers.h"
struct classifier classifiers[] = {
- {
+ {
.name = "winnow",
.init_func = winnow_init,
.classify_func = winnow_classify,
.learn_func = winnow_learn,
- },
+ .weights_func = winnow_weights
+ }
};
struct classifier *
diff --git a/src/classifiers/classifiers.h b/src/classifiers/classifiers.h
index 12787f049..de937bc3f 100644
--- a/src/classifiers/classifiers.h
+++ b/src/classifiers/classifiers.h
@@ -14,13 +14,20 @@ struct classifier_ctx {
GHashTable *results;
struct classifier_config *cfg;
};
+
+struct classify_weight {
+ const char *name;
+ double weight;
+};
+
/* Common classifier structure */
struct classifier {
char *name;
struct classifier_ctx* (*init_func)(memory_pool_t *pool, struct classifier_config *cf);
void (*classify_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
void (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool,
- stat_file_t *file, GTree *input, gboolean in_class, double *sum);
+ stat_file_t *file, GTree *input, gboolean in_class, double *sum, double multiplier);
+ GList* (*weights_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
};
/* Get classifier structure by name or return NULL if this name is not found */
@@ -29,7 +36,10 @@ struct classifier* get_classifier (char *name);
/* Winnow algorithm */
struct classifier_ctx* winnow_init (memory_pool_t *pool, struct classifier_config *cf);
void winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
-void winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, stat_file_t *file, GTree *input, gboolean in_class, double *sum);
+void winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, stat_file_t *file, GTree *input,
+ gboolean in_class, double *sum, double multiplier);
+GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
+
/* Array of all defined classifiers */
extern struct classifier classifiers[];
diff --git a/src/classifiers/winnow.c b/src/classifiers/winnow.c
index e103dd50d..af370bd38 100644
--- a/src/classifiers/winnow.c
+++ b/src/classifiers/winnow.c
@@ -42,6 +42,7 @@ struct winnow_callback_data {
struct classifier_ctx *ctx;
stat_file_t *file;
double sum;
+ double multiplier;
int count;
int in_class;
time_t now;
@@ -77,8 +78,9 @@ learn_callback (gpointer key, gpointer value, gpointer data)
token_node_t *node = key;
struct winnow_callback_data *cd = data;
double v, c;
-
+
c = (cd->in_class) ? WINNOW_PROMOTION : WINNOW_DEMOTION;
+ c *= cd->multiplier;
/* Consider that not found blocks have value 1 */
v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
@@ -195,13 +197,74 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
}
}
+GList *
+winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * input, struct worker_task *task)
+{
+ struct winnow_callback_data data;
+ double res = 0.;
+ GList *cur, *resl = NULL;
+ struct statfile *st;
+ struct classify_weight *w;
+
+ g_assert (pool != NULL);
+ g_assert (ctx != NULL);
+
+ data.pool = pool;
+ data.sum = 0;
+ data.count = 0;
+ data.now = time (NULL);
+ data.ctx = ctx;
+
+ cur = ctx->cfg->statfiles;
+ while (cur) {
+ st = cur->data;
+ if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
+ if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
+ msg_warn ("cannot open %s, skip it", st->path);
+ cur = g_list_next (cur);
+ continue;
+ }
+ }
+
+ if (data.file != NULL) {
+ statfile_pool_lock_file (pool, data.file);
+ g_tree_foreach (input, classify_callback, &data);
+ statfile_pool_unlock_file (pool, data.file);
+ }
+
+ w = memory_pool_alloc (task->task_pool, sizeof (struct classify_weight));
+ if (data.count != 0) {
+ res = data.sum / data.count;
+ w->name = st->symbol;
+ w->weight = res;
+ resl = g_list_prepend (resl, w);
+ }
+ else {
+ res = 0;
+ w->name = st->symbol;
+ w->weight = res;
+ resl = g_list_prepend (resl, w);
+ }
+ cur = g_list_next (cur);
+ }
+
+ if (resl != NULL) {
+ memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, resl);
+ }
+
+ return resl;
+
+}
+
+
void
-winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum)
+winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum, double multiplier)
{
struct winnow_callback_data data = {
.file = NULL,
.sum = 0,
.count = 0,
+ .multiplier = multiplier
};
g_assert (pool != NULL);
diff --git a/src/controller.c b/src/controller.c
index d76c35db3..3e4dc3685 100644
--- a/src/controller.c
+++ b/src/controller.c
@@ -50,7 +50,8 @@ enum command_type {
COMMAND_LEARN,
COMMAND_HELP,
COMMAND_COUNTERS,
- COMMAND_SYNC
+ COMMAND_SYNC,
+ COMMAND_WEIGHTS
};
struct controller_command {
@@ -74,6 +75,7 @@ static struct controller_command commands[] = {
{"shutdown", TRUE, COMMAND_SHUTDOWN},
{"uptime", FALSE, COMMAND_UPTIME},
{"learn", TRUE, COMMAND_LEARN},
+ {"weights", FALSE, COMMAND_WEIGHTS},
{"help", FALSE, COMMAND_HELP},
{"counters", FALSE, COMMAND_COUNTERS},
{"sync", FALSE, COMMAND_SYNC}
@@ -336,6 +338,8 @@ process_stat_command (struct controller_session *session)
memory_pool_stat (&mem_st);
r = snprintf (out_buf, sizeof (out_buf), "Messages scanned: %u" CRLF, session->worker->srv->stat->messages_scanned);
+ r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as spam: %u" CRLF, session->worker->srv->stat->messages_spam);
+ r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as ham: %u" CRLF, session->worker->srv->stat->messages_ham);
r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages learned: %u" CRLF, session->worker->srv->stat->messages_learned);
r += snprintf (out_buf + r, sizeof (out_buf) - r, "Connections count: %u" CRLF, session->worker->srv->stat->connections_count);
r += snprintf (out_buf + r, sizeof (out_buf) - r, "Control connections count: %u" CRLF, session->worker->srv->stat->control_connections_count);
@@ -477,7 +481,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
}
size = strtoul (arg, &err_str, 10);
if (err_str && *err_str != '\0') {
- msg_debug ("statfile size is invalid: %s", arg);
+ msg_debug ("message size is invalid: %s", arg);
r = snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF);
rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
return;
@@ -495,6 +499,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
/* By default learn positive */
session->in_class = 1;
+ session->learn_multiplier = 1.;
/* Get all arguments */
while (*cmd_args++) {
arg = *cmd_args;
@@ -521,6 +526,15 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
case 'n':
session->in_class = 0;
break;
+ case 'm':
+ arg = *(cmd_args + 1);
+ if (!arg || *arg == '\0') {
+ r = snprintf (out_buf, sizeof (out_buf), "recipient is not defined" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+ session->learn_multiplier = strtod (arg, NULL);
+ break;
default:
r = snprintf (out_buf, sizeof (out_buf), "tokenizer is not defined" CRLF);
rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
@@ -532,6 +546,42 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
session->state = STATE_LEARN;
}
break;
+
+ case COMMAND_WEIGHTS:
+ arg = *cmd_args;
+ if (!arg || *arg == '\0') {
+ msg_debug ("no statfile specified in weights command");
+ r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+ arg = *(cmd_args + 1);
+ if (arg == NULL || *arg == '\0') {
+ msg_debug ("no message size specified in weights command");
+ r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+ size = strtoul (arg, &err_str, 10);
+ if (err_str && *err_str != '\0') {
+ msg_debug ("message size is invalid: %s", arg);
+ r = snprintf (out_buf, sizeof (out_buf), "message size is invalid" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+
+ cl = g_hash_table_lookup (session->cfg->classifiers_symbols, *cmd_args);
+ if (cl == NULL) {
+ r = snprintf (out_buf, sizeof (out_buf), "statfile %s is not defined" CRLF, *cmd_args);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+
+ }
+ session->learn_classifier = cl;
+
+ rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size);
+ session->state = STATE_WEIGHTS;
+ break;
case COMMAND_SYNC:
if (!process_sync_command (session, cmd_args)) {
r = snprintf (out_buf, sizeof (out_buf), "FAIL" CRLF);
@@ -543,7 +593,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
r = snprintf (out_buf, sizeof (out_buf),
"Rspamd CLI commands (* - privilleged command):" CRLF
" help - this help message" CRLF
- "(*) learn <statfile> <size> [-r recipient], [-f from] [-n] - learn message to specified statfile" CRLF
+ "(*) learn <statfile> <size> [-r recipient] [-m multiplier] [-f from] [-n] - learn message to specified statfile" CRLF
" quit - quit CLI session" CRLF
"(*) reload - reload rspamd" CRLF
"(*) shutdown - shutdown rspamd" CRLF " stat - show different rspamd stat" CRLF " counters - show rspamd counters" CRLF " uptime - rspamd uptime" CRLF);
@@ -627,7 +677,7 @@ controller_read_socket (f_str_t * in, void *arg)
if (session->state == STATE_COMMAND) {
session->state = STATE_REPLY;
}
- if (session->state != STATE_LEARN && session->state != STATE_OTHER) {
+ if (session->state != STATE_LEARN && session->state != STATE_WEIGHTS && session->state != STATE_OTHER) {
if (!rspamd_dispatcher_write (session->dispatcher, END, sizeof (END) - 1, FALSE, TRUE)) {
return FALSE;
}
@@ -714,11 +764,92 @@ controller_read_socket (f_str_t * in, void *arg)
/* XXX: remove this awful legacy */
session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool,
- statfile, tokens, session->in_class, &sum);
+ statfile, tokens, session->in_class, &sum,
+ session->learn_multiplier);
session->worker->srv->stat->messages_learned++;
maybe_write_binlog (session->learn_classifier, st, statfile, tokens);
+ if (st->normalizer != NULL) {
+ sum = st->normalizer (sum, st->normalizer_data);
+ }
+
+ free_task (task, FALSE);
+ i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum);
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+
+ session->state = STATE_REPLY;
+ break;
+ case STATE_WEIGHTS:
+ session->learn_buf = in;
+ task = construct_task (session->worker);
+
+ task->msg = memory_pool_alloc (task->task_pool, sizeof (f_str_t));
+ task->msg->begin = in->begin;
+ task->msg->len = in->len;
+
+ r = process_message (task);
+ if (r == -1) {
+ msg_warn ("processing of message failed");
+ free_task (task, FALSE);
+ session->state = STATE_REPLY;
+ r = snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return FALSE;
+ }
+
+ cur = g_list_first (task->text_parts);
+ while (cur) {
+ part = cur->data;
+ if (part->is_empty) {
+ cur = g_list_next (cur);
+ continue;
+ }
+
+ c.begin = part->content->data;
+ c.len = part->content->len;
+ if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) {
+ i = snprintf (out_buf, sizeof (out_buf), "weights fail, tokenizer error" CRLF);
+ free_task (task, FALSE);
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+ session->state = STATE_REPLY;
+ return TRUE;
+ }
+ cur = g_list_next (cur);
+ }
+ /* Handle messages without text */
+ if (tokens == NULL) {
+ i = snprintf (out_buf, sizeof (out_buf), "weights fail, no tokens can be extracted (no text data)" CRLF);
+ free_task (task, FALSE);
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+ session->state = STATE_REPLY;
+ return TRUE;
+ }
+
+
+ /* Init classifier */
+ cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier);
+
+ cur = session->learn_classifier->classifier->weights_func (cls_ctx, session->worker->srv->statfile_pool,
+ tokens, task);
+ i = 0;
+ struct classify_weight *w;
+
+ while (cur) {
+ w = cur->data;
+ i += snprintf (out_buf + i, sizeof (out_buf) - i, "%s: %.2f" CRLF, w->name, w->weight);
+ cur = g_list_next (cur);
+ }
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+
free_task (task, FALSE);
i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum);
if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
diff --git a/src/filter.c b/src/filter.c
index 9ad1362f0..a300820da 100644
--- a/src/filter.c
+++ b/src/filter.c
@@ -489,7 +489,7 @@ process_autolearn (struct statfile *st, struct worker_task *task, GTree * tokens
return;
}
- classifier->learn_func (ctx, task->worker->srv->statfile_pool, statfile, tokens, TRUE, NULL);
+ classifier->learn_func (ctx, task->worker->srv->statfile_pool, statfile, tokens, TRUE, NULL, 1.);
maybe_write_binlog (ctx->cfg, st, statfile, tokens);
}
}
diff --git a/src/main.h b/src/main.h
index cffb215d5..0fffe1518 100644
--- a/src/main.h
+++ b/src/main.h
@@ -47,7 +47,7 @@ enum process_type {
TYPE_WORKER,
TYPE_CONTROLLER,
TYPE_LMTP,
- TYPE_FUZZY,
+ TYPE_FUZZY
};
@@ -131,6 +131,7 @@ struct controller_session {
STATE_QUIT,
STATE_OTHER,
STATE_WAIT,
+ STATE_WEIGHTS
} state; /**< current session state */
int sock; /**< socket descriptor */
/* Access to authorized commands */
@@ -141,6 +142,7 @@ struct controller_session {
char *learn_from; /**< from address for learning */
struct classifier_config *learn_classifier;
char *learn_symbol; /**< symbol to train */
+ double learn_multiplier; /**< multiplier for learning */
rspamd_io_dispatcher_t *dispatcher; /**< IO dispatcher object */
f_str_t *learn_buf; /**< learn input */
GList *parts; /**< extracted mime parts */
@@ -165,7 +167,7 @@ struct worker_task {
WRITE_REPLY,
WRITE_ERROR,
WAIT_FILTER,
- CLOSING_CONNECTION,
+ CLOSING_CONNECTION
} state; /**< current session state */
size_t content_length; /**< length of user's input */
enum rspamd_protocol proto; /**< protocol (rspamc or spamc) */
diff --git a/src/protocol.c b/src/protocol.c
index e07640807..409ef5254 100644
--- a/src/protocol.c
+++ b/src/protocol.c
@@ -721,6 +721,13 @@ write_check_reply (struct worker_task *task)
msg_info ("%s", logbuf);
rspamd_dispatcher_write (task->dispatcher, CRLF, sizeof (CRLF) - 1, FALSE, TRUE);
+ if (default_score >= default_required_score) {
+ task->worker->srv->stat->messages_ham ++;
+ }
+ else {
+ task->worker->srv->stat->messages_ham ++;
+ }
+
return 0;
}
@@ -778,6 +785,13 @@ write_process_reply (struct worker_task *task)
rspamd_dispatcher_write (task->dispatcher, outbuf, r, TRUE, FALSE);
rspamd_dispatcher_write (task->dispatcher, outmsg, strlen (outmsg), FALSE, TRUE);
+ if (default_score >= default_required_score) {
+ task->worker->srv->stat->messages_spam ++;
+ }
+ else {
+ task->worker->srv->stat->messages_ham ++;
+ }
+
memory_pool_add_destructor (task->task_pool, (pool_destruct_func) g_free, outmsg);
return 0;