summaryrefslogtreecommitdiffstats
path: root/src/controller.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/controller.c')
-rw-r--r--src/controller.c141
1 files changed, 136 insertions, 5 deletions
diff --git a/src/controller.c b/src/controller.c
index d76c35db3..3e4dc3685 100644
--- a/src/controller.c
+++ b/src/controller.c
@@ -50,7 +50,8 @@ enum command_type {
COMMAND_LEARN,
COMMAND_HELP,
COMMAND_COUNTERS,
- COMMAND_SYNC
+ COMMAND_SYNC,
+ COMMAND_WEIGHTS
};
struct controller_command {
@@ -74,6 +75,7 @@ static struct controller_command commands[] = {
{"shutdown", TRUE, COMMAND_SHUTDOWN},
{"uptime", FALSE, COMMAND_UPTIME},
{"learn", TRUE, COMMAND_LEARN},
+ {"weights", FALSE, COMMAND_WEIGHTS},
{"help", FALSE, COMMAND_HELP},
{"counters", FALSE, COMMAND_COUNTERS},
{"sync", FALSE, COMMAND_SYNC}
@@ -336,6 +338,8 @@ process_stat_command (struct controller_session *session)
memory_pool_stat (&mem_st);
r = snprintf (out_buf, sizeof (out_buf), "Messages scanned: %u" CRLF, session->worker->srv->stat->messages_scanned);
+ r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as spam: %u" CRLF, session->worker->srv->stat->messages_spam);
+ r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as ham: %u" CRLF, session->worker->srv->stat->messages_ham);
r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages learned: %u" CRLF, session->worker->srv->stat->messages_learned);
r += snprintf (out_buf + r, sizeof (out_buf) - r, "Connections count: %u" CRLF, session->worker->srv->stat->connections_count);
r += snprintf (out_buf + r, sizeof (out_buf) - r, "Control connections count: %u" CRLF, session->worker->srv->stat->control_connections_count);
@@ -477,7 +481,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
}
size = strtoul (arg, &err_str, 10);
if (err_str && *err_str != '\0') {
- msg_debug ("statfile size is invalid: %s", arg);
+ msg_debug ("message size is invalid: %s", arg);
r = snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF);
rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
return;
@@ -495,6 +499,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
/* By default learn positive */
session->in_class = 1;
+ session->learn_multiplier = 1.;
/* Get all arguments */
while (*cmd_args++) {
arg = *cmd_args;
@@ -521,6 +526,15 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
case 'n':
session->in_class = 0;
break;
+ case 'm':
+ arg = *(cmd_args + 1);
+ if (!arg || *arg == '\0') {
+ r = snprintf (out_buf, sizeof (out_buf), "recipient is not defined" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+ session->learn_multiplier = strtod (arg, NULL);
+ break;
default:
r = snprintf (out_buf, sizeof (out_buf), "tokenizer is not defined" CRLF);
rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
@@ -532,6 +546,42 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
session->state = STATE_LEARN;
}
break;
+
+ case COMMAND_WEIGHTS:
+ arg = *cmd_args;
+ if (!arg || *arg == '\0') {
+ msg_debug ("no statfile specified in weights command");
+ r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+ arg = *(cmd_args + 1);
+ if (arg == NULL || *arg == '\0') {
+ msg_debug ("no message size specified in weights command");
+ r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+ size = strtoul (arg, &err_str, 10);
+ if (err_str && *err_str != '\0') {
+ msg_debug ("message size is invalid: %s", arg);
+ r = snprintf (out_buf, sizeof (out_buf), "message size is invalid" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+
+ cl = g_hash_table_lookup (session->cfg->classifiers_symbols, *cmd_args);
+ if (cl == NULL) {
+ r = snprintf (out_buf, sizeof (out_buf), "statfile %s is not defined" CRLF, *cmd_args);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+
+ }
+ session->learn_classifier = cl;
+
+ rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size);
+ session->state = STATE_WEIGHTS;
+ break;
case COMMAND_SYNC:
if (!process_sync_command (session, cmd_args)) {
r = snprintf (out_buf, sizeof (out_buf), "FAIL" CRLF);
@@ -543,7 +593,7 @@ process_command (struct controller_command *cmd, char **cmd_args, struct control
r = snprintf (out_buf, sizeof (out_buf),
"Rspamd CLI commands (* - privilleged command):" CRLF
" help - this help message" CRLF
- "(*) learn <statfile> <size> [-r recipient], [-f from] [-n] - learn message to specified statfile" CRLF
+ "(*) learn <statfile> <size> [-r recipient] [-m multiplier] [-f from] [-n] - learn message to specified statfile" CRLF
" quit - quit CLI session" CRLF
"(*) reload - reload rspamd" CRLF
"(*) shutdown - shutdown rspamd" CRLF " stat - show different rspamd stat" CRLF " counters - show rspamd counters" CRLF " uptime - rspamd uptime" CRLF);
@@ -627,7 +677,7 @@ controller_read_socket (f_str_t * in, void *arg)
if (session->state == STATE_COMMAND) {
session->state = STATE_REPLY;
}
- if (session->state != STATE_LEARN && session->state != STATE_OTHER) {
+ if (session->state != STATE_LEARN && session->state != STATE_WEIGHTS && session->state != STATE_OTHER) {
if (!rspamd_dispatcher_write (session->dispatcher, END, sizeof (END) - 1, FALSE, TRUE)) {
return FALSE;
}
@@ -714,11 +764,92 @@ controller_read_socket (f_str_t * in, void *arg)
/* XXX: remove this awful legacy */
session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool,
- statfile, tokens, session->in_class, &sum);
+ statfile, tokens, session->in_class, &sum,
+ session->learn_multiplier);
session->worker->srv->stat->messages_learned++;
maybe_write_binlog (session->learn_classifier, st, statfile, tokens);
+ if (st->normalizer != NULL) {
+ sum = st->normalizer (sum, st->normalizer_data);
+ }
+
+ free_task (task, FALSE);
+ i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum);
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+
+ session->state = STATE_REPLY;
+ break;
+ case STATE_WEIGHTS:
+ session->learn_buf = in;
+ task = construct_task (session->worker);
+
+ task->msg = memory_pool_alloc (task->task_pool, sizeof (f_str_t));
+ task->msg->begin = in->begin;
+ task->msg->len = in->len;
+
+ r = process_message (task);
+ if (r == -1) {
+ msg_warn ("processing of message failed");
+ free_task (task, FALSE);
+ session->state = STATE_REPLY;
+ r = snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return FALSE;
+ }
+
+ cur = g_list_first (task->text_parts);
+ while (cur) {
+ part = cur->data;
+ if (part->is_empty) {
+ cur = g_list_next (cur);
+ continue;
+ }
+
+ c.begin = part->content->data;
+ c.len = part->content->len;
+ if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) {
+ i = snprintf (out_buf, sizeof (out_buf), "weights fail, tokenizer error" CRLF);
+ free_task (task, FALSE);
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+ session->state = STATE_REPLY;
+ return TRUE;
+ }
+ cur = g_list_next (cur);
+ }
+ /* Handle messages without text */
+ if (tokens == NULL) {
+ i = snprintf (out_buf, sizeof (out_buf), "weights fail, no tokens can be extracted (no text data)" CRLF);
+ free_task (task, FALSE);
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+ session->state = STATE_REPLY;
+ return TRUE;
+ }
+
+
+ /* Init classifier */
+ cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier);
+
+ cur = session->learn_classifier->classifier->weights_func (cls_ctx, session->worker->srv->statfile_pool,
+ tokens, task);
+ i = 0;
+ struct classify_weight *w;
+
+ while (cur) {
+ w = cur->data;
+ i += snprintf (out_buf + i, sizeof (out_buf) - i, "%s: %.2f" CRLF, w->name, w->weight);
+ cur = g_list_next (cur);
+ }
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+
free_task (task, FALSE);
i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum);
if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {