]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Core: Fix squeezed dependencies handling for virtual symbols
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 22 Mar 2019 20:06:36 +0000 (20:06 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 22 Mar 2019 20:06:36 +0000 (20:06 +0000)
lualib/lua_squeeze_rules.lua
src/libserver/rspamd_symcache.c
src/libserver/rspamd_symcache.h
src/lua/lua_config.c

index 0062504bc8a24125b5c17304d361ced43f3bc51d..abe4abdbd6cae4abe5fc5b4c9fcbf2431139cebb 100644 (file)
@@ -27,6 +27,8 @@ local SN = 'lua_squeeze'
 local squeeze_sym = 'LUA_SQUEEZE'
 local squeeze_function_ids = {}
 local squeezed_groups = {}
+local last_rule
+local virtual_symbols = {}
 
 local function gen_lua_squeeze_function(order)
   return function(task)
@@ -142,9 +144,22 @@ exports.squeeze_rule = function(s, func, flags)
     }
   end
 
+  last_rule = s
+
   return squeeze_function_ids[1]
 end
 
+-- TODO: poor approach, we register all virtual symbols with the previous real squeezed symbol
+exports.squeeze_virtual = function(id, name)
+  if squeeze_function_ids[1] and id == squeeze_function_ids[1] and last_rule then
+    virtual_symbols[name] = last_rule
+
+    return id
+  end
+
+  return -1
+end
+
 exports.squeeze_dependency = function(child, parent)
   lua_util.debugm(SN, rspamd_config, 'squeeze dep %s->%s', child, parent)
 
@@ -222,6 +237,13 @@ exports.squeeze_init = function()
   end
 
   for parent,children in pairs(squeezed_deps) do
+    if not squeezed_symbols[parent] then
+      local real_parent = virtual_symbols[parent]
+      if real_parent then
+        parent = real_parent
+      end
+    end
+
     if not squeezed_symbols[parent] then
       -- Trivial case, external dependnency
 
index 64e451f7a0882265a31ba3072b3d4c2ffd860ae0..c36a7e1d7ce3df2a3a77b5771406eb3990d9698a 100644 (file)
@@ -304,6 +304,33 @@ rspamd_symcache_find_filter (struct rspamd_symcache *cache,
        return NULL;
 }
 
+const gchar *
+rspamd_symcache_get_parent (struct rspamd_symcache *cache,
+                                                                                const gchar *symbol)
+{
+       struct rspamd_symcache_item *item;
+
+       g_assert (cache != NULL);
+
+       if (symbol == NULL) {
+               return NULL;
+       }
+
+       item = g_hash_table_lookup (cache->items_by_symbol, symbol);
+
+       if (item != NULL) {
+
+               if (item->is_virtual) {
+                       item = g_ptr_array_index (cache->items_by_id,
+                                       item->specific.virtual.parent);
+               }
+
+               return item->symbol;
+       }
+
+       return NULL;
+}
+
 static gint
 postfilters_cmp (const void *p1, const void *p2, gpointer ud)
 {
index a038d6a9d016199b0e23ef379f330d078229aa71..0228e4f5a9ce7541333e20dcc85e2796b5f02bce 100644 (file)
@@ -250,6 +250,15 @@ void rspamd_symcache_enable_symbol_perm (struct rspamd_symcache *cache,
 struct rspamd_abstract_callback_data* rspamd_symcache_get_cbdata (
                struct rspamd_symcache *cache, const gchar *symbol);
 
+/**
+ * Returns symbol's parent name (or symbol name itself)
+ * @param cache
+ * @param symbol
+ * @return
+ */
+const gchar *rspamd_symcache_get_parent (struct rspamd_symcache *cache,
+               const gchar *symbol);
+
 /**
  * Adds flags to a symbol
  * @param cache
index 91c648e6fd1506004a733cdacb5a169678cb1df4..9fea7eb5170e9d3d64c06783b988618147112f15 100644 (file)
@@ -437,6 +437,13 @@ LUA_FUNCTION_DEF (config, enable_symbol);
  */
 LUA_FUNCTION_DEF (config, disable_symbol);
 
+/***
+ * @method rspamd_config:get_symbol_parent(symbol)
+ * Returns a parent symbol for specific symbol (or symbol itself if top level)
+ * @param {string} symbol symbol's name
+ */
+LUA_FUNCTION_DEF (config, get_symbol_parent);
+
 /***
  * @method rspamd_config:__newindex(name, callback)
  * This metamethod is called if new indicies are added to the `rspamd_config` object.
@@ -813,6 +820,7 @@ static const struct luaL_reg configlib_m[] = {
        LUA_INTERFACE_DEF (config, get_symbol_callback),
        LUA_INTERFACE_DEF (config, set_symbol_callback),
        LUA_INTERFACE_DEF (config, get_symbol_stat),
+       LUA_INTERFACE_DEF (config, get_symbol_parent),
        LUA_INTERFACE_DEF (config, register_finish_script),
        LUA_INTERFACE_DEF (config, register_monitored),
        LUA_INTERFACE_DEF (config, add_doc),
@@ -1363,50 +1371,48 @@ lua_metric_symbol_callback_return (struct thread_entry *thread_entry, int ret)
 }
 
 static gint
-rspamd_register_symbol_fromlua (lua_State *L,
-               struct rspamd_config *cfg,
-               const gchar *name,
-               gint ref,
-               gdouble weight,
-               gint priority,
-               enum rspamd_symbol_type type,
-               gint parent,
-               gboolean optional,
-               gboolean no_squeeze)
+rspamd_lua_squeeze_rule (lua_State *L,
+                                                struct rspamd_config *cfg,
+                                                const gchar *name,
+                                                gint cbref,
+                                                enum rspamd_symbol_type type,
+                                                gint parent)
 {
-       struct lua_callback_data *cd;
        gint ret = -1, err_idx;
 
-       if (priority == 0 && weight < 0) {
-               priority = 1;
-       }
+       lua_pushcfunction (L, &rspamd_lua_traceback);
+       err_idx = lua_gettop (L);
 
-       if ((ret = rspamd_symcache_find_symbol (cfg->cache, name)) != -1) {
-               if (optional) {
-                       msg_debug_config ("duplicate symbol: %s, skip registering", name);
+       if (type & SYMBOL_TYPE_VIRTUAL) {
+               if (rspamd_lua_require_function (L, "lua_squeeze_rules", "squeeze_virtual")) {
+                       lua_pushnumber (L, parent);
+                       if (name) {
+                               lua_pushstring (L, name);
+                       }
+                       else {
+                               lua_pushnil (L);
+                       }
 
-                       return ret;
+                       /* Now call for squeeze function */
+                       if (lua_pcall (L, 2, 1, err_idx) != 0) {
+                               GString *tb = lua_touserdata (L, -1);
+                               msg_err_config ("call to squeeze_virtual failed: %v", tb);
+
+                               if (tb) {
+                                       g_string_free (tb, TRUE);
+                               }
+                       }
+
+                       ret = lua_tonumber (L, -1);
                }
                else {
-                       msg_err_config ("duplicate symbol: %s, skip registering", name);
-
-                       return -1;
+                       msg_err_config ("lua_squeeze is absent or bad (missing squeeze_virtual),"
+                                                       " your Rspamd installation"
+                                                       " is likely corrupted!");
                }
        }
-
-       if (ref != -1) {
-               if (type & SYMBOL_TYPE_USE_CORO) {
-                       /* Coroutines are incompatible with squeezing */
-                       no_squeeze = TRUE;
-               }
-               /*
-                * We call for routine called lua_squeeze_rules.squeeze_rule if it exists
-                */
-               lua_pushcfunction (L, &rspamd_lua_traceback);
-               err_idx = lua_gettop (L);
-
-               if (!no_squeeze && (type & (SYMBOL_TYPE_CALLBACK|SYMBOL_TYPE_NORMAL)) &&
-                               rspamd_lua_require_function (L, "lua_squeeze_rules", "squeeze_rule")) {
+       else {
+               if (rspamd_lua_require_function (L, "lua_squeeze_rules", "squeeze_rule")) {
                        if (name) {
                                lua_pushstring (L, name);
                        }
@@ -1415,7 +1421,7 @@ rspamd_register_symbol_fromlua (lua_State *L,
                        }
 
                        /* Push function reference */
-                       lua_rawgeti (L, LUA_REGISTRYINDEX, ref);
+                       lua_rawgeti (L, LUA_REGISTRYINDEX, cbref);
 
                        /* Flags */
                        lua_createtable (L, 0, 0);
@@ -1452,45 +1458,62 @@ rspamd_register_symbol_fromlua (lua_State *L,
                        }
 
                        ret = lua_tonumber (L, -1);
+               }
+               else {
+                       msg_err_config ("lua_squeeze is absent or bad (missing squeeze_rule),"
+                                                       " your Rspamd installation"
+                                                       " is likely corrupted!");
+               }
+       }
 
-                       if (ret == -1) {
-                               /* Do direct registration */
-                               cd = rspamd_mempool_alloc0 (cfg->cfg_pool,
-                                               sizeof (struct lua_callback_data));
-                               cd->magic = rspamd_lua_callback_magic;
-                               cd->cb_is_ref = TRUE;
-                               cd->callback.ref = ref;
-                               cd->L = L;
+       /* Cleanup lua stack */
+       lua_settop (L, err_idx - 1);
 
-                               if (name) {
-                                       cd->symbol = rspamd_mempool_strdup (cfg->cfg_pool, name);
-                               }
+       return ret;
+}
 
-                               if (type & SYMBOL_TYPE_USE_CORO) {
-                                       ret = rspamd_symcache_add_symbol (cfg->cache,
-                                                       name,
-                                                       priority,
-                                                       lua_metric_symbol_callback_coro,
-                                                       cd,
-                                                       type,
-                                                       parent);
-                               }
-                               else {
-                                       ret = rspamd_symcache_add_symbol (cfg->cache,
-                                                       name,
-                                                       priority,
-                                                       lua_metric_symbol_callback,
-                                                       cd,
-                                                       type,
-                                                       parent);
-                               }
+static gint
+rspamd_register_symbol_fromlua (lua_State *L,
+               struct rspamd_config *cfg,
+               const gchar *name,
+               gint ref,
+               gdouble weight,
+               gint priority,
+               enum rspamd_symbol_type type,
+               gint parent,
+               gboolean optional,
+               gboolean no_squeeze)
+{
+       struct lua_callback_data *cd;
+       gint ret = -1;
 
-                               rspamd_mempool_add_destructor (cfg->cfg_pool,
-                                               (rspamd_mempool_destruct_t) lua_destroy_cfg_symbol,
-                                               cd);
-                       }
+       if (priority == 0 && weight < 0) {
+               priority = 1;
+       }
+
+       if ((ret = rspamd_symcache_find_symbol (cfg->cache, name)) != -1) {
+               if (optional) {
+                       msg_debug_config ("duplicate symbol: %s, skip registering", name);
+
+                       return ret;
                }
                else {
+                       msg_err_config ("duplicate symbol: %s, skip registering", name);
+
+                       return -1;
+               }
+       }
+
+       if (ref != -1) {
+               if (type & SYMBOL_TYPE_USE_CORO) {
+                       /* Coroutines are incompatible with squeezing */
+                       no_squeeze = TRUE;
+               }
+               /*
+                * We call for routine called lua_squeeze_rules.squeeze_rule if it exists
+                */
+               if (no_squeeze || (ret = rspamd_lua_squeeze_rule (L, cfg, name, ref,
+                               type, parent)) == -1) {
                        cd = rspamd_mempool_alloc0 (cfg->cfg_pool,
                                        sizeof (struct lua_callback_data));
                        cd->magic = rspamd_lua_callback_magic;
@@ -1524,11 +1547,13 @@ rspamd_register_symbol_fromlua (lua_State *L,
                                        (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
                                        cd);
                }
-
-               /* Cleanup lua stack */
-               lua_settop (L, err_idx - 1);
        }
        else {
+               if (!no_squeeze) {
+                       rspamd_lua_squeeze_rule (L, cfg, name, ref,
+                                       type, parent);
+               }
+               /* Not a squeezed symbol */
                ret = rspamd_symcache_add_symbol (cfg->cache,
                                name,
                                priority,
@@ -2197,6 +2222,9 @@ rspamd_lua_squeeze_dependency (lua_State *L, struct rspamd_config *cfg,
                        ret = lua_toboolean (L, -1);
                }
        }
+       else {
+               msg_err_config ("cannot get lua_squeeze_rules.squeeze_dependency function");
+       }
 
        lua_settop (L, err_idx - 1);
 
@@ -3411,6 +3439,29 @@ lua_config_get_symbol_stat (lua_State *L)
        return 1;
 }
 
+static gint
+lua_config_get_symbol_parent (lua_State *L)
+{
+       LUA_TRACE_POINT;
+       struct rspamd_config *cfg = lua_check_config (L, 1);
+       const gchar *sym = luaL_checkstring (L, 2), *parent;
+
+       if (cfg != NULL && sym != NULL) {
+               parent = rspamd_symcache_get_parent (cfg->cache, sym);
+
+               if (parent) {
+                       lua_pushstring (L, parent);
+               }
+               else {
+                       lua_pushnil (L);
+               }
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       return 1;
+}
 
 static gint
 lua_config_register_finish_script (lua_State *L)