From: Vsevolod Stakhov Date: Fri, 16 Mar 2018 10:56:09 +0000 (+0000) Subject: [Fix] Fix dependencies in lua squeeze X-Git-Tag: 1.7.1~30 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=48c010166b1a3d13a6cbfc3ad653e1d953d12a5f;p=rspamd.git [Fix] Fix dependencies in lua squeeze --- diff --git a/lualib/lua_squeeze_rules.lua b/lualib/lua_squeeze_rules.lua index 02da274a7..d04352a68 100644 --- a/lualib/lua_squeeze_rules.lua +++ b/lualib/lua_squeeze_rules.lua @@ -18,31 +18,34 @@ local exports = {} local logger = require 'rspamd_logger' -- Squeezed rules part -local squeezed_rules = {} -- plain vector of all rules squeezed +local squeezed_rules = {{}} -- plain vector of all rules squeezed local squeezed_symbols = {} -- indexed by name of symbol local squeezed_deps = {} -- squeezed deps +local squeezed_rdeps = {} -- squeezed reverse deps local SN = 'lua_squeeze' local squeeze_sym = 'LUA_SQUEEZE' local squeeze_function_ids = {} -local function lua_squeeze_function(task) - for _,data in ipairs(squeezed_rules) do - local ret = {data[1](task)} - - if #ret ~= 0 then - local first = ret[1] - local sym = data[2] - -- Function has returned something, so it is rule, not a plugin - if type(first) == 'boolean' then - if first then +local function gen_lua_squeeze_function(order) + return function(task) + for _,data in ipairs(squeezed_rules[order]) do + local ret = {data[1](task)} + + if #ret ~= 0 then + local first = ret[1] + local sym = data[2] + -- Function has returned something, so it is rule, not a plugin + if type(first) == 'boolean' then + if first then + table.remove(ret, 1) + task:insert_result(sym, 1.0, ret) + end + elseif type(first) == 'number' then table.remove(ret, 1) + task:insert_result(sym, first, ret) + else task:insert_result(sym, 1.0, ret) end - elseif type(first) == 'number' then - table.remove(ret, 1) - task:insert_result(sym, first, ret) - else - task:insert_result(sym, 1.0, ret) end end end @@ -54,6 +57,7 @@ exports.squeeze_rule = function(s, func) squeezed_symbols[s] = { cb = func, order = 0, + sym = s, } logger.debugm(SN, rspamd_config, 'squeezed rule: %s', s) else @@ -63,13 +67,13 @@ exports.squeeze_rule = function(s, func) -- Unconditionally add function to the squeezed rules local id = tostring(#squeezed_rules) logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', id) - table.insert(squeezed_rules, {func, 'unnamed: ' .. id}) + table.insert(squeezed_rules[1], {func, 'unnamed: ' .. id}) end if not squeeze_function_ids[1] then squeeze_function_ids[1] = rspamd_config:register_symbol{ type = 'callback', - callback = lua_squeeze_function, + callback = gen_lua_squeeze_function(1), name = squeeze_sym, description = 'Meta rule for Lua rules that can be squeezed', no_squeeze = true, -- to avoid infinite recursion @@ -92,11 +96,19 @@ exports.squeeze_dependency = function(child, parent) logger.warnx(rspamd_config, 'duplicate dependency %s->%s', child, parent) end + if not squeezed_rdeps[child] then + squeezed_rdeps[child] = {} + end + + if not squeezed_rdeps[child][parent] then + squeezed_rdeps[child][parent] = true + end + return true end local function get_ordered_symbol_name(order) - if order == 0 then + if order == 1 then return squeeze_sym end @@ -106,9 +118,9 @@ end local function register_topology_symbol(order) local ord_sym = get_ordered_symbol_name(order) - squeeze_function_ids[order + 1] = rspamd_config:register_symbol{ + squeeze_function_ids[order] = rspamd_config:register_symbol{ type = 'callback', - callback = lua_squeeze_function, + callback = gen_lua_squeeze_function(order), name = ord_sym, description = 'Meta rule for Lua rules that can be squeezed, order ' .. tostring(order), no_squeeze = true, -- to avoid infinite recursion @@ -121,7 +133,30 @@ local function register_topology_symbol(order) end exports.squeeze_init = function() - local max_topology_order = 0 + -- Do topological sorting + for s,v in pairs(squeezed_symbols) do + local function visit(node, order) + + if order > node.order then + node.order = order + logger.debugm(SN, rspamd_config, "symbol: %s, order: %s", node.sym, order) + else + return + end + + if squeezed_deps[node.sym] then + for dep,_ in pairs(squeezed_deps[node.sym]) do + if squeezed_symbols[dep] then + visit(squeezed_symbols[dep], order + 1) + end + end + end + end + + if v.order == 0 then + visit(v, 1) + end + end for parent,children in pairs(squeezed_deps) do if not squeezed_symbols[parent] then @@ -138,12 +173,14 @@ exports.squeeze_init = function() -- Cross dependency logger.debugm(SN, rspamd_config, 'cross dependency in squeezed symbols %s->%s', cld, parent) - local order = math.max(ps.order + 1, squeezed_symbols[cld].order) - squeezed_symbols[cld].order = order - if order > max_topology_order then + local order = squeezed_symbols[cld].order + if not squeeze_function_ids[order] then -- Need to register new callback symbol to handle deps - register_topology_symbol(order) - max_topology_order = order + for i = 1, order do + if not squeeze_function_ids[i] then + register_topology_symbol(i) + end + end end else -- External symbol depends on a squeezed one @@ -164,10 +201,13 @@ exports.squeeze_init = function() rspamd_config:register_symbol{ type = 'virtual', name = k, - parent = squeeze_function_ids[v.order + 1], + parent = squeeze_function_ids[v.order], no_squeeze = true, -- to avoid infinite recursion } - table.insert(squeezed_rules, {v.cb,k}) + if not squeezed_rules[v.order] then + squeezed_rules[v.order] = {} + end + table.insert(squeezed_rules[v.order], {v.cb,k}) end end