aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMikhail Galanin <mgalanin@mimecast.com>2018-08-14 14:36:05 +0100
committerMikhail Galanin <mgalanin@mimecast.com>2018-08-14 14:36:05 +0100
commit9926cc68e2c143be8a05b91e28ba8830abfea04a (patch)
treeee8997d9488c9dd67a5a77789377cbd517275d56 /src
parentf9d4b50321057009489dbc673b108e6433f4ae38 (diff)
downloadrspamd-9926cc68e2c143be8a05b91e28ba8830abfea04a.tar.gz
rspamd-9926cc68e2c143be8a05b91e28ba8830abfea04a.zip
[Minor] Use coroutine model in dns resolver API
Diffstat (limited to 'src')
-rw-r--r--src/lua/lua_common.h21
-rw-r--r--src/lua/lua_config.c28
-rw-r--r--src/lua/lua_dns.c72
-rw-r--r--src/lua/lua_thread_pool.c5
-rw-r--r--src/lua/lua_thread_pool.h1
5 files changed, 106 insertions, 21 deletions
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h
index 45baa5549..149642597 100644
--- a/src/lua/lua_common.h
+++ b/src/lua/lua_common.h
@@ -408,6 +408,27 @@ void rspamd_lua_add_ref_dtor (lua_State *L, rspamd_mempool_t *pool,
gboolean rspamd_lua_require_function (lua_State *L, const gchar *modname,
const gchar *funcname);
+struct thread_entry;
+/**
+ * Yields thread. should be only called in return statement
+ * @param thread_entry
+ * @param nresults
+ * @return
+ */
+gint
+lua_yield_thread (struct thread_entry *thread_entry, gint nresults);
+
+/**
+ *
+ * @param pool
+ * @param thread_entry
+ * @param narg
+ * @return
+ */
+void
+lua_resume_thread (struct rspamd_task *task, struct thread_entry *thread_entry, gint narg);
+
+
/* Paths defs */
#define RSPAMD_CONFDIR_INDEX "CONFDIR"
#define RSPAMD_RUNDIR_INDEX "RUNDIR"
diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c
index 3a38d437b..31deed5c5 100644
--- a/src/lua/lua_config.c
+++ b/src/lua/lua_config.c
@@ -1206,6 +1206,9 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
struct thread_entry *thread_entry = lua_thread_pool_get (task->cfg->lua_thread_pool);
cd->thread_entry = thread_entry;
+ g_assert(thread_entry->cd == NULL);
+ thread_entry->cd = cd;
+
lua_State *thread = thread_entry->lua_state;
cd->stack_level = lua_gettop (cd->L);
@@ -1232,6 +1235,29 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
}
}
+gint
+lua_yield_thread (struct thread_entry *thread_entry, gint nresults)
+{
+ g_assert (thread_entry->cd != NULL);
+
+ return lua_yield (thread_entry->lua_state, nresults);
+}
+
+void
+lua_resume_thread (struct rspamd_task *task, struct thread_entry *thread_entry, gint narg)
+{
+ g_assert (thread_entry->cd != NULL);
+
+ gint ret;
+
+ lua_thread_pool_set_running_entry (task->cfg->lua_thread_pool, thread_entry);
+ ret = lua_resume (thread_entry->lua_state, narg);
+
+ if (ret != LUA_YIELD) {
+ lua_metric_symbol_callback_return (task, thread_entry->cd, ret);
+ }
+}
+
static void
lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint ret)
{
@@ -1337,7 +1363,7 @@ lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint r
g_assert (lua_gettop (thread) == cd->stack_level); /* we properly cleaned up the stack */
- lua_thread_pool_return(task->cfg->lua_thread_pool, cd->thread_entry);
+ lua_thread_pool_return (task->cfg->lua_thread_pool, cd->thread_entry);
}
cd->thread_entry = NULL;
diff --git a/src/lua/lua_dns.c b/src/lua/lua_dns.c
index f6ba88f2e..8baa3e615 100644
--- a/src/lua/lua_dns.c
+++ b/src/lua/lua_dns.c
@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "lua_common.h"
+#include "lua_thread_pool.h"
#include "utlist.h"
@@ -78,6 +79,8 @@ lua_check_dns_resolver (lua_State * L)
struct lua_dns_cbdata {
lua_State *L;
+ struct thread_entry *thread;
+ struct rspamd_task *task;
struct rspamd_dns_resolver *resolver;
gint cbref;
const gchar *to_resolve;
@@ -140,12 +143,15 @@ lua_dns_callback (struct rdns_reply *reply, gpointer arg)
struct rdns_reply_entry *elt;
rspamd_inet_addr_t *addr;
- lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->cbref);
- presolver = lua_newuserdata (cd->L, sizeof (gpointer));
- rspamd_lua_setclass (cd->L, "rspamd{resolver}", -1);
+ if (cd->cbref != -1) {
+ lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->cbref);
- *presolver = cd->resolver;
- lua_pushstring (cd->L, cd->to_resolve);
+ presolver = lua_newuserdata (cd->L, sizeof (gpointer));
+ rspamd_lua_setclass (cd->L, "rspamd{resolver}", -1);
+
+ *presolver = cd->resolver;
+ lua_pushstring (cd->L, cd->to_resolve);
+ }
/*
* XXX: rework to handle different request types
@@ -227,22 +233,41 @@ lua_dns_callback (struct rdns_reply *reply, gpointer arg)
lua_pushstring (cd->L, rdns_strerror (reply->code));
}
- if (cd->user_str != NULL) {
- lua_pushstring (cd->L, cd->user_str);
- }
- else {
- lua_pushnil (cd->L);
+ if (cd->cbref != -1) {
+ if (cd->user_str != NULL) {
+ lua_pushstring (cd->L, cd->user_str);
+ }
+ else {
+ lua_pushnil (cd->L);
+ }
}
lua_pushboolean (cd->L, reply->authenticated);
- if (lua_pcall (cd->L, 6, 0, 0) != 0) {
- msg_info ("call to dns callback failed: %s", lua_tostring (cd->L, -1));
- lua_pop (cd->L, 1);
- }
+ if (cd->cbref != -1) {
+ /*
+ * 1 - resolver
+ * 2 - to_resolve
+ * 3 - entries | nil
+ * 4 - error | nil
+ * 5 - user_str
+ * 6 - reply->authenticated
+ */
+ if (lua_pcall (cd->L, 6, 0, 0) != 0) {
+ msg_info ("call to dns callback failed: %s", lua_tostring (cd->L, -1));
+ lua_pop (cd->L, 1);
+ }
- /* Unref function */
- luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->cbref);
+ /* Unref function */
+ luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->cbref);
+ } else {
+ /*
+ * 1 - entries
+ * 2 - error | nil
+ * 3 - reply->authenticated
+ */
+ lua_resume_thread (cd->task, cd->thread, 3);
+ }
if (cd->s) {
rspamd_session_watcher_pop (cd->s, cd->w);
@@ -306,7 +331,7 @@ lua_dns_resolver_resolve_common (lua_State *L,
/* Check arguments */
if (!rspamd_lua_parse_table_arguments (L, first, &err,
- "session=U{session};mempool=U{mempool};*name=S;*callback=F;"
+ "session=U{session};mempool=U{mempool};*name=S;callback=F;"
"option=S;task=U{task};forced=B",
&session, &pool, &to_resolve, &cbref, &user_str, &task, &forced)) {
@@ -325,7 +350,7 @@ lua_dns_resolver_resolve_common (lua_State *L,
session = task->s;
}
- if (pool != NULL && to_resolve != NULL && cbref != -1) {
+ if (pool != NULL && to_resolve != NULL) {
cbdata = rspamd_mempool_alloc0 (pool, sizeof (struct lua_dns_cbdata));
cbdata->L = L;
cbdata->resolver = resolver;
@@ -375,6 +400,9 @@ lua_dns_resolver_resolve_common (lua_State *L,
}
}
else {
+ cbdata->thread = lua_thread_pool_get_running_entry (task->cfg->lua_thread_pool);
+ cbdata->task = task;
+
if (forced) {
ret = make_dns_request_task_forced (task,
lua_dns_callback,
@@ -391,10 +419,16 @@ lua_dns_resolver_resolve_common (lua_State *L,
}
if (ret) {
- lua_pushboolean (L, TRUE);
cbdata->s = session;
cbdata->w = rspamd_session_get_watcher (session);
rspamd_session_watcher_push (session);
+ if (cbdata->cbref != -1) {
+ /* callback was set up */
+ lua_pushboolean (L, TRUE);
+ } else {
+ /* this is coroutine-based call */
+ return lua_yield_thread (cbdata->thread, 0);
+ }
}
else {
lua_pushnil (L);
diff --git a/src/lua/lua_thread_pool.c b/src/lua/lua_thread_pool.c
index 07364f270..865b41fce 100644
--- a/src/lua/lua_thread_pool.c
+++ b/src/lua/lua_thread_pool.c
@@ -14,7 +14,7 @@ static struct thread_entry *
thread_entry_new (lua_State * L)
{
struct thread_entry *ent;
- ent = g_malloc (sizeof *ent);
+ ent = g_new0(struct thread_entry, 1);
ent->lua_state = lua_newthread (L);
ent->thread_index = luaL_ref (L, LUA_REGISTRYINDEX);
@@ -75,6 +75,8 @@ lua_thread_pool_get(struct lua_thread_pool *pool)
ent = thread_entry_new (pool->L);
}
+ pool->running_entry = ent;
+
return ent;
}
@@ -88,6 +90,7 @@ lua_thread_pool_return(struct lua_thread_pool *pool, struct thread_entry *thread
}
if (g_queue_get_length (pool->available_items) <= pool->max_items) {
+ thread_entry->cd = NULL;
g_queue_push_head (pool->available_items, thread_entry);
}
else {
diff --git a/src/lua/lua_thread_pool.h b/src/lua/lua_thread_pool.h
index 33a0da879..64708c488 100644
--- a/src/lua/lua_thread_pool.h
+++ b/src/lua/lua_thread_pool.h
@@ -6,6 +6,7 @@
struct thread_entry {
lua_State *lua_state;
gint thread_index;
+ gpointer cd;
};
struct thread_pool;