@@ -380,6 +380,7 @@ struct rspamd_config { | |||
gchar * checksum; /**< real checksum of config file */ | |||
gchar * dump_checksum; /**< dump checksum of config file */ | |||
gpointer lua_state; /**< pointer to lua state */ | |||
gpointer lua_thread_pool; /**< pointer to lua thread (coroutine) pool */ | |||
gchar * rrd_file; /**< rrd file to store statistics */ | |||
gchar * history_file; /**< file to save rolling history */ |
@@ -20,6 +20,7 @@ | |||
#include "uthash_strcase.h" | |||
#include "filter.h" | |||
#include "lua/lua_common.h" | |||
#include "lua/lua_thread_pool.h" | |||
#include "map.h" | |||
#include "map_helpers.h" | |||
#include "map_private.h" | |||
@@ -175,6 +176,7 @@ rspamd_config_new (enum rspamd_config_init_flags flags) | |||
if (!(flags & RSPAMD_CONFIG_INIT_SKIP_LUA)) { | |||
cfg->lua_state = rspamd_lua_init (); | |||
cfg->own_lua_state = TRUE; | |||
cfg->lua_thread_pool = lua_thread_pool_new (cfg->lua_state); | |||
} | |||
cfg->cache = rspamd_symbols_cache_new (cfg); | |||
@@ -259,6 +261,7 @@ rspamd_config_free (struct rspamd_config *cfg) | |||
g_ptr_array_free (cfg->c_modules, TRUE); | |||
if (cfg->lua_state && cfg->own_lua_state) { | |||
lua_thread_pool_free (cfg->lua_thread_pool); | |||
lua_close (cfg->lua_state); | |||
} | |||
REF_RELEASE (cfg->libs_ctx); |
@@ -25,6 +25,7 @@ SET(LUASRC ${CMAKE_CURRENT_SOURCE_DIR}/lua_common.c | |||
${CMAKE_CURRENT_SOURCE_DIR}/lua_fann.c | |||
${CMAKE_CURRENT_SOURCE_DIR}/lua_sqlite3.c | |||
${CMAKE_CURRENT_SOURCE_DIR}/lua_cryptobox.c | |||
${CMAKE_CURRENT_SOURCE_DIR}/lua_map.c) | |||
${CMAKE_CURRENT_SOURCE_DIR}/lua_map.c | |||
${CMAKE_CURRENT_SOURCE_DIR}/lua_thread_pool.c) | |||
SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE) |
@@ -19,6 +19,7 @@ | |||
#include "libserver/composites.h" | |||
#include "libmime/lang_detection.h" | |||
#include "lua/lua_map.h" | |||
#include "lua/lua_thread_pool.h" | |||
#include "utlist.h" | |||
#include <math.h> | |||
@@ -1036,6 +1037,8 @@ struct lua_callback_data { | |||
gint ref; | |||
} callback; | |||
gboolean cb_is_ref; | |||
gpointer thread_entry; | |||
gint stack_level; | |||
gint order; | |||
}; | |||
@@ -1184,43 +1187,69 @@ lua_watcher_callback (gpointer session_data, gpointer ud) | |||
lua_settop (L, err_idx - 1); | |||
} | |||
static void | |||
lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint ret); | |||
static void | |||
lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud) | |||
{ | |||
struct lua_callback_data *cd = ud; | |||
struct rspamd_task **ptask; | |||
gint level = lua_gettop (cd->L), nresults, err_idx, ret; | |||
lua_State *L = cd->L; | |||
GString *tb; | |||
struct rspamd_symbol_result *s; | |||
gint ret; | |||
lua_pushcfunction (L, &rspamd_lua_traceback); | |||
err_idx = lua_gettop (L); | |||
struct thread_entry *thread_entry = lua_thread_pool_get(task->cfg->lua_thread_pool); | |||
cd->thread_entry = thread_entry; | |||
level ++; | |||
lua_State *thread = thread_entry->lua_state; | |||
cd->stack_level = lua_gettop (cd->L); | |||
if (cd->cb_is_ref) { | |||
lua_rawgeti (L, LUA_REGISTRYINDEX, cd->callback.ref); | |||
lua_rawgeti (thread, LUA_REGISTRYINDEX, cd->callback.ref); | |||
} | |||
else { | |||
lua_getglobal (L, cd->callback.name); | |||
lua_getglobal (thread, cd->callback.name); | |||
} | |||
ptask = lua_newuserdata (L, sizeof (struct rspamd_task *)); | |||
rspamd_lua_setclass (L, "rspamd{task}", -1); | |||
ptask = lua_newuserdata (thread, sizeof (struct rspamd_task *)); | |||
rspamd_lua_setclass (thread, "rspamd{task}", -1); | |||
*ptask = task; | |||
if ((ret = lua_pcall (L, 1, LUA_MULTRET, err_idx)) != 0) { | |||
tb = lua_touserdata (L, -1); | |||
ret = lua_resume (thread, 1); | |||
if (ret == LUA_YIELD) { | |||
msg_err_task ("LUA_YIELD"); | |||
} else { | |||
lua_metric_symbol_callback_return (task, ud, ret); | |||
} | |||
} | |||
static void | |||
lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint ret) | |||
{ | |||
GString *tb; | |||
struct lua_callback_data *cd = ud; | |||
int nresults; | |||
struct rspamd_symbol_result *s; | |||
struct thread_entry *thread_entry = cd->thread_entry; | |||
lua_State *thread = thread_entry->lua_state; | |||
if (ret != 0) { | |||
lua_pushcfunction (thread, rspamd_lua_traceback); | |||
lua_call (thread, 0, 1); | |||
tb = lua_touserdata (thread, -1); | |||
msg_err_task ("call to (%s) failed (%d): %v", cd->symbol, ret, tb); | |||
if (tb) { | |||
g_string_free (tb, TRUE); | |||
lua_pop (L, 1); | |||
lua_pop (thread, 1); | |||
} | |||
assert (lua_gettop (thread) >= cd->stack_level); | |||
// maybe there is a way to recover here. For now, just remove foulty thread | |||
lua_thread_pool_terminate_entry (task->cfg->lua_thread_pool, cd->thread_entry); | |||
} | |||
else { | |||
nresults = lua_gettop (L) - level; | |||
nresults = lua_gettop (thread) - cd->stack_level; | |||
if (nresults >= 1) { | |||
/* Function returned boolean, so maybe we need to insert result? */ | |||
@@ -1230,16 +1259,16 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud) | |||
gint type; | |||
struct lua_watcher_data *wd; | |||
type = lua_type (cd->L, level + 1); | |||
type = lua_type (thread, cd->stack_level + 1); | |||
if (type == LUA_TBOOLEAN) { | |||
res = lua_toboolean (L, level + 1); | |||
res = lua_toboolean (thread, cd->stack_level + 1); | |||
} | |||
else if (type == LUA_TFUNCTION) { | |||
/* Function returned a closure that should be watched for */ | |||
wd = rspamd_mempool_alloc (task->task_pool, sizeof (*wd)); | |||
lua_pushvalue (cd->L, level + 1); | |||
wd->cb_ref = luaL_ref (L, LUA_REGISTRYINDEX); | |||
lua_pushvalue (thread /*cd->L*/, cd->stack_level + 1); | |||
wd->cb_ref = luaL_ref (thread, LUA_REGISTRYINDEX); | |||
wd->cbd = cd; | |||
rspamd_session_watcher_push_callback (task->s, | |||
rspamd_session_get_watcher (task->s), | |||
@@ -1252,14 +1281,14 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud) | |||
rspamd_session_get_watcher (task->s)); | |||
} | |||
else { | |||
res = lua_tonumber (L, level + 1); | |||
res = lua_tonumber (thread, cd->stack_level + 1); | |||
} | |||
if (res) { | |||
gint first_opt = 2; | |||
if (lua_type (L, level + 2) == LUA_TNUMBER) { | |||
flag = lua_tonumber (L, level + 2); | |||
if (lua_type (thread, cd->stack_level + 2) == LUA_TNUMBER) { | |||
flag = lua_tonumber (thread, cd->stack_level + 2); | |||
/* Shift opt index */ | |||
first_opt = 3; | |||
} | |||
@@ -1270,35 +1299,40 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud) | |||
s = rspamd_task_insert_result (task, cd->symbol, flag, NULL); | |||
if (s) { | |||
guint last_pos = lua_gettop (L); | |||
guint last_pos = lua_gettop (thread); | |||
for (i = level + first_opt; i <= last_pos; i++) { | |||
if (lua_type (L, i) == LUA_TSTRING) { | |||
const char *opt = lua_tostring (L, i); | |||
for (i = cd->stack_level + first_opt; i <= last_pos; i++) { | |||
if (lua_type (thread, i) == LUA_TSTRING) { | |||
const char *opt = lua_tostring (thread, i); | |||
rspamd_task_add_result_option (task, s, opt); | |||
} | |||
else if (lua_type (L, i) == LUA_TTABLE) { | |||
lua_pushvalue (L, i); | |||
else if (lua_type (thread, i) == LUA_TTABLE) { | |||
lua_pushvalue (thread, i); | |||
for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) { | |||
const char *opt = lua_tostring (L, -1); | |||
for (lua_pushnil (thread); lua_next (thread, -2); lua_pop (thread, 1)) { | |||
const char *opt = lua_tostring (thread, -1); | |||
rspamd_task_add_result_option (task, s, opt); | |||
} | |||
lua_pop (L, 1); | |||
lua_pop (thread, 1); | |||
} | |||
} | |||
} | |||
} | |||
lua_pop (L, nresults); | |||
lua_pop (thread, nresults); | |||
} | |||
assert (lua_gettop (thread) == cd->stack_level); /* we properly cleaned up the stack */ | |||
lua_thread_pool_return(task->cfg->lua_thread_pool, cd->thread_entry); | |||
} | |||
lua_pop (L, 1); /* Error function */ | |||
cd->thread_entry = NULL; | |||
cd->stack_level = 0; | |||
} | |||
static gint |
@@ -0,0 +1,125 @@ | |||
#include "config.h" | |||
#include <assert.h> | |||
#include "lua_common.h" | |||
#include "lua_thread_pool.h" | |||
struct lua_thread_pool { | |||
GQueue *available_items; | |||
lua_State *L; | |||
gint max_items; | |||
struct thread_entry *running_entry; | |||
}; | |||
static struct thread_entry * | |||
thread_entry_new (lua_State * L) | |||
{ | |||
struct thread_entry *ent; | |||
ent = g_malloc (sizeof *ent); | |||
ent->lua_state = lua_newthread (L); | |||
ent->thread_index = luaL_ref (L, LUA_REGISTRYINDEX); | |||
return ent; | |||
} | |||
static void | |||
thread_entry_free (lua_State * L, struct thread_entry *ent) | |||
{ | |||
luaL_unref (L, LUA_REGISTRYINDEX, ent->thread_index); | |||
g_free (ent); | |||
} | |||
struct lua_thread_pool * | |||
lua_thread_pool_new (lua_State * L) | |||
{ | |||
struct lua_thread_pool * pool = g_new0 (struct lua_thread_pool, 1); | |||
pool->L = L; | |||
pool->max_items = 100; | |||
pool->available_items = g_queue_new (); | |||
int i; | |||
struct thread_entry *ent; | |||
for (i = 0; i < pool->max_items; i ++) { | |||
ent = thread_entry_new (pool->L); | |||
g_queue_push_head (pool->available_items, ent); | |||
} | |||
return pool; | |||
} | |||
void | |||
lua_thread_pool_free (struct lua_thread_pool *pool) | |||
{ | |||
struct thread_entry *ent = NULL; | |||
while (!g_queue_is_empty (pool->available_items)) { | |||
ent = g_queue_pop_head (pool->available_items); | |||
thread_entry_free (pool->L, ent); | |||
} | |||
g_queue_free (pool->available_items); | |||
g_free (pool); | |||
} | |||
struct thread_entry * | |||
lua_thread_pool_get(struct lua_thread_pool *pool) | |||
{ | |||
gpointer cur; | |||
struct thread_entry *ent = NULL; | |||
cur = g_queue_pop_head (pool->available_items); | |||
if (cur) { | |||
ent = cur; | |||
} | |||
else { | |||
ent = thread_entry_new (pool->L); | |||
} | |||
return ent; | |||
} | |||
void | |||
lua_thread_pool_return(struct lua_thread_pool *pool, struct thread_entry *thread_entry) | |||
{ | |||
assert (lua_status (thread_entry->lua_state) == 0); // we can't return a running/yielded stack into the pool | |||
if (pool->running_entry == thread_entry) { | |||
pool->running_entry = NULL; | |||
} | |||
if (g_queue_get_length (pool->available_items) <= pool->max_items) { | |||
g_queue_push_head (pool->available_items, thread_entry); | |||
} | |||
else { | |||
thread_entry_free (pool->L, thread_entry); | |||
} | |||
} | |||
void | |||
lua_thread_pool_terminate_entry(struct lua_thread_pool *pool, struct thread_entry *thread_entry) | |||
{ | |||
struct thread_entry *ent = NULL; | |||
if (pool->running_entry == thread_entry) { | |||
pool->running_entry = NULL; | |||
} | |||
// we should only terminate failed threads | |||
assert (lua_status (thread_entry->lua_state) != 0 && lua_status (thread_entry->lua_state) != LUA_YIELD); | |||
thread_entry_free (pool->L, thread_entry); | |||
if (g_queue_get_length (pool->available_items) <= pool->max_items) { | |||
ent = thread_entry_new (pool->L); | |||
g_queue_push_head (pool->available_items, ent); | |||
} | |||
} | |||
struct thread_entry * | |||
lua_thread_pool_get_running_entry(struct lua_thread_pool *pool) | |||
{ | |||
return pool->running_entry; | |||
} | |||
void | |||
lua_thread_pool_set_running_entry(struct lua_thread_pool *pool, struct thread_entry *thread_entry) | |||
{ | |||
pool->running_entry = thread_entry; | |||
} |
@@ -0,0 +1,35 @@ | |||
#ifndef LUA_THREAD_POOL_H_ | |||
#define LUA_THREAD_POOL_H_ | |||
#include <lua.h> | |||
struct thread_entry { | |||
lua_State *lua_state; | |||
gint thread_index; | |||
}; | |||
struct thread_pool; | |||
struct lua_thread_pool * | |||
lua_thread_pool_new (lua_State * L); | |||
void | |||
lua_thread_pool_free (struct lua_thread_pool *pool); | |||
struct thread_entry * | |||
lua_thread_pool_get(struct lua_thread_pool *pool); | |||
void | |||
lua_thread_pool_return(struct lua_thread_pool *pool, struct thread_entry *thread_entry); | |||
void | |||
lua_thread_pool_terminate_entry(struct lua_thread_pool *pool, struct thread_entry *thread_entry); | |||
struct thread_entry * | |||
lua_thread_pool_get_running_entry(struct lua_thread_pool *pool); | |||
void | |||
lua_thread_pool_set_running_entry(struct lua_thread_pool *pool, struct thread_entry *thread_entry); | |||
#endif /* LUA_THREAD_POOL_H_ */ | |||