You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_squeeze_rules.lua 11KB


  1. --[[
  2. Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local exports = {}
  14. local logger = require 'rspamd_logger'
  15. -- Squeezed rules part
  16. local squeezed_rules = {{}} -- plain vector of all rules squeezed
  17. local squeezed_symbols = {} -- indexed by name of symbol
  18. local squeezed_deps = {} -- squeezed deps
  19. local squeezed_rdeps = {} -- squeezed reverse deps
  20. local SN = 'lua_squeeze'
  21. local squeeze_sym = 'LUA_SQUEEZE'
  22. local squeeze_function_ids = {}
  23. local squeezed_groups = {}
  24. local function gen_lua_squeeze_function(order)
  25. return function(task)
  26. local symbols_disabled = task:cache_get('squeezed_disable')
  27. for _,data in ipairs(squeezed_rules[order]) do
  28. if not symbols_disabled or not symbols_disabled[data[2]] then
  29. local function real_call()
  30. return {data[1](task)}
  31. end
  32. -- Too expensive to call :(
  33. --logger.debugm(SN, task, 'call for: %s', data[2])
  34. local status, ret = pcall(real_call)
  35. if not status then
  36. logger.errx(task, 'error in squeezed rule %s: %s', data[2], ret)
  37. else
  38. if #ret ~= 0 then
  39. local first = ret[1]
  40. local sym = data[2]
  41. -- Function has returned something, so it is rule, not a plugin
  42. if type(first) == 'boolean' then
  43. if first then
  44. table.remove(ret, 1)
  45. local second = ret[1]
  46. if type(second) == 'number' then
  47. table.remove(ret, 1)
  48. if second ~= 0 then
  49. if type(ret[1]) == 'table' then
  50. task:insert_result(sym, second, ret[1])
  51. else
  52. task:insert_result(sym, second, ret)
  53. end
  54. end
  55. else
  56. if type(ret[1]) == 'table' then
  57. task:insert_result(sym, 1.0, ret[1])
  58. else
  59. task:insert_result(sym, 1.0, ret)
  60. end
  61. end
  62. end
  63. elseif type(first) == 'number' then
  64. table.remove(ret, 1)
  65. if first ~= 0 then
  66. if type(ret[1]) == 'table' then
  67. task:insert_result(sym, first, ret[1])
  68. else
  69. task:insert_result(sym, first, ret)
  70. end
  71. end
  72. else
  73. if type(ret[1]) == 'table' then
  74. task:insert_result(sym, 1.0, ret[1])
  75. else
  76. task:insert_result(sym, 1.0, ret)
  77. end
  78. end
  79. end
  80. end
  81. else
  82. logger.debugm(SN, task, 'skip symbol due to settings: %s', data[2])
  83. end
  84. end
  85. end
  86. end
  87. exports.squeeze_rule = function(s, func)
  88. if s then
  89. if not squeezed_symbols[s] then
  90. squeezed_symbols[s] = {
  91. cb = func,
  92. order = 0,
  93. sym = s,
  94. }
  95. logger.debugm(SN, rspamd_config, 'squeezed rule: %s', s)
  96. else
  97. logger.warnx(rspamd_config, 'duplicate symbol registered: %s, skip', s)
  98. end
  99. else
  100. -- Unconditionally add function to the squeezed rules
  101. local id = tostring(#squeezed_rules)
  102. logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', id)
  103. table.insert(squeezed_rules[1], {func, 'unnamed: ' .. id})
  104. end
  105. if not squeeze_function_ids[1] then
  106. squeeze_function_ids[1] = rspamd_config:register_symbol{
  107. type = 'callback',
  108. flags = 'squeezed',
  109. callback = gen_lua_squeeze_function(1),
  110. name = squeeze_sym,
  111. description = 'Meta rule for Lua rules that can be squeezed',
  112. no_squeeze = true, -- to avoid infinite recursion
  113. }
  114. end
  115. return squeeze_function_ids[1]
  116. end
  117. exports.squeeze_dependency = function(child, parent)
  118. logger.debugm(SN, rspamd_config, 'squeeze dep %s->%s', child, parent)
  119. if not squeezed_deps[parent] then
  120. squeezed_deps[parent] = {}
  121. end
  122. if not squeezed_deps[parent][child] then
  123. squeezed_deps[parent][child] = true
  124. else
  125. logger.warnx(rspamd_config, 'duplicate dependency %s->%s', child, parent)
  126. end
  127. if not squeezed_rdeps[child] then
  128. squeezed_rdeps[child] = {}
  129. end
  130. if not squeezed_rdeps[child][parent] then
  131. squeezed_rdeps[child][parent] = true
  132. end
  133. return true
  134. end
  135. local function get_ordered_symbol_name(order)
  136. if order == 1 then
  137. return squeeze_sym
  138. end
  139. return squeeze_sym .. tostring(order)
  140. end
  141. local function register_topology_symbol(order)
  142. local ord_sym = get_ordered_symbol_name(order)
  143. squeeze_function_ids[order] = rspamd_config:register_symbol{
  144. type = 'callback',
  145. flags = 'squeezed',
  146. callback = gen_lua_squeeze_function(order),
  147. name = ord_sym,
  148. description = 'Meta rule for Lua rules that can be squeezed, order ' .. tostring(order),
  149. no_squeeze = true, -- to avoid infinite recursion
  150. }
  151. local parent = get_ordered_symbol_name(order - 1)
  152. logger.debugm(SN, rspamd_config, 'registered new order of deps: %s->%s',
  153. ord_sym, parent)
  154. rspamd_config:register_dependency(ord_sym, parent, true)
  155. end
  156. exports.squeeze_init = function()
  157. -- Do topological sorting
  158. for _,v in pairs(squeezed_symbols) do
  159. local function visit(node, order)
  160. if order > node.order then
  161. node.order = order
  162. logger.debugm(SN, rspamd_config, "symbol: %s, order: %s", node.sym, order)
  163. else
  164. return
  165. end
  166. if squeezed_deps[node.sym] then
  167. for dep,_ in pairs(squeezed_deps[node.sym]) do
  168. if squeezed_symbols[dep] then
  169. visit(squeezed_symbols[dep], order + 1)
  170. end
  171. end
  172. end
  173. end
  174. if v.order == 0 then
  175. visit(v, 1)
  176. end
  177. end
  178. for parent,children in pairs(squeezed_deps) do
  179. if not squeezed_symbols[parent] then
  180. -- Trivial case, external dependnency
  181. for s,_ in pairs(children) do
  182. if squeezed_symbols[s] then
  183. -- External dep depends on a squeezed symbol
  184. logger.debugm(SN, rspamd_config, 'register external squeezed dependency on %s',
  185. parent)
  186. rspamd_config:register_dependency(squeeze_sym, parent, true)
  187. else
  188. -- Generic rspamd symbols dependency
  189. logger.debugm(SN, rspamd_config, 'register external dependency %s -> %s',
  190. s, parent)
  191. rspamd_config:register_dependency(s, parent, true)
  192. end
  193. end
  194. else
  195. -- Not so trivial case
  196. local ps = squeezed_symbols[parent]
  197. for cld,_ in pairs(children) do
  198. if squeezed_symbols[cld] then
  199. -- Cross dependency
  200. logger.debugm(SN, rspamd_config, 'cross dependency in squeezed symbols %s->%s',
  201. cld, parent)
  202. local order = squeezed_symbols[cld].order
  203. if not squeeze_function_ids[order] then
  204. -- Need to register new callback symbol to handle deps
  205. for i = 1, order do
  206. if not squeeze_function_ids[i] then
  207. register_topology_symbol(i)
  208. end
  209. end
  210. end
  211. else
  212. -- External symbol depends on a squeezed one
  213. local parent_symbol = get_ordered_symbol_name(ps.order)
  214. rspamd_config:register_dependency(cld, parent_symbol, true)
  215. logger.debugm(SN, rspamd_config, 'register squeezed dependency for external symbol %s->%s',
  216. cld, parent_symbol)
  217. end
  218. end
  219. end
  220. end
  221. -- We have now all deps being registered, so we can register virtual symbols
  222. -- and create squeezed rules
  223. for k,v in pairs(squeezed_symbols) do
  224. local parent_symbol = get_ordered_symbol_name(v.order)
  225. logger.debugm(SN, rspamd_config, 'added squeezed rule: %s (%s): %s',
  226. k, parent_symbol, v)
  227. rspamd_config:register_symbol{
  228. type = 'virtual',
  229. name = k,
  230. flags = 'squeezed',
  231. parent = squeeze_function_ids[v.order],
  232. no_squeeze = true, -- to avoid infinite recursion
  233. }
  234. local metric_sym = rspamd_config:get_metric_symbol(k)
  235. if metric_sym then
  236. v.group = metric_sym.group
  237. v.score = metric_sym.score
  238. v.description = metric_sym.description
  239. if not squeezed_groups[v.group] then
  240. logger.debugm(SN, rspamd_config, 'added squeezed group: %s', v.group)
  241. squeezed_groups[v.group] = {}
  242. end
  243. table.insert(squeezed_groups[v.group], k)
  244. else
  245. logger.debugm(SN, rspamd_config, 'no metric symbol found for %s, maybe bug', k)
  246. end
  247. if not squeezed_rules[v.order] then
  248. squeezed_rules[v.order] = {}
  249. end
  250. table.insert(squeezed_rules[v.order], {v.cb,k,v})
  251. end
  252. end
  253. exports.handle_settings = function(task, settings)
  254. local symbols_disabled = {}
  255. local symbols_enabled = {}
  256. local found = false
  257. if settings.default then settings = settings.default end
  258. local function disable_all()
  259. for k,_ in pairs(squeezed_symbols) do
  260. if not symbols_enabled[k] then
  261. symbols_disabled[k] = true
  262. end
  263. end
  264. end
  265. if settings.symbols_enabled then
  266. disable_all()
  267. found = true
  268. for _,s in ipairs(settings.symbols_enabled) do
  269. if squeezed_symbols[s] then
  270. logger.debugm(SN, task, 'enable symbol %s as it is in `symbols_enabled`', s)
  271. symbols_enabled[s] = true
  272. symbols_disabled[s] = nil
  273. end
  274. end
  275. end
  276. if settings.groups_enabled then
  277. disable_all()
  278. found = true
  279. for _,gr in ipairs(settings.groups_enabled) do
  280. if squeezed_groups[gr] then
  281. for _,sym in ipairs(squeezed_groups[gr]) do
  282. logger.debugm(SN, task, 'enable symbol %s as it is in `groups_enabled`', sym)
  283. symbols_enabled[sym] = true
  284. symbols_disabled[sym] = nil
  285. end
  286. end
  287. end
  288. end
  289. if settings.symbols_disabled then
  290. found = true
  291. for _,s in ipairs(settings.symbols_disabled) do
  292. if not symbols_enabled[s] then
  293. symbols_disabled[s] = true
  294. logger.debugm(SN, task, 'disable symbol %s as it is in `symbols_disabled`', s)
  295. end
  296. end
  297. end
  298. if settings.groups_disabled then
  299. found = true
  300. for _,gr in ipairs(settings.groups_disabled) do
  301. if squeezed_groups[gr] then
  302. for _,sym in ipairs(squeezed_groups[gr]) do
  303. if not symbols_enabled[sym] then
  304. logger.debugm(SN, task, 'disable symbol %s as it is in `groups_disabled`', sym)
  305. symbols_disabled[sym] = true
  306. end
  307. end
  308. end
  309. end
  310. end
  311. if found then
  312. task:cache_set('squeezed_disable', symbols_disabled)
  313. end
  314. end
  315. return exports