You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

bayes_expiry.lua 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. --[[
  2. Copyright (c) 2017, Andrew Lewis <nerf@judo.za.org>
  3. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  4. Licensed under the Apache License, Version 2.0 (the "License");
  5. you may not use this file except in compliance with the License.
  6. You may obtain a copy of the License at
  7. http://www.apache.org/licenses/LICENSE-2.0
  8. Unless required by applicable law or agreed to in writing, software
  9. distributed under the License is distributed on an "AS IS" BASIS,
  10. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. See the License for the specific language governing permissions and
  12. limitations under the License.
  13. ]] --
  14. if confighelp then
  15. return
  16. end
  17. local N = 'bayes_expiry'
  18. local E = {}
  19. local logger = require "rspamd_logger"
  20. local rspamd_util = require "rspamd_util"
  21. local lutil = require "lua_util"
  22. local lredis = require "lua_redis"
  23. local settings = {
  24. interval = 60, -- one iteration step per minute
  25. count = 1000, -- check up to 1000 keys on each iteration
  26. epsilon_common = 0.01, -- eliminate common if spam to ham rate is equal to this epsilon
  27. common_ttl = 10 * 86400, -- TTL of discriminated common elements
  28. significant_factor = 3.0 / 4.0, -- which tokens should we update
  29. classifiers = {},
  30. cluster_nodes = 0,
  31. }
  32. local template = {}
  33. local function check_redis_classifier(cls, cfg)
  34. -- Skip old classifiers
  35. if cls.new_schema then
  36. local symbol_spam, symbol_ham
  37. local expiry = (cls.expiry or cls.expire)
  38. if type(expiry) == 'table' then
  39. expiry = expiry[1]
  40. end
  41. -- Load symbols from statfiles
  42. local function check_statfile_table(tbl, def_sym)
  43. local symbol = tbl.symbol or def_sym
  44. local spam
  45. if tbl.spam then
  46. spam = tbl.spam
  47. else
  48. if string.match(symbol:upper(), 'SPAM') then
  49. spam = true
  50. else
  51. spam = false
  52. end
  53. end
  54. if spam then
  55. symbol_spam = symbol
  56. else
  57. symbol_ham = symbol
  58. end
  59. end
  60. local statfiles = cls.statfile
  61. if statfiles[1] then
  62. for _, stf in ipairs(statfiles) do
  63. if not stf.symbol then
  64. for k, v in pairs(stf) do
  65. check_statfile_table(v, k)
  66. end
  67. else
  68. check_statfile_table(stf, 'undefined')
  69. end
  70. end
  71. else
  72. for stn, stf in pairs(statfiles) do
  73. check_statfile_table(stf, stn)
  74. end
  75. end
  76. if not symbol_spam or not symbol_ham or type(expiry) ~= 'number' then
  77. logger.debugm(N, rspamd_config,
  78. 'disable expiry for classifier %s: no expiry %s',
  79. symbol_spam, cls)
  80. return
  81. end
  82. -- Now try to load redis_params if needed
  83. local redis_params
  84. redis_params = lredis.try_load_redis_servers(cls, rspamd_config, false, 'bayes')
  85. if not redis_params then
  86. redis_params = lredis.try_load_redis_servers(cfg[N] or E, rspamd_config, false, 'bayes')
  87. if not redis_params then
  88. redis_params = lredis.try_load_redis_servers(cfg[N] or E, rspamd_config, true)
  89. if not redis_params then
  90. return false
  91. end
  92. end
  93. end
  94. if redis_params['read_only'] then
  95. logger.infox(rspamd_config, 'disable expiry for classifier %s: read only redis configuration',
  96. symbol_spam)
  97. return
  98. end
  99. logger.debugm(N, rspamd_config, "enabled expiry for %s/%s -> %s expiry",
  100. symbol_spam, symbol_ham, expiry)
  101. table.insert(settings.classifiers, {
  102. symbol_spam = symbol_spam,
  103. symbol_ham = symbol_ham,
  104. redis_params = redis_params,
  105. expiry = expiry
  106. })
  107. end
  108. end
  109. -- Check classifiers and try find the appropriate ones
  110. local obj = rspamd_config:get_ucl()
  111. local classifier = obj.classifier
  112. if classifier then
  113. if classifier[1] then
  114. for _, cls in ipairs(classifier) do
  115. if cls.bayes then
  116. cls = cls.bayes
  117. end
  118. if cls.backend and cls.backend == 'redis' then
  119. check_redis_classifier(cls, obj)
  120. end
  121. end
  122. else
  123. if classifier.bayes then
  124. classifier = classifier.bayes
  125. if classifier[1] then
  126. for _, cls in ipairs(classifier) do
  127. if cls.backend and cls.backend == 'redis' then
  128. check_redis_classifier(cls, obj)
  129. end
  130. end
  131. else
  132. if classifier.backend and classifier.backend == 'redis' then
  133. check_redis_classifier(classifier, obj)
  134. end
  135. end
  136. end
  137. end
  138. end
  139. local opts = rspamd_config:get_all_opt(N)
  140. if opts then
  141. for k, v in pairs(opts) do
  142. settings[k] = v
  143. end
  144. end
  145. -- In clustered setup, we need to increase interval of expiration
  146. -- according to number of nodes in a cluster
  147. if settings.cluster_nodes == 0 then
  148. local neighbours = obj.neighbours or {}
  149. local n_neighbours = 0
  150. for _, _ in pairs(neighbours) do
  151. n_neighbours = n_neighbours + 1
  152. end
  153. settings.cluster_nodes = n_neighbours
  154. end
  155. -- Fill template
  156. template.count = settings.count
  157. template.threshold = settings.threshold
  158. template.common_ttl = settings.common_ttl
  159. template.epsilon_common = settings.epsilon_common
  160. template.significant_factor = settings.significant_factor
  161. template.expire_step = settings.interval
  162. template.hostname = rspamd_util.get_hostname()
  163. for k, v in pairs(template) do
  164. template[k] = tostring(v)
  165. end
  166. -- Arguments:
  167. -- [1] = symbol pattern
  168. -- [2] = expire value
  169. -- [3] = cursor
  170. -- returns {cursor for the next step, step number, step statistic counters, cycle statistic counters, tokens occurrences distribution}
  171. local expiry_script = [[
  172. local unpack_function = table.unpack or unpack
  173. local hash2list = function (hash)
  174. local res = {}
  175. for k, v in pairs(hash) do
  176. table.insert(res, k)
  177. table.insert(res, v)
  178. end
  179. return res
  180. end
  181. local function merge_list(table, list)
  182. local k
  183. for i, v in ipairs(list) do
  184. if i % 2 == 1 then
  185. k = v
  186. else
  187. table[k] = v
  188. end
  189. end
  190. end
  191. local expire = math.floor(KEYS[2])
  192. local pattern_sha1 = redis.sha1hex(KEYS[1])
  193. local lock_key = pattern_sha1 .. '_lock' -- Check locking
  194. local lock = redis.call('GET', lock_key)
  195. if lock then
  196. if lock ~= '${hostname}' then
  197. return 'locked by ' .. lock
  198. end
  199. end
  200. redis.replicate_commands()
  201. redis.call('SETEX', lock_key, ${expire_step}, '${hostname}')
  202. local cursor_key = pattern_sha1 .. '_cursor'
  203. local cursor = tonumber(redis.call('GET', cursor_key) or 0)
  204. local step = 1
  205. local step_key = pattern_sha1 .. '_step'
  206. if cursor > 0 then
  207. step = redis.call('GET', step_key)
  208. step = step and (tonumber(step) + 1) or 1
  209. end
  210. local ret = redis.call('SCAN', cursor, 'MATCH', KEYS[1], 'COUNT', '${count}')
  211. local next_cursor = ret[1]
  212. local keys = ret[2]
  213. local tokens = {}
  214. -- Tokens occurrences distribution counters
  215. local occur = {
  216. ham = {},
  217. spam = {},
  218. total = {}
  219. }
  220. -- Expiry step statistics counters
  221. local nelts, extended, discriminated, sum, sum_squares, common, significant,
  222. infrequent, infrequent_ttls_set, insignificant, insignificant_ttls_set =
  223. 0,0,0,0,0,0,0,0,0,0,0
  224. for _,key in ipairs(keys) do
  225. local t = redis.call('TYPE', key)["ok"]
  226. if t == 'hash' then
  227. local values = redis.call('HMGET', key, 'H', 'S')
  228. local ham = tonumber(values[1]) or 0
  229. local spam = tonumber(values[2]) or 0
  230. local ttl = redis.call('TTL', key)
  231. tokens[key] = {
  232. ham,
  233. spam,
  234. ttl
  235. }
  236. local total = spam + ham
  237. sum = sum + total
  238. sum_squares = sum_squares + total * total
  239. nelts = nelts + 1
  240. for k,v in pairs({['ham']=ham, ['spam']=spam, ['total']=total}) do
  241. if tonumber(v) > 19 then v = 20 end
  242. occur[k][v] = occur[k][v] and occur[k][v] + 1 or 1
  243. end
  244. end
  245. end
  246. local mean, stddev = 0, 0
  247. if nelts > 0 then
  248. mean = sum / nelts
  249. stddev = math.sqrt(sum_squares / nelts - mean * mean)
  250. end
  251. for key,token in pairs(tokens) do
  252. local ham, spam, ttl = token[1], token[2], tonumber(token[3])
  253. local threshold = mean
  254. local total = spam + ham
  255. local function set_ttl()
  256. if expire < 0 then
  257. if ttl ~= -1 then
  258. redis.call('PERSIST', key)
  259. return 1
  260. end
  261. elseif ttl == -1 or ttl > expire then
  262. redis.call('EXPIRE', key, expire)
  263. return 1
  264. end
  265. return 0
  266. end
  267. if total == 0 or math.abs(ham - spam) <= total * ${epsilon_common} then
  268. common = common + 1
  269. if ttl > ${common_ttl} then
  270. discriminated = discriminated + 1
  271. redis.call('EXPIRE', key, ${common_ttl})
  272. end
  273. elseif total >= threshold and total > 0 then
  274. if ham / total > ${significant_factor} or spam / total > ${significant_factor} then
  275. significant = significant + 1
  276. if ttl ~= -1 then
  277. redis.call('PERSIST', key)
  278. extended = extended + 1
  279. end
  280. else
  281. insignificant = insignificant + 1
  282. insignificant_ttls_set = insignificant_ttls_set + set_ttl()
  283. end
  284. else
  285. infrequent = infrequent + 1
  286. infrequent_ttls_set = infrequent_ttls_set + set_ttl()
  287. end
  288. end
  289. -- Expiry cycle statistics counters
  290. local c = {nelts = 0, extended = 0, discriminated = 0, sum = 0, sum_squares = 0,
  291. common = 0, significant = 0, infrequent = 0, infrequent_ttls_set = 0, insignificant = 0, insignificant_ttls_set = 0}
  292. local counters_key = pattern_sha1 .. '_counters'
  293. if cursor ~= 0 then
  294. merge_list(c, redis.call('HGETALL', counters_key))
  295. end
  296. c.nelts = c.nelts + nelts
  297. c.extended = c.extended + extended
  298. c.discriminated = c.discriminated + discriminated
  299. c.sum = c.sum + sum
  300. c.sum_squares = c.sum_squares + sum_squares
  301. c.common = c.common + common
  302. c.significant = c.significant + significant
  303. c.infrequent = c.infrequent + infrequent
  304. c.infrequent_ttls_set = c.infrequent_ttls_set + infrequent_ttls_set
  305. c.insignificant = c.insignificant + insignificant
  306. c.insignificant_ttls_set = c.insignificant_ttls_set + insignificant_ttls_set
  307. redis.call('HMSET', counters_key, unpack_function(hash2list(c)))
  308. redis.call('SET', cursor_key, tostring(next_cursor))
  309. redis.call('SET', step_key, tostring(step))
  310. redis.call('DEL', lock_key)
  311. local occ_distr = {}
  312. for _,cl in pairs({'ham', 'spam', 'total'}) do
  313. local occur_key = pattern_sha1 .. '_occurrence_' .. cl
  314. if cursor ~= 0 then
  315. local n
  316. for i,v in ipairs(redis.call('HGETALL', occur_key)) do
  317. if i % 2 == 1 then
  318. n = tonumber(v)
  319. else
  320. occur[cl][n] = occur[cl][n] and occur[cl][n] + v or v
  321. end
  322. end
  323. local str = ''
  324. if occur[cl][0] ~= nil then
  325. str = '0:' .. occur[cl][0] .. ','
  326. end
  327. for k,v in ipairs(occur[cl]) do
  328. if k == 20 then k = '>19' end
  329. str = str .. k .. ':' .. v .. ','
  330. end
  331. table.insert(occ_distr, str)
  332. else
  333. redis.call('DEL', occur_key)
  334. end
  335. if next(occur[cl]) ~= nil then
  336. redis.call('HMSET', occur_key, unpack_function(hash2list(occur[cl])))
  337. end
  338. end
  339. return {
  340. next_cursor, step,
  341. {nelts, extended, discriminated, mean, stddev, common, significant, infrequent,
  342. infrequent_ttls_set, insignificant, insignificant_ttls_set},
  343. {c.nelts, c.extended, c.discriminated, c.sum, c.sum_squares, c.common,
  344. c.significant, c.infrequent, c.infrequent_ttls_set, c.insignificant, c.insignificant_ttls_set},
  345. occ_distr
  346. }
  347. ]]
  348. local function expire_step(cls, ev_base, worker)
  349. local function redis_step_cb(err, args)
  350. if err then
  351. logger.errx(rspamd_config, 'cannot perform expiry step: %s', err)
  352. elseif type(args) == 'table' then
  353. local cur = tonumber(args[1])
  354. local step = args[2]
  355. local data = args[3]
  356. local c_data = args[4]
  357. local occ_distr = args[5]
  358. local function log_stat(cycle)
  359. local infrequent_action = (cls.expiry < 0) and 'made persistent' or 'ttls set'
  360. local c_mean, c_stddev = 0, 0
  361. if cycle and c_data[1] ~= 0 then
  362. c_mean = c_data[4] / c_data[1]
  363. c_stddev = math.floor(.5 + math.sqrt(c_data[5] / c_data[1] - c_mean * c_mean))
  364. c_mean = math.floor(.5 + c_mean)
  365. end
  366. local d = cycle and {
  367. 'cycle in ' .. step .. ' steps', c_data[1],
  368. c_data[7], c_data[2], 'made persistent',
  369. c_data[10], c_data[11], infrequent_action,
  370. c_data[6], c_data[3],
  371. c_data[8], c_data[9], infrequent_action,
  372. c_mean,
  373. c_stddev
  374. } or {
  375. 'step ' .. step, data[1],
  376. data[7], data[2], 'made persistent',
  377. data[10], data[11], infrequent_action,
  378. data[6], data[3],
  379. data[8], data[9], infrequent_action,
  380. data[4],
  381. data[5]
  382. }
  383. logger.infox(rspamd_config,
  384. 'finished expiry %s: %s items checked, %s significant (%s %s), ' ..
  385. '%s insignificant (%s %s), %s common (%s discriminated), ' ..
  386. '%s infrequent (%s %s), %s mean, %s std',
  387. lutil.unpack(d))
  388. if cycle then
  389. for i, cl in ipairs({ 'in ham', 'in spam', 'total' }) do
  390. logger.infox(rspamd_config, 'tokens occurrences, %s: {%s}', cl, occ_distr[i])
  391. end
  392. end
  393. end
  394. log_stat(false)
  395. if cur == 0 then
  396. log_stat(true)
  397. end
  398. elseif type(args) == 'string' then
  399. logger.infox(rspamd_config, 'skip expiry step: %s', args)
  400. end
  401. end
  402. lredis.exec_redis_script(cls.script,
  403. { ev_base = ev_base, is_write = true },
  404. redis_step_cb,
  405. { 'RS*_*', cls.expiry }
  406. )
  407. end
  408. rspamd_config:add_on_load(function(_, ev_base, worker)
  409. -- Exit unless we're the first 'controller' worker
  410. if not worker:is_primary_controller() then
  411. return
  412. end
  413. local unique_redis_params = {}
  414. -- Push redis script to all unique redis servers
  415. for _, cls in ipairs(settings.classifiers) do
  416. if not unique_redis_params[cls.redis_params.hash] then
  417. unique_redis_params[cls.redis_params.hash] = cls.redis_params
  418. end
  419. end
  420. for h, rp in pairs(unique_redis_params) do
  421. local script_id = lredis.add_redis_script(lutil.template(expiry_script,
  422. template), rp)
  423. for _, cls in ipairs(settings.classifiers) do
  424. if cls.redis_params.hash == h then
  425. cls.script = script_id
  426. end
  427. end
  428. end
  429. -- Expire tokens at regular intervals
  430. for _, cls in ipairs(settings.classifiers) do
  431. rspamd_config:add_periodic(ev_base,
  432. settings['interval'],
  433. function()
  434. expire_step(cls, ev_base, worker)
  435. return true
  436. end, true)
  437. end
  438. end)