You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

bayes_expiry.lua 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. --[[
  2. Copyright (c) 2017, Andrew Lewis <nerf@judo.za.org>
  3. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  4. Licensed under the Apache License, Version 2.0 (the "License");
  5. you may not use this file except in compliance with the License.
  6. You may obtain a copy of the License at
  7. http://www.apache.org/licenses/LICENSE-2.0
  8. Unless required by applicable law or agreed to in writing, software
  9. distributed under the License is distributed on an "AS IS" BASIS,
  10. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. See the License for the specific language governing permissions and
  12. limitations under the License.
  13. ]] --
  14. if confighelp then
  15. return
  16. end
  17. local N = 'bayes_expiry'
  18. local E = {}
  19. local logger = require "rspamd_logger"
  20. local rspamd_util = require "rspamd_util"
  21. local lutil = require "lua_util"
  22. local lredis = require "lua_redis"
  23. local settings = {
  24. interval = 60, -- one iteration step per minute
  25. count = 1000, -- check up to 1000 keys on each iteration
  26. epsilon_common = 0.01, -- eliminate common if spam to ham rate is equal to this epsilon
  27. common_ttl = 10 * 86400, -- TTL of discriminated common elements
  28. significant_factor = 3.0 / 4.0, -- which tokens should we update
  29. classifiers = {},
  30. cluster_nodes = 0,
  31. }
  32. local template = {}
  33. local function check_redis_classifier(cls, cfg)
  34. -- Skip old classifiers
  35. if cls.new_schema then
  36. local symbol_spam, symbol_ham
  37. local expiry = (cls.expiry or cls.expire)
  38. if type(expiry) == 'table' then
  39. expiry = expiry[1]
  40. end
  41. -- Load symbols from statfiles
  42. local function check_statfile_table(tbl, def_sym)
  43. local symbol = tbl.symbol or def_sym
  44. local spam
  45. if tbl.spam then
  46. spam = tbl.spam
  47. else
  48. if string.match(symbol:upper(), 'SPAM') then
  49. spam = true
  50. else
  51. spam = false
  52. end
  53. end
  54. if spam then
  55. symbol_spam = symbol
  56. else
  57. symbol_ham = symbol
  58. end
  59. end
  60. local statfiles = cls.statfile
  61. if statfiles[1] then
  62. for _,stf in ipairs(statfiles) do
  63. if not stf.symbol then
  64. for k,v in pairs(stf) do
  65. check_statfile_table(v, k)
  66. end
  67. else
  68. check_statfile_table(stf, 'undefined')
  69. end
  70. end
  71. else
  72. for stn,stf in pairs(statfiles) do
  73. check_statfile_table(stf, stn)
  74. end
  75. end
  76. if not symbol_spam or not symbol_ham or type(expiry) ~= 'number' then
  77. logger.debugm(N, rspamd_config,
  78. 'disable expiry for classifier %s: no expiry %s',
  79. symbol_spam, cls)
  80. return
  81. end
  82. -- Now try to load redis_params if needed
  83. local redis_params
  84. redis_params = lredis.try_load_redis_servers(cls, rspamd_config, false, 'bayes')
  85. if not redis_params then
  86. redis_params = lredis.try_load_redis_servers(cfg[N] or E, rspamd_config, false, 'bayes')
  87. if not redis_params then
  88. redis_params = lredis.try_load_redis_servers(cfg[N] or E, rspamd_config, true)
  89. if not redis_params then
  90. return false
  91. end
  92. end
  93. end
  94. if redis_params['read_only'] then
  95. logger.infox(rspamd_config, 'disable expiry for classifier %s: read only redis configuration',
  96. symbol_spam)
  97. return
  98. end
  99. logger.debugm(N, rspamd_config, "enabled expiry for %s/%s -> %s expiry",
  100. symbol_spam, symbol_ham, expiry)
  101. table.insert(settings.classifiers, {
  102. symbol_spam = symbol_spam,
  103. symbol_ham = symbol_ham,
  104. redis_params = redis_params,
  105. expiry = expiry
  106. })
  107. end
  108. end
  109. -- Check classifiers and try find the appropriate ones
  110. local obj = rspamd_config:get_ucl()
  111. local classifier = obj.classifier
  112. if classifier then
  113. if classifier[1] then
  114. for _,cls in ipairs(classifier) do
  115. if cls.bayes then cls = cls.bayes end
  116. if cls.backend and cls.backend == 'redis' then
  117. check_redis_classifier(cls, obj)
  118. end
  119. end
  120. else
  121. if classifier.bayes then
  122. classifier = classifier.bayes
  123. if classifier[1] then
  124. for _,cls in ipairs(classifier) do
  125. if cls.backend and cls.backend == 'redis' then
  126. check_redis_classifier(cls, obj)
  127. end
  128. end
  129. else
  130. if classifier.backend and classifier.backend == 'redis' then
  131. check_redis_classifier(classifier, obj)
  132. end
  133. end
  134. end
  135. end
  136. end
  137. local opts = rspamd_config:get_all_opt(N)
  138. if opts then
  139. for k,v in pairs(opts) do
  140. settings[k] = v
  141. end
  142. end
  143. -- In clustered setup, we need to increase interval of expiration
  144. -- according to number of nodes in a cluster
  145. if settings.cluster_nodes == 0 then
  146. local neighbours = obj.neighbours or {}
  147. local n_neighbours = 0
  148. for _,_ in pairs(neighbours) do n_neighbours = n_neighbours + 1 end
  149. settings.cluster_nodes = n_neighbours
  150. end
  151. -- Fill template
  152. template.count = settings.count
  153. template.threshold = settings.threshold
  154. template.common_ttl = settings.common_ttl
  155. template.epsilon_common = settings.epsilon_common
  156. template.significant_factor = settings.significant_factor
  157. template.expire_step = settings.interval
  158. template.hostname = rspamd_util.get_hostname()
  159. for k,v in pairs(template) do
  160. template[k] = tostring(v)
  161. end
  162. -- Arguments:
  163. -- [1] = symbol pattern
  164. -- [2] = expire value
  165. -- [3] = cursor
  166. -- returns {cursor for the next step, step number, step statistic counters, cycle statistic counters, tokens occurrences distribution}
  167. local expiry_script = [[
  168. local unpack_function = table.unpack or unpack
  169. local hash2list = function (hash)
  170. local res = {}
  171. for k, v in pairs(hash) do
  172. table.insert(res, k)
  173. table.insert(res, v)
  174. end
  175. return res
  176. end
  177. local function merge_list(table, list)
  178. local k
  179. for i, v in ipairs(list) do
  180. if i % 2 == 1 then
  181. k = v
  182. else
  183. table[k] = v
  184. end
  185. end
  186. end
  187. local expire = math.floor(KEYS[2])
  188. local pattern_sha1 = redis.sha1hex(KEYS[1])
  189. local lock_key = pattern_sha1 .. '_lock' -- Check locking
  190. local lock = redis.call('GET', lock_key)
  191. if lock then
  192. if lock ~= '${hostname}' then
  193. return 'locked by ' .. lock
  194. end
  195. end
  196. redis.replicate_commands()
  197. redis.call('SETEX', lock_key, ${expire_step}, '${hostname}')
  198. local cursor_key = pattern_sha1 .. '_cursor'
  199. local cursor = tonumber(redis.call('GET', cursor_key) or 0)
  200. local step = 1
  201. local step_key = pattern_sha1 .. '_step'
  202. if cursor > 0 then
  203. step = redis.call('GET', step_key)
  204. step = step and (tonumber(step) + 1) or 1
  205. end
  206. local ret = redis.call('SCAN', cursor, 'MATCH', KEYS[1], 'COUNT', '${count}')
  207. local next_cursor = ret[1]
  208. local keys = ret[2]
  209. local tokens = {}
  210. -- Tokens occurrences distribution counters
  211. local occur = {
  212. ham = {},
  213. spam = {},
  214. total = {}
  215. }
  216. -- Expiry step statistics counters
  217. local nelts, extended, discriminated, sum, sum_squares, common, significant,
  218. infrequent, infrequent_ttls_set, insignificant, insignificant_ttls_set =
  219. 0,0,0,0,0,0,0,0,0,0,0
  220. for _,key in ipairs(keys) do
  221. local t = redis.call('TYPE', key)["ok"]
  222. if t == 'hash' then
  223. local values = redis.call('HMGET', key, 'H', 'S')
  224. local ham = tonumber(values[1]) or 0
  225. local spam = tonumber(values[2]) or 0
  226. local ttl = redis.call('TTL', key)
  227. tokens[key] = {
  228. ham,
  229. spam,
  230. ttl
  231. }
  232. local total = spam + ham
  233. sum = sum + total
  234. sum_squares = sum_squares + total * total
  235. nelts = nelts + 1
  236. for k,v in pairs({['ham']=ham, ['spam']=spam, ['total']=total}) do
  237. if tonumber(v) > 19 then v = 20 end
  238. occur[k][v] = occur[k][v] and occur[k][v] + 1 or 1
  239. end
  240. end
  241. end
  242. local mean, stddev = 0, 0
  243. if nelts > 0 then
  244. mean = sum / nelts
  245. stddev = math.sqrt(sum_squares / nelts - mean * mean)
  246. end
  247. for key,token in pairs(tokens) do
  248. local ham, spam, ttl = token[1], token[2], tonumber(token[3])
  249. local threshold = mean
  250. local total = spam + ham
  251. local function set_ttl()
  252. if expire < 0 then
  253. if ttl ~= -1 then
  254. redis.call('PERSIST', key)
  255. return 1
  256. end
  257. elseif ttl == -1 or ttl > expire then
  258. redis.call('EXPIRE', key, expire)
  259. return 1
  260. end
  261. return 0
  262. end
  263. if total == 0 or math.abs(ham - spam) <= total * ${epsilon_common} then
  264. common = common + 1
  265. if ttl > ${common_ttl} then
  266. discriminated = discriminated + 1
  267. redis.call('EXPIRE', key, ${common_ttl})
  268. end
  269. elseif total >= threshold and total > 0 then
  270. if ham / total > ${significant_factor} or spam / total > ${significant_factor} then
  271. significant = significant + 1
  272. if ttl ~= -1 then
  273. redis.call('PERSIST', key)
  274. extended = extended + 1
  275. end
  276. else
  277. insignificant = insignificant + 1
  278. insignificant_ttls_set = insignificant_ttls_set + set_ttl()
  279. end
  280. else
  281. infrequent = infrequent + 1
  282. infrequent_ttls_set = infrequent_ttls_set + set_ttl()
  283. end
  284. end
  285. -- Expiry cycle statistics counters
  286. local c = {nelts = 0, extended = 0, discriminated = 0, sum = 0, sum_squares = 0,
  287. common = 0, significant = 0, infrequent = 0, infrequent_ttls_set = 0, insignificant = 0, insignificant_ttls_set = 0}
  288. local counters_key = pattern_sha1 .. '_counters'
  289. if cursor ~= 0 then
  290. merge_list(c, redis.call('HGETALL', counters_key))
  291. end
  292. c.nelts = c.nelts + nelts
  293. c.extended = c.extended + extended
  294. c.discriminated = c.discriminated + discriminated
  295. c.sum = c.sum + sum
  296. c.sum_squares = c.sum_squares + sum_squares
  297. c.common = c.common + common
  298. c.significant = c.significant + significant
  299. c.infrequent = c.infrequent + infrequent
  300. c.infrequent_ttls_set = c.infrequent_ttls_set + infrequent_ttls_set
  301. c.insignificant = c.insignificant + insignificant
  302. c.insignificant_ttls_set = c.insignificant_ttls_set + insignificant_ttls_set
  303. redis.call('HMSET', counters_key, unpack_function(hash2list(c)))
  304. redis.call('SET', cursor_key, tostring(next_cursor))
  305. redis.call('SET', step_key, tostring(step))
  306. redis.call('DEL', lock_key)
  307. local occ_distr = {}
  308. for _,cl in pairs({'ham', 'spam', 'total'}) do
  309. local occur_key = pattern_sha1 .. '_occurrence_' .. cl
  310. if cursor ~= 0 then
  311. local n
  312. for i,v in ipairs(redis.call('HGETALL', occur_key)) do
  313. if i % 2 == 1 then
  314. n = tonumber(v)
  315. else
  316. occur[cl][n] = occur[cl][n] and occur[cl][n] + v or v
  317. end
  318. end
  319. local str = ''
  320. if occur[cl][0] ~= nil then
  321. str = '0:' .. occur[cl][0] .. ','
  322. end
  323. for k,v in ipairs(occur[cl]) do
  324. if k == 20 then k = '>19' end
  325. str = str .. k .. ':' .. v .. ','
  326. end
  327. table.insert(occ_distr, str)
  328. else
  329. redis.call('DEL', occur_key)
  330. end
  331. if next(occur[cl]) ~= nil then
  332. redis.call('HMSET', occur_key, unpack_function(hash2list(occur[cl])))
  333. end
  334. end
  335. return {
  336. next_cursor, step,
  337. {nelts, extended, discriminated, mean, stddev, common, significant, infrequent,
  338. infrequent_ttls_set, insignificant, insignificant_ttls_set},
  339. {c.nelts, c.extended, c.discriminated, c.sum, c.sum_squares, c.common,
  340. c.significant, c.infrequent, c.infrequent_ttls_set, c.insignificant, c.insignificant_ttls_set},
  341. occ_distr
  342. }
  343. ]]
  344. local function expire_step(cls, ev_base, worker)
  345. local function redis_step_cb(err, args)
  346. if err then
  347. logger.errx(rspamd_config, 'cannot perform expiry step: %s', err)
  348. elseif type(args) == 'table' then
  349. local cur = tonumber(args[1])
  350. local step = args[2]
  351. local data = args[3]
  352. local c_data = args[4]
  353. local occ_distr = args[5]
  354. local function log_stat(cycle)
  355. local infrequent_action = (cls.expiry < 0) and 'made persistent' or 'ttls set'
  356. local c_mean, c_stddev = 0, 0
  357. if cycle and c_data[1] ~= 0 then
  358. c_mean = c_data[4] / c_data[1]
  359. c_stddev = math.floor(.5 + math.sqrt(c_data[5] / c_data[1] - c_mean * c_mean))
  360. c_mean = math.floor(.5 + c_mean)
  361. end
  362. local d = cycle and {
  363. 'cycle in ' .. step .. ' steps', c_data[1],
  364. c_data[7], c_data[2], 'made persistent',
  365. c_data[10], c_data[11], infrequent_action,
  366. c_data[6], c_data[3],
  367. c_data[8], c_data[9], infrequent_action,
  368. c_mean,
  369. c_stddev
  370. } or {
  371. 'step ' .. step, data[1],
  372. data[7], data[2], 'made persistent',
  373. data[10], data[11], infrequent_action,
  374. data[6], data[3],
  375. data[8], data[9], infrequent_action,
  376. data[4],
  377. data[5]
  378. }
  379. logger.infox(rspamd_config,
  380. 'finished expiry %s: %s items checked, %s significant (%s %s), ' ..
  381. '%s insignificant (%s %s), %s common (%s discriminated), ' ..
  382. '%s infrequent (%s %s), %s mean, %s std',
  383. lutil.unpack(d))
  384. if cycle then
  385. for i,cl in ipairs({'in ham', 'in spam', 'total'}) do
  386. logger.infox(rspamd_config, 'tokens occurrences, %s: {%s}', cl, occ_distr[i])
  387. end
  388. end
  389. end
  390. log_stat(false)
  391. if cur == 0 then
  392. log_stat(true)
  393. end
  394. elseif type(args) == 'string' then
  395. logger.infox(rspamd_config, 'skip expiry step: %s', args)
  396. end
  397. end
  398. lredis.exec_redis_script(cls.script,
  399. {ev_base = ev_base, is_write = true},
  400. redis_step_cb,
  401. {'RS*_*', cls.expiry}
  402. )
  403. end
  404. rspamd_config:add_on_load(function (_, ev_base, worker)
  405. -- Exit unless we're the first 'controller' worker
  406. if not worker:is_primary_controller() then return end
  407. local unique_redis_params = {}
  408. -- Push redis script to all unique redis servers
  409. for _,cls in ipairs(settings.classifiers) do
  410. if not unique_redis_params[cls.redis_params.hash] then
  411. unique_redis_params[cls.redis_params.hash] = cls.redis_params
  412. end
  413. end
  414. for h,rp in pairs(unique_redis_params) do
  415. local script_id = lredis.add_redis_script(lutil.template(expiry_script,
  416. template), rp)
  417. for _,cls in ipairs(settings.classifiers) do
  418. if cls.redis_params.hash == h then
  419. cls.script = script_id
  420. end
  421. end
  422. end
  423. -- Expire tokens at regular intervals
  424. for _,cls in ipairs(settings.classifiers) do
  425. rspamd_config:add_periodic(ev_base,
  426. settings['interval'],
  427. function ()
  428. expire_step(cls, ev_base, worker)
  429. return true
  430. end, true)
  431. end
  432. end)