]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Use coroutine model in dns resolver API
authorMikhail Galanin <mgalanin@mimecast.com>
Tue, 14 Aug 2018 13:36:05 +0000 (14:36 +0100)
committerMikhail Galanin <mgalanin@mimecast.com>
Tue, 14 Aug 2018 13:36:05 +0000 (14:36 +0100)
src/lua/lua_common.h
src/lua/lua_config.c
src/lua/lua_dns.c
src/lua/lua_thread_pool.c
src/lua/lua_thread_pool.h

index 45baa55491d5525423ddbcbcfbb17204c9f86a4a..149642597f46f28c7dbac887964afd0c8afb3f53 100644 (file)
@@ -408,6 +408,27 @@ void rspamd_lua_add_ref_dtor (lua_State *L, rspamd_mempool_t *pool,
 gboolean rspamd_lua_require_function (lua_State *L, const gchar *modname,
                const gchar *funcname);
 
+struct thread_entry;
+/**
+ * Yields thread. should be only called in return statement
+ * @param thread_entry
+ * @param nresults
+ * @return
+ */
+gint
+lua_yield_thread (struct thread_entry *thread_entry, gint nresults);
+
+/**
+ *
+ * @param pool
+ * @param thread_entry
+ * @param narg
+ * @return
+ */
+void
+lua_resume_thread (struct rspamd_task *task, struct thread_entry *thread_entry, gint narg);
+
+
 /* Paths defs */
 #define RSPAMD_CONFDIR_INDEX "CONFDIR"
 #define RSPAMD_RUNDIR_INDEX "RUNDIR"
index 3a38d437b12d907158175121d402d51b65aa6aac..31deed5c52f8b8404852a48b62bd39db0011b4e9 100644 (file)
@@ -1206,6 +1206,9 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
        struct thread_entry *thread_entry = lua_thread_pool_get (task->cfg->lua_thread_pool);
        cd->thread_entry = thread_entry;
 
+       g_assert(thread_entry->cd == NULL);
+       thread_entry->cd = cd;
+
        lua_State *thread = thread_entry->lua_state;
        cd->stack_level = lua_gettop (cd->L);
 
@@ -1232,6 +1235,29 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
        }
 }
 
+gint
+lua_yield_thread (struct thread_entry *thread_entry, gint nresults)
+{
+    g_assert (thread_entry->cd != NULL);
+
+       return lua_yield (thread_entry->lua_state, nresults);
+}
+
+void
+lua_resume_thread (struct rspamd_task *task, struct thread_entry *thread_entry, gint narg)
+{
+    g_assert (thread_entry->cd != NULL);
+
+       gint ret;
+
+       lua_thread_pool_set_running_entry (task->cfg->lua_thread_pool, thread_entry);
+       ret = lua_resume (thread_entry->lua_state, narg);
+
+       if (ret != LUA_YIELD) {
+               lua_metric_symbol_callback_return (task, thread_entry->cd, ret);
+       }
+}
+
 static void
 lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint ret)
 {
@@ -1337,7 +1363,7 @@ lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint r
 
                g_assert (lua_gettop (thread) == cd->stack_level); /* we properly cleaned up the stack */
 
-               lua_thread_pool_return(task->cfg->lua_thread_pool, cd->thread_entry);
+               lua_thread_pool_return (task->cfg->lua_thread_pool, cd->thread_entry);
        }
 
        cd->thread_entry = NULL;
index f6ba88f2ef033e3816595032180438406e610876..8baa3e615309cb5f0bd6599e8a5d9cac454a1106 100644 (file)
@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 #include "lua_common.h"
+#include "lua_thread_pool.h"
 #include "utlist.h"
 
 
@@ -78,6 +79,8 @@ lua_check_dns_resolver (lua_State * L)
 
 struct lua_dns_cbdata {
        lua_State *L;
+       struct thread_entry *thread;
+       struct rspamd_task *task;
        struct rspamd_dns_resolver *resolver;
        gint cbref;
        const gchar *to_resolve;
@@ -140,12 +143,15 @@ lua_dns_callback (struct rdns_reply *reply, gpointer arg)
        struct rdns_reply_entry *elt;
        rspamd_inet_addr_t *addr;
 
-       lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->cbref);
-       presolver = lua_newuserdata (cd->L, sizeof (gpointer));
-       rspamd_lua_setclass (cd->L, "rspamd{resolver}", -1);
+       if (cd->cbref != -1) {
+               lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->cbref);
 
-       *presolver = cd->resolver;
-       lua_pushstring (cd->L, cd->to_resolve);
+               presolver = lua_newuserdata (cd->L, sizeof (gpointer));
+               rspamd_lua_setclass (cd->L, "rspamd{resolver}", -1);
+
+               *presolver = cd->resolver;
+               lua_pushstring (cd->L, cd->to_resolve);
+       }
 
        /*
         * XXX: rework to handle different request types
@@ -227,22 +233,41 @@ lua_dns_callback (struct rdns_reply *reply, gpointer arg)
                lua_pushstring (cd->L, rdns_strerror (reply->code));
        }
 
-       if (cd->user_str != NULL) {
-               lua_pushstring (cd->L, cd->user_str);
-       }
-       else {
-               lua_pushnil (cd->L);
+       if (cd->cbref != -1) {
+               if (cd->user_str != NULL) {
+                       lua_pushstring (cd->L, cd->user_str);
+               }
+               else {
+                       lua_pushnil (cd->L);
+               }
        }
 
        lua_pushboolean (cd->L, reply->authenticated);
 
-       if (lua_pcall (cd->L, 6, 0, 0) != 0) {
-               msg_info ("call to dns callback failed: %s", lua_tostring (cd->L, -1));
-               lua_pop (cd->L, 1);
-       }
+       if (cd->cbref != -1) {
+               /*
+                * 1 - resolver
+                * 2 - to_resolve
+                * 3 - entries | nil
+                * 4 - error | nil
+                * 5 - user_str
+                * 6 - reply->authenticated
+                */
+               if (lua_pcall (cd->L, 6, 0, 0) != 0) {
+                       msg_info ("call to dns callback failed: %s", lua_tostring (cd->L, -1));
+                       lua_pop (cd->L, 1);
+               }
 
-       /* Unref function */
-       luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->cbref);
+               /* Unref function */
+               luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->cbref);
+       } else {
+               /*
+                * 1 - entries
+                * 2 - error | nil
+                * 3 - reply->authenticated
+                */
+               lua_resume_thread (cd->task, cd->thread, 3);
+       }
 
        if (cd->s) {
                rspamd_session_watcher_pop (cd->s, cd->w);
@@ -306,7 +331,7 @@ lua_dns_resolver_resolve_common (lua_State *L,
 
        /* Check arguments */
        if (!rspamd_lua_parse_table_arguments (L, first, &err,
-                       "session=U{session};mempool=U{mempool};*name=S;*callback=F;"
+                       "session=U{session};mempool=U{mempool};*name=S;callback=F;"
                        "option=S;task=U{task};forced=B",
                        &session, &pool, &to_resolve, &cbref, &user_str, &task, &forced)) {
 
@@ -325,7 +350,7 @@ lua_dns_resolver_resolve_common (lua_State *L,
                session = task->s;
        }
 
-       if (pool != NULL && to_resolve != NULL && cbref != -1) {
+       if (pool != NULL && to_resolve != NULL) {
                cbdata = rspamd_mempool_alloc0 (pool, sizeof (struct lua_dns_cbdata));
                cbdata->L = L;
                cbdata->resolver = resolver;
@@ -375,6 +400,9 @@ lua_dns_resolver_resolve_common (lua_State *L,
                        }
                }
                else {
+                       cbdata->thread = lua_thread_pool_get_running_entry (task->cfg->lua_thread_pool);
+                       cbdata->task = task;
+
                        if (forced) {
                                ret = make_dns_request_task_forced (task,
                                                lua_dns_callback,
@@ -391,10 +419,16 @@ lua_dns_resolver_resolve_common (lua_State *L,
                        }
 
                        if (ret) {
-                               lua_pushboolean (L, TRUE);
                                cbdata->s = session;
                                cbdata->w = rspamd_session_get_watcher (session);
                                rspamd_session_watcher_push (session);
+                               if (cbdata->cbref != -1) {
+                                       /* callback was set up */
+                                       lua_pushboolean (L, TRUE);
+                               } else {
+                                       /* this is coroutine-based call */
+                                       return lua_yield_thread (cbdata->thread, 0);
+                               }
                        }
                        else {
                                lua_pushnil (L);
index 07364f27010465f3b45d78c9468374bc697df0f4..865b41fce8492ec9ff7432157a8dae1771cab1e3 100644 (file)
@@ -14,7 +14,7 @@ static struct thread_entry *
 thread_entry_new (lua_State * L)
 {
        struct thread_entry *ent;
-       ent = g_malloc (sizeof *ent);
+       ent = g_new0(struct thread_entry, 1);
        ent->lua_state = lua_newthread (L);
        ent->thread_index = luaL_ref (L, LUA_REGISTRYINDEX);
 
@@ -75,6 +75,8 @@ lua_thread_pool_get(struct lua_thread_pool *pool)
                ent = thread_entry_new (pool->L);
        }
 
+       pool->running_entry = ent;
+
        return ent;
 }
 
@@ -88,6 +90,7 @@ lua_thread_pool_return(struct lua_thread_pool *pool, struct thread_entry *thread
        }
 
        if (g_queue_get_length (pool->available_items) <= pool->max_items) {
+               thread_entry->cd = NULL;
                g_queue_push_head (pool->available_items, thread_entry);
        }
        else {
index 33a0da879b5f3e4be7947d9dd7f02415d16115f4..64708c4884dff3d1fd0d52bed8af191d58383fe9 100644 (file)
@@ -6,6 +6,7 @@
 struct thread_entry {
        lua_State *lua_state;
        gint thread_index;
+       gpointer cd;
 };
 
 struct thread_pool;