]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Add lua rules squeezing
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 15 Mar 2018 18:00:35 +0000 (18:00 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 15 Mar 2018 18:00:35 +0000 (18:00 +0000)
lualib/lua_squeeze_rules.lua
src/libserver/cfg_utils.c
src/lua/lua_config.c

index d740e2ee00f3c5f49029862f0148f168fc7e23a0..02da274a70741178aedfc9d486279758fa9e651c 100644 (file)
@@ -22,28 +22,28 @@ local squeezed_rules = {} -- plain vector of all rules squeezed
 local squeezed_symbols = {} -- indexed by name of symbol
 local squeezed_deps = {} -- squeezed deps
 local SN = 'lua_squeeze'
-local squeeze_function_id
+local squeeze_sym = 'LUA_SQUEEZE'
+local squeeze_function_ids = {}
 
 local function lua_squeeze_function(task)
-  if not squeezed_symbols then
-    for k,v in pairs(squeezed_symbols) do
-      if not squeezed_exceptions[k] then
-        logger.debugm(SN, task, 'added squeezed rule: %s', k)
-        table.insert(squeezed_rules, v)
-      else
-        logger.debugm(SN, task, 'skipped squeezed rule: %s', k)
-      end
-    end
+  for _,data in ipairs(squeezed_rules) do
+    local ret = {data[1](task)}
 
-    squeezed_symbols = nil
-  end
-
-  for _,func in ipairs(squeezed_rules) do
-    local ret = func(task)
-
-    if ret then
+    if #ret ~= 0 then
+      local first = ret[1]
+      local sym = data[2]
       -- Function has returned something, so it is rule, not a plugin
-      logger.errx(task, 'hui: %s', ret)
+      if type(first) == 'boolean' then
+        if first then
+          table.remove(ret, 1)
+          task:insert_result(sym, 1.0, ret)
+        end
+      elseif type(first) == 'number' then
+        table.remove(ret, 1)
+        task:insert_result(sym, first, ret)
+      else
+        task:insert_result(sym, 1.0, ret)
+      end
     end
   end
 end
@@ -61,37 +61,114 @@ exports.squeeze_rule = function(s, func)
     end
   else
     -- Unconditionally add function to the squeezed rules
-    logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', #squeezed_rules)
-    table.insert(squeezed_rules, func)
+    local id = tostring(#squeezed_rules)
+    logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', id)
+    table.insert(squeezed_rules, {func, 'unnamed: ' .. id})
   end
 
-  if not squeeze_function_id then
-    squeeze_function_id = rspamd_config:register_symbol{
+  if not squeeze_function_ids[1] then
+    squeeze_function_ids[1] = rspamd_config:register_symbol{
       type = 'callback',
       callback = lua_squeeze_function,
-      name = 'LUA_SQUEEZE',
+      name = squeeze_sym,
       description = 'Meta rule for Lua rules that can be squeezed',
       no_squeeze = true, -- to avoid infinite recursion
     }
   end
 
-  return squeeze_function_id
+  return squeeze_function_ids[1]
 end
 
-exports.squeeze_dependency = function(from, to)
-  logger.debugm(SN, rspamd_config, 'squeeze dep %s->%s', to, from)
+exports.squeeze_dependency = function(child, parent)
+  logger.debugm(SN, rspamd_config, 'squeeze dep %s->%s', child, parent)
 
-  if not squeezed_deps[to] then
-    squeezed_deps[to] = {}
+  if not squeezed_deps[parent] then
+    squeezed_deps[parent] = {}
   end
 
-  if not squeezed_symbols[to][from] then
-    squeezed_symbols[to][from] = true
+  if not squeezed_deps[parent][child] then
+    squeezed_deps[parent][child] = true
   else
-    logger.warnx('duplicate dependency %s->%s', to, from)
+    logger.warnx(rspamd_config, 'duplicate dependency %s->%s', child, parent)
   end
 
   return true
 end
 
+local function get_ordered_symbol_name(order)
+  if order == 0 then
+    return squeeze_sym
+  end
+
+  return squeeze_sym .. tostring(order)
+end
+
+local function register_topology_symbol(order)
+  local ord_sym = get_ordered_symbol_name(order)
+
+  squeeze_function_ids[order + 1] = rspamd_config:register_symbol{
+    type = 'callback',
+    callback = lua_squeeze_function,
+    name = ord_sym,
+    description = 'Meta rule for Lua rules that can be squeezed, order ' .. tostring(order),
+    no_squeeze = true, -- to avoid infinite recursion
+  }
+
+  local parent = get_ordered_symbol_name(order - 1)
+  logger.debugm(SN, rspamd_config, 'registered new order of deps: %s->%s',
+      ord_sym, parent)
+  rspamd_config:register_dependency(ord_sym, parent, true)
+end
+
+exports.squeeze_init = function()
+  local max_topology_order = 0
+
+  for parent,children in pairs(squeezed_deps) do
+    if not squeezed_symbols[parent] then
+      -- Trivial case, external dependnency
+      logger.debugm(SN, rspamd_config, 'register external squeezed dependency on %s',
+          parent)
+      rspamd_config:register_dependency(squeeze_sym, parent, true)
+    else
+      -- Not so trivial case
+      local ps = squeezed_symbols[parent]
+
+      for cld,_ in pairs(children) do
+        if squeezed_symbols[cld] then
+          -- Cross dependency
+          logger.debugm(SN, rspamd_config, 'cross dependency in squeezed symbols %s->%s',
+              cld, parent)
+          local order = math.max(ps.order + 1, squeezed_symbols[cld].order)
+          squeezed_symbols[cld].order = order
+          if order > max_topology_order then
+            -- Need to register new callback symbol to handle deps
+            register_topology_symbol(order)
+            max_topology_order = order
+          end
+        else
+          -- External symbol depends on a squeezed one
+          local parent_symbol = get_ordered_symbol_name(ps.order)
+          rspamd_config:register_dependency(cld, parent_symbol, true)
+          logger.debugm(SN, rspamd_config, 'register squeezed dependency for external symbol %s->%s',
+              cld, parent_symbol)
+        end
+      end
+    end
+  end
+
+  -- We have now all deps being registered, so we can register virtual symbols
+  -- and create squeezed rules
+  for k,v in pairs(squeezed_symbols) do
+    local parent_symbol = get_ordered_symbol_name(v.order)
+    logger.debugm(SN, rspamd_config, 'added squeezed rule: %s (%s)', k, parent_symbol)
+    rspamd_config:register_symbol{
+      type = 'virtual',
+      name = k,
+      parent = squeeze_function_ids[v.order + 1],
+      no_squeeze = true, -- to avoid infinite recursion
+    }
+    table.insert(squeezed_rules, {v.cb,k})
+  end
+end
+
 return exports
\ No newline at end of file
index c042c5eb2074f703eef28bb3b1d5ba240cfc471c..10503dc21b6cfbbd7a8a84fe738433daadd3a623 100644 (file)
@@ -753,6 +753,27 @@ rspamd_config_post_load (struct rspamd_config *cfg,
        }
 
        if (opts & RSPAMD_CONFIG_INIT_SYMCACHE) {
+               lua_State *L = cfg->lua_state;
+               int err_idx;
+
+               /* Process squeezed Lua rules */
+               lua_pushcfunction (L, &rspamd_lua_traceback);
+               err_idx = lua_gettop (L);
+
+               if (rspamd_lua_require_function (cfg->lua_state, "lua_squeeze_rules",
+                               "squeeze_init")) {
+                       if (lua_pcall (L, 0, 0, err_idx) != 0) {
+                               GString *tb = lua_touserdata (L, -1);
+                               msg_err_config ("call to squeeze_init script failed: %v", tb);
+
+                               if (tb) {
+                                       g_string_free (tb, TRUE);
+                               }
+                       }
+               }
+
+               lua_settop (L, err_idx - 1);
+
                /* Init config cache */
                rspamd_symbols_cache_init (cfg->cache);
 
index d6a4ad550a1e7b6a0aaeffaa13c6e28f91558f3b..6abbf7d0eecc121b079be53bf6121739a5b08bdf 100644 (file)
@@ -1818,24 +1818,21 @@ lua_config_register_callback_symbol_priority (lua_State * L)
 
 static gboolean
 rspamd_lua_squeeze_dependency (lua_State *L, struct rspamd_config *cfg,
-               const gchar *name,
-               const gchar *from)
+               const gchar *child,
+               const gchar *parent)
 {
        gint err_idx;
        gboolean ret = FALSE;
 
+       g_assert (parent != NULL);
+       g_assert (child != NULL);
+
        lua_pushcfunction (L, &rspamd_lua_traceback);
        err_idx = lua_gettop (L);
 
        if (rspamd_lua_require_function (L, "lua_squeeze_rules", "squeeze_dependency")) {
-               lua_pushstring (L, name);
-
-               if (from) {
-                       lua_pushstring (L, from);
-               }
-               else {
-                       lua_pushnil (L);
-               }
+               lua_pushstring (L, child);
+               lua_pushstring (L, parent);
 
                if (lua_pcall (L, 2, 1, err_idx) != 0) {
                        GString *tb = lua_touserdata (L, -1);
@@ -1859,8 +1856,8 @@ static gint
 lua_config_register_dependency (lua_State * L)
 {
        struct rspamd_config *cfg = lua_check_config (L, 1);
-       const gchar *name = NULL, *from = NULL;
-       gint id;
+       const gchar *parent = NULL, *child = NULL;
+       gint child_id;
        gboolean skip_squeeze = FALSE;
 
        if (cfg == NULL) {
@@ -1869,37 +1866,38 @@ lua_config_register_dependency (lua_State * L)
        }
 
        if (lua_type (L, 2) == LUA_TNUMBER) {
-               id = luaL_checknumber (L, 2);
-               name = luaL_checkstring (L, 3);
+               child_id = luaL_checknumber (L, 2);
+               parent = luaL_checkstring (L, 3);
 
                if (lua_isboolean (L, 4)) {
                        skip_squeeze = lua_toboolean (L, 4);
                }
 
                msg_warn_config ("calling for obsolete method to register deps for symbol %d->%s",
-                               id, name);
+                               child_id, parent);
 
-               if (id > 0 && name != NULL) {
+               if (child_id > 0 && parent != NULL) {
 
-                       if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg, name,
-                                       rspamd_symbols_cache_symbol_by_id (cfg->cache, id))) {
-                               rspamd_symbols_cache_add_dependency (cfg->cache, id, name);
+                       if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg,
+                                       rspamd_symbols_cache_symbol_by_id (cfg->cache, child_id),
+                                       parent)) {
+                               rspamd_symbols_cache_add_dependency (cfg->cache, child_id, parent);
                        }
                }
        }
        else {
-               from = luaL_checkstring (L,2);
-               name = luaL_checkstring (L, 3);
+               child = luaL_checkstring (L,2);
+               parent = luaL_checkstring (L, 3);
 
                if (lua_isboolean (L, 4)) {
                        skip_squeeze = lua_toboolean (L, 4);
                }
 
-               if (from != NULL && name != NULL) {
+               if (child != NULL && child != NULL) {
 
-                       if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg, name, from)) {
-                               rspamd_symbols_cache_add_delayed_dependency (cfg->cache, from,
-                                               name);
+                       if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg, child, parent)) {
+                               rspamd_symbols_cache_add_delayed_dependency (cfg->cache, child,
+                                               parent);
                        }
 
                }