summaryrefslogtreecommitdiffstats
path: root/src/lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rambler-co.ru>2010-08-25 19:38:18 +0400
committerVsevolod Stakhov <vsevolod@rambler-co.ru>2010-08-25 19:38:18 +0400
commit671bbfa9cc85a625df33d6384a3179ce076765b9 (patch)
treeac77b0439bc08ce896b614e1a95641878d35d807 /src/lua
parent331f6807e9ef813755f8ec197cc24915c458a684 (diff)
downloadrspamd-671bbfa9cc85a625df33d6384a3179ce076765b9.tar.gz
rspamd-671bbfa9cc85a625df33d6384a3179ce076765b9.zip
* Add post filters to lua API - filters that would be called after all message's processing
* Add ability to check for specified symbol in task results from lua * Add ability to check for metric's results from lua * Add ability to learn specified statfile form lua
Diffstat (limited to 'src/lua')
-rw-r--r--src/lua/lua_common.h1
-rw-r--r--src/lua/lua_config.c46
-rw-r--r--src/lua/lua_task.c198
3 files changed, 244 insertions, 1 deletions
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h
index d1f5f9eb4..18fa6481d 100644
--- a/src/lua/lua_common.h
+++ b/src/lua/lua_common.h
@@ -39,6 +39,7 @@ int lua_call_filter (const char *function, struct worker_task *task);
int lua_call_chain_filter (const char *function, struct worker_task *task, int *marks, unsigned int number);
double lua_consolidation_func (struct worker_task *task, const char *metric_name, const char *function_name);
gboolean lua_call_expression_func (const char *function, struct worker_task *task, GList *args, gboolean *res);
+void lua_call_post_filters (struct worker_task *task);
void add_luabuf (const char *line);
/* Classify functions */
diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c
index 5e0148ca7..7ae7d3e0e 100644
--- a/src/lua/lua_config.c
+++ b/src/lua/lua_config.c
@@ -37,6 +37,7 @@ LUA_FUNCTION_DEF (config, add_radix_map);
LUA_FUNCTION_DEF (config, add_hash_map);
LUA_FUNCTION_DEF (config, get_classifier);
LUA_FUNCTION_DEF (config, register_symbol);
+LUA_FUNCTION_DEF (config, register_post_filter);
static const struct luaL_reg configlib_m[] = {
LUA_INTERFACE_DEF (config, get_module_opt),
@@ -46,6 +47,7 @@ static const struct luaL_reg configlib_m[] = {
LUA_INTERFACE_DEF (config, add_hash_map),
LUA_INTERFACE_DEF (config, get_classifier),
LUA_INTERFACE_DEF (config, register_symbol),
+ LUA_INTERFACE_DEF (config, register_post_filter),
{"__tostring", lua_class_tostring},
{NULL, NULL}
};
@@ -291,7 +293,49 @@ lua_config_register_function (lua_State *L)
register_expression_function (name, lua_config_function_callback, cd);
}
}
- return 0;
+ return 1;
+}
+
+void
+lua_call_post_filters (struct worker_task *task)
+{
+ struct lua_callback_data *cd;
+ struct worker_task **ptask;
+ GList *cur;
+
+ cur = task->cfg->post_filters;
+ while (cur) {
+ cd = cur->data;
+ lua_getglobal (cd->L, cd->name);
+ ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *));
+ lua_setclass (cd->L, "rspamd{task}", -1);
+ *ptask = task;
+
+ if (lua_pcall (cd->L, 1, 1, 0) != 0) {
+ msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1));
+ }
+ cur = g_list_next (cur);
+ }
+}
+
+static int
+lua_config_register_post_filter (lua_State *L)
+{
+ struct config_file *cfg = lua_check_config (L);
+ const char *callback;
+ struct lua_callback_data *cd;
+
+ if (cfg) {
+
+ callback = luaL_checkstring (L, 2);
+ if (callback) {
+ cd = g_malloc (sizeof (struct lua_callback_data));
+ cd->name = g_strdup (callback);
+ cd->L = L;
+ cfg->post_filters = g_list_prepend (cfg->post_filters, cd);
+ }
+ }
+ return 1;
}
static int
diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c
index 32c7b41ef..b18cfc21f 100644
--- a/src/lua/lua_task.c
+++ b/src/lua/lua_task.c
@@ -26,8 +26,20 @@
#include "lua_common.h"
#include "../message.h"
#include "../expressions.h"
+#include "../protocol.h"
+#include "../filter.h"
#include "../dns.h"
+#include "../util.h"
#include "../images.h"
+#include "../cfg_file.h"
+#include "../statfile.h"
+#include "../tokenizers/tokenizers.h"
+#include "../classifiers/classifiers.h"
+#include "../binlog.h"
+#include "../statfile_sync.h"
+
+extern stat_file_t* get_statfile_by_symbol (statfile_pool_t *pool, struct classifier_config *ccf,
+ const char *symbol, struct statfile **st, gboolean try_create);
/* Task methods */
LUA_FUNCTION_DEF (task, get_message);
@@ -47,6 +59,10 @@ LUA_FUNCTION_DEF (task, get_from_ip_num);
LUA_FUNCTION_DEF (task, get_client_ip_num);
LUA_FUNCTION_DEF (task, get_helo);
LUA_FUNCTION_DEF (task, get_images);
+LUA_FUNCTION_DEF (task, get_symbol);
+LUA_FUNCTION_DEF (task, get_metric_score);
+LUA_FUNCTION_DEF (task, get_metric_action);
+LUA_FUNCTION_DEF (task, learn_statfile);
static const struct luaL_reg tasklib_m[] = {
LUA_INTERFACE_DEF (task, get_message),
@@ -66,6 +82,10 @@ static const struct luaL_reg tasklib_m[] = {
LUA_INTERFACE_DEF (task, get_client_ip_num),
LUA_INTERFACE_DEF (task, get_helo),
LUA_INTERFACE_DEF (task, get_images),
+ LUA_INTERFACE_DEF (task, get_symbol),
+ LUA_INTERFACE_DEF (task, get_metric_score),
+ LUA_INTERFACE_DEF (task, get_metric_action),
+ LUA_INTERFACE_DEF (task, learn_statfile),
{"__tostring", lua_class_tostring},
{NULL, NULL}
};
@@ -580,6 +600,184 @@ lua_task_get_images (lua_State *L)
return 1;
}
+G_INLINE_FUNC gboolean
+lua_push_symbol_result (lua_State *L, struct worker_task *task, struct metric *metric, const char *symbol)
+{
+ struct metric_result *metric_res;
+ struct symbol *s;
+ int j;
+ GList *opt;
+
+ metric_res = g_hash_table_lookup (task->results, metric->name);
+ if (metric_res) {
+ if ((s = g_hash_table_lookup (metric_res->symbols, symbol)) != NULL) {
+ j = 0;
+ lua_newtable (L);
+ lua_pushstring (L, "metric");
+ lua_pushstring (L, metric->name);
+ lua_settable (L, -3);
+ lua_pushstring (L, "score");
+ lua_pushnumber (L, s->score);
+ lua_settable (L, -3);
+ if (s->options) {
+ opt = s->options;
+ lua_pushstring (L, "options");
+ lua_newtable (L);
+ while (opt) {
+ lua_pushstring (L, opt->data);
+ lua_rawseti (L, -2, j++);
+ opt = g_list_next (opt);
+ }
+ lua_settable (L, -3);
+ }
+
+ return TRUE;
+ }
+ }
+
+ return FALSE;
+}
+
+static int
+lua_task_get_symbol (lua_State *L)
+{
+ struct worker_task *task = lua_check_task (L);
+ const char *symbol;
+ struct metric *metric;
+ GList *cur = NULL, *metric_list;
+ gboolean found = FALSE;
+ int i = 1;
+
+ symbol = luaL_checkstring (L, 2);
+
+ if (task && symbol) {
+ metric_list = g_hash_table_lookup (task->cfg->metrics_symbols, symbol);
+ if (metric_list) {
+ lua_newtable (L);
+ cur = metric_list;
+ }
+ else {
+ metric = task->cfg->default_metric;
+ }
+
+ if (!cur && metric) {
+ if ((found = lua_push_symbol_result (L, task, metric, symbol))) {
+ lua_newtable (L);
+ lua_rawseti (L, -2, i++);
+ }
+ }
+ else {
+ while (cur) {
+ metric = cur->data;
+ if (lua_push_symbol_result (L, task, metric, symbol)) {
+ lua_rawseti (L, -2, i++);
+ found = TRUE;
+ }
+ cur = g_list_next (cur);
+ }
+ }
+ }
+
+ if (!found) {
+ lua_pushnil (L);
+ }
+ return 1;
+}
+
+static int
+lua_task_learn_statfile (lua_State *L)
+{
+ struct worker_task *task = lua_check_task (L);
+ const char *symbol;
+ struct classifier_config *cl;
+ GTree *tokens;
+ struct statfile *st;
+ stat_file_t *statfile;
+ struct classifier_ctx *ctx;
+
+ symbol = luaL_checkstring (L, 2);
+
+ if (task && symbol) {
+ cl = g_hash_table_lookup (task->cfg->classifiers_symbols, symbol);
+ if (cl == NULL) {
+ msg_warn ("classifier for symbol %s is not found", symbol);
+ lua_pushboolean (L, FALSE);
+ return 1;
+ }
+ ctx = cl->classifier->init_func (task->task_pool, cl);
+ if ((tokens = g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) {
+ msg_warn ("no tokens found learn failed!");
+ lua_pushboolean (L, FALSE);
+ return 1;
+ }
+ statfile = get_statfile_by_symbol (task->worker->srv->statfile_pool, ctx->cfg,
+ symbol, &st, TRUE);
+
+ if (statfile == NULL) {
+ msg_warn ("opening statfile failed!");
+ lua_pushboolean (L, FALSE);
+ return 1;
+ }
+
+ cl->classifier->learn_func (ctx, task->worker->srv->statfile_pool, symbol, tokens, TRUE, NULL, 1., NULL);
+ maybe_write_binlog (ctx->cfg, st, statfile, tokens);
+ lua_pushboolean (L, TRUE);
+ }
+
+ return 1;
+}
+
+static int
+lua_task_get_metric_score (lua_State *L)
+{
+ struct worker_task *task = lua_check_task (L);
+ const char *metric_name;
+ struct metric_result *metric_res;
+
+ metric_name = luaL_checkstring (L, 2);
+
+ if (task && metric_name) {
+ if ((metric_res = g_hash_table_lookup (task->results, metric_name)) != NULL) {
+ lua_newtable (L);
+ lua_pushnumber (L, metric_res->score);
+ lua_rawseti (L, -2, 1);
+ lua_pushnumber (L, metric_res->metric->required_score);
+ lua_rawseti (L, -2, 2);
+ lua_pushnumber (L, metric_res->metric->reject_score);
+ lua_rawseti (L, -2, 3);
+ }
+ else {
+ lua_pushnil (L);
+ }
+ return 1;
+ }
+
+ return 0;
+}
+
+static int
+lua_task_get_metric_action (lua_State *L)
+{
+ struct worker_task *task = lua_check_task (L);
+ const char *metric_name;
+ struct metric_result *metric_res;
+ enum rspamd_metric_action action;
+
+ metric_name = luaL_checkstring (L, 2);
+
+ if (task && metric_name) {
+ if ((metric_res = g_hash_table_lookup (task->results, metric_name)) != NULL) {
+ action = check_metric_action (metric_res->score, metric_res->metric->required_score, metric_res->metric);
+ lua_pushstring (L, str_action_metric (action));
+ }
+ else {
+ lua_pushnil (L);
+ }
+ return 1;
+ }
+
+ return 0;
+}
/**** Textpart implementation *****/