]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Rework and abstract lua maps API
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 5 Mar 2016 14:44:08 +0000 (14:44 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 5 Mar 2016 14:44:08 +0000 (14:44 +0000)
- Now all maps share the same lua object table.
- Remove bad destructions code since objects are reallocated during maps
  operations
- Fix and unify various parts of maps management
- Pass map object to lua callbacks

src/lua/lua_common.c
src/lua/lua_common.h
src/lua/lua_config.c

index 960d3457d21dfb8960691ada6273f5d3fae4bc0a..1d78aa73f7572a53c09081740fe71316ece20180 100644 (file)
@@ -211,8 +211,7 @@ rspamd_lua_init ()
        luaopen_logger (L);
        luaopen_mempool (L);
        luaopen_config (L);
-       luaopen_radix (L);
-       luaopen_hash_table (L);
+       luaopen_map (L);
        luaopen_trie (L);
        luaopen_task (L);
        luaopen_textpart (L);
index b18b494a3e6e777ac444c81f5b2d29993c72bd63..8629d6366e243310b5865604a54313f601d8b2cd 100644 (file)
@@ -196,8 +196,7 @@ void rspamd_lua_add_preload (lua_State *L, const gchar *name, lua_CFunction func
 void luaopen_task (lua_State *L);
 void luaopen_config (lua_State *L);
 void luaopen_metric (lua_State *L);
-void luaopen_radix (lua_State *L);
-void luaopen_hash_table (lua_State *L);
+void luaopen_map (lua_State *L);
 void luaopen_trie (lua_State * L);
 void luaopen_textpart (lua_State *L);
 void luaopen_mimepart (lua_State *L);
index dd16710932822f001038e5564a475b5b59f53086..868fa50876c67e1b3717c9f5a77d6074dd7ca98a 100644 (file)
@@ -416,21 +416,29 @@ static const struct luaL_reg configlib_m[] = {
        {NULL, NULL}
 };
 
+enum rspamd_lua_map_type {
+       RSPAMD_LUA_MAP_RADIX = 0,
+       RSPAMD_LUA_MAP_SET,
+       RSPAMD_LUA_MAP_HASH,
+       RSPAMD_LUA_MAP_CALLBACK
+};
 
-/* Radix tree */
-LUA_FUNCTION_DEF (radix, get_key);
+struct rspamd_lua_map {
+       struct rspamd_map *map;
+       enum rspamd_lua_map_type type;
 
-static const struct luaL_reg radixlib_m[] = {
-       LUA_INTERFACE_DEF (radix, get_key),
-       {"__tostring", rspamd_lua_class_tostring},
-       {NULL, NULL}
+       union {
+               radix_compressed_t *radix;
+               GHashTable *hash;
+               gint cbref;
+       } data;
 };
 
-/* Hash table */
-LUA_FUNCTION_DEF (hash_table, get_key);
+/* Radix tree */
+LUA_FUNCTION_DEF (map, get_key);
 
-static const struct luaL_reg hashlib_m[] = {
-       LUA_INTERFACE_DEF (hash_table, get_key),
+static const struct luaL_reg maplib_m[] = {
+       LUA_INTERFACE_DEF (map, get_key),
        {"__tostring", rspamd_lua_class_tostring},
        {NULL, NULL}
 };
@@ -443,20 +451,12 @@ lua_check_config (lua_State * L, gint pos)
        return ud ? *((struct rspamd_config **)ud) : NULL;
 }
 
-static radix_compressed_t *
-lua_check_radix (lua_State * L)
+static struct rspamd_lua_map  *
+lua_check_map (lua_State * L)
 {
-       void *ud = luaL_checkudata (L, 1, "rspamd{radix}");
-       luaL_argcheck (L, ud != NULL, 1, "'radix' expected");
-       return ud ? **((radix_compressed_t ***)ud) : NULL;
-}
-
-static GHashTable *
-lua_check_hash_table (lua_State * L)
-{
-       void *ud = luaL_checkudata (L, 1, "rspamd{hash_table}");
-       luaL_argcheck (L, ud != NULL, 1, "'hash_table' expected");
-       return ud ? **((GHashTable ***)ud) : NULL;
+       void *ud = luaL_checkudata (L, 1, "rspamd{map}");
+       luaL_argcheck (L, ud != NULL, 1, "'map' expected");
+       return ud ? *((struct rspamd_lua_map **)ud) : NULL;
 }
 
 /*** Config functions ***/
@@ -766,30 +766,33 @@ lua_config_add_radix_map (lua_State *L)
 {
        struct rspamd_config *cfg = lua_check_config (L, 1);
        const gchar *map_line, *description;
-       radix_compressed_t **r, ***ud;
+       struct rspamd_lua_map *map, **pmap;
 
        if (cfg) {
                map_line = luaL_checkstring (L, 2);
                description = lua_tostring (L, 3);
-               r = rspamd_mempool_alloc (cfg->cfg_pool, sizeof (radix_compressed_t *));
-               *r = radix_create_compressed ();
+               map = rspamd_mempool_alloc0 (cfg->cfg_pool, sizeof (*map));
+               map->data.radix = radix_create_compressed ();
+               map->type = RSPAMD_LUA_MAP_RADIX;
 
-               if (!rspamd_map_add (cfg, map_line, description, rspamd_radix_read,
-                               rspamd_radix_fin, (void **)r)) {
+               if (!rspamd_map_add (cfg, map_line, description,
+                               rspamd_radix_read,
+                               rspamd_radix_fin,
+                               (void **)&map->data.radix)) {
                        msg_warn_config ("invalid radix map %s", map_line);
-                       radix_destroy_compressed (*r);
+                       radix_destroy_compressed (map->data.radix);
                        lua_pushnil (L);
                        return 1;
                }
 
-               ud = lua_newuserdata (L, sizeof (radix_compressed_t **));
-               *ud = r;
-               rspamd_lua_setclass (L, "rspamd{radix}", -1);
-
-               return 1;
+               pmap = lua_newuserdata (L, sizeof (void *));
+               *pmap = map;
+               rspamd_lua_setclass (L, "rspamd{map}", -1);
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
        }
 
-       lua_pushnil (L);
        return 1;
 
 }
@@ -800,11 +803,10 @@ lua_config_radix_from_config (lua_State *L)
        struct rspamd_config *cfg = lua_check_config (L, 1);
        const gchar *mname, *optname;
        const ucl_object_t *obj;
-       radix_compressed_t **r, ***ud;
+       struct rspamd_lua_map *map, **pmap;
 
        if (!cfg) {
-               lua_pushnil (L);
-               return 1;
+               return luaL_error (L, "invalid arguments");
        }
 
        mname = luaL_checkstring (L, 2);
@@ -813,24 +815,26 @@ lua_config_radix_from_config (lua_State *L)
        if (mname && optname) {
                obj = rspamd_config_get_module_opt (cfg, mname, optname);
                if (obj) {
-                       r = rspamd_mempool_alloc (cfg->cfg_pool, sizeof (radix_compressed_t *));
-                       *r = radix_create_compressed ();
-                       radix_add_generic_iplist (ucl_obj_tostring (obj), r);
-                       ud = lua_newuserdata (L, sizeof (radix_compressed_t **));
-                       *ud = r;
-                       rspamd_lua_setclass (L, "rspamd{radix}", -1);
-                       return 1;
+                       map = rspamd_mempool_alloc0 (cfg->cfg_pool, sizeof (*map));
+                       map->data.radix = radix_create_compressed ();
+                       map->type = RSPAMD_LUA_MAP_RADIX;
+                       map->data.radix = radix_create_compressed ();
+                       radix_add_generic_iplist (ucl_obj_tostring (obj), &map->data.radix);
+                       pmap = lua_newuserdata (L, sizeof (void *));
+                       *pmap = map;
+                       rspamd_lua_setclass (L, "rspamd{map}", -1);
                } else {
                        msg_warn_config ("Couldnt find config option [%s][%s]", mname,
                                        optname);
                        lua_pushnil (L);
-                       return 1;
                }
-       } else {
-               msg_warn_config ("Couldnt find config option");
-               lua_pushnil (L);
-               return 1;
+
        }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       return 1;
 }
 
 static gint
@@ -838,35 +842,34 @@ lua_config_add_hash_map (lua_State *L)
 {
        struct rspamd_config *cfg = lua_check_config (L, 1);
        const gchar *map_line, *description;
-       GHashTable **r, ***ud;
+       struct rspamd_lua_map *map, **pmap;
 
        if (cfg) {
                map_line = luaL_checkstring (L, 2);
                description = lua_tostring (L, 3);
-               r = rspamd_mempool_alloc (cfg->cfg_pool, sizeof (GHashTable *));
-               *r = g_hash_table_new (rspamd_strcase_hash, rspamd_strcase_equal);
+               map = rspamd_mempool_alloc0 (cfg->cfg_pool, sizeof (*map));
+               map->data.hash = g_hash_table_new (rspamd_strcase_hash,
+                               rspamd_strcase_equal);
+               map->type = RSPAMD_LUA_MAP_SET;
 
                if (!rspamd_map_add (cfg, map_line, description,
                                rspamd_hosts_read,
                                rspamd_hosts_fin,
-                               (void **)r)) {
-                       msg_warn ("invalid hash map %s", map_line);
-                       g_hash_table_destroy (*r);
+                               (void **)&map->data.hash)) {
+                       msg_warn_config ("invalid set map %s", map_line);
+                       g_hash_table_destroy (map->data.hash);
                        lua_pushnil (L);
                        return 1;
                }
 
-               rspamd_mempool_add_destructor (cfg->cfg_pool,
-                       (rspamd_mempool_destruct_t)g_hash_table_destroy,
-                       *r);
-               ud = lua_newuserdata (L, sizeof (GHashTable **));
-               *ud = r;
-               rspamd_lua_setclass (L, "rspamd{hash_table}", -1);
-
-               return 1;
+               pmap = lua_newuserdata (L, sizeof (void *));
+               *pmap = map;
+               rspamd_lua_setclass (L, "rspamd{map}", -1);
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
        }
 
-       lua_pushnil (L);
        return 1;
 
 }
@@ -876,37 +879,35 @@ lua_config_add_kv_map (lua_State *L)
 {
        struct rspamd_config *cfg = lua_check_config (L, 1);
        const gchar *map_line, *description;
-       GHashTable **r, ***ud;
+       struct rspamd_lua_map *map, **pmap;
 
        if (cfg) {
                map_line = luaL_checkstring (L, 2);
                description = lua_tostring (L, 3);
-               r = rspamd_mempool_alloc (cfg->cfg_pool, sizeof (GHashTable *));
-               *r = g_hash_table_new (rspamd_strcase_hash, rspamd_strcase_equal);
+               map = rspamd_mempool_alloc0 (cfg->cfg_pool, sizeof (*map));
+               map->data.hash = g_hash_table_new (rspamd_strcase_hash,
+                               rspamd_strcase_equal);
+               map->type = RSPAMD_LUA_MAP_HASH;
 
                if (!rspamd_map_add (cfg, map_line, description,
                                rspamd_kv_list_read,
                                rspamd_kv_list_fin,
-                               (void **)r)) {
+                               (void **)&map->data.hash)) {
                        msg_warn_config ("invalid hash map %s", map_line);
-                       g_hash_table_destroy (*r);
+                       g_hash_table_destroy (map->data.hash);
                        lua_pushnil (L);
                        return 1;
                }
 
-               rspamd_mempool_add_destructor (cfg->cfg_pool,
-                       (rspamd_mempool_destruct_t)g_hash_table_destroy,
-                       *r);
-               ud = lua_newuserdata (L, sizeof (GHashTable **));
-               *ud = r;
-               rspamd_lua_setclass (L, "rspamd{hash_table}", -1);
-
-               return 1;
+               pmap = lua_newuserdata (L, sizeof (void *));
+               *pmap = map;
+               rspamd_lua_setclass (L, "rspamd{map}", -1);
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
        }
 
-       lua_pushnil (L);
        return 1;
-
 }
 
 static gint
@@ -928,7 +929,7 @@ lua_config_get_key (lua_State *L)
                }
        }
        else {
-               lua_pushnil (L);
+               return luaL_error (L, "invalid arguments");
        }
 
        return 1;
@@ -1786,6 +1787,7 @@ struct lua_map_callback_data {
        lua_State *L;
        gint ref;
        GString *data;
+       struct rspamd_lua_map *lua_map;
 };
 
 static gchar *
@@ -1799,6 +1801,7 @@ lua_map_read (rspamd_mempool_t *pool, gchar *chunk, gint len,
                old = (struct lua_map_callback_data *)data->prev_data;
                cbdata->L = old->L;
                cbdata->ref = old->ref;
+               cbdata->lua_map = old->lua_map;
                data->cur_data = cbdata;
        }
        else {
@@ -1819,6 +1822,7 @@ void
 lua_map_fin (rspamd_mempool_t * pool, struct map_cb_data *data)
 {
        struct lua_map_callback_data *cbdata, *old;
+       struct rspamd_lua_map **pmap;
 
        if (data->prev_data) {
                /* Cleanup old data */
@@ -1841,8 +1845,11 @@ lua_map_fin (rspamd_mempool_t * pool, struct map_cb_data *data)
        if (cbdata->data != NULL && cbdata->data->len != 0) {
                lua_rawgeti (cbdata->L, LUA_REGISTRYINDEX, cbdata->ref);
                lua_pushlstring (cbdata->L, cbdata->data->str, cbdata->data->len);
+               pmap = lua_newuserdata (cbdata->L, sizeof (void *));
+               *pmap = cbdata->lua_map;
+               rspamd_lua_setclass (cbdata->L, "rspamd{map}", -1);
 
-               if (lua_pcall (cbdata->L, 1, 0, 0) != 0) {
+               if (lua_pcall (cbdata->L, -1, 0, 0) != 0) {
                        msg_info_pool ("call to %s failed: %s", "local function",
                                lua_tostring (cbdata->L, -1));
                        lua_pop (cbdata->L, 1);
@@ -1856,6 +1863,7 @@ lua_config_add_map (lua_State *L)
        struct rspamd_config *cfg = lua_check_config (L, 1);
        const gchar *map_line, *description;
        struct lua_map_callback_data *cbdata, **pcbdata;
+       struct rspamd_lua_map *map;
        int cbidx;
 
        if (cfg) {
@@ -1877,8 +1885,13 @@ lua_config_add_map (lua_State *L)
                        lua_pushvalue (L, cbidx);
                        /* Get a reference */
                        cbdata->ref = luaL_ref (L, LUA_REGISTRYINDEX);
+                       map = rspamd_mempool_alloc (cfg->cfg_pool, sizeof (*map));
+                       map->type = RSPAMD_LUA_MAP_CALLBACK;
+                       map->data.cbref = cbdata->ref;
+                       cbdata->lua_map = map;
                        pcbdata = rspamd_mempool_alloc (cfg->cfg_pool, sizeof (cbdata));
                        *pcbdata = cbdata;
+
                        if (!rspamd_map_add (cfg, map_line, description, lua_map_read, lua_map_fin,
                                (void **)pcbdata)) {
                                msg_warn_config ("invalid hash map %s", map_line);
@@ -1894,7 +1907,7 @@ lua_config_add_map (lua_State *L)
                }
        }
        else {
-               lua_pushboolean (L, false);
+               return luaL_error (L, "invalid arguments");
        }
 
        return 1;
@@ -1902,95 +1915,92 @@ lua_config_add_map (lua_State *L)
 
 /* Radix and hash table functions */
 static gint
-lua_radix_get_key (lua_State * L)
+lua_map_get_key (lua_State * L)
 {
-       radix_compressed_t *radix = lua_check_radix (L);
+       struct rspamd_lua_map *map = lua_check_map (L);
+       radix_compressed_t *radix;
        struct rspamd_lua_ip *addr = NULL;
+       const gchar *key, *value = NULL;
        gpointer ud;
        guint32 key_num = 0;
        gboolean ret = FALSE;
 
-       if (radix) {
-               if (lua_type (L, 2) == LUA_TNUMBER) {
-                       key_num = luaL_checknumber (L, 2);
-                       key_num = htonl (key_num);
-               }
-               else if (lua_type (L, 2) == LUA_TUSERDATA) {
-                       ud = luaL_checkudata (L, 2, "rspamd{ip}");
-                       if (ud != NULL) {
-                               addr = *((struct rspamd_lua_ip **)ud);
-                               if (addr->addr == NULL) {
-                                       msg_err ("rspamd{ip} is not valid");
-                                       addr = NULL;
-                               }
+       if (map) {
+               if (map->type == RSPAMD_LUA_MAP_RADIX) {
+                       radix = map->data.radix;
+
+                       if (lua_type (L, 2) == LUA_TNUMBER) {
+                               key_num = luaL_checknumber (L, 2);
+                               key_num = htonl (key_num);
                        }
-                       else {
-                               msg_err ("invalid userdata type provided, rspamd{ip} expected");
+                       else if (lua_type (L, 2) == LUA_TUSERDATA) {
+                               ud = luaL_checkudata (L, 2, "rspamd{ip}");
+                               if (ud != NULL) {
+                                       addr = *((struct rspamd_lua_ip **)ud);
+                                       if (addr->addr == NULL) {
+                                               msg_err ("rspamd{ip} is not valid");
+                                               addr = NULL;
+                                       }
+                               }
+                               else {
+                                       msg_err ("invalid userdata type provided, rspamd{ip} expected");
+                               }
                        }
-               }
 
-               if (addr != NULL) {
-                       if (radix_find_compressed_addr (radix, addr->addr)
-                                       !=  RADIX_NO_VALUE) {
-                               ret = TRUE;
+                       if (addr != NULL) {
+                               if (radix_find_compressed_addr (radix, addr->addr)
+                                               !=  RADIX_NO_VALUE) {
+                                       ret = TRUE;
+                               }
                        }
-               }
-               else if (key_num != 0) {
-                       if (radix_find_compressed (radix, (guint8 *)&key_num, sizeof (key_num))
-                               != RADIX_NO_VALUE) {
-                               ret = TRUE;
+                       else if (key_num != 0) {
+                               if (radix_find_compressed (radix, (guint8 *)&key_num, sizeof (key_num))
+                                               != RADIX_NO_VALUE) {
+                                       ret = TRUE;
+                               }
                        }
                }
-       }
+               else if (map->type == RSPAMD_LUA_MAP_SET) {
+                       key = lua_tostring (L, 2);
 
-       lua_pushboolean (L, ret);
-       return 1;
-}
-
-static gint
-lua_hash_table_get_key (lua_State * L)
-{
-       GHashTable *tbl = lua_check_hash_table (L);
-       const gchar *key, *value;
+                       if (key) {
+                               ret = g_hash_table_lookup (map->data.hash, key) != NULL;
+                       }
+               }
+               else {
+                       /* key-value map */
+                       key = lua_tostring (L, 2);
 
-       if (tbl) {
-               key = luaL_checkstring (L, 2);
+                       if (key) {
+                               value = g_hash_table_lookup (map->data.hash, key);
 
-               if ((value = g_hash_table_lookup (tbl, key)) != NULL) {
-                       lua_pushstring (L, value);
-                       return 1;
+                               if (value) {
+                                       lua_pushstring (L, value);
+                                       return 1;
+                               }
+                       }
                }
        }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
 
-       lua_pushnil (L);
+       lua_pushboolean (L, ret);
        return 1;
 }
 
-/* Trie functions */
-
-/* Init functions */
-
 void
 luaopen_config (lua_State * L)
 {
        rspamd_lua_new_class (L, "rspamd{config}", configlib_m);
 
-       lua_pop (L, 1);                      /* remove metatable from stack */
-}
-
-void
-luaopen_radix (lua_State * L)
-{
-       rspamd_lua_new_class (L, "rspamd{radix}", radixlib_m);
-
-       lua_pop (L, 1);                      /* remove metatable from stack */
+       lua_pop (L, 1);
 }
 
 void
-luaopen_hash_table (lua_State * L)
+luaopen_map (lua_State * L)
 {
-       rspamd_lua_new_class (L, "rspamd{hash_table}", hashlib_m);
-       luaL_register (L, "rspamd_hash_table", null_reg);
+       rspamd_lua_new_class (L, "rspamd{map}", maplib_m);
 
-       lua_pop (L, 1);                      /* remove metatable from stack */
+       lua_pop (L, 1);
 }