You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

clustering.lua 9.2KB


  1. --[[
  2. Copyright (c) 2018, Vsevolod Stakhov
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. if confighelp then
  14. return
  15. end
  16. -- Plugin for finding patterns in email flows
  17. local N = 'clustering'
  18. local rspamd_logger = require "rspamd_logger"
  19. local lua_util = require "lua_util"
  20. local lua_verdict = require "lua_verdict"
  21. local lua_redis = require "lua_redis"
  22. local lua_selectors = require "lua_selectors"
  23. local ts = require("tableshape").types
  24. local redis_params
  25. local rules = {} -- Rules placement
  26. local default_rule = {
  27. max_elts = 100, -- Maximum elements in a cluster
  28. expire = 3600, -- Expire for a bucket when limit is not reached
  29. expire_overflow = 36000, -- Expire for a bucket when limit is reached
  30. spam_mult = 1.0, -- Increase on spam hit
  31. junk_mult = 0.5, -- Increase on junk
  32. ham_mult = -0.1, -- Increase on ham
  33. size_mult = 0.01, -- Reaches 1.0 on `max_elts`
  34. score_mult = 0.1,
  35. }
  36. local rule_schema = ts.shape {
  37. max_elts = ts.number + ts.string / tonumber,
  38. expire = ts.number + ts.string / lua_util.parse_time_interval,
  39. expire_overflow = ts.number + ts.string / lua_util.parse_time_interval,
  40. spam_mult = ts.number,
  41. junk_mult = ts.number,
  42. ham_mult = ts.number,
  43. size_mult = ts.number,
  44. score_mult = ts.number,
  45. source_selector = ts.string,
  46. cluster_selector = ts.string,
  47. symbol = ts.string:is_optional(),
  48. prefix = ts.string:is_optional(),
  49. }
  50. -- Redis scripts
  51. -- Queries for a cluster's data
  52. -- Arguments:
  53. -- 1. Source selector (string)
  54. -- 2. Cluster selector (string)
  55. -- Returns: {cur_elts, total_score, element_score}
  56. local query_cluster_script = [[
  57. local sz = redis.call('HLEN', KEYS[1])
  58. if not sz or not tonumber(sz) then
  59. -- New bucket, will update on idempotent phase
  60. return {0, '0', '0'}
  61. end
  62. local total_score = redis.call('HGET', KEYS[1], '__s')
  63. total_score = tonumber(total_score) or 0
  64. local score = redis.call('HGET', KEYS[1], KEYS[2])
  65. if not score or not tonumber(score) then
  66. return {sz, tostring(total_score), '0'}
  67. end
  68. return {sz, tostring(total_score), tostring(score)}
  69. ]]
  70. local query_cluster_id
  71. -- Updates cluster's data
  72. -- Arguments:
  73. -- 1. Source selector (string)
  74. -- 2. Cluster selector (string)
  75. -- 3. Score (number)
  76. -- 4. Max buckets (number)
  77. -- 5. Expire (number)
  78. -- 6. Expire overflow (number)
  79. -- Returns: nothing
  80. local update_cluster_script = [[
  81. local sz = redis.call('HLEN', KEYS[1])
  82. if not sz or not tonumber(sz) then
  83. -- Create bucket
  84. redis.call('HSET', KEYS[1], KEYS[2], math.abs(KEYS[3]))
  85. redis.call('HSET', KEYS[1], '__s', KEYS[3])
  86. redis.call('EXPIRE', KEYS[1], KEYS[5])
  87. return
  88. end
  89. sz = tonumber(sz)
  90. local lim = tonumber(KEYS[4])
  91. if sz > lim then
  92. if k then
  93. -- Existing key
  94. redis.call('HINCRBYFLOAT', KEYS[1], KEYS[2], math.abs(KEYS[3]))
  95. end
  96. else
  97. redis.call('HINCRBYFLOAT', KEYS[1], KEYS[2], math.abs(KEYS[3]))
  98. redis.call('EXPIRE', KEYS[1], KEYS[6])
  99. end
  100. redis.call('HINCRBYFLOAT', KEYS[1], '__s', KEYS[3])
  101. redis.call('EXPIRE', KEYS[1], KEYS[5])
  102. ]]
  103. local update_cluster_id
  104. -- Callbacks and logic
  105. local function clusterting_filter_cb(task, rule)
  106. local source_selector = rule.source_selector(task)
  107. local cluster_selector
  108. if source_selector then
  109. cluster_selector = rule.cluster_selector(task)
  110. end
  111. if not cluster_selector or not source_selector then
  112. rspamd_logger.debugm(N, task, 'skip rule %s, selectors: source="%s", cluster="%s"',
  113. rule.name, source_selector, cluster_selector)
  114. return
  115. end
  116. local function combine_scores(cur_elts, total_score, element_score)
  117. local final_score
  118. local size_score = cur_elts * rule.size_mult
  119. local cluster_score = total_score * rule.score_mult
  120. if element_score > 0 then
  121. -- We have seen this element mostly in junk/spam
  122. final_score = math.min(1.0, size_score + cluster_score)
  123. else
  124. -- We have seen this element in ham mostly, so subtract average it from the size score
  125. final_score = math.min(1.0, size_score - cluster_score / cur_elts)
  126. end
  127. rspamd_logger.debugm(N, task,
  128. 'processed rule %s, selectors: source="%s", cluster="%s"; data: %s elts, %s score, %s elt score',
  129. rule.name, source_selector, cluster_selector, cur_elts, total_score, element_score)
  130. if final_score > 0.1 then
  131. task:insert_result(rule.symbol, final_score, { source_selector,
  132. tostring(size_score),
  133. tostring(cluster_score) })
  134. end
  135. end
  136. local function redis_get_cb(err, data)
  137. if data then
  138. if type(data) == 'table' then
  139. combine_scores(tonumber(data[1]), tonumber(data[2]), tonumber(data[3]))
  140. else
  141. rspamd_logger.errx(task, 'invalid type while getting clustering keys %s: %s',
  142. source_selector, type(data))
  143. end
  144. elseif err then
  145. rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s',
  146. source_selector, err)
  147. else
  148. rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s',
  149. source_selector, "unknown error")
  150. end
  151. end
  152. lua_redis.exec_redis_script(query_cluster_id,
  153. { task = task, is_write = false, key = source_selector },
  154. redis_get_cb,
  155. { source_selector, cluster_selector })
  156. end
  157. local function clusterting_idempotent_cb(task, rule)
  158. if task:has_flag('skip') then
  159. return
  160. end
  161. if not rule.allow_local and lua_util.is_rspamc_or_controller(task) then
  162. return
  163. end
  164. local verdict = lua_verdict.get_specific_verdict(N, task)
  165. local score
  166. if verdict == 'ham' then
  167. score = rule.ham_mult
  168. elseif verdict == 'spam' then
  169. score = rule.spam_mult
  170. elseif verdict == 'junk' then
  171. score = rule.junk_mult
  172. else
  173. rspamd_logger.debugm(N, task, 'skip rule %s, verdict=%s',
  174. rule.name, verdict)
  175. return
  176. end
  177. local source_selector = rule.source_selector(task)
  178. local cluster_selector
  179. if source_selector then
  180. cluster_selector = rule.cluster_selector(task)
  181. end
  182. if not cluster_selector or not source_selector then
  183. rspamd_logger.debugm(N, task, 'skip rule %s, selectors: source="%s", cluster="%s"',
  184. rule.name, source_selector, cluster_selector)
  185. return
  186. end
  187. local function redis_set_cb(err, data)
  188. if err then
  189. rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s',
  190. source_selector, err)
  191. else
  192. rspamd_logger.debugm(N, task, 'set clustering key for %s: %s{%s} = %s',
  193. source_selector, "unknown error")
  194. end
  195. end
  196. lua_redis.exec_redis_script(update_cluster_id,
  197. { task = task, is_write = true, key = source_selector },
  198. redis_set_cb,
  199. {
  200. source_selector,
  201. cluster_selector,
  202. tostring(score),
  203. tostring(rule.max_elts),
  204. tostring(rule.expire),
  205. tostring(rule.expire_overflow)
  206. }
  207. )
  208. end
  209. -- Init part
  210. redis_params = lua_redis.parse_redis_server('clustering')
  211. local opts = rspamd_config:get_all_opt("clustering")
  212. -- Initialization part
  213. if not (opts and type(opts) == 'table') then
  214. lua_util.disable_module(N, "config")
  215. return
  216. end
  217. if not redis_params then
  218. lua_util.disable_module(N, "redis")
  219. return
  220. end
  221. if opts['rules'] then
  222. for k, v in pairs(opts['rules']) do
  223. local raw_rule = lua_util.override_defaults(default_rule, v)
  224. local rule, err = rule_schema:transform(raw_rule)
  225. if not rule then
  226. rspamd_logger.errx(rspamd_config, 'invalid clustering rule %s: %s',
  227. k, err)
  228. else
  229. if not rule.symbol then
  230. rule.symbol = k
  231. end
  232. if not rule.prefix then
  233. rule.prefix = k .. "_"
  234. end
  235. rule.source_selector = lua_selectors.create_selector_closure(rspamd_config,
  236. rule.source_selector, '')
  237. rule.cluster_selector = lua_selectors.create_selector_closure(rspamd_config,
  238. rule.cluster_selector, '')
  239. if rule.source_selector and rule.cluster_selector then
  240. rule.name = k
  241. table.insert(rules, rule)
  242. end
  243. end
  244. end
  245. if #rules > 0 then
  246. query_cluster_id = lua_redis.add_redis_script(query_cluster_script, redis_params)
  247. update_cluster_id = lua_redis.add_redis_script(update_cluster_script, redis_params)
  248. local function callback_gen(f, rule)
  249. return function(task)
  250. return f(task, rule)
  251. end
  252. end
  253. for _, rule in ipairs(rules) do
  254. rspamd_config:register_symbol {
  255. name = rule.symbol,
  256. type = 'normal',
  257. callback = callback_gen(clusterting_filter_cb, rule),
  258. }
  259. rspamd_config:register_symbol {
  260. name = rule.symbol .. '_STORE',
  261. type = 'idempotent',
  262. flags = 'empty,explicit_disable,ignore_passthrough',
  263. callback = callback_gen(clusterting_idempotent_cb, rule),
  264. augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }
  265. }
  266. end
  267. else
  268. lua_util.disable_module(N, "config")
  269. end
  270. else
  271. lua_util.disable_module(N, "config")
  272. end