Pārlūkot izejas kodu

[Minor] Use coroutine model in dns resolver API

tags/1.8.0
Mikhail Galanin pirms 5 gadiem
vecāks
revīzija
9926cc68e2
5 mainītis faili ar 106 papildinājumiem un 21 dzēšanām
  1. 21
    0
      src/lua/lua_common.h
  2. 27
    1
      src/lua/lua_config.c
  3. 53
    19
      src/lua/lua_dns.c
  4. 4
    1
      src/lua/lua_thread_pool.c
  5. 1
    0
      src/lua/lua_thread_pool.h

+ 21
- 0
src/lua/lua_common.h Parādīt failu

@@ -408,6 +408,27 @@ void rspamd_lua_add_ref_dtor (lua_State *L, rspamd_mempool_t *pool,
gboolean rspamd_lua_require_function (lua_State *L, const gchar *modname,
const gchar *funcname);

struct thread_entry;
/**
* Yields thread. should be only called in return statement
* @param thread_entry
* @param nresults
* @return
*/
gint
lua_yield_thread (struct thread_entry *thread_entry, gint nresults);

/**
*
* @param pool
* @param thread_entry
* @param narg
* @return
*/
void
lua_resume_thread (struct rspamd_task *task, struct thread_entry *thread_entry, gint narg);


/* Paths defs */
#define RSPAMD_CONFDIR_INDEX "CONFDIR"
#define RSPAMD_RUNDIR_INDEX "RUNDIR"

+ 27
- 1
src/lua/lua_config.c Parādīt failu

@@ -1206,6 +1206,9 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
struct thread_entry *thread_entry = lua_thread_pool_get (task->cfg->lua_thread_pool);
cd->thread_entry = thread_entry;

g_assert(thread_entry->cd == NULL);
thread_entry->cd = cd;

lua_State *thread = thread_entry->lua_state;
cd->stack_level = lua_gettop (cd->L);

@@ -1232,6 +1235,29 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
}
}

gint
lua_yield_thread (struct thread_entry *thread_entry, gint nresults)
{
g_assert (thread_entry->cd != NULL);

return lua_yield (thread_entry->lua_state, nresults);
}

void
lua_resume_thread (struct rspamd_task *task, struct thread_entry *thread_entry, gint narg)
{
g_assert (thread_entry->cd != NULL);

gint ret;

lua_thread_pool_set_running_entry (task->cfg->lua_thread_pool, thread_entry);
ret = lua_resume (thread_entry->lua_state, narg);

if (ret != LUA_YIELD) {
lua_metric_symbol_callback_return (task, thread_entry->cd, ret);
}
}

static void
lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint ret)
{
@@ -1337,7 +1363,7 @@ lua_metric_symbol_callback_return (struct rspamd_task *task, gpointer ud, gint r

g_assert (lua_gettop (thread) == cd->stack_level); /* we properly cleaned up the stack */

lua_thread_pool_return(task->cfg->lua_thread_pool, cd->thread_entry);
lua_thread_pool_return (task->cfg->lua_thread_pool, cd->thread_entry);
}

cd->thread_entry = NULL;

+ 53
- 19
src/lua/lua_dns.c Parādīt failu

@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "lua_common.h"
#include "lua_thread_pool.h"
#include "utlist.h"


@@ -78,6 +79,8 @@ lua_check_dns_resolver (lua_State * L)

struct lua_dns_cbdata {
lua_State *L;
struct thread_entry *thread;
struct rspamd_task *task;
struct rspamd_dns_resolver *resolver;
gint cbref;
const gchar *to_resolve;
@@ -140,12 +143,15 @@ lua_dns_callback (struct rdns_reply *reply, gpointer arg)
struct rdns_reply_entry *elt;
rspamd_inet_addr_t *addr;

lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->cbref);
presolver = lua_newuserdata (cd->L, sizeof (gpointer));
rspamd_lua_setclass (cd->L, "rspamd{resolver}", -1);
if (cd->cbref != -1) {
lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->cbref);

*presolver = cd->resolver;
lua_pushstring (cd->L, cd->to_resolve);
presolver = lua_newuserdata (cd->L, sizeof (gpointer));
rspamd_lua_setclass (cd->L, "rspamd{resolver}", -1);

*presolver = cd->resolver;
lua_pushstring (cd->L, cd->to_resolve);
}

/*
* XXX: rework to handle different request types
@@ -227,22 +233,41 @@ lua_dns_callback (struct rdns_reply *reply, gpointer arg)
lua_pushstring (cd->L, rdns_strerror (reply->code));
}

if (cd->user_str != NULL) {
lua_pushstring (cd->L, cd->user_str);
}
else {
lua_pushnil (cd->L);
if (cd->cbref != -1) {
if (cd->user_str != NULL) {
lua_pushstring (cd->L, cd->user_str);
}
else {
lua_pushnil (cd->L);
}
}

lua_pushboolean (cd->L, reply->authenticated);

if (lua_pcall (cd->L, 6, 0, 0) != 0) {
msg_info ("call to dns callback failed: %s", lua_tostring (cd->L, -1));
lua_pop (cd->L, 1);
}
if (cd->cbref != -1) {
/*
* 1 - resolver
* 2 - to_resolve
* 3 - entries | nil
* 4 - error | nil
* 5 - user_str
* 6 - reply->authenticated
*/
if (lua_pcall (cd->L, 6, 0, 0) != 0) {
msg_info ("call to dns callback failed: %s", lua_tostring (cd->L, -1));
lua_pop (cd->L, 1);
}

/* Unref function */
luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->cbref);
/* Unref function */
luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->cbref);
} else {
/*
* 1 - entries
* 2 - error | nil
* 3 - reply->authenticated
*/
lua_resume_thread (cd->task, cd->thread, 3);
}

if (cd->s) {
rspamd_session_watcher_pop (cd->s, cd->w);
@@ -306,7 +331,7 @@ lua_dns_resolver_resolve_common (lua_State *L,

/* Check arguments */
if (!rspamd_lua_parse_table_arguments (L, first, &err,
"session=U{session};mempool=U{mempool};*name=S;*callback=F;"
"session=U{session};mempool=U{mempool};*name=S;callback=F;"
"option=S;task=U{task};forced=B",
&session, &pool, &to_resolve, &cbref, &user_str, &task, &forced)) {

@@ -325,7 +350,7 @@ lua_dns_resolver_resolve_common (lua_State *L,
session = task->s;
}

if (pool != NULL && to_resolve != NULL && cbref != -1) {
if (pool != NULL && to_resolve != NULL) {
cbdata = rspamd_mempool_alloc0 (pool, sizeof (struct lua_dns_cbdata));
cbdata->L = L;
cbdata->resolver = resolver;
@@ -375,6 +400,9 @@ lua_dns_resolver_resolve_common (lua_State *L,
}
}
else {
cbdata->thread = lua_thread_pool_get_running_entry (task->cfg->lua_thread_pool);
cbdata->task = task;

if (forced) {
ret = make_dns_request_task_forced (task,
lua_dns_callback,
@@ -391,10 +419,16 @@ lua_dns_resolver_resolve_common (lua_State *L,
}

if (ret) {
lua_pushboolean (L, TRUE);
cbdata->s = session;
cbdata->w = rspamd_session_get_watcher (session);
rspamd_session_watcher_push (session);
if (cbdata->cbref != -1) {
/* callback was set up */
lua_pushboolean (L, TRUE);
} else {
/* this is coroutine-based call */
return lua_yield_thread (cbdata->thread, 0);
}
}
else {
lua_pushnil (L);

+ 4
- 1
src/lua/lua_thread_pool.c Parādīt failu

@@ -14,7 +14,7 @@ static struct thread_entry *
thread_entry_new (lua_State * L)
{
struct thread_entry *ent;
ent = g_malloc (sizeof *ent);
ent = g_new0(struct thread_entry, 1);
ent->lua_state = lua_newthread (L);
ent->thread_index = luaL_ref (L, LUA_REGISTRYINDEX);

@@ -75,6 +75,8 @@ lua_thread_pool_get(struct lua_thread_pool *pool)
ent = thread_entry_new (pool->L);
}

pool->running_entry = ent;

return ent;
}

@@ -88,6 +90,7 @@ lua_thread_pool_return(struct lua_thread_pool *pool, struct thread_entry *thread
}

if (g_queue_get_length (pool->available_items) <= pool->max_items) {
thread_entry->cd = NULL;
g_queue_push_head (pool->available_items, thread_entry);
}
else {

+ 1
- 0
src/lua/lua_thread_pool.h Parādīt failu

@@ -6,6 +6,7 @@
struct thread_entry {
lua_State *lua_state;
gint thread_index;
gpointer cd;
};

struct thread_pool;

Notiek ielāde…
Atcelt
Saglabāt