aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/trie.lua
blob: d0d77e7960ecb7f553793d2fd808a71f2bd9b4b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
-- Trie is rspamd module designed to define and operate with suffix trie

local tries = {}
local rspamd_logger = require "rspamd_logger"
local rspamd_trie = require "rspamd_trie"

local function split(str, delim, maxNb)
	-- Eliminate bad cases...
	if string.find(str, delim) == nil then
		return { str }
	end
	if maxNb == nil or maxNb < 1 then
		maxNb = 0    -- No limit
	end
	local result = {}
	local pat = "(.-)" .. delim .. "()"
	local nb = 0
	local lastPos
	for part, pos in string.gmatch(str, pat) do
		nb = nb + 1
		result[nb] = part
		lastPos = pos
		if nb == maxNb then break end
	end
	-- Handle the last field
	if nb ~= maxNb then
		result[nb + 1] = string.sub(str, lastPos)
	end
	return result
end

local function add_trie(params)
	local symbol = params[1]
	
	file = io.open(params[2])
	if file then
		local trie = {}
		trie['trie'] = rspamd_trie.create(true)
		num = 0
		for line in file:lines() do
			trie['trie']:add_pattern(line, num)
			num = num + 1
		end
		
		if type(rspamd_config.get_api_version) ~= 'nil' then
			rspamd_config:register_virtual_symbol(symbol, 1.0)
		end
		file:close()
		trie['symbol'] = symbol
		table.insert(tries, trie)
	else
		local patterns = split(params[2], ',')
		local trie = {}
		trie['trie'] = rspamd_trie.create(true)
		for num,pattern in ipairs(patterns) do
			trie['trie']:add_pattern(pattern, num)
		end
		if type(rspamd_config.get_api_version) ~= 'nil' then
			rspamd_config:register_virtual_symbol(symbol, 1.0)
		end
		trie['symbol'] = symbol
		table.insert(tries, trie)
	end
end

function check_trie(task)
	for _,trie in ipairs(tries) do
		if trie['trie']:search_task(task) then
			task:insert_result(trie['symbol'], 1)
			return
		end
		-- Search inside urls
		urls = task:get_urls()
		if urls then
			for _,url in ipairs(urls) do
				if trie['trie']:search_text(url:get_text()) then
					task:insert_result(trie['symbol'], 1)
					return
				end
			end
		end
	end
end

-- Registration
if type(rspamd_config.get_api_version) ~= 'nil' then
	if rspamd_config:get_api_version() >= 1 then
		rspamd_config:register_module_option('trie', 'rule', 'string')
	end
end

local opts =  rspamd_config:get_all_opt('trie')
if opts then
	local strrules = opts['rule']
	if strrules then
		if type(strrules) == 'table' then 
			for _,value in ipairs(strrules) do
				local params = split(value, ':')
				add_trie(params)
			end
		elseif type(strrules) == 'string' then
			local params = split(strrules, ':')
			add_trie (params)
		end
	end
	if table.maxn(tries) then
		if type(rspamd_config.get_api_version) ~= 'nil' then
			rspamd_config:register_callback_symbol('TRIE', 1.0, 'check_trie')
		else
			rspamd_config:register_symbol('TRIE', 1.0, 'check_trie')
		end
	end
end