struct classifier_ctx *ctx;
stat_file_t *file;
double sum;
+ double multiplier;
int count;
int in_class;
time_t now;
token_node_t *node = key;
struct winnow_callback_data *cd = data;
double v, c;
-
+
c = (cd->in_class) ? WINNOW_PROMOTION : WINNOW_DEMOTION;
+ c *= cd->multiplier;
/* Consider that not found blocks have value 1 */
v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
}
}
+GList *
+winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * input, struct worker_task *task)
+{
+ struct winnow_callback_data data;
+ double res = 0.;
+ GList *cur, *resl = NULL;
+ struct statfile *st;
+ struct classify_weight *w;
+
+ g_assert (pool != NULL);
+ g_assert (ctx != NULL);
+
+ data.pool = pool;
+ data.sum = 0;
+ data.count = 0;
+ data.now = time (NULL);
+ data.ctx = ctx;
+
+ cur = ctx->cfg->statfiles;
+ while (cur) {
+ st = cur->data;
+ if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
+ if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
+ msg_warn ("cannot open %s, skip it", st->path);
+ cur = g_list_next (cur);
+ continue;
+ }
+ }
+
+ if (data.file != NULL) {
+ statfile_pool_lock_file (pool, data.file);
+ g_tree_foreach (input, classify_callback, &data);
+ statfile_pool_unlock_file (pool, data.file);
+ }
+
+ w = memory_pool_alloc (task->task_pool, sizeof (struct classify_weight));
+ if (data.count != 0) {
+ res = data.sum / data.count;
+ w->name = st->symbol;
+ w->weight = res;
+ resl = g_list_prepend (resl, w);
+ }
+ else {
+ res = 0;
+ w->name = st->symbol;
+ w->weight = res;
+ resl = g_list_prepend (resl, w);
+ }
+ cur = g_list_next (cur);
+ }
+
+ if (resl != NULL) {
+ memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, resl);
+ }
+
+ return resl;
+
+}
+
+
void
-winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum)
+winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum, double multiplier)
{
struct winnow_callback_data data = {
.file = NULL,
.sum = 0,
.count = 0,
+ .multiplier = multiplier
};
g_assert (pool != NULL);
COMMAND_LEARN,
COMMAND_HELP,
COMMAND_COUNTERS,
- COMMAND_SYNC
+ COMMAND_SYNC,
+ COMMAND_WEIGHTS
};
struct controller_command {
{"shutdown", TRUE, COMMAND_SHUTDOWN},
{"uptime", FALSE, COMMAND_UPTIME},
{"learn", TRUE, COMMAND_LEARN},
+ {"weights", FALSE, COMMAND_WEIGHTS},
{"help", FALSE, COMMAND_HELP},
{"counters", FALSE, COMMAND_COUNTERS},
{"sync", FALSE, COMMAND_SYNC}
memory_pool_stat (&mem_st);
r = snprintf (out_buf, sizeof (out_buf), "Messages scanned: %u" CRLF, session->worker->srv->stat->messages_scanned);
+ r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as spam: %u" CRLF, session->worker->srv->stat->messages_spam);
+ r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as ham: %u" CRLF, session->worker->srv->stat->messages_ham);
r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages learned: %u" CRLF, session->worker->srv->stat->messages_learned);
r += snprintf (out_buf + r, sizeof (out_buf) - r, "Connections count: %u" CRLF, session->worker->srv->stat->connections_count);
r += snprintf (out_buf + r, sizeof (out_buf) - r, "Control connections count: %u" CRLF, session->worker->srv->stat->control_connections_count);
}
size = strtoul (arg, &err_str, 10);
if (err_str && *err_str != '\0') {
- msg_debug ("statfile size is invalid: %s", arg);
+ msg_debug ("message size is invalid: %s", arg);
r = snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF);
rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
return;
/* By default learn positive */
session->in_class = 1;
+ session->learn_multiplier = 1.;
/* Get all arguments */
while (*cmd_args++) {
arg = *cmd_args;
case 'n':
session->in_class = 0;
break;
+ case 'm':
+ arg = *(cmd_args + 1);
+ if (!arg || *arg == '\0') {
+ r = snprintf (out_buf, sizeof (out_buf), "recipient is not defined" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+ session->learn_multiplier = strtod (arg, NULL);
+ break;
default:
r = snprintf (out_buf, sizeof (out_buf), "tokenizer is not defined" CRLF);
rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
session->state = STATE_LEARN;
}
break;
+
+ case COMMAND_WEIGHTS:
+ arg = *cmd_args;
+ if (!arg || *arg == '\0') {
+ msg_debug ("no statfile specified in weights command");
+ r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+ arg = *(cmd_args + 1);
+ if (arg == NULL || *arg == '\0') {
+ msg_debug ("no message size specified in weights command");
+ r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+ size = strtoul (arg, &err_str, 10);
+ if (err_str && *err_str != '\0') {
+ msg_debug ("message size is invalid: %s", arg);
+ r = snprintf (out_buf, sizeof (out_buf), "message size is invalid" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+ }
+
+ cl = g_hash_table_lookup (session->cfg->classifiers_symbols, *cmd_args);
+ if (cl == NULL) {
+ r = snprintf (out_buf, sizeof (out_buf), "statfile %s is not defined" CRLF, *cmd_args);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return;
+
+ }
+ session->learn_classifier = cl;
+
+ rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size);
+ session->state = STATE_WEIGHTS;
+ break;
case COMMAND_SYNC:
if (!process_sync_command (session, cmd_args)) {
r = snprintf (out_buf, sizeof (out_buf), "FAIL" CRLF);
r = snprintf (out_buf, sizeof (out_buf),
"Rspamd CLI commands (* - privilleged command):" CRLF
" help - this help message" CRLF
- "(*) learn <statfile> <size> [-r recipient], [-f from] [-n] - learn message to specified statfile" CRLF
+ "(*) learn <statfile> <size> [-r recipient] [-m multiplier] [-f from] [-n] - learn message to specified statfile" CRLF
" quit - quit CLI session" CRLF
"(*) reload - reload rspamd" CRLF
"(*) shutdown - shutdown rspamd" CRLF " stat - show different rspamd stat" CRLF " counters - show rspamd counters" CRLF " uptime - rspamd uptime" CRLF);
if (session->state == STATE_COMMAND) {
session->state = STATE_REPLY;
}
- if (session->state != STATE_LEARN && session->state != STATE_OTHER) {
+ if (session->state != STATE_LEARN && session->state != STATE_WEIGHTS && session->state != STATE_OTHER) {
if (!rspamd_dispatcher_write (session->dispatcher, END, sizeof (END) - 1, FALSE, TRUE)) {
return FALSE;
}
/* XXX: remove this awful legacy */
session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool,
- statfile, tokens, session->in_class, &sum);
+ statfile, tokens, session->in_class, &sum,
+ session->learn_multiplier);
session->worker->srv->stat->messages_learned++;
maybe_write_binlog (session->learn_classifier, st, statfile, tokens);
+ if (st->normalizer != NULL) {
+ sum = st->normalizer (sum, st->normalizer_data);
+ }
+
+ free_task (task, FALSE);
+ i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum);
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+
+ session->state = STATE_REPLY;
+ break;
+ case STATE_WEIGHTS:
+ session->learn_buf = in;
+ task = construct_task (session->worker);
+
+ task->msg = memory_pool_alloc (task->task_pool, sizeof (f_str_t));
+ task->msg->begin = in->begin;
+ task->msg->len = in->len;
+
+ r = process_message (task);
+ if (r == -1) {
+ msg_warn ("processing of message failed");
+ free_task (task, FALSE);
+ session->state = STATE_REPLY;
+ r = snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF);
+ rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE);
+ return FALSE;
+ }
+
+ cur = g_list_first (task->text_parts);
+ while (cur) {
+ part = cur->data;
+ if (part->is_empty) {
+ cur = g_list_next (cur);
+ continue;
+ }
+
+ c.begin = part->content->data;
+ c.len = part->content->len;
+ if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) {
+ i = snprintf (out_buf, sizeof (out_buf), "weights fail, tokenizer error" CRLF);
+ free_task (task, FALSE);
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+ session->state = STATE_REPLY;
+ return TRUE;
+ }
+ cur = g_list_next (cur);
+ }
+ /* Handle messages without text */
+ if (tokens == NULL) {
+ i = snprintf (out_buf, sizeof (out_buf), "weights fail, no tokens can be extracted (no text data)" CRLF);
+ free_task (task, FALSE);
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+ session->state = STATE_REPLY;
+ return TRUE;
+ }
+
+
+ /* Init classifier */
+ cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier);
+
+ cur = session->learn_classifier->classifier->weights_func (cls_ctx, session->worker->srv->statfile_pool,
+ tokens, task);
+ i = 0;
+ struct classify_weight *w;
+
+ while (cur) {
+ w = cur->data;
+ i += snprintf (out_buf + i, sizeof (out_buf) - i, "%s: %.2f" CRLF, w->name, w->weight);
+ cur = g_list_next (cur);
+ }
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+
free_task (task, FALSE);
i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum);
if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {