]> source.dussan.org Git - rspamd.git/commitdiff
[Project] coroutine threaded model for API calls: thread pool
authorMikhail Galanin <mgalanin@mimecast.com>
Wed, 8 Aug 2018 08:01:49 +0000 (09:01 +0100)
committerMikhail Galanin <mgalanin@mimecast.com>
Wed, 8 Aug 2018 08:06:27 +0000 (09:06 +0100)
src/libserver/cfg_file.h
src/libserver/cfg_utils.c
src/lua/CMakeLists.txt
src/lua/lua_config.c
src/lua/lua_thread_pool.c [new file with mode: 0644]
src/lua/lua_thread_pool.h [new file with mode: 0644]

index 75b404530fd74c583a2c1fae9ef55e321c730754..c583766c44ef928e1dd520df91007c0c109844c7 100644 (file)
@@ -380,6 +380,7 @@ struct rspamd_config {
        gchar * checksum;                               /**< real checksum of config file                                               */
        gchar * dump_checksum;                          /**< dump checksum of config file                                               */
        gpointer lua_state;                             /**< pointer to lua state                                                               */
+       gpointer lua_thread_pool;                       /**< pointer to lua thread (coroutine) pool                             */
 
        gchar * rrd_file;                               /**< rrd file to store statistics                                               */
        gchar * history_file;                           /**< file to save rolling history                                               */
index b7b9dfdee3bd990a9251930de76fd7c6462daa87..016556912253f9697bef0d4bf5b41d2396e15794 100644 (file)
@@ -20,6 +20,7 @@
 #include "uthash_strcase.h"
 #include "filter.h"
 #include "lua/lua_common.h"
+#include "lua/lua_thread_pool.h"
 #include "map.h"
 #include "map_helpers.h"
 #include "map_private.h"
@@ -175,6 +176,7 @@ rspamd_config_new (enum rspamd_config_init_flags flags)
        if (!(flags & RSPAMD_CONFIG_INIT_SKIP_LUA)) {
                cfg->lua_state = rspamd_lua_init ();
                cfg->own_lua_state = TRUE;
+               cfg->lua_thread_pool = lua_thread_pool_new (cfg->lua_state);
        }
 
        cfg->cache = rspamd_symbols_cache_new (cfg);
@@ -259,6 +261,7 @@ rspamd_config_free (struct rspamd_config *cfg)
        g_ptr_array_free (cfg->c_modules, TRUE);
 
        if (cfg->lua_state && cfg->own_lua_state) {
+               lua_thread_pool_free (cfg->lua_thread_pool);
                lua_close (cfg->lua_state);
        }
        REF_RELEASE (cfg->libs_ctx);
index 9c561c0e43c6e72b488f3c9620660c7f543390b4..ffc4b27ca86b469a25ace19f10a1090dd9198ff3 100644 (file)
@@ -25,6 +25,7 @@ SET(LUASRC                      ${CMAKE_CURRENT_SOURCE_DIR}/lua_common.c
                                          ${CMAKE_CURRENT_SOURCE_DIR}/lua_fann.c
                                          ${CMAKE_CURRENT_SOURCE_DIR}/lua_sqlite3.c
                                          ${CMAKE_CURRENT_SOURCE_DIR}/lua_cryptobox.c
-                                         ${CMAKE_CURRENT_SOURCE_DIR}/lua_map.c)
+                                         ${CMAKE_CURRENT_SOURCE_DIR}/lua_map.c
+                                         ${CMAKE_CURRENT_SOURCE_DIR}/lua_thread_pool.c)
 
 SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE)
\ No newline at end of file
index 2093cbe018121079ad9a2c75bf0ab873d9a531a8..2ee7d072d9ff2722b9ea681382627e522ee6c7ec 100644 (file)
@@ -19,6 +19,7 @@
 #include "libserver/composites.h"
 #include "libmime/lang_detection.h"
 #include "lua/lua_map.h"
+#include "lua/lua_thread_pool.h"
 #include "utlist.h"
 #include <math.h>
 
@@ -1036,6 +1037,8 @@ struct lua_callback_data {
                gint ref;
        } callback;
        gboolean cb_is_ref;
+       gpointer thread_entry;
+       gint stack_level;
        gint order;
 };
 
@@ -1184,43 +1187,69 @@ lua_watcher_callback (gpointer session_data, gpointer ud)
        lua_settop (L, err_idx - 1);
 }
 
+static void
+lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint ret);
+
 static void
 lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
 {
        struct lua_callback_data *cd = ud;
        struct rspamd_task **ptask;
-       gint level = lua_gettop (cd->L), nresults, err_idx, ret;
-       lua_State *L = cd->L;
-       GString *tb;
-       struct rspamd_symbol_result *s;
+       gint ret;
 
-       lua_pushcfunction (L, &rspamd_lua_traceback);
-       err_idx = lua_gettop (L);
+       struct thread_entry *thread_entry = lua_thread_pool_get(task->cfg->lua_thread_pool);
+       cd->thread_entry = thread_entry;
 
-       level ++;
+       lua_State *thread = thread_entry->lua_state;
+       cd->stack_level = lua_gettop (cd->L);
 
        if (cd->cb_is_ref) {
-               lua_rawgeti (L, LUA_REGISTRYINDEX, cd->callback.ref);
+               lua_rawgeti (thread, LUA_REGISTRYINDEX, cd->callback.ref);
        }
        else {
-               lua_getglobal (L, cd->callback.name);
+               lua_getglobal (thread, cd->callback.name);
        }
 
-       ptask = lua_newuserdata (L, sizeof (struct rspamd_task *));
-       rspamd_lua_setclass (L, "rspamd{task}", -1);
+       ptask = lua_newuserdata (thread, sizeof (struct rspamd_task *));
+       rspamd_lua_setclass (thread, "rspamd{task}", -1);
        *ptask = task;
 
-       if ((ret = lua_pcall (L, 1, LUA_MULTRET, err_idx)) != 0) {
-               tb = lua_touserdata (L, -1);
+       ret = lua_resume (thread, 1);
+
+       if (ret == LUA_YIELD) {
+               msg_err_task ("LUA_YIELD");
+       } else {
+               lua_metric_symbol_callback_return (task, ud, ret);
+       }
+}
+
+static void
+lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint ret)
+{
+       GString *tb;
+       struct lua_callback_data *cd = ud;
+       int nresults;
+       struct rspamd_symbol_result *s;
+       struct thread_entry *thread_entry = cd->thread_entry;
+       lua_State *thread = thread_entry->lua_state;
+
+       if (ret != 0) {
+               lua_pushcfunction (thread, rspamd_lua_traceback);
+               lua_call (thread, 0, 1);
+
+               tb = lua_touserdata (thread, -1);
                msg_err_task ("call to (%s) failed (%d): %v", cd->symbol, ret, tb);
 
                if (tb) {
                        g_string_free (tb, TRUE);
-                       lua_pop (L, 1);
+                       lua_pop (thread, 1);
                }
+               assert (lua_gettop (thread) >= cd->stack_level);
+               // maybe there is a way to recover here. For now, just remove foulty thread
+               lua_thread_pool_terminate_entry (task->cfg->lua_thread_pool, cd->thread_entry);
        }
        else {
-               nresults = lua_gettop (L) - level;
+               nresults = lua_gettop (thread) - cd->stack_level;
 
                if (nresults >= 1) {
                        /* Function returned boolean, so maybe we need to insert result? */
@@ -1230,16 +1259,16 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
                        gint type;
                        struct lua_watcher_data *wd;
 
-                       type = lua_type (cd->L, level + 1);
+                       type = lua_type (thread, cd->stack_level + 1);
 
                        if (type == LUA_TBOOLEAN) {
-                               res = lua_toboolean (L, level + 1);
+                               res = lua_toboolean (thread, cd->stack_level + 1);
                        }
                        else if (type == LUA_TFUNCTION) {
                                /* Function returned a closure that should be watched for */
                                wd = rspamd_mempool_alloc (task->task_pool, sizeof (*wd));
-                               lua_pushvalue (cd->L, level + 1);
-                               wd->cb_ref = luaL_ref (L, LUA_REGISTRYINDEX);
+                               lua_pushvalue (thread /*cd->L*/, cd->stack_level + 1);
+                               wd->cb_ref = luaL_ref (thread, LUA_REGISTRYINDEX);
                                wd->cbd = cd;
                                rspamd_session_watcher_push_callback (task->s,
                                                rspamd_session_get_watcher (task->s),
@@ -1252,14 +1281,14 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
                                                rspamd_session_get_watcher (task->s));
                        }
                        else {
-                               res = lua_tonumber (L, level + 1);
+                               res = lua_tonumber (thread, cd->stack_level + 1);
                        }
 
                        if (res) {
                                gint first_opt = 2;
 
-                               if (lua_type (L, level + 2) == LUA_TNUMBER) {
-                                       flag = lua_tonumber (L, level + 2);
+                               if (lua_type (thread, cd->stack_level + 2) == LUA_TNUMBER) {
+                                       flag = lua_tonumber (thread, cd->stack_level + 2);
                                        /* Shift opt index */
                                        first_opt = 3;
                                }
@@ -1270,35 +1299,40 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
                                s = rspamd_task_insert_result (task, cd->symbol, flag, NULL);
 
                                if (s) {
-                                       guint last_pos = lua_gettop (L);
+                                       guint last_pos = lua_gettop (thread);
 
-                                       for (i = level + first_opt; i <= last_pos; i++) {
-                                               if (lua_type (L, i) == LUA_TSTRING) {
-                                                       const char *opt = lua_tostring (L, i);
+                                       for (i = cd->stack_level + first_opt; i <= last_pos; i++) {
+                                               if (lua_type (thread, i) == LUA_TSTRING) {
+                                                       const char *opt = lua_tostring (thread, i);
 
                                                        rspamd_task_add_result_option (task, s, opt);
                                                }
-                                               else if (lua_type (L, i) == LUA_TTABLE) {
-                                                       lua_pushvalue (L, i);
+                                               else if (lua_type (thread, i) == LUA_TTABLE) {
+                                                       lua_pushvalue (thread, i);
 
-                                                       for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
-                                                               const char *opt = lua_tostring (L, -1);
+                                                       for (lua_pushnil (thread); lua_next (thread, -2); lua_pop (thread, 1)) {
+                                                               const char *opt = lua_tostring (thread, -1);
 
                                                                rspamd_task_add_result_option (task, s, opt);
                                                        }
 
-                                                       lua_pop (L, 1);
+                                                       lua_pop (thread, 1);
                                                }
                                        }
                                }
 
                        }
 
-                       lua_pop (L, nresults);
+                       lua_pop (thread, nresults);
                }
+
+               assert (lua_gettop (thread) == cd->stack_level); /* we properly cleaned up the stack */
+
+               lua_thread_pool_return(task->cfg->lua_thread_pool, cd->thread_entry);
        }
 
-       lua_pop (L, 1); /* Error function */
+       cd->thread_entry = NULL;
+       cd->stack_level = 0;
 }
 
 static gint
diff --git a/src/lua/lua_thread_pool.c b/src/lua/lua_thread_pool.c
new file mode 100644 (file)
index 0000000..6effbe7
--- /dev/null
@@ -0,0 +1,125 @@
+#include "config.h"
+
+#include <assert.h>
+
+#include "lua_common.h"
+#include "lua_thread_pool.h"
+
+struct lua_thread_pool {
+       GQueue *available_items;
+       lua_State *L;
+       gint max_items;
+       struct thread_entry *running_entry;
+};
+
+static struct thread_entry *
+thread_entry_new (lua_State * L)
+{
+       struct thread_entry *ent;
+       ent = g_malloc (sizeof *ent);
+       ent->lua_state = lua_newthread (L);
+       ent->thread_index = luaL_ref (L, LUA_REGISTRYINDEX);
+       return ent;
+}
+
+static void
+thread_entry_free (lua_State * L, struct thread_entry *ent)
+{
+       luaL_unref (L, LUA_REGISTRYINDEX, ent->thread_index);
+       g_free (ent);
+}
+
+struct lua_thread_pool *
+lua_thread_pool_new (lua_State * L)
+{
+       struct lua_thread_pool * pool = g_new0 (struct lua_thread_pool, 1);
+
+       pool->L = L;
+       pool->max_items = 100;
+
+       pool->available_items = g_queue_new ();
+       int i;
+
+       struct thread_entry *ent;
+       for (i = 0; i < pool->max_items; i ++) {
+               ent = thread_entry_new (pool->L);
+               g_queue_push_head (pool->available_items, ent);
+       }
+
+       return pool;
+}
+
+void
+lua_thread_pool_free (struct lua_thread_pool *pool)
+{
+       struct thread_entry *ent = NULL;
+       while (!g_queue_is_empty (pool->available_items)) {
+               ent = g_queue_pop_head (pool->available_items);
+               thread_entry_free (pool->L, ent);
+       }
+       g_queue_free (pool->available_items);
+       g_free (pool);
+}
+
+struct thread_entry *
+lua_thread_pool_get(struct lua_thread_pool *pool)
+{
+       gpointer cur;
+       struct thread_entry *ent = NULL;
+
+       cur = g_queue_pop_head (pool->available_items);
+
+       if (cur) {
+               ent = cur;
+       }
+       else {
+               ent = thread_entry_new (pool->L);
+       }
+       return ent;
+}
+
+void
+lua_thread_pool_return(struct lua_thread_pool *pool, struct thread_entry *thread_entry)
+{
+       assert (lua_status (thread_entry->lua_state) == 0); // we can't return a running/yielded stack into the pool
+       if (pool->running_entry == thread_entry) {
+               pool->running_entry = NULL;
+       }
+       if (g_queue_get_length (pool->available_items) <= pool->max_items) {
+               g_queue_push_head (pool->available_items, thread_entry);
+       }
+       else {
+               thread_entry_free (pool->L, thread_entry);
+       }
+}
+
+void
+lua_thread_pool_terminate_entry(struct lua_thread_pool *pool, struct thread_entry *thread_entry)
+{
+       struct thread_entry *ent = NULL;
+
+       if (pool->running_entry == thread_entry) {
+               pool->running_entry = NULL;
+       }
+
+       // we should only terminate failed threads
+       assert (lua_status (thread_entry->lua_state) != 0 && lua_status (thread_entry->lua_state) != LUA_YIELD);
+       thread_entry_free (pool->L, thread_entry);
+
+       if (g_queue_get_length (pool->available_items) <= pool->max_items) {
+               ent = thread_entry_new (pool->L);
+               g_queue_push_head (pool->available_items, ent);
+       }
+}
+
+struct thread_entry *
+lua_thread_pool_get_running_entry(struct lua_thread_pool *pool)
+{
+       return pool->running_entry;
+}
+
+void
+lua_thread_pool_set_running_entry(struct lua_thread_pool *pool, struct thread_entry *thread_entry)
+{
+       pool->running_entry = thread_entry;
+}
diff --git a/src/lua/lua_thread_pool.h b/src/lua/lua_thread_pool.h
new file mode 100644 (file)
index 0000000..01643df
--- /dev/null
@@ -0,0 +1,35 @@
+#ifndef LUA_THREAD_POOL_H_
+#define LUA_THREAD_POOL_H_
+
+#include <lua.h>
+
+struct thread_entry {
+       lua_State *lua_state;
+       gint thread_index;
+};
+
+struct thread_pool;
+
+struct lua_thread_pool *
+lua_thread_pool_new (lua_State * L);
+
+void
+lua_thread_pool_free (struct lua_thread_pool *pool);
+
+struct thread_entry *
+lua_thread_pool_get(struct lua_thread_pool *pool);
+
+void
+lua_thread_pool_return(struct lua_thread_pool *pool, struct thread_entry *thread_entry);
+
+void
+lua_thread_pool_terminate_entry(struct lua_thread_pool *pool, struct thread_entry *thread_entry);
+
+struct thread_entry *
+lua_thread_pool_get_running_entry(struct lua_thread_pool *pool);
+
+void
+lua_thread_pool_set_running_entry(struct lua_thread_pool *pool, struct thread_entry *thread_entry);
+
+#endif /* LUA_THREAD_POOL_H_ */
+