Browse Source

* Add ability to learn rspamd via worker (without password)

tags/0.3.11
Vsevolod Stakhov 13 years ago
parent
commit
e414be4059
7 changed files with 235 additions and 106 deletions
  1. 2
    71
      src/controller.c
  2. 109
    0
      src/filter.c
  3. 9
    0
      src/filter.h
  4. 16
    13
      src/main.h
  5. 57
    5
      src/protocol.c
  6. 2
    0
      src/protocol.h
  7. 40
    17
      src/worker.c

+ 2
- 71
src/controller.c View File

@@ -723,8 +723,6 @@ controller_read_socket (f_str_t * in, void *arg)
{
struct controller_session *session = (struct controller_session *)arg;
struct classifier_ctx *cls_ctx;
stat_file_t *statfile;
struct statfile *st;
gint len, i, r;
gchar *s, **params, *cmd, out_buf[128];
struct worker_task *task;
@@ -733,7 +731,6 @@ controller_read_socket (f_str_t * in, void *arg)
GTree *tokens = NULL;
GError *err = NULL;
f_str_t c;
double sum;

switch (session->state) {
case STATE_COMMAND:
@@ -799,74 +796,14 @@ controller_read_socket (f_str_t * in, void *arg)
}
return FALSE;
}
if ((s = g_hash_table_lookup (session->learn_classifier->opts, "header")) != NULL) {
cur = message_get_header (task->task_pool, task->message, s, FALSE);
if (cur) {
memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
}
}
else {
cur = g_list_first (task->text_parts);
}
while (cur) {
if (s != NULL) {
c.len = strlen (cur->data);
c.begin = cur->data;
}
else {
part = cur->data;
if (part->is_empty) {
cur = g_list_next (cur);
continue;
}
c.begin = part->content->data;
c.len = part->content->len;
}
if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) {
i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, tokenizer error" CRLF);
free_task (task, FALSE);
if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
return FALSE;
}
session->state = STATE_REPLY;
return TRUE;
}
cur = g_list_next (cur);
}
/* Handle messages without text */
if (tokens == NULL) {
i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, no tokens can be extracted (no text data)" CRLF END);
msg_info ("learn failed for message <%s>, no tokens to extract", task->message_id);
free_task (task, FALSE);
if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
return FALSE;
}
session->state = STATE_REPLY;
return TRUE;
}

/* Take care of subject */
tokenize_subject (task, &tokens);

/* Init classifier */
cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier);
/* Get or create statfile */
statfile = get_statfile_by_symbol (session->worker->srv->statfile_pool, session->learn_classifier,
session->learn_symbol, &st, TRUE);

if (statfile == NULL ||
! session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool,
session->learn_symbol, tokens, session->in_class, &sum,
session->learn_multiplier, &err)) {
if (!learn_task (session->learn_symbol, task, &err)) {
if (err) {
i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF END, err->message);
msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, err->message);
g_error_free (err);
}
else {
i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF END);
msg_info ("learn failed for message <%s>, unknown learn error", task->message_id);
}
free_task (task, FALSE);
if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
@@ -875,18 +812,12 @@ controller_read_socket (f_str_t * in, void *arg)
session->state = STATE_REPLY;
return TRUE;
}
session->worker->srv->stat->messages_learned++;

maybe_write_binlog (session->learn_classifier, st, statfile, tokens);
msg_info ("learn success for message <%s>, for statfile: %s, sum weight: %.2f",
task->message_id, session->learn_symbol, sum);
statfile_pool_plan_invalidate (session->worker->srv->statfile_pool, DEFAULT_STATFILE_INVALIDATE_TIME, DEFAULT_STATFILE_INVALIDATE_JITTER);
free_task (task, FALSE);
i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF END, sum);
i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok" CRLF END);
if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
return FALSE;
}

session->state = STATE_REPLY;
break;
case STATE_WEIGHTS:

+ 109
- 0
src/filter.c View File

@@ -43,6 +43,12 @@
# include "lua/lua_common.h"
#endif

static inline GQuark
filter_error_quark (void)
{
return g_quark_from_static_string ("g-filter-error-quark");
}

static void
insert_metric_result (struct worker_task *task, struct metric *metric, const gchar *symbol,
double flag, GList * opts, gboolean single)
@@ -799,6 +805,109 @@ check_metric_action (double score, double required_score, struct metric *metric)
}
}

gboolean
learn_task (const gchar *statfile, struct worker_task *task, GError **err)
{
GList *cur;
struct classifier_config *cl;
struct classifier_ctx *cls_ctx;
gchar *s;
f_str_t c;
GTree *tokens = NULL;
struct statfile *st;
stat_file_t *stf;
gdouble sum;
struct mime_text_part *part;

/* Load classifier by symbol */
cl = g_hash_table_lookup (task->cfg->classifiers_symbols, statfile);
if (cl == NULL) {
g_set_error (err, filter_error_quark(), 1, "Statfile %s is not configured in any classifier", statfile);
return FALSE;
}

/* If classifier has 'header' option just classify header of this type */
if ((s = g_hash_table_lookup (cl->opts, "header")) != NULL) {
cur = message_get_header (task->task_pool, task->message, s, FALSE);
if (cur) {
memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
}
}
else {
/* Classify message otherwise */
cur = g_list_first (task->text_parts);
}

/* Get tokens from each element */
while (cur) {
if (s != NULL) {
c.len = strlen (cur->data);
c.begin = cur->data;
}
else {
part = cur->data;
/* Skip empty parts */
if (part->is_empty) {
cur = g_list_next (cur);
continue;
}
c.begin = part->content->data;
c.len = part->content->len;
}
/* Get tokens */
if (!cl->tokenizer->tokenize_func (
cl->tokenizer, task->task_pool,
&c, &tokens)) {
g_set_error (err, filter_error_quark(), 2, "Cannot tokenize message");
return FALSE;
}
cur = g_list_next (cur);
}

/* Handle messages without text */
if (tokens == NULL) {
g_set_error (err, filter_error_quark(), 3, "Cannot tokenize message, no text data");
msg_info ("learn failed for message <%s>, no tokens to extract", task->message_id);
return FALSE;
}

/* Take care of subject */
tokenize_subject (task, &tokens);

/* Init classifier */
cls_ctx = cl->classifier->init_func (
task->task_pool, cl);
/* Get or create statfile */
stf = get_statfile_by_symbol (task->worker->srv->statfile_pool,
cl, statfile, &st, TRUE);

/* Learn */
if (stf== NULL || !cl->classifier->learn_func (
cls_ctx, task->worker->srv->statfile_pool,
statfile, tokens, TRUE, &sum,
1.0, err)) {
if (*err) {
msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, (*err)->message);
return FALSE;
}
else {
g_set_error (err, filter_error_quark(), 4, "Learn failed, unknown learn classifier error");
msg_info ("learn failed for message <%s>, unknown learn error", task->message_id);
return FALSE;
}
}
/* Increase statistics */
task->worker->srv->stat->messages_learned++;

maybe_write_binlog (cl, st, stf, tokens);
msg_info ("learn success for message <%s>, for statfile: %s, sum weight: %.2f",
task->message_id, statfile, sum);
statfile_pool_plan_invalidate (task->worker->srv->statfile_pool,
DEFAULT_STATFILE_INVALIDATE_TIME,
DEFAULT_STATFILE_INVALIDATE_JITTER);

return TRUE;
}

/*
* vi:ts=4

+ 9
- 0
src/filter.h View File

@@ -123,6 +123,15 @@ void make_composites (struct worker_task *task);
*/
double factor_consolidation_func (struct worker_task *task, const gchar *metric_name, const gchar *unused);

/*
* Learn specified statfile with message in a task
* @param statfile symbol of statfile
* @param task worker's task object
* @param err pointer to GError
* @return true if learn succeed
*/
gboolean learn_task (const gchar *statfile, struct worker_task *task, GError **err);

gboolean check_action_str (const gchar *data, gint *result);
const gchar *str_action_metric (enum rspamd_metric_action action);
gint check_metric_action (double score, double required_score, struct metric *metric);

+ 16
- 13
src/main.h View File

@@ -95,7 +95,7 @@ struct rspamd_main {

memory_pool_t *server_pool; /**< server's memory pool */
statfile_pool_t *statfile_pool; /**< shared statfiles pool */
GHashTable *workers; /**< workers pool indexed by pid */
GHashTable *workers; /**< workers pool indexed by pid */
};

struct counter_data {
@@ -117,9 +117,9 @@ struct save_point {
* Union that would be used for storing sockaddrs
*/
union sa_union {
struct sockaddr_storage ss;
struct sockaddr_in s4;
struct sockaddr_in6 s6;
struct sockaddr_storage ss;
struct sockaddr_in s4;
struct sockaddr_in6 s6;
};

/**
@@ -151,9 +151,9 @@ struct controller_session {
GList *parts; /**< extracted mime parts */
gint in_class; /**< positive or negative learn */
void (*other_handler)(struct controller_session *session,
f_str_t *in); /**< other command handler to execute at the end of processing */
f_str_t *in); /**< other command handler to execute at the end of processing */
void *other_data; /**< and its data */
struct rspamd_async_session* s; /**< async session object */
struct rspamd_async_session* s; /**< async session object */
};

typedef void (*controller_func_t)(gchar **args, struct controller_session *session);
@@ -178,10 +178,12 @@ struct worker_task {
enum rspamd_command cmd; /**< command */
struct custom_command *custom_cmd; /**< custom command if any */
gint sock; /**< socket descriptor */
gboolean is_mime; /**< if this task is mime task */
gboolean is_json; /**< output is JSON */
gboolean is_http; /**< output is HTTP */
gboolean is_skipped; /**< whether message was skipped by configuration */
gboolean is_mime; /**< if this task is mime task */
gboolean is_json; /**< output is JSON */
gboolean is_http; /**< output is HTTP */
gboolean allow_learn; /**< allow learning */
gboolean is_skipped; /**< whether message was skipped by configuration */

gchar *helo; /**< helo header value */
gchar *from; /**< from header value */
gchar *queue_id; /**< queue id if specified */
@@ -193,9 +195,10 @@ struct worker_task {
gchar *deliver_to; /**< address to deliver */
gchar *user; /**< user to deliver */
gchar *subject; /**< subject (for non-mime) */
gchar *statfile; /**< statfile for learning */
f_str_t *msg; /**< message buffer */
rspamd_io_dispatcher_t *dispatcher; /**< IO dispatcher object */
struct rspamd_async_session* s; /**< async session object */
struct rspamd_async_session* s; /**< async session object */
gint parts_count; /**< mime parts count */
GMimeMessage *message; /**< message, parsed with GMime */
GMimeObject *parser_parent_part; /**< current parent part */
@@ -209,9 +212,9 @@ struct worker_task {
GList *images; /**< list of images */
GList *raw_headers_list; /**< list of raw headers */
GHashTable *results; /**< hash table of metric_result indexed by
* metric's name */
* metric's name */
GHashTable *tokens; /**< hash table of tokens indexed by tokenizer
* pointer */
* pointer */
GList *messages; /**< list of messages that would be reported */
GHashTable *re_cache; /**< cache for matched or not matched regexps */
struct config_file *cfg; /**< pointer to config object */

+ 57
- 5
src/protocol.c View File

@@ -63,6 +63,11 @@
*/
#define MSG_CMD_PROCESS "process"

/*
* Learn specified statfile using message
*/
#define MSG_CMD_LEARN "learn"

/*
* spamassassin greeting:
*/
@@ -81,6 +86,7 @@
#define NRCPT_HEADER "Recipient-Number"
#define RCPT_HEADER "Rcpt"
#define SUBJECT_HEADER "Subject"
#define STATFILE_HEADER "Statfile"
#define QUEUE_ID_HEADER "Queue-ID"
#define ERROR_HEADER "Error"
#define USER_HEADER "User"
@@ -198,6 +204,22 @@ parse_check_command (struct worker_task *task, gchar *token)
return FALSE;
}
break;
case 'l':
case 'L':
if (g_ascii_strcasecmp (token + 1, MSG_CMD_LEARN + 1) == 0) {
if (task->allow_learn) {
task->cmd = CMD_LEARN;
}
else {
msg_info ("learning is disabled");
return FALSE;
}
}
else {
debug_task ("bad command: %s", token);
return FALSE;
}
break;
default:
cur = custom_commands;
while (cur) {
@@ -306,8 +328,8 @@ parse_http_command (struct worker_task *task, f_str_t * line)
}
else {
/* Copy command */
cmd = memory_pool_alloc (task->task_pool, p - c);
rspamd_strlcpy (cmd, c, p - c);
cmd = memory_pool_alloc (task->task_pool, p - c + 1);
rspamd_strlcpy (cmd, c, p - c + 1);
/* Skip the first '/' */
if (*cmd == '/') {
cmd ++;
@@ -379,8 +401,22 @@ parse_header (struct worker_task *task, f_str_t * line)
}
else {
if (task->content_length > 0) {
rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length);
task->state = READ_MESSAGE;
if (task->cmd == CMD_LEARN) {
if (task->statfile != NULL) {
rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length);
task->state = READ_MESSAGE;
}
else {
task->last_error = "Unknown statfile";
task->error_code = RSPAMD_STATFILE_ERROR;
task->state = WRITE_ERROR;
return FALSE;
}
}
else {
rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length);
task->state = READ_MESSAGE;
}
}
else {
task->last_error = "Unknown content length";
@@ -528,6 +564,9 @@ parse_header (struct worker_task *task, f_str_t * line)
if (g_ascii_strncasecmp (headern, SUBJECT_HEADER, sizeof (SUBJECT_HEADER) - 1) == 0) {
task->subject = memory_pool_fstrdup (task->task_pool, line);
}
else if (g_ascii_strncasecmp (headern, STATFILE_HEADER, sizeof (STATFILE_HEADER) - 1) == 0) {
task->statfile = memory_pool_fstrdup (task->task_pool, line);
}
else {
return FALSE;
}
@@ -1433,7 +1472,7 @@ write_reply (struct worker_task *task)
/* Write error message and error code to reply */
if (task->is_http) {
r = rspamd_snprintf (outbuf, sizeof (outbuf), "HTTP/1.0 400 Bad request" CRLF
"Connection: close" CRLF CRLF);
"Connection: close" CRLF CRLF "Error: %d - %s" CRLF, task->error_code, task->last_error);
}
else {
if (task->proto == SPAMC_PROTO) {
@@ -1471,6 +1510,19 @@ write_reply (struct worker_task *task)
(task->proto == SPAMC_PROTO) ? SPAMD_REPLY_BANNER : RSPAMD_REPLY_BANNER, rspamc_proto_str (task->proto_ver));
return rspamd_dispatcher_write (task->dispatcher, outbuf, r, FALSE, FALSE);
break;
case CMD_LEARN:
if (task->is_http) {
r = rspamd_snprintf (outbuf, sizeof (outbuf), "HTTP/1.0 200 Ok" CRLF
"Connection: close" CRLF CRLF "%s" CRLF, task->last_error);
}
else {
r = rspamd_snprintf (outbuf, sizeof (outbuf), "%s/%s 0 LEARN" CRLF CRLF "%s" CRLF,
(task->proto == SPAMC_PROTO) ? SPAMD_REPLY_BANNER : RSPAMD_REPLY_BANNER,
rspamc_proto_str (task->proto_ver),
task->last_error);
}
return rspamd_dispatcher_write (task->dispatcher, outbuf, r, FALSE, FALSE);
break;
case CMD_OTHER:
return task->custom_cmd->func (task);
}

+ 2
- 0
src/protocol.h View File

@@ -13,6 +13,7 @@
#define RSPAMD_NETWORK_ERROR 2
#define RSPAMD_PROTOCOL_ERROR 3
#define RSPAMD_LENGTH_ERROR 4
#define RSPAMD_STATFILE_ERROR 5

#define RSPAMC_PROTO_1_0 "1.0"
#define RSPAMC_PROTO_1_1 "1.1"
@@ -44,6 +45,7 @@ enum rspamd_command {
CMD_SKIP,
CMD_PING,
CMD_PROCESS,
CMD_LEARN,
CMD_OTHER,
};


+ 40
- 17
src/worker.c View File

@@ -89,6 +89,8 @@ struct rspamd_worker_ctx {
gboolean is_http;
/* JSON output */
gboolean is_json;
/* Allow learning throught worker */
gboolean allow_learn;
GList *custom_filters;
/* DNS resolver */
struct rspamd_dns_resolver *resolver;
@@ -318,6 +320,7 @@ read_socket (f_str_t * in, void *arg)
struct worker_task *task = (struct worker_task *) arg;
struct rspamd_worker_ctx *ctx;
ssize_t r;
GError *err = NULL;

ctx = task->worker->ctx;
switch (task->state) {
@@ -332,8 +335,10 @@ read_socket (f_str_t * in, void *arg)
}
else {
if (!read_rspamd_input_line (task, in)) {
task->last_error = "Read error";
task->error_code = RSPAMD_NETWORK_ERROR;
if (!task->last_error) {
task->last_error = "Read error";
task->error_code = RSPAMD_NETWORK_ERROR;
}
task->state = WRITE_ERROR;
}
}
@@ -359,22 +364,38 @@ read_socket (f_str_t * in, void *arg)
task->state = WRITE_REPLY;
return write_socket (task);
}
r = process_filters (task);
if (r == -1) {
task->last_error = "Filter processing error";
task->error_code = RSPAMD_FILTER_ERROR;
task->state = WRITE_ERROR;
else if (task->cmd == CMD_LEARN) {
if (!learn_task (task->statfile, task, &err)) {
task->last_error = memory_pool_strdup (task->task_pool, err->message);
task->error_code = err->code;
g_error_free (err);
task->state = WRITE_ERROR;
}
else {
task->last_error = "learn ok";
task->error_code = 0;
task->state = WRITE_REPLY;
}
return write_socket (task);
}
else if (r == 0) {
task->state = WAIT_FILTER;
rspamd_dispatcher_pause (task->dispatcher);
}
else {
process_statfiles (task);
lua_call_post_filters (task);
task->state = WRITE_REPLY;
return write_socket (task);
r = process_filters (task);
if (r == -1) {
task->last_error = "Filter processing error";
task->error_code = RSPAMD_FILTER_ERROR;
task->state = WRITE_ERROR;
return write_socket (task);
}
else if (r == 0) {
task->state = WAIT_FILTER;
rspamd_dispatcher_pause (task->dispatcher);
}
else {
process_statfiles (task);
lua_call_post_filters (task);
task->state = WRITE_REPLY;
return write_socket (task);
}
}
break;
case WRITE_REPLY:
@@ -515,9 +536,8 @@ construct_task (struct rspamd_worker *worker)
{
struct worker_task *new_task;

new_task = g_malloc (sizeof (struct worker_task));
new_task = g_malloc0 (sizeof (struct worker_task));

bzero (new_task, sizeof (struct worker_task));
new_task->worker = worker;
new_task->state = READ_COMMAND;
new_task->cfg = worker->srv->cfg;
@@ -605,10 +625,12 @@ accept_socket (gint fd, short what, void *arg)
sizeof (struct in_addr));
}

/* Copy some variables */
new_task->sock = nfd;
new_task->is_mime = ctx->is_mime;
new_task->is_json = ctx->is_json;
new_task->is_http = ctx->is_http;
new_task->allow_learn = ctx->allow_learn;

worker->srv->stat->connections_count++;
new_task->resolver = ctx->resolver;
@@ -750,6 +772,7 @@ init_worker (void)
register_worker_opt (TYPE_WORKER, "mime", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_mime));
register_worker_opt (TYPE_WORKER, "http", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_http));
register_worker_opt (TYPE_WORKER, "json", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_json));
register_worker_opt (TYPE_WORKER, "allow_learn", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, allow_learn));
register_worker_opt (TYPE_WORKER, "timeout", xml_handle_seconds, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, timeout));

return ctx;

Loading…
Cancel
Save