diff options
Diffstat (limited to 'src/lua/lua_task.c')
-rw-r--r-- | src/lua/lua_task.c | 64 |
1 files changed, 54 insertions, 10 deletions
diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index 19de0f7c8..7cca35c3d 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -444,7 +444,11 @@ lua_task_get_received_headers (lua_State * L) struct lua_dns_callback_data { lua_State *L; struct worker_task *task; - const gchar *callback; + union { + const gchar *cbname; + gint ref; + } callback; + gboolean cb_is_ref; const gchar *to_resolve; gint cbtype; union { @@ -464,7 +468,12 @@ lua_dns_callback (struct rspamd_dns_reply *reply, gpointer arg) union rspamd_reply_element *elt; GList *cur; - lua_getglobal (cd->L, cd->callback); + if (cd->cb_is_ref) { + lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal (cd->L, cd->callback.cbname); + } ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); lua_setclass (cd->L, "rspamd{task}", -1); @@ -536,7 +545,13 @@ lua_dns_callback (struct rspamd_dns_reply *reply, gpointer arg) } if (lua_pcall (cd->L, 5, 0, 0) != 0) { - msg_info ("call to %s failed: %s", cd->callback, lua_tostring (cd->L, -1)); + msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" : + cd->callback.cbname, lua_tostring (cd->L, -1)); + } + + /* Unref function */ + if (cd->cb_is_ref) { + luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); } } @@ -551,7 +566,18 @@ lua_task_resolve_dns_a (lua_State * L) cd->task = task; cd->L = L; cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2)); - cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + + /* Check what type we have */ + if (lua_type (L, 3) == LUA_TSTRING) { + cd->cb_is_ref = FALSE; + cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + } + else { + lua_pushvalue (L, 3); + cd->cb_is_ref = TRUE; + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + } + cd->cbtype = lua_type (L, 4); if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) { switch (cd->cbtype) { @@ -565,13 +591,13 @@ lua_task_resolve_dns_a (lua_State * L) cd->cbdata.string = memory_pool_strdup (task->task_pool, lua_tostring (L, 4)); break; default: - msg_warn ("cannot handle type %s as callback data", lua_typename (L, cd->cbtype)); + msg_warn ("cannot handle type %s as callback data, try using closures", lua_typename (L, cd->cbtype)); cd->cbtype = LUA_TNONE; break; } } - if (!cd->to_resolve || !cd->callback) { + if (!cd->to_resolve) { msg_info ("invalid parameters passed to function"); return 0; } @@ -593,7 +619,16 @@ lua_task_resolve_dns_txt (lua_State * L) cd->task = task; cd->L = L; cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2)); - cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + /* Check what type we have */ + if (lua_type (L, 3) == LUA_TSTRING) { + cd->cb_is_ref = FALSE; + cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + } + else { + lua_pushvalue (L, 3); + cd->cb_is_ref = TRUE; + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + } cd->cbtype = lua_type (L, 4); if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) { switch (cd->cbtype) { @@ -612,7 +647,7 @@ lua_task_resolve_dns_txt (lua_State * L) break; } } - if (!cd->to_resolve || !cd->callback) { + if (!cd->to_resolve) { msg_info ("invalid parameters passed to function"); return 0; } @@ -635,7 +670,16 @@ lua_task_resolve_dns_ptr (lua_State * L) cd->task = task; cd->L = L; cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2)); - cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + /* Check what type we have */ + if (lua_type (L, 3) == LUA_TSTRING) { + cd->cb_is_ref = FALSE; + cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + } + else { + lua_pushvalue (L, 3); + cd->cb_is_ref = TRUE; + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + } cd->cbtype = lua_type (L, 4); if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) { switch (cd->cbtype) { @@ -655,7 +699,7 @@ lua_task_resolve_dns_ptr (lua_State * L) } } ina = memory_pool_alloc (task->task_pool, sizeof (struct in_addr)); - if (!cd->to_resolve || !cd->callback || !inet_aton (cd->to_resolve, ina)) { + if (!cd->to_resolve || !inet_aton (cd->to_resolve, ina)) { msg_info ("invalid parameters passed to function"); return 0; } |