diff options
-rw-r--r-- | lualib/lua_squeeze_rules.lua | 130 | ||||
-rw-r--r-- | src/plugins/lua/settings.lua | 6 |
2 files changed, 114 insertions, 22 deletions
diff --git a/lualib/lua_squeeze_rules.lua b/lualib/lua_squeeze_rules.lua index 9bcccf6f0..e310986d6 100644 --- a/lualib/lua_squeeze_rules.lua +++ b/lualib/lua_squeeze_rules.lua @@ -25,37 +25,46 @@ local squeezed_rdeps = {} -- squeezed reverse deps local SN = 'lua_squeeze' local squeeze_sym = 'LUA_SQUEEZE' local squeeze_function_ids = {} +local squeezed_groups = {} local function gen_lua_squeeze_function(order) return function(task) + local symbols_disabled = task:cache_get('squeezed_disable') for _,data in ipairs(squeezed_rules[order]) do - local ret = {data[1](task)} - - if #ret ~= 0 then - local first = ret[1] - local sym = data[2] - -- Function has returned something, so it is rule, not a plugin - if type(first) == 'boolean' then - if first then - table.remove(ret, 1) - if type(ret[1]) == 'table' then - task:insert_result(sym, 1.0, ret[1]) - else - task:insert_result(sym, 1.0, ret) + if not symbols_disabled or not symbols_disabled[data[2]] then + local ret = {data[1](task)} + + if #ret ~= 0 then + local first = ret[1] + local sym = data[2] + -- Function has returned something, so it is rule, not a plugin + if type(first) == 'boolean' then + if first then + table.remove(ret, 1) + + if type(ret[1]) == 'table' then + task:insert_result(sym, 1.0, ret[1]) + else + task:insert_result(sym, 1.0, ret) + end end - end - elseif type(first) == 'number' then - table.remove(ret, 1) + elseif type(first) == 'number' then + table.remove(ret, 1) - if first ~= 0 then - task:insert_result(sym, first, ret) + if first ~= 0 then + task:insert_result(sym, first, ret) + end + else + task:insert_result(sym, 1.0, ret) end - else - task:insert_result(sym, 1.0, ret) end + else + logger.debugm(SN, task, 'skip symbol due to settings: %s', data[2]) + end + + end - end end end @@ -224,10 +233,87 @@ exports.squeeze_init = function() parent = squeeze_function_ids[v.order], no_squeeze = true, -- to avoid infinite recursion } + local metric_sym = rspamd_config:get_metric_symbol(k) + + if metric_sym then + v.group = metric_sym.group + v.score = metric_sym.score + v.description = description + + if not squeezed_groups[v.group] then + logger.debugm(SN, rspamd_config, 'added squeezed group: %s', v.group) + squeezed_groups[v.group] = {} + end + + table.insert(squeezed_groups[v.group], v) + end if not squeezed_rules[v.order] then squeezed_rules[v.order] = {} end - table.insert(squeezed_rules[v.order], {v.cb,k}) + table.insert(squeezed_rules[v.order], {v.cb,k,v}) + end +end + +exports.handle_settings = function(task, settings) + local symbols_disabled = {} + local symbols_enabled = {} + local found = false + + if settings.default then settings = settings.default end + + if settings.symbols_enabled then + for k,v in squeezed_symbols do + if not settings.symbols_enabled[k] then + symbols_disabled[k] = true + found = true + else + symbols_enabled[k] = true + end + end + end + + if settings.groups_enabled then + for k,syms in pairs(squeezed_groups) do + if not settings.groups_enabled[k] then + for _,sym in ipairs(syms) do + if not symbols_enabled[sym] then + symbols_disabled[sym] = true + found = true + end + end + else + for _,sym in ipairs(syms) do + if symbols_disabled[sym] then + symbols_disabled[sym] = nil + end + symbols_enabled[sym] = true + end + end + end + end + + if settings.symbols_disabled then + for k,v in squeezed_symbols do + if settings.symbols_disabled[k] then + symbols_disabled[k] = true + found = true + end + end + end + + if settings.groups_disabled then + for k,syms in pairs(squeezed_groups) do + if settings.groups_disabled[k] then + for _,sym in ipairs(syms) do + symbols_disabled[sym] = true + found = true + end + end + end + end + + if found then + task:cache_set('squeezed_disable', symbols_disabled) end end diff --git a/src/plugins/lua/settings.lua b/src/plugins/lua/settings.lua index 13647c480..3ebf83f99 100644 --- a/src/plugins/lua/settings.lua +++ b/src/plugins/lua/settings.lua @@ -24,6 +24,7 @@ end local rspamd_logger = require "rspamd_logger" local rspamd_maps = require "lua_maps" +local lua_squeeze = require "lua_squeeze_rules" local redis_params local settings = {} @@ -47,6 +48,7 @@ local function check_query_settings(task) local settings_obj = parser:get_object() task:set_settings(settings_obj) task:cache_set('settings', settings_obj) + lua_squeeze.handle_settings(task, settings_obj) return true else @@ -75,6 +77,7 @@ local function check_query_settings(task) task:set_settings(nset) task:cache_set('settings', nset) + lua_squeeze.handle_settings(task, nset) return true end @@ -87,6 +90,7 @@ local function check_query_settings(task) local elt = settings_ids[id_str] if elt and elt['apply'] then task:set_settings(elt['apply']) + lua_squeeze.handle_settings(task, elt['apply']) task:cache_set('settings', elt['apply']) if elt.apply['add_headers'] or elt.apply['remove_headers'] then @@ -336,6 +340,7 @@ local function check_settings(task) task:get_message_id(), s.name) if rule['apply'] then task:set_settings(rule['apply']) + lua_squeeze.handle_settings(task, rule['apply']) task:cache_set('settings', rule['apply']) applied = true end @@ -648,6 +653,7 @@ local function gen_redis_callback(handler, id) rspamd_logger.infox(task, "<%1> apply settings according to redis rule %2", task:get_message_id(), id) task:set_settings(obj) + lua_squeeze.handle_settings(task, obj) task:cache_set('settings', obj) break end |