From f81ead4f644e6cb0798c9def14e7fead718f89bc Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Tue, 12 Jul 2016 16:32:50 +0100 Subject: [PATCH] [Rework] Rework pre and postfilters system --- src/libserver/cfg_file.h | 2 - src/libserver/symbols_cache.c | 145 +++++++++++++-- src/libserver/symbols_cache.h | 4 +- src/libserver/task.c | 19 +- src/lua/lua_common.h | 2 - src/lua/lua_config.c | 323 ++++++++++++---------------------- 6 files changed, 258 insertions(+), 237 deletions(-) diff --git a/src/libserver/cfg_file.h b/src/libserver/cfg_file.h index 86aad5117..9e946bfca 100644 --- a/src/libserver/cfg_file.h +++ b/src/libserver/cfg_file.h @@ -343,8 +343,6 @@ struct rspamd_config { GList *statfiles; /**< list of all statfiles in config file order */ GHashTable *classifiers_symbols; /**< hashtable indexed by symbol name of classifiers */ GHashTable * cfg_params; /**< all cfg params indexed by its name in this structure */ - GList *pre_filters; /**< list of pre-processing lua filters */ - GList *post_filters; /**< list of post-processing lua filters */ gchar *dynamic_conf; /**< path to dynamic configuration */ ucl_object_t *current_dynamic_conf; /**< currently loaded dynamic configuration */ GHashTable * domain_settings; /**< settings per-domains */ diff --git a/src/libserver/symbols_cache.c b/src/libserver/symbols_cache.c index 8182dbf5e..11ecc3249 100644 --- a/src/libserver/symbols_cache.c +++ b/src/libserver/symbols_cache.c @@ -62,6 +62,9 @@ struct symbols_cache { GHashTable *items_by_symbol; struct symbols_cache_order *items_by_order; GPtrArray *items_by_id; + GPtrArray *prefilters; + GPtrArray *postfilters; + GPtrArray *composites; GList *delayed_deps; GList *delayed_conditions; rspamd_mempool_t *static_pool; @@ -128,7 +131,16 @@ struct delayed_cache_condition { struct cache_savepoint { guchar *processed_bits; - guint pass; + enum { + RSPAMD_CACHE_PASS_INIT = 0, + RSPAMD_CACHE_PASS_PREFILTERS, + RSPAMD_CACHE_PASS_WAIT_PREFILTERS, + RSPAMD_CACHE_PASS_FILTERS, + RSPAMD_CACHE_PASS_WAIT_FILTERS, + RSPAMD_CACHE_PASS_POSTFILTERS, + RSPAMD_CACHE_PASS_WAIT_POSTFILTERS, + RSPAMD_CACHE_PASS_DONE, + } pass; guint version; struct metric_result *rs; gdouble lim; @@ -278,13 +290,16 @@ rspamd_symbols_cache_resort (struct symbols_cache *cache) { struct symbols_cache_order *ord; guint i; - gpointer p; + struct cache_item *it; ord = rspamd_symbols_cache_order_new (cache->used_items); for (i = 0; i < cache->used_items; i ++) { - p = g_ptr_array_index (cache->items_by_id, i); - g_ptr_array_add (ord->d, p); + it = g_ptr_array_index (cache->items_by_id, i); + + if (!(it->type & (SYMBOL_TYPE_PREFILTER|SYMBOL_TYPE_POSTFILTER|SYMBOL_TYPE_COMPOSITE))) { + g_ptr_array_add (ord->d, it); + } } g_ptr_array_sort_with_data (ord->d, cache_logic_cmp, cache); @@ -386,6 +401,9 @@ rspamd_symbols_cache_post_init (struct symbols_cache *cache) } } } + + g_ptr_array_sort_with_data (cache->prefilters, cache_logic_cmp, cache); + g_ptr_array_sort_with_data (cache->postfilters, cache_logic_cmp, cache); } static gboolean @@ -660,6 +678,16 @@ rspamd_symbols_cache_add_symbol (struct symbols_cache *cache, g_hash_table_insert (cache->items_by_symbol, item->symbol, item); } + if (item->type & SYMBOL_TYPE_PREFILTER) { + g_ptr_array_add (cache->prefilters, item); + } + else if (item->type & SYMBOL_TYPE_POSTFILTER) { + g_ptr_array_add (cache->postfilters, item); + } + else if (item->type & SYMBOL_TYPE_COMPOSITE) { + g_ptr_array_add (cache->composites, item); + } + return item->id; } @@ -677,6 +705,13 @@ rspamd_symbols_cache_add_condition (struct symbols_cache *cache, gint id, item = g_ptr_array_index (cache->items_by_id, id); + if (item->type & (SYMBOL_TYPE_POSTFILTER|SYMBOL_TYPE_PREFILTER)) { + msg_err_cache ("conditions are not supported for prefilters and " + "postfilters %s", item->symbol); + + return FALSE; + } + if (item->condition_cb != -1) { /* We already have a condition, so we need to remove old cbref first */ msg_warn_cache ("rewriting condition for symbol %s", item->symbol); @@ -767,6 +802,9 @@ rspamd_symbols_cache_destroy (struct symbols_cache *cache) g_hash_table_destroy (cache->items_by_symbol); rspamd_mempool_delete (cache->static_pool); g_ptr_array_free (cache->items_by_id, TRUE); + g_ptr_array_free (cache->prefilters, TRUE); + g_ptr_array_free (cache->postfilters, TRUE); + g_ptr_array_free (cache->composites, TRUE); REF_RELEASE (cache->items_by_order); g_slice_free1 (sizeof (*cache), cache); } @@ -783,6 +821,9 @@ rspamd_symbols_cache_new (struct rspamd_config *cfg) cache->items_by_symbol = g_hash_table_new (rspamd_str_hash, rspamd_str_equal); cache->items_by_id = g_ptr_array_new (); + cache->prefilters = g_ptr_array_new (); + cache->postfilters = g_ptr_array_new (); + cache->composites = g_ptr_array_new (); cache->mtx = rspamd_mempool_get_mutex (cache->static_pool); cache->reload_time = CACHE_RELOAD_TIME; cache->total_freq = 1; @@ -1256,14 +1297,18 @@ rspamd_symbols_cache_make_checkpoint (struct rspamd_task *task, struct symbols_cache *cache) { struct cache_savepoint *checkpoint; + guint nitems; + + nitems = cache->items_by_id->len - cache->postfilters->len - + cache->prefilters->len - cache->composites->len; - if (cache->items_by_id->len != cache->items_by_order->d->len) { + if (nitems != cache->items_by_order->d->len) { /* * Cache has been modified, need to resort it */ msg_info_cache ("symbols cache has been modified since last check:" " old items: %ud, new items: %ud", - cache->items_by_order->d->len, cache->items_by_id->len); + cache->items_by_order->d->len, nitems); rspamd_symbols_cache_resort (cache); } @@ -1280,6 +1325,7 @@ rspamd_symbols_cache_make_checkpoint (struct rspamd_task *task, rspamd_symbols_cache_order_unref, checkpoint->order); rspamd_mempool_add_destructor (task->task_pool, rspamd_ptr_array_free_hard, checkpoint->waitq); + checkpoint->pass = RSPAMD_CACHE_PASS_INIT; task->checkpoint = checkpoint; rspamd_create_metric_result (task, DEFAULT_METRIC); @@ -1388,6 +1434,7 @@ rspamd_symbols_cache_process_symbols (struct rspamd_task * task, struct cache_savepoint *checkpoint; gint i; gdouble total_microseconds = 0; + gboolean all_done; const gdouble max_microseconds = 3e5; guint start_events_pending; @@ -1410,8 +1457,40 @@ rspamd_symbols_cache_process_symbols (struct rspamd_task * task, msg_debug_task ("symbols processing stage at pass: %d", checkpoint->pass); start_events_pending = rspamd_session_events_pending (task->s); - if (checkpoint->pass == 0) { + switch (checkpoint->pass) { + case RSPAMD_CACHE_PASS_INIT: + /* Check for prefilters */ + for (i = 0; i < (gint)cache->prefilters->len; i ++) { + item = g_ptr_array_index (cache->prefilters, i); + if (!isset (checkpoint->processed_bits, item->id * 2)) { + rspamd_symbols_cache_check_symbol (task, cache, item, + checkpoint, &total_microseconds); + } + } + checkpoint->pass = RSPAMD_CACHE_PASS_WAIT_PREFILTERS; + break; + + case RSPAMD_CACHE_PASS_PREFILTERS: + case RSPAMD_CACHE_PASS_WAIT_PREFILTERS: + all_done = TRUE; + + for (i = 0; i < (gint)cache->prefilters->len; i ++) { + item = g_ptr_array_index (cache->prefilters, i); + + if (!isset (checkpoint->processed_bits, item->id * 2 + 1)) { + all_done = FALSE; + break; + } + } + + if (all_done) { + checkpoint->pass = RSPAMD_CACHE_PASS_FILTERS; + + return rspamd_symbols_cache_process_symbols (task, cache); + } + break; + case RSPAMD_CACHE_PASS_FILTERS: /* * On the first pass we check symbols that do not have dependencies * If we figure out symbol that has no dependencies satisfied, then @@ -1470,9 +1549,10 @@ rspamd_symbols_cache_process_symbols (struct rspamd_task * task, } } - checkpoint->pass ++; - } - else { + checkpoint->pass = RSPAMD_CACHE_PASS_WAIT_FILTERS; + break; + + case RSPAMD_CACHE_PASS_WAIT_FILTERS: /* We just go through the blocked symbols and check if they are ready */ for (i = 0; i < (gint)checkpoint->waitq->len; i ++) { item = g_ptr_array_index (checkpoint->waitq, i); @@ -1502,9 +1582,52 @@ rspamd_symbols_cache_process_symbols (struct rspamd_task * task, } } } + + if (checkpoint->waitq->len == 0) { + checkpoint->pass = RSPAMD_CACHE_PASS_POSTFILTERS; + + return rspamd_symbols_cache_process_symbols (task, cache); + } + break; + + case RSPAMD_CACHE_PASS_POSTFILTERS: + /* Check for prefilters */ + for (i = 0; i < (gint)cache->postfilters->len; i ++) { + item = g_ptr_array_index (cache->postfilters, i); + + if (!isset (checkpoint->processed_bits, item->id * 2)) { + rspamd_symbols_cache_check_symbol (task, cache, item, + checkpoint, &total_microseconds); + } + } + checkpoint->pass = RSPAMD_CACHE_PASS_WAIT_POSTFILTERS; + break; + + case RSPAMD_CACHE_PASS_WAIT_POSTFILTERS: + all_done = TRUE; + + for (i = 0; i < (gint)cache->postfilters->len; i ++) { + item = g_ptr_array_index (cache->postfilters, i); + + if (!isset (checkpoint->processed_bits, item->id * 2 + 1)) { + all_done = FALSE; + break; + } + } + + if (all_done) { + checkpoint->pass = RSPAMD_CACHE_PASS_DONE; + + return TRUE; + } + break; + + case RSPAMD_CACHE_PASS_DONE: + return TRUE; + break; } - return TRUE; + return FALSE; } struct counters_cbdata { diff --git a/src/libserver/symbols_cache.h b/src/libserver/symbols_cache.h index 56008d2c1..2c7738ed2 100644 --- a/src/libserver/symbols_cache.h +++ b/src/libserver/symbols_cache.h @@ -36,7 +36,9 @@ enum rspamd_symbol_type { SYMBOL_TYPE_COMPOSITE = (1 << 5), SYMBOL_TYPE_CLASSIFIER = (1 << 6), SYMBOL_TYPE_FINE = (1 << 7), - SYMBOL_TYPE_EMPTY = (1 << 8) /* Allow execution on empty tasks */ + SYMBOL_TYPE_EMPTY = (1 << 8), /* Allow execution on empty tasks */ + SYMBOL_TYPE_PREFILTER = (1 << 9), + SYMBOL_TYPE_POSTFILTER = (1 << 10), }; /** diff --git a/src/libserver/task.c b/src/libserver/task.c index 6e0493771..7b4970943 100644 --- a/src/libserver/task.c +++ b/src/libserver/task.c @@ -538,13 +538,6 @@ rspamd_task_select_processing_stage (struct rspamd_task *task, guint stages) return RSPAMD_TASK_STAGE_DONE; } -static gboolean -rspamd_process_filters (struct rspamd_task *task) -{ - /* Process metrics symbols */ - return rspamd_symbols_cache_process_symbols (task, task->cfg->cache); -} - gboolean rspamd_task_process (struct rspamd_task *task, guint stages) { @@ -581,13 +574,11 @@ rspamd_task_process (struct rspamd_task *task, guint stages) break; case RSPAMD_TASK_STAGE_PRE_FILTERS: - rspamd_lua_call_pre_filters (task); + rspamd_symbols_cache_process_symbols (task, task->cfg->cache); break; case RSPAMD_TASK_STAGE_FILTERS: - if (!rspamd_process_filters (task)) { - ret = FALSE; - } + rspamd_symbols_cache_process_symbols (task, task->cfg->cache); break; case RSPAMD_TASK_STAGE_CLASSIFIERS: @@ -607,7 +598,8 @@ rspamd_task_process (struct rspamd_task *task, guint stages) break; case RSPAMD_TASK_STAGE_POST_FILTERS: - rspamd_lua_call_post_filters (task); + rspamd_symbols_cache_process_symbols (task, task->cfg->cache); + if ((task->flags & RSPAMD_TASK_FLAG_LEARN_AUTO) && !RSPAMD_TASK_IS_EMPTY (task) && !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM|RSPAMD_TASK_FLAG_LEARN_HAM))) { @@ -674,9 +666,6 @@ rspamd_task_process (struct rspamd_task *task, guint stages) msg_debug_task ("completed stage %d", st); task->processed_stages |= st; - /* Reset checkpoint */ - task->checkpoint = NULL; - /* Tail recursion */ return rspamd_task_process (task, stages); } diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index ba389d7a6..2ec235485 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -256,8 +256,6 @@ void luaopen_fann (lua_State *L); void luaopen_sqlite3 (lua_State *L); void luaopen_cryptobox (lua_State *L); -void rspamd_lua_call_post_filters (struct rspamd_task *task); -void rspamd_lua_call_pre_filters (struct rspamd_task *task); void rspamd_lua_dostring (const gchar *line); /* Classify functions */ diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index c3834759c..d16c645a6 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -664,212 +664,6 @@ rspamd_compare_order_func (gconstpointer a, gconstpointer b) return cb2->order - cb1->order; } -void -rspamd_lua_call_post_filters (struct rspamd_task *task) -{ - struct lua_callback_data *cd; - struct rspamd_task **ptask; - GList *cur; - lua_State *L = task->cfg->lua_state; - gint err_idx; - GString *tb; - - if (task->checkpoint == NULL) { - task->checkpoint = GUINT_TO_POINTER (0x1); - } - else { - /* Do not process if done */ - return; - } - - cur = task->cfg->post_filters; - while (cur) { - lua_pushcfunction (L, &rspamd_lua_traceback); - err_idx = lua_gettop (L); - - cd = cur->data; - if (cd->cb_is_ref) { - lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); - } - else { - lua_getglobal (cd->L, cd->callback.name); - } - ptask = lua_newuserdata (cd->L, sizeof (struct rspamd_task *)); - rspamd_lua_setclass (cd->L, "rspamd{task}", -1); - *ptask = task; - - if (lua_pcall (cd->L, 1, 0, err_idx) != 0) { - tb = lua_touserdata (L, -1); - msg_err_task ("call to post-filter failed: %v", tb); - g_string_free (tb, TRUE); - lua_pop (L, 1); - } - - lua_pop (L, 1); /* Error function */ - - cur = g_list_next (cur); - } -} - -static gint -lua_config_register_post_filter (lua_State *L) -{ - struct rspamd_config *cfg = lua_check_config (L, 1); - struct lua_callback_data *cd; - gint order = 0; - - if (cfg) { - cd = - rspamd_mempool_alloc (cfg->cfg_pool, - sizeof (struct lua_callback_data)); - cd->magic = rspamd_lua_callback_magic; - - if (lua_type (L, 3) == LUA_TNUMBER) { - order = lua_tonumber (L, 3); - } - - if (lua_type (L, 2) == LUA_TSTRING) { - cd->callback.name = rspamd_mempool_strdup (cfg->cfg_pool, - luaL_checkstring (L, 2)); - cd->cb_is_ref = FALSE; - } - else { - lua_pushvalue (L, 2); - /* Get a reference */ - cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); - cd->cb_is_ref = TRUE; - } - - cd->L = L; - cd->order = order; - cfg->post_filters = g_list_insert_sorted (cfg->post_filters, cd, - rspamd_compare_order_func); - rspamd_mempool_add_destructor (cfg->cfg_pool, - (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol, - cd); - } - - return 0; -} - -void -rspamd_lua_call_pre_filters (struct rspamd_task *task) -{ - struct lua_callback_data *cd; - struct rspamd_task **ptask; - GList *cur; - lua_State *L = task->cfg->lua_state; - gint err_idx; - GString *tb; - - if (task->checkpoint == NULL) { - task->checkpoint = GUINT_TO_POINTER (0x1); - } - else { - /* Do not process if done */ - return; - } - - cur = task->cfg->pre_filters; - while (cur) { - lua_pushcfunction (L, &rspamd_lua_traceback); - err_idx = lua_gettop (L); - - cd = cur->data; - if (cd->cb_is_ref) { - lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); - } - else { - lua_getglobal (cd->L, cd->callback.name); - } - ptask = lua_newuserdata (cd->L, sizeof (struct rspamd_task *)); - rspamd_lua_setclass (cd->L, "rspamd{task}", -1); - *ptask = task; - - if (lua_pcall (cd->L, 1, 0, err_idx) != 0) { - tb = lua_touserdata (L, -1); - msg_err_task ("call to pre-filter failed: %v", tb); - g_string_free (tb, TRUE); - lua_pop (L, 1); - } - - lua_pop (L, 1); /* Error function */ - - if (task->pre_result.action != METRIC_ACTION_MAX) { - /* Stop processing on reaching some pre-result */ - break; - } - - cur = g_list_next (cur); - } -} - -static gint -lua_config_register_pre_filter (lua_State *L) -{ - struct rspamd_config *cfg = lua_check_config (L, 1); - struct lua_callback_data *cd; - gint order = 0; - - if (cfg) { - cd = - rspamd_mempool_alloc (cfg->cfg_pool, - sizeof (struct lua_callback_data)); - cd->magic = rspamd_lua_callback_magic; - - if (lua_type (L, 3) == LUA_TNUMBER) { - order = lua_tonumber (L, 3); - } - - if (lua_type (L, 2) == LUA_TSTRING) { - cd->callback.name = rspamd_mempool_strdup (cfg->cfg_pool, - luaL_checkstring (L, 2)); - cd->cb_is_ref = FALSE; - } - else { - lua_pushvalue (L, 2); - /* Get a reference */ - cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); - cd->cb_is_ref = TRUE; - } - - cd->L = L; - cd->order = order; - cfg->pre_filters = g_list_insert_sorted (cfg->pre_filters, cd, - rspamd_compare_order_func); - rspamd_mempool_add_destructor (cfg->cfg_pool, - (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol, - cd); - } - - return 0; -} - -static gint -lua_config_get_key (lua_State *L) -{ - struct rspamd_config *cfg = lua_check_config (L, 1); - const gchar *name; - size_t namelen; - const ucl_object_t *val; - - name = luaL_checklstring(L, 2, &namelen); - if (name && cfg) { - val = ucl_object_lookup_len(cfg->rcl_obj, name, namelen); - if (val != NULL) { - ucl_object_push_lua (L, val, val->type != UCL_ARRAY); - } - else { - lua_pushnil (L); - } - } - else { - return luaL_error (L, "invalid arguments"); - } - - return 1; -} - static void lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud) { @@ -1003,6 +797,117 @@ rspamd_register_symbol_fromlua (lua_State *L, return ret; } +static gint +lua_config_register_post_filter (lua_State *L) +{ + struct rspamd_config *cfg = lua_check_config (L, 1); + gint order = 0, cbref, ret; + + if (cfg) { + if (lua_type (L, 3) == LUA_TNUMBER) { + order = lua_tonumber (L, 3); + } + + if (lua_type (L, 2) == LUA_TFUNCTION) { + lua_pushvalue (L, 2); + /* Get a reference */ + cbref = luaL_ref (L, LUA_REGISTRYINDEX); + } + else { + return luaL_error (L, "invalid type for callback: %s", + lua_typename (L, lua_type (L, 2))); + } + + msg_warn_config ("register_post_filter function is deprecated, " + "use register_symbol instead"); + + ret = rspamd_register_symbol_fromlua (L, + cfg, + NULL, + cbref, + 1.0, + order, + SYMBOL_TYPE_POSTFILTER|SYMBOL_TYPE_CALLBACK, + -1, + FALSE); + + lua_pushboolean (L, ret); + } + else { + return luaL_error (L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_register_pre_filter (lua_State *L) +{ + struct rspamd_config *cfg = lua_check_config (L, 1); + gint order = 0, cbref, ret; + + if (cfg) { + if (lua_type (L, 3) == LUA_TNUMBER) { + order = lua_tonumber (L, 3); + } + + if (lua_type (L, 2) == LUA_TFUNCTION) { + lua_pushvalue (L, 2); + /* Get a reference */ + cbref = luaL_ref (L, LUA_REGISTRYINDEX); + } + else { + return luaL_error (L, "invalid type for callback: %s", + lua_typename (L, lua_type (L, 2))); + } + + msg_warn_config ("register_pre_filter function is deprecated, " + "use register_symbol instead"); + + ret = rspamd_register_symbol_fromlua (L, + cfg, + NULL, + cbref, + 1.0, + order, + SYMBOL_TYPE_PREFILTER|SYMBOL_TYPE_CALLBACK, + -1, + FALSE); + + lua_pushboolean (L, ret); + } + else { + return luaL_error (L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_get_key (lua_State *L) +{ + struct rspamd_config *cfg = lua_check_config (L, 1); + const gchar *name; + size_t namelen; + const ucl_object_t *val; + + name = luaL_checklstring(L, 2, &namelen); + if (name && cfg) { + val = ucl_object_lookup_len(cfg->rcl_obj, name, namelen); + if (val != NULL) { + ucl_object_push_lua (L, val, val->type != UCL_ARRAY); + } + else { + lua_pushnil (L); + } + } + else { + return luaL_error (L, "invalid arguments"); + } + + return 1; +} + static gint lua_parse_symbol_type (const gchar *str) { @@ -1018,6 +923,12 @@ lua_parse_symbol_type (const gchar *str) else if (strcmp (str, "normal") == 0) { ret = SYMBOL_TYPE_NORMAL; } + else if (strcmp (str, "prefilter") == 0) { + ret = SYMBOL_TYPE_PREFILTER; + } + else if (strcmp (str, "postfilter") == 0) { + ret = SYMBOL_TYPE_POSTFILTER; + } else { msg_warn ("bad type: %s", str); } -- 2.39.5