You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

clustering.lua 9.0KB


  1. --[[
  2. Copyright (c) 2018, Vsevolod Stakhov
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. if confighelp then
  14. return
  15. end
  16. -- Plugin for finding patterns in email flows
  17. local N = 'clustering'
  18. local rspamd_logger = require "rspamd_logger"
  19. local lua_util = require "lua_util"
  20. local lua_verdict = require "lua_verdict"
  21. local lua_redis = require "lua_redis"
  22. local lua_selectors = require "lua_selectors"
  23. local ts = require("tableshape").types
  24. local redis_params
  25. local rules = {} -- Rules placement
  26. local default_rule = {
  27. max_elts = 100, -- Maximum elements in a cluster
  28. expire = 3600, -- Expire for a bucket when limit is not reached
  29. expire_overflow = 36000, -- Expire for a bucket when limit is reached
  30. spam_mult = 1.0, -- Increase on spam hit
  31. junk_mult = 0.5, -- Increase on junk
  32. ham_mult = -0.1, -- Increase on ham
  33. size_mult = 0.01, -- Reaches 1.0 on `max_elts`
  34. score_mult = 0.1,
  35. }
  36. local rule_schema = ts.shape{
  37. max_elts = ts.number + ts.string / tonumber,
  38. expire = ts.number + ts.string / lua_util.parse_time_interval,
  39. expire_overflow = ts.number + ts.string / lua_util.parse_time_interval,
  40. spam_mult = ts.number,
  41. junk_mult = ts.number,
  42. ham_mult = ts.number,
  43. size_mult = ts.number,
  44. score_mult = ts.number,
  45. source_selector = ts.string,
  46. cluster_selector = ts.string,
  47. symbol = ts.string:is_optional(),
  48. prefix = ts.string:is_optional(),
  49. }
  50. -- Redis scripts
  51. -- Queries for a cluster's data
  52. -- Arguments:
  53. -- 1. Source selector (string)
  54. -- 2. Cluster selector (string)
  55. -- Returns: {cur_elts, total_score, element_score}
  56. local query_cluster_script = [[
  57. local sz = redis.call('HLEN', KEYS[1])
  58. if not sz or not tonumber(sz) then
  59. -- New bucket, will update on idempotent phase
  60. return {0, '0', '0'}
  61. end
  62. local total_score = redis.call('HGET', KEYS[1], '__s')
  63. total_score = tonumber(total_score) or 0
  64. local score = redis.call('HGET', KEYS[1], KEYS[2])
  65. if not score or not tonumber(score) then
  66. return {sz, tostring(total_score), '0'}
  67. end
  68. return {sz, tostring(total_score), tostring(score)}
  69. ]]
  70. local query_cluster_id
  71. -- Updates cluster's data
  72. -- Arguments:
  73. -- 1. Source selector (string)
  74. -- 2. Cluster selector (string)
  75. -- 3. Score (number)
  76. -- 4. Max buckets (number)
  77. -- 5. Expire (number)
  78. -- 6. Expire overflow (number)
  79. -- Returns: nothing
  80. local update_cluster_script = [[
  81. local sz = redis.call('HLEN', KEYS[1])
  82. if not sz or not tonumber(sz) then
  83. -- Create bucket
  84. redis.call('HSET', KEYS[1], KEYS[2], math.abs(KEYS[3]))
  85. redis.call('HSET', KEYS[1], '__s', KEYS[3])
  86. redis.call('EXPIRE', KEYS[1], KEYS[5])
  87. return
  88. end
  89. sz = tonumber(sz)
  90. local lim = tonumber(KEYS[4])
  91. if sz > lim then
  92. if k then
  93. -- Existing key
  94. redis.call('HINCRBYFLOAT', KEYS[1], KEYS[2], math.abs(KEYS[3]))
  95. end
  96. else
  97. redis.call('HINCRBYFLOAT', KEYS[1], KEYS[2], math.abs(KEYS[3]))
  98. redis.call('EXPIRE', KEYS[1], KEYS[6])
  99. end
  100. redis.call('HINCRBYFLOAT', KEYS[1], '__s', KEYS[3])
  101. redis.call('EXPIRE', KEYS[1], KEYS[5])
  102. ]]
  103. local update_cluster_id
  104. -- Callbacks and logic
  105. local function clusterting_filter_cb(task, rule)
  106. local source_selector = rule.source_selector(task)
  107. local cluster_selector
  108. if source_selector then
  109. cluster_selector = rule.cluster_selector(task)
  110. end
  111. if not cluster_selector or not source_selector then
  112. rspamd_logger.debugm(N, task, 'skip rule %s, selectors: source="%s", cluster="%s"',
  113. rule.name, source_selector, cluster_selector)
  114. return
  115. end
  116. local function combine_scores(cur_elts, total_score, element_score)
  117. local final_score
  118. local size_score = cur_elts * rule.size_mult
  119. local cluster_score = total_score * rule.score_mult
  120. if element_score > 0 then
  121. -- We have seen this element mostly in junk/spam
  122. final_score = math.min(1.0, size_score + cluster_score)
  123. else
  124. -- We have seen this element in ham mostly, so subtract average it from the size score
  125. final_score = math.min(1.0, size_score - cluster_score / cur_elts)
  126. end
  127. rspamd_logger.debugm(N, task,
  128. 'processed rule %s, selectors: source="%s", cluster="%s"; data: %s elts, %s score, %s elt score',
  129. rule.name, source_selector, cluster_selector, cur_elts, total_score, element_score)
  130. if final_score > 0.1 then
  131. task:insert_result(rule.symbol, final_score, {source_selector,
  132. tostring(size_score),
  133. tostring(cluster_score)})
  134. end
  135. end
  136. local function redis_get_cb(err, data)
  137. if data then
  138. if type(data) == 'table' then
  139. combine_scores(tonumber(data[1]), tonumber(data[2]), tonumber(data[3]))
  140. else
  141. rspamd_logger.errx(task, 'invalid type while getting clustering keys %s: %s',
  142. source_selector, type(data))
  143. end
  144. elseif err then
  145. rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s',
  146. source_selector, err)
  147. else
  148. rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s',
  149. source_selector, "unknown error")
  150. end
  151. end
  152. lua_redis.exec_redis_script(query_cluster_id,
  153. {task = task, is_write = false, key = source_selector},
  154. redis_get_cb,
  155. {source_selector, cluster_selector})
  156. end
  157. local function clusterting_idempotent_cb(task, rule)
  158. if task:has_flag('skip') then return end
  159. if not rule.allow_local and lua_util.is_rspamc_or_controller(task) then return end
  160. local verdict = lua_verdict.get_specific_verdict(N, task)
  161. local score
  162. if verdict == 'ham' then
  163. score = rule.ham_mult
  164. elseif verdict == 'spam' then
  165. score = rule.spam_mult
  166. elseif verdict == 'junk' then
  167. score = rule.junk_mult
  168. else
  169. rspamd_logger.debugm(N, task, 'skip rule %s, verdict=%s',
  170. rule.name, verdict)
  171. return
  172. end
  173. local source_selector = rule.source_selector(task)
  174. local cluster_selector
  175. if source_selector then
  176. cluster_selector = rule.cluster_selector(task)
  177. end
  178. if not cluster_selector or not source_selector then
  179. rspamd_logger.debugm(N, task, 'skip rule %s, selectors: source="%s", cluster="%s"',
  180. rule.name, source_selector, cluster_selector)
  181. return
  182. end
  183. local function redis_set_cb(err, data)
  184. if err then
  185. rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s',
  186. source_selector, err)
  187. else
  188. rspamd_logger.debugm(N, task, 'set clustering key for %s: %s{%s} = %s',
  189. source_selector, "unknown error")
  190. end
  191. end
  192. lua_redis.exec_redis_script(update_cluster_id,
  193. {task = task, is_write = true, key = source_selector},
  194. redis_set_cb,
  195. {
  196. source_selector,
  197. cluster_selector,
  198. tostring(score),
  199. tostring(rule.max_elts),
  200. tostring(rule.expire),
  201. tostring(rule.expire_overflow)
  202. }
  203. )
  204. end
  205. -- Init part
  206. redis_params = lua_redis.parse_redis_server('clustering')
  207. local opts = rspamd_config:get_all_opt("clustering")
  208. -- Initialization part
  209. if not (opts and type(opts) == 'table') then
  210. lua_util.disable_module(N, "config")
  211. return
  212. end
  213. if not redis_params then
  214. lua_util.disable_module(N, "redis")
  215. return
  216. end
  217. if opts['rules'] then
  218. for k,v in pairs(opts['rules']) do
  219. local raw_rule = lua_util.override_defaults(default_rule, v)
  220. local rule,err = rule_schema:transform(raw_rule)
  221. if not rule then
  222. rspamd_logger.errx(rspamd_config, 'invalid clustering rule %s: %s',
  223. k, err)
  224. else
  225. if not rule.symbol then rule.symbol = k end
  226. if not rule.prefix then rule.prefix = k .. "_" end
  227. rule.source_selector = lua_selectors.create_selector_closure(rspamd_config,
  228. rule.source_selector, '')
  229. rule.cluster_selector = lua_selectors.create_selector_closure(rspamd_config,
  230. rule.cluster_selector, '')
  231. if rule.source_selector and rule.cluster_selector then
  232. rule.name = k
  233. table.insert(rules, rule)
  234. end
  235. end
  236. end
  237. if #rules > 0 then
  238. query_cluster_id = lua_redis.add_redis_script(query_cluster_script, redis_params)
  239. update_cluster_id = lua_redis.add_redis_script(update_cluster_script, redis_params)
  240. local function callback_gen(f, rule)
  241. return function(task) return f(task, rule) end
  242. end
  243. for _,rule in ipairs(rules) do
  244. rspamd_config:register_symbol{
  245. name = rule.symbol,
  246. type = 'normal',
  247. callback = callback_gen(clusterting_filter_cb, rule),
  248. }
  249. rspamd_config:register_symbol{
  250. name = rule.symbol .. '_STORE',
  251. type = 'idempotent',
  252. callback = callback_gen(clusterting_idempotent_cb, rule),
  253. }
  254. end
  255. else
  256. lua_util.disable_module(N, "config")
  257. end
  258. else
  259. lua_util.disable_module(N, "config")
  260. end