From: Vsevolod Stakhov Date: Tue, 12 Jul 2011 16:46:55 +0000 (+0400) Subject: * First commit to implement multi-statfile filter system with new learning mechanizm... X-Git-Tag: 0.4.0~31 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=ff4871310ff5b269dcd02ea300cf78092860e1d4;p=rspamd.git * First commit to implement multi-statfile filter system with new learning mechanizm (untested yet) --- diff --git a/src/cfg_file.h b/src/cfg_file.h index 7e2dfd413..d15923639 100644 --- a/src/cfg_file.h +++ b/src/cfg_file.h @@ -422,14 +422,17 @@ gboolean get_config_checksum (struct config_file *cfg); void unescape_quotes (gchar *line); GList* parse_comma_list (memory_pool_t *pool, gchar *line); -struct classifier_config* check_classifier_cfg (struct config_file *cfg, struct classifier_config *c); +struct classifier_config* check_classifier_conf (struct config_file *cfg, struct classifier_config *c); struct worker_conf* check_worker_conf (struct config_file *cfg, struct worker_conf *c); struct metric* check_metric_conf (struct config_file *cfg, struct metric *c); +struct statfile* check_statfile_conf (struct config_file *cfg, struct statfile *c); gboolean parse_normalizer (struct config_file *cfg, struct statfile *st, const gchar *line); gboolean read_xml_config (struct config_file *cfg, const gchar *filename); gboolean check_modules_config (struct config_file *cfg); void insert_classifier_symbols (struct config_file *cfg); +struct classifier_config* find_classifier_conf (struct config_file *cfg, const gchar *name); + #endif /* ifdef CFG_FILE_H */ /* * vi:ts=4 diff --git a/src/cfg_utils.c b/src/cfg_utils.c index 720d931ef..6bd16d620 100644 --- a/src/cfg_utils.c +++ b/src/cfg_utils.c @@ -751,7 +751,7 @@ parse_comma_list (memory_pool_t * pool, gchar *line) } struct classifier_config * -check_classifier_cfg (struct config_file *cfg, struct classifier_config *c) +check_classifier_conf (struct config_file *cfg, struct classifier_config *c) { if (c == NULL) { c = memory_pool_alloc0 (cfg->cfg_pool, sizeof (struct classifier_config)); @@ -764,6 +764,20 @@ check_classifier_cfg (struct config_file *cfg, struct classifier_config *c) return c; } +struct statfile* +check_statfile_conf (struct config_file *cfg, struct statfile *c) +{ + if (c == NULL) { + c = memory_pool_alloc0 (cfg->cfg_pool, sizeof (struct statfile)); + } + if (c->opts == NULL) { + c->opts = g_hash_table_new (g_str_hash, g_str_equal); + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func) g_hash_table_destroy, c->opts); + } + + return c; +} + struct metric * check_metric_conf (struct config_file *cfg, struct metric *c) { @@ -1006,6 +1020,30 @@ insert_classifier_symbols (struct config_file *cfg) g_hash_table_foreach (cfg->classifiers_symbols, symbols_classifiers_callback, cfg); } +struct classifier_config* +find_classifier_conf (struct config_file *cfg, const gchar *name) +{ + GList *cur; + struct classifier_config *cf; + + if (name == NULL) { + return NULL; + } + + cur = cfg->classifiers; + while (cur) { + cf = cur->data; + + if (g_ascii_strcasecmp (cf->classifier->name, name) == 0) { + return cf; + } + + cur = g_list_next (cur); + } + + return NULL; +} + /* * vi:ts=4 */ diff --git a/src/cfg_xml.c b/src/cfg_xml.c index 09f6574a0..6953edb3a 100644 --- a/src/cfg_xml.c +++ b/src/cfg_xml.c @@ -470,6 +470,12 @@ static struct xml_parser_rule grammar[] = { G_STRUCT_OFFSET (struct statfile, size), NULL }, + { + "spam", + xml_handle_boolean, + G_STRUCT_OFFSET (struct statfile, is_spam), + NULL + }, { "normalizer", handle_statfile_normalizer, @@ -496,7 +502,11 @@ static struct xml_parser_rule grammar[] = { }, NULL_ATTR }, - NULL_DEF_ATTR + { + handle_statfile_opt, + 0, + NULL + } }, { XML_SECTION_MODULE, { NULL_ATTR @@ -1017,9 +1027,9 @@ static void set_lua_globals (struct config_file *cfg, lua_State *L) { struct config_file **pcfg; + /* First check for global variable 'config' */ lua_getglobal (L, "config"); - if (lua_isnil (L, -1)) { /* Assign global table to set up attributes */ lua_newtable (L); @@ -1038,13 +1048,19 @@ set_lua_globals (struct config_file *cfg, lua_State *L) lua_setglobal (L, "composites"); } + lua_getglobal (L, "classifiers"); + if (lua_isnil (L, -1)) { + lua_newtable (L); + lua_setglobal (L, "classifiers"); + } + pcfg = lua_newuserdata (L, sizeof (struct config_file *)); lua_setclass (L, "rspamd{config}", -1); *pcfg = cfg; lua_setglobal (L, "rspamd_config"); /* Clear stack from globals */ - lua_pop (L, 3); + lua_pop (L, 4); } /* Handle lua tag */ @@ -1402,6 +1418,27 @@ handle_statfile_binlog_master (struct config_file *cfg, struct rspamd_xml_userda return TRUE; } +gboolean +handle_statfile_opt (struct config_file *cfg, struct rspamd_xml_userdata *ctx, const gchar *tag, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset) +{ + struct statfile *st = ctx->section_pointer; + const gchar *name; + + if (g_ascii_strcasecmp (tag, "option") == 0 || g_ascii_strcasecmp (tag, "param") == 0) { + if (attrs == NULL || (name = g_hash_table_lookup (attrs, "name")) == NULL) { + msg_err ("worker param tag must have \"name\" attribute"); + return FALSE; + } + } + else { + name = memory_pool_strdup (cfg->cfg_pool, tag); + } + + g_hash_table_insert (st->opts, (char *)name, memory_pool_strdup (cfg->cfg_pool, data)); + + return TRUE; +} + /* Common handlers */ gboolean xml_handle_string (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset) @@ -1617,7 +1654,7 @@ rspamd_xml_start_element (GMarkupParseContext *context, const gchar *element_nam if (extract_attr ("type", attribute_names, attribute_values, &res)) { ud->state = XML_READ_CLASSIFIER; /* Create object */ - ccf = check_classifier_cfg (ud->cfg, NULL); + ccf = check_classifier_conf (ud->cfg, NULL); if ((ccf->classifier = get_classifier (res)) == NULL) { *error = g_error_new (xml_error_quark (), XML_INVALID_ATTR, "invalid classifier type: %s", res); ud->state = XML_ERROR; @@ -1665,7 +1702,7 @@ rspamd_xml_start_element (GMarkupParseContext *context, const gchar *element_nam /* Now section pointer is statfile and parent pointer is classifier */ ud->parent_pointer = ud->section_pointer; - ud->section_pointer = memory_pool_alloc0 (ud->cfg->cfg_pool, sizeof (struct statfile)); + ud->section_pointer = check_statfile_conf (ud->cfg, NULL); } else { rspamd_strlcpy (ud->section_name, element_name, sizeof (ud->section_name)); diff --git a/src/cfg_xml.h b/src/cfg_xml.h index b48f3d96c..659f35c19 100644 --- a/src/cfg_xml.h +++ b/src/cfg_xml.h @@ -154,6 +154,7 @@ gboolean handle_statfile_normalizer (struct config_file *cfg, struct rspamd_xml_ gboolean handle_statfile_binlog (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset); gboolean handle_statfile_binlog_rotate (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset); gboolean handle_statfile_binlog_master (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset); +gboolean handle_statfile_opt (struct config_file *cfg, struct rspamd_xml_userdata *ctx, const gchar *tag, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset); /* Register new module option */ void register_module_opt (const gchar *mname, const gchar *optname, enum module_opt_type type); diff --git a/src/classifiers/bayes.c b/src/classifiers/bayes.c index 7363df522..44e9323a2 100644 --- a/src/classifiers/bayes.c +++ b/src/classifiers/bayes.c @@ -30,9 +30,8 @@ #include "../main.h" #include "../filter.h" #include "../cfg_file.h" -#ifdef WITH_LUA +#include "../binlog.h" #include "../lua/lua_common.h" -#endif #define LOCAL_PROB_DENOM 16.0 @@ -194,15 +193,22 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, } } - data.statfiles_num = g_list_length (ctx->cfg->statfiles); + cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE); + if (cur) { + memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur); + } + else { + cur = ctx->cfg->statfiles; + } + + data.statfiles_num = g_list_length (cur); data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num); data.pool = pool; data.now = time (NULL); data.ctx = ctx; - cur = ctx->cfg->statfiles; while (cur) { - /* Select statfile to learn */ + /* Select statfile to classify */ st = cur->data; if ((file = statfile_pool_is_open (pool, st->path)) == NULL) { if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { @@ -344,6 +350,70 @@ bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symb return TRUE; } +gboolean +bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool, + GTree *input, struct worker_task *task, gboolean is_spam, GError **err) +{ + struct bayes_callback_data data; + gchar *value; + gint nodes, minnodes; + struct statfile *st; + stat_file_t *file; + GList *cur; + + g_assert (pool != NULL); + g_assert (ctx != NULL); + + if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { + minnodes = strtol (value, NULL, 10); + nodes = g_tree_nnodes (input); + if (nodes > FEATURE_WINDOW_SIZE) { + nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE; + } + if (nodes < minnodes) { + return FALSE; + } + } + + cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE); + if (cur) { + memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur); + } + else { + cur = ctx->cfg->statfiles; + } + + data.pool = pool; + data.now = time (NULL); + data.ctx = ctx; + + while (cur) { + /* Select statfiles to learn */ + st = cur->data; + if (st->is_spam != is_spam) { + cur = g_list_next (cur); + continue; + } + if ((file = statfile_pool_is_open (pool, st->path)) == NULL) { + if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { + msg_warn ("cannot open %s", st->path); + cur = g_list_next (cur); + continue; + } + } + data.file = file; + statfile_pool_lock_file (pool, data.file); + g_tree_foreach (input, bayes_learn_callback, &data); + statfile_inc_revision (file); + statfile_pool_unlock_file (pool, data.file); + maybe_write_binlog (ctx->cfg, st, file, input); + + cur = g_list_next (cur); + } + + return TRUE; +} + GList * bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task) { diff --git a/src/classifiers/classifiers.c b/src/classifiers/classifiers.c index 6b0554e1b..5e2b9ea88 100644 --- a/src/classifiers/classifiers.c +++ b/src/classifiers/classifiers.c @@ -35,6 +35,7 @@ struct classifier classifiers[] = { .init_func = winnow_init, .classify_func = winnow_classify, .learn_func = winnow_learn, + .learn_spam_func = winnow_learn_spam, .weights_func = winnow_weights }, { @@ -42,6 +43,7 @@ struct classifier classifiers[] = { .init_func = bayes_init, .classify_func = bayes_classify, .learn_func = bayes_learn, + .learn_spam_func = bayes_learn_spam, .weights_func = bayes_weights } }; diff --git a/src/classifiers/classifiers.h b/src/classifiers/classifiers.h index 601db0205..78ceb196e 100644 --- a/src/classifiers/classifiers.h +++ b/src/classifiers/classifiers.h @@ -32,6 +32,8 @@ struct classifier { gboolean (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input, gboolean in_class, double *sum, double multiplier, GError **err); + gboolean (*learn_spam_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, + GTree *input, struct worker_task *task, gboolean is_spam, GError **err); GList* (*weights_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); }; @@ -43,6 +45,8 @@ struct classifier_ctx* winnow_init (memory_pool_t *pool, struct classifier_confi gboolean winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); gboolean winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input, gboolean in_class, double *sum, double multiplier, GError **err); +gboolean winnow_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool, + GTree *input, struct worker_task *task, gboolean is_spam, GError **err); GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); /* Bayes algorithm */ @@ -50,6 +54,8 @@ struct classifier_ctx* bayes_init (memory_pool_t *pool, struct classifier_config gboolean bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); gboolean bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input, gboolean in_class, double *sum, double multiplier, GError **err); +gboolean bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool, + GTree *input, struct worker_task *task, gboolean is_spam, GError **err); GList *bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); /* Array of all defined classifiers */ extern struct classifier classifiers[]; diff --git a/src/classifiers/winnow.c b/src/classifiers/winnow.c index 2e8b98423..b123ce3e5 100644 --- a/src/classifiers/winnow.c +++ b/src/classifiers/winnow.c @@ -223,19 +223,14 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp } } - if (ctx->cfg->pre_callbacks) { -#ifdef WITH_LUA - cur = call_classifier_pre_callbacks (ctx->cfg, task); - if (cur) { - memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur); - } -#else - cur = ctx->cfg->statfiles; -#endif - } - else { - cur = ctx->cfg->statfiles; - } + cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE); + if (cur) { + memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur); + } + else { + cur = ctx->cfg->statfiles; + } + while (cur) { st = cur->data; data.sum = 0; @@ -597,3 +592,15 @@ end: } return TRUE; } + +gboolean +winnow_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool, + GTree *input, struct worker_task *task, gboolean is_spam, GError **err) +{ + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "learn spam is not supported for winnow" + ); + return FALSE; +} diff --git a/src/controller.c b/src/controller.c index e7e4c10d8..ce33c53a2 100644 --- a/src/controller.c +++ b/src/controller.c @@ -35,6 +35,7 @@ #include "classifiers/classifiers.h" #include "binlog.h" #include "statfile_sync.h" +#include "lua/lua_common.h" #define END "END" CRLF @@ -49,6 +50,8 @@ enum command_type { COMMAND_SHUTDOWN, COMMAND_UPTIME, COMMAND_LEARN, + COMMAND_LEARN_SPAM, + COMMAND_LEARN_HAM, COMMAND_HELP, COMMAND_COUNTERS, COMMAND_SYNC, @@ -84,7 +87,9 @@ static struct controller_command commands[] = { {"weights", FALSE, COMMAND_WEIGHTS}, {"help", FALSE, COMMAND_HELP}, {"counters", FALSE, COMMAND_COUNTERS}, - {"sync", FALSE, COMMAND_SYNC} + {"sync", FALSE, COMMAND_SYNC}, + {"learn_spam", TRUE, COMMAND_LEARN_SPAM}, + {"learn_ham", TRUE, COMMAND_LEARN_HAM} }; static GList *custom_commands = NULL; @@ -519,7 +524,7 @@ process_command (struct controller_command *cmd, gchar **cmd_args, struct contro } } break; - case COMMAND_LEARN: + case COMMAND_LEARN_SPAM: if (check_auth (cmd, session)) { arg = *cmd_args; if (!arg || *arg == '\0') { @@ -532,13 +537,105 @@ process_command (struct controller_command *cmd, gchar **cmd_args, struct contro } arg = *(cmd_args + 1); if (arg == NULL || *arg == '\0') { - msg_debug ("no statfile size specified in learn command"); + msg_debug ("no message size specified in learn command"); + r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: symbol and message size" CRLF); + if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) { + return FALSE; + } + return TRUE; + } + size = strtoul (arg, &err_str, 10); + if (err_str && *err_str != '\0') { + msg_debug ("message size is invalid: %s", arg); + r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF); + if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) { + return FALSE; + } + return TRUE; + } + cl = find_classifier_conf (session->cfg, *cmd_args); + if (cl == NULL) { + r = rspamd_snprintf (out_buf, sizeof (out_buf), "classifier %s is not defined" CRLF, *cmd_args); + if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) { + return FALSE; + } + return TRUE; + + } + session->learn_classifier = cl; + + /* By default learn positive */ + session->in_class = TRUE; + rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size); + session->state = STATE_LEARN_SPAM; + } + break; + case COMMAND_LEARN_HAM: + if (check_auth (cmd, session)) { + arg = *cmd_args; + if (!arg || *arg == '\0') { + msg_debug ("no statfile specified in learn command"); r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: stat filename and its size" CRLF); if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) { return FALSE; } return TRUE; } + arg = *(cmd_args + 1); + if (arg == NULL || *arg == '\0') { + msg_debug ("no message size specified in learn command"); + r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: symbol and message size" CRLF); + if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) { + return FALSE; + } + return TRUE; + } + size = strtoul (arg, &err_str, 10); + if (err_str && *err_str != '\0') { + msg_debug ("message size is invalid: %s", arg); + r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF); + if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) { + return FALSE; + } + return TRUE; + } + cl = find_classifier_conf (session->cfg, *cmd_args); + if (cl == NULL) { + r = rspamd_snprintf (out_buf, sizeof (out_buf), "classifier %s is not defined" CRLF, *cmd_args); + if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) { + return FALSE; + } + return TRUE; + + } + session->learn_classifier = cl; + + /* By default learn positive */ + session->in_class = FALSE; + rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size); + session->state = STATE_LEARN_SPAM; + } + break; + case COMMAND_LEARN: + if (check_auth (cmd, session)) { + arg = *cmd_args; + if (!arg || *arg == '\0') { + msg_debug ("no statfile specified in learn command"); + r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: stat filename and its size" CRLF); + if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) { + return FALSE; + } + return TRUE; + } + arg = *(cmd_args + 1); + if (arg == NULL || *arg == '\0') { + msg_debug ("no message size specified in learn command"); + r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: symbol and message size" CRLF); + if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) { + return FALSE; + } + return TRUE; + } size = strtoul (arg, &err_str, 10); if (err_str && *err_str != '\0') { msg_debug ("message size is invalid: %s", arg); @@ -817,6 +914,65 @@ controller_read_socket (f_str_t * in, void *arg) } session->state = STATE_REPLY; break; + case STATE_LEARN_SPAM_PRE: + session->learn_buf = in; + task = construct_task (session->worker); + + task->msg = memory_pool_alloc (task->task_pool, sizeof (f_str_t)); + task->msg->begin = in->begin; + task->msg->len = in->len; + + + r = process_message (task); + if (r == -1) { + msg_warn ("processing of message failed"); + free_task (task, FALSE); + session->state = STATE_REPLY; + r = rspamd_snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF); + if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) { + return FALSE; + } + return FALSE; + } + + r = process_filters (task); + if (r == -1) { + session->state = STATE_REPLY; + r = rspamd_snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF); + free_task (task, FALSE); + if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) { + return FALSE; + } + } + else if (r == 0) { + session->state = STATE_LEARN; + task->dispatcher = session->dispatcher; + session->learn_task = task; + rspamd_dispatcher_pause (session->dispatcher); + } + else { + lua_call_post_filters (task); + session->state = STATE_REPLY; + + if (! learn_task_spam (session->learn_classifier, task, session->in_class, &err)) { + if (err) { + i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF END, err->message); + g_error_free (err); + } + else { + i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF END); + } + } + else { + i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok" CRLF END); + } + + free_task (task, FALSE); + if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { + return FALSE; + } + } + break; case STATE_WEIGHTS: session->learn_buf = in; task = construct_task (session->worker); @@ -926,12 +1082,37 @@ static gboolean controller_write_socket (void *arg) { struct controller_session *session = (struct controller_session *)arg; + gint i; + gchar out_buf[64]; + GError *err = NULL; if (session->state == STATE_QUIT) { /* Free buffers */ destroy_session (session->s); return FALSE; } + else if (session->state == STATE_LEARN) { + /* Perform actual learn here */ + if (! learn_task_spam (session->learn_classifier, session->learn_task, session->in_class, &err)) { + if (err) { + i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF END, err->message); + g_error_free (err); + } + else { + i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF END); + } + } + else { + i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok" CRLF END); + } + learn_task_spam (session->learn_classifier, session->learn_task, session->in_class, &err); + session->learn_task->dispatcher = NULL; + free_task (session->learn_task, FALSE); + session->state = STATE_REPLY; + if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { + return FALSE; + } + } else if (session->state == STATE_REPLY) { session->state = STATE_COMMAND; rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_LINE, BUFSIZ); diff --git a/src/filter.c b/src/filter.c index fea91125f..66f233115 100644 --- a/src/filter.c +++ b/src/filter.c @@ -944,6 +944,89 @@ learn_task (const gchar *statfile, struct worker_task *task, GError **err) return TRUE; } +gboolean +learn_task_spam (struct classifier_config *cl, struct worker_task *task, gboolean is_spam, GError **err) +{ + GList *cur, *ex; + struct classifier_ctx *cls_ctx; + f_str_t c; + GTree *tokens = NULL; + struct mime_text_part *part; + gboolean is_utf = FALSE, is_twopart = FALSE; + + cur = g_list_first (task->text_parts); + if (cur != NULL && cur->next != NULL && cur->next->next == NULL) { + is_twopart = TRUE; + } + + /* Get tokens from each element */ + while (cur) { + part = cur->data; + /* Skip empty parts */ + if (part->is_empty) { + cur = g_list_next (cur); + continue; + } + c.begin = part->content->data; + c.len = part->content->len; + is_utf = part->is_utf; + ex = part->urls_offset; + if (is_twopart && cur->next == NULL) { + /* Compare part's content */ + if (fuzzy_compare_parts (cur->data, cur->prev->data) >= COMMON_PART_FACTOR) { + msg_info ("message <%s> has two common text parts, ignore the last one", task->message_id); + break; + } + } + /* Get tokens */ + if (!cl->tokenizer->tokenize_func ( + cl->tokenizer, task->task_pool, + &c, &tokens, FALSE, is_utf, ex)) { + g_set_error (err, filter_error_quark(), 2, "Cannot tokenize message"); + return FALSE; + } + cur = g_list_next (cur); + } + + /* Handle messages without text */ + if (tokens == NULL) { + g_set_error (err, filter_error_quark(), 3, "Cannot tokenize message, no text data"); + msg_info ("learn failed for message <%s>, no tokens to extract", task->message_id); + return FALSE; + } + + /* Take care of subject */ + tokenize_subject (task, &tokens); + + /* Init classifier */ + cls_ctx = cl->classifier->init_func ( + task->task_pool, cl); + /* Learn */ + if (!cl->classifier->learn_spam_func ( + cls_ctx, task->worker->srv->statfile_pool, + tokens, task, is_spam, err)) { + if (*err) { + msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, (*err)->message); + return FALSE; + } + else { + g_set_error (err, filter_error_quark(), 4, "Learn failed, unknown learn classifier error"); + msg_info ("learn failed for message <%s>, unknown learn error", task->message_id); + return FALSE; + } + } + /* Increase statistics */ + task->worker->srv->stat->messages_learned++; + + msg_info ("learn success for message <%s>", + task->message_id); + statfile_pool_plan_invalidate (task->worker->srv->statfile_pool, + DEFAULT_STATFILE_INVALIDATE_TIME, + DEFAULT_STATFILE_INVALIDATE_JITTER); + + return TRUE; +} + /* * vi:ts=4 */ diff --git a/src/filter.h b/src/filter.h index cfcab7cb5..930594170 100644 --- a/src/filter.h +++ b/src/filter.h @@ -11,6 +11,7 @@ struct worker_task; struct rspamd_settings; +struct classifier_config; typedef double (*metric_cons_func)(struct worker_task *task, const gchar *metric_name, const gchar *func_name); typedef void (*filter_func)(struct worker_task *task); @@ -136,6 +137,15 @@ double factor_consolidation_func (struct worker_task *task, const gchar *metric_ */ gboolean learn_task (const gchar *statfile, struct worker_task *task, GError **err); +/* + * Learn specified statfile with message in a task + * @param statfile symbol of statfile + * @param task worker's task object + * @param err pointer to GError + * @return true if learn succeed + */ +gboolean learn_task_spam (struct classifier_config *cl, struct worker_task *task, gboolean is_spam, GError **err); + gboolean check_action_str (const gchar *data, gint *result); const gchar *str_action_metric (enum rspamd_metric_action action); gint check_metric_action (double score, double required_score, struct metric *metric); diff --git a/src/html.c b/src/html.c index 306e1e700..9a6541610 100644 --- a/src/html.c +++ b/src/html.c @@ -655,29 +655,6 @@ decode_entitles (gchar *s, guint * len) } } -/* - * Find the first occurrence of find in s, ignore case. - */ -static gchar * -html_strncasestr (const gchar *s, const gchar *find, gsize len) -{ - gchar c, sc; - gsize mlen; - - if ((c = *find++) != 0) { - c = g_ascii_tolower (c); - mlen = strlen (find); - do { - do { - if ((sc = *s++) == 0 || len -- == 0) - return (NULL); - } while (g_ascii_tolower (sc) != c); - } while (g_ascii_strncasecmp (s, find, mlen) != 0); - s--; - } - return ((gchar *)s); -} - static void check_phishing (struct worker_task *task, struct uri *href_url, const gchar *url_text, gsize remain, tag_id_t id) { @@ -803,11 +780,11 @@ parse_tag_url (struct worker_task *task, struct mime_text_part *part, tag_id_t i /* For A tags search for href= and for IMG tags search for src= */ if (id == Tag_A) { - c = html_strncasestr (tag_text, "href=", tag_len); + c = rspamd_strncasestr (tag_text, "href=", tag_len); len = sizeof ("href=") - 1; } else if (id == Tag_IMG) { - c = html_strncasestr (tag_text, "src=", tag_len); + c = rspamd_strncasestr (tag_text, "src=", tag_len); len = sizeof ("src=") - 1; } diff --git a/src/lua/lua_classifier.c b/src/lua/lua_classifier.c index f7ce173a7..511767680 100644 --- a/src/lua/lua_classifier.c +++ b/src/lua/lua_classifier.c @@ -41,7 +41,18 @@ static const struct luaL_reg classifierlib_m[] = { }; +LUA_FUNCTION_DEF (statfile, get_symbol); +LUA_FUNCTION_DEF (statfile, get_path); +LUA_FUNCTION_DEF (statfile, get_size); +LUA_FUNCTION_DEF (statfile, is_spam); +LUA_FUNCTION_DEF (statfile, get_param); + static const struct luaL_reg statfilelib_m[] = { + LUA_INTERFACE_DEF (statfile, get_symbol), + LUA_INTERFACE_DEF (statfile, get_path), + LUA_INTERFACE_DEF (statfile, get_size), + LUA_INTERFACE_DEF (statfile, is_spam), + LUA_INTERFACE_DEF (statfile, get_param), {"__tostring", lua_class_tostring}, {NULL, NULL} }; @@ -64,50 +75,81 @@ lua_check_classifier (lua_State * L) return *((struct classifier_config **)ud); } -/* Return list of statfiles that should be checked for this message */ -GList * -call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task) +static GList * +call_classifier_pre_callback (struct classifier_config *ccf, struct worker_task *task, + lua_State *L, gboolean is_learn, gboolean is_spam) { - GList *res = NULL, *cur; - struct classifier_callback_data *cd; struct classifier_config **pccf; struct worker_task **ptask; struct statfile *st; gint i, len; + GList *res = NULL; - /* Go throught all callbacks and call them, appending results to list */ - cur = g_list_first (ccf->pre_callbacks); - while (cur) { - cd = cur->data; - lua_getglobal (cd->L, cd->name); + pccf = lua_newuserdata (L, sizeof (struct classifier_config *)); + lua_setclass (L, "rspamd{classifier}", -1); + *pccf = ccf; - pccf = lua_newuserdata (cd->L, sizeof (struct classifier_config *)); - lua_setclass (cd->L, "rspamd{classifier}", -1); - *pccf = ccf; + ptask = lua_newuserdata (L, sizeof (struct worker_task *)); + lua_setclass (L, "rspamd{task}", -1); + *ptask = task; - ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); - lua_setclass (cd->L, "rspamd{task}", -1); - *ptask = task; + lua_pushboolean (L, is_learn); + lua_pushboolean (L, is_spam); - if (lua_pcall (cd->L, 2, 1, 0) != 0) { - msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1)); - } - else { - if (lua_istable (cd->L, 1)) { - len = lua_objlen (cd->L, 1); - for (i = 1; i <= len; i ++) { - lua_rawgeti (cd->L, 1, i); - st = lua_check_statfile (cd->L); - if (st) { - res = g_list_prepend (res, st); - } + if (lua_pcall (L, 4, 1, 0) != 0) { + msg_warn ("error running pre classifier callback %s", lua_tostring (L, -1)); + } + else { + if (lua_istable (L, 1)) { + len = lua_objlen (L, 1); + for (i = 1; i <= len; i ++) { + lua_rawgeti (L, 1, i); + st = lua_check_statfile (L); + if (st) { + res = g_list_prepend (res, st); } } } + } + + return res; +} + +/* Return list of statfiles that should be checked for this message */ +GList * +call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task, + gboolean is_learn, gboolean is_spam) +{ + GList *res = NULL, *cur; + struct classifier_callback_data *cd; + lua_State *L; + + + /* Go throught all callbacks and call them, appending results to list */ + cur = g_list_first (ccf->pre_callbacks); + while (cur) { + cd = cur->data; + lua_getglobal (cd->L, cd->name); + + res = g_list_concat (res, call_classifier_pre_callback (ccf, task, cd->L, is_learn, is_spam)); cur = g_list_next (cur); } + if (res == NULL) { + L = task->cfg->lua_state; + /* Check function from global table 'classifiers' */ + lua_getglobal (L, "classifiers"); + if (lua_istable (L, -1)) { + lua_pushstring (L, ccf->classifier->name); + lua_gettable (L, -2); + /* Function is now on top */ + if (lua_isfunction (L, 1)) { + res = call_classifier_pre_callback (ccf, task, L, is_learn, is_spam); + } + } + } + return res; } @@ -226,6 +268,86 @@ lua_classifier_get_statfiles (lua_State *L) } /* Statfile functions */ +static gint +lua_statfile_get_symbol (lua_State *L) +{ + struct statfile *st = lua_check_statfile (L); + + if (st != NULL) { + lua_pushstring (L, st->symbol); + } + else { + lua_pushnil (L); + } + + return 1; +} + +static gint +lua_statfile_get_path (lua_State *L) +{ + struct statfile *st = lua_check_statfile (L); + + if (st != NULL) { + lua_pushstring (L, st->path); + } + else { + lua_pushnil (L); + } + + return 1; +} + +static gint +lua_statfile_get_size (lua_State *L) +{ + struct statfile *st = lua_check_statfile (L); + + if (st != NULL) { + lua_pushinteger (L, st->size); + } + else { + lua_pushnil (L); + } + + return 1; +} + +static gint +lua_statfile_is_spam (lua_State *L) +{ + struct statfile *st = lua_check_statfile (L); + + if (st != NULL) { + lua_pushboolean (L, st->is_spam); + } + else { + lua_pushnil (L); + } + + return 1; +} + +static gint +lua_statfile_get_param (lua_State *L) +{ + struct statfile *st = lua_check_statfile (L); + const gchar *param, *value; + + param = luaL_checkstring (L, 2); + + if (st != NULL && param != NULL) { + value = g_hash_table_lookup (st->opts, param); + if (param != NULL) { + lua_pushstring (L, value); + return 1; + } + } + lua_pushnil (L); + + return 1; +} + static struct statfile * lua_check_statfile (lua_State * L) { diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index 4427dcae9..3f3fc7a1d 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -51,7 +51,7 @@ void lua_call_post_filters (struct worker_task *task); void add_luabuf (const gchar *line); /* Classify functions */ -GList *call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task); +GList *call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task, gboolean is_learn, gboolean is_spam); double call_classifier_post_callbacks (struct classifier_config *ccf, struct worker_task *task, double in); double lua_normalizer_func (struct config_file *cfg, long double score, void *params); diff --git a/src/main.h b/src/main.h index 9a3335d0a..d8761617f 100644 --- a/src/main.h +++ b/src/main.h @@ -64,6 +64,7 @@ struct classifier_config; struct mime_part; struct rspamd_view; struct rspamd_dns_resolver; +struct worker_task; /** * Server statistics @@ -138,6 +139,8 @@ struct controller_session { enum { STATE_COMMAND, STATE_LEARN, + STATE_LEARN_SPAM_PRE, + STATE_LEARN_SPAM, STATE_REPLY, STATE_QUIT, STATE_OTHER, @@ -162,6 +165,7 @@ struct controller_session { f_str_t *in); /**< other command handler to execute at the end of processing */ void *other_data; /**< and its data */ struct rspamd_async_session* s; /**< async session object */ + struct worker_task *learn_task; }; typedef void (*controller_func_t)(gchar **args, struct controller_session *session); diff --git a/src/util.c b/src/util.c index 6d8cb09e0..16ed43786 100644 --- a/src/util.c +++ b/src/util.c @@ -1363,6 +1363,29 @@ escape_braces_addr_fstr (memory_pool_t *pool, f_str_t *in) return res; } +/* + * Find the first occurrence of find in s, ignore case. + */ +gchar * +rspamd_strncasestr (const gchar *s, const gchar *find, gint len) +{ + gchar c, sc; + gsize mlen; + + if ((c = *find++) != 0) { + c = g_ascii_tolower (c); + mlen = strlen (find); + do { + do { + if ((sc = *s++) == 0 || len -- == 0) + return (NULL); + } while (g_ascii_tolower (sc) != c); + } while (g_ascii_strncasecmp (s, find, mlen) != 0); + s--; + } + return ((gchar *)s); +} + /* * vi:ts=4 */ diff --git a/src/util.h b/src/util.h index 73d96c759..b1d4ad2e9 100644 --- a/src/util.h +++ b/src/util.h @@ -158,4 +158,7 @@ void free_task (struct worker_task *task, gboolean is_soft); void free_task_hard (gpointer ud); void free_task_soft (gpointer ud); +/* Find string find in string s ignoring case */ +gchar* rspamd_strncasestr (const gchar *s, const gchar *find, gint len); + #endif