/*- * Copyright 2018 Mikhail Galanin * Copyright 2019 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "config.h" #include "lua_common.h" #include "lua_thread_pool.h" #include #define msg_debug_lua_threads(...) rspamd_conditional_debug_fast(NULL, NULL, \ rspamd_lua_threads_log_id, "lua_threads", NULL, \ RSPAMD_LOG_FUNC, \ __VA_ARGS__) INIT_LOG_MODULE(lua_threads) static struct thread_entry *thread_entry_new(lua_State *L); static void thread_entry_free(lua_State *L, struct thread_entry *ent); #define CFG_POOL_GET(cfg) (reinterpret_cast((cfg)->lua_thread_pool)) struct lua_thread_pool { std::vector available_items; lua_State *L; int max_items; struct thread_entry *running_entry; static const int default_max_items = 100; lua_thread_pool(lua_State *L, int max_items = default_max_items) : L(L), max_items(max_items) { running_entry = nullptr; available_items.reserve(max_items); for (auto i = 0; i < MAX(2, max_items / 10); i++) { auto *ent = thread_entry_new(L); available_items.push_back(ent); } } ~lua_thread_pool() { for (auto *ent: available_items) { thread_entry_free(L, ent); } } auto get_thread() -> struct thread_entry * { struct thread_entry *ent; if (!available_items.empty()) { ent = available_items.back(); available_items.pop_back(); } else { ent = thread_entry_new(L); } running_entry = ent; return ent; } auto return_thread(struct thread_entry *thread_entry, const char *loc) -> void { /* we can't return a running/yielded thread into the pool */ g_assert(lua_status(thread_entry->lua_state) == 0); if (running_entry == thread_entry) { running_entry = NULL; } if (available_items.size() <= max_items) { thread_entry->cd = NULL; thread_entry->finish_callback = NULL; thread_entry->error_callback = NULL; thread_entry->task = NULL; thread_entry->cfg = NULL; msg_debug_lua_threads("%s: returned thread to the threads pool %ud items", loc, available_items.size()); available_items.push_back(thread_entry); } else { msg_debug_lua_threads("%s: removed thread as thread pool has %ud items", loc, available_items.size()); thread_entry_free(L, thread_entry); } } auto terminate_thread(struct thread_entry *thread_entry, const char *loc, bool enforce) -> void { struct thread_entry *ent = NULL; if (!enforce) { /* we should only terminate failed threads */ g_assert(lua_status(thread_entry->lua_state) != 0 && lua_status(thread_entry->lua_state) != LUA_YIELD); } if (running_entry == thread_entry) { running_entry = NULL; } msg_debug_lua_threads("%s: terminated thread entry", loc); thread_entry_free(L, thread_entry); if (available_items.size() <= max_items) { ent = thread_entry_new(L); available_items.push_back(ent); } } auto get_running_entry(void) -> struct thread_entry * { return running_entry; }; auto set_running_entry(struct thread_entry *ent) -> struct thread_entry * { auto *old_entry = running_entry; running_entry = ent; return old_entry; }; }; static struct thread_entry * thread_entry_new(lua_State *L) { struct thread_entry *ent; ent = g_new0(struct thread_entry, 1); ent->lua_state = lua_newthread(L); ent->thread_index = luaL_ref(L, LUA_REGISTRYINDEX); return ent; } static void thread_entry_free(lua_State *L, struct thread_entry *ent) { luaL_unref(L, LUA_REGISTRYINDEX, ent->thread_index); g_free(ent); } struct lua_thread_pool * lua_thread_pool_new(lua_State *L) { auto *pool = new lua_thread_pool(L); return pool; } void lua_thread_pool_free(struct lua_thread_pool *pool) { delete pool; } struct thread_entry * lua_thread_pool_get_for_task(struct rspamd_task *task) { struct thread_entry *ent = CFG_POOL_GET(task->cfg)->get_thread(); ent->task = task; return ent; } struct thread_entry * lua_thread_pool_get_for_config(struct rspamd_config *cfg) { struct thread_entry *ent = CFG_POOL_GET(cfg)->get_thread(); ent->cfg = cfg; return ent; } void lua_thread_pool_return_full(struct lua_thread_pool *pool, struct thread_entry *thread_entry, const char *loc) { pool->return_thread(thread_entry, loc); } void lua_thread_pool_terminate_entry_full(struct lua_thread_pool *pool, struct thread_entry *thread_entry, const char *loc, bool enforce) { pool->terminate_thread(thread_entry, loc, enforce); } struct thread_entry * lua_thread_pool_get_running_entry_full(struct lua_thread_pool *pool, const char *loc) { msg_debug_lua_threads("%s: lua_thread_pool_get_running_entry_full", loc); return pool->get_running_entry(); } void lua_thread_pool_set_running_entry_full(struct lua_thread_pool *pool, struct thread_entry *thread_entry, const char *loc) { msg_debug_lua_threads("%s: lua_thread_pool_set_running_entry_full", loc); pool->set_running_entry(thread_entry); } static void lua_thread_pool_set_running_entry_for_thread(struct thread_entry *thread_entry, const char *loc) { struct lua_thread_pool *pool; if (thread_entry->task) { pool = CFG_POOL_GET(thread_entry->task->cfg); } else { pool = CFG_POOL_GET(thread_entry->cfg); } lua_thread_pool_set_running_entry_full(pool, thread_entry, loc); } void lua_thread_pool_prepare_callback_full(struct lua_thread_pool *pool, struct lua_callback_state *cbs, const char *loc) { msg_debug_lua_threads("%s: lua_thread_pool_prepare_callback_full", loc); cbs->thread_pool = pool; cbs->previous_thread = lua_thread_pool_get_running_entry_full(pool, loc); cbs->my_thread = pool->get_thread(); cbs->L = cbs->my_thread->lua_state; } void lua_thread_pool_restore_callback_full(struct lua_callback_state *cbs, const char *loc) { lua_thread_pool_return_full(cbs->thread_pool, cbs->my_thread, loc); lua_thread_pool_set_running_entry_full(cbs->thread_pool, cbs->previous_thread, loc); } static int lua_do_resume_full(lua_State *L, int narg, const char *loc) { #if LUA_VERSION_NUM >= 504 int nres; #endif msg_debug_lua_threads("%s: lua_do_resume_full", loc); #if LUA_VERSION_NUM < 502 return lua_resume(L, narg); #else #if LUA_VERSION_NUM >= 504 return lua_resume(L, NULL, narg, &nres); #else return lua_resume(L, NULL, narg); #endif #endif } static void lua_resume_thread_internal_full(struct thread_entry *thread_entry, int narg, const char *loc) { int ret; struct lua_thread_pool *pool; struct rspamd_task *task; msg_debug_lua_threads("%s: lua_resume_thread_internal_full", loc); ret = lua_do_resume_full(thread_entry->lua_state, narg, loc); if (ret != LUA_YIELD) { /* LUA_YIELD state should not be handled here. It should only happen when the thread initiated a asynchronous event and it will be restored as soon the event is finished */ if (thread_entry->task) { pool = CFG_POOL_GET(thread_entry->task->cfg); } else { pool = CFG_POOL_GET(thread_entry->cfg); } if (ret == 0) { if (thread_entry->finish_callback) { thread_entry->finish_callback(thread_entry, ret); } pool->return_thread(thread_entry, loc); } else { rspamd_lua_traceback(thread_entry->lua_state); if (thread_entry->error_callback) { thread_entry->error_callback(thread_entry, ret, lua_tostring(thread_entry->lua_state, -1)); } else if (thread_entry->task) { task = thread_entry->task; msg_err_task("lua call failed (%d): %s", ret, lua_tostring(thread_entry->lua_state, -1)); } else { msg_err("lua call failed (%d): %s", ret, lua_tostring(thread_entry->lua_state, -1)); } /* * Maybe there is a way to recover here. * For now, just remove faulty thread */ pool->terminate_thread(thread_entry, loc, false); } } } void lua_thread_resume_full(struct thread_entry *thread_entry, int narg, const char *loc) { /* * The only state where we can resume from is LUA_YIELD * Another acceptable status is OK (0) but in that case we should push function on stack * to start the thread from, which is happening in lua_thread_call(), not in this function. */ g_assert(lua_status(thread_entry->lua_state) == LUA_YIELD); msg_debug_lua_threads("%s: lua_thread_resume_full", loc); lua_thread_pool_set_running_entry_for_thread(thread_entry, loc); lua_resume_thread_internal_full(thread_entry, narg, loc); } void lua_thread_call_full(struct thread_entry *thread_entry, int narg, const char *loc) { g_assert(lua_status(thread_entry->lua_state) == 0); /* we can't call running/yielded thread */ g_assert(thread_entry->task != NULL || thread_entry->cfg != NULL); /* we can't call without pool */ lua_resume_thread_internal_full(thread_entry, narg, loc); } int lua_thread_yield_full(struct thread_entry *thread_entry, int nresults, const char *loc) { g_assert(lua_status(thread_entry->lua_state) == 0); msg_debug_lua_threads("%s: lua_thread_yield_full", loc); return lua_yield(thread_entry->lua_state, nresults); }