#include "libserver/composites.h"
#include "libmime/lang_detection.h"
#include "lua/lua_map.h"
+#include "lua/lua_thread_pool.h"
#include "utlist.h"
#include <math.h>
gint ref;
} callback;
gboolean cb_is_ref;
+ gpointer thread_entry;
+ gint stack_level;
gint order;
};
lua_settop (L, err_idx - 1);
}
+static void
+lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint ret);
+
static void
lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
{
struct lua_callback_data *cd = ud;
struct rspamd_task **ptask;
- gint level = lua_gettop (cd->L), nresults, err_idx, ret;
- lua_State *L = cd->L;
- GString *tb;
- struct rspamd_symbol_result *s;
+ gint ret;
- lua_pushcfunction (L, &rspamd_lua_traceback);
- err_idx = lua_gettop (L);
+ struct thread_entry *thread_entry = lua_thread_pool_get(task->cfg->lua_thread_pool);
+ cd->thread_entry = thread_entry;
- level ++;
+ lua_State *thread = thread_entry->lua_state;
+ cd->stack_level = lua_gettop (cd->L);
if (cd->cb_is_ref) {
- lua_rawgeti (L, LUA_REGISTRYINDEX, cd->callback.ref);
+ lua_rawgeti (thread, LUA_REGISTRYINDEX, cd->callback.ref);
}
else {
- lua_getglobal (L, cd->callback.name);
+ lua_getglobal (thread, cd->callback.name);
}
- ptask = lua_newuserdata (L, sizeof (struct rspamd_task *));
- rspamd_lua_setclass (L, "rspamd{task}", -1);
+ ptask = lua_newuserdata (thread, sizeof (struct rspamd_task *));
+ rspamd_lua_setclass (thread, "rspamd{task}", -1);
*ptask = task;
- if ((ret = lua_pcall (L, 1, LUA_MULTRET, err_idx)) != 0) {
- tb = lua_touserdata (L, -1);
+ ret = lua_resume (thread, 1);
+
+ if (ret == LUA_YIELD) {
+ msg_err_task ("LUA_YIELD");
+ } else {
+ lua_metric_symbol_callback_return (task, ud, ret);
+ }
+}
+
+static void
+lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint ret)
+{
+ GString *tb;
+ struct lua_callback_data *cd = ud;
+ int nresults;
+ struct rspamd_symbol_result *s;
+ struct thread_entry *thread_entry = cd->thread_entry;
+ lua_State *thread = thread_entry->lua_state;
+
+ if (ret != 0) {
+ lua_pushcfunction (thread, rspamd_lua_traceback);
+ lua_call (thread, 0, 1);
+
+ tb = lua_touserdata (thread, -1);
msg_err_task ("call to (%s) failed (%d): %v", cd->symbol, ret, tb);
if (tb) {
g_string_free (tb, TRUE);
- lua_pop (L, 1);
+ lua_pop (thread, 1);
}
+ assert (lua_gettop (thread) >= cd->stack_level);
+ // maybe there is a way to recover here. For now, just remove foulty thread
+ lua_thread_pool_terminate_entry (task->cfg->lua_thread_pool, cd->thread_entry);
}
else {
- nresults = lua_gettop (L) - level;
+ nresults = lua_gettop (thread) - cd->stack_level;
if (nresults >= 1) {
/* Function returned boolean, so maybe we need to insert result? */
gint type;
struct lua_watcher_data *wd;
- type = lua_type (cd->L, level + 1);
+ type = lua_type (thread, cd->stack_level + 1);
if (type == LUA_TBOOLEAN) {
- res = lua_toboolean (L, level + 1);
+ res = lua_toboolean (thread, cd->stack_level + 1);
}
else if (type == LUA_TFUNCTION) {
/* Function returned a closure that should be watched for */
wd = rspamd_mempool_alloc (task->task_pool, sizeof (*wd));
- lua_pushvalue (cd->L, level + 1);
- wd->cb_ref = luaL_ref (L, LUA_REGISTRYINDEX);
+ lua_pushvalue (thread /*cd->L*/, cd->stack_level + 1);
+ wd->cb_ref = luaL_ref (thread, LUA_REGISTRYINDEX);
wd->cbd = cd;
rspamd_session_watcher_push_callback (task->s,
rspamd_session_get_watcher (task->s),
rspamd_session_get_watcher (task->s));
}
else {
- res = lua_tonumber (L, level + 1);
+ res = lua_tonumber (thread, cd->stack_level + 1);
}
if (res) {
gint first_opt = 2;
- if (lua_type (L, level + 2) == LUA_TNUMBER) {
- flag = lua_tonumber (L, level + 2);
+ if (lua_type (thread, cd->stack_level + 2) == LUA_TNUMBER) {
+ flag = lua_tonumber (thread, cd->stack_level + 2);
/* Shift opt index */
first_opt = 3;
}
s = rspamd_task_insert_result (task, cd->symbol, flag, NULL);
if (s) {
- guint last_pos = lua_gettop (L);
+ guint last_pos = lua_gettop (thread);
- for (i = level + first_opt; i <= last_pos; i++) {
- if (lua_type (L, i) == LUA_TSTRING) {
- const char *opt = lua_tostring (L, i);
+ for (i = cd->stack_level + first_opt; i <= last_pos; i++) {
+ if (lua_type (thread, i) == LUA_TSTRING) {
+ const char *opt = lua_tostring (thread, i);
rspamd_task_add_result_option (task, s, opt);
}
- else if (lua_type (L, i) == LUA_TTABLE) {
- lua_pushvalue (L, i);
+ else if (lua_type (thread, i) == LUA_TTABLE) {
+ lua_pushvalue (thread, i);
- for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
- const char *opt = lua_tostring (L, -1);
+ for (lua_pushnil (thread); lua_next (thread, -2); lua_pop (thread, 1)) {
+ const char *opt = lua_tostring (thread, -1);
rspamd_task_add_result_option (task, s, opt);
}
- lua_pop (L, 1);
+ lua_pop (thread, 1);
}
}
}
}
- lua_pop (L, nresults);
+ lua_pop (thread, nresults);
}
+
+ assert (lua_gettop (thread) == cd->stack_level); /* we properly cleaned up the stack */
+
+ lua_thread_pool_return(task->cfg->lua_thread_pool, cd->thread_entry);
}
- lua_pop (L, 1); /* Error function */
+ cd->thread_entry = NULL;
+ cd->stack_level = 0;
}
static gint
--- /dev/null
+#include "config.h"
+
+#include <assert.h>
+
+#include "lua_common.h"
+#include "lua_thread_pool.h"
+
+struct lua_thread_pool {
+ GQueue *available_items;
+ lua_State *L;
+ gint max_items;
+ struct thread_entry *running_entry;
+};
+
+static struct thread_entry *
+thread_entry_new (lua_State * L)
+{
+ struct thread_entry *ent;
+ ent = g_malloc (sizeof *ent);
+ ent->lua_state = lua_newthread (L);
+ ent->thread_index = luaL_ref (L, LUA_REGISTRYINDEX);
+ return ent;
+}
+
+static void
+thread_entry_free (lua_State * L, struct thread_entry *ent)
+{
+ luaL_unref (L, LUA_REGISTRYINDEX, ent->thread_index);
+ g_free (ent);
+}
+
+struct lua_thread_pool *
+lua_thread_pool_new (lua_State * L)
+{
+ struct lua_thread_pool * pool = g_new0 (struct lua_thread_pool, 1);
+
+ pool->L = L;
+ pool->max_items = 100;
+
+ pool->available_items = g_queue_new ();
+ int i;
+
+ struct thread_entry *ent;
+ for (i = 0; i < pool->max_items; i ++) {
+ ent = thread_entry_new (pool->L);
+ g_queue_push_head (pool->available_items, ent);
+ }
+
+ return pool;
+}
+
+void
+lua_thread_pool_free (struct lua_thread_pool *pool)
+{
+ struct thread_entry *ent = NULL;
+ while (!g_queue_is_empty (pool->available_items)) {
+ ent = g_queue_pop_head (pool->available_items);
+ thread_entry_free (pool->L, ent);
+ }
+ g_queue_free (pool->available_items);
+ g_free (pool);
+}
+
+struct thread_entry *
+lua_thread_pool_get(struct lua_thread_pool *pool)
+{
+ gpointer cur;
+ struct thread_entry *ent = NULL;
+
+ cur = g_queue_pop_head (pool->available_items);
+
+ if (cur) {
+ ent = cur;
+ }
+ else {
+ ent = thread_entry_new (pool->L);
+ }
+ return ent;
+}
+
+void
+lua_thread_pool_return(struct lua_thread_pool *pool, struct thread_entry *thread_entry)
+{
+ assert (lua_status (thread_entry->lua_state) == 0); // we can't return a running/yielded stack into the pool
+ if (pool->running_entry == thread_entry) {
+ pool->running_entry = NULL;
+ }
+ if (g_queue_get_length (pool->available_items) <= pool->max_items) {
+ g_queue_push_head (pool->available_items, thread_entry);
+ }
+ else {
+ thread_entry_free (pool->L, thread_entry);
+ }
+}
+
+void
+lua_thread_pool_terminate_entry(struct lua_thread_pool *pool, struct thread_entry *thread_entry)
+{
+ struct thread_entry *ent = NULL;
+
+ if (pool->running_entry == thread_entry) {
+ pool->running_entry = NULL;
+ }
+
+ // we should only terminate failed threads
+ assert (lua_status (thread_entry->lua_state) != 0 && lua_status (thread_entry->lua_state) != LUA_YIELD);
+ thread_entry_free (pool->L, thread_entry);
+
+ if (g_queue_get_length (pool->available_items) <= pool->max_items) {
+ ent = thread_entry_new (pool->L);
+ g_queue_push_head (pool->available_items, ent);
+ }
+}
+
+struct thread_entry *
+lua_thread_pool_get_running_entry(struct lua_thread_pool *pool)
+{
+ return pool->running_entry;
+}
+
+void
+lua_thread_pool_set_running_entry(struct lua_thread_pool *pool, struct thread_entry *thread_entry)
+{
+ pool->running_entry = thread_entry;
+}