You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_squeeze_rules.lua 10KB


  1. --[[
  2. Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local exports = {}
  14. local logger = require 'rspamd_logger'
  15. -- Squeezed rules part
  16. local squeezed_rules = {{}} -- plain vector of all rules squeezed
  17. local squeezed_symbols = {} -- indexed by name of symbol
  18. local squeezed_deps = {} -- squeezed deps
  19. local squeezed_rdeps = {} -- squeezed reverse deps
  20. local SN = 'lua_squeeze'
  21. local squeeze_sym = 'LUA_SQUEEZE'
  22. local squeeze_function_ids = {}
  23. local squeezed_groups = {}
  24. local function gen_lua_squeeze_function(order)
  25. return function(task)
  26. local symbols_disabled = task:cache_get('squeezed_disable')
  27. for _,data in ipairs(squeezed_rules[order]) do
  28. if not symbols_disabled or not symbols_disabled[data[2]] then
  29. local ret = {data[1](task)}
  30. if #ret ~= 0 then
  31. local first = ret[1]
  32. local sym = data[2]
  33. -- Function has returned something, so it is rule, not a plugin
  34. if type(first) == 'boolean' then
  35. if first then
  36. table.remove(ret, 1)
  37. local second = ret[1]
  38. if type(second) == 'number' then
  39. table.remove(ret, 1)
  40. if second ~= 0 then
  41. if type(ret[1]) == 'table' then
  42. task:insert_result(sym, second, ret[1])
  43. else
  44. task:insert_result(sym, second, ret)
  45. end
  46. end
  47. else
  48. if type(ret[1]) == 'table' then
  49. task:insert_result(sym, 1.0, ret[1])
  50. else
  51. task:insert_result(sym, 1.0, ret)
  52. end
  53. end
  54. end
  55. elseif type(first) == 'number' then
  56. table.remove(ret, 1)
  57. if first ~= 0 then
  58. if type(ret[1]) == 'table' then
  59. task:insert_result(sym, first, ret[1])
  60. else
  61. task:insert_result(sym, first, ret)
  62. end
  63. end
  64. else
  65. if type(ret[1]) == 'table' then
  66. task:insert_result(sym, 1.0, ret[1])
  67. else
  68. task:insert_result(sym, 1.0, ret)
  69. end
  70. end
  71. end
  72. else
  73. logger.debugm(SN, task, 'skip symbol due to settings: %s', data[2])
  74. end
  75. end
  76. end
  77. end
  78. exports.squeeze_rule = function(s, func)
  79. if s then
  80. if not squeezed_symbols[s] then
  81. squeezed_symbols[s] = {
  82. cb = func,
  83. order = 0,
  84. sym = s,
  85. }
  86. logger.debugm(SN, rspamd_config, 'squeezed rule: %s', s)
  87. else
  88. logger.warnx(rspamd_config, 'duplicate symbol registered: %s, skip', s)
  89. end
  90. else
  91. -- Unconditionally add function to the squeezed rules
  92. local id = tostring(#squeezed_rules)
  93. logger.debugm(SN, rspamd_config, 'squeezed unnamed rule: %s', id)
  94. table.insert(squeezed_rules[1], {func, 'unnamed: ' .. id})
  95. end
  96. if not squeeze_function_ids[1] then
  97. squeeze_function_ids[1] = rspamd_config:register_symbol{
  98. type = 'callback',
  99. flags = 'squeezed',
  100. callback = gen_lua_squeeze_function(1),
  101. name = squeeze_sym,
  102. description = 'Meta rule for Lua rules that can be squeezed',
  103. no_squeeze = true, -- to avoid infinite recursion
  104. }
  105. end
  106. return squeeze_function_ids[1]
  107. end
  108. exports.squeeze_dependency = function(child, parent)
  109. logger.debugm(SN, rspamd_config, 'squeeze dep %s->%s', child, parent)
  110. if not squeezed_deps[parent] then
  111. squeezed_deps[parent] = {}
  112. end
  113. if not squeezed_deps[parent][child] then
  114. squeezed_deps[parent][child] = true
  115. else
  116. logger.warnx(rspamd_config, 'duplicate dependency %s->%s', child, parent)
  117. end
  118. if not squeezed_rdeps[child] then
  119. squeezed_rdeps[child] = {}
  120. end
  121. if not squeezed_rdeps[child][parent] then
  122. squeezed_rdeps[child][parent] = true
  123. end
  124. return true
  125. end
  126. local function get_ordered_symbol_name(order)
  127. if order == 1 then
  128. return squeeze_sym
  129. end
  130. return squeeze_sym .. tostring(order)
  131. end
  132. local function register_topology_symbol(order)
  133. local ord_sym = get_ordered_symbol_name(order)
  134. squeeze_function_ids[order] = rspamd_config:register_symbol{
  135. type = 'callback',
  136. flags = 'squeezed',
  137. callback = gen_lua_squeeze_function(order),
  138. name = ord_sym,
  139. description = 'Meta rule for Lua rules that can be squeezed, order ' .. tostring(order),
  140. no_squeeze = true, -- to avoid infinite recursion
  141. }
  142. local parent = get_ordered_symbol_name(order - 1)
  143. logger.debugm(SN, rspamd_config, 'registered new order of deps: %s->%s',
  144. ord_sym, parent)
  145. rspamd_config:register_dependency(ord_sym, parent, true)
  146. end
  147. exports.squeeze_init = function()
  148. -- Do topological sorting
  149. for _,v in pairs(squeezed_symbols) do
  150. local function visit(node, order)
  151. if order > node.order then
  152. node.order = order
  153. logger.debugm(SN, rspamd_config, "symbol: %s, order: %s", node.sym, order)
  154. else
  155. return
  156. end
  157. if squeezed_deps[node.sym] then
  158. for dep,_ in pairs(squeezed_deps[node.sym]) do
  159. if squeezed_symbols[dep] then
  160. visit(squeezed_symbols[dep], order + 1)
  161. end
  162. end
  163. end
  164. end
  165. if v.order == 0 then
  166. visit(v, 1)
  167. end
  168. end
  169. for parent,children in pairs(squeezed_deps) do
  170. if not squeezed_symbols[parent] then
  171. -- Trivial case, external dependnency
  172. for s,_ in pairs(children) do
  173. if squeezed_symbols[s] then
  174. -- External dep depends on a squeezed symbol
  175. logger.debugm(SN, rspamd_config, 'register external squeezed dependency on %s',
  176. parent)
  177. rspamd_config:register_dependency(squeeze_sym, parent, true)
  178. else
  179. -- Generic rspamd symbols dependency
  180. logger.debugm(SN, rspamd_config, 'register external dependency %s -> %s',
  181. s, parent)
  182. rspamd_config:register_dependency(s, parent, true)
  183. end
  184. end
  185. else
  186. -- Not so trivial case
  187. local ps = squeezed_symbols[parent]
  188. for cld,_ in pairs(children) do
  189. if squeezed_symbols[cld] then
  190. -- Cross dependency
  191. logger.debugm(SN, rspamd_config, 'cross dependency in squeezed symbols %s->%s',
  192. cld, parent)
  193. local order = squeezed_symbols[cld].order
  194. if not squeeze_function_ids[order] then
  195. -- Need to register new callback symbol to handle deps
  196. for i = 1, order do
  197. if not squeeze_function_ids[i] then
  198. register_topology_symbol(i)
  199. end
  200. end
  201. end
  202. else
  203. -- External symbol depends on a squeezed one
  204. local parent_symbol = get_ordered_symbol_name(ps.order)
  205. rspamd_config:register_dependency(cld, parent_symbol, true)
  206. logger.debugm(SN, rspamd_config, 'register squeezed dependency for external symbol %s->%s',
  207. cld, parent_symbol)
  208. end
  209. end
  210. end
  211. end
  212. -- We have now all deps being registered, so we can register virtual symbols
  213. -- and create squeezed rules
  214. for k,v in pairs(squeezed_symbols) do
  215. local parent_symbol = get_ordered_symbol_name(v.order)
  216. logger.debugm(SN, rspamd_config, 'added squeezed rule: %s (%s)', k, parent_symbol)
  217. rspamd_config:register_symbol{
  218. type = 'virtual',
  219. name = k,
  220. flags = 'squeezed',
  221. parent = squeeze_function_ids[v.order],
  222. no_squeeze = true, -- to avoid infinite recursion
  223. }
  224. local metric_sym = rspamd_config:get_metric_symbol(k)
  225. if metric_sym then
  226. v.group = metric_sym.group
  227. v.score = metric_sym.score
  228. v.description = metric_sym.description
  229. if not squeezed_groups[v.group] then
  230. logger.debugm(SN, rspamd_config, 'added squeezed group: %s', v.group)
  231. squeezed_groups[v.group] = {}
  232. end
  233. table.insert(squeezed_groups[v.group], k)
  234. else
  235. logger.debugm(SN, rspamd_config, 'no metric symbol found for %s, maybe bug', k)
  236. end
  237. if not squeezed_rules[v.order] then
  238. squeezed_rules[v.order] = {}
  239. end
  240. table.insert(squeezed_rules[v.order], {v.cb,k,v})
  241. end
  242. end
  243. exports.handle_settings = function(task, settings)
  244. local symbols_disabled = {}
  245. local symbols_enabled = {}
  246. local found = false
  247. if settings.default then settings = settings.default end
  248. local function disable_all()
  249. for k,_ in pairs(squeezed_symbols) do
  250. if not symbols_enabled[k] then
  251. symbols_disabled[k] = true
  252. end
  253. end
  254. end
  255. if settings.symbols_enabled then
  256. disable_all()
  257. found = true
  258. for _,s in ipairs(settings.symbols_enabled) do
  259. if squeezed_symbols[s] then
  260. logger.debugm(SN, task, 'enable symbol %s as it is in `symbols_enabled`', s)
  261. symbols_enabled[s] = true
  262. symbols_disabled[s] = nil
  263. end
  264. end
  265. end
  266. if settings.groups_enabled then
  267. disable_all()
  268. found = true
  269. for _,gr in ipairs(settings.groups_enabled) do
  270. if squeezed_groups[gr] then
  271. for _,sym in ipairs(squeezed_groups[gr]) do
  272. logger.debugm(SN, task, 'enable symbol %s as it is in `groups_enabled`', sym)
  273. symbols_enabled[sym] = true
  274. symbols_disabled[sym] = nil
  275. end
  276. end
  277. end
  278. end
  279. if settings.symbols_disabled then
  280. found = true
  281. for _,s in ipairs(settings.symbols_disabled) do
  282. if not symbols_enabled[s] then
  283. symbols_disabled[s] = true
  284. logger.debugm(SN, task, 'disable symbol %s as it is in `symbols_disabled`', s)
  285. end
  286. end
  287. end
  288. if settings.groups_disabled then
  289. found = true
  290. for _,gr in ipairs(settings.groups_disabled) do
  291. if squeezed_groups[gr] then
  292. for _,sym in ipairs(squeezed_groups[gr]) do
  293. if not symbols_enabled[sym] then
  294. logger.debugm(SN, task, 'disable symbol %s as it is in `groups_disabled`', sym)
  295. symbols_disabled[sym] = true
  296. end
  297. end
  298. end
  299. end
  300. end
  301. if found then
  302. task:cache_set('squeezed_disable', symbols_disabled)
  303. end
  304. end
  305. return exports