]> source.dussan.org Git - rspamd.git/commitdiff
* Now it is possible to specify local functions to every callback of rspamd lua API,
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Wed, 14 Dec 2011 17:41:34 +0000 (20:41 +0300)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Wed, 14 Dec 2011 17:41:34 +0000 (20:41 +0300)
  that will allow such things as passing different variables via lua closures
  mechanic.

Use config pool for configuration allocation in lua API to avoid leaks on config reload.

src/lua/lua_common.h
src/lua/lua_config.c
src/lua/lua_task.c

index 6ae8421fb499ede55ae7a094cabe187e4701b7b4..c1891a6a7453c57b9e191f9f2492174d2921723b 100644 (file)
 
 extern const luaL_reg null_reg[];
 
-#define RSPAMD_LUA_API_VERSION 8
+#define RSPAMD_LUA_API_VERSION 9
 
 /* Common utility functions */
+
+/**
+ * Create and register new class
+ */
 void lua_newclass (lua_State *L, const gchar *classname, const struct luaL_reg *func);
+
+/**
+ * Set class name for object at @param objidx position
+ */
 void lua_setclass (lua_State *L, const gchar *classname, gint objidx);
+
+/**
+ * Set index of table to value (like t['index'] = value)
+ */
 void lua_set_table_index (lua_State *L, const gchar *index, const gchar *value);
+
+/**
+ * Convert classname to string
+ */
 gint lua_class_tostring (lua_State *L);
+
+/**
+ * Open libraries functions
+ */
 gint luaopen_message (lua_State *L);
 gint luaopen_task (lua_State *L);
 gint luaopen_config (lua_State *L);
index 1c3017b6bce4fec9e8b4bd563517dd43b881be4d..cd1287a188d20dd514f65f9e278e7fffb9446f18 100644 (file)
@@ -285,10 +285,29 @@ lua_config_get_classifier (lua_State * L)
 }
 
 struct lua_callback_data {
-       const gchar                     *name;
-       lua_State                      *L;
+       union {
+               gchar                                           *name;
+               gint                                             ref;
+       } callback;
+       gboolean                                                 cb_is_ref;
+       lua_State                                               *L;
+       gchar                                                   *symbol;
 };
 
+/*
+ * Unref symbol if it is local reference
+ */
+static void
+lua_destroy_cfg_symbol (gpointer ud)
+{
+       struct lua_callback_data       *cd = ud;
+
+       /* Unref callback */
+       if (cd->cb_is_ref) {
+               luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->callback.ref);
+       }
+}
+
 static gboolean
 lua_config_function_callback (struct worker_task *task, GList *args, void *user_data)
 {
@@ -299,7 +318,12 @@ lua_config_function_callback (struct worker_task *task, GList *args, void *user_
        GList                          *cur;
        gboolean                        res = FALSE;
 
-       lua_getglobal (cd->L, cd->name);
+       if (cd->cb_is_ref) {
+               lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref);
+       }
+       else {
+               lua_getglobal (cd->L, cd->callback.name);
+       }
        ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *));
        lua_setclass (cd->L, "rspamd{task}", -1);
        *ptask = task;
@@ -313,7 +337,9 @@ lua_config_function_callback (struct worker_task *task, GList *args, void *user_
        }
 
        if (lua_pcall (cd->L, i, 1, 0) != 0) {
-               msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1));
+               msg_info ("error processing symbol %s: call to %s failed: %s", cd->symbol,
+                                               cd->cb_is_ref ? "local function" :
+                                               cd->callback.name, lua_tostring (cd->L, -1));
        }
        else {
                if (lua_isboolean (cd->L, 1)) {
@@ -329,19 +355,29 @@ static gint
 lua_config_register_function (lua_State *L)
 {
        struct config_file             *cfg = lua_check_config (L);
-       const gchar                     *name, *callback;
+       gchar                          *name;
        struct lua_callback_data       *cd;
        
        if (cfg) {
-               name = g_strdup (luaL_checkstring (L, 2));
-       
-               callback = luaL_checkstring (L, 3);
+               name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2));
+               cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data));
+
+               if (lua_type (L, 3) == LUA_TSTRING) {
+                       cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 3));
+                       cd->cb_is_ref = FALSE;
+               }
+               else {
+                       lua_pushvalue (L, 3);
+                       /* Get a reference */
+                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
+                       cd->cb_is_ref = TRUE;
+               }
                if (name) {
-                       cd = g_malloc (sizeof (struct lua_callback_data));
-                       cd->name = g_strdup (callback);
                        cd->L = L;
+                       cd->symbol = name;
                        register_expression_function (name, lua_config_function_callback, cd);
                }
+               memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd);
        }
        return 1;
 }
@@ -406,13 +442,19 @@ lua_call_post_filters (struct worker_task *task)
        cur = task->cfg->post_filters;
        while (cur) {
                cd = cur->data;
-               lua_getglobal (cd->L, cd->name);
+               if (cd->cb_is_ref) {
+                       lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref);
+               }
+               else {
+                       lua_getglobal (cd->L, cd->callback.name);
+               }
                ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *));
                lua_setclass (cd->L, "rspamd{task}", -1);
                *ptask = task;
 
                if (lua_pcall (cd->L, 1, 0, 0) != 0) {
-                       msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1));
+                       msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" :
+                                                       cd->callback.name, lua_tostring (cd->L, -1));
                }
                cur = g_list_next (cur);
        }
@@ -422,18 +464,23 @@ static gint
 lua_config_register_post_filter (lua_State *L)
 {
        struct config_file             *cfg = lua_check_config (L);
-       const gchar                     *callback;
        struct lua_callback_data       *cd;
 
        if (cfg) {
-
-               callback = luaL_checkstring (L, 2);
-               if (callback) {
-                       cd = g_malloc (sizeof (struct lua_callback_data));
-                       cd->name = g_strdup (callback);
-                       cd->L = L;
-                       cfg->post_filters = g_list_prepend (cfg->post_filters, cd);
+               cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data));
+               if (lua_type (L, 2) == LUA_TSTRING) {
+                       cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2));
+                       cd->cb_is_ref = FALSE;
+               }
+               else {
+                       lua_pushvalue (L, 2);
+                       /* Get a reference */
+                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
+                       cd->cb_is_ref = TRUE;
                }
+               cd->L = L;
+               cfg->post_filters = g_list_prepend (cfg->post_filters, cd);
+               memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd);
        }
        return 1;
 }
@@ -447,12 +494,11 @@ lua_config_add_radix_map (lua_State *L)
 
        if (cfg) {
                map_line = luaL_checkstring (L, 2);
-               r = g_malloc (sizeof (radix_tree_t *));
+               r = memory_pool_alloc (cfg->cfg_pool, sizeof (radix_tree_t *));
                *r = radix_tree_create ();
                if (!add_map (map_line, read_radix_list, fin_radix_list, (void **)r)) {
                        msg_warn ("invalid radix map %s", map_line);
                        radix_tree_free (*r);
-                       g_free (r);
                        lua_pushnil (L);
                        return 1;
                }
@@ -477,15 +523,15 @@ lua_config_add_hash_map (lua_State *L)
 
        if (cfg) {
                map_line = luaL_checkstring (L, 2);
-               r = g_malloc (sizeof (GHashTable *));
+               r = memory_pool_alloc (cfg->cfg_pool, sizeof (GHashTable *));
                *r = g_hash_table_new (rspamd_strcase_hash, rspamd_strcase_equal);
                if (!add_map (map_line, read_host_list, fin_host_list, (void **)r)) {
                        msg_warn ("invalid hash map %s", map_line);
                        g_hash_table_destroy (*r);
-                       g_free (r);
                        lua_pushnil (L);
                        return 1;
                }
+               memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)g_hash_table_destroy, *r);
                ud = lua_newuserdata (L, sizeof (GHashTable *));
                *ud = r;
                lua_setclass (L, "rspamd{hash_table}", -1);
@@ -507,15 +553,15 @@ lua_config_add_kv_map (lua_State *L)
 
        if (cfg) {
                map_line = luaL_checkstring (L, 2);
-               r = g_malloc (sizeof (GHashTable *));
+               r = memory_pool_alloc (cfg->cfg_pool, sizeof (GHashTable *));
                *r = g_hash_table_new (rspamd_strcase_hash, rspamd_strcase_equal);
                if (!add_map (map_line, read_kv_list, fin_kv_list, (void **)r)) {
                        msg_warn ("invalid hash map %s", map_line);
                        g_hash_table_destroy (*r);
-                       g_free (r);
                        lua_pushnil (L);
                        return 1;
                }
+               memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)g_hash_table_destroy, *r);
                ud = lua_newuserdata (L, sizeof (GHashTable *));
                *ud = r;
                lua_setclass (L, "rspamd{hash_table}", -1);
@@ -536,13 +582,20 @@ lua_metric_symbol_callback (struct worker_task *task, gpointer ud)
 {
        struct lua_callback_data       *cd = ud;
        struct worker_task            **ptask;
-       lua_getglobal (cd->L, cd->name);
+
+       if (cd->cb_is_ref) {
+               lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref);
+       }
+       else {
+               lua_getglobal (cd->L, cd->callback.name);
+       }
        ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *));
        lua_setclass (cd->L, "rspamd{task}", -1);
        *ptask = task;
 
        if (lua_pcall (cd->L, 1, 0, 0) != 0) {
-               msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1));
+               msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" :
+                                                                       cd->callback.name, lua_tostring (cd->L, -1));
        }
 }
 
@@ -550,20 +603,30 @@ static gint
 lua_config_register_symbol (lua_State * L)
 {
        struct config_file             *cfg = lua_check_config (L);
-       const gchar                     *name, *callback;
+       gchar                          *name;
        double                          weight;
        struct lua_callback_data       *cd;
 
        if (cfg) {
                name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2));
                weight = luaL_checknumber (L, 3);
-               callback = luaL_checkstring (L, 4);
+               cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data));
+               if (lua_type (L, 4) == LUA_TSTRING) {
+                       cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 4));
+                       cd->cb_is_ref = FALSE;
+               }
+               else {
+                       lua_pushvalue (L, 4);
+                       /* Get a reference */
+                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
+                       cd->cb_is_ref = TRUE;
+               }
                if (name) {
-                       cd = g_malloc (sizeof (struct lua_callback_data));
-                       cd->name = g_strdup (callback);
+                       cd->symbol = name;
                        cd->L = L;
                        register_symbol (&cfg->cache, name, weight, lua_metric_symbol_callback, cd);
                }
+               memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd);
        }
        return 1;
 }
@@ -572,7 +635,7 @@ static gint
 lua_config_register_virtual_symbol (lua_State * L)
 {
        struct config_file             *cfg = lua_check_config (L);
-       const gchar                     *name;
+       gchar                          *name;
        double                          weight;
 
        if (cfg) {
@@ -589,20 +652,30 @@ static gint
 lua_config_register_callback_symbol (lua_State * L)
 {
        struct config_file             *cfg = lua_check_config (L);
-       const gchar                     *name, *callback;
+       gchar                          *name;
        double                          weight;
        struct lua_callback_data       *cd;
 
        if (cfg) {
                name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2));
                weight = luaL_checknumber (L, 3);
-               callback = luaL_checkstring (L, 4);
+               cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data));
+               if (lua_type (L, 4) == LUA_TSTRING) {
+                       cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 4));
+                       cd->cb_is_ref = FALSE;
+               }
+               else {
+                       lua_pushvalue (L, 4);
+                       /* Get a reference */
+                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
+                       cd->cb_is_ref = TRUE;
+               }
                if (name) {
-                       cd = g_malloc (sizeof (struct lua_callback_data));
-                       cd->name = g_strdup (callback);
+                       cd->symbol = name;
                        cd->L = L;
                        register_callback_symbol (&cfg->cache, name, weight, lua_metric_symbol_callback, cd);
                }
+               memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd);
        }
        return 1;
 }
@@ -611,7 +684,7 @@ static gint
 lua_config_register_callback_symbol_priority (lua_State * L)
 {
        struct config_file             *cfg = lua_check_config (L);
-       const gchar                     *name, *callback;
+       gchar                          *name;
        double                          weight;
        gint                            priority;
        struct lua_callback_data       *cd;
@@ -620,14 +693,25 @@ lua_config_register_callback_symbol_priority (lua_State * L)
                name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2));
                weight = luaL_checknumber (L, 3);
                priority = luaL_checknumber (L, 4);
-               callback = luaL_checkstring (L, 5);
+               cd = memory_pool_alloc (cfg->cfg_pool, sizeof (struct lua_callback_data));
+               if (lua_type (L, 5) == LUA_TSTRING) {
+                       cd->callback.name = memory_pool_strdup (cfg->cfg_pool, luaL_checkstring (L, 5));
+                       cd->cb_is_ref = FALSE;
+               }
+               else {
+                       lua_pushvalue (L, 5);
+                       /* Get a reference */
+                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
+                       cd->cb_is_ref = TRUE;
+               }
 
                if (name) {
-                       cd = g_malloc (sizeof (struct lua_callback_data));
-                       cd->name = g_strdup (callback);
                        cd->L = L;
+                       cd->symbol = name;
                        register_callback_symbol_priority (&cfg->cache, name, weight, priority, lua_metric_symbol_callback, cd);
                }
+               memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func)lua_destroy_cfg_symbol, cd);
+
        }
        return 1;
 }
index 19de0f7c8799c46d8873cdfd48667ea48ad79ca7..7cca35c3d79d8378b97e4025bf4b42b2681b9b10 100644 (file)
@@ -444,7 +444,11 @@ lua_task_get_received_headers (lua_State * L)
 struct lua_dns_callback_data {
        lua_State                      *L;
        struct worker_task             *task;
-       const gchar                    *callback;
+       union {
+               const gchar                *cbname;
+               gint                                            ref;
+       } callback;
+       gboolean                                                cb_is_ref;
        const gchar                    *to_resolve;
        gint                            cbtype;
        union {
@@ -464,7 +468,12 @@ lua_dns_callback (struct rspamd_dns_reply *reply, gpointer arg)
        union rspamd_reply_element     *elt;
        GList                          *cur;
 
-       lua_getglobal (cd->L, cd->callback);
+       if (cd->cb_is_ref) {
+               lua_rawgeti (cd->L, LUA_REGISTRYINDEX, cd->callback.ref);
+       }
+       else {
+               lua_getglobal (cd->L, cd->callback.cbname);
+       }
        ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *));
        lua_setclass (cd->L, "rspamd{task}", -1);
 
@@ -536,7 +545,13 @@ lua_dns_callback (struct rspamd_dns_reply *reply, gpointer arg)
        }
 
        if (lua_pcall (cd->L, 5, 0, 0) != 0) {
-               msg_info ("call to %s failed: %s", cd->callback, lua_tostring (cd->L, -1));
+               msg_info ("call to %s failed: %s", cd->cb_is_ref ? "local function" :
+                               cd->callback.cbname, lua_tostring (cd->L, -1));
+       }
+
+       /* Unref function */
+       if (cd->cb_is_ref) {
+               luaL_unref (cd->L, LUA_REGISTRYINDEX, cd->callback.ref);
        }
 }
 
@@ -551,7 +566,18 @@ lua_task_resolve_dns_a (lua_State * L)
                cd->task = task;
                cd->L = L;
                cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2));
-               cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3));
+
+               /* Check what type we have */
+               if (lua_type (L, 3) == LUA_TSTRING) {
+                       cd->cb_is_ref = FALSE;
+                       cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3));
+               }
+               else {
+                       lua_pushvalue (L, 3);
+                       cd->cb_is_ref = TRUE;
+                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
+               }
+
                cd->cbtype = lua_type (L, 4);
                if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) {
                        switch (cd->cbtype) {
@@ -565,13 +591,13 @@ lua_task_resolve_dns_a (lua_State * L)
                                cd->cbdata.string = memory_pool_strdup (task->task_pool, lua_tostring (L, 4));
                                break;
                        default:
-                               msg_warn ("cannot handle type %s as callback data", lua_typename (L, cd->cbtype));
+                               msg_warn ("cannot handle type %s as callback data, try using closures", lua_typename (L, cd->cbtype));
                                cd->cbtype = LUA_TNONE;
                                break;
                        }
                }
 
-               if (!cd->to_resolve || !cd->callback) {
+               if (!cd->to_resolve) {
                        msg_info ("invalid parameters passed to function");
                        return 0;
                }
@@ -593,7 +619,16 @@ lua_task_resolve_dns_txt (lua_State * L)
                cd->task = task;
                cd->L = L;
                cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2));
-               cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3));
+               /* Check what type we have */
+               if (lua_type (L, 3) == LUA_TSTRING) {
+                       cd->cb_is_ref = FALSE;
+                       cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3));
+               }
+               else {
+                       lua_pushvalue (L, 3);
+                       cd->cb_is_ref = TRUE;
+                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
+               }
                cd->cbtype = lua_type (L, 4);
                if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) {
                        switch (cd->cbtype) {
@@ -612,7 +647,7 @@ lua_task_resolve_dns_txt (lua_State * L)
                                break;
                        }
                }
-               if (!cd->to_resolve || !cd->callback) {
+               if (!cd->to_resolve) {
                        msg_info ("invalid parameters passed to function");
                        return 0;
                }
@@ -635,7 +670,16 @@ lua_task_resolve_dns_ptr (lua_State * L)
                cd->task = task;
                cd->L = L;
                cd->to_resolve = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 2));
-               cd->callback = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3));
+               /* Check what type we have */
+               if (lua_type (L, 3) == LUA_TSTRING) {
+                       cd->cb_is_ref = FALSE;
+                       cd->callback.cbname = memory_pool_strdup (task->task_pool, luaL_checkstring (L, 3));
+               }
+               else {
+                       lua_pushvalue (L, 3);
+                       cd->cb_is_ref = TRUE;
+                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
+               }
                cd->cbtype = lua_type (L, 4);
                if (cd->cbtype != LUA_TNONE && cd->cbtype != LUA_TNIL) {
                        switch (cd->cbtype) {
@@ -655,7 +699,7 @@ lua_task_resolve_dns_ptr (lua_State * L)
                        }
                }
                ina = memory_pool_alloc (task->task_pool, sizeof (struct in_addr));
-               if (!cd->to_resolve || !cd->callback || !inet_aton (cd->to_resolve, ina)) {
+               if (!cd->to_resolve || !inet_aton (cd->to_resolve, ina)) {
                        msg_info ("invalid parameters passed to function");
                        return 0;
                }