diff options
author | Vsevolod Stakhov <vsevolod@rambler-co.ru> | 2010-08-25 19:38:18 +0400 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rambler-co.ru> | 2010-08-25 19:38:18 +0400 |
commit | 671bbfa9cc85a625df33d6384a3179ce076765b9 (patch) | |
tree | ac77b0439bc08ce896b614e1a95641878d35d807 /src/lua | |
parent | 331f6807e9ef813755f8ec197cc24915c458a684 (diff) | |
download | rspamd-671bbfa9cc85a625df33d6384a3179ce076765b9.tar.gz rspamd-671bbfa9cc85a625df33d6384a3179ce076765b9.zip |
* Add post filters to lua API - filters that would be called after all message's processing
* Add ability to check for specified symbol in task results from lua
* Add ability to check for metric's results from lua
* Add ability to learn specified statfile form lua
Diffstat (limited to 'src/lua')
-rw-r--r-- | src/lua/lua_common.h | 1 | ||||
-rw-r--r-- | src/lua/lua_config.c | 46 | ||||
-rw-r--r-- | src/lua/lua_task.c | 198 |
3 files changed, 244 insertions, 1 deletions
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index d1f5f9eb4..18fa6481d 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -39,6 +39,7 @@ int lua_call_filter (const char *function, struct worker_task *task); int lua_call_chain_filter (const char *function, struct worker_task *task, int *marks, unsigned int number); double lua_consolidation_func (struct worker_task *task, const char *metric_name, const char *function_name); gboolean lua_call_expression_func (const char *function, struct worker_task *task, GList *args, gboolean *res); +void lua_call_post_filters (struct worker_task *task); void add_luabuf (const char *line); /* Classify functions */ diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index 5e0148ca7..7ae7d3e0e 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -37,6 +37,7 @@ LUA_FUNCTION_DEF (config, add_radix_map); LUA_FUNCTION_DEF (config, add_hash_map); LUA_FUNCTION_DEF (config, get_classifier); LUA_FUNCTION_DEF (config, register_symbol); +LUA_FUNCTION_DEF (config, register_post_filter); static const struct luaL_reg configlib_m[] = { LUA_INTERFACE_DEF (config, get_module_opt), @@ -46,6 +47,7 @@ static const struct luaL_reg configlib_m[] = { LUA_INTERFACE_DEF (config, add_hash_map), LUA_INTERFACE_DEF (config, get_classifier), LUA_INTERFACE_DEF (config, register_symbol), + LUA_INTERFACE_DEF (config, register_post_filter), {"__tostring", lua_class_tostring}, {NULL, NULL} }; @@ -291,7 +293,49 @@ lua_config_register_function (lua_State *L) register_expression_function (name, lua_config_function_callback, cd); } } - return 0; + return 1; +} + +void +lua_call_post_filters (struct worker_task *task) +{ + struct lua_callback_data *cd; + struct worker_task **ptask; + GList *cur; + + cur = task->cfg->post_filters; + while (cur) { + cd = cur->data; + lua_getglobal (cd->L, cd->name); + ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); + lua_setclass (cd->L, "rspamd{task}", -1); + *ptask = task; + + if (lua_pcall (cd->L, 1, 1, 0) != 0) { + msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1)); + } + cur = g_list_next (cur); + } +} + +static int +lua_config_register_post_filter (lua_State *L) +{ + struct config_file *cfg = lua_check_config (L); + const char *callback; + struct lua_callback_data *cd; + + if (cfg) { + + callback = luaL_checkstring (L, 2); + if (callback) { + cd = g_malloc (sizeof (struct lua_callback_data)); + cd->name = g_strdup (callback); + cd->L = L; + cfg->post_filters = g_list_prepend (cfg->post_filters, cd); + } + } + return 1; } static int diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index 32c7b41ef..b18cfc21f 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -26,8 +26,20 @@ #include "lua_common.h" #include "../message.h" #include "../expressions.h" +#include "../protocol.h" +#include "../filter.h" #include "../dns.h" +#include "../util.h" #include "../images.h" +#include "../cfg_file.h" +#include "../statfile.h" +#include "../tokenizers/tokenizers.h" +#include "../classifiers/classifiers.h" +#include "../binlog.h" +#include "../statfile_sync.h" + +extern stat_file_t* get_statfile_by_symbol (statfile_pool_t *pool, struct classifier_config *ccf, + const char *symbol, struct statfile **st, gboolean try_create); /* Task methods */ LUA_FUNCTION_DEF (task, get_message); @@ -47,6 +59,10 @@ LUA_FUNCTION_DEF (task, get_from_ip_num); LUA_FUNCTION_DEF (task, get_client_ip_num); LUA_FUNCTION_DEF (task, get_helo); LUA_FUNCTION_DEF (task, get_images); +LUA_FUNCTION_DEF (task, get_symbol); +LUA_FUNCTION_DEF (task, get_metric_score); +LUA_FUNCTION_DEF (task, get_metric_action); +LUA_FUNCTION_DEF (task, learn_statfile); static const struct luaL_reg tasklib_m[] = { LUA_INTERFACE_DEF (task, get_message), @@ -66,6 +82,10 @@ static const struct luaL_reg tasklib_m[] = { LUA_INTERFACE_DEF (task, get_client_ip_num), LUA_INTERFACE_DEF (task, get_helo), LUA_INTERFACE_DEF (task, get_images), + LUA_INTERFACE_DEF (task, get_symbol), + LUA_INTERFACE_DEF (task, get_metric_score), + LUA_INTERFACE_DEF (task, get_metric_action), + LUA_INTERFACE_DEF (task, learn_statfile), {"__tostring", lua_class_tostring}, {NULL, NULL} }; @@ -580,6 +600,184 @@ lua_task_get_images (lua_State *L) return 1; } +G_INLINE_FUNC gboolean +lua_push_symbol_result (lua_State *L, struct worker_task *task, struct metric *metric, const char *symbol) +{ + struct metric_result *metric_res; + struct symbol *s; + int j; + GList *opt; + + metric_res = g_hash_table_lookup (task->results, metric->name); + if (metric_res) { + if ((s = g_hash_table_lookup (metric_res->symbols, symbol)) != NULL) { + j = 0; + lua_newtable (L); + lua_pushstring (L, "metric"); + lua_pushstring (L, metric->name); + lua_settable (L, -3); + lua_pushstring (L, "score"); + lua_pushnumber (L, s->score); + lua_settable (L, -3); + if (s->options) { + opt = s->options; + lua_pushstring (L, "options"); + lua_newtable (L); + while (opt) { + lua_pushstring (L, opt->data); + lua_rawseti (L, -2, j++); + opt = g_list_next (opt); + } + lua_settable (L, -3); + } + + return TRUE; + } + } + + return FALSE; +} + +static int +lua_task_get_symbol (lua_State *L) +{ + struct worker_task *task = lua_check_task (L); + const char *symbol; + struct metric *metric; + GList *cur = NULL, *metric_list; + gboolean found = FALSE; + int i = 1; + + symbol = luaL_checkstring (L, 2); + + if (task && symbol) { + metric_list = g_hash_table_lookup (task->cfg->metrics_symbols, symbol); + if (metric_list) { + lua_newtable (L); + cur = metric_list; + } + else { + metric = task->cfg->default_metric; + } + + if (!cur && metric) { + if ((found = lua_push_symbol_result (L, task, metric, symbol))) { + lua_newtable (L); + lua_rawseti (L, -2, i++); + } + } + else { + while (cur) { + metric = cur->data; + if (lua_push_symbol_result (L, task, metric, symbol)) { + lua_rawseti (L, -2, i++); + found = TRUE; + } + cur = g_list_next (cur); + } + } + } + + if (!found) { + lua_pushnil (L); + } + return 1; +} + +static int +lua_task_learn_statfile (lua_State *L) +{ + struct worker_task *task = lua_check_task (L); + const char *symbol; + struct classifier_config *cl; + GTree *tokens; + struct statfile *st; + stat_file_t *statfile; + struct classifier_ctx *ctx; + + symbol = luaL_checkstring (L, 2); + + if (task && symbol) { + cl = g_hash_table_lookup (task->cfg->classifiers_symbols, symbol); + if (cl == NULL) { + msg_warn ("classifier for symbol %s is not found", symbol); + lua_pushboolean (L, FALSE); + return 1; + } + ctx = cl->classifier->init_func (task->task_pool, cl); + if ((tokens = g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) { + msg_warn ("no tokens found learn failed!"); + lua_pushboolean (L, FALSE); + return 1; + } + statfile = get_statfile_by_symbol (task->worker->srv->statfile_pool, ctx->cfg, + symbol, &st, TRUE); + + if (statfile == NULL) { + msg_warn ("opening statfile failed!"); + lua_pushboolean (L, FALSE); + return 1; + } + + cl->classifier->learn_func (ctx, task->worker->srv->statfile_pool, symbol, tokens, TRUE, NULL, 1., NULL); + maybe_write_binlog (ctx->cfg, st, statfile, tokens); + lua_pushboolean (L, TRUE); + } + + return 1; +} + +static int +lua_task_get_metric_score (lua_State *L) +{ + struct worker_task *task = lua_check_task (L); + const char *metric_name; + struct metric_result *metric_res; + + metric_name = luaL_checkstring (L, 2); + + if (task && metric_name) { + if ((metric_res = g_hash_table_lookup (task->results, metric_name)) != NULL) { + lua_newtable (L); + lua_pushnumber (L, metric_res->score); + lua_rawseti (L, -2, 1); + lua_pushnumber (L, metric_res->metric->required_score); + lua_rawseti (L, -2, 2); + lua_pushnumber (L, metric_res->metric->reject_score); + lua_rawseti (L, -2, 3); + } + else { + lua_pushnil (L); + } + return 1; + } + + return 0; +} + +static int +lua_task_get_metric_action (lua_State *L) +{ + struct worker_task *task = lua_check_task (L); + const char *metric_name; + struct metric_result *metric_res; + enum rspamd_metric_action action; + + metric_name = luaL_checkstring (L, 2); + + if (task && metric_name) { + if ((metric_res = g_hash_table_lookup (task->results, metric_name)) != NULL) { + action = check_metric_action (metric_res->score, metric_res->metric->required_score, metric_res->metric); + lua_pushstring (L, str_action_metric (action)); + } + else { + lua_pushnil (L); + } + return 1; + } + + return 0; +} /**** Textpart implementation *****/ |