]> source.dussan.org Git - rspamd.git/commitdiff
* First commit to implement multi-statfile filter system with new learning mechanizm...
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Tue, 12 Jul 2011 16:46:55 +0000 (20:46 +0400)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Tue, 12 Jul 2011 16:46:55 +0000 (20:46 +0400)
17 files changed:
src/cfg_file.h
src/cfg_utils.c
src/cfg_xml.c
src/cfg_xml.h
src/classifiers/bayes.c
src/classifiers/classifiers.c
src/classifiers/classifiers.h
src/classifiers/winnow.c
src/controller.c
src/filter.c
src/filter.h
src/html.c
src/lua/lua_classifier.c
src/lua/lua_common.h
src/main.h
src/util.c
src/util.h

index 7e2dfd413d46539cc795ef811f337c43cb1d0c9f..d15923639a29d35b373a66cb8d33e6cec06abae7 100644 (file)
@@ -422,14 +422,17 @@ gboolean get_config_checksum (struct config_file *cfg);
 void unescape_quotes (gchar *line);
 
 GList* parse_comma_list (memory_pool_t *pool, gchar *line);
-struct classifier_config* check_classifier_cfg (struct config_file *cfg, struct classifier_config *c);
+struct classifier_config* check_classifier_conf (struct config_file *cfg, struct classifier_config *c);
 struct worker_conf* check_worker_conf (struct config_file *cfg, struct worker_conf *c);
 struct metric* check_metric_conf (struct config_file *cfg, struct metric *c);
+struct statfile* check_statfile_conf (struct config_file *cfg, struct statfile *c);
 gboolean parse_normalizer (struct config_file *cfg, struct statfile *st, const gchar *line);
 gboolean read_xml_config (struct config_file *cfg, const gchar *filename);
 gboolean check_modules_config (struct config_file *cfg);
 void insert_classifier_symbols (struct config_file *cfg);
 
+struct classifier_config* find_classifier_conf (struct config_file *cfg, const gchar *name);
+
 #endif /* ifdef CFG_FILE_H */
 /* 
  * vi:ts=4 
index 720d931ef0c23f3f2df08445aa9f487bc9518080..6bd16d620fb0290eb0f659234ac3170f5844e350 100644 (file)
@@ -751,7 +751,7 @@ parse_comma_list (memory_pool_t * pool, gchar *line)
 }
 
 struct classifier_config       *
-check_classifier_cfg (struct config_file *cfg, struct classifier_config *c)
+check_classifier_conf (struct config_file *cfg, struct classifier_config *c)
 {
        if (c == NULL) {
                c = memory_pool_alloc0 (cfg->cfg_pool, sizeof (struct classifier_config));
@@ -764,6 +764,20 @@ check_classifier_cfg (struct config_file *cfg, struct classifier_config *c)
        return c;
 }
 
+struct statfile*
+check_statfile_conf (struct config_file *cfg, struct statfile *c)
+{
+       if (c == NULL) {
+               c = memory_pool_alloc0 (cfg->cfg_pool, sizeof (struct statfile));
+       }
+       if (c->opts == NULL) {
+               c->opts = g_hash_table_new (g_str_hash, g_str_equal);
+               memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func) g_hash_table_destroy, c->opts);
+       }
+
+       return c;
+}
+
 struct metric *
 check_metric_conf (struct config_file *cfg, struct metric *c)
 {
@@ -1006,6 +1020,30 @@ insert_classifier_symbols (struct config_file *cfg)
        g_hash_table_foreach (cfg->classifiers_symbols, symbols_classifiers_callback, cfg);
 }
 
+struct classifier_config*
+find_classifier_conf (struct config_file *cfg, const gchar *name)
+{
+       GList                          *cur;
+       struct classifier_config       *cf;
+
+       if (name == NULL) {
+               return NULL;
+       }
+
+       cur = cfg->classifiers;
+       while (cur) {
+               cf = cur->data;
+
+               if (g_ascii_strcasecmp (cf->classifier->name, name) == 0) {
+                       return cf;
+               }
+
+               cur = g_list_next (cur);
+       }
+
+       return NULL;
+}
+
 /*
  * vi:ts=4
  */
index 09f6574a00b2f86371c21d39fbfce10430f9ec41..6953edb3a603faec9f4186d889014c0c0523c270 100644 (file)
@@ -470,6 +470,12 @@ static struct xml_parser_rule grammar[] = {
                                G_STRUCT_OFFSET (struct statfile, size),
                                NULL
                        },
+                       {
+                               "spam",
+                               xml_handle_boolean,
+                               G_STRUCT_OFFSET (struct statfile, is_spam),
+                               NULL
+                       },
                        {
                                "normalizer",
                                handle_statfile_normalizer,
@@ -496,7 +502,11 @@ static struct xml_parser_rule grammar[] = {
                        },
                        NULL_ATTR
                },
-               NULL_DEF_ATTR
+               {
+                               handle_statfile_opt,
+                               0,
+                               NULL
+               }
        },
        { XML_SECTION_MODULE, {
                        NULL_ATTR
@@ -1017,9 +1027,9 @@ static void
 set_lua_globals (struct config_file *cfg, lua_State *L)
 {
        struct config_file           **pcfg;
+
        /* First check for global variable 'config' */
        lua_getglobal (L, "config");
-
        if (lua_isnil (L, -1)) {
                /* Assign global table to set up attributes */
                lua_newtable (L);
@@ -1038,13 +1048,19 @@ set_lua_globals (struct config_file *cfg, lua_State *L)
                lua_setglobal (L, "composites");
        }
 
+       lua_getglobal (L, "classifiers");
+       if (lua_isnil (L, -1)) {
+               lua_newtable (L);
+               lua_setglobal (L, "classifiers");
+       }
+
        pcfg = lua_newuserdata (L, sizeof (struct config_file *));
        lua_setclass (L, "rspamd{config}", -1);
        *pcfg = cfg;
        lua_setglobal (L, "rspamd_config");
 
        /* Clear stack from globals */
-       lua_pop (L, 3);
+       lua_pop (L, 4);
 }
 
 /* Handle lua tag */
@@ -1402,6 +1418,27 @@ handle_statfile_binlog_master (struct config_file *cfg, struct rspamd_xml_userda
        return TRUE;
 }
 
+gboolean
+handle_statfile_opt (struct config_file *cfg, struct rspamd_xml_userdata *ctx, const gchar *tag, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset)
+{
+       struct statfile                *st = ctx->section_pointer;
+       const gchar                    *name;
+
+       if (g_ascii_strcasecmp (tag, "option") == 0 || g_ascii_strcasecmp (tag, "param") == 0) {
+               if (attrs == NULL || (name = g_hash_table_lookup (attrs, "name")) == NULL) {
+                       msg_err ("worker param tag must have \"name\" attribute");
+                       return FALSE;
+               }
+       }
+       else {
+               name = memory_pool_strdup (cfg->cfg_pool, tag);
+       }
+
+       g_hash_table_insert (st->opts, (char *)name, memory_pool_strdup (cfg->cfg_pool, data));
+
+       return TRUE;
+}
+
 /* Common handlers */
 gboolean 
 xml_handle_string (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset)
@@ -1617,7 +1654,7 @@ rspamd_xml_start_element (GMarkupParseContext *context, const gchar *element_nam
                                if (extract_attr ("type", attribute_names, attribute_values, &res)) {
                                        ud->state = XML_READ_CLASSIFIER;
                                        /* Create object */
-                                       ccf = check_classifier_cfg (ud->cfg, NULL);
+                                       ccf = check_classifier_conf (ud->cfg, NULL);
                                        if ((ccf->classifier = get_classifier (res)) == NULL) {
                                                *error = g_error_new (xml_error_quark (), XML_INVALID_ATTR, "invalid classifier type: %s", res);
                                                ud->state = XML_ERROR;
@@ -1665,7 +1702,7 @@ rspamd_xml_start_element (GMarkupParseContext *context, const gchar *element_nam
 
                                /* Now section pointer is statfile and parent pointer is classifier */
                                ud->parent_pointer = ud->section_pointer;
-                               ud->section_pointer = memory_pool_alloc0 (ud->cfg->cfg_pool, sizeof (struct statfile));
+                               ud->section_pointer = check_statfile_conf (ud->cfg, NULL);
                        }
                        else {
                                rspamd_strlcpy (ud->section_name, element_name, sizeof (ud->section_name));
index b48f3d96c80cb4d2a64c5b5c38136df8e100a435..659f35c19eb941dd1be8d92dfc3a849b846a64fa 100644 (file)
@@ -154,6 +154,7 @@ gboolean handle_statfile_normalizer (struct config_file *cfg, struct rspamd_xml_
 gboolean handle_statfile_binlog (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset);
 gboolean handle_statfile_binlog_rotate (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset);
 gboolean handle_statfile_binlog_master (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset);
+gboolean handle_statfile_opt (struct config_file *cfg, struct rspamd_xml_userdata *ctx, const gchar *tag, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset);
 
 /* Register new module option */
 void register_module_opt (const gchar *mname, const gchar *optname, enum module_opt_type type);
index 7363df52273343931aab3784b6311d1a227d1d74..44e9323a2051963e3e9091fedc649117ec2b7c89 100644 (file)
@@ -30,9 +30,8 @@
 #include "../main.h"
 #include "../filter.h"
 #include "../cfg_file.h"
-#ifdef WITH_LUA
+#include "../binlog.h"
 #include "../lua/lua_common.h"
-#endif
 
 #define LOCAL_PROB_DENOM 16.0
 
@@ -194,15 +193,22 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input,
                }
        }
 
-       data.statfiles_num = g_list_length (ctx->cfg->statfiles);
+       cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE);
+       if (cur) {
+               memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
+       }
+       else {
+               cur = ctx->cfg->statfiles;
+       }
+
+       data.statfiles_num = g_list_length (cur);
        data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num);
        data.pool = pool;
        data.now = time (NULL);
        data.ctx = ctx;
 
-       cur = ctx->cfg->statfiles;
        while (cur) {
-               /* Select statfile to learn */
+               /* Select statfile to classify */
                st = cur->data;
                if ((file = statfile_pool_is_open (pool, st->path)) == NULL) {
                        if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
@@ -344,6 +350,70 @@ bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symb
        return TRUE;
 }
 
+gboolean
+bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
+               GTree *input, struct worker_task *task, gboolean is_spam, GError **err)
+{
+       struct bayes_callback_data      data;
+       gchar                          *value;
+       gint                            nodes, minnodes;
+       struct statfile                *st;
+       stat_file_t                    *file;
+       GList                          *cur;
+
+       g_assert (pool != NULL);
+       g_assert (ctx != NULL);
+
+       if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
+               minnodes = strtol (value, NULL, 10);
+               nodes = g_tree_nnodes (input);
+               if (nodes > FEATURE_WINDOW_SIZE) {
+                       nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE;
+               }
+               if (nodes < minnodes) {
+                       return FALSE;
+               }
+       }
+
+       cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE);
+       if (cur) {
+               memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
+       }
+       else {
+               cur = ctx->cfg->statfiles;
+       }
+
+       data.pool = pool;
+       data.now = time (NULL);
+       data.ctx = ctx;
+
+       while (cur) {
+               /* Select statfiles to learn */
+               st = cur->data;
+               if (st->is_spam != is_spam) {
+                       cur = g_list_next (cur);
+                       continue;
+               }
+               if ((file = statfile_pool_is_open (pool, st->path)) == NULL) {
+                       if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
+                               msg_warn ("cannot open %s", st->path);
+                               cur = g_list_next (cur);
+                               continue;
+                       }
+               }
+               data.file = file;
+               statfile_pool_lock_file (pool, data.file);
+               g_tree_foreach (input, bayes_learn_callback, &data);
+               statfile_inc_revision (file);
+               statfile_pool_unlock_file (pool, data.file);
+               maybe_write_binlog (ctx->cfg, st, file, input);
+
+               cur = g_list_next (cur);
+       }
+
+       return TRUE;
+}
+
 GList *
 bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task)
 {
index 6b0554e1b0cd3db3dd382db6a12d7f5464a6c3be..5e2b9ea882ffd8ed9c433cea581be011747f6118 100644 (file)
@@ -35,6 +35,7 @@ struct classifier               classifiers[] = {
                        .init_func = winnow_init,
                        .classify_func = winnow_classify,
                        .learn_func = winnow_learn,
+                       .learn_spam_func = winnow_learn_spam,
                        .weights_func = winnow_weights
                },
                {
@@ -42,6 +43,7 @@ struct classifier               classifiers[] = {
                        .init_func = bayes_init,
                        .classify_func = bayes_classify,
                        .learn_func = bayes_learn,
+                       .learn_spam_func = bayes_learn_spam,
                        .weights_func = bayes_weights
                }
 };
index 601db0205f853d77d3ae8124247c4dc2395410fe..78ceb196ebcf90aeea933d4b509446e30a5f6fd6 100644 (file)
@@ -32,6 +32,8 @@ struct classifier {
        gboolean (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool,
                                                        const char *symbol, GTree *input, gboolean in_class,
                                                        double *sum, double multiplier, GError **err);
+       gboolean (*learn_spam_func)(struct classifier_ctx* ctx, statfile_pool_t *pool,
+                                                               GTree *input, struct worker_task *task, gboolean is_spam, GError **err);
        GList* (*weights_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
 };
 
@@ -43,6 +45,8 @@ struct classifier_ctx* winnow_init (memory_pool_t *pool, struct classifier_confi
 gboolean winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
 gboolean winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input,
                                gboolean in_class, double *sum, double multiplier, GError **err);
+gboolean winnow_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
+                       GTree *input, struct worker_task *task, gboolean is_spam, GError **err);
 GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
 
 /* Bayes algorithm */
@@ -50,6 +54,8 @@ struct classifier_ctx* bayes_init (memory_pool_t *pool, struct classifier_config
 gboolean bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
 gboolean bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input,
                                gboolean in_class, double *sum, double multiplier, GError **err);
+gboolean bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
+                       GTree *input, struct worker_task *task, gboolean is_spam, GError **err);
 GList *bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
 /* Array of all defined classifiers */
 extern struct classifier classifiers[];
index 2e8b9842347bf04aaa5f6cb9f26c848a06d96f99..b123ce3e5f14dccc1a8404748bbc29dda419c13e 100644 (file)
@@ -223,19 +223,14 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
                }
        }
 
-    if (ctx->cfg->pre_callbacks) {
-#ifdef WITH_LUA
-        cur = call_classifier_pre_callbacks (ctx->cfg, task);
-        if (cur) {
-            memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
-        }
-#else
-           cur = ctx->cfg->statfiles;
-#endif
-    }
-    else {
-           cur = ctx->cfg->statfiles;
-    }
+       cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE);
+       if (cur) {
+               memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
+       }
+       else {
+               cur = ctx->cfg->statfiles;
+       }
+
        while (cur) {
                st = cur->data;
                data.sum = 0;
@@ -597,3 +592,15 @@ end:
        }
        return TRUE;
 }
+
+gboolean
+winnow_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
+               GTree *input, struct worker_task *task, gboolean is_spam, GError **err)
+{
+       g_set_error (err,
+                                       winnow_error_quark(),           /* error domain */
+                                       1,                                      /* error code */
+                                       "learn spam is not supported for winnow"
+                                       );
+       return FALSE;
+}
index e7e4c10d86b020b01ea943073ad0919d59dddfe4..ce33c53a27c53eb759e13523c4a76da005cdf1f1 100644 (file)
@@ -35,6 +35,7 @@
 #include "classifiers/classifiers.h"
 #include "binlog.h"
 #include "statfile_sync.h"
+#include "lua/lua_common.h"
 
 #define END "END" CRLF
 
@@ -49,6 +50,8 @@ enum command_type {
        COMMAND_SHUTDOWN,
        COMMAND_UPTIME,
        COMMAND_LEARN,
+       COMMAND_LEARN_SPAM,
+       COMMAND_LEARN_HAM,
        COMMAND_HELP,
        COMMAND_COUNTERS,
        COMMAND_SYNC,
@@ -84,7 +87,9 @@ static struct controller_command commands[] = {
        {"weights", FALSE, COMMAND_WEIGHTS},
        {"help", FALSE, COMMAND_HELP},
        {"counters", FALSE, COMMAND_COUNTERS},
-       {"sync", FALSE, COMMAND_SYNC}
+       {"sync", FALSE, COMMAND_SYNC},
+       {"learn_spam", TRUE, COMMAND_LEARN_SPAM},
+       {"learn_ham", TRUE, COMMAND_LEARN_HAM}
 };
 
 static GList                   *custom_commands = NULL;
@@ -519,7 +524,7 @@ process_command (struct controller_command *cmd, gchar **cmd_args, struct contro
                        }
                }
                break;
-       case COMMAND_LEARN:
+       case COMMAND_LEARN_SPAM:
                if (check_auth (cmd, session)) {
                        arg = *cmd_args;
                        if (!arg || *arg == '\0') {
@@ -532,13 +537,105 @@ process_command (struct controller_command *cmd, gchar **cmd_args, struct contro
                        }
                        arg = *(cmd_args + 1);
                        if (arg == NULL || *arg == '\0') {
-                               msg_debug ("no statfile size specified in learn command");
+                               msg_debug ("no message size specified in learn command");
+                               r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: symbol and message size" CRLF);
+                               if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+                                       return FALSE;
+                               }
+                               return TRUE;
+                       }
+                       size = strtoul (arg, &err_str, 10);
+                       if (err_str && *err_str != '\0') {
+                               msg_debug ("message size is invalid: %s", arg);
+                               r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF);
+                               if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+                                       return FALSE;
+                               }
+                               return TRUE;
+                       }
+                       cl = find_classifier_conf (session->cfg, *cmd_args);
+                       if (cl == NULL) {
+                               r = rspamd_snprintf (out_buf, sizeof (out_buf), "classifier %s is not defined" CRLF, *cmd_args);
+                               if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+                                       return FALSE;
+                               }
+                               return TRUE;
+
+                       }
+                       session->learn_classifier = cl;
+
+                       /* By default learn positive */
+                       session->in_class = TRUE;
+                       rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size);
+                       session->state = STATE_LEARN_SPAM;
+               }
+               break;
+       case COMMAND_LEARN_HAM:
+               if (check_auth (cmd, session)) {
+                       arg = *cmd_args;
+                       if (!arg || *arg == '\0') {
+                               msg_debug ("no statfile specified in learn command");
                                r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: stat filename and its size" CRLF);
                                if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
                                        return FALSE;
                                }
                                return TRUE;
                        }
+                       arg = *(cmd_args + 1);
+                       if (arg == NULL || *arg == '\0') {
+                               msg_debug ("no message size specified in learn command");
+                               r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: symbol and message size" CRLF);
+                               if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+                                       return FALSE;
+                               }
+                               return TRUE;
+                       }
+                       size = strtoul (arg, &err_str, 10);
+                       if (err_str && *err_str != '\0') {
+                               msg_debug ("message size is invalid: %s", arg);
+                               r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF);
+                               if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+                                       return FALSE;
+                               }
+                               return TRUE;
+                       }
+                       cl = find_classifier_conf (session->cfg, *cmd_args);
+                       if (cl == NULL) {
+                               r = rspamd_snprintf (out_buf, sizeof (out_buf), "classifier %s is not defined" CRLF, *cmd_args);
+                               if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+                                       return FALSE;
+                               }
+                               return TRUE;
+
+                       }
+                       session->learn_classifier = cl;
+
+                       /* By default learn positive */
+                       session->in_class = FALSE;
+                       rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size);
+                       session->state = STATE_LEARN_SPAM;
+               }
+               break;
+       case COMMAND_LEARN:
+               if (check_auth (cmd, session)) {
+                       arg = *cmd_args;
+                       if (!arg || *arg == '\0') {
+                               msg_debug ("no statfile specified in learn command");
+                               r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: stat filename and its size" CRLF);
+                               if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+                                       return FALSE;
+                               }
+                               return TRUE;
+                       }
+                       arg = *(cmd_args + 1);
+                       if (arg == NULL || *arg == '\0') {
+                               msg_debug ("no message size specified in learn command");
+                               r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: symbol and message size" CRLF);
+                               if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+                                       return FALSE;
+                               }
+                               return TRUE;
+                       }
                        size = strtoul (arg, &err_str, 10);
                        if (err_str && *err_str != '\0') {
                                msg_debug ("message size is invalid: %s", arg);
@@ -817,6 +914,65 @@ controller_read_socket (f_str_t * in, void *arg)
                }
                session->state = STATE_REPLY;
                break;
+       case STATE_LEARN_SPAM_PRE:
+               session->learn_buf = in;
+               task = construct_task (session->worker);
+
+               task->msg = memory_pool_alloc (task->task_pool, sizeof (f_str_t));
+               task->msg->begin = in->begin;
+               task->msg->len = in->len;
+
+
+               r = process_message (task);
+               if (r == -1) {
+                       msg_warn ("processing of message failed");
+                       free_task (task, FALSE);
+                       session->state = STATE_REPLY;
+                       r = rspamd_snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF);
+                       if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+                               return FALSE;
+                       }
+                       return FALSE;
+               }
+
+               r = process_filters (task);
+               if (r == -1) {
+                       session->state = STATE_REPLY;
+                       r = rspamd_snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF);
+                       free_task (task, FALSE);
+                       if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+                               return FALSE;
+                       }
+               }
+               else if (r == 0) {
+                       session->state = STATE_LEARN;
+                       task->dispatcher = session->dispatcher;
+                       session->learn_task = task;
+                       rspamd_dispatcher_pause (session->dispatcher);
+               }
+               else {
+                       lua_call_post_filters (task);
+                       session->state = STATE_REPLY;
+
+                       if (! learn_task_spam (session->learn_classifier, task, session->in_class, &err)) {
+                               if (err) {
+                                       i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF END, err->message);
+                                       g_error_free (err);
+                               }
+                               else {
+                                       i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF END);
+                               }
+                       }
+                       else {
+                               i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok" CRLF END);
+                       }
+
+                       free_task (task, FALSE);
+                       if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+                               return FALSE;
+                       }
+               }
+               break;
        case STATE_WEIGHTS:
                session->learn_buf = in;
                task = construct_task (session->worker);
@@ -926,12 +1082,37 @@ static                          gboolean
 controller_write_socket (void *arg)
 {
        struct controller_session      *session = (struct controller_session *)arg;
+       gint                            i;
+       gchar                           out_buf[64];
+       GError                         *err = NULL;
 
        if (session->state == STATE_QUIT) {
                /* Free buffers */
                destroy_session (session->s);
                return FALSE;
        }
+       else if (session->state == STATE_LEARN) {
+               /* Perform actual learn here */
+               if (! learn_task_spam (session->learn_classifier, session->learn_task, session->in_class, &err)) {
+                       if (err) {
+                               i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF END, err->message);
+                               g_error_free (err);
+                       }
+                       else {
+                               i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF END);
+                       }
+               }
+               else {
+                       i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok" CRLF END);
+               }
+               learn_task_spam (session->learn_classifier, session->learn_task, session->in_class, &err);
+               session->learn_task->dispatcher = NULL;
+               free_task (session->learn_task, FALSE);
+               session->state = STATE_REPLY;
+               if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+                       return FALSE;
+               }
+       }
        else if (session->state == STATE_REPLY) {
                session->state = STATE_COMMAND;
                rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_LINE, BUFSIZ);
index fea91125fee695b9f1c1528f3c5035596532ea11..66f2331150cf5a5628d2eec491ffe86da4f04230 100644 (file)
@@ -944,6 +944,89 @@ learn_task (const gchar *statfile, struct worker_task *task, GError **err)
        return TRUE;
 }
 
+gboolean
+learn_task_spam (struct classifier_config *cl, struct worker_task *task, gboolean is_spam, GError **err)
+{
+       GList                          *cur, *ex;
+       struct classifier_ctx          *cls_ctx;
+       f_str_t                         c;
+       GTree                          *tokens = NULL;
+       struct mime_text_part          *part;
+       gboolean                        is_utf = FALSE, is_twopart = FALSE;
+
+       cur = g_list_first (task->text_parts);
+       if (cur != NULL && cur->next != NULL && cur->next->next == NULL) {
+               is_twopart = TRUE;
+       }
+
+       /* Get tokens from each element */
+       while (cur) {
+               part = cur->data;
+               /* Skip empty parts */
+               if (part->is_empty) {
+                       cur = g_list_next (cur);
+                       continue;
+               }
+               c.begin = part->content->data;
+               c.len = part->content->len;
+               is_utf = part->is_utf;
+               ex = part->urls_offset;
+               if (is_twopart && cur->next == NULL) {
+                       /* Compare part's content */
+                       if (fuzzy_compare_parts (cur->data, cur->prev->data) >= COMMON_PART_FACTOR) {
+                               msg_info ("message <%s> has two common text parts, ignore the last one", task->message_id);
+                               break;
+                       }
+               }
+               /* Get tokens */
+               if (!cl->tokenizer->tokenize_func (
+                               cl->tokenizer, task->task_pool,
+                               &c, &tokens, FALSE, is_utf, ex)) {
+                       g_set_error (err, filter_error_quark(), 2, "Cannot tokenize message");
+                       return FALSE;
+               }
+               cur = g_list_next (cur);
+       }
+
+       /* Handle messages without text */
+       if (tokens == NULL) {
+               g_set_error (err, filter_error_quark(), 3, "Cannot tokenize message, no text data");
+               msg_info ("learn failed for message <%s>, no tokens to extract", task->message_id);
+               return FALSE;
+       }
+
+       /* Take care of subject */
+       tokenize_subject (task, &tokens);
+
+       /* Init classifier */
+       cls_ctx = cl->classifier->init_func (
+                       task->task_pool, cl);
+       /* Learn */
+       if (!cl->classifier->learn_spam_func (
+                       cls_ctx, task->worker->srv->statfile_pool,
+                       tokens, task, is_spam, err)) {
+               if (*err) {
+                       msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, (*err)->message);
+                       return FALSE;
+               }
+               else {
+                       g_set_error (err, filter_error_quark(), 4, "Learn failed, unknown learn classifier error");
+                       msg_info ("learn failed for message <%s>, unknown learn error", task->message_id);
+                       return FALSE;
+               }
+       }
+       /* Increase statistics */
+       task->worker->srv->stat->messages_learned++;
+
+       msg_info ("learn success for message <%s>",
+                       task->message_id);
+       statfile_pool_plan_invalidate (task->worker->srv->statfile_pool,
+                       DEFAULT_STATFILE_INVALIDATE_TIME,
+                       DEFAULT_STATFILE_INVALIDATE_JITTER);
+
+       return TRUE;
+}
+
 /* 
  * vi:ts=4 
  */
index cfcab7cb55e6883b88089e3ae4ca1754403f96c1..9305941706b8007bb859246d24d9cee459f7299d 100644 (file)
@@ -11,6 +11,7 @@
 
 struct worker_task;
 struct rspamd_settings;
+struct classifier_config;
 
 typedef double (*metric_cons_func)(struct worker_task *task, const gchar *metric_name, const gchar *func_name);
 typedef void (*filter_func)(struct worker_task *task);
@@ -136,6 +137,15 @@ double factor_consolidation_func (struct worker_task *task, const gchar *metric_
  */
 gboolean learn_task (const gchar *statfile, struct worker_task *task, GError **err);
 
+/*
+ * Learn specified statfile with message in a task
+ * @param statfile symbol of statfile
+ * @param task worker's task object
+ * @param err pointer to GError
+ * @return true if learn succeed
+ */
+gboolean learn_task_spam (struct classifier_config *cl, struct worker_task *task, gboolean is_spam, GError **err);
+
 gboolean check_action_str (const gchar *data, gint *result);
 const gchar *str_action_metric (enum rspamd_metric_action action);
 gint check_metric_action (double score, double required_score, struct metric *metric);
index 306e1e70059e4b592ab58f53935ca2949a2f27f1..9a65416104ad5b8be5e9415bd92509bee47c6e02 100644 (file)
@@ -655,29 +655,6 @@ decode_entitles (gchar *s, guint * len)
        }
 }
 
-/*
- * Find the first occurrence of find in s, ignore case.
- */
-static gchar *
-html_strncasestr (const gchar *s, const gchar *find, gsize len)
-{
-       gchar                           c, sc;
-       gsize                           mlen;
-
-       if ((c = *find++) != 0) {
-               c = g_ascii_tolower (c);
-               mlen = strlen (find);
-               do {
-                       do {
-                               if ((sc = *s++) == 0 || len -- == 0)
-                                       return (NULL);
-                       } while (g_ascii_tolower (sc) != c);
-               } while (g_ascii_strncasecmp (s, find, mlen) != 0);
-               s--;
-       }
-       return ((gchar *)s);
-}
-
 static void
 check_phishing (struct worker_task *task, struct uri *href_url, const gchar *url_text, gsize remain, tag_id_t id)
 {
@@ -803,11 +780,11 @@ parse_tag_url (struct worker_task *task, struct mime_text_part *part, tag_id_t i
 
        /* For A tags search for href= and for IMG tags search for src= */
        if (id == Tag_A) {
-               c = html_strncasestr (tag_text, "href=", tag_len);
+               c = rspamd_strncasestr (tag_text, "href=", tag_len);
                len = sizeof ("href=") - 1;
        }
        else if (id == Tag_IMG) {
-               c = html_strncasestr (tag_text, "src=", tag_len);
+               c = rspamd_strncasestr (tag_text, "src=", tag_len);
                len = sizeof ("src=") - 1;
        }
 
index f7ce173a71d8b0f2d079fd440ea1f61ca9d1e167..511767680dcee1637675ba99fcc34c25d56817c6 100644 (file)
@@ -41,7 +41,18 @@ static const struct luaL_reg    classifierlib_m[] = {
 };
 
 
+LUA_FUNCTION_DEF (statfile, get_symbol);
+LUA_FUNCTION_DEF (statfile, get_path);
+LUA_FUNCTION_DEF (statfile, get_size);
+LUA_FUNCTION_DEF (statfile, is_spam);
+LUA_FUNCTION_DEF (statfile, get_param);
+
 static const struct luaL_reg    statfilelib_m[] = {
+       LUA_INTERFACE_DEF (statfile, get_symbol),
+       LUA_INTERFACE_DEF (statfile, get_path),
+       LUA_INTERFACE_DEF (statfile, get_size),
+       LUA_INTERFACE_DEF (statfile, is_spam),
+       LUA_INTERFACE_DEF (statfile, get_param),
        {"__tostring", lua_class_tostring},
        {NULL, NULL}
 };
@@ -64,50 +75,81 @@ lua_check_classifier (lua_State * L)
        return *((struct classifier_config **)ud);
 }
 
-/* Return list of statfiles that should be checked for this message */
-GList *
-call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task)
+static GList *
+call_classifier_pre_callback (struct classifier_config *ccf, struct worker_task *task,
+               lua_State *L, gboolean is_learn, gboolean is_spam)
 {
-       GList                          *res = NULL, *cur;
-       struct classifier_callback_data *cd;
        struct classifier_config      **pccf;
        struct worker_task            **ptask;
        struct statfile                *st;
        gint                            i, len;
+       GList                          *res = NULL;
 
-       /* Go throught all callbacks and call them, appending results to list */
-       cur = g_list_first (ccf->pre_callbacks);
-       while (cur) {
-               cd = cur->data;
-               lua_getglobal (cd->L, cd->name);
+       pccf = lua_newuserdata (L, sizeof (struct classifier_config *));
+       lua_setclass (L, "rspamd{classifier}", -1);
+       *pccf = ccf;
 
-               pccf = lua_newuserdata (cd->L, sizeof (struct classifier_config *));
-               lua_setclass (cd->L, "rspamd{classifier}", -1);
-               *pccf = ccf;
+       ptask = lua_newuserdata (L, sizeof (struct worker_task *));
+       lua_setclass (L, "rspamd{task}", -1);
+       *ptask = task;
 
-               ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *));
-               lua_setclass (cd->L, "rspamd{task}", -1);
-               *ptask = task;
+       lua_pushboolean (L, is_learn);
+       lua_pushboolean (L, is_spam);
 
-               if (lua_pcall (cd->L, 2, 1, 0) != 0) {
-                       msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1));
-               }
-               else {
-                       if (lua_istable (cd->L, 1)) {
-                               len = lua_objlen (cd->L, 1);
-                               for (i = 1; i <= len; i ++) {
-                                       lua_rawgeti (cd->L, 1, i);
-                                       st = lua_check_statfile (cd->L);
-                                       if (st) {
-                                               res = g_list_prepend (res, st);
-                                       }
+       if (lua_pcall (L, 4, 1, 0) != 0) {
+               msg_warn ("error running pre classifier callback %s", lua_tostring (L, -1));
+       }
+       else {
+               if (lua_istable (L, 1)) {
+                       len = lua_objlen (L, 1);
+                       for (i = 1; i <= len; i ++) {
+                               lua_rawgeti (L, 1, i);
+                               st = lua_check_statfile (L);
+                               if (st) {
+                                       res = g_list_prepend (res, st);
                                }
                        }
                }
+       }
+
+       return res;
+}
+
+/* Return list of statfiles that should be checked for this message */
+GList *
+call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task,
+               gboolean is_learn, gboolean is_spam)
+{
+       GList                           *res = NULL, *cur;
+       struct classifier_callback_data *cd;
+       lua_State                       *L;
+
+
+       /* Go throught all callbacks and call them, appending results to list */
+       cur = g_list_first (ccf->pre_callbacks);
+       while (cur) {
+               cd = cur->data;
+               lua_getglobal (cd->L, cd->name);
+
+               res = g_list_concat (res, call_classifier_pre_callback (ccf, task, cd->L, is_learn, is_spam));
 
                cur = g_list_next (cur);
        }
        
+       if (res == NULL) {
+               L = task->cfg->lua_state;
+               /* Check function from global table 'classifiers' */
+               lua_getglobal (L, "classifiers");
+               if (lua_istable (L, -1)) {
+                       lua_pushstring (L, ccf->classifier->name);
+                       lua_gettable (L, -2);
+                       /* Function is now on top */
+                       if (lua_isfunction (L, 1)) {
+                               res = call_classifier_pre_callback (ccf, task, L, is_learn, is_spam);
+                       }
+               }
+       }
+
        return res;
 }
 
@@ -226,6 +268,86 @@ lua_classifier_get_statfiles (lua_State *L)
 }
 
 /* Statfile functions */
+static gint
+lua_statfile_get_symbol (lua_State *L)
+{
+       struct statfile                *st = lua_check_statfile (L);
+
+       if (st != NULL) {
+               lua_pushstring (L, st->symbol);
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+}
+
+static gint
+lua_statfile_get_path (lua_State *L)
+{
+       struct statfile                *st = lua_check_statfile (L);
+
+       if (st != NULL) {
+               lua_pushstring (L, st->path);
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+}
+
+static gint
+lua_statfile_get_size (lua_State *L)
+{
+       struct statfile                *st = lua_check_statfile (L);
+
+       if (st != NULL) {
+               lua_pushinteger (L, st->size);
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+}
+
+static gint
+lua_statfile_is_spam (lua_State *L)
+{
+       struct statfile                *st = lua_check_statfile (L);
+
+       if (st != NULL) {
+               lua_pushboolean (L, st->is_spam);
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+}
+
+static gint
+lua_statfile_get_param (lua_State *L)
+{
+       struct statfile                *st = lua_check_statfile (L);
+       const gchar                    *param, *value;
+
+       param = luaL_checkstring (L, 2);
+
+       if (st != NULL && param != NULL) {
+               value = g_hash_table_lookup (st->opts, param);
+               if (param != NULL) {
+                       lua_pushstring (L, value);
+                       return 1;
+               }
+       }
+       lua_pushnil (L);
+
+       return 1;
+}
+
 static struct statfile      *
 lua_check_statfile (lua_State * L)
 {
index 4427dcae964d4bfc6562ff97298265be2cabec49..3f3fc7a1daf7bc5cae8e5d3ca3012e5322206d23 100644 (file)
@@ -51,7 +51,7 @@ void lua_call_post_filters (struct worker_task *task);
 void add_luabuf (const gchar *line);
 
 /* Classify functions */
-GList *call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task);
+GList *call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task, gboolean is_learn, gboolean is_spam);
 double call_classifier_post_callbacks (struct classifier_config *ccf, struct worker_task *task, double in);
 
 double lua_normalizer_func (struct config_file *cfg, long double score, void *params);
index 9a3335d0ab47309c65426a65a1046c7587bafd40..d8761617f62c69e641033f986e6d8ceed1f89161 100644 (file)
@@ -64,6 +64,7 @@ struct classifier_config;
 struct mime_part;
 struct rspamd_view;
 struct rspamd_dns_resolver;
+struct worker_task;
 
 /** 
  * Server statistics
@@ -138,6 +139,8 @@ struct controller_session {
        enum {
                STATE_COMMAND,
                STATE_LEARN,
+               STATE_LEARN_SPAM_PRE,
+               STATE_LEARN_SPAM,
                STATE_REPLY,
                STATE_QUIT,
                STATE_OTHER,
@@ -162,6 +165,7 @@ struct controller_session {
                        f_str_t *in);                                   /**< other command handler to execute at the end of processing */
        void *other_data;                                                                                       /**< and its data                                                                       */
        struct rspamd_async_session* s;                                                         /**< async session object                                                       */
+       struct worker_task *learn_task;
 };
 
 typedef void (*controller_func_t)(gchar **args, struct controller_session *session);
index 6d8cb09e09330c050db6ef1184ce488409971a2d..16ed43786b676d3bcfadebabda9a0f73980f2368 100644 (file)
@@ -1363,6 +1363,29 @@ escape_braces_addr_fstr (memory_pool_t *pool, f_str_t *in)
        return res;
 }
 
+/*
+ * Find the first occurrence of find in s, ignore case.
+ */
+gchar *
+rspamd_strncasestr (const gchar *s, const gchar *find, gint len)
+{
+       gchar                           c, sc;
+       gsize                           mlen;
+
+       if ((c = *find++) != 0) {
+               c = g_ascii_tolower (c);
+               mlen = strlen (find);
+               do {
+                       do {
+                               if ((sc = *s++) == 0 || len -- == 0)
+                                       return (NULL);
+                       } while (g_ascii_tolower (sc) != c);
+               } while (g_ascii_strncasecmp (s, find, mlen) != 0);
+               s--;
+       }
+       return ((gchar *)s);
+}
+
 /*
  * vi:ts=4
  */
index 73d96c7595c6c84bb9cd65be14c60f1fe7cccf14..b1d4ad2e9ace5b9e0ac2edaa1a77f4d2675372ad 100644 (file)
@@ -158,4 +158,7 @@ void free_task (struct worker_task *task, gboolean is_soft);
 void free_task_hard (gpointer ud);
 void free_task_soft (gpointer ud);
 
+/* Find string find in string s ignoring case */
+gchar* rspamd_strncasestr (const gchar *s, const gchar *find, gint len);
+
 #endif