diff options
-rw-r--r-- | src/controller.c | 73 | ||||
-rw-r--r-- | src/filter.c | 109 | ||||
-rw-r--r-- | src/filter.h | 9 | ||||
-rw-r--r-- | src/main.h | 29 | ||||
-rw-r--r-- | src/protocol.c | 62 | ||||
-rw-r--r-- | src/protocol.h | 2 | ||||
-rw-r--r-- | src/worker.c | 57 |
7 files changed, 235 insertions, 106 deletions
diff --git a/src/controller.c b/src/controller.c index 380c47791..9504d3b1f 100644 --- a/src/controller.c +++ b/src/controller.c @@ -723,8 +723,6 @@ controller_read_socket (f_str_t * in, void *arg) { struct controller_session *session = (struct controller_session *)arg; struct classifier_ctx *cls_ctx; - stat_file_t *statfile; - struct statfile *st; gint len, i, r; gchar *s, **params, *cmd, out_buf[128]; struct worker_task *task; @@ -733,7 +731,6 @@ controller_read_socket (f_str_t * in, void *arg) GTree *tokens = NULL; GError *err = NULL; f_str_t c; - double sum; switch (session->state) { case STATE_COMMAND: @@ -799,74 +796,14 @@ controller_read_socket (f_str_t * in, void *arg) } return FALSE; } - if ((s = g_hash_table_lookup (session->learn_classifier->opts, "header")) != NULL) { - cur = message_get_header (task->task_pool, task->message, s, FALSE); - if (cur) { - memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur); - } - } - else { - cur = g_list_first (task->text_parts); - } - while (cur) { - if (s != NULL) { - c.len = strlen (cur->data); - c.begin = cur->data; - } - else { - part = cur->data; - if (part->is_empty) { - cur = g_list_next (cur); - continue; - } - c.begin = part->content->data; - c.len = part->content->len; - } - if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) { - i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, tokenizer error" CRLF); - free_task (task, FALSE); - if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { - return FALSE; - } - session->state = STATE_REPLY; - return TRUE; - } - cur = g_list_next (cur); - } - - /* Handle messages without text */ - if (tokens == NULL) { - i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, no tokens can be extracted (no text data)" CRLF END); - msg_info ("learn failed for message <%s>, no tokens to extract", task->message_id); - free_task (task, FALSE); - if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { - return FALSE; - } - session->state = STATE_REPLY; - return TRUE; - } - /* Take care of subject */ - tokenize_subject (task, &tokens); - - /* Init classifier */ - cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier); - /* Get or create statfile */ - statfile = get_statfile_by_symbol (session->worker->srv->statfile_pool, session->learn_classifier, - session->learn_symbol, &st, TRUE); - - if (statfile == NULL || - ! session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool, - session->learn_symbol, tokens, session->in_class, &sum, - session->learn_multiplier, &err)) { + if (!learn_task (session->learn_symbol, task, &err)) { if (err) { i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF END, err->message); - msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, err->message); g_error_free (err); } else { i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF END); - msg_info ("learn failed for message <%s>, unknown learn error", task->message_id); } free_task (task, FALSE); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { @@ -875,18 +812,12 @@ controller_read_socket (f_str_t * in, void *arg) session->state = STATE_REPLY; return TRUE; } - session->worker->srv->stat->messages_learned++; - maybe_write_binlog (session->learn_classifier, st, statfile, tokens); - msg_info ("learn success for message <%s>, for statfile: %s, sum weight: %.2f", - task->message_id, session->learn_symbol, sum); - statfile_pool_plan_invalidate (session->worker->srv->statfile_pool, DEFAULT_STATFILE_INVALIDATE_TIME, DEFAULT_STATFILE_INVALIDATE_JITTER); free_task (task, FALSE); - i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF END, sum); + i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok" CRLF END); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { return FALSE; } - session->state = STATE_REPLY; break; case STATE_WEIGHTS: diff --git a/src/filter.c b/src/filter.c index ec7b5a5ed..df8e1a9e0 100644 --- a/src/filter.c +++ b/src/filter.c @@ -43,6 +43,12 @@ # include "lua/lua_common.h" #endif +static inline GQuark +filter_error_quark (void) +{ + return g_quark_from_static_string ("g-filter-error-quark"); +} + static void insert_metric_result (struct worker_task *task, struct metric *metric, const gchar *symbol, double flag, GList * opts, gboolean single) @@ -799,6 +805,109 @@ check_metric_action (double score, double required_score, struct metric *metric) } } +gboolean +learn_task (const gchar *statfile, struct worker_task *task, GError **err) +{ + GList *cur; + struct classifier_config *cl; + struct classifier_ctx *cls_ctx; + gchar *s; + f_str_t c; + GTree *tokens = NULL; + struct statfile *st; + stat_file_t *stf; + gdouble sum; + struct mime_text_part *part; + + /* Load classifier by symbol */ + cl = g_hash_table_lookup (task->cfg->classifiers_symbols, statfile); + if (cl == NULL) { + g_set_error (err, filter_error_quark(), 1, "Statfile %s is not configured in any classifier", statfile); + return FALSE; + } + + /* If classifier has 'header' option just classify header of this type */ + if ((s = g_hash_table_lookup (cl->opts, "header")) != NULL) { + cur = message_get_header (task->task_pool, task->message, s, FALSE); + if (cur) { + memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur); + } + } + else { + /* Classify message otherwise */ + cur = g_list_first (task->text_parts); + } + + /* Get tokens from each element */ + while (cur) { + if (s != NULL) { + c.len = strlen (cur->data); + c.begin = cur->data; + } + else { + part = cur->data; + /* Skip empty parts */ + if (part->is_empty) { + cur = g_list_next (cur); + continue; + } + c.begin = part->content->data; + c.len = part->content->len; + } + /* Get tokens */ + if (!cl->tokenizer->tokenize_func ( + cl->tokenizer, task->task_pool, + &c, &tokens)) { + g_set_error (err, filter_error_quark(), 2, "Cannot tokenize message"); + return FALSE; + } + cur = g_list_next (cur); + } + + /* Handle messages without text */ + if (tokens == NULL) { + g_set_error (err, filter_error_quark(), 3, "Cannot tokenize message, no text data"); + msg_info ("learn failed for message <%s>, no tokens to extract", task->message_id); + return FALSE; + } + + /* Take care of subject */ + tokenize_subject (task, &tokens); + + /* Init classifier */ + cls_ctx = cl->classifier->init_func ( + task->task_pool, cl); + /* Get or create statfile */ + stf = get_statfile_by_symbol (task->worker->srv->statfile_pool, + cl, statfile, &st, TRUE); + + /* Learn */ + if (stf== NULL || !cl->classifier->learn_func ( + cls_ctx, task->worker->srv->statfile_pool, + statfile, tokens, TRUE, &sum, + 1.0, err)) { + if (*err) { + msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, (*err)->message); + return FALSE; + } + else { + g_set_error (err, filter_error_quark(), 4, "Learn failed, unknown learn classifier error"); + msg_info ("learn failed for message <%s>, unknown learn error", task->message_id); + return FALSE; + } + } + /* Increase statistics */ + task->worker->srv->stat->messages_learned++; + + maybe_write_binlog (cl, st, stf, tokens); + msg_info ("learn success for message <%s>, for statfile: %s, sum weight: %.2f", + task->message_id, statfile, sum); + statfile_pool_plan_invalidate (task->worker->srv->statfile_pool, + DEFAULT_STATFILE_INVALIDATE_TIME, + DEFAULT_STATFILE_INVALIDATE_JITTER); + + return TRUE; +} /* * vi:ts=4 diff --git a/src/filter.h b/src/filter.h index cea49893b..2c3dde4fc 100644 --- a/src/filter.h +++ b/src/filter.h @@ -123,6 +123,15 @@ void make_composites (struct worker_task *task); */ double factor_consolidation_func (struct worker_task *task, const gchar *metric_name, const gchar *unused); +/* + * Learn specified statfile with message in a task + * @param statfile symbol of statfile + * @param task worker's task object + * @param err pointer to GError + * @return true if learn succeed + */ +gboolean learn_task (const gchar *statfile, struct worker_task *task, GError **err); + gboolean check_action_str (const gchar *data, gint *result); const gchar *str_action_metric (enum rspamd_metric_action action); gint check_metric_action (double score, double required_score, struct metric *metric); diff --git a/src/main.h b/src/main.h index 78b3af14b..581883a6e 100644 --- a/src/main.h +++ b/src/main.h @@ -95,7 +95,7 @@ struct rspamd_main { memory_pool_t *server_pool; /**< server's memory pool */ statfile_pool_t *statfile_pool; /**< shared statfiles pool */ - GHashTable *workers; /**< workers pool indexed by pid */ + GHashTable *workers; /**< workers pool indexed by pid */ }; struct counter_data { @@ -117,9 +117,9 @@ struct save_point { * Union that would be used for storing sockaddrs */ union sa_union { - struct sockaddr_storage ss; - struct sockaddr_in s4; - struct sockaddr_in6 s6; + struct sockaddr_storage ss; + struct sockaddr_in s4; + struct sockaddr_in6 s6; }; /** @@ -151,9 +151,9 @@ struct controller_session { GList *parts; /**< extracted mime parts */ gint in_class; /**< positive or negative learn */ void (*other_handler)(struct controller_session *session, - f_str_t *in); /**< other command handler to execute at the end of processing */ + f_str_t *in); /**< other command handler to execute at the end of processing */ void *other_data; /**< and its data */ - struct rspamd_async_session* s; /**< async session object */ + struct rspamd_async_session* s; /**< async session object */ }; typedef void (*controller_func_t)(gchar **args, struct controller_session *session); @@ -178,10 +178,12 @@ struct worker_task { enum rspamd_command cmd; /**< command */ struct custom_command *custom_cmd; /**< custom command if any */ gint sock; /**< socket descriptor */ - gboolean is_mime; /**< if this task is mime task */ - gboolean is_json; /**< output is JSON */ - gboolean is_http; /**< output is HTTP */ - gboolean is_skipped; /**< whether message was skipped by configuration */ + gboolean is_mime; /**< if this task is mime task */ + gboolean is_json; /**< output is JSON */ + gboolean is_http; /**< output is HTTP */ + gboolean allow_learn; /**< allow learning */ + gboolean is_skipped; /**< whether message was skipped by configuration */ + gchar *helo; /**< helo header value */ gchar *from; /**< from header value */ gchar *queue_id; /**< queue id if specified */ @@ -193,9 +195,10 @@ struct worker_task { gchar *deliver_to; /**< address to deliver */ gchar *user; /**< user to deliver */ gchar *subject; /**< subject (for non-mime) */ + gchar *statfile; /**< statfile for learning */ f_str_t *msg; /**< message buffer */ rspamd_io_dispatcher_t *dispatcher; /**< IO dispatcher object */ - struct rspamd_async_session* s; /**< async session object */ + struct rspamd_async_session* s; /**< async session object */ gint parts_count; /**< mime parts count */ GMimeMessage *message; /**< message, parsed with GMime */ GMimeObject *parser_parent_part; /**< current parent part */ @@ -209,9 +212,9 @@ struct worker_task { GList *images; /**< list of images */ GList *raw_headers_list; /**< list of raw headers */ GHashTable *results; /**< hash table of metric_result indexed by - * metric's name */ + * metric's name */ GHashTable *tokens; /**< hash table of tokens indexed by tokenizer - * pointer */ + * pointer */ GList *messages; /**< list of messages that would be reported */ GHashTable *re_cache; /**< cache for matched or not matched regexps */ struct config_file *cfg; /**< pointer to config object */ diff --git a/src/protocol.c b/src/protocol.c index 8ffaddea1..ac8515004 100644 --- a/src/protocol.c +++ b/src/protocol.c @@ -64,6 +64,11 @@ #define MSG_CMD_PROCESS "process" /* + * Learn specified statfile using message + */ +#define MSG_CMD_LEARN "learn" + +/* * spamassassin greeting: */ #define SPAMC_GREETING "SPAMC" @@ -81,6 +86,7 @@ #define NRCPT_HEADER "Recipient-Number" #define RCPT_HEADER "Rcpt" #define SUBJECT_HEADER "Subject" +#define STATFILE_HEADER "Statfile" #define QUEUE_ID_HEADER "Queue-ID" #define ERROR_HEADER "Error" #define USER_HEADER "User" @@ -198,6 +204,22 @@ parse_check_command (struct worker_task *task, gchar *token) return FALSE; } break; + case 'l': + case 'L': + if (g_ascii_strcasecmp (token + 1, MSG_CMD_LEARN + 1) == 0) { + if (task->allow_learn) { + task->cmd = CMD_LEARN; + } + else { + msg_info ("learning is disabled"); + return FALSE; + } + } + else { + debug_task ("bad command: %s", token); + return FALSE; + } + break; default: cur = custom_commands; while (cur) { @@ -306,8 +328,8 @@ parse_http_command (struct worker_task *task, f_str_t * line) } else { /* Copy command */ - cmd = memory_pool_alloc (task->task_pool, p - c); - rspamd_strlcpy (cmd, c, p - c); + cmd = memory_pool_alloc (task->task_pool, p - c + 1); + rspamd_strlcpy (cmd, c, p - c + 1); /* Skip the first '/' */ if (*cmd == '/') { cmd ++; @@ -379,8 +401,22 @@ parse_header (struct worker_task *task, f_str_t * line) } else { if (task->content_length > 0) { - rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length); - task->state = READ_MESSAGE; + if (task->cmd == CMD_LEARN) { + if (task->statfile != NULL) { + rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length); + task->state = READ_MESSAGE; + } + else { + task->last_error = "Unknown statfile"; + task->error_code = RSPAMD_STATFILE_ERROR; + task->state = WRITE_ERROR; + return FALSE; + } + } + else { + rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length); + task->state = READ_MESSAGE; + } } else { task->last_error = "Unknown content length"; @@ -528,6 +564,9 @@ parse_header (struct worker_task *task, f_str_t * line) if (g_ascii_strncasecmp (headern, SUBJECT_HEADER, sizeof (SUBJECT_HEADER) - 1) == 0) { task->subject = memory_pool_fstrdup (task->task_pool, line); } + else if (g_ascii_strncasecmp (headern, STATFILE_HEADER, sizeof (STATFILE_HEADER) - 1) == 0) { + task->statfile = memory_pool_fstrdup (task->task_pool, line); + } else { return FALSE; } @@ -1433,7 +1472,7 @@ write_reply (struct worker_task *task) /* Write error message and error code to reply */ if (task->is_http) { r = rspamd_snprintf (outbuf, sizeof (outbuf), "HTTP/1.0 400 Bad request" CRLF - "Connection: close" CRLF CRLF); + "Connection: close" CRLF CRLF "Error: %d - %s" CRLF, task->error_code, task->last_error); } else { if (task->proto == SPAMC_PROTO) { @@ -1471,6 +1510,19 @@ write_reply (struct worker_task *task) (task->proto == SPAMC_PROTO) ? SPAMD_REPLY_BANNER : RSPAMD_REPLY_BANNER, rspamc_proto_str (task->proto_ver)); return rspamd_dispatcher_write (task->dispatcher, outbuf, r, FALSE, FALSE); break; + case CMD_LEARN: + if (task->is_http) { + r = rspamd_snprintf (outbuf, sizeof (outbuf), "HTTP/1.0 200 Ok" CRLF + "Connection: close" CRLF CRLF "%s" CRLF, task->last_error); + } + else { + r = rspamd_snprintf (outbuf, sizeof (outbuf), "%s/%s 0 LEARN" CRLF CRLF "%s" CRLF, + (task->proto == SPAMC_PROTO) ? SPAMD_REPLY_BANNER : RSPAMD_REPLY_BANNER, + rspamc_proto_str (task->proto_ver), + task->last_error); + } + return rspamd_dispatcher_write (task->dispatcher, outbuf, r, FALSE, FALSE); + break; case CMD_OTHER: return task->custom_cmd->func (task); } diff --git a/src/protocol.h b/src/protocol.h index a15530a7b..de6d0ea03 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -13,6 +13,7 @@ #define RSPAMD_NETWORK_ERROR 2 #define RSPAMD_PROTOCOL_ERROR 3 #define RSPAMD_LENGTH_ERROR 4 +#define RSPAMD_STATFILE_ERROR 5 #define RSPAMC_PROTO_1_0 "1.0" #define RSPAMC_PROTO_1_1 "1.1" @@ -44,6 +45,7 @@ enum rspamd_command { CMD_SKIP, CMD_PING, CMD_PROCESS, + CMD_LEARN, CMD_OTHER, }; diff --git a/src/worker.c b/src/worker.c index 57ad4ecf1..a9b05d64e 100644 --- a/src/worker.c +++ b/src/worker.c @@ -89,6 +89,8 @@ struct rspamd_worker_ctx { gboolean is_http; /* JSON output */ gboolean is_json; + /* Allow learning throught worker */ + gboolean allow_learn; GList *custom_filters; /* DNS resolver */ struct rspamd_dns_resolver *resolver; @@ -318,6 +320,7 @@ read_socket (f_str_t * in, void *arg) struct worker_task *task = (struct worker_task *) arg; struct rspamd_worker_ctx *ctx; ssize_t r; + GError *err = NULL; ctx = task->worker->ctx; switch (task->state) { @@ -332,8 +335,10 @@ read_socket (f_str_t * in, void *arg) } else { if (!read_rspamd_input_line (task, in)) { - task->last_error = "Read error"; - task->error_code = RSPAMD_NETWORK_ERROR; + if (!task->last_error) { + task->last_error = "Read error"; + task->error_code = RSPAMD_NETWORK_ERROR; + } task->state = WRITE_ERROR; } } @@ -359,22 +364,38 @@ read_socket (f_str_t * in, void *arg) task->state = WRITE_REPLY; return write_socket (task); } - r = process_filters (task); - if (r == -1) { - task->last_error = "Filter processing error"; - task->error_code = RSPAMD_FILTER_ERROR; - task->state = WRITE_ERROR; + else if (task->cmd == CMD_LEARN) { + if (!learn_task (task->statfile, task, &err)) { + task->last_error = memory_pool_strdup (task->task_pool, err->message); + task->error_code = err->code; + g_error_free (err); + task->state = WRITE_ERROR; + } + else { + task->last_error = "learn ok"; + task->error_code = 0; + task->state = WRITE_REPLY; + } return write_socket (task); } - else if (r == 0) { - task->state = WAIT_FILTER; - rspamd_dispatcher_pause (task->dispatcher); - } else { - process_statfiles (task); - lua_call_post_filters (task); - task->state = WRITE_REPLY; - return write_socket (task); + r = process_filters (task); + if (r == -1) { + task->last_error = "Filter processing error"; + task->error_code = RSPAMD_FILTER_ERROR; + task->state = WRITE_ERROR; + return write_socket (task); + } + else if (r == 0) { + task->state = WAIT_FILTER; + rspamd_dispatcher_pause (task->dispatcher); + } + else { + process_statfiles (task); + lua_call_post_filters (task); + task->state = WRITE_REPLY; + return write_socket (task); + } } break; case WRITE_REPLY: @@ -515,9 +536,8 @@ construct_task (struct rspamd_worker *worker) { struct worker_task *new_task; - new_task = g_malloc (sizeof (struct worker_task)); + new_task = g_malloc0 (sizeof (struct worker_task)); - bzero (new_task, sizeof (struct worker_task)); new_task->worker = worker; new_task->state = READ_COMMAND; new_task->cfg = worker->srv->cfg; @@ -605,10 +625,12 @@ accept_socket (gint fd, short what, void *arg) sizeof (struct in_addr)); } + /* Copy some variables */ new_task->sock = nfd; new_task->is_mime = ctx->is_mime; new_task->is_json = ctx->is_json; new_task->is_http = ctx->is_http; + new_task->allow_learn = ctx->allow_learn; worker->srv->stat->connections_count++; new_task->resolver = ctx->resolver; @@ -750,6 +772,7 @@ init_worker (void) register_worker_opt (TYPE_WORKER, "mime", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_mime)); register_worker_opt (TYPE_WORKER, "http", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_http)); register_worker_opt (TYPE_WORKER, "json", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_json)); + register_worker_opt (TYPE_WORKER, "allow_learn", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, allow_learn)); register_worker_opt (TYPE_WORKER, "timeout", xml_handle_seconds, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, timeout)); return ctx; |