aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/controller.c73
-rw-r--r--src/filter.c109
-rw-r--r--src/filter.h9
-rw-r--r--src/main.h29
-rw-r--r--src/protocol.c62
-rw-r--r--src/protocol.h2
-rw-r--r--src/worker.c57
7 files changed, 235 insertions, 106 deletions
diff --git a/src/controller.c b/src/controller.c
index 380c47791..9504d3b1f 100644
--- a/src/controller.c
+++ b/src/controller.c
@@ -723,8 +723,6 @@ controller_read_socket (f_str_t * in, void *arg)
{
struct controller_session *session = (struct controller_session *)arg;
struct classifier_ctx *cls_ctx;
- stat_file_t *statfile;
- struct statfile *st;
gint len, i, r;
gchar *s, **params, *cmd, out_buf[128];
struct worker_task *task;
@@ -733,7 +731,6 @@ controller_read_socket (f_str_t * in, void *arg)
GTree *tokens = NULL;
GError *err = NULL;
f_str_t c;
- double sum;
switch (session->state) {
case STATE_COMMAND:
@@ -799,74 +796,14 @@ controller_read_socket (f_str_t * in, void *arg)
}
return FALSE;
}
- if ((s = g_hash_table_lookup (session->learn_classifier->opts, "header")) != NULL) {
- cur = message_get_header (task->task_pool, task->message, s, FALSE);
- if (cur) {
- memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
- }
- }
- else {
- cur = g_list_first (task->text_parts);
- }
- while (cur) {
- if (s != NULL) {
- c.len = strlen (cur->data);
- c.begin = cur->data;
- }
- else {
- part = cur->data;
- if (part->is_empty) {
- cur = g_list_next (cur);
- continue;
- }
- c.begin = part->content->data;
- c.len = part->content->len;
- }
- if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) {
- i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, tokenizer error" CRLF);
- free_task (task, FALSE);
- if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
- return FALSE;
- }
- session->state = STATE_REPLY;
- return TRUE;
- }
- cur = g_list_next (cur);
- }
-
- /* Handle messages without text */
- if (tokens == NULL) {
- i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, no tokens can be extracted (no text data)" CRLF END);
- msg_info ("learn failed for message <%s>, no tokens to extract", task->message_id);
- free_task (task, FALSE);
- if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
- return FALSE;
- }
- session->state = STATE_REPLY;
- return TRUE;
- }
- /* Take care of subject */
- tokenize_subject (task, &tokens);
-
- /* Init classifier */
- cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier);
- /* Get or create statfile */
- statfile = get_statfile_by_symbol (session->worker->srv->statfile_pool, session->learn_classifier,
- session->learn_symbol, &st, TRUE);
-
- if (statfile == NULL ||
- ! session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool,
- session->learn_symbol, tokens, session->in_class, &sum,
- session->learn_multiplier, &err)) {
+ if (!learn_task (session->learn_symbol, task, &err)) {
if (err) {
i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF END, err->message);
- msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, err->message);
g_error_free (err);
}
else {
i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF END);
- msg_info ("learn failed for message <%s>, unknown learn error", task->message_id);
}
free_task (task, FALSE);
if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
@@ -875,18 +812,12 @@ controller_read_socket (f_str_t * in, void *arg)
session->state = STATE_REPLY;
return TRUE;
}
- session->worker->srv->stat->messages_learned++;
- maybe_write_binlog (session->learn_classifier, st, statfile, tokens);
- msg_info ("learn success for message <%s>, for statfile: %s, sum weight: %.2f",
- task->message_id, session->learn_symbol, sum);
- statfile_pool_plan_invalidate (session->worker->srv->statfile_pool, DEFAULT_STATFILE_INVALIDATE_TIME, DEFAULT_STATFILE_INVALIDATE_JITTER);
free_task (task, FALSE);
- i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF END, sum);
+ i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok" CRLF END);
if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
return FALSE;
}
-
session->state = STATE_REPLY;
break;
case STATE_WEIGHTS:
diff --git a/src/filter.c b/src/filter.c
index ec7b5a5ed..df8e1a9e0 100644
--- a/src/filter.c
+++ b/src/filter.c
@@ -43,6 +43,12 @@
# include "lua/lua_common.h"
#endif
+static inline GQuark
+filter_error_quark (void)
+{
+ return g_quark_from_static_string ("g-filter-error-quark");
+}
+
static void
insert_metric_result (struct worker_task *task, struct metric *metric, const gchar *symbol,
double flag, GList * opts, gboolean single)
@@ -799,6 +805,109 @@ check_metric_action (double score, double required_score, struct metric *metric)
}
}
+gboolean
+learn_task (const gchar *statfile, struct worker_task *task, GError **err)
+{
+ GList *cur;
+ struct classifier_config *cl;
+ struct classifier_ctx *cls_ctx;
+ gchar *s;
+ f_str_t c;
+ GTree *tokens = NULL;
+ struct statfile *st;
+ stat_file_t *stf;
+ gdouble sum;
+ struct mime_text_part *part;
+
+ /* Load classifier by symbol */
+ cl = g_hash_table_lookup (task->cfg->classifiers_symbols, statfile);
+ if (cl == NULL) {
+ g_set_error (err, filter_error_quark(), 1, "Statfile %s is not configured in any classifier", statfile);
+ return FALSE;
+ }
+
+ /* If classifier has 'header' option just classify header of this type */
+ if ((s = g_hash_table_lookup (cl->opts, "header")) != NULL) {
+ cur = message_get_header (task->task_pool, task->message, s, FALSE);
+ if (cur) {
+ memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
+ }
+ }
+ else {
+ /* Classify message otherwise */
+ cur = g_list_first (task->text_parts);
+ }
+
+ /* Get tokens from each element */
+ while (cur) {
+ if (s != NULL) {
+ c.len = strlen (cur->data);
+ c.begin = cur->data;
+ }
+ else {
+ part = cur->data;
+ /* Skip empty parts */
+ if (part->is_empty) {
+ cur = g_list_next (cur);
+ continue;
+ }
+ c.begin = part->content->data;
+ c.len = part->content->len;
+ }
+ /* Get tokens */
+ if (!cl->tokenizer->tokenize_func (
+ cl->tokenizer, task->task_pool,
+ &c, &tokens)) {
+ g_set_error (err, filter_error_quark(), 2, "Cannot tokenize message");
+ return FALSE;
+ }
+ cur = g_list_next (cur);
+ }
+
+ /* Handle messages without text */
+ if (tokens == NULL) {
+ g_set_error (err, filter_error_quark(), 3, "Cannot tokenize message, no text data");
+ msg_info ("learn failed for message <%s>, no tokens to extract", task->message_id);
+ return FALSE;
+ }
+
+ /* Take care of subject */
+ tokenize_subject (task, &tokens);
+
+ /* Init classifier */
+ cls_ctx = cl->classifier->init_func (
+ task->task_pool, cl);
+ /* Get or create statfile */
+ stf = get_statfile_by_symbol (task->worker->srv->statfile_pool,
+ cl, statfile, &st, TRUE);
+
+ /* Learn */
+ if (stf== NULL || !cl->classifier->learn_func (
+ cls_ctx, task->worker->srv->statfile_pool,
+ statfile, tokens, TRUE, &sum,
+ 1.0, err)) {
+ if (*err) {
+ msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, (*err)->message);
+ return FALSE;
+ }
+ else {
+ g_set_error (err, filter_error_quark(), 4, "Learn failed, unknown learn classifier error");
+ msg_info ("learn failed for message <%s>, unknown learn error", task->message_id);
+ return FALSE;
+ }
+ }
+ /* Increase statistics */
+ task->worker->srv->stat->messages_learned++;
+
+ maybe_write_binlog (cl, st, stf, tokens);
+ msg_info ("learn success for message <%s>, for statfile: %s, sum weight: %.2f",
+ task->message_id, statfile, sum);
+ statfile_pool_plan_invalidate (task->worker->srv->statfile_pool,
+ DEFAULT_STATFILE_INVALIDATE_TIME,
+ DEFAULT_STATFILE_INVALIDATE_JITTER);
+
+ return TRUE;
+}
/*
* vi:ts=4
diff --git a/src/filter.h b/src/filter.h
index cea49893b..2c3dde4fc 100644
--- a/src/filter.h
+++ b/src/filter.h
@@ -123,6 +123,15 @@ void make_composites (struct worker_task *task);
*/
double factor_consolidation_func (struct worker_task *task, const gchar *metric_name, const gchar *unused);
+/*
+ * Learn specified statfile with message in a task
+ * @param statfile symbol of statfile
+ * @param task worker's task object
+ * @param err pointer to GError
+ * @return true if learn succeed
+ */
+gboolean learn_task (const gchar *statfile, struct worker_task *task, GError **err);
+
gboolean check_action_str (const gchar *data, gint *result);
const gchar *str_action_metric (enum rspamd_metric_action action);
gint check_metric_action (double score, double required_score, struct metric *metric);
diff --git a/src/main.h b/src/main.h
index 78b3af14b..581883a6e 100644
--- a/src/main.h
+++ b/src/main.h
@@ -95,7 +95,7 @@ struct rspamd_main {
memory_pool_t *server_pool; /**< server's memory pool */
statfile_pool_t *statfile_pool; /**< shared statfiles pool */
- GHashTable *workers; /**< workers pool indexed by pid */
+ GHashTable *workers; /**< workers pool indexed by pid */
};
struct counter_data {
@@ -117,9 +117,9 @@ struct save_point {
* Union that would be used for storing sockaddrs
*/
union sa_union {
- struct sockaddr_storage ss;
- struct sockaddr_in s4;
- struct sockaddr_in6 s6;
+ struct sockaddr_storage ss;
+ struct sockaddr_in s4;
+ struct sockaddr_in6 s6;
};
/**
@@ -151,9 +151,9 @@ struct controller_session {
GList *parts; /**< extracted mime parts */
gint in_class; /**< positive or negative learn */
void (*other_handler)(struct controller_session *session,
- f_str_t *in); /**< other command handler to execute at the end of processing */
+ f_str_t *in); /**< other command handler to execute at the end of processing */
void *other_data; /**< and its data */
- struct rspamd_async_session* s; /**< async session object */
+ struct rspamd_async_session* s; /**< async session object */
};
typedef void (*controller_func_t)(gchar **args, struct controller_session *session);
@@ -178,10 +178,12 @@ struct worker_task {
enum rspamd_command cmd; /**< command */
struct custom_command *custom_cmd; /**< custom command if any */
gint sock; /**< socket descriptor */
- gboolean is_mime; /**< if this task is mime task */
- gboolean is_json; /**< output is JSON */
- gboolean is_http; /**< output is HTTP */
- gboolean is_skipped; /**< whether message was skipped by configuration */
+ gboolean is_mime; /**< if this task is mime task */
+ gboolean is_json; /**< output is JSON */
+ gboolean is_http; /**< output is HTTP */
+ gboolean allow_learn; /**< allow learning */
+ gboolean is_skipped; /**< whether message was skipped by configuration */
+
gchar *helo; /**< helo header value */
gchar *from; /**< from header value */
gchar *queue_id; /**< queue id if specified */
@@ -193,9 +195,10 @@ struct worker_task {
gchar *deliver_to; /**< address to deliver */
gchar *user; /**< user to deliver */
gchar *subject; /**< subject (for non-mime) */
+ gchar *statfile; /**< statfile for learning */
f_str_t *msg; /**< message buffer */
rspamd_io_dispatcher_t *dispatcher; /**< IO dispatcher object */
- struct rspamd_async_session* s; /**< async session object */
+ struct rspamd_async_session* s; /**< async session object */
gint parts_count; /**< mime parts count */
GMimeMessage *message; /**< message, parsed with GMime */
GMimeObject *parser_parent_part; /**< current parent part */
@@ -209,9 +212,9 @@ struct worker_task {
GList *images; /**< list of images */
GList *raw_headers_list; /**< list of raw headers */
GHashTable *results; /**< hash table of metric_result indexed by
- * metric's name */
+ * metric's name */
GHashTable *tokens; /**< hash table of tokens indexed by tokenizer
- * pointer */
+ * pointer */
GList *messages; /**< list of messages that would be reported */
GHashTable *re_cache; /**< cache for matched or not matched regexps */
struct config_file *cfg; /**< pointer to config object */
diff --git a/src/protocol.c b/src/protocol.c
index 8ffaddea1..ac8515004 100644
--- a/src/protocol.c
+++ b/src/protocol.c
@@ -64,6 +64,11 @@
#define MSG_CMD_PROCESS "process"
/*
+ * Learn specified statfile using message
+ */
+#define MSG_CMD_LEARN "learn"
+
+/*
* spamassassin greeting:
*/
#define SPAMC_GREETING "SPAMC"
@@ -81,6 +86,7 @@
#define NRCPT_HEADER "Recipient-Number"
#define RCPT_HEADER "Rcpt"
#define SUBJECT_HEADER "Subject"
+#define STATFILE_HEADER "Statfile"
#define QUEUE_ID_HEADER "Queue-ID"
#define ERROR_HEADER "Error"
#define USER_HEADER "User"
@@ -198,6 +204,22 @@ parse_check_command (struct worker_task *task, gchar *token)
return FALSE;
}
break;
+ case 'l':
+ case 'L':
+ if (g_ascii_strcasecmp (token + 1, MSG_CMD_LEARN + 1) == 0) {
+ if (task->allow_learn) {
+ task->cmd = CMD_LEARN;
+ }
+ else {
+ msg_info ("learning is disabled");
+ return FALSE;
+ }
+ }
+ else {
+ debug_task ("bad command: %s", token);
+ return FALSE;
+ }
+ break;
default:
cur = custom_commands;
while (cur) {
@@ -306,8 +328,8 @@ parse_http_command (struct worker_task *task, f_str_t * line)
}
else {
/* Copy command */
- cmd = memory_pool_alloc (task->task_pool, p - c);
- rspamd_strlcpy (cmd, c, p - c);
+ cmd = memory_pool_alloc (task->task_pool, p - c + 1);
+ rspamd_strlcpy (cmd, c, p - c + 1);
/* Skip the first '/' */
if (*cmd == '/') {
cmd ++;
@@ -379,8 +401,22 @@ parse_header (struct worker_task *task, f_str_t * line)
}
else {
if (task->content_length > 0) {
- rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length);
- task->state = READ_MESSAGE;
+ if (task->cmd == CMD_LEARN) {
+ if (task->statfile != NULL) {
+ rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length);
+ task->state = READ_MESSAGE;
+ }
+ else {
+ task->last_error = "Unknown statfile";
+ task->error_code = RSPAMD_STATFILE_ERROR;
+ task->state = WRITE_ERROR;
+ return FALSE;
+ }
+ }
+ else {
+ rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length);
+ task->state = READ_MESSAGE;
+ }
}
else {
task->last_error = "Unknown content length";
@@ -528,6 +564,9 @@ parse_header (struct worker_task *task, f_str_t * line)
if (g_ascii_strncasecmp (headern, SUBJECT_HEADER, sizeof (SUBJECT_HEADER) - 1) == 0) {
task->subject = memory_pool_fstrdup (task->task_pool, line);
}
+ else if (g_ascii_strncasecmp (headern, STATFILE_HEADER, sizeof (STATFILE_HEADER) - 1) == 0) {
+ task->statfile = memory_pool_fstrdup (task->task_pool, line);
+ }
else {
return FALSE;
}
@@ -1433,7 +1472,7 @@ write_reply (struct worker_task *task)
/* Write error message and error code to reply */
if (task->is_http) {
r = rspamd_snprintf (outbuf, sizeof (outbuf), "HTTP/1.0 400 Bad request" CRLF
- "Connection: close" CRLF CRLF);
+ "Connection: close" CRLF CRLF "Error: %d - %s" CRLF, task->error_code, task->last_error);
}
else {
if (task->proto == SPAMC_PROTO) {
@@ -1471,6 +1510,19 @@ write_reply (struct worker_task *task)
(task->proto == SPAMC_PROTO) ? SPAMD_REPLY_BANNER : RSPAMD_REPLY_BANNER, rspamc_proto_str (task->proto_ver));
return rspamd_dispatcher_write (task->dispatcher, outbuf, r, FALSE, FALSE);
break;
+ case CMD_LEARN:
+ if (task->is_http) {
+ r = rspamd_snprintf (outbuf, sizeof (outbuf), "HTTP/1.0 200 Ok" CRLF
+ "Connection: close" CRLF CRLF "%s" CRLF, task->last_error);
+ }
+ else {
+ r = rspamd_snprintf (outbuf, sizeof (outbuf), "%s/%s 0 LEARN" CRLF CRLF "%s" CRLF,
+ (task->proto == SPAMC_PROTO) ? SPAMD_REPLY_BANNER : RSPAMD_REPLY_BANNER,
+ rspamc_proto_str (task->proto_ver),
+ task->last_error);
+ }
+ return rspamd_dispatcher_write (task->dispatcher, outbuf, r, FALSE, FALSE);
+ break;
case CMD_OTHER:
return task->custom_cmd->func (task);
}
diff --git a/src/protocol.h b/src/protocol.h
index a15530a7b..de6d0ea03 100644
--- a/src/protocol.h
+++ b/src/protocol.h
@@ -13,6 +13,7 @@
#define RSPAMD_NETWORK_ERROR 2
#define RSPAMD_PROTOCOL_ERROR 3
#define RSPAMD_LENGTH_ERROR 4
+#define RSPAMD_STATFILE_ERROR 5
#define RSPAMC_PROTO_1_0 "1.0"
#define RSPAMC_PROTO_1_1 "1.1"
@@ -44,6 +45,7 @@ enum rspamd_command {
CMD_SKIP,
CMD_PING,
CMD_PROCESS,
+ CMD_LEARN,
CMD_OTHER,
};
diff --git a/src/worker.c b/src/worker.c
index 57ad4ecf1..a9b05d64e 100644
--- a/src/worker.c
+++ b/src/worker.c
@@ -89,6 +89,8 @@ struct rspamd_worker_ctx {
gboolean is_http;
/* JSON output */
gboolean is_json;
+ /* Allow learning throught worker */
+ gboolean allow_learn;
GList *custom_filters;
/* DNS resolver */
struct rspamd_dns_resolver *resolver;
@@ -318,6 +320,7 @@ read_socket (f_str_t * in, void *arg)
struct worker_task *task = (struct worker_task *) arg;
struct rspamd_worker_ctx *ctx;
ssize_t r;
+ GError *err = NULL;
ctx = task->worker->ctx;
switch (task->state) {
@@ -332,8 +335,10 @@ read_socket (f_str_t * in, void *arg)
}
else {
if (!read_rspamd_input_line (task, in)) {
- task->last_error = "Read error";
- task->error_code = RSPAMD_NETWORK_ERROR;
+ if (!task->last_error) {
+ task->last_error = "Read error";
+ task->error_code = RSPAMD_NETWORK_ERROR;
+ }
task->state = WRITE_ERROR;
}
}
@@ -359,22 +364,38 @@ read_socket (f_str_t * in, void *arg)
task->state = WRITE_REPLY;
return write_socket (task);
}
- r = process_filters (task);
- if (r == -1) {
- task->last_error = "Filter processing error";
- task->error_code = RSPAMD_FILTER_ERROR;
- task->state = WRITE_ERROR;
+ else if (task->cmd == CMD_LEARN) {
+ if (!learn_task (task->statfile, task, &err)) {
+ task->last_error = memory_pool_strdup (task->task_pool, err->message);
+ task->error_code = err->code;
+ g_error_free (err);
+ task->state = WRITE_ERROR;
+ }
+ else {
+ task->last_error = "learn ok";
+ task->error_code = 0;
+ task->state = WRITE_REPLY;
+ }
return write_socket (task);
}
- else if (r == 0) {
- task->state = WAIT_FILTER;
- rspamd_dispatcher_pause (task->dispatcher);
- }
else {
- process_statfiles (task);
- lua_call_post_filters (task);
- task->state = WRITE_REPLY;
- return write_socket (task);
+ r = process_filters (task);
+ if (r == -1) {
+ task->last_error = "Filter processing error";
+ task->error_code = RSPAMD_FILTER_ERROR;
+ task->state = WRITE_ERROR;
+ return write_socket (task);
+ }
+ else if (r == 0) {
+ task->state = WAIT_FILTER;
+ rspamd_dispatcher_pause (task->dispatcher);
+ }
+ else {
+ process_statfiles (task);
+ lua_call_post_filters (task);
+ task->state = WRITE_REPLY;
+ return write_socket (task);
+ }
}
break;
case WRITE_REPLY:
@@ -515,9 +536,8 @@ construct_task (struct rspamd_worker *worker)
{
struct worker_task *new_task;
- new_task = g_malloc (sizeof (struct worker_task));
+ new_task = g_malloc0 (sizeof (struct worker_task));
- bzero (new_task, sizeof (struct worker_task));
new_task->worker = worker;
new_task->state = READ_COMMAND;
new_task->cfg = worker->srv->cfg;
@@ -605,10 +625,12 @@ accept_socket (gint fd, short what, void *arg)
sizeof (struct in_addr));
}
+ /* Copy some variables */
new_task->sock = nfd;
new_task->is_mime = ctx->is_mime;
new_task->is_json = ctx->is_json;
new_task->is_http = ctx->is_http;
+ new_task->allow_learn = ctx->allow_learn;
worker->srv->stat->connections_count++;
new_task->resolver = ctx->resolver;
@@ -750,6 +772,7 @@ init_worker (void)
register_worker_opt (TYPE_WORKER, "mime", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_mime));
register_worker_opt (TYPE_WORKER, "http", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_http));
register_worker_opt (TYPE_WORKER, "json", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_json));
+ register_worker_opt (TYPE_WORKER, "allow_learn", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, allow_learn));
register_worker_opt (TYPE_WORKER, "timeout", xml_handle_seconds, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, timeout));
return ctx;