]> source.dussan.org Git - rspamd.git/commitdiff
Rework adding symbols from lua.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 17 Aug 2014 17:14:31 +0000 (18:14 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 17 Aug 2014 17:14:31 +0000 (18:14 +0100)
It is now possible to use something like:

rspamd_config.SYMBOL = function(task) ... end

or even

rspamd_config.SYMBOL =  {
     callback = function(task) ... end,
weight = '1.0', --optional
priority = '0', --optional
type = 'callback' --optional
}

src/libserver/symbols_cache.c
src/libserver/symbols_cache.h
src/lua/lua_config.c

index 185b1e1058182d0b2f9b58baa208da7cb7c81664..823af39f53e7057a0860ce00cc69dc1ce460dd3b 100644 (file)
@@ -281,13 +281,7 @@ create_cache_file (struct symbols_cache *cache,
        return mmap_cache_file (cache, fd, pool);
 }
 
-enum rspamd_symbol_type {
-       SYMBOL_TYPE_NORMAL,
-       SYMBOL_TYPE_VIRTUAL,
-       SYMBOL_TYPE_CALLBACK
-};
-
-static void
+void
 register_symbol_common (struct symbols_cache **cache,
        const gchar *name,
        double weight,
index 36128460d53e98a1fc5439005151f340e54b7882..fd9da2ef73b864c8cbd82793da7f1431c48eb924 100644 (file)
@@ -49,6 +49,11 @@ struct cache_item {
        gdouble metric_weight;
 };
 
+enum rspamd_symbol_type {
+       SYMBOL_TYPE_NORMAL,
+       SYMBOL_TYPE_VIRTUAL,
+       SYMBOL_TYPE_CALLBACK
+};
 
 struct symbols_cache {
        /* Normal cache items */
@@ -146,6 +151,25 @@ void register_dynamic_symbol (rspamd_mempool_t *pool,
        gpointer user_data,
        GList *networks);
 
+/**
+ * Generic function to register a symbol
+ * @param cache
+ * @param name
+ * @param weight
+ * @param priority
+ * @param func
+ * @param user_data
+ * @param type
+ */
+void
+register_symbol_common (struct symbols_cache **cache,
+       const gchar *name,
+       double weight,
+       gint priority,
+       symbol_func_t func,
+       gpointer user_data,
+       enum rspamd_symbol_type type);
+
 /**
  * Call function for cached symbol using saved callback
  * @param task task object
index d8ef67f76e1db9ea902a37cec9b6cbb22cbe0f5d..c5d1a7039ae2512193b04d43b1fadef89327ad82 100644 (file)
@@ -51,6 +51,7 @@ LUA_FUNCTION_DEF (config, register_post_filter);
 LUA_FUNCTION_DEF (config, register_module_option);
 LUA_FUNCTION_DEF (config, get_api_version);
 LUA_FUNCTION_DEF (config, get_key);
+LUA_FUNCTION_DEF (config, newindex);
 
 static const struct luaL_reg configlib_m[] = {
        LUA_INTERFACE_DEF (config, get_module_opt),
@@ -73,6 +74,7 @@ static const struct luaL_reg configlib_m[] = {
        LUA_INTERFACE_DEF (config, get_api_version),
        LUA_INTERFACE_DEF (config, get_key),
        {"__tostring", rspamd_lua_class_tostring},
+       {"__newindex", lua_config_newindex},
        {NULL, NULL}
 };
 
@@ -583,7 +585,7 @@ static gint
 lua_config_get_key (lua_State *L)
 {
        struct rspamd_config *cfg = lua_check_config (L);
-       const char *name;
+       const gchar *name;
        size_t namelen;
        const ucl_object_t *val;
 
@@ -604,115 +606,6 @@ lua_config_get_key (lua_State *L)
        return 1;
 }
 
-struct lua_map_callback_data {
-       lua_State *L;
-       gint ref;
-       GString *data;
-};
-
-static gchar *
-lua_map_read (rspamd_mempool_t *pool, gchar *chunk, gint len,
-       struct map_cb_data *data)
-{
-       struct lua_map_callback_data *cbdata, *old;
-
-       if (data->cur_data == NULL) {
-               cbdata = g_slice_alloc (sizeof (*cbdata));
-               old = (struct lua_map_callback_data *)data->prev_data;
-               cbdata->L = old->L;
-               cbdata->ref = old->ref;
-       }
-       else {
-               cbdata = (struct lua_map_callback_data *)data->cur_data;
-       }
-
-       if (cbdata->data == NULL) {
-               cbdata->data = g_string_new_len (chunk, len);
-       }
-       else {
-               g_string_append_len (cbdata->data, chunk, len);
-       }
-
-       return NULL;
-}
-
-void
-lua_map_fin (rspamd_mempool_t * pool, struct map_cb_data *data)
-{
-       struct lua_map_callback_data *cbdata, *old;
-
-       if (data->prev_data) {
-               /* Cleanup old data */
-               old = (struct lua_map_callback_data *)data->prev_data;
-               if (old->data) {
-                       g_string_free (old->data, TRUE);
-               }
-               g_slice_free1 (sizeof (*old), old);
-       }
-
-       if (data->cur_data) {
-               cbdata = (struct lua_map_callback_data *)data->cur_data;
-       }
-       else {
-               msg_err ("no data read for map");
-               return;
-       }
-
-       if (cbdata->data != NULL && cbdata->data->len != 0) {
-               lua_rawgeti (cbdata->L, LUA_REGISTRYINDEX, cbdata->ref);
-               lua_pushlstring (cbdata->L, cbdata->data->str, cbdata->data->len);
-
-               if (lua_pcall (cbdata->L, 1, 0, 0) != 0) {
-                       msg_info ("call to %s failed: %s", "local function",
-                               lua_tostring (cbdata->L, -1));
-               }
-       }
-}
-
-static gint
-lua_config_add_map (lua_State *L)
-{
-       struct rspamd_config *cfg = lua_check_config (L);
-       const gchar *map_line, *description;
-       struct lua_map_callback_data *cbdata, **pcbdata;
-
-       if (cfg) {
-               map_line = luaL_checkstring (L, 2);
-               description = lua_tostring (L, 3);
-
-               if (lua_type (L, 4) == LUA_TFUNCTION) {
-                       cbdata = g_slice_alloc (sizeof (*cbdata));
-                       cbdata->L = L;
-                       cbdata->data = NULL;
-                       lua_pushvalue (L, 4);
-                       /* Get a reference */
-                       cbdata->ref = luaL_ref (L, LUA_REGISTRYINDEX);
-                       pcbdata = rspamd_mempool_alloc (cfg->cfg_pool, sizeof (cbdata));
-                       *pcbdata = cbdata;
-                       if (!add_map (cfg, map_line, description, lua_map_read, lua_map_fin,
-                               (void **)pcbdata)) {
-                               msg_warn ("invalid hash map %s", map_line);
-                               lua_pushboolean (L, false);
-                       }
-                       else {
-                               lua_pushboolean (L, true);
-                       }
-               }
-               else {
-                       msg_warn ("invalid callback argument for map %s", map_line);
-                       lua_pushboolean (L, false);
-               }
-       }
-       else {
-               lua_pushboolean (L, false);
-       }
-
-       return 1;
-}
-
-/*** Metric functions ***/
-
-
 static void
 lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
 {
@@ -735,44 +628,62 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
        }
 }
 
+static void
+rspamd_register_symbol_fromlua (lua_State *L,
+               struct rspamd_config *cfg,
+               const gchar *name,
+               gint ref,
+               gdouble weight,
+               gint priority,
+               enum rspamd_symbol_type type)
+{
+       struct lua_callback_data *cd;
+
+       cd = rspamd_mempool_alloc (cfg->cfg_pool,
+               sizeof (struct lua_callback_data));
+       cd->cb_is_ref = TRUE;
+       cd->callback.ref = ref;
+
+       register_symbol_common (&cfg->cache,
+                                       name,
+                                       weight,
+                                       priority,
+                                       lua_metric_symbol_callback,
+                                       cd,
+                                       type);
+       rspamd_mempool_add_destructor (cfg->cfg_pool,
+               (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
+               cd);
+}
+
 static gint
 lua_config_register_symbol (lua_State * L)
 {
        struct rspamd_config *cfg = lua_check_config (L);
        gchar *name;
        double weight;
-       struct lua_callback_data *cd;
 
        if (cfg) {
                name = rspamd_mempool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2));
                weight = luaL_checknumber (L, 3);
-               cd =
-                       rspamd_mempool_alloc (cfg->cfg_pool,
-                               sizeof (struct lua_callback_data));
+
                if (lua_type (L, 4) == LUA_TSTRING) {
-                       cd->callback.name = rspamd_mempool_strdup (cfg->cfg_pool,
-                                       luaL_checkstring (L, 4));
-                       cd->cb_is_ref = FALSE;
+                       lua_getglobal (L, luaL_checkstring (L, 4));
                }
                else {
                        lua_pushvalue (L, 4);
-                       /* Get a reference */
-                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
-                       cd->cb_is_ref = TRUE;
                }
                if (name) {
-                       cd->symbol = name;
-                       cd->L = L;
-                       register_symbol (&cfg->cache,
-                               name,
-                               weight,
-                               lua_metric_symbol_callback,
-                               cd);
+                       rspamd_register_symbol_fromlua (L,
+                                       cfg,
+                                       name,
+                                       luaL_ref (L, LUA_REGISTRYINDEX),
+                                       weight,
+                                       0,
+                                       SYMBOL_TYPE_NORMAL);
                }
-               rspamd_mempool_add_destructor (cfg->cfg_pool,
-                       (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
-                       cd);
        }
+
        return 0;
 }
 
@@ -780,8 +691,7 @@ static gint
 lua_config_register_symbols (lua_State *L)
 {
        struct rspamd_config *cfg = lua_check_config (L);
-       struct lua_callback_data *cd;
-       gint i, top;
+       gint i, top, idx;
        gchar *sym;
        gdouble weight = 1.0;
 
@@ -790,20 +700,14 @@ lua_config_register_symbols (lua_State *L)
                return 0;
        }
        if (cfg) {
-               cd =
-                       rspamd_mempool_alloc (cfg->cfg_pool,
-                               sizeof (struct lua_callback_data));
                if (lua_type (L, 2) == LUA_TSTRING) {
-                       cd->callback.name = rspamd_mempool_strdup (cfg->cfg_pool,
-                                       luaL_checkstring (L, 2));
-                       cd->cb_is_ref = FALSE;
+                       lua_getglobal (L, luaL_checkstring (L, 2));
                }
                else {
                        lua_pushvalue (L, 2);
-                       /* Get a reference */
-                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
-                       cd->cb_is_ref = TRUE;
                }
+               idx = luaL_ref (L, LUA_REGISTRYINDEX);
+
                if (lua_type (L, 3) == LUA_TNUMBER) {
                        weight = lua_tonumber (L, 3);
                        top = 4;
@@ -812,13 +716,13 @@ lua_config_register_symbols (lua_State *L)
                        top = 3;
                }
                sym = rspamd_mempool_strdup (cfg->cfg_pool, luaL_checkstring (L, top));
-               cd->symbol = sym;
-               cd->L = L;
-               register_symbol (&cfg->cache,
-                       sym,
-                       weight,
-                       lua_metric_symbol_callback,
-                       cd);
+               rspamd_register_symbol_fromlua (L,
+                               cfg,
+                               sym,
+                               idx,
+                               weight,
+                               0,
+                               SYMBOL_TYPE_NORMAL);
                for (i = top; i < lua_gettop (L); i++) {
                        sym =
                                rspamd_mempool_strdup (cfg->cfg_pool, luaL_checkstring (L,
@@ -853,38 +757,28 @@ lua_config_register_callback_symbol (lua_State * L)
        struct rspamd_config *cfg = lua_check_config (L);
        gchar *name;
        double weight;
-       struct lua_callback_data *cd;
 
        if (cfg) {
                name = rspamd_mempool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2));
                weight = luaL_checknumber (L, 3);
-               cd =
-                       rspamd_mempool_alloc (cfg->cfg_pool,
-                               sizeof (struct lua_callback_data));
+
                if (lua_type (L, 4) == LUA_TSTRING) {
-                       cd->callback.name = rspamd_mempool_strdup (cfg->cfg_pool,
-                                       luaL_checkstring (L, 4));
-                       cd->cb_is_ref = FALSE;
+                       lua_getglobal (L, luaL_checkstring (L, 4));
                }
                else {
                        lua_pushvalue (L, 4);
-                       /* Get a reference */
-                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
-                       cd->cb_is_ref = TRUE;
                }
                if (name) {
-                       cd->symbol = name;
-                       cd->L = L;
-                       register_callback_symbol (&cfg->cache,
-                               name,
-                               weight,
-                               lua_metric_symbol_callback,
-                               cd);
+                       rspamd_register_symbol_fromlua (L,
+                                       cfg,
+                                       name,
+                                       luaL_ref (L, LUA_REGISTRYINDEX),
+                                       weight,
+                                       0,
+                                       SYMBOL_TYPE_CALLBACK);
                }
-               rspamd_mempool_add_destructor (cfg->cfg_pool,
-                       (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
-                       cd);
        }
+
        return 0;
 }
 
@@ -895,45 +789,231 @@ lua_config_register_callback_symbol_priority (lua_State * L)
        gchar *name;
        double weight;
        gint priority;
-       struct lua_callback_data *cd;
 
        if (cfg) {
                name = rspamd_mempool_strdup (cfg->cfg_pool, luaL_checkstring (L, 2));
                weight = luaL_checknumber (L, 3);
                priority = luaL_checknumber (L, 4);
-               cd =
-                       rspamd_mempool_alloc (cfg->cfg_pool,
-                               sizeof (struct lua_callback_data));
+
                if (lua_type (L, 5) == LUA_TSTRING) {
-                       cd->callback.name = rspamd_mempool_strdup (cfg->cfg_pool,
-                                       luaL_checkstring (L, 5));
-                       cd->cb_is_ref = FALSE;
+                       lua_getglobal (L, luaL_checkstring (L, 5));
                }
                else {
                        lua_pushvalue (L, 5);
-                       /* Get a reference */
-                       cd->callback.ref = luaL_ref (L, LUA_REGISTRYINDEX);
-                       cd->cb_is_ref = TRUE;
                }
-
                if (name) {
-                       cd->L = L;
-                       cd->symbol = name;
-                       register_callback_symbol_priority (&cfg->cache,
-                               name,
-                               weight,
-                               priority,
-                               lua_metric_symbol_callback,
-                               cd);
+                       rspamd_register_symbol_fromlua (L,
+                                       cfg,
+                                       name,
+                                       luaL_ref (L, LUA_REGISTRYINDEX),
+                                       weight,
+                                       priority,
+                                       SYMBOL_TYPE_CALLBACK);
                }
-               rspamd_mempool_add_destructor (cfg->cfg_pool,
-                       (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
-                       cd);
+       }
+
+       return 0;
+}
+
+
+static gint
+lua_config_newindex (lua_State *L)
+{
+       struct rspamd_config *cfg = lua_check_config (L);
+       const gchar *name;
+
+       name = luaL_checkstring (L, 2);
+
+       if (name != NULL && lua_gettop (L) > 2) {
+               if (lua_type (L, 3) == LUA_TFUNCTION) {
+                       /* Normal symbol from just a function */
+                       lua_pushvalue (L, 3);
+                       rspamd_register_symbol_fromlua (L,
+                                       cfg,
+                                       name,
+                                       luaL_ref (L, LUA_REGISTRYINDEX),
+                                       1.0,
+                                       0,
+                                       SYMBOL_TYPE_NORMAL);
+               }
+               else if (lua_type (L, 3) == LUA_TTABLE) {
+                       gint type = SYMBOL_TYPE_NORMAL, priority = 0, idx;
+                       gdouble weight = 1.0;
+                       const char *type_str;
+
+                       /*
+                        * Table can have the following attributes:
+                        * "callback" - should be a callback function
+                        * "weight" - optional weight
+                        * "priority" - optional priority
+                        * "type" - optional type (normal, virtual, callback)
+                        */
+                       lua_pushstring (L, "callback");
+                       lua_gettable (L, -2);
+
+                       if (lua_type (L, -1) != LUA_TFUNCTION) {
+                               lua_pop (L, 1);
+                               msg_info ("cannot find callback definition for %s", name);
+                               return 0;
+                       }
+                       idx = luaL_ref (L, LUA_REGISTRYINDEX);
+
+                       /* Optional fields */
+                       lua_pushstring (L, "weight");
+                       lua_gettable (L, -2);
 
+                       if (lua_type (L, -1) == LUA_TNUMBER) {
+                               weight = lua_tonumber (L, -1);
+                       }
+                       lua_pop (L, 1);
+
+                       lua_pushstring (L, "priority");
+                       lua_gettable (L, -2);
+
+                       if (lua_type (L, -1) == LUA_TNUMBER) {
+                               priority = lua_tonumber (L, -1);
+                       }
+                       lua_pop (L, 1);
+
+                       lua_pushstring (L, "type");
+                       lua_gettable (L, -2);
+
+                       if (lua_type (L, -1) == LUA_TSTRING) {
+                               type_str = lua_tostring (L, -1);
+                               if (strcmp (type_str, "normal") == 0) {
+                                       type = SYMBOL_TYPE_NORMAL;
+                               }
+                               else if (strcmp (type_str, "virtual") == 0) {
+                                       type = SYMBOL_TYPE_VIRTUAL;
+                               }
+                               else if (strcmp (type_str, "callback") == 0) {
+                                       type = SYMBOL_TYPE_CALLBACK;
+                               }
+                               else {
+                                       msg_info ("unknown type: %s", type_str);
+                               }
+
+                       }
+                       lua_pop (L, 1);
+
+                       rspamd_register_symbol_fromlua (L,
+                                       cfg,
+                                       name,
+                                       idx,
+                                       weight,
+                                       priority,
+                                       type);
+               }
        }
+
        return 0;
 }
 
+struct lua_map_callback_data {
+       lua_State *L;
+       gint ref;
+       GString *data;
+};
+
+static gchar *
+lua_map_read (rspamd_mempool_t *pool, gchar *chunk, gint len,
+       struct map_cb_data *data)
+{
+       struct lua_map_callback_data *cbdata, *old;
+
+       if (data->cur_data == NULL) {
+               cbdata = g_slice_alloc (sizeof (*cbdata));
+               old = (struct lua_map_callback_data *)data->prev_data;
+               cbdata->L = old->L;
+               cbdata->ref = old->ref;
+       }
+       else {
+               cbdata = (struct lua_map_callback_data *)data->cur_data;
+       }
+
+       if (cbdata->data == NULL) {
+               cbdata->data = g_string_new_len (chunk, len);
+       }
+       else {
+               g_string_append_len (cbdata->data, chunk, len);
+       }
+
+       return NULL;
+}
+
+void
+lua_map_fin (rspamd_mempool_t * pool, struct map_cb_data *data)
+{
+       struct lua_map_callback_data *cbdata, *old;
+
+       if (data->prev_data) {
+               /* Cleanup old data */
+               old = (struct lua_map_callback_data *)data->prev_data;
+               if (old->data) {
+                       g_string_free (old->data, TRUE);
+               }
+               g_slice_free1 (sizeof (*old), old);
+       }
+
+       if (data->cur_data) {
+               cbdata = (struct lua_map_callback_data *)data->cur_data;
+       }
+       else {
+               msg_err ("no data read for map");
+               return;
+       }
+
+       if (cbdata->data != NULL && cbdata->data->len != 0) {
+               lua_rawgeti (cbdata->L, LUA_REGISTRYINDEX, cbdata->ref);
+               lua_pushlstring (cbdata->L, cbdata->data->str, cbdata->data->len);
+
+               if (lua_pcall (cbdata->L, 1, 0, 0) != 0) {
+                       msg_info ("call to %s failed: %s", "local function",
+                               lua_tostring (cbdata->L, -1));
+               }
+       }
+}
+
+static gint
+lua_config_add_map (lua_State *L)
+{
+       struct rspamd_config *cfg = lua_check_config (L);
+       const gchar *map_line, *description;
+       struct lua_map_callback_data *cbdata, **pcbdata;
+
+       if (cfg) {
+               map_line = luaL_checkstring (L, 2);
+               description = lua_tostring (L, 3);
+
+               if (lua_type (L, 4) == LUA_TFUNCTION) {
+                       cbdata = g_slice_alloc (sizeof (*cbdata));
+                       cbdata->L = L;
+                       cbdata->data = NULL;
+                       lua_pushvalue (L, 4);
+                       /* Get a reference */
+                       cbdata->ref = luaL_ref (L, LUA_REGISTRYINDEX);
+                       pcbdata = rspamd_mempool_alloc (cfg->cfg_pool, sizeof (cbdata));
+                       *pcbdata = cbdata;
+                       if (!add_map (cfg, map_line, description, lua_map_read, lua_map_fin,
+                               (void **)pcbdata)) {
+                               msg_warn ("invalid hash map %s", map_line);
+                               lua_pushboolean (L, false);
+                       }
+                       else {
+                               lua_pushboolean (L, true);
+                       }
+               }
+               else {
+                       msg_warn ("invalid callback argument for map %s", map_line);
+                       lua_pushboolean (L, false);
+               }
+       }
+       else {
+               lua_pushboolean (L, false);
+       }
+
+       return 1;
+}
 
 /* Radix and hash table functions */
 static gint