From 9926cc68e2c143be8a05b91e28ba8830abfea04a Mon Sep 17 00:00:00 2001 From: Mikhail Galanin Date: Tue, 14 Aug 2018 14:36:05 +0100 Subject: [PATCH] [Minor] Use coroutine model in dns resolver API --- src/lua/lua_common.h | 21 ++++++++++++ src/lua/lua_config.c | 28 ++++++++++++++- src/lua/lua_dns.c | 72 ++++++++++++++++++++++++++++----------- src/lua/lua_thread_pool.c | 5 ++- src/lua/lua_thread_pool.h | 1 + 5 files changed, 106 insertions(+), 21 deletions(-) diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index 45baa5549..149642597 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -408,6 +408,27 @@ void rspamd_lua_add_ref_dtor (lua_State *L, rspamd_mempool_t *pool, gboolean rspamd_lua_require_function (lua_State *L, const gchar *modname, const gchar *funcname); +struct thread_entry; +/** + * Yields thread. should be only called in return statement + * @param thread_entry + * @param nresults + * @return + */ +gint +lua_yield_thread (struct thread_entry *thread_entry, gint nresults); + +/** + * + * @param pool + * @param thread_entry + * @param narg + * @return + */ +void +lua_resume_thread (struct rspamd_task *task, struct thread_entry *thread_entry, gint narg); + + /* Paths defs */ #define RSPAMD_CONFDIR_INDEX "CONFDIR" #define RSPAMD_RUNDIR_INDEX "RUNDIR" diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index 3a38d437b..31deed5c5 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -1206,6 +1206,9 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud) struct thread_entry *thread_entry = lua_thread_pool_get (task->cfg->lua_thread_pool); cd->thread_entry = thread_entry; + g_assert(thread_entry->cd == NULL); + thread_entry->cd = cd; + lua_State *thread = thread_entry->lua_state; cd->stack_level = lua_gettop (cd->L); @@ -1232,6 +1235,29 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud) } } +gint +lua_yield_thread (struct thread_entry *thread_entry, gint nresults) +{ + g_assert (thread_entry->cd != NULL); + + return lua_yield (thread_entry->lua_state, nresults); +} + +void +lua_resume_thread (struct rspamd_task *task, struct thread_entry *thread_entry, gint narg) +{ + g_assert (thread_entry->cd != NULL); + + gint ret; + + lua_thread_pool_set_running_entry (task->cfg->lua_thread_pool, thread_entry); + ret = lua_resume (thread_entry->lua_state, narg); + + if (ret != LUA_YIELD) { + lua_metric_symbol_callback_return (task, thread_entry->cd, ret); + } +} + static void lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint ret) { @@ -1337,7 +1363,7 @@ lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint r g_assert (lua_gettop (thread) == cd->stack_level); /* we properly cleaned up the stack */ - lua_thread_pool_return(task->cfg->lua_thread_pool, cd->thread_entry); + lua_thread_pool_return (task->cfg->lua_thread_pool, cd->thread_entry); } cd->thread_entry = NULL; diff --git a/src/lua/lua_dns.c b/src/lua/lua_dns.c index f6ba88f2e..8baa3e615 100644 --- a/src/lua/lua_dns.c +++ b/src/lua/lua_dns.c @@ -14,6 +14,7 @@ * limitations under the License. */ #include "lua_common.h" +#include "lua_thread_pool.h" #include "utlist.h" @@ -78,6 +79,8 @@ lua_check_dns_resolver (lua_State * L) struct lua_dns_cbdata { lua_State *L; + struct thread_entry *thread; + struct rspamd_task *task; struct rspamd_dns_resolver *resolver; gint cbref; const gchar *to_resolve; @@ -140,12 +143,15 @@ lua_dns_callback (struct rdns_reply *reply, gpointer arg) struct rdns_reply_entry *elt; rspamd_inet_addr_t *addr; - lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->cbref); - presolver = lua_newuserdata (cd->L, sizeof (gpointer)); - rspamd_lua_setclass (cd->L, "rspamd{resolver}", -1); + if (cd->cbref != -1) { + lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->cbref); - *presolver = cd->resolver; - lua_pushstring (cd->L, cd->to_resolve); + presolver = lua_newuserdata (cd->L, sizeof (gpointer)); + rspamd_lua_setclass (cd->L, "rspamd{resolver}", -1); + + *presolver = cd->resolver; + lua_pushstring (cd->L, cd->to_resolve); + } /* * XXX: rework to handle different request types @@ -227,22 +233,41 @@ lua_dns_callback (struct rdns_reply *reply, gpointer arg) lua_pushstring (cd->L, rdns_strerror (reply->code)); } - if (cd->user_str != NULL) { - lua_pushstring (cd->L, cd->user_str); - } - else { - lua_pushnil (cd->L); + if (cd->cbref != -1) { + if (cd->user_str != NULL) { + lua_pushstring (cd->L, cd->user_str); + } + else { + lua_pushnil (cd->L); + } } lua_pushboolean (cd->L, reply->authenticated); - if (lua_pcall (cd->L, 6, 0, 0) != 0) { - msg_info ("call to dns callback failed: %s", lua_tostring (cd->L, -1)); - lua_pop (cd->L, 1); - } + if (cd->cbref != -1) { + /* + * 1 - resolver + * 2 - to_resolve + * 3 - entries | nil + * 4 - error | nil + * 5 - user_str + * 6 - reply->authenticated + */ + if (lua_pcall (cd->L, 6, 0, 0) != 0) { + msg_info ("call to dns callback failed: %s", lua_tostring (cd->L, -1)); + lua_pop (cd->L, 1); + } - /* Unref function */ - luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->cbref); + /* Unref function */ + luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->cbref); + } else { + /* + * 1 - entries + * 2 - error | nil + * 3 - reply->authenticated + */ + lua_resume_thread (cd->task, cd->thread, 3); + } if (cd->s) { rspamd_session_watcher_pop (cd->s, cd->w); @@ -306,7 +331,7 @@ lua_dns_resolver_resolve_common (lua_State *L, /* Check arguments */ if (!rspamd_lua_parse_table_arguments (L, first, &err, - "session=U{session};mempool=U{mempool};*name=S;*callback=F;" + "session=U{session};mempool=U{mempool};*name=S;callback=F;" "option=S;task=U{task};forced=B", &session, &pool, &to_resolve, &cbref, &user_str, &task, &forced)) { @@ -325,7 +350,7 @@ lua_dns_resolver_resolve_common (lua_State *L, session = task->s; } - if (pool != NULL && to_resolve != NULL && cbref != -1) { + if (pool != NULL && to_resolve != NULL) { cbdata = rspamd_mempool_alloc0 (pool, sizeof (struct lua_dns_cbdata)); cbdata->L = L; cbdata->resolver = resolver; @@ -375,6 +400,9 @@ lua_dns_resolver_resolve_common (lua_State *L, } } else { + cbdata->thread = lua_thread_pool_get_running_entry (task->cfg->lua_thread_pool); + cbdata->task = task; + if (forced) { ret = make_dns_request_task_forced (task, lua_dns_callback, @@ -391,10 +419,16 @@ lua_dns_resolver_resolve_common (lua_State *L, } if (ret) { - lua_pushboolean (L, TRUE); cbdata->s = session; cbdata->w = rspamd_session_get_watcher (session); rspamd_session_watcher_push (session); + if (cbdata->cbref != -1) { + /* callback was set up */ + lua_pushboolean (L, TRUE); + } else { + /* this is coroutine-based call */ + return lua_yield_thread (cbdata->thread, 0); + } } else { lua_pushnil (L); diff --git a/src/lua/lua_thread_pool.c b/src/lua/lua_thread_pool.c index 07364f270..865b41fce 100644 --- a/src/lua/lua_thread_pool.c +++ b/src/lua/lua_thread_pool.c @@ -14,7 +14,7 @@ static struct thread_entry * thread_entry_new (lua_State * L) { struct thread_entry *ent; - ent = g_malloc (sizeof *ent); + ent = g_new0(struct thread_entry, 1); ent->lua_state = lua_newthread (L); ent->thread_index = luaL_ref (L, LUA_REGISTRYINDEX); @@ -75,6 +75,8 @@ lua_thread_pool_get(struct lua_thread_pool *pool) ent = thread_entry_new (pool->L); } + pool->running_entry = ent; + return ent; } @@ -88,6 +90,7 @@ lua_thread_pool_return(struct lua_thread_pool *pool, struct thread_entry *thread } if (g_queue_get_length (pool->available_items) <= pool->max_items) { + thread_entry->cd = NULL; g_queue_push_head (pool->available_items, thread_entry); } else { diff --git a/src/lua/lua_thread_pool.h b/src/lua/lua_thread_pool.h index 33a0da879..64708c488 100644 --- a/src/lua/lua_thread_pool.h +++ b/src/lua/lua_thread_pool.h @@ -6,6 +6,7 @@ struct thread_entry { lua_State *lua_state; gint thread_index; + gpointer cd; }; struct thread_pool; -- 2.39.5