Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

lua_squeeze_rules.lua 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. --[[
  2. Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local exports = {}
  14. local logger = require 'rspamd_logger'
  15. local lua_util = require 'lua_util'
  16. -- Squeezed rules part
  17. local squeezed_rules = {{}} -- plain vector of all rules squeezed
  18. local squeezed_symbols = {} -- indexed by name of symbol
  19. local squeezed_deps = {} -- squeezed deps
  20. local squeezed_rdeps = {} -- squeezed reverse deps
  21. local SN = 'lua_squeeze'
  22. local squeeze_sym = 'LUA_SQUEEZE'
  23. local squeeze_function_ids = {}
  24. local squeezed_groups = {}
  25. local last_rule
  26. local virtual_symbols = {}
  27. local function gen_lua_squeeze_function(order)
  28. return function(task)
  29. local symbols_disabled = task:cache_get('squeezed_disable')
  30. local mime_task = task:has_flag('mime')
  31. for _,data in ipairs(squeezed_rules[order]) do
  32. local disable = false
  33. if symbols_disabled and symbols_disabled[data[2]] then
  34. disable = true
  35. end
  36. if data[3] and data[3].flags.mime then
  37. if not mime_task then
  38. disable = true
  39. end
  40. end
  41. if not disable then
  42. local function real_call()
  43. return {data[1](task)}
  44. end
  45. -- Too expensive to call :(
  46. lua_util.debugm(SN, task, 'call for: %s', data[2])
  47. local status, ret = pcall(real_call)
  48. if not status then
  49. logger.errx(task, 'error in squeezed rule %s: %s', data[2], ret)
  50. else
  51. if #ret ~= 0 then
  52. local first = ret[1]
  53. local sym = data[2]
  54. -- Function has returned something, so it is rule, not a plugin
  55. if type(first) == 'boolean' then
  56. if first then
  57. table.remove(ret, 1)
  58. local second = ret[1]
  59. if type(second) == 'number' then
  60. table.remove(ret, 1)
  61. if second ~= 0 then
  62. if type(ret[1]) == 'table' then
  63. task:insert_result(sym, second, ret[1])
  64. else
  65. task:insert_result(sym, second, ret)
  66. end
  67. end
  68. else
  69. if type(ret[1]) == 'table' then
  70. task:insert_result(sym, 1.0, ret[1])
  71. else
  72. task:insert_result(sym, 1.0, ret)
  73. end
  74. end
  75. end
  76. elseif type(first) == 'number' then
  77. table.remove(ret, 1)
  78. if first ~= 0 then
  79. if type(ret[1]) == 'table' then
  80. task:insert_result(sym, first, ret[1])
  81. else
  82. task:insert_result(sym, first, ret)
  83. end
  84. end
  85. else
  86. if type(ret[1]) == 'table' then
  87. task:insert_result(sym, 1.0, ret[1])
  88. else
  89. task:insert_result(sym, 1.0, ret)
  90. end
  91. end
  92. end
  93. end
  94. else
  95. lua_util.debugm(SN, task, 'skip symbol due to settings: %s', data[2])
  96. end
  97. end
  98. end
  99. end
  100. exports.squeeze_rule = function(s, func, flags)
  101. if s then
  102. if not squeezed_symbols[s] then
  103. squeezed_symbols[s] = {
  104. cb = func,
  105. order = 0,
  106. sym = s,
  107. flags = flags or {}
  108. }
  109. lua_util.debugm(SN, rspamd_config, 'squeezed rule: %s', s)
  110. else
  111. logger.warnx(rspamd_config, 'duplicate symbol registered: %s, skip', s)
  112. end
  113. else
  114. -- Unconditionally add function to the squeezed rules
  115. local id = tostring(#squeezed_rules)
  116. lua_util.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', id)
  117. table.insert(squeezed_rules[1], {func, 'unnamed: ' .. id, squeezed_symbols[s]})
  118. end
  119. if not squeeze_function_ids[1] then
  120. squeeze_function_ids[1] = rspamd_config:register_symbol{
  121. type = 'callback',
  122. flags = 'squeezed',
  123. callback = gen_lua_squeeze_function(1),
  124. name = squeeze_sym,
  125. description = 'Meta rule for Lua rules that can be squeezed',
  126. no_squeeze = true, -- to avoid infinite recursion
  127. }
  128. end
  129. last_rule = s
  130. return squeeze_function_ids[1]
  131. end
  132. -- TODO: poor approach, we register all virtual symbols with the previous real squeezed symbol
  133. exports.squeeze_virtual = function(id, name)
  134. if squeeze_function_ids[1] and id == squeeze_function_ids[1] and last_rule then
  135. virtual_symbols[name] = last_rule
  136. lua_util.debugm(SN, rspamd_config, 'add virtual symbol %s -> %s',
  137. name, last_rule)
  138. return id
  139. end
  140. return -1
  141. end
  142. exports.squeeze_dependency = function(child, parent)
  143. lua_util.debugm(SN, rspamd_config, 'squeeze dep %s->%s', child, parent)
  144. if not squeezed_deps[parent] then
  145. squeezed_deps[parent] = {}
  146. end
  147. if not squeezed_deps[parent][child] then
  148. squeezed_deps[parent][child] = true
  149. else
  150. logger.warnx(rspamd_config, 'duplicate dependency %s->%s', child, parent)
  151. end
  152. if not squeezed_rdeps[child] then
  153. squeezed_rdeps[child] = {}
  154. end
  155. if not squeezed_rdeps[child][parent] then
  156. squeezed_rdeps[child][parent] = true
  157. end
  158. return true
  159. end
  160. local function get_ordered_symbol_name(order)
  161. if order == 1 then
  162. return squeeze_sym
  163. end
  164. return squeeze_sym .. tostring(order)
  165. end
  166. local function register_topology_symbol(order)
  167. local ord_sym = get_ordered_symbol_name(order)
  168. squeeze_function_ids[order] = rspamd_config:register_symbol{
  169. type = 'callback',
  170. flags = 'squeezed',
  171. callback = gen_lua_squeeze_function(order),
  172. name = ord_sym,
  173. description = 'Meta rule for Lua rules that can be squeezed, order ' .. tostring(order),
  174. no_squeeze = true, -- to avoid infinite recursion
  175. }
  176. local parent = get_ordered_symbol_name(order - 1)
  177. lua_util.debugm(SN, rspamd_config, 'registered new order of deps: %s->%s',
  178. ord_sym, parent)
  179. rspamd_config:register_dependency(ord_sym, parent, true)
  180. end
  181. exports.squeeze_init = function()
  182. -- Do topological sorting
  183. for _,v in pairs(squeezed_symbols) do
  184. local function visit(node, order)
  185. if order > node.order then
  186. node.order = order
  187. lua_util.debugm(SN, rspamd_config, "symbol: %s, order: %s", node.sym, order)
  188. else
  189. return
  190. end
  191. if squeezed_deps[node.sym] then
  192. for dep,_ in pairs(squeezed_deps[node.sym]) do
  193. if squeezed_symbols[dep] then
  194. visit(squeezed_symbols[dep], order + 1)
  195. end
  196. end
  197. end
  198. end
  199. if v.order == 0 then
  200. visit(v, 1)
  201. end
  202. end
  203. for parent,children in pairs(squeezed_deps) do
  204. if not squeezed_symbols[parent] then
  205. local real_parent = virtual_symbols[parent]
  206. if real_parent then
  207. parent = real_parent
  208. end
  209. end
  210. if not squeezed_symbols[parent] then
  211. -- Trivial case, external dependnency
  212. for s,_ in pairs(children) do
  213. if squeezed_symbols[s] then
  214. -- External dep depends on a squeezed symbol
  215. lua_util.debugm(SN, rspamd_config, 'register external squeezed dependency on %s',
  216. parent)
  217. rspamd_config:register_dependency(squeeze_sym, parent, true)
  218. else
  219. -- Generic rspamd symbols dependency
  220. lua_util.debugm(SN, rspamd_config, 'register external dependency %s -> %s',
  221. s, parent)
  222. rspamd_config:register_dependency(s, parent, true)
  223. end
  224. end
  225. else
  226. -- Not so trivial case
  227. local ps = squeezed_symbols[parent]
  228. for cld,_ in pairs(children) do
  229. if squeezed_symbols[cld] then
  230. -- Cross dependency
  231. lua_util.debugm(SN, rspamd_config, 'cross dependency in squeezed symbols %s->%s',
  232. cld, parent)
  233. local order = squeezed_symbols[cld].order
  234. if not squeeze_function_ids[order] then
  235. -- Need to register new callback symbol to handle deps
  236. for i = 1, order do
  237. if not squeeze_function_ids[i] then
  238. register_topology_symbol(i)
  239. end
  240. end
  241. end
  242. else
  243. -- External symbol depends on a squeezed one
  244. local parent_symbol = get_ordered_symbol_name(ps.order)
  245. rspamd_config:register_dependency(cld, parent_symbol, true)
  246. lua_util.debugm(SN, rspamd_config, 'register squeezed dependency for external symbol %s->%s',
  247. cld, parent_symbol)
  248. end
  249. end
  250. end
  251. end
  252. -- We have now all deps being registered, so we can register virtual symbols
  253. -- and create squeezed rules
  254. for k,v in pairs(squeezed_symbols) do
  255. local parent_symbol = get_ordered_symbol_name(v.order)
  256. lua_util.debugm(SN, rspamd_config, 'added squeezed rule: %s (%s): %s',
  257. k, parent_symbol, v)
  258. rspamd_config:register_symbol{
  259. type = 'virtual',
  260. name = k,
  261. flags = 'squeezed',
  262. parent = squeeze_function_ids[v.order],
  263. no_squeeze = true, -- to avoid infinite recursion
  264. }
  265. local metric_sym = rspamd_config:get_metric_symbol(k)
  266. if metric_sym then
  267. v.group = metric_sym.group
  268. v.groups = metric_sym.groups
  269. v.score = metric_sym.score
  270. v.description = metric_sym.description
  271. if v.group then
  272. if not squeezed_groups[v.group] then
  273. lua_util.debugm(SN, rspamd_config, 'added squeezed group: %s', v.group)
  274. squeezed_groups[v.group] = {}
  275. end
  276. table.insert(squeezed_groups[v.group], k)
  277. end
  278. if v.groups then
  279. for _,gr in ipairs(v.groups) do
  280. if not squeezed_groups[gr] then
  281. lua_util.debugm(SN, rspamd_config, 'added squeezed group: %s', gr)
  282. squeezed_groups[gr] = {}
  283. end
  284. table.insert(squeezed_groups[gr], k)
  285. end
  286. end
  287. else
  288. lua_util.debugm(SN, rspamd_config, 'no metric symbol found for %s, maybe bug', k)
  289. end
  290. if not squeezed_rules[v.order] then
  291. squeezed_rules[v.order] = {}
  292. end
  293. table.insert(squeezed_rules[v.order], {v.cb,k,v})
  294. end
  295. end
  296. exports.handle_settings = function(task, settings)
  297. local symbols_disabled = {}
  298. local symbols_enabled = {}
  299. local found = false
  300. local disabled = false
  301. if settings.default then settings = settings.default end
  302. local function disable_all()
  303. for k,sym in pairs(squeezed_symbols) do
  304. if not symbols_enabled[k] and not (sym.flags and sym.flags.explicit_disable) then
  305. symbols_disabled[k] = true
  306. end
  307. end
  308. end
  309. if settings.symbols_enabled then
  310. disable_all()
  311. found = true
  312. disabled = true
  313. for _,s in ipairs(settings.symbols_enabled) do
  314. if squeezed_symbols[s] then
  315. lua_util.debugm(SN, task, 'enable symbol %s as it is in `symbols_enabled`', s)
  316. symbols_enabled[s] = true
  317. symbols_disabled[s] = nil
  318. end
  319. end
  320. end
  321. if settings.groups_enabled then
  322. if not disabled then
  323. disable_all()
  324. end
  325. found = true
  326. for _,gr in ipairs(settings.groups_enabled) do
  327. if squeezed_groups[gr] then
  328. for _,sym in ipairs(squeezed_groups[gr]) do
  329. lua_util.debugm(SN, task, 'enable symbol %s as it is in `groups_enabled`', sym)
  330. symbols_enabled[sym] = true
  331. symbols_disabled[sym] = nil
  332. end
  333. end
  334. end
  335. end
  336. if settings.symbols_disabled then
  337. found = true
  338. for _,s in ipairs(settings.symbols_disabled) do
  339. lua_util.debugm(SN, task, 'try disable symbol %s as it is in `symbols_disabled`', s)
  340. if not symbols_enabled[s] then
  341. symbols_disabled[s] = true
  342. lua_util.debugm(SN, task, 'disable symbol %s as it is in `symbols_disabled`', s)
  343. end
  344. end
  345. end
  346. if settings.groups_disabled then
  347. found = true
  348. for _,gr in ipairs(settings.groups_disabled) do
  349. lua_util.debugm(SN, task, 'try disable group %s as it is in `groups_disabled`: %s', gr)
  350. if squeezed_groups[gr] then
  351. for _,sym in ipairs(squeezed_groups[gr]) do
  352. if not symbols_enabled[sym] then
  353. lua_util.debugm(SN, task, 'disable symbol %s as it is in `groups_disabled`', sym)
  354. symbols_disabled[sym] = true
  355. end
  356. end
  357. end
  358. end
  359. end
  360. if found then
  361. task:cache_set('squeezed_disable', symbols_disabled)
  362. end
  363. end
  364. return exports