From e417edfa97e80946b3fadbd5b7da3420b287d76a Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Thu, 15 Mar 2018 18:00:35 +0000 Subject: [PATCH] [Feature] Add lua rules squeezing --- lualib/lua_squeeze_rules.lua | 139 +++++++++++++++++++++++++++-------- src/libserver/cfg_utils.c | 21 ++++++ src/lua/lua_config.c | 48 ++++++------ 3 files changed, 152 insertions(+), 56 deletions(-) diff --git a/lualib/lua_squeeze_rules.lua b/lualib/lua_squeeze_rules.lua index d740e2ee0..02da274a7 100644 --- a/lualib/lua_squeeze_rules.lua +++ b/lualib/lua_squeeze_rules.lua @@ -22,28 +22,28 @@ local squeezed_rules = {} -- plain vector of all rules squeezed local squeezed_symbols = {} -- indexed by name of symbol local squeezed_deps = {} -- squeezed deps local SN = 'lua_squeeze' -local squeeze_function_id +local squeeze_sym = 'LUA_SQUEEZE' +local squeeze_function_ids = {} local function lua_squeeze_function(task) - if not squeezed_symbols then - for k,v in pairs(squeezed_symbols) do - if not squeezed_exceptions[k] then - logger.debugm(SN, task, 'added squeezed rule: %s', k) - table.insert(squeezed_rules, v) - else - logger.debugm(SN, task, 'skipped squeezed rule: %s', k) - end - end + for _,data in ipairs(squeezed_rules) do + local ret = {data[1](task)} - squeezed_symbols = nil - end - - for _,func in ipairs(squeezed_rules) do - local ret = func(task) - - if ret then + if #ret ~= 0 then + local first = ret[1] + local sym = data[2] -- Function has returned something, so it is rule, not a plugin - logger.errx(task, 'hui: %s', ret) + if type(first) == 'boolean' then + if first then + table.remove(ret, 1) + task:insert_result(sym, 1.0, ret) + end + elseif type(first) == 'number' then + table.remove(ret, 1) + task:insert_result(sym, first, ret) + else + task:insert_result(sym, 1.0, ret) + end end end end @@ -61,37 +61,114 @@ exports.squeeze_rule = function(s, func) end else -- Unconditionally add function to the squeezed rules - logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', #squeezed_rules) - table.insert(squeezed_rules, func) + local id = tostring(#squeezed_rules) + logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', id) + table.insert(squeezed_rules, {func, 'unnamed: ' .. id}) end - if not squeeze_function_id then - squeeze_function_id = rspamd_config:register_symbol{ + if not squeeze_function_ids[1] then + squeeze_function_ids[1] = rspamd_config:register_symbol{ type = 'callback', callback = lua_squeeze_function, - name = 'LUA_SQUEEZE', + name = squeeze_sym, description = 'Meta rule for Lua rules that can be squeezed', no_squeeze = true, -- to avoid infinite recursion } end - return squeeze_function_id + return squeeze_function_ids[1] end -exports.squeeze_dependency = function(from, to) - logger.debugm(SN, rspamd_config, 'squeeze dep %s->%s', to, from) +exports.squeeze_dependency = function(child, parent) + logger.debugm(SN, rspamd_config, 'squeeze dep %s->%s', child, parent) - if not squeezed_deps[to] then - squeezed_deps[to] = {} + if not squeezed_deps[parent] then + squeezed_deps[parent] = {} end - if not squeezed_symbols[to][from] then - squeezed_symbols[to][from] = true + if not squeezed_deps[parent][child] then + squeezed_deps[parent][child] = true else - logger.warnx('duplicate dependency %s->%s', to, from) + logger.warnx(rspamd_config, 'duplicate dependency %s->%s', child, parent) end return true end +local function get_ordered_symbol_name(order) + if order == 0 then + return squeeze_sym + end + + return squeeze_sym .. tostring(order) +end + +local function register_topology_symbol(order) + local ord_sym = get_ordered_symbol_name(order) + + squeeze_function_ids[order + 1] = rspamd_config:register_symbol{ + type = 'callback', + callback = lua_squeeze_function, + name = ord_sym, + description = 'Meta rule for Lua rules that can be squeezed, order ' .. tostring(order), + no_squeeze = true, -- to avoid infinite recursion + } + + local parent = get_ordered_symbol_name(order - 1) + logger.debugm(SN, rspamd_config, 'registered new order of deps: %s->%s', + ord_sym, parent) + rspamd_config:register_dependency(ord_sym, parent, true) +end + +exports.squeeze_init = function() + local max_topology_order = 0 + + for parent,children in pairs(squeezed_deps) do + if not squeezed_symbols[parent] then + -- Trivial case, external dependnency + logger.debugm(SN, rspamd_config, 'register external squeezed dependency on %s', + parent) + rspamd_config:register_dependency(squeeze_sym, parent, true) + else + -- Not so trivial case + local ps = squeezed_symbols[parent] + + for cld,_ in pairs(children) do + if squeezed_symbols[cld] then + -- Cross dependency + logger.debugm(SN, rspamd_config, 'cross dependency in squeezed symbols %s->%s', + cld, parent) + local order = math.max(ps.order + 1, squeezed_symbols[cld].order) + squeezed_symbols[cld].order = order + if order > max_topology_order then + -- Need to register new callback symbol to handle deps + register_topology_symbol(order) + max_topology_order = order + end + else + -- External symbol depends on a squeezed one + local parent_symbol = get_ordered_symbol_name(ps.order) + rspamd_config:register_dependency(cld, parent_symbol, true) + logger.debugm(SN, rspamd_config, 'register squeezed dependency for external symbol %s->%s', + cld, parent_symbol) + end + end + end + end + + -- We have now all deps being registered, so we can register virtual symbols + -- and create squeezed rules + for k,v in pairs(squeezed_symbols) do + local parent_symbol = get_ordered_symbol_name(v.order) + logger.debugm(SN, rspamd_config, 'added squeezed rule: %s (%s)', k, parent_symbol) + rspamd_config:register_symbol{ + type = 'virtual', + name = k, + parent = squeeze_function_ids[v.order + 1], + no_squeeze = true, -- to avoid infinite recursion + } + table.insert(squeezed_rules, {v.cb,k}) + end +end + return exports \ No newline at end of file diff --git a/src/libserver/cfg_utils.c b/src/libserver/cfg_utils.c index c042c5eb2..10503dc21 100644 --- a/src/libserver/cfg_utils.c +++ b/src/libserver/cfg_utils.c @@ -753,6 +753,27 @@ rspamd_config_post_load (struct rspamd_config *cfg, } if (opts & RSPAMD_CONFIG_INIT_SYMCACHE) { + lua_State *L = cfg->lua_state; + int err_idx; + + /* Process squeezed Lua rules */ + lua_pushcfunction (L, &rspamd_lua_traceback); + err_idx = lua_gettop (L); + + if (rspamd_lua_require_function (cfg->lua_state, "lua_squeeze_rules", + "squeeze_init")) { + if (lua_pcall (L, 0, 0, err_idx) != 0) { + GString *tb = lua_touserdata (L, -1); + msg_err_config ("call to squeeze_init script failed: %v", tb); + + if (tb) { + g_string_free (tb, TRUE); + } + } + } + + lua_settop (L, err_idx - 1); + /* Init config cache */ rspamd_symbols_cache_init (cfg->cache); diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index d6a4ad550..6abbf7d0e 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -1818,24 +1818,21 @@ lua_config_register_callback_symbol_priority (lua_State * L) static gboolean rspamd_lua_squeeze_dependency (lua_State *L, struct rspamd_config *cfg, - const gchar *name, - const gchar *from) + const gchar *child, + const gchar *parent) { gint err_idx; gboolean ret = FALSE; + g_assert (parent != NULL); + g_assert (child != NULL); + lua_pushcfunction (L, &rspamd_lua_traceback); err_idx = lua_gettop (L); if (rspamd_lua_require_function (L, "lua_squeeze_rules", "squeeze_dependency")) { - lua_pushstring (L, name); - - if (from) { - lua_pushstring (L, from); - } - else { - lua_pushnil (L); - } + lua_pushstring (L, child); + lua_pushstring (L, parent); if (lua_pcall (L, 2, 1, err_idx) != 0) { GString *tb = lua_touserdata (L, -1); @@ -1859,8 +1856,8 @@ static gint lua_config_register_dependency (lua_State * L) { struct rspamd_config *cfg = lua_check_config (L, 1); - const gchar *name = NULL, *from = NULL; - gint id; + const gchar *parent = NULL, *child = NULL; + gint child_id; gboolean skip_squeeze = FALSE; if (cfg == NULL) { @@ -1869,37 +1866,38 @@ lua_config_register_dependency (lua_State * L) } if (lua_type (L, 2) == LUA_TNUMBER) { - id = luaL_checknumber (L, 2); - name = luaL_checkstring (L, 3); + child_id = luaL_checknumber (L, 2); + parent = luaL_checkstring (L, 3); if (lua_isboolean (L, 4)) { skip_squeeze = lua_toboolean (L, 4); } msg_warn_config ("calling for obsolete method to register deps for symbol %d->%s", - id, name); + child_id, parent); - if (id > 0 && name != NULL) { + if (child_id > 0 && parent != NULL) { - if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg, name, - rspamd_symbols_cache_symbol_by_id (cfg->cache, id))) { - rspamd_symbols_cache_add_dependency (cfg->cache, id, name); + if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg, + rspamd_symbols_cache_symbol_by_id (cfg->cache, child_id), + parent)) { + rspamd_symbols_cache_add_dependency (cfg->cache, child_id, parent); } } } else { - from = luaL_checkstring (L,2); - name = luaL_checkstring (L, 3); + child = luaL_checkstring (L,2); + parent = luaL_checkstring (L, 3); if (lua_isboolean (L, 4)) { skip_squeeze = lua_toboolean (L, 4); } - if (from != NULL && name != NULL) { + if (child != NULL && child != NULL) { - if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg, name, from)) { - rspamd_symbols_cache_add_delayed_dependency (cfg->cache, from, - name); + if (skip_squeeze || !rspamd_lua_squeeze_dependency (L, cfg, child, parent)) { + rspamd_symbols_cache_add_delayed_dependency (cfg->cache, child, + parent); } } -- 2.39.5