summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/cfg_file.h1
-rw-r--r--src/lua/lua_common.c18
-rw-r--r--src/lua/lua_common.h3
-rw-r--r--src/lua/lua_config.c57
-rw-r--r--src/lua/lua_task.c25
-rw-r--r--src/main.h6
-rw-r--r--src/protocol.c71
-rw-r--r--src/worker.c95
8 files changed, 230 insertions, 46 deletions
diff --git a/src/cfg_file.h b/src/cfg_file.h
index 203558e9a..c4e820997 100644
--- a/src/cfg_file.h
+++ b/src/cfg_file.h
@@ -312,6 +312,7 @@ struct config_file {
GHashTable *classifiers_symbols; /**< hashtable indexed by symbol name of classifiers */
GHashTable* cfg_params; /**< all cfg params indexed by its name in this structure */
GList *views; /**< views */
+ GList *pre_filters; /**< list of pre-processing lua filters */
GList *post_filters; /**< list of post-processing lua filters */
GHashTable* domain_settings; /**< settings per-domains */
GHashTable* user_settings; /**< settings per-user */
diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c
index 8fbff979c..4a62a7d02 100644
--- a/src/lua/lua_common.c
+++ b/src/lua/lua_common.c
@@ -219,6 +219,22 @@ luaopen_logger (lua_State * L)
return 1;
}
+static void
+lua_add_actions_global (lua_State *L)
+{
+ gint i;
+
+ lua_newtable (L);
+
+ for (i = METRIC_ACTION_REJECT; i <= METRIC_ACTION_NOACTION; i ++) {
+ lua_pushstring (L, str_action_metric (i));
+ lua_pushinteger (L, i);
+ lua_settable (L, -3);
+ }
+ /* Set global table */
+ lua_setglobal (L, "rspamd_actions");
+}
+
void
init_lua (struct config_file *cfg)
{
@@ -253,6 +269,8 @@ init_lua (struct config_file *cfg)
(void)luaopen_http (L);
(void)luaopen_redis (L);
(void)luaopen_upstream (L);
+ (void)lua_add_actions_global (L);
+
cfg->lua_state = L;
memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_close, L);
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h
index adb87135e..c138a35a2 100644
--- a/src/lua/lua_common.h
+++ b/src/lua/lua_common.h
@@ -17,7 +17,7 @@
extern const luaL_reg null_reg[];
extern GMutex *lua_mtx;
-#define RSPAMD_LUA_API_VERSION 10
+#define RSPAMD_LUA_API_VERSION 11
/* Common utility functions */
@@ -71,6 +71,7 @@ gint lua_call_chain_filter (const gchar *function, struct worker_task *task, gin
double lua_consolidation_func (struct worker_task *task, const gchar *metric_name, const gchar *function_name);
gboolean lua_call_expression_func (const gchar *module, const gchar *symbol, struct worker_task *task, GList *args, gboolean *res);
void lua_call_post_filters (struct worker_task *task);
+void lua_call_pre_filters (struct worker_task *task);
void add_luabuf (const gchar *line);
/* Classify functions */
diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c
index bdb4ce056..0e63f6421 100644
--- a/src/lua/lua_config.c
+++ b/src/lua/lua_config.c
@@ -44,6 +44,7 @@ LUA_FUNCTION_DEF (config, register_symbol);
LUA_FUNCTION_DEF (config, register_virtual_symbol);
LUA_FUNCTION_DEF (config, register_callback_symbol);
LUA_FUNCTION_DEF (config, register_callback_symbol_priority);
+LUA_FUNCTION_DEF (config, register_pre_filter);
LUA_FUNCTION_DEF (config, register_post_filter);
LUA_FUNCTION_DEF (config, register_module_option);
LUA_FUNCTION_DEF (config, get_api_version);
@@ -61,6 +62,7 @@ static const struct luaL_reg configlib_m[] = {
LUA_INTERFACE_DEF (config, register_callback_symbol),
LUA_INTERFACE_DEF (config, register_callback_symbol_priority),
LUA_INTERFACE_DEF (config, register_module_option),
+ LUA_INTERFACE_DEF (config, register_pre_filter),
LUA_INTERFACE_DEF (config, register_post_filter),
LUA_INTERFACE_DEF (config, get_api_version),
{"__tostring", lua_class_tostring},
@@ -489,6 +491,61 @@ lua_config_register_post_filter (lua_State *L)
return 1;
}
+void
+lua_call_pre_filters (struct worker_task *task)
+{
+ struct lua_callback_data *cd;
+ struct worker_task **ptask;
+ GList *cur;
+
+ g_mutex_lock (lua_mtx);
+ cur = task->cfg->pre_filters;
+ while (cur) {
+ cd = cur->data;
+ if (cd->cb_is_ref) {
+ lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref);
+ }
+ else {
+ lua_getglobal (cd->L, cd->callback.name);
+ }
+ ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *));
+ lua_setclass (cd->L, "rspamd{task}", -1);
+ *ptask = task;
+
+ if (lua_pcall (cd->L, 1, 0, 0) != 0) {
+ msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" :
+ cd->callback.name, lua_tostring (cd->L, -1));
+ }
+ cur = g_list_next (cur);
+ }
+ g_mutex_unlock (lua_mtx);
+}
+
+static gint
+lua_config_register_pre_filter (lua_State *L)
+{
+ struct config_file *cfg = lua_check_config (L);
+ struct lua_callback_data *cd;
+
+ if (cfg) {
+ cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data));
+ if (lua_type (L, 2) == LUA_TSTRING) {
+ cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2));
+ cd->cb_is_ref = FALSE;
+ }
+ else {
+ lua_pushvalue (L, 2);
+ /* Get a reference */
+ cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
+ cd->cb_is_ref = TRUE;
+ }
+ cd->L = L;
+ cfg->pre_filters = g_list_prepend (cfg->pre_filters, cd);
+ memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd);
+ }
+ return 1;
+}
+
static gint
lua_config_add_radix_map (lua_State *L)
{
diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c
index d28d897f6..d205016e0 100644
--- a/src/lua/lua_task.c
+++ b/src/lua/lua_task.c
@@ -45,6 +45,7 @@ extern stat_file_t* get_statfile_by_symbol (statfile_pool_t *pool, struct classi
/* Task methods */
LUA_FUNCTION_DEF (task, get_message);
LUA_FUNCTION_DEF (task, insert_result);
+LUA_FUNCTION_DEF (task, set_pre_result);
LUA_FUNCTION_DEF (task, get_urls);
LUA_FUNCTION_DEF (task, get_emails);
LUA_FUNCTION_DEF (task, get_text_parts);
@@ -75,6 +76,7 @@ LUA_FUNCTION_DEF (task, learn_statfile);
static const struct luaL_reg tasklib_m[] = {
LUA_INTERFACE_DEF (task, get_message),
LUA_INTERFACE_DEF (task, insert_result),
+ LUA_INTERFACE_DEF (task, set_pre_result),
LUA_INTERFACE_DEF (task, get_urls),
LUA_INTERFACE_DEF (task, get_emails),
LUA_INTERFACE_DEF (task, get_text_parts),
@@ -233,6 +235,29 @@ lua_task_insert_result (lua_State * L)
return 1;
}
+static gint
+lua_task_set_pre_result (lua_State * L)
+{
+ struct worker_task *task = lua_check_task (L);
+ gchar *action_str;
+ guint action;
+
+ if (task != NULL) {
+ action = luaL_checkinteger (L, 2);
+ if (action < task->pre_result.action) {
+ task->pre_result.action = action;
+ if (lua_gettop (L) >= 3) {
+ action_str = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3));
+ task->pre_result.str = action_str;
+ }
+ else {
+ task->pre_result.str = NULL;
+ }
+ }
+ }
+ return 1;
+}
+
struct lua_tree_cb_data {
lua_State *L;
int i;
diff --git a/src/main.h b/src/main.h
index 2331dbdb5..22b2f7584 100644
--- a/src/main.h
+++ b/src/main.h
@@ -179,6 +179,7 @@ struct worker_task {
READ_MESSAGE,
WRITE_REPLY,
WRITE_ERROR,
+ WAIT_PRE_FILTER,
WAIT_FILTER,
WAIT_POST_FILTER,
CLOSING_CONNECTION,
@@ -248,6 +249,11 @@ struct worker_task {
struct rspamd_dns_resolver *resolver; /**< DNS resolver */
struct event_base *ev_base; /**< Event base */
+
+ struct {
+ enum rspamd_metric_action action; /**< Action of pre filters */
+ gchar *str; /**< String describing action */
+ } pre_result; /**< Result of pre-filters */
};
/**
diff --git a/src/protocol.c b/src/protocol.c
index b4dcd7d13..40bd01e0c 100644
--- a/src/protocol.c
+++ b/src/protocol.c
@@ -988,32 +988,48 @@ print_metric_data_rspamc (struct worker_task *task, gchar *outbuf, gsize size,
{
gint r = 0;
gboolean is_spam = FALSE;
+ gchar *local_act;
if (metric_res == NULL) {
+ /* This is case when we got reject result from pre filters */
if (task->proto == SPAMC_PROTO) {
r = rspamd_snprintf (outbuf, size,
"Spam: False ; 0.00 / %.2f" CRLF, ms);
}
else {
+ local_act = "False";
+ msg_info ("action: %s", str_action_metric (task->pre_result.action));
+ if (task->pre_result.action <= METRIC_ACTION_SOFT_REJECT) {
+ local_act = "True";
+ }
+
if (task->proto_ver >= 11) {
- if (!task->is_skipped) {
- r = rspamd_snprintf (outbuf, size,
- "Metric: default; False; 0.00 / %.2f / %.2f" CRLF, ms,
- rs);
- }
- else {
- r = rspamd_snprintf (outbuf, size,
- "Metric: default; Skip; 0.00 / %.2f / %.2f" CRLF, ms,
- rs);
+ if (task->is_skipped) {
+ local_act = "Skip";
}
+ r = rspamd_snprintf (outbuf, size,
+ "Metric: default; %s; 0.00 / %.2f / %.2f" CRLF, local_act, ms,
+ rs);
}
else {
r = rspamd_snprintf (outbuf, size,
- "Metric: default; False; 0.00 / %.2f" CRLF, ms);
+ "Metric: default; %s; 0.00 / %.2f" CRLF, local_act, ms);
}
- r += rspamd_snprintf (outbuf + r, size - r,
+
+ if (task->pre_result.action == METRIC_ACTION_NOACTION) {
+ r += rspamd_snprintf (outbuf + r, size - r,
"Action: %s" CRLF, str_action_metric (
METRIC_ACTION_NOACTION));
+ }
+ else {
+ r += rspamd_snprintf (outbuf + r, size - r,
+ "Action: %s" CRLF, str_action_metric (
+ task->pre_result.action));
+ if (task->pre_result.str != NULL) {
+ r += rspamd_snprintf (outbuf + r, size - r,
+ "Message: %s" CRLF, task->pre_result.str);
+ }
+ }
}
}
else {
@@ -1067,29 +1083,38 @@ print_metric_data_json (struct worker_task *task, gchar *outbuf, gsize size,
enum rspamd_metric_action action)
{
gint r = 0;
+ gchar *local_act;
-
+ if (task->pre_result.action == METRIC_ACTION_NOACTION) {
+ local_act = "False";
+ }
+ else if (task->pre_result.action <= METRIC_ACTION_SOFT_REJECT) {
+ local_act = "True";
+ }
if (metric_res == NULL) {
- r = rspamd_snprintf (outbuf, size,
- " {" CRLF " \"name\": \"default\"," CRLF
- " \"is_spam\": false," CRLF
- " \"is_skipped\": %s," CRLF
- " \"score\": 0.00," CRLF
- " \"required_score\": %.2f," CRLF
- " \"reject_score\": %.2f," CRLF
- " \"action\": \"%s\"," CRLF,
- task->is_skipped ? "true" : "false", ms, rs,
- str_action_metric (METRIC_ACTION_NOACTION));
+ /* This is case when we got reject result from pre filters */
+ r = rspamd_snprintf (outbuf, size,
+ " {" CRLF " \"name\": \"default\"," CRLF
+ " \"is_spam\": %s," CRLF
+ " \"is_skipped\": %s," CRLF
+ " \"score\": 0.00," CRLF
+ " \"required_score\": %.2f," CRLF
+ " \"reject_score\": %.2f," CRLF
+ " \"action\": \"%s\"," CRLF,
+ local_act,
+ task->is_skipped ? "true" : "false", ms, rs,
+ str_action_metric (task->pre_result.action));
}
else {
r = rspamd_snprintf (outbuf, size,
- " {" CRLF " \"name\": \"default\"," CRLF
+ " {" CRLF " \"name\": \"%s\"," CRLF
" \"is_spam\": %s," CRLF
" \"is_skipped\": %s," CRLF
" \"score\": %.2f," CRLF
" \"required_score\": %.2f," CRLF
" \"reject_score\": %.2f," CRLF
" \"action\": \"%s\"," CRLF,
+ metric_res->metric->name,
metric_res->score >= ms ? "true" : "false",
metric_res->score,
task->is_skipped ? "true" : "false", ms, rs,
diff --git a/src/worker.c b/src/worker.c
index d56c2d924..dad842ce3 100644
--- a/src/worker.c
+++ b/src/worker.c
@@ -247,6 +247,7 @@ construct_task (struct rspamd_worker *worker)
new_task->urls);
new_task->sock = -1;
new_task->is_mime = TRUE;
+ new_task->pre_result.action = METRIC_ACTION_NOACTION;
return new_task;
}
@@ -497,22 +498,31 @@ read_socket (f_str_t * in, void *arg)
return write_socket (task);
}
else {
- r = process_filters (task);
- if (r == -1) {
- task->last_error = "Filter processing error";
- task->error_code = RSPAMD_FILTER_ERROR;
- task->state = WRITE_ERROR;
- return write_socket (task);
- }
- /* Add task to classify to classify pool */
- if (ctx->classify_pool) {
- register_async_thread (task->s);
- g_thread_pool_push (ctx->classify_pool, task, &err);
- if (err != NULL) {
- msg_err ("cannot pull task to the pool: %s", err->message);
- remove_async_thread (task->s);
+ if (task->cfg->pre_filters == NULL) {
+ r = process_filters (task);
+ if (r == -1) {
+ task->last_error = "Filter processing error";
+ task->error_code = RSPAMD_FILTER_ERROR;
+ task->state = WRITE_ERROR;
+ return write_socket (task);
+ }
+ /* Add task to classify to classify pool */
+ if (ctx->classify_pool) {
+ register_async_thread (task->s);
+ g_thread_pool_push (ctx->classify_pool, task, &err);
+ if (err != NULL) {
+ msg_err ("cannot pull task to the pool: %s", err->message);
+ remove_async_thread (task->s);
+ }
}
}
+ else {
+ lua_call_pre_filters (task);
+ /* We want fin_task after pre filters are processed */
+ task->s->wanna_die = TRUE;
+ task->state = WAIT_PRE_FILTER;
+ check_session_pending (task->s);
+ }
}
break;
case WRITE_REPLY:
@@ -521,6 +531,7 @@ read_socket (f_str_t * in, void *arg)
break;
case WAIT_FILTER:
case WAIT_POST_FILTER:
+ case WAIT_PRE_FILTER:
msg_info ("ignoring trailing garbadge of size %z", in->len);
break;
default:
@@ -539,6 +550,8 @@ write_socket (void *arg)
{
struct worker_task *task = (struct worker_task *) arg;
struct rspamd_worker_ctx *ctx;
+ GError *err = NULL;
+ gint r;
ctx = task->worker->ctx;
@@ -578,6 +591,25 @@ write_socket (void *arg)
case WAIT_POST_FILTER:
/* Do nothing here */
break;
+ case WAIT_PRE_FILTER:
+ task->state = WAIT_FILTER;
+ r = process_filters (task);
+ if (r == -1) {
+ task->last_error = "Filter processing error";
+ task->error_code = RSPAMD_FILTER_ERROR;
+ task->state = WRITE_ERROR;
+ return write_socket (task);
+ }
+ /* Add task to classify to classify pool */
+ if (ctx->classify_pool) {
+ register_async_thread (task->s);
+ g_thread_pool_push (ctx->classify_pool, task, &err);
+ if (err != NULL) {
+ msg_err ("cannot pull task to the pool: %s", err->message);
+ remove_async_thread (task->s);
+ }
+ }
+ break;
default:
msg_info ("abnormally closing connection at state: %d", task->state);
if (ctx->is_custom) {
@@ -616,11 +648,12 @@ err_socket (GError * err, void *arg)
static gboolean
fin_task (void *arg)
{
- struct worker_task *task = (struct worker_task *) arg;
- struct rspamd_worker_ctx *ctx;
+ struct worker_task *task = (struct worker_task *) arg;
+ struct rspamd_worker_ctx *ctx;
+
ctx = task->worker->ctx;
- if (task->state != WAIT_POST_FILTER) {
+ if (task->state != WAIT_POST_FILTER && task->state != WAIT_PRE_FILTER) {
/* Process all statfiles */
if (ctx->classify_pool == NULL) {
/* Non-threaded version */
@@ -639,13 +672,31 @@ fin_task (void *arg)
}
- /* Check if we have all events finished */
- task->state = WRITE_REPLY;
- if (task->fin_callback) {
- task->fin_callback (task->fin_arg);
+ if (task->state != WAIT_PRE_FILTER) {
+ /* Check if we have all events finished */
+ task->state = WRITE_REPLY;
+ if (task->fin_callback) {
+ task->fin_callback (task->fin_arg);
+ }
+ else {
+ rspamd_dispatcher_restore (task->dispatcher);
+ }
}
else {
- rspamd_dispatcher_restore (task->dispatcher);
+ if (task->pre_result.action != METRIC_ACTION_NOACTION) {
+ /* Write result based on pre filters */
+ task->state = WRITE_REPLY;
+ if (task->fin_callback) {
+ task->fin_callback (task->fin_arg);
+ }
+ else {
+ rspamd_dispatcher_restore (task->dispatcher);
+ }
+ }
+ else {
+ /* Check normal filters in write callback */
+ rspamd_dispatcher_restore (task->dispatcher);
+ }
}
return TRUE;