You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_stat.lua 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. --[[
  2. Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local sqlite3 = require "rspamd_sqlite3"
  15. local util = require "rspamd_util"
  16. local lua_redis = require "lua_redis"
  17. local exports = {}
  18. local N = "stat_tools" -- luacheck: ignore (maybe unused)
  19. -- Performs synchronous conversion of redis schema
  20. local function convert_bayes_schema(redis_params, symbol_spam, symbol_ham, expire)
  21. -- Old schema is the following one:
  22. -- Keys are named <symbol>[<user>]
  23. -- Elements are placed within hash:
  24. -- BAYES_SPAM -> {<id1>: <num_hits>, <id2>: <num_hits> ...}
  25. -- In new schema it is changed to a more extensible schema:
  26. -- Keys are named RS[<user>]_<id> -> {'H': <ham_hits>, 'S': <spam_hits>}
  27. -- So we can expire individual records, measure most popular elements by zranges,
  28. -- add new fields, such as tokens etc
  29. local res,conn = lua_redis.redis_connect_sync(redis_params, true)
  30. if not res then
  31. logger.errx("cannot connect to redis server")
  32. return false
  33. end
  34. -- KEYS[1]: key to check (e.g. 'BAYES_SPAM')
  35. -- KEYS[2]: hash key ('S' or 'H')
  36. -- KEYS[3]: expire
  37. local lua_script = [[
  38. local keys = redis.call('SMEMBERS', KEYS[1]..'_keys')
  39. local nconverted = 0
  40. for _,k in ipairs(keys) do
  41. local elts = redis.call('HGETALL', k)
  42. local neutral_prefix = string.gsub(k, KEYS[1], 'RS')
  43. local real_key
  44. for i,v in ipairs(elts) do
  45. if i % 2 ~= 0 then
  46. real_key = v
  47. else
  48. local nkey = string.format('%s_%s', neutral_prefix, real_key)
  49. redis.call('HSET', nkey, KEYS[2], v)
  50. if KEYS[3] and tonumber(KEYS[3]) > 0 then
  51. redis.call('EXPIRE', nkey, KEYS[3])
  52. end
  53. nconverted = nconverted + 1
  54. end
  55. end
  56. end
  57. return nconverted
  58. ]]
  59. conn:add_cmd('EVAL', {lua_script, '3', symbol_spam, 'S', tostring(expire)})
  60. local ret
  61. ret, res = conn:exec()
  62. if not ret then
  63. logger.errx('error converting symbol %s: %s', symbol_spam, res)
  64. return false
  65. else
  66. logger.messagex('converted %s elements from symbol %s', res, symbol_spam)
  67. end
  68. conn:add_cmd('EVAL', {lua_script, '3', symbol_ham, 'H', tostring(expire)})
  69. ret, res = conn:exec()
  70. if not ret then
  71. logger.errx('error converting symbol %s: %s', symbol_ham, res)
  72. return false
  73. else
  74. logger.messagex('converted %s elements from symbol %s', res, symbol_ham)
  75. end
  76. -- We can now convert metadata: set + learned + version
  77. -- KEYS[1]: key to check (e.g. 'BAYES_SPAM')
  78. -- KEYS[2]: learn key (e.g. 'learns_spam' or 'learns_ham')
  79. lua_script = [[
  80. local keys = redis.call('SMEMBERS', KEYS[1]..'_keys')
  81. for _,k in ipairs(keys) do
  82. local learns = redis.call('HGET', k, 'learns')
  83. local neutral_prefix = string.gsub(k, KEYS[1], 'RS')
  84. redis.call('HSET', neutral_prefix, KEYS[2], learns)
  85. redis.call('SADD', KEYS[1]..'_keys', neutral_prefix)
  86. redis.call('SREM', KEYS[1]..'_keys', k)
  87. redis.call('DEL', k)
  88. redis.call('SET', KEYS[1]..'_version', '2')
  89. end
  90. ]]
  91. conn:add_cmd('EVAL', {lua_script, '2', symbol_spam, 'learns_spam'})
  92. ret = conn:exec()
  93. if not ret then
  94. logger.errx('error converting metadata for symbol %s', symbol_spam)
  95. return false
  96. end
  97. conn:add_cmd('EVAL', {lua_script, '2', symbol_ham, 'learns_ham'})
  98. ret = conn:exec()
  99. if not ret then
  100. logger.errx('error converting metadata for symbol %s', symbol_ham)
  101. return false
  102. end
  103. return true
  104. end
  105. exports.convert_bayes_schema = convert_bayes_schema
  106. -- It now accepts both ham and spam databases
  107. -- parameters:
  108. -- redis_params - how do we connect to a redis server
  109. -- sqlite_db_spam - name for sqlite database with spam tokens
  110. -- sqlite_db_ham - name for sqlite database with ham tokens
  111. -- symbol_ham - name for symbol representing spam, e.g. BAYES_SPAM
  112. -- symbol_spam - name for symbol representing ham, e.g. BAYES_HAM
  113. -- learn_cache_spam - name for sqlite database with spam learn cache
  114. -- learn_cache_ham - name for sqlite database with ham learn cache
  115. -- reset_previous - if true, then the old database is flushed (slow)
  116. local function convert_sqlite_to_redis(redis_params,
  117. sqlite_db_spam, sqlite_db_ham, symbol_spam, symbol_ham,
  118. learn_cache_db, expire, reset_previous)
  119. local nusers = 0
  120. local lim = 1000 -- Update each 1000 tokens
  121. local users_map = {}
  122. local converted = 0
  123. local db_spam = sqlite3.open(sqlite_db_spam)
  124. if not db_spam then
  125. logger.errx('Cannot open source db: %s', sqlite_db_spam)
  126. return false
  127. end
  128. local db_ham = sqlite3.open(sqlite_db_ham)
  129. if not db_ham then
  130. logger.errx('Cannot open source db: %s', sqlite_db_ham)
  131. return false
  132. end
  133. local res,conn = lua_redis.redis_connect_sync(redis_params, true)
  134. if not res then
  135. logger.errx("cannot connect to redis server")
  136. return false
  137. end
  138. if reset_previous then
  139. -- Do a more complicated cleanup
  140. -- execute a lua script that cleans up data
  141. local script = [[
  142. local members = redis.call('SMEMBERS', KEYS[1]..'_keys')
  143. for _,prefix in ipairs(members) do
  144. local keys = redis.call('KEYS', prefix..'*')
  145. redis.call('DEL', keys)
  146. end
  147. ]]
  148. -- Common keys
  149. for _,sym in ipairs({symbol_spam, symbol_ham}) do
  150. logger.messagex('Cleaning up old data for %s', sym)
  151. conn:add_cmd('EVAL', {script, '1', sym})
  152. conn:exec()
  153. conn:add_cmd('DEL', {sym .. "_version"})
  154. conn:add_cmd('DEL', {sym .. "_keys"})
  155. conn:exec()
  156. end
  157. if learn_cache_db then
  158. -- Cleanup learned_cache
  159. logger.messagex('Cleaning up old data learned cache')
  160. conn:add_cmd('DEL', {"learned_ids"})
  161. conn:exec()
  162. end
  163. end
  164. local function convert_db(db, is_spam)
  165. -- Map users and languages
  166. local what = 'ham'
  167. if is_spam then
  168. what = 'spam'
  169. end
  170. local learns = {}
  171. db:sql('BEGIN;')
  172. -- Fill users mapping
  173. for row in db:rows('SELECT * FROM users;') do
  174. if row.id == '0' then
  175. users_map[row.id] = ''
  176. else
  177. users_map[row.id] = row.name
  178. end
  179. learns[row.id] = row.learns
  180. nusers = nusers + 1
  181. end
  182. -- Workaround for old databases
  183. for row in db:rows('SELECT * FROM languages') do
  184. if learns['0'] then
  185. learns['0'] = learns['0'] + row.learns
  186. else
  187. learns['0'] = row.learns
  188. end
  189. end
  190. local function send_batch(tokens, prefix)
  191. -- We use the new schema: RS[user]_token -> H=ham count
  192. -- S=spam count
  193. local hash_key = 'H'
  194. if is_spam then
  195. hash_key = 'S'
  196. end
  197. for _,tok in ipairs(tokens) do
  198. -- tok schema:
  199. -- tok[1] = token_id (uint64 represented as a string)
  200. -- tok[2] = token value (number)
  201. -- tok[3] = user_map[user_id] or ''
  202. local rkey = string.format('%s%s_%s', prefix, tok[3], tok[1])
  203. conn:add_cmd('HINCRBYFLOAT', {rkey, hash_key, tostring(tok[2])})
  204. if expire and expire ~= 0 then
  205. conn:add_cmd('EXPIRE', {rkey, tostring(expire)})
  206. end
  207. end
  208. return conn:exec()
  209. end
  210. -- Fill tokens, sending data to redis each `lim` records
  211. local ntokens = db:query('SELECT count(*) as c FROM tokens')['c']
  212. local tokens = {}
  213. local num = 0
  214. local total = 0
  215. for row in db:rows('SELECT token,value,user FROM tokens;') do
  216. local user = ''
  217. if row.user ~= 0 and users_map[row.user] then
  218. user = users_map[row.user]
  219. end
  220. table.insert(tokens, {row.token, row.value, user})
  221. num = num + 1
  222. total = total + 1
  223. if num > lim then
  224. -- TODO: we use the default 'RS' prefix, it can be false in case of
  225. -- classifiers with labels
  226. local ret,err_str = send_batch(tokens, 'RS')
  227. if not ret then
  228. logger.errx('Cannot send tokens to the redis server: ' .. err_str)
  229. db:sql('COMMIT;')
  230. return false
  231. end
  232. num = 0
  233. tokens = {}
  234. end
  235. io.write(string.format('Processed batch %s: %s/%s\r', what, total, ntokens))
  236. end
  237. -- Last batch
  238. if #tokens > 0 then
  239. local ret,err_str = send_batch(tokens, 'RS')
  240. if not ret then
  241. logger.errx('Cannot send tokens to the redis server: ' .. err_str)
  242. db:sql('COMMIT;')
  243. return false
  244. end
  245. io.write(string.format('Processed batch %s: %s/%s\r', what, total, ntokens))
  246. end
  247. io.write('\n')
  248. converted = converted + total
  249. -- Close DB
  250. db:sql('COMMIT;')
  251. local symbol = symbol_ham
  252. local learns_elt = "learns_ham"
  253. if is_spam then
  254. symbol = symbol_spam
  255. learns_elt = "learns_spam"
  256. end
  257. for id,learned in pairs(learns) do
  258. local user = users_map[id]
  259. if not conn:add_cmd('HSET', {'RS' .. user, learns_elt, learned}) then
  260. logger.errx('Cannot update learns for user: ' .. user)
  261. return false
  262. end
  263. if not conn:add_cmd('SADD', {symbol .. '_keys', 'RS' .. user}) then
  264. logger.errx('Cannot update learns for user: ' .. user)
  265. return false
  266. end
  267. end
  268. -- Set version
  269. conn:add_cmd('SET', {symbol..'_version', '2'})
  270. return conn:exec()
  271. end
  272. logger.messagex('Convert spam tokens')
  273. if not convert_db(db_spam, true) then
  274. return false
  275. end
  276. logger.messagex('Convert ham tokens')
  277. if not convert_db(db_ham, false) then
  278. return false
  279. end
  280. if learn_cache_db then
  281. logger.messagex('Convert learned ids from %s', learn_cache_db)
  282. local db = sqlite3.open(learn_cache_db)
  283. local ret = true
  284. local total = 0
  285. if not db then
  286. logger.errx('Cannot open cache database: ' .. learn_cache_db)
  287. return false
  288. end
  289. db:sql('BEGIN;')
  290. for row in db:rows('SELECT * FROM learns;') do
  291. local is_spam
  292. local digest = tostring(util.encode_base32(row.digest))
  293. if row.flag == '0' then
  294. is_spam = '-1'
  295. else
  296. is_spam = '1'
  297. end
  298. if not conn:add_cmd('HSET', {'learned_ids', digest, is_spam}) then
  299. logger.errx('Cannot add hash: ' .. digest)
  300. ret = false
  301. else
  302. total = total + 1
  303. end
  304. end
  305. db:sql('COMMIT;')
  306. if ret then
  307. conn:exec()
  308. end
  309. if ret then
  310. logger.messagex('Converted %s cached items from sqlite3 learned cache to redis',
  311. total)
  312. else
  313. logger.errx('Error occurred during sending data to redis')
  314. end
  315. end
  316. logger.messagex('Migrated %s tokens for %s users for symbols (%s, %s)',
  317. converted, nusers, symbol_spam, symbol_ham)
  318. return true
  319. end
  320. exports.convert_sqlite_to_redis = convert_sqlite_to_redis
  321. -- Loads sqlite3 based classifiers and output data in form of array of objects:
  322. -- [
  323. -- {
  324. -- symbol_spam = XXX
  325. -- symbol_ham = YYY
  326. -- db_spam = XXX.sqlite
  327. -- db_ham = YYY.sqlite
  328. -- learn_cahe = ZZZ.sqlite
  329. -- per_user = true/false
  330. -- label = str
  331. -- }
  332. -- ]
  333. local function load_sqlite_config(cfg)
  334. local result = {}
  335. local function parse_classifier(cls)
  336. local tbl = {}
  337. if cls.cache then
  338. local cache = cls.cache
  339. if cache.type == 'sqlite3' and (cache.file or cache.path) then
  340. tbl.learn_cache = (cache.file or cache.path)
  341. end
  342. end
  343. if cls.per_user then
  344. tbl.per_user = cls.per_user
  345. end
  346. if cls.label then
  347. tbl.label = cls.label
  348. end
  349. local statfiles = cls.statfile
  350. for _,stf in ipairs(statfiles) do
  351. local path = (stf.file or stf.path or stf.db or stf.dbname)
  352. local symbol = stf.symbol or 'undefined'
  353. if not path then
  354. logger.errx('no path defined for statfile %s', symbol)
  355. else
  356. local spam
  357. if stf.spam then
  358. spam = stf.spam
  359. else
  360. if string.match(symbol:upper(), 'SPAM') then
  361. spam = true
  362. else
  363. spam = false
  364. end
  365. end
  366. if spam then
  367. tbl.symbol_spam = symbol
  368. tbl.db_spam = path
  369. else
  370. tbl.symbol_ham = symbol
  371. tbl.db_ham = path
  372. end
  373. end
  374. end
  375. if tbl.symbol_spam and tbl.symbol_ham and tbl.db_ham and tbl.db_spam then
  376. table.insert(result, tbl)
  377. end
  378. end
  379. local classifier = cfg.classifier
  380. if classifier then
  381. if classifier[1] then
  382. for _,cls in ipairs(classifier) do
  383. if cls.bayes then cls = cls.bayes end
  384. if cls.backend and cls.backend == 'sqlite3' then
  385. parse_classifier(cls)
  386. end
  387. end
  388. else
  389. if classifier.bayes then
  390. classifier = classifier.bayes
  391. if classifier[1] then
  392. for _,cls in ipairs(classifier) do
  393. if cls.backend and cls.backend == 'sqlite3' then
  394. parse_classifier(cls)
  395. end
  396. end
  397. else
  398. if classifier.backend and classifier.backend == 'sqlite3' then
  399. parse_classifier(classifier)
  400. end
  401. end
  402. end
  403. end
  404. end
  405. return result
  406. end
  407. exports.load_sqlite_config = load_sqlite_config
  408. -- A helper method that suggests a user how to configure Redis based
  409. -- classifier based on the existing sqlite classifier
  410. local function redis_classifier_from_sqlite(sqlite_classifier, expire)
  411. local result = {
  412. new_schema = true,
  413. backend = 'redis',
  414. cache = {
  415. backend = 'redis'
  416. },
  417. statfile = {
  418. [sqlite_classifier.symbol_spam] = {
  419. spam = true
  420. },
  421. [sqlite_classifier.symbol_ham] = {
  422. spam = false
  423. }
  424. }
  425. }
  426. if expire then
  427. result.expire = expire
  428. end
  429. return {classifier = {bayes = result}}
  430. end
  431. exports.redis_classifier_from_sqlite = redis_classifier_from_sqlite
  432. return exports