aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lualib/lua_squeeze_rules.lua139
-rw-r--r--src/libserver/cfg_utils.c21
-rw-r--r--src/lua/lua_config.c48
3 files changed, 152 insertions, 56 deletions
diff --git a/lualib/lua_squeeze_rules.lua b/lualib/lua_squeeze_rules.lua
index d740e2ee0..02da274a7 100644
--- a/lualib/lua_squeeze_rules.lua
+++ b/lualib/lua_squeeze_rules.lua
@@ -22,28 +22,28 @@ local squeezed_rules = {} -- plain vector of all rules squeezed
local squeezed_symbols = {} -- indexed by name of symbol
local squeezed_deps = {} -- squeezed deps
local SN = 'lua_squeeze'
-local squeeze_function_id
+local squeeze_sym = 'LUA_SQUEEZE'
+local squeeze_function_ids = {}
local function lua_squeeze_function(task)
- if not squeezed_symbols then
- for k,v in pairs(squeezed_symbols) do
- if not squeezed_exceptions[k] then
- logger.debugm(SN, task, 'added squeezed rule: %s', k)
- table.insert(squeezed_rules, v)
- else
- logger.debugm(SN, task, 'skipped squeezed rule: %s', k)
- end
- end
+ for _,data in ipairs(squeezed_rules) do
+ local ret = {data[1](task)}
- squeezed_symbols = nil
- end
-
- for _,func in ipairs(squeezed_rules) do
- local ret = func(task)
-
- if ret then
+ if #ret ~= 0 then
+ local first = ret[1]
+ local sym = data[2]
-- Function has returned something, so it is rule, not a plugin
- logger.errx(task, 'hui: %s', ret)
+ if type(first) == 'boolean' then
+ if first then
+ table.remove(ret, 1)
+ task:insert_result(sym, 1.0, ret)
+ end
+ elseif type(first) == 'number' then
+ table.remove(ret, 1)
+ task:insert_result(sym, first, ret)
+ else
+ task:insert_result(sym, 1.0, ret)
+ end
end
end
end
@@ -61,37 +61,114 @@ exports.squeeze_rule = function(s, func)
end
else
-- Unconditionally add function to the squeezed rules
- logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', #squeezed_rules)
- table.insert(squeezed_rules, func)
+ local id = tostring(#squeezed_rules)
+ logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', id)
+ table.insert(squeezed_rules, {func, 'unnamed: ' .. id})
end
- if not squeeze_function_id then
- squeeze_function_id = rspamd_config:register_symbol{
+ if not squeeze_function_ids[1] then
+ squeeze_function_ids[1] = rspamd_config:register_symbol{
type = 'callback',
callback = lua_squeeze_function,
- name = 'LUA_SQUEEZE',
+ name = squeeze_sym,
description = 'Meta rule for Lua rules that can be squeezed',
no_squeeze = true, -- to avoid infinite recursion
}
end
- return squeeze_function_id
+ return squeeze_function_ids[1]
end
-exports.squeeze_dependency = function(from, to)
- logger.debugm(SN, rspamd_config, 'squeeze dep %s->%s', to, from)
+exports.squeeze_dependency = function(child, parent)
+ logger.debugm(SN, rspamd_config, 'squeeze dep %s->%s', child, parent)
- if not squeezed_deps[to] then
- squeezed_deps[to] = {}
+ if not squeezed_deps[parent] then
+ squeezed_deps[parent] = {}
end
- if not squeezed_symbols[to][from] then
- squeezed_symbols[to][from] = true
+ if not squeezed_deps[parent][child] then
+ squeezed_deps[parent][child] = true
else
- logger.warnx('duplicate dependency %s->%s', to, from)
+ logger.warnx(rspamd_config, 'duplicate dependency %s->%s', child, parent)
end
return true
end
+local function get_ordered_symbol_name(order)
+ if order == 0 then
+ return squeeze_sym
+ end
+
+ return squeeze_sym .. tostring(order)
+end
+
+local function register_topology_symbol(order)
+ local ord_sym = get_ordered_symbol_name(order)
+
+ squeeze_function_ids[order + 1] = rspamd_config:register_symbol{
+ type = 'callback',
+ callback = lua_squeeze_function,
+ name = ord_sym,
+ description = 'Meta rule for Lua rules that can be squeezed, order ' .. tostring(order),
+ no_squeeze = true, -- to avoid infinite recursion
+ }
+
+ local parent = get_ordered_symbol_name(order - 1)
+ logger.debugm(SN, rspamd_config, 'registered new order of deps: %s->%s',
+ ord_sym, parent)
+ rspamd_config:register_dependency(ord_sym, parent, true)
+end
+
+exports.squeeze_init = function()
+ local max_topology_order = 0
+
+ for parent,children in pairs(squeezed_deps) do
+ if not squeezed_symbols[parent] then
+ -- Trivial case, external dependnency
+ logger.debugm(SN, rspamd_config, 'register external squeezed dependency on %s',
+ parent)
+ rspamd_config:register_dependency(squeeze_sym, parent, true)
+ else
+ -- Not so trivial case
+ local ps = squeezed_symbols[parent]
+
+ for cld,_ in pairs(children) do
+ if squeezed_symbols[cld] then
+ -- Cross dependency
+ logger.debugm(SN, rspamd_config, 'cross dependency in squeezed symbols %s->%s',
+ cld, parent)
+ local order = math.max(ps.order + 1, squeezed_symbols[cld].order)
+ squeezed_symbols[cld].order = order
+ if order > max_topology_order then
+ -- Need to register new callback symbol to handle deps
+ register_topology_symbol(order)
+ max_topology_order = order
+ end
+ else
+ -- External symbol depends on a squeezed one
+ local parent_symbol = get_ordered_symbol_name(ps.order)
+ rspamd_config:register_dependency(cld, parent_symbol, true)
+ logger.debugm(SN, rspamd_config, 'register squeezed dependency for external symbol %s->%s',
+ cld, parent_symbol)
+ end
+ end
+ end
+ end
+
+ -- We have now all deps being registered, so we can register virtual symbols
+ -- and create squeezed rules
+ for k,v in pairs(squeezed_symbols) do
+ local parent_symbol = get_ordered_symbol_name(v.order)
+ logger.debugm(SN, rspamd_config, 'added squeezed rule: %s (%s)', k, parent_symbol)
+ rspamd_config:register_symbol{
+ type = 'virtual',
+ name = k,
+ parent = squeeze_function_ids[v.order + 1],
+ no_squeeze = true, -- to avoid infinite recursion
+ }
+ table.insert(squeezed_rules, {v.cb,k})
+ end
+end
+
return exports \ No newline at end of file
diff --git a/src/libserver/cfg_utils.c b/src/libserver/cfg_utils.c
index c042c5eb2..10503dc21 100644
--- a/src/libserver/cfg_utils.c
+++ b/src/libserver/cfg_utils.c
@@ -753,6 +753,27 @@ rspamd_config_post_load (struct rspamd_config *cfg,
}
if (opts & RSPAMD_CONFIG_INIT_SYMCACHE) {
+ lua_State *L = cfg->lua_state;
+ int err_idx;
+
+ /* Process squeezed Lua rules */
+ lua_pushcfunction (L, &rspamd_lua_traceback);
+ err_idx = lua_gettop (L);
+
+ if (rspamd_lua_require_function (cfg->lua_state, "lua_squeeze_rules",
+ "squeeze_init")) {
+ if (lua_pcall (L, 0, 0, err_idx) != 0) {
+ GString *tb = lua_touserdata (L, -1);
+ msg_err_config ("call to squeeze_init script failed: %v", tb);
+
+ if (tb) {
+ g_string_free (tb, TRUE);
+ }
+ }
+ }
+
+ lua_settop (L, err_idx - 1);
+
/* Init config cache */
rspamd_symbols_cache_init (cfg->cache);
diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c
index d6a4ad550..6abbf7d0e 100644
--- a/src/lua/lua_config.c
+++ b/src/lua/lua_config.c
@@ -1818,24 +1818,21 @@ lua_config_register_callback_symbol_priority (lua_State * L)
static gboolean
rspamd_lua_squeeze_dependency (lua_State *L, struct rspamd_config *cfg,
- const gchar *name,
- const gchar *from)
+ const gchar *child,
+ const gchar *parent)
{
gint err_idx;
gboolean ret = FALSE;
+ g_assert (parent != NULL);
+ g_assert (child != NULL);
+
lua_pushcfunction (L, &rspamd_lua_traceback);
err_idx = lua_gettop (L);
if (rspamd_lua_require_function (L, "lua_squeeze_rules", "squeeze_dependency")) {
- lua_pushstring (L, name);
-
- if (from) {
- lua_pushstring (L, from);
- }
- else {
- lua_pushnil (L);
- }
+ lua_pushstring (L, child);
+ lua_pushstring (L, parent);
if (lua_pcall (L, 2, 1, err_idx) != 0) {
GString *tb = lua_touserdata (L, -1);
@@ -1859,8 +1856,8 @@ static gint
lua_config_register_dependency (lua_State * L)
{
struct rspamd_config *cfg = lua_check_config (L, 1);
- const gchar *name = NULL, *from = NULL;
- gint id;
+ const gchar *parent = NULL, *child = NULL;
+ gint child_id;
gboolean skip_squeeze = FALSE;
if (cfg == NULL) {
@@ -1869,37 +1866,38 @@ lua_config_register_dependency (lua_State * L)
}
if (lua_type (L, 2) == LUA_TNUMBER) {
- id = luaL_checknumber (L, 2);
- name = luaL_checkstring (L, 3);
+ child_id = luaL_checknumber (L, 2);
+ parent = luaL_checkstring (L, 3);
if (lua_isboolean (L, 4)) {
skip_squeeze = lua_toboolean (L, 4);
}
msg_warn_config ("calling for obsolete method to register deps for symbol %d->%s",
- id, name);
+ child_id, parent);
- if (id > 0 && name != NULL) {
+ if (child_id > 0 && parent != NULL) {
- if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg, name,
- rspamd_symbols_cache_symbol_by_id (cfg->cache, id))) {
- rspamd_symbols_cache_add_dependency (cfg->cache, id, name);
+ if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg,
+ rspamd_symbols_cache_symbol_by_id (cfg->cache, child_id),
+ parent)) {
+ rspamd_symbols_cache_add_dependency (cfg->cache, child_id, parent);
}
}
}
else {
- from = luaL_checkstring (L,2);
- name = luaL_checkstring (L, 3);
+ child = luaL_checkstring (L,2);
+ parent = luaL_checkstring (L, 3);
if (lua_isboolean (L, 4)) {
skip_squeeze = lua_toboolean (L, 4);
}
- if (from != NULL && name != NULL) {
+ if (child != NULL && child != NULL) {
- if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg, name, from)) {
- rspamd_symbols_cache_add_delayed_dependency (cfg->cache, from,
- name);
+ if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg, child, parent)) {
+ rspamd_symbols_cache_add_delayed_dependency (cfg->cache, child,
+ parent);
}
}