aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/cfg_file.h1
-rw-r--r--src/filter.c31
-rw-r--r--src/lua/lua_common.h1
-rw-r--r--src/lua/lua_config.c46
-rw-r--r--src/lua/lua_task.c198
-rw-r--r--src/main.h2
-rw-r--r--src/protocol.c4
-rw-r--r--src/protocol.h5
-rw-r--r--src/worker.c4
9 files changed, 270 insertions, 22 deletions
diff --git a/src/cfg_file.h b/src/cfg_file.h
index b2ca61150..d770c237e 100644
--- a/src/cfg_file.h
+++ b/src/cfg_file.h
@@ -289,6 +289,7 @@ struct config_file {
GHashTable *classifiers_symbols; /**< hashtable indexed by symbol name of classifiers */
GHashTable* cfg_params; /**< all cfg params indexed by its name in this structure */
GList *views; /**< views */
+ GList *post_filters; /**< list of post-processing lua filters */
GHashTable* domain_settings; /**< settings per-domains */
GHashTable* user_settings; /**< settings per-user */
gchar* domain_settings_str; /**< string representation of settings */
diff --git a/src/filter.c b/src/filter.c
index 90566ded9..7d1de3d20 100644
--- a/src/filter.c
+++ b/src/filter.c
@@ -243,6 +243,9 @@ continue_process_filters (struct worker_task *task)
end:
/* Process all statfiles */
process_statfiles (task);
+ /* Call post filters */
+ lua_call_post_filters (task);
+ task->state = WRITE_REPLY;
/* XXX: ugly direct call */
if (task->fin_callback) {
task->fin_callback (task->fin_arg);
@@ -293,6 +296,7 @@ process_filters (struct worker_task *task)
else if (!task->pass_all_filters &&
metric->action == METRIC_ACTION_REJECT &&
check_metric_is_spam (task, metric)) {
+ task->state = WRITE_REPLY;
return 1;
}
cur = g_list_next (cur);
@@ -463,17 +467,10 @@ make_composites (struct worker_task *task)
g_hash_table_foreach (task->results, composites_metric_callback, task);
}
-
-struct statfile_callback_data {
- GHashTable *tokens;
- struct worker_task *task;
-};
-
static void
classifiers_callback (gpointer value, void *arg)
{
- struct statfile_callback_data *data = (struct statfile_callback_data *)arg;
- struct worker_task *task = data->task;
+ struct worker_task *task = arg;
struct classifier_config *cl = value;
struct classifier_ctx *ctx;
struct mime_text_part *text_part;
@@ -494,7 +491,7 @@ classifiers_callback (gpointer value, void *arg)
}
ctx = cl->classifier->init_func (task->task_pool, cl);
- if ((tokens = g_hash_table_lookup (data->tokens, cl->tokenizer)) == NULL) {
+ if ((tokens = g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) {
while (cur != NULL) {
if (header) {
c.len = strlen (cur->data);
@@ -522,7 +519,7 @@ classifiers_callback (gpointer value, void *arg)
}
cur = g_list_next (cur);
}
- g_hash_table_insert (data->tokens, cl->tokenizer, tokens);
+ g_hash_table_insert (task->tokens, cl->tokenizer, tokens);
}
if (tokens == NULL) {
@@ -549,20 +546,20 @@ classifiers_callback (gpointer value, void *arg)
void
process_statfiles (struct worker_task *task)
{
- struct statfile_callback_data cd;
-
+
if (task->is_skipped) {
return;
}
- cd.task = task;
- cd.tokens = g_hash_table_new (g_direct_hash, g_direct_equal);
- g_list_foreach (task->cfg->classifiers, classifiers_callback, &cd);
- g_hash_table_destroy (cd.tokens);
+ if (task->tokens == NULL) {
+ task->tokens = g_hash_table_new (g_direct_hash, g_direct_equal);
+ memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_hash_table_destroy, task->tokens);
+ }
+
+ g_list_foreach (task->cfg->classifiers, classifiers_callback, task);
/* Process results */
make_composites (task);
- task->state = WRITE_REPLY;
}
static void
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h
index d1f5f9eb4..18fa6481d 100644
--- a/src/lua/lua_common.h
+++ b/src/lua/lua_common.h
@@ -39,6 +39,7 @@ int lua_call_filter (const char *function, struct worker_task *task);
int lua_call_chain_filter (const char *function, struct worker_task *task, int *marks, unsigned int number);
double lua_consolidation_func (struct worker_task *task, const char *metric_name, const char *function_name);
gboolean lua_call_expression_func (const char *function, struct worker_task *task, GList *args, gboolean *res);
+void lua_call_post_filters (struct worker_task *task);
void add_luabuf (const char *line);
/* Classify functions */
diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c
index 5e0148ca7..7ae7d3e0e 100644
--- a/src/lua/lua_config.c
+++ b/src/lua/lua_config.c
@@ -37,6 +37,7 @@ LUA_FUNCTION_DEF (config, add_radix_map);
LUA_FUNCTION_DEF (config, add_hash_map);
LUA_FUNCTION_DEF (config, get_classifier);
LUA_FUNCTION_DEF (config, register_symbol);
+LUA_FUNCTION_DEF (config, register_post_filter);
static const struct luaL_reg configlib_m[] = {
LUA_INTERFACE_DEF (config, get_module_opt),
@@ -46,6 +47,7 @@ static const struct luaL_reg configlib_m[] = {
LUA_INTERFACE_DEF (config, add_hash_map),
LUA_INTERFACE_DEF (config, get_classifier),
LUA_INTERFACE_DEF (config, register_symbol),
+ LUA_INTERFACE_DEF (config, register_post_filter),
{"__tostring", lua_class_tostring},
{NULL, NULL}
};
@@ -291,7 +293,49 @@ lua_config_register_function (lua_State *L)
register_expression_function (name, lua_config_function_callback, cd);
}
}
- return 0;
+ return 1;
+}
+
+void
+lua_call_post_filters (struct worker_task *task)
+{
+ struct lua_callback_data *cd;
+ struct worker_task **ptask;
+ GList *cur;
+
+ cur = task->cfg->post_filters;
+ while (cur) {
+ cd = cur->data;
+ lua_getglobal (cd->L, cd->name);
+ ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *));
+ lua_setclass (cd->L, "rspamd{task}", -1);
+ *ptask = task;
+
+ if (lua_pcall (cd->L, 1, 1, 0) != 0) {
+ msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1));
+ }
+ cur = g_list_next (cur);
+ }
+}
+
+static int
+lua_config_register_post_filter (lua_State *L)
+{
+ struct config_file *cfg = lua_check_config (L);
+ const char *callback;
+ struct lua_callback_data *cd;
+
+ if (cfg) {
+
+ callback = luaL_checkstring (L, 2);
+ if (callback) {
+ cd = g_malloc (sizeof (struct lua_callback_data));
+ cd->name = g_strdup (callback);
+ cd->L = L;
+ cfg->post_filters = g_list_prepend (cfg->post_filters, cd);
+ }
+ }
+ return 1;
}
static int
diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c
index 32c7b41ef..b18cfc21f 100644
--- a/src/lua/lua_task.c
+++ b/src/lua/lua_task.c
@@ -26,8 +26,20 @@
#include "lua_common.h"
#include "../message.h"
#include "../expressions.h"
+#include "../protocol.h"
+#include "../filter.h"
#include "../dns.h"
+#include "../util.h"
#include "../images.h"
+#include "../cfg_file.h"
+#include "../statfile.h"
+#include "../tokenizers/tokenizers.h"
+#include "../classifiers/classifiers.h"
+#include "../binlog.h"
+#include "../statfile_sync.h"
+
+extern stat_file_t* get_statfile_by_symbol (statfile_pool_t *pool, struct classifier_config *ccf,
+ const char *symbol, struct statfile **st, gboolean try_create);
/* Task methods */
LUA_FUNCTION_DEF (task, get_message);
@@ -47,6 +59,10 @@ LUA_FUNCTION_DEF (task, get_from_ip_num);
LUA_FUNCTION_DEF (task, get_client_ip_num);
LUA_FUNCTION_DEF (task, get_helo);
LUA_FUNCTION_DEF (task, get_images);
+LUA_FUNCTION_DEF (task, get_symbol);
+LUA_FUNCTION_DEF (task, get_metric_score);
+LUA_FUNCTION_DEF (task, get_metric_action);
+LUA_FUNCTION_DEF (task, learn_statfile);
static const struct luaL_reg tasklib_m[] = {
LUA_INTERFACE_DEF (task, get_message),
@@ -66,6 +82,10 @@ static const struct luaL_reg tasklib_m[] = {
LUA_INTERFACE_DEF (task, get_client_ip_num),
LUA_INTERFACE_DEF (task, get_helo),
LUA_INTERFACE_DEF (task, get_images),
+ LUA_INTERFACE_DEF (task, get_symbol),
+ LUA_INTERFACE_DEF (task, get_metric_score),
+ LUA_INTERFACE_DEF (task, get_metric_action),
+ LUA_INTERFACE_DEF (task, learn_statfile),
{"__tostring", lua_class_tostring},
{NULL, NULL}
};
@@ -580,6 +600,184 @@ lua_task_get_images (lua_State *L)
return 1;
}
+G_INLINE_FUNC gboolean
+lua_push_symbol_result (lua_State *L, struct worker_task *task, struct metric *metric, const char *symbol)
+{
+ struct metric_result *metric_res;
+ struct symbol *s;
+ int j;
+ GList *opt;
+
+ metric_res = g_hash_table_lookup (task->results, metric->name);
+ if (metric_res) {
+ if ((s = g_hash_table_lookup (metric_res->symbols, symbol)) != NULL) {
+ j = 0;
+ lua_newtable (L);
+ lua_pushstring (L, "metric");
+ lua_pushstring (L, metric->name);
+ lua_settable (L, -3);
+ lua_pushstring (L, "score");
+ lua_pushnumber (L, s->score);
+ lua_settable (L, -3);
+ if (s->options) {
+ opt = s->options;
+ lua_pushstring (L, "options");
+ lua_newtable (L);
+ while (opt) {
+ lua_pushstring (L, opt->data);
+ lua_rawseti (L, -2, j++);
+ opt = g_list_next (opt);
+ }
+ lua_settable (L, -3);
+ }
+
+ return TRUE;
+ }
+ }
+
+ return FALSE;
+}
+
+static int
+lua_task_get_symbol (lua_State *L)
+{
+ struct worker_task *task = lua_check_task (L);
+ const char *symbol;
+ struct metric *metric;
+ GList *cur = NULL, *metric_list;
+ gboolean found = FALSE;
+ int i = 1;
+
+ symbol = luaL_checkstring (L, 2);
+
+ if (task && symbol) {
+ metric_list = g_hash_table_lookup (task->cfg->metrics_symbols, symbol);
+ if (metric_list) {
+ lua_newtable (L);
+ cur = metric_list;
+ }
+ else {
+ metric = task->cfg->default_metric;
+ }
+
+ if (!cur && metric) {
+ if ((found = lua_push_symbol_result (L, task, metric, symbol))) {
+ lua_newtable (L);
+ lua_rawseti (L, -2, i++);
+ }
+ }
+ else {
+ while (cur) {
+ metric = cur->data;
+ if (lua_push_symbol_result (L, task, metric, symbol)) {
+ lua_rawseti (L, -2, i++);
+ found = TRUE;
+ }
+ cur = g_list_next (cur);
+ }
+ }
+ }
+
+ if (!found) {
+ lua_pushnil (L);
+ }
+ return 1;
+}
+
+static int
+lua_task_learn_statfile (lua_State *L)
+{
+ struct worker_task *task = lua_check_task (L);
+ const char *symbol;
+ struct classifier_config *cl;
+ GTree *tokens;
+ struct statfile *st;
+ stat_file_t *statfile;
+ struct classifier_ctx *ctx;
+
+ symbol = luaL_checkstring (L, 2);
+
+ if (task && symbol) {
+ cl = g_hash_table_lookup (task->cfg->classifiers_symbols, symbol);
+ if (cl == NULL) {
+ msg_warn ("classifier for symbol %s is not found", symbol);
+ lua_pushboolean (L, FALSE);
+ return 1;
+ }
+ ctx = cl->classifier->init_func (task->task_pool, cl);
+ if ((tokens = g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) {
+ msg_warn ("no tokens found learn failed!");
+ lua_pushboolean (L, FALSE);
+ return 1;
+ }
+ statfile = get_statfile_by_symbol (task->worker->srv->statfile_pool, ctx->cfg,
+ symbol, &st, TRUE);
+
+ if (statfile == NULL) {
+ msg_warn ("opening statfile failed!");
+ lua_pushboolean (L, FALSE);
+ return 1;
+ }
+
+ cl->classifier->learn_func (ctx, task->worker->srv->statfile_pool, symbol, tokens, TRUE, NULL, 1., NULL);
+ maybe_write_binlog (ctx->cfg, st, statfile, tokens);
+ lua_pushboolean (L, TRUE);
+ }
+
+ return 1;
+}
+
+static int
+lua_task_get_metric_score (lua_State *L)
+{
+ struct worker_task *task = lua_check_task (L);
+ const char *metric_name;
+ struct metric_result *metric_res;
+
+ metric_name = luaL_checkstring (L, 2);
+
+ if (task && metric_name) {
+ if ((metric_res = g_hash_table_lookup (task->results, metric_name)) != NULL) {
+ lua_newtable (L);
+ lua_pushnumber (L, metric_res->score);
+ lua_rawseti (L, -2, 1);
+ lua_pushnumber (L, metric_res->metric->required_score);
+ lua_rawseti (L, -2, 2);
+ lua_pushnumber (L, metric_res->metric->reject_score);
+ lua_rawseti (L, -2, 3);
+ }
+ else {
+ lua_pushnil (L);
+ }
+ return 1;
+ }
+
+ return 0;
+}
+
+static int
+lua_task_get_metric_action (lua_State *L)
+{
+ struct worker_task *task = lua_check_task (L);
+ const char *metric_name;
+ struct metric_result *metric_res;
+ enum rspamd_metric_action action;
+
+ metric_name = luaL_checkstring (L, 2);
+
+ if (task && metric_name) {
+ if ((metric_res = g_hash_table_lookup (task->results, metric_name)) != NULL) {
+ action = check_metric_action (metric_res->score, metric_res->metric->required_score, metric_res->metric);
+ lua_pushstring (L, str_action_metric (action));
+ }
+ else {
+ lua_pushnil (L);
+ }
+ return 1;
+ }
+
+ return 0;
+}
/**** Textpart implementation *****/
diff --git a/src/main.h b/src/main.h
index b54d3ed8a..e26f3fbd0 100644
--- a/src/main.h
+++ b/src/main.h
@@ -205,6 +205,8 @@ struct worker_task {
GList *images; /**< list of images */
GHashTable *results; /**< hash table of metric_result indexed by
* metric's name */
+ GHashTable *tokens; /**< hash table of tokens indexed by tokenizer
+ * pointer */
GList *messages; /**< list of messages that would be reported */
GHashTable *re_cache; /**< cache for matched or not matched regexps */
struct config_file *cfg; /**< pointer to config object */
diff --git a/src/protocol.c b/src/protocol.c
index 63b1bee5b..cece0bab7 100644
--- a/src/protocol.c
+++ b/src/protocol.c
@@ -620,7 +620,7 @@ show_metric_symbols (struct metric_result *metric_res, struct metric_callback_da
return TRUE;
}
-G_INLINE_FUNC const char *
+const char *
str_action_metric (enum rspamd_metric_action action)
{
switch (action) {
@@ -641,7 +641,7 @@ str_action_metric (enum rspamd_metric_action action)
return "unknown action";
}
-G_INLINE_FUNC gint
+gint
check_metric_action (double score, double required_score, struct metric *metric)
{
GList *cur;
diff --git a/src/protocol.h b/src/protocol.h
index c216259f6..affcccd5c 100644
--- a/src/protocol.h
+++ b/src/protocol.h
@@ -27,6 +27,8 @@
#define SPAMD_ERROR "EX_ERROR"
struct worker_task;
+enum rspamd_metric_action;
+struct metric;
enum rspamd_protocol {
SPAMC_PROTO,
@@ -75,4 +77,7 @@ gboolean write_reply (struct worker_task *task) G_GNUC_WARN_UNUSED_RESULT;
*/
void register_protocol_command (const char *name, protocol_reply_func func);
+const char *str_action_metric (enum rspamd_metric_action action);
+gint check_metric_action (double score, double required_score, struct metric *metric);
+
#endif
diff --git a/src/worker.c b/src/worker.c
index cbd234492..427609458 100644
--- a/src/worker.c
+++ b/src/worker.c
@@ -38,7 +38,7 @@
#include "map.h"
#include "dns.h"
-#include "evdns/evdns.h"
+#include "lua/lua_common.h"
#ifndef WITHOUT_PERL
# include <EXTERN.h> /* from the Perl distribution */
@@ -359,6 +359,7 @@ read_socket (f_str_t * in, void *arg)
}
else {
process_statfiles (task);
+ lua_call_post_filters (task);
return write_socket (task);
}
break;
@@ -666,7 +667,6 @@ start_worker (struct rspamd_worker *worker)
worker->srv->pid = getpid ();
event_init ();
- evdns_init ();
init_signals (&signals, sig_handler);
sigprocmask (SIG_UNBLOCK, &signals.sa_mask, NULL);