]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Enable explicit coroutines symbols
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 1 Mar 2019 09:59:53 +0000 (09:59 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 1 Mar 2019 09:59:53 +0000 (09:59 +0000)
src/libserver/rspamd_symcache.h
src/lua/lua_config.c

index 69eac1f01fe3298db23b6e35ed2b2865542abf25..a038d6a9d016199b0e23ef379f330d078229aa71 100644 (file)
@@ -50,6 +50,7 @@ enum rspamd_symbol_type {
        SYMBOL_TYPE_MIME_ONLY = (1 << 15), /* Symbol is mime only */
        SYMBOL_TYPE_EXPLICIT_DISABLE = (1 << 16), /* Symbol should be disabled explicitly only */
        SYMBOL_TYPE_IGNORE_PASSTHROUGH = (1 << 17), /* Symbol ignores passthrough result */
+       SYMBOL_TYPE_USE_CORO = (1 << 18), /* Symbol uses lua coroutines */
 };
 
 /**
index b5e668e717c69199ac1b8c59b2f06a88046c6cf5..9d7f3341cb06d6a5988b1dd84bad697ae50fbbe7 100644 (file)
@@ -1222,6 +1222,142 @@ lua_metric_symbol_callback (struct rspamd_task *task,
        g_assert (lua_gettop (L) == level - 1);
 }
 
+static void lua_metric_symbol_callback_return (struct thread_entry *thread_entry,
+                                                                                          int ret);
+
+static void lua_metric_symbol_callback_error (struct thread_entry *thread_entry,
+                                                                                         int ret,
+                                                                                         const char *msg);
+
+static void
+lua_metric_symbol_callback_coro (struct rspamd_task *task,
+                                                       struct rspamd_symcache_item *item,
+                                                       gpointer ud)
+{
+       struct lua_callback_data *cd = ud;
+       struct rspamd_task **ptask;
+       struct thread_entry *thread_entry;
+
+       rspamd_symcache_item_async_inc (task, item, "lua coro symbol");
+       thread_entry = lua_thread_pool_get_for_task (task);
+
+       g_assert(thread_entry->cd == NULL);
+       thread_entry->cd = cd;
+
+       lua_State *thread = thread_entry->lua_state;
+       cd->stack_level = lua_gettop (thread);
+       cd->item = item;
+
+       if (cd->cb_is_ref) {
+               lua_rawgeti (thread, LUA_REGISTRYINDEX, cd->callback.ref);
+       }
+       else {
+               lua_getglobal (thread, cd->callback.name);
+       }
+
+       ptask = lua_newuserdata (thread, sizeof (struct rspamd_task *));
+       rspamd_lua_setclass (thread, "rspamd{task}", -1);
+       *ptask = task;
+
+       thread_entry->finish_callback = lua_metric_symbol_callback_return;
+       thread_entry->error_callback = lua_metric_symbol_callback_error;
+
+       lua_thread_call (thread_entry, 1);
+}
+
+static void
+lua_metric_symbol_callback_error (struct thread_entry *thread_entry,
+                                                                 int ret,
+                                                                 const char *msg)
+{
+       struct lua_callback_data *cd = thread_entry->cd;
+       struct rspamd_task *task = thread_entry->task;
+       msg_err_task ("call to coroutine (%s) failed (%d): %s", cd->symbol, ret, msg);
+
+       rspamd_symcache_item_async_dec_check (task, cd->item, "lua coro symbol");
+}
+
+static void
+lua_metric_symbol_callback_return (struct thread_entry *thread_entry, int ret)
+{
+       struct lua_callback_data *cd = thread_entry->cd;
+       struct rspamd_task *task = thread_entry->task;
+       int nresults;
+       struct rspamd_symbol_result *s;
+
+       (void)ret;
+
+       lua_State *L = thread_entry->lua_state;
+
+       nresults = lua_gettop (L) - cd->stack_level;
+
+       if (nresults >= 1) {
+               /* Function returned boolean, so maybe we need to insert result? */
+               gint res = 0;
+               gint i;
+               gdouble flag = 1.0;
+               gint type;
+
+               type = lua_type (L, cd->stack_level + 1);
+
+               if (type == LUA_TBOOLEAN) {
+                       res = lua_toboolean (L, cd->stack_level + 1);
+               }
+               else if (type == LUA_TFUNCTION) {
+                       g_assert_not_reached ();
+               }
+               else {
+                       res = lua_tonumber (L, cd->stack_level + 1);
+               }
+
+               if (res) {
+                       gint first_opt = 2;
+
+                       if (lua_type (L, cd->stack_level + 2) == LUA_TNUMBER) {
+                               flag = lua_tonumber (L, cd->stack_level + 2);
+                               /* Shift opt index */
+                               first_opt = 3;
+                       }
+                       else {
+                               flag = res;
+                       }
+
+                       s = rspamd_task_insert_result (task, cd->symbol, flag, NULL);
+
+                       if (s) {
+                               guint last_pos = lua_gettop (L);
+
+                               for (i = cd->stack_level + first_opt; i <= last_pos; i++) {
+                                       if (lua_type (L, i) == LUA_TSTRING) {
+                                               const char *opt = lua_tostring (L, i);
+
+                                               rspamd_task_add_result_option (task, s, opt);
+                                       }
+                                       else if (lua_type (L, i) == LUA_TTABLE) {
+                                               lua_pushvalue (L, i);
+
+                                               for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+                                                       const char *opt = lua_tostring (L, -1);
+
+                                                       rspamd_task_add_result_option (task, s, opt);
+                                               }
+
+                                               lua_pop (L, 1);
+                                       }
+                               }
+                       }
+
+               }
+
+               lua_pop (L, nresults);
+       }
+
+       g_assert (lua_gettop (L) == cd->stack_level); /* we properly cleaned up the stack */
+
+       cd->stack_level = 0;
+       rspamd_symcache_item_async_dec_check (task, cd->item, "lua coro symbol");
+}
+
 static gint
 rspamd_register_symbol_fromlua (lua_State *L,
                struct rspamd_config *cfg,
@@ -1255,6 +1391,10 @@ rspamd_register_symbol_fromlua (lua_State *L,
        }
 
        if (ref != -1) {
+               if (type & SYMBOL_TYPE_USE_CORO) {
+                       /* Coroutines are incompatible with squeezing */
+                       no_squeeze = TRUE;
+               }
                /*
                 * We call for routine called lua_squeeze_rules.squeeze_rule if it exists
                 */
@@ -1322,15 +1462,27 @@ rspamd_register_symbol_fromlua (lua_State *L,
                                        cd->symbol = rspamd_mempool_strdup (cfg->cfg_pool, name);
                                }
 
-                               ret = rspamd_symcache_add_symbol (cfg->cache,
-                                               name,
-                                               priority,
-                                               lua_metric_symbol_callback,
-                                               cd,
-                                               type,
-                                               parent);
+                               if (type & SYMBOL_TYPE_USE_CORO) {
+                                       ret = rspamd_symcache_add_symbol (cfg->cache,
+                                                       name,
+                                                       priority,
+                                                       lua_metric_symbol_callback_coro,
+                                                       cd,
+                                                       type,
+                                                       parent);
+                               }
+                               else {
+                                       ret = rspamd_symcache_add_symbol (cfg->cache,
+                                                       name,
+                                                       priority,
+                                                       lua_metric_symbol_callback,
+                                                       cd,
+                                                       type,
+                                                       parent);
+                               }
+
                                rspamd_mempool_add_destructor (cfg->cfg_pool,
-                                               (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
+                                               (rspamd_mempool_destruct_t) lua_destroy_cfg_symbol,
                                                cd);
                        }
                }
@@ -1346,13 +1498,24 @@ rspamd_register_symbol_fromlua (lua_State *L,
                                cd->symbol = rspamd_mempool_strdup (cfg->cfg_pool, name);
                        }
 
-                       ret = rspamd_symcache_add_symbol (cfg->cache,
-                                       name,
-                                       priority,
-                                       lua_metric_symbol_callback,
-                                       cd,
-                                       type,
-                                       parent);
+                       if (type & SYMBOL_TYPE_USE_CORO) {
+                               ret = rspamd_symcache_add_symbol (cfg->cache,
+                                               name,
+                                               priority,
+                                               lua_metric_symbol_callback_coro,
+                                               cd,
+                                               type,
+                                               parent);
+                       }
+                       else {
+                               ret = rspamd_symcache_add_symbol (cfg->cache,
+                                               name,
+                                               priority,
+                                               lua_metric_symbol_callback,
+                                               cd,
+                                               type,
+                                               parent);
+                       }
                        rspamd_mempool_add_destructor (cfg->cfg_pool,
                                        (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
                                        cd);
@@ -1529,6 +1692,9 @@ lua_parse_symbol_flags (const gchar *str)
                if (strstr (str, "explicit_disable") != NULL) {
                        ret |= SYMBOL_TYPE_EXPLICIT_DISABLE;
                }
+               if (strstr (str, "coro") != NULL) {
+                       ret |= SYMBOL_TYPE_USE_CORO;
+               }
        }
 
        return ret;
@@ -2423,7 +2589,7 @@ lua_config_newindex (lua_State *L)
        LUA_TRACE_POINT;
        struct rspamd_config *cfg = lua_check_config (L, 1);
        const gchar *name;
-       gint id, nshots;
+       gint id, nshots, flags = 0;
        gboolean optional = FALSE, no_squeeze = FALSE;
 
        name = luaL_checkstring (L, 2);
@@ -2458,6 +2624,7 @@ lua_config_newindex (lua_State *L)
                         * "weight" - optional weight
                         * "priority" - optional priority
                         * "type" - optional type (normal, virtual, callback)
+                        * "flags" - optional flags
                         * -- Metric options
                         * "score" - optional default score (overridden by metric)
                         * "group" - optional default group
@@ -2510,6 +2677,15 @@ lua_config_newindex (lua_State *L)
                        }
                        lua_pop (L, 1);
 
+                       lua_pushstring (L, "flags");
+                       lua_gettable (L, -2);
+
+                       if (lua_type (L, -1) == LUA_TSTRING) {
+                               type_str = lua_tostring (L, -1);
+                               type |= lua_parse_symbol_flags (type_str);
+                       }
+                       lua_pop (L, 1);
+
                        lua_pushstring (L, "condition");
                        lua_gettable (L, -2);