]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Fix dependencies in lua squeeze
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 16 Mar 2018 10:56:09 +0000 (10:56 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 16 Mar 2018 10:56:09 +0000 (10:56 +0000)
lualib/lua_squeeze_rules.lua

index 02da274a70741178aedfc9d486279758fa9e651c..d04352a685353cd8ffbab440b317c4a1cb56b0bc 100644 (file)
@@ -18,31 +18,34 @@ local exports = {}
 local logger = require 'rspamd_logger'
 
 -- Squeezed rules part
-local squeezed_rules = {} -- plain vector of all rules squeezed
+local squeezed_rules = {{}} -- plain vector of all rules squeezed
 local squeezed_symbols = {} -- indexed by name of symbol
 local squeezed_deps = {} -- squeezed deps
+local squeezed_rdeps = {} -- squeezed reverse deps
 local SN = 'lua_squeeze'
 local squeeze_sym = 'LUA_SQUEEZE'
 local squeeze_function_ids = {}
 
-local function lua_squeeze_function(task)
-  for _,data in ipairs(squeezed_rules) do
-    local ret = {data[1](task)}
-
-    if #ret ~= 0 then
-      local first = ret[1]
-      local sym = data[2]
-      -- Function has returned something, so it is rule, not a plugin
-      if type(first) == 'boolean' then
-        if first then
+local function gen_lua_squeeze_function(order)
+  return function(task)
+    for _,data in ipairs(squeezed_rules[order]) do
+      local ret = {data[1](task)}
+
+      if #ret ~= 0 then
+        local first = ret[1]
+        local sym = data[2]
+        -- Function has returned something, so it is rule, not a plugin
+        if type(first) == 'boolean' then
+          if first then
+            table.remove(ret, 1)
+            task:insert_result(sym, 1.0, ret)
+          end
+        elseif type(first) == 'number' then
           table.remove(ret, 1)
+          task:insert_result(sym, first, ret)
+        else
           task:insert_result(sym, 1.0, ret)
         end
-      elseif type(first) == 'number' then
-        table.remove(ret, 1)
-        task:insert_result(sym, first, ret)
-      else
-        task:insert_result(sym, 1.0, ret)
       end
     end
   end
@@ -54,6 +57,7 @@ exports.squeeze_rule = function(s, func)
       squeezed_symbols[s] = {
         cb = func,
         order = 0,
+        sym = s,
       }
       logger.debugm(SN, rspamd_config, 'squeezed rule: %s', s)
     else
@@ -63,13 +67,13 @@ exports.squeeze_rule = function(s, func)
     -- Unconditionally add function to the squeezed rules
     local id = tostring(#squeezed_rules)
     logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', id)
-    table.insert(squeezed_rules, {func, 'unnamed: ' .. id})
+    table.insert(squeezed_rules[1], {func, 'unnamed: ' .. id})
   end
 
   if not squeeze_function_ids[1] then
     squeeze_function_ids[1] = rspamd_config:register_symbol{
       type = 'callback',
-      callback = lua_squeeze_function,
+      callback = gen_lua_squeeze_function(1),
       name = squeeze_sym,
       description = 'Meta rule for Lua rules that can be squeezed',
       no_squeeze = true, -- to avoid infinite recursion
@@ -92,11 +96,19 @@ exports.squeeze_dependency = function(child, parent)
     logger.warnx(rspamd_config, 'duplicate dependency %s->%s', child, parent)
   end
 
+  if not squeezed_rdeps[child] then
+    squeezed_rdeps[child] = {}
+  end
+
+  if not squeezed_rdeps[child][parent] then
+    squeezed_rdeps[child][parent] = true
+  end
+
   return true
 end
 
 local function get_ordered_symbol_name(order)
-  if order == 0 then
+  if order == 1 then
     return squeeze_sym
   end
 
@@ -106,9 +118,9 @@ end
 local function register_topology_symbol(order)
   local ord_sym = get_ordered_symbol_name(order)
 
-  squeeze_function_ids[order + 1] = rspamd_config:register_symbol{
+  squeeze_function_ids[order] = rspamd_config:register_symbol{
     type = 'callback',
-    callback = lua_squeeze_function,
+    callback = gen_lua_squeeze_function(order),
     name = ord_sym,
     description = 'Meta rule for Lua rules that can be squeezed, order ' .. tostring(order),
     no_squeeze = true, -- to avoid infinite recursion
@@ -121,7 +133,30 @@ local function register_topology_symbol(order)
 end
 
 exports.squeeze_init = function()
-  local max_topology_order = 0
+  -- Do topological sorting
+  for s,v in pairs(squeezed_symbols) do
+    local function visit(node, order)
+
+      if order > node.order then
+        node.order = order
+        logger.debugm(SN, rspamd_config, "symbol: %s, order: %s", node.sym, order)
+      else
+        return
+      end
+
+      if squeezed_deps[node.sym] then
+        for dep,_ in pairs(squeezed_deps[node.sym]) do
+          if squeezed_symbols[dep] then
+            visit(squeezed_symbols[dep], order + 1)
+          end
+        end
+      end
+    end
+
+    if v.order == 0 then
+      visit(v, 1)
+    end
+  end
 
   for parent,children in pairs(squeezed_deps) do
     if not squeezed_symbols[parent] then
@@ -138,12 +173,14 @@ exports.squeeze_init = function()
           -- Cross dependency
           logger.debugm(SN, rspamd_config, 'cross dependency in squeezed symbols %s->%s',
               cld, parent)
-          local order = math.max(ps.order + 1, squeezed_symbols[cld].order)
-          squeezed_symbols[cld].order = order
-          if order > max_topology_order then
+          local order = squeezed_symbols[cld].order
+          if not squeeze_function_ids[order] then
             -- Need to register new callback symbol to handle deps
-            register_topology_symbol(order)
-            max_topology_order = order
+            for i = 1, order do
+              if not squeeze_function_ids[i] then
+                register_topology_symbol(i)
+              end
+            end
           end
         else
           -- External symbol depends on a squeezed one
@@ -164,10 +201,13 @@ exports.squeeze_init = function()
     rspamd_config:register_symbol{
       type = 'virtual',
       name = k,
-      parent = squeeze_function_ids[v.order + 1],
+      parent = squeeze_function_ids[v.order],
       no_squeeze = true, -- to avoid infinite recursion
     }
-    table.insert(squeezed_rules, {v.cb,k})
+    if not squeezed_rules[v.order] then
+      squeezed_rules[v.order] = {}
+    end
+    table.insert(squeezed_rules[v.order], {v.cb,k})
   end
 end