]> source.dussan.org Git - rspamd.git/commitdiff
* Add post filters to lua API - filters that would be called after all message's...
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Wed, 25 Aug 2010 15:38:18 +0000 (19:38 +0400)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Wed, 25 Aug 2010 15:38:18 +0000 (19:38 +0400)
* Add ability to check for specified symbol in task results from lua
* Add ability to check for metric's results from lua
* Add ability to learn specified statfile form lua

src/cfg_file.h
src/filter.c
src/lua/lua_common.h
src/lua/lua_config.c
src/lua/lua_task.c
src/main.h
src/protocol.c
src/protocol.h
src/worker.c

index b2ca611500d1efca41e4e9bd11506fb8f5fd9a0d..d770c237e8c9d704aa994e9c5f4f876afe3998e1 100644 (file)
@@ -289,6 +289,7 @@ struct config_file {
        GHashTable *classifiers_symbols;                /**< hashtable indexed by symbol name of classifiers    */
        GHashTable* cfg_params;                                                 /**< all cfg params indexed by its name in this structure */
        GList *views;                                                                   /**< views                                                                                              */
+       GList *post_filters;                                                    /**< list of post-processing lua filters                                */
        GHashTable* domain_settings;                    /**< settings per-domains                               */
        GHashTable* user_settings;                      /**< settings per-user                                  */
        gchar* domain_settings_str;                                             /**< string representation of settings                                  */
index 90566ded90ba6692e48000ea68e4c2d0f2f16101..7d1de3d208bd18b8472bd1f7544b721ef3dd04b9 100644 (file)
@@ -243,6 +243,9 @@ continue_process_filters (struct worker_task *task)
 end:
        /* Process all statfiles */
        process_statfiles (task);
+       /* Call post filters */
+       lua_call_post_filters (task);
+       task->state = WRITE_REPLY;
        /* XXX: ugly direct call */
        if (task->fin_callback) {
                task->fin_callback (task->fin_arg);
@@ -293,6 +296,7 @@ process_filters (struct worker_task *task)
                        else if (!task->pass_all_filters && 
                                                metric->action == METRIC_ACTION_REJECT && 
                                                check_metric_is_spam (task, metric)) {
+                               task->state = WRITE_REPLY;
                                return 1;
                        }
                        cur = g_list_next (cur);
@@ -463,17 +467,10 @@ make_composites (struct worker_task *task)
        g_hash_table_foreach (task->results, composites_metric_callback, task);
 }
 
-
-struct statfile_callback_data {
-       GHashTable                     *tokens;
-       struct worker_task             *task;
-};
-
 static void
 classifiers_callback (gpointer value, void *arg)
 {
-       struct statfile_callback_data  *data = (struct statfile_callback_data *)arg;
-       struct worker_task             *task = data->task;
+       struct worker_task             *task = arg;
        struct classifier_config       *cl = value;
        struct classifier_ctx          *ctx;
        struct mime_text_part          *text_part;
@@ -494,7 +491,7 @@ classifiers_callback (gpointer value, void *arg)
        }
        ctx = cl->classifier->init_func (task->task_pool, cl);
 
-       if ((tokens = g_hash_table_lookup (data->tokens, cl->tokenizer)) == NULL) {
+       if ((tokens = g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) {
                while (cur != NULL) {
                        if (header) {
                                c.len = strlen (cur->data);
@@ -522,7 +519,7 @@ classifiers_callback (gpointer value, void *arg)
                        }
                        cur = g_list_next (cur);
                }
-               g_hash_table_insert (data->tokens, cl->tokenizer, tokens);
+               g_hash_table_insert (task->tokens, cl->tokenizer, tokens);
        }
 
        if (tokens == NULL) {
@@ -549,20 +546,20 @@ classifiers_callback (gpointer value, void *arg)
 void
 process_statfiles (struct worker_task *task)
 {
-       struct statfile_callback_data   cd;
-       
+
        if (task->is_skipped) {
                return;
        }
-       cd.task = task;
-       cd.tokens = g_hash_table_new (g_direct_hash, g_direct_equal);
 
-       g_list_foreach (task->cfg->classifiers, classifiers_callback, &cd);
-       g_hash_table_destroy (cd.tokens);
+       if (task->tokens == NULL) {
+               task->tokens = g_hash_table_new (g_direct_hash, g_direct_equal);
+               memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_hash_table_destroy, task->tokens);
+       }
+
+       g_list_foreach (task->cfg->classifiers, classifiers_callback, task);
 
        /* Process results */
        make_composites (task);
-       task->state = WRITE_REPLY;
 }
 
 static void
index d1f5f9eb4fbc044b7518f6701a2f2debf1e50c1e..18fa6481d04bbd1d839480bf3272f10405147b8e 100644 (file)
@@ -39,6 +39,7 @@ int lua_call_filter (const char *function, struct worker_task *task);
 int lua_call_chain_filter (const char *function, struct worker_task *task, int *marks, unsigned int number);
 double lua_consolidation_func (struct worker_task *task, const char *metric_name, const char *function_name);
 gboolean lua_call_expression_func (const char *function, struct worker_task *task, GList *args, gboolean *res);
+void lua_call_post_filters (struct worker_task *task);
 void add_luabuf (const char *line);
 
 /* Classify functions */
index 5e0148ca7620eea69fdac5b08a4a1499b2cad869..7ae7d3e0ef36e66b4a44682e5e0cfad772b7ade3 100644 (file)
@@ -37,6 +37,7 @@ LUA_FUNCTION_DEF (config, add_radix_map);
 LUA_FUNCTION_DEF (config, add_hash_map);
 LUA_FUNCTION_DEF (config, get_classifier);
 LUA_FUNCTION_DEF (config, register_symbol);
+LUA_FUNCTION_DEF (config, register_post_filter);
 
 static const struct luaL_reg    configlib_m[] = {
        LUA_INTERFACE_DEF (config, get_module_opt),
@@ -46,6 +47,7 @@ static const struct luaL_reg    configlib_m[] = {
        LUA_INTERFACE_DEF (config, add_hash_map),
        LUA_INTERFACE_DEF (config, get_classifier),
        LUA_INTERFACE_DEF (config, register_symbol),
+       LUA_INTERFACE_DEF (config, register_post_filter),
        {"__tostring", lua_class_tostring},
        {NULL, NULL}
 };
@@ -291,7 +293,49 @@ lua_config_register_function (lua_State *L)
                        register_expression_function (name, lua_config_function_callback, cd);
                }
        }
-       return 0;
+       return 1;
+}
+
+void
+lua_call_post_filters (struct worker_task *task)
+{
+       struct lua_callback_data       *cd;
+       struct worker_task            **ptask;
+       GList                          *cur;
+
+       cur = task->cfg->post_filters;
+       while (cur) {
+               cd = cur->data;
+               lua_getglobal (cd->L, cd->name);
+               ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *));
+               lua_setclass (cd->L, "rspamd{task}", -1);
+               *ptask = task;
+
+               if (lua_pcall (cd->L, 1, 1, 0) != 0) {
+                       msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1));
+               }
+               cur = g_list_next (cur);
+       }
+}
+
+static int
+lua_config_register_post_filter (lua_State *L)
+{
+       struct config_file             *cfg = lua_check_config (L);
+       const char                     *callback;
+       struct lua_callback_data       *cd;
+
+       if (cfg) {
+
+               callback = luaL_checkstring (L, 2);
+               if (callback) {
+                       cd = g_malloc (sizeof (struct lua_callback_data));
+                       cd->name = g_strdup (callback);
+                       cd->L = L;
+                       cfg->post_filters = g_list_prepend (cfg->post_filters, cd);
+               }
+       }
+       return 1;
 }
 
 static int
index 32c7b41ef07b0b77fc7cbd991576c3bf9499afa5..b18cfc21f035544c0bb731a22974fc2e74444a90 100644 (file)
 #include "lua_common.h"
 #include "../message.h"
 #include "../expressions.h"
+#include "../protocol.h"
+#include "../filter.h"
 #include "../dns.h"
+#include "../util.h"
 #include "../images.h"
+#include "../cfg_file.h"
+#include "../statfile.h"
+#include "../tokenizers/tokenizers.h"
+#include "../classifiers/classifiers.h"
+#include "../binlog.h"
+#include "../statfile_sync.h"
+
+extern stat_file_t* get_statfile_by_symbol (statfile_pool_t *pool, struct classifier_config *ccf,
+               const char *symbol, struct statfile **st, gboolean try_create);
 
 /* Task methods */
 LUA_FUNCTION_DEF (task, get_message);
@@ -47,6 +59,10 @@ LUA_FUNCTION_DEF (task, get_from_ip_num);
 LUA_FUNCTION_DEF (task, get_client_ip_num);
 LUA_FUNCTION_DEF (task, get_helo);
 LUA_FUNCTION_DEF (task, get_images);
+LUA_FUNCTION_DEF (task, get_symbol);
+LUA_FUNCTION_DEF (task, get_metric_score);
+LUA_FUNCTION_DEF (task, get_metric_action);
+LUA_FUNCTION_DEF (task, learn_statfile);
 
 static const struct luaL_reg    tasklib_m[] = {
        LUA_INTERFACE_DEF (task, get_message),
@@ -66,6 +82,10 @@ static const struct luaL_reg    tasklib_m[] = {
        LUA_INTERFACE_DEF (task, get_client_ip_num),
        LUA_INTERFACE_DEF (task, get_helo),
        LUA_INTERFACE_DEF (task, get_images),
+       LUA_INTERFACE_DEF (task, get_symbol),
+       LUA_INTERFACE_DEF (task, get_metric_score),
+       LUA_INTERFACE_DEF (task, get_metric_action),
+       LUA_INTERFACE_DEF (task, learn_statfile),
        {"__tostring", lua_class_tostring},
        {NULL, NULL}
 };
@@ -580,6 +600,184 @@ lua_task_get_images (lua_State *L)
        return 1;
 }
 
+G_INLINE_FUNC gboolean
+lua_push_symbol_result (lua_State *L, struct worker_task *task, struct metric *metric, const char *symbol)
+{
+       struct metric_result           *metric_res;
+       struct symbol                  *s;
+       int                             j;
+       GList                          *opt;
+
+       metric_res = g_hash_table_lookup (task->results, metric->name);
+       if (metric_res) {
+               if ((s = g_hash_table_lookup (metric_res->symbols, symbol)) != NULL) {
+                       j = 0;
+                       lua_newtable (L);
+                       lua_pushstring (L, "metric");
+                       lua_pushstring (L, metric->name);
+                       lua_settable (L, -3);
+                       lua_pushstring (L, "score");
+                       lua_pushnumber (L, s->score);
+                       lua_settable (L, -3);
+                       if (s->options) {
+                               opt = s->options;
+                               lua_pushstring (L, "options");
+                               lua_newtable (L);
+                               while (opt) {
+                                       lua_pushstring (L, opt->data);
+                                       lua_rawseti (L, -2, j++);
+                                       opt = g_list_next (opt);
+                               }
+                               lua_settable (L, -3);
+                       }
+
+                       return TRUE;
+               }
+       }
+
+       return FALSE;
+}
+
+static int
+lua_task_get_symbol (lua_State *L)
+{
+       struct worker_task             *task = lua_check_task (L);
+       const char                     *symbol;
+       struct metric                  *metric;
+       GList                          *cur = NULL, *metric_list;
+       gboolean                        found = FALSE;
+       int                             i = 1;
+
+       symbol = luaL_checkstring (L, 2);
+
+       if (task && symbol) {
+               metric_list = g_hash_table_lookup (task->cfg->metrics_symbols, symbol);
+               if (metric_list) {
+                       lua_newtable (L);
+                       cur = metric_list;
+               }
+               else {
+                       metric = task->cfg->default_metric;
+               }
+
+               if (!cur && metric) {
+                       if ((found = lua_push_symbol_result (L, task, metric, symbol))) {
+                               lua_newtable (L);
+                               lua_rawseti (L, -2, i++);
+                       }
+               }
+               else {
+                       while (cur) {
+                               metric = cur->data;
+                               if (lua_push_symbol_result (L, task, metric, symbol)) {
+                                       lua_rawseti (L, -2, i++);
+                                       found = TRUE;
+                               }
+                               cur = g_list_next (cur);
+                       }
+               }
+       }
+
+       if (!found) {
+               lua_pushnil (L);
+       }
+       return 1;
+}
+
+static int
+lua_task_learn_statfile (lua_State *L)
+{
+       struct worker_task             *task = lua_check_task (L);
+       const char                     *symbol;
+       struct classifier_config       *cl;
+       GTree                          *tokens;
+       struct statfile                *st;
+       stat_file_t                    *statfile;
+       struct classifier_ctx          *ctx;
+
+       symbol = luaL_checkstring (L, 2);
+
+       if (task && symbol) {
+               cl = g_hash_table_lookup (task->cfg->classifiers_symbols, symbol);
+               if (cl == NULL) {
+                       msg_warn ("classifier for symbol %s is not found", symbol);
+                       lua_pushboolean (L, FALSE);
+                       return 1;
+               }
+               ctx = cl->classifier->init_func (task->task_pool, cl);
+               if ((tokens = g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) {
+                       msg_warn ("no tokens found learn failed!");
+                       lua_pushboolean (L, FALSE);
+                       return 1;
+               }
+               statfile = get_statfile_by_symbol (task->worker->srv->statfile_pool, ctx->cfg,
+                                                               symbol, &st, TRUE);
+
+               if (statfile == NULL) {
+                       msg_warn ("opening statfile failed!");
+                       lua_pushboolean (L, FALSE);
+                       return 1;
+               }
+
+               cl->classifier->learn_func (ctx, task->worker->srv->statfile_pool, symbol, tokens, TRUE, NULL, 1., NULL);
+               maybe_write_binlog (ctx->cfg, st, statfile, tokens);
+               lua_pushboolean (L, TRUE);
+       }
+
+       return 1;
+}
+
+static int
+lua_task_get_metric_score (lua_State *L)
+{
+       struct worker_task             *task = lua_check_task (L);
+       const char                     *metric_name;
+       struct metric_result           *metric_res;
+
+       metric_name = luaL_checkstring (L, 2);
+
+       if (task && metric_name) {
+               if ((metric_res = g_hash_table_lookup (task->results, metric_name)) != NULL) {
+                       lua_newtable (L);
+                       lua_pushnumber (L, metric_res->score);
+                       lua_rawseti (L, -2, 1);
+                       lua_pushnumber (L, metric_res->metric->required_score);
+                       lua_rawseti (L, -2, 2);
+                       lua_pushnumber (L, metric_res->metric->reject_score);
+                       lua_rawseti (L, -2, 3);
+               }
+               else {
+                       lua_pushnil (L);
+               }
+               return 1;
+       }
+
+       return 0;
+}
+
+static int
+lua_task_get_metric_action (lua_State *L)
+{
+       struct worker_task             *task = lua_check_task (L);
+       const char                     *metric_name;
+       struct metric_result           *metric_res;
+       enum rspamd_metric_action       action;
+
+       metric_name = luaL_checkstring (L, 2);
+
+       if (task && metric_name) {
+               if ((metric_res = g_hash_table_lookup (task->results, metric_name)) != NULL) {
+                       action = check_metric_action (metric_res->score, metric_res->metric->required_score, metric_res->metric);
+                       lua_pushstring (L, str_action_metric (action));
+               }
+               else {
+                       lua_pushnil (L);
+               }
+               return 1;
+       }
+
+       return 0;
+}
 
 /**** Textpart implementation *****/
 
index b54d3ed8a2b548172d9e625f4e47c9e9bc2dba1e..e26f3fbd0846385fdae1b54351226f993a547ffd 100644 (file)
@@ -205,6 +205,8 @@ struct worker_task {
        GList *images;                                                                                          /**< list of images                                                                     */
        GHashTable *results;                                                                            /**< hash table of metric_result indexed by 
                                                                                                                                 *    metric's name                                                                     */
+       GHashTable *tokens;                                                                                     /**< hash table of tokens indexed by tokenizer
+                                                                                                                                *    pointer                                                                           */
        GList *messages;                                                                                        /**< list of messages that would be reported            */
        GHashTable *re_cache;                                                                           /**< cache for matched or not matched regexps           */
        struct config_file *cfg;                                                                        /**< pointer to config object                                           */
index 63b1bee5b5c27b3c058d3660f2f2449bfd88029b..cece0bab70f7c8fae47fcf01665e268d3bb2b8fa 100644 (file)
@@ -620,7 +620,7 @@ show_metric_symbols (struct metric_result *metric_res, struct metric_callback_da
        return TRUE;
 }
 
-G_INLINE_FUNC const char *
+const char *
 str_action_metric (enum rspamd_metric_action action)
 {
        switch (action) {
@@ -641,7 +641,7 @@ str_action_metric (enum rspamd_metric_action action)
        return "unknown action";
 }
 
-G_INLINE_FUNC gint
+gint
 check_metric_action (double score, double required_score, struct metric *metric)
 {
        GList                          *cur;
index c216259f61c89e758fc95f0f5d181f7790a37e31..affcccd5c3adb28ebe7447c03e3d3910134b0289 100644 (file)
@@ -27,6 +27,8 @@
 #define SPAMD_ERROR "EX_ERROR"
 
 struct worker_task;
+enum rspamd_metric_action;
+struct metric;
 
 enum rspamd_protocol {
        SPAMC_PROTO,
@@ -75,4 +77,7 @@ gboolean write_reply (struct worker_task *task) G_GNUC_WARN_UNUSED_RESULT;
  */
 void register_protocol_command (const char *name, protocol_reply_func func);
 
+const char *str_action_metric (enum rspamd_metric_action action);
+gint check_metric_action (double score, double required_score, struct metric *metric);
+
 #endif
index cbd23449226b407dca57867b23c764354adcf397..4276094581aebbbc1cc6f776ea6c1aacbeb0396b 100644 (file)
@@ -38,7 +38,7 @@
 #include "map.h"
 #include "dns.h"
 
-#include "evdns/evdns.h"
+#include "lua/lua_common.h"
 
 #ifndef WITHOUT_PERL
 #   include <EXTERN.h>                 /* from the Perl distribution     */
@@ -359,6 +359,7 @@ read_socket (f_str_t * in, void *arg)
                }
                else {
                        process_statfiles (task);
+                       lua_call_post_filters (task);
                        return write_socket (task);
                }
                break;
@@ -666,7 +667,6 @@ start_worker (struct rspamd_worker *worker)
        worker->srv->pid = getpid ();
 
        event_init ();
-       evdns_init ();
 
        init_signals (&signals, sig_handler);
        sigprocmask (SIG_UNBLOCK, &signals.sa_mask, NULL);