From 671bbfa9cc85a625df33d6384a3179ce076765b9 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Wed, 25 Aug 2010 19:38:18 +0400 Subject: [PATCH] * Add post filters to lua API - filters that would be called after all message's processing * Add ability to check for specified symbol in task results from lua * Add ability to check for metric's results from lua * Add ability to learn specified statfile form lua --- src/cfg_file.h | 1 + src/filter.c | 31 +++---- src/lua/lua_common.h | 1 + src/lua/lua_config.c | 46 +++++++++- src/lua/lua_task.c | 198 +++++++++++++++++++++++++++++++++++++++++++ src/main.h | 2 + src/protocol.c | 4 +- src/protocol.h | 5 ++ src/worker.c | 4 +- 9 files changed, 270 insertions(+), 22 deletions(-) diff --git a/src/cfg_file.h b/src/cfg_file.h index b2ca61150..d770c237e 100644 --- a/src/cfg_file.h +++ b/src/cfg_file.h @@ -289,6 +289,7 @@ struct config_file { GHashTable *classifiers_symbols; /**< hashtable indexed by symbol name of classifiers */ GHashTable* cfg_params; /**< all cfg params indexed by its name in this structure */ GList *views; /**< views */ + GList *post_filters; /**< list of post-processing lua filters */ GHashTable* domain_settings; /**< settings per-domains */ GHashTable* user_settings; /**< settings per-user */ gchar* domain_settings_str; /**< string representation of settings */ diff --git a/src/filter.c b/src/filter.c index 90566ded9..7d1de3d20 100644 --- a/src/filter.c +++ b/src/filter.c @@ -243,6 +243,9 @@ continue_process_filters (struct worker_task *task) end: /* Process all statfiles */ process_statfiles (task); + /* Call post filters */ + lua_call_post_filters (task); + task->state = WRITE_REPLY; /* XXX: ugly direct call */ if (task->fin_callback) { task->fin_callback (task->fin_arg); @@ -293,6 +296,7 @@ process_filters (struct worker_task *task) else if (!task->pass_all_filters && metric->action == METRIC_ACTION_REJECT && check_metric_is_spam (task, metric)) { + task->state = WRITE_REPLY; return 1; } cur = g_list_next (cur); @@ -463,17 +467,10 @@ make_composites (struct worker_task *task) g_hash_table_foreach (task->results, composites_metric_callback, task); } - -struct statfile_callback_data { - GHashTable *tokens; - struct worker_task *task; -}; - static void classifiers_callback (gpointer value, void *arg) { - struct statfile_callback_data *data = (struct statfile_callback_data *)arg; - struct worker_task *task = data->task; + struct worker_task *task = arg; struct classifier_config *cl = value; struct classifier_ctx *ctx; struct mime_text_part *text_part; @@ -494,7 +491,7 @@ classifiers_callback (gpointer value, void *arg) } ctx = cl->classifier->init_func (task->task_pool, cl); - if ((tokens = g_hash_table_lookup (data->tokens, cl->tokenizer)) == NULL) { + if ((tokens = g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) { while (cur != NULL) { if (header) { c.len = strlen (cur->data); @@ -522,7 +519,7 @@ classifiers_callback (gpointer value, void *arg) } cur = g_list_next (cur); } - g_hash_table_insert (data->tokens, cl->tokenizer, tokens); + g_hash_table_insert (task->tokens, cl->tokenizer, tokens); } if (tokens == NULL) { @@ -549,20 +546,20 @@ classifiers_callback (gpointer value, void *arg) void process_statfiles (struct worker_task *task) { - struct statfile_callback_data cd; - + if (task->is_skipped) { return; } - cd.task = task; - cd.tokens = g_hash_table_new (g_direct_hash, g_direct_equal); - g_list_foreach (task->cfg->classifiers, classifiers_callback, &cd); - g_hash_table_destroy (cd.tokens); + if (task->tokens == NULL) { + task->tokens = g_hash_table_new (g_direct_hash, g_direct_equal); + memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_hash_table_destroy, task->tokens); + } + + g_list_foreach (task->cfg->classifiers, classifiers_callback, task); /* Process results */ make_composites (task); - task->state = WRITE_REPLY; } static void diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index d1f5f9eb4..18fa6481d 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -39,6 +39,7 @@ int lua_call_filter (const char *function, struct worker_task *task); int lua_call_chain_filter (const char *function, struct worker_task *task, int *marks, unsigned int number); double lua_consolidation_func (struct worker_task *task, const char *metric_name, const char *function_name); gboolean lua_call_expression_func (const char *function, struct worker_task *task, GList *args, gboolean *res); +void lua_call_post_filters (struct worker_task *task); void add_luabuf (const char *line); /* Classify functions */ diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index 5e0148ca7..7ae7d3e0e 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -37,6 +37,7 @@ LUA_FUNCTION_DEF (config, add_radix_map); LUA_FUNCTION_DEF (config, add_hash_map); LUA_FUNCTION_DEF (config, get_classifier); LUA_FUNCTION_DEF (config, register_symbol); +LUA_FUNCTION_DEF (config, register_post_filter); static const struct luaL_reg configlib_m[] = { LUA_INTERFACE_DEF (config, get_module_opt), @@ -46,6 +47,7 @@ static const struct luaL_reg configlib_m[] = { LUA_INTERFACE_DEF (config, add_hash_map), LUA_INTERFACE_DEF (config, get_classifier), LUA_INTERFACE_DEF (config, register_symbol), + LUA_INTERFACE_DEF (config, register_post_filter), {"__tostring", lua_class_tostring}, {NULL, NULL} }; @@ -291,7 +293,49 @@ lua_config_register_function (lua_State *L) register_expression_function (name, lua_config_function_callback, cd); } } - return 0; + return 1; +} + +void +lua_call_post_filters (struct worker_task *task) +{ + struct lua_callback_data *cd; + struct worker_task **ptask; + GList *cur; + + cur = task->cfg->post_filters; + while (cur) { + cd = cur->data; + lua_getglobal (cd->L, cd->name); + ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); + lua_setclass (cd->L, "rspamd{task}", -1); + *ptask = task; + + if (lua_pcall (cd->L, 1, 1, 0) != 0) { + msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1)); + } + cur = g_list_next (cur); + } +} + +static int +lua_config_register_post_filter (lua_State *L) +{ + struct config_file *cfg = lua_check_config (L); + const char *callback; + struct lua_callback_data *cd; + + if (cfg) { + + callback = luaL_checkstring (L, 2); + if (callback) { + cd = g_malloc (sizeof (struct lua_callback_data)); + cd->name = g_strdup (callback); + cd->L = L; + cfg->post_filters = g_list_prepend (cfg->post_filters, cd); + } + } + return 1; } static int diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index 32c7b41ef..b18cfc21f 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -26,8 +26,20 @@ #include "lua_common.h" #include "../message.h" #include "../expressions.h" +#include "../protocol.h" +#include "../filter.h" #include "../dns.h" +#include "../util.h" #include "../images.h" +#include "../cfg_file.h" +#include "../statfile.h" +#include "../tokenizers/tokenizers.h" +#include "../classifiers/classifiers.h" +#include "../binlog.h" +#include "../statfile_sync.h" + +extern stat_file_t* get_statfile_by_symbol (statfile_pool_t *pool, struct classifier_config *ccf, + const char *symbol, struct statfile **st, gboolean try_create); /* Task methods */ LUA_FUNCTION_DEF (task, get_message); @@ -47,6 +59,10 @@ LUA_FUNCTION_DEF (task, get_from_ip_num); LUA_FUNCTION_DEF (task, get_client_ip_num); LUA_FUNCTION_DEF (task, get_helo); LUA_FUNCTION_DEF (task, get_images); +LUA_FUNCTION_DEF (task, get_symbol); +LUA_FUNCTION_DEF (task, get_metric_score); +LUA_FUNCTION_DEF (task, get_metric_action); +LUA_FUNCTION_DEF (task, learn_statfile); static const struct luaL_reg tasklib_m[] = { LUA_INTERFACE_DEF (task, get_message), @@ -66,6 +82,10 @@ static const struct luaL_reg tasklib_m[] = { LUA_INTERFACE_DEF (task, get_client_ip_num), LUA_INTERFACE_DEF (task, get_helo), LUA_INTERFACE_DEF (task, get_images), + LUA_INTERFACE_DEF (task, get_symbol), + LUA_INTERFACE_DEF (task, get_metric_score), + LUA_INTERFACE_DEF (task, get_metric_action), + LUA_INTERFACE_DEF (task, learn_statfile), {"__tostring", lua_class_tostring}, {NULL, NULL} }; @@ -580,6 +600,184 @@ lua_task_get_images (lua_State *L) return 1; } +G_INLINE_FUNC gboolean +lua_push_symbol_result (lua_State *L, struct worker_task *task, struct metric *metric, const char *symbol) +{ + struct metric_result *metric_res; + struct symbol *s; + int j; + GList *opt; + + metric_res = g_hash_table_lookup (task->results, metric->name); + if (metric_res) { + if ((s = g_hash_table_lookup (metric_res->symbols, symbol)) != NULL) { + j = 0; + lua_newtable (L); + lua_pushstring (L, "metric"); + lua_pushstring (L, metric->name); + lua_settable (L, -3); + lua_pushstring (L, "score"); + lua_pushnumber (L, s->score); + lua_settable (L, -3); + if (s->options) { + opt = s->options; + lua_pushstring (L, "options"); + lua_newtable (L); + while (opt) { + lua_pushstring (L, opt->data); + lua_rawseti (L, -2, j++); + opt = g_list_next (opt); + } + lua_settable (L, -3); + } + + return TRUE; + } + } + + return FALSE; +} + +static int +lua_task_get_symbol (lua_State *L) +{ + struct worker_task *task = lua_check_task (L); + const char *symbol; + struct metric *metric; + GList *cur = NULL, *metric_list; + gboolean found = FALSE; + int i = 1; + + symbol = luaL_checkstring (L, 2); + + if (task && symbol) { + metric_list = g_hash_table_lookup (task->cfg->metrics_symbols, symbol); + if (metric_list) { + lua_newtable (L); + cur = metric_list; + } + else { + metric = task->cfg->default_metric; + } + + if (!cur && metric) { + if ((found = lua_push_symbol_result (L, task, metric, symbol))) { + lua_newtable (L); + lua_rawseti (L, -2, i++); + } + } + else { + while (cur) { + metric = cur->data; + if (lua_push_symbol_result (L, task, metric, symbol)) { + lua_rawseti (L, -2, i++); + found = TRUE; + } + cur = g_list_next (cur); + } + } + } + + if (!found) { + lua_pushnil (L); + } + return 1; +} + +static int +lua_task_learn_statfile (lua_State *L) +{ + struct worker_task *task = lua_check_task (L); + const char *symbol; + struct classifier_config *cl; + GTree *tokens; + struct statfile *st; + stat_file_t *statfile; + struct classifier_ctx *ctx; + + symbol = luaL_checkstring (L, 2); + + if (task && symbol) { + cl = g_hash_table_lookup (task->cfg->classifiers_symbols, symbol); + if (cl == NULL) { + msg_warn ("classifier for symbol %s is not found", symbol); + lua_pushboolean (L, FALSE); + return 1; + } + ctx = cl->classifier->init_func (task->task_pool, cl); + if ((tokens = g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) { + msg_warn ("no tokens found learn failed!"); + lua_pushboolean (L, FALSE); + return 1; + } + statfile = get_statfile_by_symbol (task->worker->srv->statfile_pool, ctx->cfg, + symbol, &st, TRUE); + + if (statfile == NULL) { + msg_warn ("opening statfile failed!"); + lua_pushboolean (L, FALSE); + return 1; + } + + cl->classifier->learn_func (ctx, task->worker->srv->statfile_pool, symbol, tokens, TRUE, NULL, 1., NULL); + maybe_write_binlog (ctx->cfg, st, statfile, tokens); + lua_pushboolean (L, TRUE); + } + + return 1; +} + +static int +lua_task_get_metric_score (lua_State *L) +{ + struct worker_task *task = lua_check_task (L); + const char *metric_name; + struct metric_result *metric_res; + + metric_name = luaL_checkstring (L, 2); + + if (task && metric_name) { + if ((metric_res = g_hash_table_lookup (task->results, metric_name)) != NULL) { + lua_newtable (L); + lua_pushnumber (L, metric_res->score); + lua_rawseti (L, -2, 1); + lua_pushnumber (L, metric_res->metric->required_score); + lua_rawseti (L, -2, 2); + lua_pushnumber (L, metric_res->metric->reject_score); + lua_rawseti (L, -2, 3); + } + else { + lua_pushnil (L); + } + return 1; + } + + return 0; +} + +static int +lua_task_get_metric_action (lua_State *L) +{ + struct worker_task *task = lua_check_task (L); + const char *metric_name; + struct metric_result *metric_res; + enum rspamd_metric_action action; + + metric_name = luaL_checkstring (L, 2); + + if (task && metric_name) { + if ((metric_res = g_hash_table_lookup (task->results, metric_name)) != NULL) { + action = check_metric_action (metric_res->score, metric_res->metric->required_score, metric_res->metric); + lua_pushstring (L, str_action_metric (action)); + } + else { + lua_pushnil (L); + } + return 1; + } + + return 0; +} /**** Textpart implementation *****/ diff --git a/src/main.h b/src/main.h index b54d3ed8a..e26f3fbd0 100644 --- a/src/main.h +++ b/src/main.h @@ -205,6 +205,8 @@ struct worker_task { GList *images; /**< list of images */ GHashTable *results; /**< hash table of metric_result indexed by * metric's name */ + GHashTable *tokens; /**< hash table of tokens indexed by tokenizer + * pointer */ GList *messages; /**< list of messages that would be reported */ GHashTable *re_cache; /**< cache for matched or not matched regexps */ struct config_file *cfg; /**< pointer to config object */ diff --git a/src/protocol.c b/src/protocol.c index 63b1bee5b..cece0bab7 100644 --- a/src/protocol.c +++ b/src/protocol.c @@ -620,7 +620,7 @@ show_metric_symbols (struct metric_result *metric_res, struct metric_callback_da return TRUE; } -G_INLINE_FUNC const char * +const char * str_action_metric (enum rspamd_metric_action action) { switch (action) { @@ -641,7 +641,7 @@ str_action_metric (enum rspamd_metric_action action) return "unknown action"; } -G_INLINE_FUNC gint +gint check_metric_action (double score, double required_score, struct metric *metric) { GList *cur; diff --git a/src/protocol.h b/src/protocol.h index c216259f6..affcccd5c 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -27,6 +27,8 @@ #define SPAMD_ERROR "EX_ERROR" struct worker_task; +enum rspamd_metric_action; +struct metric; enum rspamd_protocol { SPAMC_PROTO, @@ -75,4 +77,7 @@ gboolean write_reply (struct worker_task *task) G_GNUC_WARN_UNUSED_RESULT; */ void register_protocol_command (const char *name, protocol_reply_func func); +const char *str_action_metric (enum rspamd_metric_action action); +gint check_metric_action (double score, double required_score, struct metric *metric); + #endif diff --git a/src/worker.c b/src/worker.c index cbd234492..427609458 100644 --- a/src/worker.c +++ b/src/worker.c @@ -38,7 +38,7 @@ #include "map.h" #include "dns.h" -#include "evdns/evdns.h" +#include "lua/lua_common.h" #ifndef WITHOUT_PERL # include /* from the Perl distribution */ @@ -359,6 +359,7 @@ read_socket (f_str_t * in, void *arg) } else { process_statfiles (task); + lua_call_post_filters (task); return write_socket (task); } break; @@ -666,7 +667,6 @@ start_worker (struct rspamd_worker *worker) worker->srv->pid = getpid (); event_init (); - evdns_init (); init_signals (&signals, sig_handler); sigprocmask (SIG_UNBLOCK, &signals.sa_mask, NULL); -- 2.39.5