diff options
-rw-r--r-- | src/lua/lua_common.h | 22 | ||||
-rw-r--r-- | src/lua/lua_config.c | 166 | ||||
-rw-r--r-- | src/lua/lua_task.c | 64 |
3 files changed, 200 insertions, 52 deletions
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index 6ae8421fb..c1891a6a7 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -16,13 +16,33 @@ extern const luaL_reg null_reg[]; -#define RSPAMD_LUA_API_VERSION 8 +#define RSPAMD_LUA_API_VERSION 9 /* Common utility functions */ + +/** + * Create and register new class + */ void lua_newclass (lua_State *L, const gchar *classname, const struct luaL_reg *func); + +/** + * Set class name for object at @param objidx position + */ void lua_setclass (lua_State *L, const gchar *classname, gint objidx); + +/** + * Set index of table to value (like t['index'] = value) + */ void lua_set_table_index (lua_State *L, const gchar *index, const gchar *value); + +/** + * Convert classname to string + */ gint lua_class_tostring (lua_State *L); + +/** + * Open libraries functions + */ gint luaopen_message (lua_State *L); gint luaopen_task (lua_State *L); gint luaopen_config (lua_State *L); diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index 1c3017b6b..cd1287a18 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -285,10 +285,29 @@ lua_config_get_classifier (lua_State * L) } struct lua_callback_data { - const gchar *name; - lua_State *L; + union { + gchar *name; + gint ref; + } callback; + gboolean cb_is_ref; + lua_State *L; + gchar *symbol; }; +/* + * Unref symbol if it is local reference + */ +static void +lua_destroy_cfg_symbol (gpointer ud) +{ + struct lua_callback_data *cd = ud; + + /* Unref callback */ + if (cd->cb_is_ref) { + luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } +} + static gboolean lua_config_function_callback (struct worker_task *task, GList *args, void *user_data) { @@ -299,7 +318,12 @@ lua_config_function_callback (struct worker_task *task, GList *args, void *user_ GList *cur; gboolean res = FALSE; - lua_getglobal (cd->L, cd->name); + if (cd->cb_is_ref) { + lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal (cd->L, cd->callback.name); + } ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); lua_setclass (cd->L, "rspamd{task}", -1); *ptask = task; @@ -313,7 +337,9 @@ lua_config_function_callback (struct worker_task *task, GList *args, void *user_ } if (lua_pcall (cd->L, i, 1, 0) != 0) { - msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1)); + msg_info ("error processing symbol %s: call to %s failed: %s", cd->symbol, + cd->cb_is_ref ? "local function" : + cd->callback.name, lua_tostring (cd->L, -1)); } else { if (lua_isboolean (cd->L, 1)) { @@ -329,19 +355,29 @@ static gint lua_config_register_function (lua_State *L) { struct config_file *cfg = lua_check_config (L); - const gchar *name, *callback; + gchar *name; struct lua_callback_data *cd; if (cfg) { - name = g_strdup (luaL_checkstring (L, 2)); - - callback = luaL_checkstring (L, 3); + name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2)); + cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data)); + + if (lua_type (L, 3) == LUA_TSTRING) { + cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 3)); + cd->cb_is_ref = FALSE; + } + else { + lua_pushvalue (L, 3); + /* Get a reference */ + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + cd->cb_is_ref = TRUE; + } if (name) { - cd = g_malloc (sizeof (struct lua_callback_data)); - cd->name = g_strdup (callback); cd->L = L; + cd->symbol = name; register_expression_function (name, lua_config_function_callback, cd); } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd); } return 1; } @@ -406,13 +442,19 @@ lua_call_post_filters (struct worker_task *task) cur = task->cfg->post_filters; while (cur) { cd = cur->data; - lua_getglobal (cd->L, cd->name); + if (cd->cb_is_ref) { + lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal (cd->L, cd->callback.name); + } ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); lua_setclass (cd->L, "rspamd{task}", -1); *ptask = task; if (lua_pcall (cd->L, 1, 0, 0) != 0) { - msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1)); + msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" : + cd->callback.name, lua_tostring (cd->L, -1)); } cur = g_list_next (cur); } @@ -422,18 +464,23 @@ static gint lua_config_register_post_filter (lua_State *L) { struct config_file *cfg = lua_check_config (L); - const gchar *callback; struct lua_callback_data *cd; if (cfg) { - - callback = luaL_checkstring (L, 2); - if (callback) { - cd = g_malloc (sizeof (struct lua_callback_data)); - cd->name = g_strdup (callback); - cd->L = L; - cfg->post_filters = g_list_prepend (cfg->post_filters, cd); + cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data)); + if (lua_type (L, 2) == LUA_TSTRING) { + cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2)); + cd->cb_is_ref = FALSE; + } + else { + lua_pushvalue (L, 2); + /* Get a reference */ + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + cd->cb_is_ref = TRUE; } + cd->L = L; + cfg->post_filters = g_list_prepend (cfg->post_filters, cd); + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd); } return 1; } @@ -447,12 +494,11 @@ lua_config_add_radix_map (lua_State *L) if (cfg) { map_line = luaL_checkstring (L, 2); - r = g_malloc (sizeof (radix_tree_t *)); + r = memory_pool_alloc (cfg->cfg_pool, sizeof (radix_tree_t *)); *r = radix_tree_create (); if (!add_map (map_line, read_radix_list, fin_radix_list, (void **)r)) { msg_warn ("invalid radix map %s", map_line); radix_tree_free (*r); - g_free (r); lua_pushnil (L); return 1; } @@ -477,15 +523,15 @@ lua_config_add_hash_map (lua_State *L) if (cfg) { map_line = luaL_checkstring (L, 2); - r = g_malloc (sizeof (GHashTable *)); + r = memory_pool_alloc (cfg->cfg_pool, sizeof (GHashTable *)); *r = g_hash_table_new (rspamd_strcase_hash, rspamd_strcase_equal); if (!add_map (map_line, read_host_list, fin_host_list, (void **)r)) { msg_warn ("invalid hash map %s", map_line); g_hash_table_destroy (*r); - g_free (r); lua_pushnil (L); return 1; } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)g_hash_table_destroy, *r); ud = lua_newuserdata (L, sizeof (GHashTable *)); *ud = r; lua_setclass (L, "rspamd{hash_table}", -1); @@ -507,15 +553,15 @@ lua_config_add_kv_map (lua_State *L) if (cfg) { map_line = luaL_checkstring (L, 2); - r = g_malloc (sizeof (GHashTable *)); + r = memory_pool_alloc (cfg->cfg_pool, sizeof (GHashTable *)); *r = g_hash_table_new (rspamd_strcase_hash, rspamd_strcase_equal); if (!add_map (map_line, read_kv_list, fin_kv_list, (void **)r)) { msg_warn ("invalid hash map %s", map_line); g_hash_table_destroy (*r); - g_free (r); lua_pushnil (L); return 1; } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)g_hash_table_destroy, *r); ud = lua_newuserdata (L, sizeof (GHashTable *)); *ud = r; lua_setclass (L, "rspamd{hash_table}", -1); @@ -536,13 +582,20 @@ lua_metric_symbol_callback (struct worker_task *task, gpointer ud) { struct lua_callback_data *cd = ud; struct worker_task **ptask; - lua_getglobal (cd->L, cd->name); + + if (cd->cb_is_ref) { + lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal (cd->L, cd->callback.name); + } ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); lua_setclass (cd->L, "rspamd{task}", -1); *ptask = task; if (lua_pcall (cd->L, 1, 0, 0) != 0) { - msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1)); + msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" : + cd->callback.name, lua_tostring (cd->L, -1)); } } @@ -550,20 +603,30 @@ static gint lua_config_register_symbol (lua_State * L) { struct config_file *cfg = lua_check_config (L); - const gchar *name, *callback; + gchar *name; double weight; struct lua_callback_data *cd; if (cfg) { name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2)); weight = luaL_checknumber (L, 3); - callback = luaL_checkstring (L, 4); + cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data)); + if (lua_type (L, 4) == LUA_TSTRING) { + cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 4)); + cd->cb_is_ref = FALSE; + } + else { + lua_pushvalue (L, 4); + /* Get a reference */ + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + cd->cb_is_ref = TRUE; + } if (name) { - cd = g_malloc (sizeof (struct lua_callback_data)); - cd->name = g_strdup (callback); + cd->symbol = name; cd->L = L; register_symbol (&cfg->cache, name, weight, lua_metric_symbol_callback, cd); } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd); } return 1; } @@ -572,7 +635,7 @@ static gint lua_config_register_virtual_symbol (lua_State * L) { struct config_file *cfg = lua_check_config (L); - const gchar *name; + gchar *name; double weight; if (cfg) { @@ -589,20 +652,30 @@ static gint lua_config_register_callback_symbol (lua_State * L) { struct config_file *cfg = lua_check_config (L); - const gchar *name, *callback; + gchar *name; double weight; struct lua_callback_data *cd; if (cfg) { name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2)); weight = luaL_checknumber (L, 3); - callback = luaL_checkstring (L, 4); + cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data)); + if (lua_type (L, 4) == LUA_TSTRING) { + cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 4)); + cd->cb_is_ref = FALSE; + } + else { + lua_pushvalue (L, 4); + /* Get a reference */ + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + cd->cb_is_ref = TRUE; + } if (name) { - cd = g_malloc (sizeof (struct lua_callback_data)); - cd->name = g_strdup (callback); + cd->symbol = name; cd->L = L; register_callback_symbol (&cfg->cache, name, weight, lua_metric_symbol_callback, cd); } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd); } return 1; } @@ -611,7 +684,7 @@ static gint lua_config_register_callback_symbol_priority (lua_State * L) { struct config_file *cfg = lua_check_config (L); - const gchar *name, *callback; + gchar *name; double weight; gint priority; struct lua_callback_data *cd; @@ -620,14 +693,25 @@ lua_config_register_callback_symbol_priority (lua_State * L) name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2)); weight = luaL_checknumber (L, 3); priority = luaL_checknumber (L, 4); - callback = luaL_checkstring (L, 5); + cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data)); + if (lua_type (L, 5) == LUA_TSTRING) { + cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 5)); + cd->cb_is_ref = FALSE; + } + else { + lua_pushvalue (L, 5); + /* Get a reference */ + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + cd->cb_is_ref = TRUE; + } if (name) { - cd = g_malloc (sizeof (struct lua_callback_data)); - cd->name = g_strdup (callback); cd->L = L; + cd->symbol = name; register_callback_symbol_priority (&cfg->cache, name, weight, priority, lua_metric_symbol_callback, cd); } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd); + } return 1; } diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index 19de0f7c8..7cca35c3d 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -444,7 +444,11 @@ lua_task_get_received_headers (lua_State * L) struct lua_dns_callback_data { lua_State *L; struct worker_task *task; - const gchar *callback; + union { + const gchar *cbname; + gint ref; + } callback; + gboolean cb_is_ref; const gchar *to_resolve; gint cbtype; union { @@ -464,7 +468,12 @@ lua_dns_callback (struct rspamd_dns_reply *reply, gpointer arg) union rspamd_reply_element *elt; GList *cur; - lua_getglobal (cd->L, cd->callback); + if (cd->cb_is_ref) { + lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal (cd->L, cd->callback.cbname); + } ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); lua_setclass (cd->L, "rspamd{task}", -1); @@ -536,7 +545,13 @@ lua_dns_callback (struct rspamd_dns_reply *reply, gpointer arg) } if (lua_pcall (cd->L, 5, 0, 0) != 0) { - msg_info ("call to %s failed: %s", cd->callback, lua_tostring (cd->L, -1)); + msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" : + cd->callback.cbname, lua_tostring (cd->L, -1)); + } + + /* Unref function */ + if (cd->cb_is_ref) { + luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); } } @@ -551,7 +566,18 @@ lua_task_resolve_dns_a (lua_State * L) cd->task = task; cd->L = L; cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2)); - cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + + /* Check what type we have */ + if (lua_type (L, 3) == LUA_TSTRING) { + cd->cb_is_ref = FALSE; + cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + } + else { + lua_pushvalue (L, 3); + cd->cb_is_ref = TRUE; + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + } + cd->cbtype = lua_type (L, 4); if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) { switch (cd->cbtype) { @@ -565,13 +591,13 @@ lua_task_resolve_dns_a (lua_State * L) cd->cbdata.string = memory_pool_strdup (task->task_pool, lua_tostring (L, 4)); break; default: - msg_warn ("cannot handle type %s as callback data", lua_typename (L, cd->cbtype)); + msg_warn ("cannot handle type %s as callback data, try using closures", lua_typename (L, cd->cbtype)); cd->cbtype = LUA_TNONE; break; } } - if (!cd->to_resolve || !cd->callback) { + if (!cd->to_resolve) { msg_info ("invalid parameters passed to function"); return 0; } @@ -593,7 +619,16 @@ lua_task_resolve_dns_txt (lua_State * L) cd->task = task; cd->L = L; cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2)); - cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + /* Check what type we have */ + if (lua_type (L, 3) == LUA_TSTRING) { + cd->cb_is_ref = FALSE; + cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + } + else { + lua_pushvalue (L, 3); + cd->cb_is_ref = TRUE; + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + } cd->cbtype = lua_type (L, 4); if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) { switch (cd->cbtype) { @@ -612,7 +647,7 @@ lua_task_resolve_dns_txt (lua_State * L) break; } } - if (!cd->to_resolve || !cd->callback) { + if (!cd->to_resolve) { msg_info ("invalid parameters passed to function"); return 0; } @@ -635,7 +670,16 @@ lua_task_resolve_dns_ptr (lua_State * L) cd->task = task; cd->L = L; cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2)); - cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + /* Check what type we have */ + if (lua_type (L, 3) == LUA_TSTRING) { + cd->cb_is_ref = FALSE; + cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + } + else { + lua_pushvalue (L, 3); + cd->cb_is_ref = TRUE; + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + } cd->cbtype = lua_type (L, 4); if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) { switch (cd->cbtype) { @@ -655,7 +699,7 @@ lua_task_resolve_dns_ptr (lua_State * L) } } ina = memory_pool_alloc (task->task_pool, sizeof (struct in_addr)); - if (!cd->to_resolve || !cd->callback || !inet_aton (cd->to_resolve, ina)) { + if (!cd->to_resolve || !inet_aton (cd->to_resolve, ina)) { msg_info ("invalid parameters passed to function"); return 0; } |