summaryrefslogtreecommitdiffstats
path: root/lualib
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2018-03-16 10:56:09 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2018-03-16 10:56:09 +0000
commit48c010166b1a3d13a6cbfc3ad653e1d953d12a5f (patch)
treea09dcdb4c76302490defa96fbb0d2cd747a67209 /lualib
parent6d8158e23ad11a5a002753aed3382b5e1dbc17bf (diff)
downloadrspamd-48c010166b1a3d13a6cbfc3ad653e1d953d12a5f.tar.gz
rspamd-48c010166b1a3d13a6cbfc3ad653e1d953d12a5f.zip
[Fix] Fix dependencies in lua squeeze
Diffstat (limited to 'lualib')
-rw-r--r--lualib/lua_squeeze_rules.lua98
1 files changed, 69 insertions, 29 deletions
diff --git a/lualib/lua_squeeze_rules.lua b/lualib/lua_squeeze_rules.lua
index 02da274a7..d04352a68 100644
--- a/lualib/lua_squeeze_rules.lua
+++ b/lualib/lua_squeeze_rules.lua
@@ -18,31 +18,34 @@ local exports = {}
local logger = require 'rspamd_logger'
-- Squeezed rules part
-local squeezed_rules = {} -- plain vector of all rules squeezed
+local squeezed_rules = {{}} -- plain vector of all rules squeezed
local squeezed_symbols = {} -- indexed by name of symbol
local squeezed_deps = {} -- squeezed deps
+local squeezed_rdeps = {} -- squeezed reverse deps
local SN = 'lua_squeeze'
local squeeze_sym = 'LUA_SQUEEZE'
local squeeze_function_ids = {}
-local function lua_squeeze_function(task)
- for _,data in ipairs(squeezed_rules) do
- local ret = {data[1](task)}
-
- if #ret ~= 0 then
- local first = ret[1]
- local sym = data[2]
- -- Function has returned something, so it is rule, not a plugin
- if type(first) == 'boolean' then
- if first then
+local function gen_lua_squeeze_function(order)
+ return function(task)
+ for _,data in ipairs(squeezed_rules[order]) do
+ local ret = {data[1](task)}
+
+ if #ret ~= 0 then
+ local first = ret[1]
+ local sym = data[2]
+ -- Function has returned something, so it is rule, not a plugin
+ if type(first) == 'boolean' then
+ if first then
+ table.remove(ret, 1)
+ task:insert_result(sym, 1.0, ret)
+ end
+ elseif type(first) == 'number' then
table.remove(ret, 1)
+ task:insert_result(sym, first, ret)
+ else
task:insert_result(sym, 1.0, ret)
end
- elseif type(first) == 'number' then
- table.remove(ret, 1)
- task:insert_result(sym, first, ret)
- else
- task:insert_result(sym, 1.0, ret)
end
end
end
@@ -54,6 +57,7 @@ exports.squeeze_rule = function(s, func)
squeezed_symbols[s] = {
cb = func,
order = 0,
+ sym = s,
}
logger.debugm(SN, rspamd_config, 'squeezed rule: %s', s)
else
@@ -63,13 +67,13 @@ exports.squeeze_rule = function(s, func)
-- Unconditionally add function to the squeezed rules
local id = tostring(#squeezed_rules)
logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', id)
- table.insert(squeezed_rules, {func, 'unnamed: ' .. id})
+ table.insert(squeezed_rules[1], {func, 'unnamed: ' .. id})
end
if not squeeze_function_ids[1] then
squeeze_function_ids[1] = rspamd_config:register_symbol{
type = 'callback',
- callback = lua_squeeze_function,
+ callback = gen_lua_squeeze_function(1),
name = squeeze_sym,
description = 'Meta rule for Lua rules that can be squeezed',
no_squeeze = true, -- to avoid infinite recursion
@@ -92,11 +96,19 @@ exports.squeeze_dependency = function(child, parent)
logger.warnx(rspamd_config, 'duplicate dependency %s->%s', child, parent)
end
+ if not squeezed_rdeps[child] then
+ squeezed_rdeps[child] = {}
+ end
+
+ if not squeezed_rdeps[child][parent] then
+ squeezed_rdeps[child][parent] = true
+ end
+
return true
end
local function get_ordered_symbol_name(order)
- if order == 0 then
+ if order == 1 then
return squeeze_sym
end
@@ -106,9 +118,9 @@ end
local function register_topology_symbol(order)
local ord_sym = get_ordered_symbol_name(order)
- squeeze_function_ids[order + 1] = rspamd_config:register_symbol{
+ squeeze_function_ids[order] = rspamd_config:register_symbol{
type = 'callback',
- callback = lua_squeeze_function,
+ callback = gen_lua_squeeze_function(order),
name = ord_sym,
description = 'Meta rule for Lua rules that can be squeezed, order ' .. tostring(order),
no_squeeze = true, -- to avoid infinite recursion
@@ -121,7 +133,30 @@ local function register_topology_symbol(order)
end
exports.squeeze_init = function()
- local max_topology_order = 0
+ -- Do topological sorting
+ for s,v in pairs(squeezed_symbols) do
+ local function visit(node, order)
+
+ if order > node.order then
+ node.order = order
+ logger.debugm(SN, rspamd_config, "symbol: %s, order: %s", node.sym, order)
+ else
+ return
+ end
+
+ if squeezed_deps[node.sym] then
+ for dep,_ in pairs(squeezed_deps[node.sym]) do
+ if squeezed_symbols[dep] then
+ visit(squeezed_symbols[dep], order + 1)
+ end
+ end
+ end
+ end
+
+ if v.order == 0 then
+ visit(v, 1)
+ end
+ end
for parent,children in pairs(squeezed_deps) do
if not squeezed_symbols[parent] then
@@ -138,12 +173,14 @@ exports.squeeze_init = function()
-- Cross dependency
logger.debugm(SN, rspamd_config, 'cross dependency in squeezed symbols %s->%s',
cld, parent)
- local order = math.max(ps.order + 1, squeezed_symbols[cld].order)
- squeezed_symbols[cld].order = order
- if order > max_topology_order then
+ local order = squeezed_symbols[cld].order
+ if not squeeze_function_ids[order] then
-- Need to register new callback symbol to handle deps
- register_topology_symbol(order)
- max_topology_order = order
+ for i = 1, order do
+ if not squeeze_function_ids[i] then
+ register_topology_symbol(i)
+ end
+ end
end
else
-- External symbol depends on a squeezed one
@@ -164,10 +201,13 @@ exports.squeeze_init = function()
rspamd_config:register_symbol{
type = 'virtual',
name = k,
- parent = squeeze_function_ids[v.order + 1],
+ parent = squeeze_function_ids[v.order],
no_squeeze = true, -- to avoid infinite recursion
}
- table.insert(squeezed_rules, {v.cb,k})
+ if not squeezed_rules[v.order] then
+ squeezed_rules[v.order] = {}
+ end
+ table.insert(squeezed_rules[v.order], {v.cb,k})
end
end