From 131c74bd2c3419c6ca6dcdd259376f489f305205 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Fri, 1 Mar 2019 09:59:53 +0000 Subject: [PATCH] [Rework] Enable explicit coroutines symbols --- src/libserver/rspamd_symcache.h | 1 + src/lua/lua_config.c | 208 +++++++++++++++++++++++++++++--- 2 files changed, 193 insertions(+), 16 deletions(-) diff --git a/src/libserver/rspamd_symcache.h b/src/libserver/rspamd_symcache.h index 69eac1f01..a038d6a9d 100644 --- a/src/libserver/rspamd_symcache.h +++ b/src/libserver/rspamd_symcache.h @@ -50,6 +50,7 @@ enum rspamd_symbol_type { SYMBOL_TYPE_MIME_ONLY = (1 << 15), /* Symbol is mime only */ SYMBOL_TYPE_EXPLICIT_DISABLE = (1 << 16), /* Symbol should be disabled explicitly only */ SYMBOL_TYPE_IGNORE_PASSTHROUGH = (1 << 17), /* Symbol ignores passthrough result */ + SYMBOL_TYPE_USE_CORO = (1 << 18), /* Symbol uses lua coroutines */ }; /** diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index b5e668e71..9d7f3341c 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -1222,6 +1222,142 @@ lua_metric_symbol_callback (struct rspamd_task *task, g_assert (lua_gettop (L) == level - 1); } +static void lua_metric_symbol_callback_return (struct thread_entry *thread_entry, + int ret); + +static void lua_metric_symbol_callback_error (struct thread_entry *thread_entry, + int ret, + const char *msg); + +static void +lua_metric_symbol_callback_coro (struct rspamd_task *task, + struct rspamd_symcache_item *item, + gpointer ud) +{ + struct lua_callback_data *cd = ud; + struct rspamd_task **ptask; + struct thread_entry *thread_entry; + + rspamd_symcache_item_async_inc (task, item, "lua coro symbol"); + thread_entry = lua_thread_pool_get_for_task (task); + + g_assert(thread_entry->cd == NULL); + thread_entry->cd = cd; + + lua_State *thread = thread_entry->lua_state; + cd->stack_level = lua_gettop (thread); + cd->item = item; + + if (cd->cb_is_ref) { + lua_rawgeti (thread, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal (thread, cd->callback.name); + } + + ptask = lua_newuserdata (thread, sizeof (struct rspamd_task *)); + rspamd_lua_setclass (thread, "rspamd{task}", -1); + *ptask = task; + + thread_entry->finish_callback = lua_metric_symbol_callback_return; + thread_entry->error_callback = lua_metric_symbol_callback_error; + + lua_thread_call (thread_entry, 1); +} + +static void +lua_metric_symbol_callback_error (struct thread_entry *thread_entry, + int ret, + const char *msg) +{ + struct lua_callback_data *cd = thread_entry->cd; + struct rspamd_task *task = thread_entry->task; + msg_err_task ("call to coroutine (%s) failed (%d): %s", cd->symbol, ret, msg); + + rspamd_symcache_item_async_dec_check (task, cd->item, "lua coro symbol"); +} + +static void +lua_metric_symbol_callback_return (struct thread_entry *thread_entry, int ret) +{ + struct lua_callback_data *cd = thread_entry->cd; + struct rspamd_task *task = thread_entry->task; + int nresults; + struct rspamd_symbol_result *s; + + (void)ret; + + lua_State *L = thread_entry->lua_state; + + nresults = lua_gettop (L) - cd->stack_level; + + if (nresults >= 1) { + /* Function returned boolean, so maybe we need to insert result? */ + gint res = 0; + gint i; + gdouble flag = 1.0; + gint type; + + type = lua_type (L, cd->stack_level + 1); + + if (type == LUA_TBOOLEAN) { + res = lua_toboolean (L, cd->stack_level + 1); + } + else if (type == LUA_TFUNCTION) { + g_assert_not_reached (); + } + else { + res = lua_tonumber (L, cd->stack_level + 1); + } + + if (res) { + gint first_opt = 2; + + if (lua_type (L, cd->stack_level + 2) == LUA_TNUMBER) { + flag = lua_tonumber (L, cd->stack_level + 2); + /* Shift opt index */ + first_opt = 3; + } + else { + flag = res; + } + + s = rspamd_task_insert_result (task, cd->symbol, flag, NULL); + + if (s) { + guint last_pos = lua_gettop (L); + + for (i = cd->stack_level + first_opt; i <= last_pos; i++) { + if (lua_type (L, i) == LUA_TSTRING) { + const char *opt = lua_tostring (L, i); + + rspamd_task_add_result_option (task, s, opt); + } + else if (lua_type (L, i) == LUA_TTABLE) { + lua_pushvalue (L, i); + + for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) { + const char *opt = lua_tostring (L, -1); + + rspamd_task_add_result_option (task, s, opt); + } + + lua_pop (L, 1); + } + } + } + + } + + lua_pop (L, nresults); + } + + g_assert (lua_gettop (L) == cd->stack_level); /* we properly cleaned up the stack */ + + cd->stack_level = 0; + rspamd_symcache_item_async_dec_check (task, cd->item, "lua coro symbol"); +} + static gint rspamd_register_symbol_fromlua (lua_State *L, struct rspamd_config *cfg, @@ -1255,6 +1391,10 @@ rspamd_register_symbol_fromlua (lua_State *L, } if (ref != -1) { + if (type & SYMBOL_TYPE_USE_CORO) { + /* Coroutines are incompatible with squeezing */ + no_squeeze = TRUE; + } /* * We call for routine called lua_squeeze_rules.squeeze_rule if it exists */ @@ -1322,15 +1462,27 @@ rspamd_register_symbol_fromlua (lua_State *L, cd->symbol = rspamd_mempool_strdup (cfg->cfg_pool, name); } - ret = rspamd_symcache_add_symbol (cfg->cache, - name, - priority, - lua_metric_symbol_callback, - cd, - type, - parent); + if (type & SYMBOL_TYPE_USE_CORO) { + ret = rspamd_symcache_add_symbol (cfg->cache, + name, + priority, + lua_metric_symbol_callback_coro, + cd, + type, + parent); + } + else { + ret = rspamd_symcache_add_symbol (cfg->cache, + name, + priority, + lua_metric_symbol_callback, + cd, + type, + parent); + } + rspamd_mempool_add_destructor (cfg->cfg_pool, - (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol, + (rspamd_mempool_destruct_t) lua_destroy_cfg_symbol, cd); } } @@ -1346,13 +1498,24 @@ rspamd_register_symbol_fromlua (lua_State *L, cd->symbol = rspamd_mempool_strdup (cfg->cfg_pool, name); } - ret = rspamd_symcache_add_symbol (cfg->cache, - name, - priority, - lua_metric_symbol_callback, - cd, - type, - parent); + if (type & SYMBOL_TYPE_USE_CORO) { + ret = rspamd_symcache_add_symbol (cfg->cache, + name, + priority, + lua_metric_symbol_callback_coro, + cd, + type, + parent); + } + else { + ret = rspamd_symcache_add_symbol (cfg->cache, + name, + priority, + lua_metric_symbol_callback, + cd, + type, + parent); + } rspamd_mempool_add_destructor (cfg->cfg_pool, (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol, cd); @@ -1529,6 +1692,9 @@ lua_parse_symbol_flags (const gchar *str) if (strstr (str, "explicit_disable") != NULL) { ret |= SYMBOL_TYPE_EXPLICIT_DISABLE; } + if (strstr (str, "coro") != NULL) { + ret |= SYMBOL_TYPE_USE_CORO; + } } return ret; @@ -2423,7 +2589,7 @@ lua_config_newindex (lua_State *L) LUA_TRACE_POINT; struct rspamd_config *cfg = lua_check_config (L, 1); const gchar *name; - gint id, nshots; + gint id, nshots, flags = 0; gboolean optional = FALSE, no_squeeze = FALSE; name = luaL_checkstring (L, 2); @@ -2458,6 +2624,7 @@ lua_config_newindex (lua_State *L) * "weight" - optional weight * "priority" - optional priority * "type" - optional type (normal, virtual, callback) + * "flags" - optional flags * -- Metric options * "score" - optional default score (overridden by metric) * "group" - optional default group @@ -2510,6 +2677,15 @@ lua_config_newindex (lua_State *L) } lua_pop (L, 1); + lua_pushstring (L, "flags"); + lua_gettable (L, -2); + + if (lua_type (L, -1) == LUA_TSTRING) { + type_str = lua_tostring (L, -1); + type |= lua_parse_symbol_flags (type_str); + } + lua_pop (L, 1); + lua_pushstring (L, "condition"); lua_gettable (L, -2); -- 2.39.5