123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363 |
- --[[
- Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ]]--
-
- local exports = {}
- local logger = require 'rspamd_logger'
-
- -- Squeezed rules part
- local squeezed_rules = {{}} -- plain vector of all rules squeezed
- local squeezed_symbols = {} -- indexed by name of symbol
- local squeezed_deps = {} -- squeezed deps
- local squeezed_rdeps = {} -- squeezed reverse deps
- local SN = 'lua_squeeze'
- local squeeze_sym = 'LUA_SQUEEZE'
- local squeeze_function_ids = {}
- local squeezed_groups = {}
-
- local function gen_lua_squeeze_function(order)
- return function(task)
- local symbols_disabled = task:cache_get('squeezed_disable')
- for _,data in ipairs(squeezed_rules[order]) do
- if not symbols_disabled or not symbols_disabled[data[2]] then
- local function real_call()
- return {data[1](task)}
- end
-
- -- Too expensive to call :(
- --logger.debugm(SN, task, 'call for: %s', data[2])
- local status, ret = pcall(real_call)
-
- if not status then
- logger.errx(task, 'error in squeezed rule %s: %s', data[2], ret)
- else
- if #ret ~= 0 then
- local first = ret[1]
- local sym = data[2]
- -- Function has returned something, so it is rule, not a plugin
- if type(first) == 'boolean' then
- if first then
- table.remove(ret, 1)
-
- local second = ret[1]
-
- if type(second) == 'number' then
- table.remove(ret, 1)
- if second ~= 0 then
- if type(ret[1]) == 'table' then
- task:insert_result(sym, second, ret[1])
- else
- task:insert_result(sym, second, ret)
- end
- end
- else
- if type(ret[1]) == 'table' then
- task:insert_result(sym, 1.0, ret[1])
- else
- task:insert_result(sym, 1.0, ret)
- end
- end
- end
- elseif type(first) == 'number' then
- table.remove(ret, 1)
-
- if first ~= 0 then
- if type(ret[1]) == 'table' then
- task:insert_result(sym, first, ret[1])
- else
- task:insert_result(sym, first, ret)
- end
- end
- else
- if type(ret[1]) == 'table' then
- task:insert_result(sym, 1.0, ret[1])
- else
- task:insert_result(sym, 1.0, ret)
- end
- end
- end
- end
- else
- logger.debugm(SN, task, 'skip symbol due to settings: %s', data[2])
- end
-
-
- end
- end
- end
-
- exports.squeeze_rule = function(s, func)
- if s then
- if not squeezed_symbols[s] then
- squeezed_symbols[s] = {
- cb = func,
- order = 0,
- sym = s,
- }
- logger.debugm(SN, rspamd_config, 'squeezed rule: %s', s)
- else
- logger.warnx(rspamd_config, 'duplicate symbol registered: %s, skip', s)
- end
- else
- -- Unconditionally add function to the squeezed rules
- local id = tostring(#squeezed_rules)
- logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', id)
- table.insert(squeezed_rules[1], {func, 'unnamed: ' .. id})
- end
-
- if not squeeze_function_ids[1] then
- squeeze_function_ids[1] = rspamd_config:register_symbol{
- type = 'callback',
- flags = 'squeezed',
- callback = gen_lua_squeeze_function(1),
- name = squeeze_sym,
- description = 'Meta rule for Lua rules that can be squeezed',
- no_squeeze = true, -- to avoid infinite recursion
- }
- end
-
- return squeeze_function_ids[1]
- end
-
- exports.squeeze_dependency = function(child, parent)
- logger.debugm(SN, rspamd_config, 'squeeze dep %s->%s', child, parent)
-
- if not squeezed_deps[parent] then
- squeezed_deps[parent] = {}
- end
-
- if not squeezed_deps[parent][child] then
- squeezed_deps[parent][child] = true
- else
- logger.warnx(rspamd_config, 'duplicate dependency %s->%s', child, parent)
- end
-
- if not squeezed_rdeps[child] then
- squeezed_rdeps[child] = {}
- end
-
- if not squeezed_rdeps[child][parent] then
- squeezed_rdeps[child][parent] = true
- end
-
- return true
- end
-
- local function get_ordered_symbol_name(order)
- if order == 1 then
- return squeeze_sym
- end
-
- return squeeze_sym .. tostring(order)
- end
-
- local function register_topology_symbol(order)
- local ord_sym = get_ordered_symbol_name(order)
-
- squeeze_function_ids[order] = rspamd_config:register_symbol{
- type = 'callback',
- flags = 'squeezed',
- callback = gen_lua_squeeze_function(order),
- name = ord_sym,
- description = 'Meta rule for Lua rules that can be squeezed, order ' .. tostring(order),
- no_squeeze = true, -- to avoid infinite recursion
- }
-
- local parent = get_ordered_symbol_name(order - 1)
- logger.debugm(SN, rspamd_config, 'registered new order of deps: %s->%s',
- ord_sym, parent)
- rspamd_config:register_dependency(ord_sym, parent, true)
- end
-
- exports.squeeze_init = function()
- -- Do topological sorting
- for _,v in pairs(squeezed_symbols) do
- local function visit(node, order)
-
- if order > node.order then
- node.order = order
- logger.debugm(SN, rspamd_config, "symbol: %s, order: %s", node.sym, order)
- else
- return
- end
-
- if squeezed_deps[node.sym] then
- for dep,_ in pairs(squeezed_deps[node.sym]) do
- if squeezed_symbols[dep] then
- visit(squeezed_symbols[dep], order + 1)
- end
- end
- end
- end
-
- if v.order == 0 then
- visit(v, 1)
- end
- end
-
- for parent,children in pairs(squeezed_deps) do
- if not squeezed_symbols[parent] then
- -- Trivial case, external dependnency
-
- for s,_ in pairs(children) do
-
- if squeezed_symbols[s] then
- -- External dep depends on a squeezed symbol
- logger.debugm(SN, rspamd_config, 'register external squeezed dependency on %s',
- parent)
- rspamd_config:register_dependency(squeeze_sym, parent, true)
- else
- -- Generic rspamd symbols dependency
- logger.debugm(SN, rspamd_config, 'register external dependency %s -> %s',
- s, parent)
- rspamd_config:register_dependency(s, parent, true)
- end
- end
- else
- -- Not so trivial case
- local ps = squeezed_symbols[parent]
-
- for cld,_ in pairs(children) do
- if squeezed_symbols[cld] then
- -- Cross dependency
- logger.debugm(SN, rspamd_config, 'cross dependency in squeezed symbols %s->%s',
- cld, parent)
- local order = squeezed_symbols[cld].order
- if not squeeze_function_ids[order] then
- -- Need to register new callback symbol to handle deps
- for i = 1, order do
- if not squeeze_function_ids[i] then
- register_topology_symbol(i)
- end
- end
- end
- else
- -- External symbol depends on a squeezed one
- local parent_symbol = get_ordered_symbol_name(ps.order)
- rspamd_config:register_dependency(cld, parent_symbol, true)
- logger.debugm(SN, rspamd_config, 'register squeezed dependency for external symbol %s->%s',
- cld, parent_symbol)
- end
- end
- end
- end
-
- -- We have now all deps being registered, so we can register virtual symbols
- -- and create squeezed rules
- for k,v in pairs(squeezed_symbols) do
- local parent_symbol = get_ordered_symbol_name(v.order)
- logger.debugm(SN, rspamd_config, 'added squeezed rule: %s (%s): %s',
- k, parent_symbol, v)
- rspamd_config:register_symbol{
- type = 'virtual',
- name = k,
- flags = 'squeezed',
- parent = squeeze_function_ids[v.order],
- no_squeeze = true, -- to avoid infinite recursion
- }
- local metric_sym = rspamd_config:get_metric_symbol(k)
-
- if metric_sym then
- v.group = metric_sym.group
- v.score = metric_sym.score
- v.description = metric_sym.description
-
- if not squeezed_groups[v.group] then
- logger.debugm(SN, rspamd_config, 'added squeezed group: %s', v.group)
- squeezed_groups[v.group] = {}
- end
-
- table.insert(squeezed_groups[v.group], k)
- else
- logger.debugm(SN, rspamd_config, 'no metric symbol found for %s, maybe bug', k)
- end
- if not squeezed_rules[v.order] then
- squeezed_rules[v.order] = {}
- end
- table.insert(squeezed_rules[v.order], {v.cb,k,v})
- end
- end
-
- exports.handle_settings = function(task, settings)
- local symbols_disabled = {}
- local symbols_enabled = {}
- local found = false
-
- if settings.default then settings = settings.default end
-
- local function disable_all()
- for k,_ in pairs(squeezed_symbols) do
- if not symbols_enabled[k] then
- symbols_disabled[k] = true
- end
- end
- end
-
- if settings.symbols_enabled then
- disable_all()
- found = true
- for _,s in ipairs(settings.symbols_enabled) do
- if squeezed_symbols[s] then
- logger.debugm(SN, task, 'enable symbol %s as it is in `symbols_enabled`', s)
- symbols_enabled[s] = true
- symbols_disabled[s] = nil
- end
- end
- end
-
- if settings.groups_enabled then
- disable_all()
- found = true
- for _,gr in ipairs(settings.groups_enabled) do
- if squeezed_groups[gr] then
- for _,sym in ipairs(squeezed_groups[gr]) do
- logger.debugm(SN, task, 'enable symbol %s as it is in `groups_enabled`', sym)
- symbols_enabled[sym] = true
- symbols_disabled[sym] = nil
- end
- end
- end
- end
-
- if settings.symbols_disabled then
- found = true
- for _,s in ipairs(settings.symbols_disabled) do
- if not symbols_enabled[s] then
- symbols_disabled[s] = true
- logger.debugm(SN, task, 'disable symbol %s as it is in `symbols_disabled`', s)
- end
- end
- end
-
- if settings.groups_disabled then
- found = true
- for _,gr in ipairs(settings.groups_disabled) do
- if squeezed_groups[gr] then
- for _,sym in ipairs(squeezed_groups[gr]) do
- if not symbols_enabled[sym] then
- logger.debugm(SN, task, 'disable symbol %s as it is in `groups_disabled`', sym)
- symbols_disabled[sym] = true
- end
- end
- end
- end
- end
-
- if found then
- task:cache_set('squeezed_disable', symbols_disabled)
- end
- end
-
- return exports
|