aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/trie.lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2015-04-07 17:22:40 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2015-04-07 17:22:40 +0100
commitfacb4c5d4aaac450609f7f845a6d42b7518008b6 (patch)
treecbf306f92cb8899a664fbc646ef110470e07c7ee /src/plugins/lua/trie.lua
parent125f1805175b7760930f404578642066d0981eb5 (diff)
downloadrspamd-facb4c5d4aaac450609f7f845a6d42b7518008b6.tar.gz
rspamd-facb4c5d4aaac450609f7f845a6d42b7518008b6.zip
Rewrite trie plugin.
Diffstat (limited to 'src/plugins/lua/trie.lua')
-rw-r--r--src/plugins/lua/trie.lua212
1 files changed, 115 insertions, 97 deletions
diff --git a/src/plugins/lua/trie.lua b/src/plugins/lua/trie.lua
index a66ca0877..81c89fde4 100644
--- a/src/plugins/lua/trie.lua
+++ b/src/plugins/lua/trie.lua
@@ -26,114 +26,132 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-- Trie is rspamd module designed to define and operate with suffix trie
-local tries = {}
local rspamd_logger = require "rspamd_logger"
local rspamd_trie = require "rspamd_trie"
+local _ = require "fun"
-local function split(str, delim, maxNb)
- -- Eliminate bad cases...
- if string.find(str, delim) == nil then
- return { str }
- end
- if maxNb == nil or maxNb < 1 then
- maxNb = 0 -- No limit
- end
- local result = {}
- local pat = "(.-)" .. delim .. "()"
- local nb = 0
- local lastPos
- for part, pos in string.gmatch(str, pat) do
- nb = nb + 1
- result[nb] = part
- lastPos = pos
- if nb == maxNb then break end
- end
- -- Handle the last field
- if nb ~= maxNb then
- result[nb + 1] = string.sub(str, lastPos)
- end
- return result
+local mime_trie
+local raw_trie
+
+-- here we store all patterns as text
+local mime_patterns = {}
+local raw_patterns = {}
+
+-- here we store params for each pattern, so for each i = 1..n patterns[i]
+-- should have corresponding params[i]
+local mime_params = {}
+local raw_params = {}
+
+local function tries_callback(task)
+
+ local matched = {}
+
+ local function gen_trie_cb(raw)
+ local patterns = mime_patterns
+ local params = mime_params
+ if raw then
+ patterns = raw_patterns
+ params = raw_params
+ end
+
+ return function (idx, pos)
+ local param = params[idx]
+ local pattern = patterns[idx]
+
+ rspamd_logger.debugx("<%1> matched pattern %2 at pos %3",
+ task:get_message_id(), pattern, pos)
+
+ if params['multi'] or not matched[pattern] then
+ task:insert_result(params['symbol'], 1.0)
+ if not params['multi'] then
+ matched[pattern] = true
+ end
+ end
+ end
+ end
+
+ if mime_trie then
+ mime_trie:search_mime(task, gen_trie_cb(false))
+ end
+ if raw_trie then
+ raw_trie:search_rawmsg(task, gen_trie_cb(true))
+ end
end
-local function add_trie(params)
- local symbol = params[1]
-
- file = io.open(params[2])
- if file then
- local trie = {}
- trie['trie'] = rspamd_trie.create(true)
- num = 0
- for line in file:lines() do
- trie['trie']:add_pattern(line, num)
- num = num + 1
- end
-
- if type(rspamd_config.get_api_version) ~= 'nil' then
- rspamd_config:register_virtual_symbol(symbol, 1.0)
- end
- file:close()
- trie['symbol'] = symbol
- table.insert(tries, trie)
- else
- local patterns = split(params[2], ',')
- local trie = {}
- trie['trie'] = rspamd_trie.create(true)
- for num,pattern in ipairs(patterns) do
- trie['trie']:add_pattern(pattern, num)
- end
- if type(rspamd_config.get_api_version) ~= 'nil' then
- rspamd_config:register_virtual_symbol(symbol, 1.0)
- end
- trie['symbol'] = symbol
- table.insert(tries, trie)
- end
+local function process_single_pattern(pat, symbol, cf)
+ if pat then
+ if cf['raw'] then
+ table.insert(raw_patterns, pat)
+ table.insert(raw_params, {symbol=symbol, multi=multi})
+ else
+ table.insert(mime_patterns, pat)
+ table.insert(mime_params, {symbol=symbol, multi=multi})
+ end
+ end
end
-function check_trie(task)
- for _,trie in ipairs(tries) do
- if trie['trie']:search_task(task) then
- task:insert_result(trie['symbol'], 1)
- return
- end
- -- Search inside urls
- urls = task:get_urls()
- if urls then
- for _,url in ipairs(urls) do
- if trie['trie']:search_text(url:get_text()) then
- task:insert_result(trie['symbol'], 1)
- return
- end
- end
- end
- end
+local function process_trie_file(symbol, cf)
+ file = io.open(cf['file'])
+
+ if not file then
+ rspamd_logger.errx('Cannot open trie file %1', cf['file'])
+ else
+ if cf['binary'] then
+ rspamd_logger.errx('binary trie patterns are not implemented yet: %1',
+ cf['file'])
+ else
+ local multi = false
+ if cf['multi'] then multi = true end
+
+ for line in file:lines() do
+ local pat = string.match(line, '^([^#].*[^%s])%s*$')
+ process_single_pattern(pat, symbol, cf)
+ end
+ end
+ end
end
--- Registration
-if type(rspamd_config.get_api_version) ~= 'nil' then
- if rspamd_config:get_api_version() >= 1 then
- rspamd_config:register_module_option('trie', 'rule', 'string')
- end
+local function process_trie_conf(symbol, cf)
+ local raw = false
+
+ if type(cf) ~= 'table' then
+ rspamd_logger.errx('invalid value for symbol %1: "%2", expected table',
+ symbol, cf)
+ return
+ end
+
+ if cf['raw'] then raw = true end
+
+ if cf['file'] then
+ process_trie_file(symbol, cf)
+ elseif cf['patterns'] then
+ _.each(function(pat)
+ process_single_pattern(pat, symbol, cf)
+ end, cf['patterns'])
+ end
+
+ rspamd_config:register_virtual_symbol(symbol, 1.0)
end
-local opts = rspamd_config:get_all_opt('trie')
+local opts = rspamd_config:get_key("trie")
if opts then
- local strrules = opts['rule']
- if strrules then
- if type(strrules) == 'table' then
- for _,value in ipairs(strrules) do
- local params = split(value, ':')
- add_trie(params)
- end
- elseif type(strrules) == 'string' then
- local params = split(strrules, ':')
- add_trie (params)
- end
- end
- if table.maxn(tries) then
- if type(rspamd_config.get_api_version) ~= 'nil' then
- rspamd_config:register_callback_symbol('TRIE', 1.0, 'check_trie')
- else
- rspamd_config:register_symbol('TRIE', 1.0, 'check_trie')
- end
+ for sym, opt in pairs(opts) do
+ process_trie_conf(sym, opt)
+ end
+
+ if #raw_patterns > 0 then
+ raw_trie = rspamd_trie.create(raw_patterns)
+ rspamd_logger.infox('registered raw search trie from %1 patterns', #raw_patterns)
end
+
+ if #mime_patterns > 0 then
+ mime_trie = rspamd_trie.create(mime_patterns)
+ rspamd_logger.infox('registered mime search trie from %1 patterns', #mime_patterns)
+ end
+
+ if mime_trie or raw_trie then
+ rspamd_config:register_callback_symbol('TRIE', 1.0, tries_callback)
+ else
+ rspamd_logger.err('no tries defined')
+ end
end