diff options
-rw-r--r-- | src/cfg_file.h | 1 | ||||
-rw-r--r-- | src/lua/lua_common.c | 18 | ||||
-rw-r--r-- | src/lua/lua_common.h | 3 | ||||
-rw-r--r-- | src/lua/lua_config.c | 57 | ||||
-rw-r--r-- | src/lua/lua_task.c | 25 | ||||
-rw-r--r-- | src/main.h | 6 | ||||
-rw-r--r-- | src/protocol.c | 71 | ||||
-rw-r--r-- | src/worker.c | 95 |
8 files changed, 230 insertions, 46 deletions
diff --git a/src/cfg_file.h b/src/cfg_file.h index 203558e9a..c4e820997 100644 --- a/src/cfg_file.h +++ b/src/cfg_file.h @@ -312,6 +312,7 @@ struct config_file { GHashTable *classifiers_symbols; /**< hashtable indexed by symbol name of classifiers */ GHashTable* cfg_params; /**< all cfg params indexed by its name in this structure */ GList *views; /**< views */ + GList *pre_filters; /**< list of pre-processing lua filters */ GList *post_filters; /**< list of post-processing lua filters */ GHashTable* domain_settings; /**< settings per-domains */ GHashTable* user_settings; /**< settings per-user */ diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c index 8fbff979c..4a62a7d02 100644 --- a/src/lua/lua_common.c +++ b/src/lua/lua_common.c @@ -219,6 +219,22 @@ luaopen_logger (lua_State * L) return 1; } +static void +lua_add_actions_global (lua_State *L) +{ + gint i; + + lua_newtable (L); + + for (i = METRIC_ACTION_REJECT; i <= METRIC_ACTION_NOACTION; i ++) { + lua_pushstring (L, str_action_metric (i)); + lua_pushinteger (L, i); + lua_settable (L, -3); + } + /* Set global table */ + lua_setglobal (L, "rspamd_actions"); +} + void init_lua (struct config_file *cfg) { @@ -253,6 +269,8 @@ init_lua (struct config_file *cfg) (void)luaopen_http (L); (void)luaopen_redis (L); (void)luaopen_upstream (L); + (void)lua_add_actions_global (L); + cfg->lua_state = L; memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_close, L); diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index adb87135e..c138a35a2 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -17,7 +17,7 @@ extern const luaL_reg null_reg[]; extern GMutex *lua_mtx; -#define RSPAMD_LUA_API_VERSION 10 +#define RSPAMD_LUA_API_VERSION 11 /* Common utility functions */ @@ -71,6 +71,7 @@ gint lua_call_chain_filter (const gchar *function, struct worker_task *task, gin double lua_consolidation_func (struct worker_task *task, const gchar *metric_name, const gchar *function_name); gboolean lua_call_expression_func (const gchar *module, const gchar *symbol, struct worker_task *task, GList *args, gboolean *res); void lua_call_post_filters (struct worker_task *task); +void lua_call_pre_filters (struct worker_task *task); void add_luabuf (const gchar *line); /* Classify functions */ diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index bdb4ce056..0e63f6421 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -44,6 +44,7 @@ LUA_FUNCTION_DEF (config, register_symbol); LUA_FUNCTION_DEF (config, register_virtual_symbol); LUA_FUNCTION_DEF (config, register_callback_symbol); LUA_FUNCTION_DEF (config, register_callback_symbol_priority); +LUA_FUNCTION_DEF (config, register_pre_filter); LUA_FUNCTION_DEF (config, register_post_filter); LUA_FUNCTION_DEF (config, register_module_option); LUA_FUNCTION_DEF (config, get_api_version); @@ -61,6 +62,7 @@ static const struct luaL_reg configlib_m[] = { LUA_INTERFACE_DEF (config, register_callback_symbol), LUA_INTERFACE_DEF (config, register_callback_symbol_priority), LUA_INTERFACE_DEF (config, register_module_option), + LUA_INTERFACE_DEF (config, register_pre_filter), LUA_INTERFACE_DEF (config, register_post_filter), LUA_INTERFACE_DEF (config, get_api_version), {"__tostring", lua_class_tostring}, @@ -489,6 +491,61 @@ lua_config_register_post_filter (lua_State *L) return 1; } +void +lua_call_pre_filters (struct worker_task *task) +{ + struct lua_callback_data *cd; + struct worker_task **ptask; + GList *cur; + + g_mutex_lock (lua_mtx); + cur = task->cfg->pre_filters; + while (cur) { + cd = cur->data; + if (cd->cb_is_ref) { + lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal (cd->L, cd->callback.name); + } + ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); + lua_setclass (cd->L, "rspamd{task}", -1); + *ptask = task; + + if (lua_pcall (cd->L, 1, 0, 0) != 0) { + msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" : + cd->callback.name, lua_tostring (cd->L, -1)); + } + cur = g_list_next (cur); + } + g_mutex_unlock (lua_mtx); +} + +static gint +lua_config_register_pre_filter (lua_State *L) +{ + struct config_file *cfg = lua_check_config (L); + struct lua_callback_data *cd; + + if (cfg) { + cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data)); + if (lua_type (L, 2) == LUA_TSTRING) { + cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2)); + cd->cb_is_ref = FALSE; + } + else { + lua_pushvalue (L, 2); + /* Get a reference */ + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + cd->cb_is_ref = TRUE; + } + cd->L = L; + cfg->pre_filters = g_list_prepend (cfg->pre_filters, cd); + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd); + } + return 1; +} + static gint lua_config_add_radix_map (lua_State *L) { diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index d28d897f6..d205016e0 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -45,6 +45,7 @@ extern stat_file_t* get_statfile_by_symbol (statfile_pool_t *pool, struct classi /* Task methods */ LUA_FUNCTION_DEF (task, get_message); LUA_FUNCTION_DEF (task, insert_result); +LUA_FUNCTION_DEF (task, set_pre_result); LUA_FUNCTION_DEF (task, get_urls); LUA_FUNCTION_DEF (task, get_emails); LUA_FUNCTION_DEF (task, get_text_parts); @@ -75,6 +76,7 @@ LUA_FUNCTION_DEF (task, learn_statfile); static const struct luaL_reg tasklib_m[] = { LUA_INTERFACE_DEF (task, get_message), LUA_INTERFACE_DEF (task, insert_result), + LUA_INTERFACE_DEF (task, set_pre_result), LUA_INTERFACE_DEF (task, get_urls), LUA_INTERFACE_DEF (task, get_emails), LUA_INTERFACE_DEF (task, get_text_parts), @@ -233,6 +235,29 @@ lua_task_insert_result (lua_State * L) return 1; } +static gint +lua_task_set_pre_result (lua_State * L) +{ + struct worker_task *task = lua_check_task (L); + gchar *action_str; + guint action; + + if (task != NULL) { + action = luaL_checkinteger (L, 2); + if (action < task->pre_result.action) { + task->pre_result.action = action; + if (lua_gettop (L) >= 3) { + action_str = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + task->pre_result.str = action_str; + } + else { + task->pre_result.str = NULL; + } + } + } + return 1; +} + struct lua_tree_cb_data { lua_State *L; int i; diff --git a/src/main.h b/src/main.h index 2331dbdb5..22b2f7584 100644 --- a/src/main.h +++ b/src/main.h @@ -179,6 +179,7 @@ struct worker_task { READ_MESSAGE, WRITE_REPLY, WRITE_ERROR, + WAIT_PRE_FILTER, WAIT_FILTER, WAIT_POST_FILTER, CLOSING_CONNECTION, @@ -248,6 +249,11 @@ struct worker_task { struct rspamd_dns_resolver *resolver; /**< DNS resolver */ struct event_base *ev_base; /**< Event base */ + + struct { + enum rspamd_metric_action action; /**< Action of pre filters */ + gchar *str; /**< String describing action */ + } pre_result; /**< Result of pre-filters */ }; /** diff --git a/src/protocol.c b/src/protocol.c index b4dcd7d13..40bd01e0c 100644 --- a/src/protocol.c +++ b/src/protocol.c @@ -988,32 +988,48 @@ print_metric_data_rspamc (struct worker_task *task, gchar *outbuf, gsize size, { gint r = 0; gboolean is_spam = FALSE; + gchar *local_act; if (metric_res == NULL) { + /* This is case when we got reject result from pre filters */ if (task->proto == SPAMC_PROTO) { r = rspamd_snprintf (outbuf, size, "Spam: False ; 0.00 / %.2f" CRLF, ms); } else { + local_act = "False"; + msg_info ("action: %s", str_action_metric (task->pre_result.action)); + if (task->pre_result.action <= METRIC_ACTION_SOFT_REJECT) { + local_act = "True"; + } + if (task->proto_ver >= 11) { - if (!task->is_skipped) { - r = rspamd_snprintf (outbuf, size, - "Metric: default; False; 0.00 / %.2f / %.2f" CRLF, ms, - rs); - } - else { - r = rspamd_snprintf (outbuf, size, - "Metric: default; Skip; 0.00 / %.2f / %.2f" CRLF, ms, - rs); + if (task->is_skipped) { + local_act = "Skip"; } + r = rspamd_snprintf (outbuf, size, + "Metric: default; %s; 0.00 / %.2f / %.2f" CRLF, local_act, ms, + rs); } else { r = rspamd_snprintf (outbuf, size, - "Metric: default; False; 0.00 / %.2f" CRLF, ms); + "Metric: default; %s; 0.00 / %.2f" CRLF, local_act, ms); } - r += rspamd_snprintf (outbuf + r, size - r, + + if (task->pre_result.action == METRIC_ACTION_NOACTION) { + r += rspamd_snprintf (outbuf + r, size - r, "Action: %s" CRLF, str_action_metric ( METRIC_ACTION_NOACTION)); + } + else { + r += rspamd_snprintf (outbuf + r, size - r, + "Action: %s" CRLF, str_action_metric ( + task->pre_result.action)); + if (task->pre_result.str != NULL) { + r += rspamd_snprintf (outbuf + r, size - r, + "Message: %s" CRLF, task->pre_result.str); + } + } } } else { @@ -1067,29 +1083,38 @@ print_metric_data_json (struct worker_task *task, gchar *outbuf, gsize size, enum rspamd_metric_action action) { gint r = 0; + gchar *local_act; - + if (task->pre_result.action == METRIC_ACTION_NOACTION) { + local_act = "False"; + } + else if (task->pre_result.action <= METRIC_ACTION_SOFT_REJECT) { + local_act = "True"; + } if (metric_res == NULL) { - r = rspamd_snprintf (outbuf, size, - " {" CRLF " \"name\": \"default\"," CRLF - " \"is_spam\": false," CRLF - " \"is_skipped\": %s," CRLF - " \"score\": 0.00," CRLF - " \"required_score\": %.2f," CRLF - " \"reject_score\": %.2f," CRLF - " \"action\": \"%s\"," CRLF, - task->is_skipped ? "true" : "false", ms, rs, - str_action_metric (METRIC_ACTION_NOACTION)); + /* This is case when we got reject result from pre filters */ + r = rspamd_snprintf (outbuf, size, + " {" CRLF " \"name\": \"default\"," CRLF + " \"is_spam\": %s," CRLF + " \"is_skipped\": %s," CRLF + " \"score\": 0.00," CRLF + " \"required_score\": %.2f," CRLF + " \"reject_score\": %.2f," CRLF + " \"action\": \"%s\"," CRLF, + local_act, + task->is_skipped ? "true" : "false", ms, rs, + str_action_metric (task->pre_result.action)); } else { r = rspamd_snprintf (outbuf, size, - " {" CRLF " \"name\": \"default\"," CRLF + " {" CRLF " \"name\": \"%s\"," CRLF " \"is_spam\": %s," CRLF " \"is_skipped\": %s," CRLF " \"score\": %.2f," CRLF " \"required_score\": %.2f," CRLF " \"reject_score\": %.2f," CRLF " \"action\": \"%s\"," CRLF, + metric_res->metric->name, metric_res->score >= ms ? "true" : "false", metric_res->score, task->is_skipped ? "true" : "false", ms, rs, diff --git a/src/worker.c b/src/worker.c index d56c2d924..dad842ce3 100644 --- a/src/worker.c +++ b/src/worker.c @@ -247,6 +247,7 @@ construct_task (struct rspamd_worker *worker) new_task->urls); new_task->sock = -1; new_task->is_mime = TRUE; + new_task->pre_result.action = METRIC_ACTION_NOACTION; return new_task; } @@ -497,22 +498,31 @@ read_socket (f_str_t * in, void *arg) return write_socket (task); } else { - r = process_filters (task); - if (r == -1) { - task->last_error = "Filter processing error"; - task->error_code = RSPAMD_FILTER_ERROR; - task->state = WRITE_ERROR; - return write_socket (task); - } - /* Add task to classify to classify pool */ - if (ctx->classify_pool) { - register_async_thread (task->s); - g_thread_pool_push (ctx->classify_pool, task, &err); - if (err != NULL) { - msg_err ("cannot pull task to the pool: %s", err->message); - remove_async_thread (task->s); + if (task->cfg->pre_filters == NULL) { + r = process_filters (task); + if (r == -1) { + task->last_error = "Filter processing error"; + task->error_code = RSPAMD_FILTER_ERROR; + task->state = WRITE_ERROR; + return write_socket (task); + } + /* Add task to classify to classify pool */ + if (ctx->classify_pool) { + register_async_thread (task->s); + g_thread_pool_push (ctx->classify_pool, task, &err); + if (err != NULL) { + msg_err ("cannot pull task to the pool: %s", err->message); + remove_async_thread (task->s); + } } } + else { + lua_call_pre_filters (task); + /* We want fin_task after pre filters are processed */ + task->s->wanna_die = TRUE; + task->state = WAIT_PRE_FILTER; + check_session_pending (task->s); + } } break; case WRITE_REPLY: @@ -521,6 +531,7 @@ read_socket (f_str_t * in, void *arg) break; case WAIT_FILTER: case WAIT_POST_FILTER: + case WAIT_PRE_FILTER: msg_info ("ignoring trailing garbadge of size %z", in->len); break; default: @@ -539,6 +550,8 @@ write_socket (void *arg) { struct worker_task *task = (struct worker_task *) arg; struct rspamd_worker_ctx *ctx; + GError *err = NULL; + gint r; ctx = task->worker->ctx; @@ -578,6 +591,25 @@ write_socket (void *arg) case WAIT_POST_FILTER: /* Do nothing here */ break; + case WAIT_PRE_FILTER: + task->state = WAIT_FILTER; + r = process_filters (task); + if (r == -1) { + task->last_error = "Filter processing error"; + task->error_code = RSPAMD_FILTER_ERROR; + task->state = WRITE_ERROR; + return write_socket (task); + } + /* Add task to classify to classify pool */ + if (ctx->classify_pool) { + register_async_thread (task->s); + g_thread_pool_push (ctx->classify_pool, task, &err); + if (err != NULL) { + msg_err ("cannot pull task to the pool: %s", err->message); + remove_async_thread (task->s); + } + } + break; default: msg_info ("abnormally closing connection at state: %d", task->state); if (ctx->is_custom) { @@ -616,11 +648,12 @@ err_socket (GError * err, void *arg) static gboolean fin_task (void *arg) { - struct worker_task *task = (struct worker_task *) arg; - struct rspamd_worker_ctx *ctx; + struct worker_task *task = (struct worker_task *) arg; + struct rspamd_worker_ctx *ctx; + ctx = task->worker->ctx; - if (task->state != WAIT_POST_FILTER) { + if (task->state != WAIT_POST_FILTER && task->state != WAIT_PRE_FILTER) { /* Process all statfiles */ if (ctx->classify_pool == NULL) { /* Non-threaded version */ @@ -639,13 +672,31 @@ fin_task (void *arg) } - /* Check if we have all events finished */ - task->state = WRITE_REPLY; - if (task->fin_callback) { - task->fin_callback (task->fin_arg); + if (task->state != WAIT_PRE_FILTER) { + /* Check if we have all events finished */ + task->state = WRITE_REPLY; + if (task->fin_callback) { + task->fin_callback (task->fin_arg); + } + else { + rspamd_dispatcher_restore (task->dispatcher); + } } else { - rspamd_dispatcher_restore (task->dispatcher); + if (task->pre_result.action != METRIC_ACTION_NOACTION) { + /* Write result based on pre filters */ + task->state = WRITE_REPLY; + if (task->fin_callback) { + task->fin_callback (task->fin_arg); + } + else { + rspamd_dispatcher_restore (task->dispatcher); + } + } + else { + /* Check normal filters in write callback */ + rspamd_dispatcher_restore (task->dispatcher); + } } return TRUE; |