diff options
author | Vsevolod Stakhov <vsevolod@rambler-co.ru> | 2011-12-14 20:41:34 +0300 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rambler-co.ru> | 2011-12-14 20:41:34 +0300 |
commit | a9b60057092bcb44934ff515bbf034a6142024a4 (patch) | |
tree | 2f347ef43d85ef78af1bd899c62f4d6a4c70fe1d | |
parent | 78f53ec59aed754795c0d9195a05839adc306ea5 (diff) | |
download | rspamd-a9b60057092bcb44934ff515bbf034a6142024a4.tar.gz rspamd-a9b60057092bcb44934ff515bbf034a6142024a4.zip |
* Now it is possible to specify local functions to every callback of rspamd lua API,
that will allow such things as passing different variables via lua closures
mechanic.
Use config pool for configuration allocation in lua API to avoid leaks on config reload.
-rw-r--r-- | src/lua/lua_common.h | 22 | ||||
-rw-r--r-- | src/lua/lua_config.c | 166 | ||||
-rw-r--r-- | src/lua/lua_task.c | 64 |
3 files changed, 200 insertions, 52 deletions
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index 6ae8421fb..c1891a6a7 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -16,13 +16,33 @@ extern const luaL_reg null_reg[]; -#define RSPAMD_LUA_API_VERSION 8 +#define RSPAMD_LUA_API_VERSION 9 /* Common utility functions */ + +/** + * Create and register new class + */ void lua_newclass (lua_State *L, const gchar *classname, const struct luaL_reg *func); + +/** + * Set class name for object at @param objidx position + */ void lua_setclass (lua_State *L, const gchar *classname, gint objidx); + +/** + * Set index of table to value (like t['index'] = value) + */ void lua_set_table_index (lua_State *L, const gchar *index, const gchar *value); + +/** + * Convert classname to string + */ gint lua_class_tostring (lua_State *L); + +/** + * Open libraries functions + */ gint luaopen_message (lua_State *L); gint luaopen_task (lua_State *L); gint luaopen_config (lua_State *L); diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index 1c3017b6b..cd1287a18 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -285,10 +285,29 @@ lua_config_get_classifier (lua_State * L) } struct lua_callback_data { - const gchar *name; - lua_State *L; + union { + gchar *name; + gint ref; + } callback; + gboolean cb_is_ref; + lua_State *L; + gchar *symbol; }; +/* + * Unref symbol if it is local reference + */ +static void +lua_destroy_cfg_symbol (gpointer ud) +{ + struct lua_callback_data *cd = ud; + + /* Unref callback */ + if (cd->cb_is_ref) { + luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } +} + static gboolean lua_config_function_callback (struct worker_task *task, GList *args, void *user_data) { @@ -299,7 +318,12 @@ lua_config_function_callback (struct worker_task *task, GList *args, void *user_ GList *cur; gboolean res = FALSE; - lua_getglobal (cd->L, cd->name); + if (cd->cb_is_ref) { + lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal (cd->L, cd->callback.name); + } ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); lua_setclass (cd->L, "rspamd{task}", -1); *ptask = task; @@ -313,7 +337,9 @@ lua_config_function_callback (struct worker_task *task, GList *args, void *user_ } if (lua_pcall (cd->L, i, 1, 0) != 0) { - msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1)); + msg_info ("error processing symbol %s: call to %s failed: %s", cd->symbol, + cd->cb_is_ref ? "local function" : + cd->callback.name, lua_tostring (cd->L, -1)); } else { if (lua_isboolean (cd->L, 1)) { @@ -329,19 +355,29 @@ static gint lua_config_register_function (lua_State *L) { struct config_file *cfg = lua_check_config (L); - const gchar *name, *callback; + gchar *name; struct lua_callback_data *cd; if (cfg) { - name = g_strdup (luaL_checkstring (L, 2)); - - callback = luaL_checkstring (L, 3); + name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2)); + cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data)); + + if (lua_type (L, 3) == LUA_TSTRING) { + cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 3)); + cd->cb_is_ref = FALSE; + } + else { + lua_pushvalue (L, 3); + /* Get a reference */ + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + cd->cb_is_ref = TRUE; + } if (name) { - cd = g_malloc (sizeof (struct lua_callback_data)); - cd->name = g_strdup (callback); cd->L = L; + cd->symbol = name; register_expression_function (name, lua_config_function_callback, cd); } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd); } return 1; } @@ -406,13 +442,19 @@ lua_call_post_filters (struct worker_task *task) cur = task->cfg->post_filters; while (cur) { cd = cur->data; - lua_getglobal (cd->L, cd->name); + if (cd->cb_is_ref) { + lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal (cd->L, cd->callback.name); + } ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); lua_setclass (cd->L, "rspamd{task}", -1); *ptask = task; if (lua_pcall (cd->L, 1, 0, 0) != 0) { - msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1)); + msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" : + cd->callback.name, lua_tostring (cd->L, -1)); } cur = g_list_next (cur); } @@ -422,18 +464,23 @@ static gint lua_config_register_post_filter (lua_State *L) { struct config_file *cfg = lua_check_config (L); - const gchar *callback; struct lua_callback_data *cd; if (cfg) { - - callback = luaL_checkstring (L, 2); - if (callback) { - cd = g_malloc (sizeof (struct lua_callback_data)); - cd->name = g_strdup (callback); - cd->L = L; - cfg->post_filters = g_list_prepend (cfg->post_filters, cd); + cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data)); + if (lua_type (L, 2) == LUA_TSTRING) { + cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2)); + cd->cb_is_ref = FALSE; + } + else { + lua_pushvalue (L, 2); + /* Get a reference */ + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + cd->cb_is_ref = TRUE; } + cd->L = L; + cfg->post_filters = g_list_prepend (cfg->post_filters, cd); + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd); } return 1; } @@ -447,12 +494,11 @@ lua_config_add_radix_map (lua_State *L) if (cfg) { map_line = luaL_checkstring (L, 2); - r = g_malloc (sizeof (radix_tree_t *)); + r = memory_pool_alloc (cfg->cfg_pool, sizeof (radix_tree_t *)); *r = radix_tree_create (); if (!add_map (map_line, read_radix_list, fin_radix_list, (void **)r)) { msg_warn ("invalid radix map %s", map_line); radix_tree_free (*r); - g_free (r); lua_pushnil (L); return 1; } @@ -477,15 +523,15 @@ lua_config_add_hash_map (lua_State *L) if (cfg) { map_line = luaL_checkstring (L, 2); - r = g_malloc (sizeof (GHashTable *)); + r = memory_pool_alloc (cfg->cfg_pool, sizeof (GHashTable *)); *r = g_hash_table_new (rspamd_strcase_hash, rspamd_strcase_equal); if (!add_map (map_line, read_host_list, fin_host_list, (void **)r)) { msg_warn ("invalid hash map %s", map_line); g_hash_table_destroy (*r); - g_free (r); lua_pushnil (L); return 1; } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)g_hash_table_destroy, *r); ud = lua_newuserdata (L, sizeof (GHashTable *)); *ud = r; lua_setclass (L, "rspamd{hash_table}", -1); @@ -507,15 +553,15 @@ lua_config_add_kv_map (lua_State *L) if (cfg) { map_line = luaL_checkstring (L, 2); - r = g_malloc (sizeof (GHashTable *)); + r = memory_pool_alloc (cfg->cfg_pool, sizeof (GHashTable *)); *r = g_hash_table_new (rspamd_strcase_hash, rspamd_strcase_equal); if (!add_map (map_line, read_kv_list, fin_kv_list, (void **)r)) { msg_warn ("invalid hash map %s", map_line); g_hash_table_destroy (*r); - g_free (r); lua_pushnil (L); return 1; } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)g_hash_table_destroy, *r); ud = lua_newuserdata (L, sizeof (GHashTable *)); *ud = r; lua_setclass (L, "rspamd{hash_table}", -1); @@ -536,13 +582,20 @@ lua_metric_symbol_callback (struct worker_task *task, gpointer ud) { struct lua_callback_data *cd = ud; struct worker_task **ptask; - lua_getglobal (cd->L, cd->name); + + if (cd->cb_is_ref) { + lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal (cd->L, cd->callback.name); + } ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); lua_setclass (cd->L, "rspamd{task}", -1); *ptask = task; if (lua_pcall (cd->L, 1, 0, 0) != 0) { - msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1)); + msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" : + cd->callback.name, lua_tostring (cd->L, -1)); } } @@ -550,20 +603,30 @@ static gint lua_config_register_symbol (lua_State * L) { struct config_file *cfg = lua_check_config (L); - const gchar *name, *callback; + gchar *name; double weight; struct lua_callback_data *cd; if (cfg) { name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2)); weight = luaL_checknumber (L, 3); - callback = luaL_checkstring (L, 4); + cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data)); + if (lua_type (L, 4) == LUA_TSTRING) { + cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 4)); + cd->cb_is_ref = FALSE; + } + else { + lua_pushvalue (L, 4); + /* Get a reference */ + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + cd->cb_is_ref = TRUE; + } if (name) { - cd = g_malloc (sizeof (struct lua_callback_data)); - cd->name = g_strdup (callback); + cd->symbol = name; cd->L = L; register_symbol (&cfg->cache, name, weight, lua_metric_symbol_callback, cd); } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd); } return 1; } @@ -572,7 +635,7 @@ static gint lua_config_register_virtual_symbol (lua_State * L) { struct config_file *cfg = lua_check_config (L); - const gchar *name; + gchar *name; double weight; if (cfg) { @@ -589,20 +652,30 @@ static gint lua_config_register_callback_symbol (lua_State * L) { struct config_file *cfg = lua_check_config (L); - const gchar *name, *callback; + gchar *name; double weight; struct lua_callback_data *cd; if (cfg) { name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2)); weight = luaL_checknumber (L, 3); - callback = luaL_checkstring (L, 4); + cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data)); + if (lua_type (L, 4) == LUA_TSTRING) { + cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 4)); + cd->cb_is_ref = FALSE; + } + else { + lua_pushvalue (L, 4); + /* Get a reference */ + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + cd->cb_is_ref = TRUE; + } if (name) { - cd = g_malloc (sizeof (struct lua_callback_data)); - cd->name = g_strdup (callback); + cd->symbol = name; cd->L = L; register_callback_symbol (&cfg->cache, name, weight, lua_metric_symbol_callback, cd); } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd); } return 1; } @@ -611,7 +684,7 @@ static gint lua_config_register_callback_symbol_priority (lua_State * L) { struct config_file *cfg = lua_check_config (L); - const gchar *name, *callback; + gchar *name; double weight; gint priority; struct lua_callback_data *cd; @@ -620,14 +693,25 @@ lua_config_register_callback_symbol_priority (lua_State * L) name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2)); weight = luaL_checknumber (L, 3); priority = luaL_checknumber (L, 4); - callback = luaL_checkstring (L, 5); + cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data)); + if (lua_type (L, 5) == LUA_TSTRING) { + cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 5)); + cd->cb_is_ref = FALSE; + } + else { + lua_pushvalue (L, 5); + /* Get a reference */ + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + cd->cb_is_ref = TRUE; + } if (name) { - cd = g_malloc (sizeof (struct lua_callback_data)); - cd->name = g_strdup (callback); cd->L = L; + cd->symbol = name; register_callback_symbol_priority (&cfg->cache, name, weight, priority, lua_metric_symbol_callback, cd); } + memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd); + } return 1; } diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index 19de0f7c8..7cca35c3d 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -444,7 +444,11 @@ lua_task_get_received_headers (lua_State * L) struct lua_dns_callback_data { lua_State *L; struct worker_task *task; - const gchar *callback; + union { + const gchar *cbname; + gint ref; + } callback; + gboolean cb_is_ref; const gchar *to_resolve; gint cbtype; union { @@ -464,7 +468,12 @@ lua_dns_callback (struct rspamd_dns_reply *reply, gpointer arg) union rspamd_reply_element *elt; GList *cur; - lua_getglobal (cd->L, cd->callback); + if (cd->cb_is_ref) { + lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal (cd->L, cd->callback.cbname); + } ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *)); lua_setclass (cd->L, "rspamd{task}", -1); @@ -536,7 +545,13 @@ lua_dns_callback (struct rspamd_dns_reply *reply, gpointer arg) } if (lua_pcall (cd->L, 5, 0, 0) != 0) { - msg_info ("call to %s failed: %s", cd->callback, lua_tostring (cd->L, -1)); + msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" : + cd->callback.cbname, lua_tostring (cd->L, -1)); + } + + /* Unref function */ + if (cd->cb_is_ref) { + luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->callback.ref); } } @@ -551,7 +566,18 @@ lua_task_resolve_dns_a (lua_State * L) cd->task = task; cd->L = L; cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2)); - cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + + /* Check what type we have */ + if (lua_type (L, 3) == LUA_TSTRING) { + cd->cb_is_ref = FALSE; + cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + } + else { + lua_pushvalue (L, 3); + cd->cb_is_ref = TRUE; + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + } + cd->cbtype = lua_type (L, 4); if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) { switch (cd->cbtype) { @@ -565,13 +591,13 @@ lua_task_resolve_dns_a (lua_State * L) cd->cbdata.string = memory_pool_strdup (task->task_pool, lua_tostring (L, 4)); break; default: - msg_warn ("cannot handle type %s as callback data", lua_typename (L, cd->cbtype)); + msg_warn ("cannot handle type %s as callback data, try using closures", lua_typename (L, cd->cbtype)); cd->cbtype = LUA_TNONE; break; } } - if (!cd->to_resolve || !cd->callback) { + if (!cd->to_resolve) { msg_info ("invalid parameters passed to function"); return 0; } @@ -593,7 +619,16 @@ lua_task_resolve_dns_txt (lua_State * L) cd->task = task; cd->L = L; cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2)); - cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + /* Check what type we have */ + if (lua_type (L, 3) == LUA_TSTRING) { + cd->cb_is_ref = FALSE; + cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + } + else { + lua_pushvalue (L, 3); + cd->cb_is_ref = TRUE; + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + } cd->cbtype = lua_type (L, 4); if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) { switch (cd->cbtype) { @@ -612,7 +647,7 @@ lua_task_resolve_dns_txt (lua_State * L) break; } } - if (!cd->to_resolve || !cd->callback) { + if (!cd->to_resolve) { msg_info ("invalid parameters passed to function"); return 0; } @@ -635,7 +670,16 @@ lua_task_resolve_dns_ptr (lua_State * L) cd->task = task; cd->L = L; cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2)); - cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + /* Check what type we have */ + if (lua_type (L, 3) == LUA_TSTRING) { + cd->cb_is_ref = FALSE; + cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3)); + } + else { + lua_pushvalue (L, 3); + cd->cb_is_ref = TRUE; + cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX); + } cd->cbtype = lua_type (L, 4); if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) { switch (cd->cbtype) { @@ -655,7 +699,7 @@ lua_task_resolve_dns_ptr (lua_State * L) } } ina = memory_pool_alloc (task->task_pool, sizeof (struct in_addr)); - if (!cd->to_resolve || !cd->callback || !inet_aton (cd->to_resolve, ina)) { + if (!cd->to_resolve || !inet_aton (cd->to_resolve, ina)) { msg_info ("invalid parameters passed to function"); return 0; } |