diff options
Diffstat (limited to 'src/controller.c')
-rw-r--r-- | src/controller.c | 141 |
1 files changed, 136 insertions, 5 deletions
diff --git a/src/controller.c b/src/controller.c index d76c35db3..3e4dc3685 100644 --- a/src/controller.c +++ b/src/controller.c @@ -50,7 +50,8 @@ enum command_type { COMMAND_LEARN, COMMAND_HELP, COMMAND_COUNTERS, - COMMAND_SYNC + COMMAND_SYNC, + COMMAND_WEIGHTS }; struct controller_command { @@ -74,6 +75,7 @@ static struct controller_command commands[] = { {"shutdown", TRUE, COMMAND_SHUTDOWN}, {"uptime", FALSE, COMMAND_UPTIME}, {"learn", TRUE, COMMAND_LEARN}, + {"weights", FALSE, COMMAND_WEIGHTS}, {"help", FALSE, COMMAND_HELP}, {"counters", FALSE, COMMAND_COUNTERS}, {"sync", FALSE, COMMAND_SYNC} @@ -336,6 +338,8 @@ process_stat_command (struct controller_session *session) memory_pool_stat (&mem_st); r = snprintf (out_buf, sizeof (out_buf), "Messages scanned: %u" CRLF, session->worker->srv->stat->messages_scanned); + r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as spam: %u" CRLF, session->worker->srv->stat->messages_spam); + r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as ham: %u" CRLF, session->worker->srv->stat->messages_ham); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages learned: %u" CRLF, session->worker->srv->stat->messages_learned); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Connections count: %u" CRLF, session->worker->srv->stat->connections_count); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Control connections count: %u" CRLF, session->worker->srv->stat->control_connections_count); @@ -477,7 +481,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control } size = strtoul (arg, &err_str, 10); if (err_str && *err_str != '\0') { - msg_debug ("statfile size is invalid: %s", arg); + msg_debug ("message size is invalid: %s", arg); r = snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; @@ -495,6 +499,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control /* By default learn positive */ session->in_class = 1; + session->learn_multiplier = 1.; /* Get all arguments */ while (*cmd_args++) { arg = *cmd_args; @@ -521,6 +526,15 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control case 'n': session->in_class = 0; break; + case 'm': + arg = *(cmd_args + 1); + if (!arg || *arg == '\0') { + r = snprintf (out_buf, sizeof (out_buf), "recipient is not defined" CRLF); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return; + } + session->learn_multiplier = strtod (arg, NULL); + break; default: r = snprintf (out_buf, sizeof (out_buf), "tokenizer is not defined" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); @@ -532,6 +546,42 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control session->state = STATE_LEARN; } break; + + case COMMAND_WEIGHTS: + arg = *cmd_args; + if (!arg || *arg == '\0') { + msg_debug ("no statfile specified in weights command"); + r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return; + } + arg = *(cmd_args + 1); + if (arg == NULL || *arg == '\0') { + msg_debug ("no message size specified in weights command"); + r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return; + } + size = strtoul (arg, &err_str, 10); + if (err_str && *err_str != '\0') { + msg_debug ("message size is invalid: %s", arg); + r = snprintf (out_buf, sizeof (out_buf), "message size is invalid" CRLF); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return; + } + + cl = g_hash_table_lookup (session->cfg->classifiers_symbols, *cmd_args); + if (cl == NULL) { + r = snprintf (out_buf, sizeof (out_buf), "statfile %s is not defined" CRLF, *cmd_args); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return; + + } + session->learn_classifier = cl; + + rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size); + session->state = STATE_WEIGHTS; + break; case COMMAND_SYNC: if (!process_sync_command (session, cmd_args)) { r = snprintf (out_buf, sizeof (out_buf), "FAIL" CRLF); @@ -543,7 +593,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control r = snprintf (out_buf, sizeof (out_buf), "Rspamd CLI commands (* - privilleged command):" CRLF " help - this help message" CRLF - "(*) learn <statfile> <size> [-r recipient], [-f from] [-n] - learn message to specified statfile" CRLF + "(*) learn <statfile> <size> [-r recipient] [-m multiplier] [-f from] [-n] - learn message to specified statfile" CRLF " quit - quit CLI session" CRLF "(*) reload - reload rspamd" CRLF "(*) shutdown - shutdown rspamd" CRLF " stat - show different rspamd stat" CRLF " counters - show rspamd counters" CRLF " uptime - rspamd uptime" CRLF); @@ -627,7 +677,7 @@ controller_read_socket (f_str_t * in, void *arg) if (session->state == STATE_COMMAND) { session->state = STATE_REPLY; } - if (session->state != STATE_LEARN && session->state != STATE_OTHER) { + if (session->state != STATE_LEARN && session->state != STATE_WEIGHTS && session->state != STATE_OTHER) { if (!rspamd_dispatcher_write (session->dispatcher, END, sizeof (END) - 1, FALSE, TRUE)) { return FALSE; } @@ -714,11 +764,92 @@ controller_read_socket (f_str_t * in, void *arg) /* XXX: remove this awful legacy */ session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool, - statfile, tokens, session->in_class, &sum); + statfile, tokens, session->in_class, &sum, + session->learn_multiplier); session->worker->srv->stat->messages_learned++; maybe_write_binlog (session->learn_classifier, st, statfile, tokens); + if (st->normalizer != NULL) { + sum = st->normalizer (sum, st->normalizer_data); + } + + free_task (task, FALSE); + i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum); + if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { + return FALSE; + } + + session->state = STATE_REPLY; + break; + case STATE_WEIGHTS: + session->learn_buf = in; + task = construct_task (session->worker); + + task->msg = memory_pool_alloc (task->task_pool, sizeof (f_str_t)); + task->msg->begin = in->begin; + task->msg->len = in->len; + + r = process_message (task); + if (r == -1) { + msg_warn ("processing of message failed"); + free_task (task, FALSE); + session->state = STATE_REPLY; + r = snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF); + rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); + return FALSE; + } + + cur = g_list_first (task->text_parts); + while (cur) { + part = cur->data; + if (part->is_empty) { + cur = g_list_next (cur); + continue; + } + + c.begin = part->content->data; + c.len = part->content->len; + if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) { + i = snprintf (out_buf, sizeof (out_buf), "weights fail, tokenizer error" CRLF); + free_task (task, FALSE); + if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { + return FALSE; + } + session->state = STATE_REPLY; + return TRUE; + } + cur = g_list_next (cur); + } + /* Handle messages without text */ + if (tokens == NULL) { + i = snprintf (out_buf, sizeof (out_buf), "weights fail, no tokens can be extracted (no text data)" CRLF); + free_task (task, FALSE); + if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { + return FALSE; + } + session->state = STATE_REPLY; + return TRUE; + } + + + /* Init classifier */ + cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier); + + cur = session->learn_classifier->classifier->weights_func (cls_ctx, session->worker->srv->statfile_pool, + tokens, task); + i = 0; + struct classify_weight *w; + + while (cur) { + w = cur->data; + i += snprintf (out_buf + i, sizeof (out_buf) - i, "%s: %.2f" CRLF, w->name, w->weight); + cur = g_list_next (cur); + } + if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { + return FALSE; + } + free_task (task, FALSE); i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { |