]> source.dussan.org Git - rspamd.git/commitdiff
Rewrite trie plugin.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 7 Apr 2015 16:22:40 +0000 (17:22 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 7 Apr 2015 16:22:40 +0000 (17:22 +0100)
src/plugins/lua/trie.lua

index a66ca0877767154fdfe841631897c6706d7e27d5..81c89fde49934e199be7d99ba23ea20528f80649 100644 (file)
@@ -26,114 +26,132 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 -- Trie is rspamd module designed to define and operate with suffix trie
 
-local tries = {}
 local rspamd_logger = require "rspamd_logger"
 local rspamd_trie = require "rspamd_trie"
+local _ = require "fun"
 
-local function split(str, delim, maxNb)
-       -- Eliminate bad cases...
-       if string.find(str, delim) == nil then
-               return { str }
-       end
-       if maxNb == nil or maxNb < 1 then
-               maxNb = 0    -- No limit
-       end
-       local result = {}
-       local pat = "(.-)" .. delim .. "()"
-       local nb = 0
-       local lastPos
-       for part, pos in string.gmatch(str, pat) do
-               nb = nb + 1
-               result[nb] = part
-               lastPos = pos
-               if nb == maxNb then break end
-       end
-       -- Handle the last field
-       if nb ~= maxNb then
-               result[nb + 1] = string.sub(str, lastPos)
-       end
-       return result
+local mime_trie
+local raw_trie
+
+-- here we store all patterns as text
+local mime_patterns = {}
+local raw_patterns = {}
+
+-- here we store params for each pattern, so for each i = 1..n patterns[i] 
+-- should have corresponding params[i]
+local mime_params = {}
+local raw_params = {}
+
+local function tries_callback(task)
+  
+  local matched = {}
+
+  local function gen_trie_cb(raw)
+    local patterns = mime_patterns
+    local params = mime_params
+    if raw then
+      patterns = raw_patterns
+      params = raw_params
+    end
+    
+    return function (idx, pos)
+      local param = params[idx]
+      local pattern = patterns[idx]
+      
+      rspamd_logger.debugx("<%1> matched pattern %2 at pos %3",
+        task:get_message_id(), pattern, pos)
+      
+      if params['multi'] or not matched[pattern] then
+        task:insert_result(params['symbol'], 1.0)
+        if not params['multi'] then
+          matched[pattern] = true
+        end
+      end
+    end
+  end
+  
+  if mime_trie then
+    mime_trie:search_mime(task, gen_trie_cb(false))
+  end
+  if raw_trie then
+    raw_trie:search_rawmsg(task, gen_trie_cb(true))
+  end
 end
 
-local function add_trie(params)
-       local symbol = params[1]
-       
-       file = io.open(params[2])
-       if file then
-               local trie = {}
-               trie['trie'] = rspamd_trie.create(true)
-               num = 0
-               for line in file:lines() do
-                       trie['trie']:add_pattern(line, num)
-                       num = num + 1
-               end
-               
-               if type(rspamd_config.get_api_version) ~= 'nil' then
-                       rspamd_config:register_virtual_symbol(symbol, 1.0)
-               end
-               file:close()
-               trie['symbol'] = symbol
-               table.insert(tries, trie)
-       else
-               local patterns = split(params[2], ',')
-               local trie = {}
-               trie['trie'] = rspamd_trie.create(true)
-               for num,pattern in ipairs(patterns) do
-                       trie['trie']:add_pattern(pattern, num)
-               end
-               if type(rspamd_config.get_api_version) ~= 'nil' then
-                       rspamd_config:register_virtual_symbol(symbol, 1.0)
-               end
-               trie['symbol'] = symbol
-               table.insert(tries, trie)
-       end
+local function process_single_pattern(pat, symbol, cf)
+  if pat then
+    if cf['raw'] then
+      table.insert(raw_patterns, pat)
+      table.insert(raw_params, {symbol=symbol, multi=multi})
+    else
+      table.insert(mime_patterns, pat)
+      table.insert(mime_params, {symbol=symbol, multi=multi})
+    end
+  end
 end
 
-function check_trie(task)
-       for _,trie in ipairs(tries) do
-               if trie['trie']:search_task(task) then
-                       task:insert_result(trie['symbol'], 1)
-                       return
-               end
-               -- Search inside urls
-               urls = task:get_urls()
-               if urls then
-                       for _,url in ipairs(urls) do
-                               if trie['trie']:search_text(url:get_text()) then
-                                       task:insert_result(trie['symbol'], 1)
-                                       return
-                               end
-                       end
-               end
-       end
+local function process_trie_file(symbol, cf)
+  file = io.open(cf['file'])
+  
+  if not file then
+    rspamd_logger.errx('Cannot open trie file %1', cf['file'])
+  else
+    if cf['binary'] then
+      rspamd_logger.errx('binary trie patterns are not implemented yet: %1', 
+        cf['file'])
+    else
+      local multi = false
+      if cf['multi'] then multi = true end
+      
+      for line in file:lines() do
+        local pat = string.match(line, '^([^#].*[^%s])%s*$')
+        process_single_pattern(pat, symbol, cf)
+      end
+    end
+  end
 end
 
--- Registration
-if type(rspamd_config.get_api_version) ~= 'nil' then
-       if rspamd_config:get_api_version() >= 1 then
-               rspamd_config:register_module_option('trie', 'rule', 'string')
-       end
+local function process_trie_conf(symbol, cf)
+  local raw = false
+  
+  if type(cf) ~= 'table' then
+    rspamd_logger.errx('invalid value for symbol %1: "%2", expected table', 
+      symbol, cf)
+    return
+  end
+  
+  if cf['raw'] then raw = true end
+  
+  if cf['file'] then
+    process_trie_file(symbol, cf)
+  elseif cf['patterns'] then
+    _.each(function(pat)
+      process_single_pattern(pat, symbol, cf)
+    end, cf['patterns'])
+  end
+  
+  rspamd_config:register_virtual_symbol(symbol, 1.0)
 end
 
-local opts =  rspamd_config:get_all_opt('trie')
+local opts =  rspamd_config:get_key("trie")
 if opts then
-       local strrules = opts['rule']
-       if strrules then
-               if type(strrules) == 'table' then 
-                       for _,value in ipairs(strrules) do
-                               local params = split(value, ':')
-                               add_trie(params)
-                       end
-               elseif type(strrules) == 'string' then
-                       local params = split(strrules, ':')
-                       add_trie (params)
-               end
-       end
-       if table.maxn(tries) then
-               if type(rspamd_config.get_api_version) ~= 'nil' then
-                       rspamd_config:register_callback_symbol('TRIE', 1.0, 'check_trie')
-               else
-                       rspamd_config:register_symbol('TRIE', 1.0, 'check_trie')
-               end
+  for sym, opt in pairs(opts) do
+     process_trie_conf(sym, opt)
+  end
+  
+  if #raw_patterns > 0 then
+    raw_trie = rspamd_trie.create(raw_patterns)
+    rspamd_logger.infox('registered raw search trie from %1 patterns', #raw_patterns)
        end
+
+  if #mime_patterns > 0 then
+    mime_trie = rspamd_trie.create(mime_patterns)
+    rspamd_logger.infox('registered mime search trie from %1 patterns', #mime_patterns)
+  end
+
+  if mime_trie or raw_trie then
+    rspamd_config:register_callback_symbol('TRIE', 1.0, tries_callback)
+  else
+    rspamd_logger.err('no tries defined')
+  end
 end