You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_stat.lua 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. --[[
  2. Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local sqlite3 = require "rspamd_sqlite3"
  15. local util = require "rspamd_util"
  16. local lua_redis = require "lua_redis"
  17. local exports = {}
  18. local N = "stat_tools" -- luacheck: ignore (maybe unused)
  19. -- Performs synchronous conversion of redis schema
  20. local function convert_bayes_schema(redis_params, symbol_spam, symbol_ham, expire)
  21. -- Old schema is the following one:
  22. -- Keys are named <symbol>[<user>]
  23. -- Elements are placed within hash:
  24. -- BAYES_SPAM -> {<id1>: <num_hits>, <id2>: <num_hits> ...}
  25. -- In new schema it is changed to a more extensible schema:
  26. -- Keys are named RS[<user>]_<id> -> {'H': <ham_hits>, 'S': <spam_hits>}
  27. -- So we can expire individual records, measure most popular elements by zranges,
  28. -- add new fields, such as tokens etc
  29. local res,conn = lua_redis.redis_connect_sync(redis_params, true)
  30. if not res then
  31. logger.errx("cannot connect to redis server")
  32. return false
  33. end
  34. -- KEYS[1]: key to check (e.g. 'BAYES_SPAM')
  35. -- KEYS[2]: hash key ('S' or 'H')
  36. -- KEYS[3]: expire
  37. local lua_script = [[
  38. redis.replicate_commands()
  39. local keys = redis.call('SMEMBERS', KEYS[1]..'_keys')
  40. local nconverted = 0
  41. for _,k in ipairs(keys) do
  42. local cursor = redis.call('HSCAN', k, 0)
  43. local neutral_prefix = string.gsub(k, KEYS[1], 'RS')
  44. local elts
  45. while cursor[1] ~= "0" do
  46. elts = cursor[2]
  47. cursor = redis.call('HSCAN', k, cursor[1])
  48. local real_key
  49. for i,v in ipairs(elts) do
  50. if i % 2 ~= 0 then
  51. real_key = v
  52. else
  53. local nkey = string.format('%s_%s', neutral_prefix, real_key)
  54. redis.call('HSET', nkey, KEYS[2], v)
  55. if KEYS[3] and tonumber(KEYS[3]) > 0 then
  56. redis.call('EXPIRE', nkey, KEYS[3])
  57. end
  58. nconverted = nconverted + 1
  59. end
  60. end
  61. end
  62. end
  63. return nconverted
  64. ]]
  65. conn:add_cmd('EVAL', {lua_script, '3', symbol_spam, 'S', tostring(expire)})
  66. local ret
  67. ret, res = conn:exec()
  68. if not ret then
  69. logger.errx('error converting symbol %s: %s', symbol_spam, res)
  70. return false
  71. else
  72. logger.messagex('converted %s elements from symbol %s', res, symbol_spam)
  73. end
  74. conn:add_cmd('EVAL', {lua_script, '3', symbol_ham, 'H', tostring(expire)})
  75. ret, res = conn:exec()
  76. if not ret then
  77. logger.errx('error converting symbol %s: %s', symbol_ham, res)
  78. return false
  79. else
  80. logger.messagex('converted %s elements from symbol %s', res, symbol_ham)
  81. end
  82. -- We can now convert metadata: set + learned + version
  83. -- KEYS[1]: key to check (e.g. 'BAYES_SPAM')
  84. -- KEYS[2]: learn key (e.g. 'learns_spam' or 'learns_ham')
  85. lua_script = [[
  86. local keys = redis.call('SMEMBERS', KEYS[1]..'_keys')
  87. for _,k in ipairs(keys) do
  88. local learns = redis.call('HGET', k, 'learns')
  89. local neutral_prefix = string.gsub(k, KEYS[1], 'RS')
  90. redis.call('HSET', neutral_prefix, KEYS[2], learns)
  91. redis.call('SADD', KEYS[1]..'_keys', neutral_prefix)
  92. redis.call('SREM', KEYS[1]..'_keys', k)
  93. redis.call('DEL', k)
  94. redis.call('SET', KEYS[1]..'_version', '2')
  95. end
  96. ]]
  97. conn:add_cmd('EVAL', {lua_script, '2', symbol_spam, 'learns_spam'})
  98. ret = conn:exec()
  99. if not ret then
  100. logger.errx('error converting metadata for symbol %s', symbol_spam)
  101. return false
  102. end
  103. conn:add_cmd('EVAL', {lua_script, '2', symbol_ham, 'learns_ham'})
  104. ret = conn:exec()
  105. if not ret then
  106. logger.errx('error converting metadata for symbol %s', symbol_ham)
  107. return false
  108. end
  109. return true
  110. end
  111. exports.convert_bayes_schema = convert_bayes_schema
  112. -- It now accepts both ham and spam databases
  113. -- parameters:
  114. -- redis_params - how do we connect to a redis server
  115. -- sqlite_db_spam - name for sqlite database with spam tokens
  116. -- sqlite_db_ham - name for sqlite database with ham tokens
  117. -- symbol_ham - name for symbol representing spam, e.g. BAYES_SPAM
  118. -- symbol_spam - name for symbol representing ham, e.g. BAYES_HAM
  119. -- learn_cache_spam - name for sqlite database with spam learn cache
  120. -- learn_cache_ham - name for sqlite database with ham learn cache
  121. -- reset_previous - if true, then the old database is flushed (slow)
  122. local function convert_sqlite_to_redis(redis_params,
  123. sqlite_db_spam, sqlite_db_ham, symbol_spam, symbol_ham,
  124. learn_cache_db, expire, reset_previous)
  125. local nusers = 0
  126. local lim = 1000 -- Update each 1000 tokens
  127. local users_map = {}
  128. local converted = 0
  129. local db_spam = sqlite3.open(sqlite_db_spam)
  130. if not db_spam then
  131. logger.errx('Cannot open source db: %s', sqlite_db_spam)
  132. return false
  133. end
  134. local db_ham = sqlite3.open(sqlite_db_ham)
  135. if not db_ham then
  136. logger.errx('Cannot open source db: %s', sqlite_db_ham)
  137. return false
  138. end
  139. local res,conn = lua_redis.redis_connect_sync(redis_params, true)
  140. if not res then
  141. logger.errx("cannot connect to redis server")
  142. return false
  143. end
  144. if reset_previous then
  145. -- Do a more complicated cleanup
  146. -- execute a lua script that cleans up data
  147. local script = [[
  148. local members = redis.call('SMEMBERS', KEYS[1]..'_keys')
  149. for _,prefix in ipairs(members) do
  150. local keys = redis.call('KEYS', prefix..'*')
  151. redis.call('DEL', keys)
  152. end
  153. ]]
  154. -- Common keys
  155. for _,sym in ipairs({symbol_spam, symbol_ham}) do
  156. logger.messagex('Cleaning up old data for %s', sym)
  157. conn:add_cmd('EVAL', {script, '1', sym})
  158. conn:exec()
  159. conn:add_cmd('DEL', {sym .. "_version"})
  160. conn:add_cmd('DEL', {sym .. "_keys"})
  161. conn:exec()
  162. end
  163. if learn_cache_db then
  164. -- Cleanup learned_cache
  165. logger.messagex('Cleaning up old data learned cache')
  166. conn:add_cmd('DEL', {"learned_ids"})
  167. conn:exec()
  168. end
  169. end
  170. local function convert_db(db, is_spam)
  171. -- Map users and languages
  172. local what = 'ham'
  173. if is_spam then
  174. what = 'spam'
  175. end
  176. local learns = {}
  177. db:sql('BEGIN;')
  178. -- Fill users mapping
  179. for row in db:rows('SELECT * FROM users;') do
  180. if row.id == '0' then
  181. users_map[row.id] = ''
  182. else
  183. users_map[row.id] = row.name
  184. end
  185. learns[row.id] = row.learns
  186. nusers = nusers + 1
  187. end
  188. -- Workaround for old databases
  189. for row in db:rows('SELECT * FROM languages') do
  190. if learns['0'] then
  191. learns['0'] = learns['0'] + row.learns
  192. else
  193. learns['0'] = row.learns
  194. end
  195. end
  196. local function send_batch(tokens, prefix)
  197. -- We use the new schema: RS[user]_token -> H=ham count
  198. -- S=spam count
  199. local hash_key = 'H'
  200. if is_spam then
  201. hash_key = 'S'
  202. end
  203. for _,tok in ipairs(tokens) do
  204. -- tok schema:
  205. -- tok[1] = token_id (uint64 represented as a string)
  206. -- tok[2] = token value (number)
  207. -- tok[3] = user_map[user_id] or ''
  208. local rkey = string.format('%s%s_%s', prefix, tok[3], tok[1])
  209. conn:add_cmd('HINCRBYFLOAT', {rkey, hash_key, tostring(tok[2])})
  210. if expire and expire ~= 0 then
  211. conn:add_cmd('EXPIRE', {rkey, tostring(expire)})
  212. end
  213. end
  214. return conn:exec()
  215. end
  216. -- Fill tokens, sending data to redis each `lim` records
  217. local ntokens = db:query('SELECT count(*) as c FROM tokens')['c']
  218. local tokens = {}
  219. local num = 0
  220. local total = 0
  221. for row in db:rows('SELECT token,value,user FROM tokens;') do
  222. local user = ''
  223. if row.user ~= 0 and users_map[row.user] then
  224. user = users_map[row.user]
  225. end
  226. table.insert(tokens, {row.token, row.value, user})
  227. num = num + 1
  228. total = total + 1
  229. if num > lim then
  230. -- TODO: we use the default 'RS' prefix, it can be false in case of
  231. -- classifiers with labels
  232. local ret,err_str = send_batch(tokens, 'RS')
  233. if not ret then
  234. logger.errx('Cannot send tokens to the redis server: ' .. err_str)
  235. db:sql('COMMIT;')
  236. return false
  237. end
  238. num = 0
  239. tokens = {}
  240. end
  241. io.write(string.format('Processed batch %s: %s/%s\r', what, total, ntokens))
  242. end
  243. -- Last batch
  244. if #tokens > 0 then
  245. local ret,err_str = send_batch(tokens, 'RS')
  246. if not ret then
  247. logger.errx('Cannot send tokens to the redis server: ' .. err_str)
  248. db:sql('COMMIT;')
  249. return false
  250. end
  251. io.write(string.format('Processed batch %s: %s/%s\r', what, total, ntokens))
  252. end
  253. io.write('\n')
  254. converted = converted + total
  255. -- Close DB
  256. db:sql('COMMIT;')
  257. local symbol = symbol_ham
  258. local learns_elt = "learns_ham"
  259. if is_spam then
  260. symbol = symbol_spam
  261. learns_elt = "learns_spam"
  262. end
  263. for id,learned in pairs(learns) do
  264. local user = users_map[id]
  265. if not conn:add_cmd('HSET', {'RS' .. user, learns_elt, learned}) then
  266. logger.errx('Cannot update learns for user: ' .. user)
  267. return false
  268. end
  269. if not conn:add_cmd('SADD', {symbol .. '_keys', 'RS' .. user}) then
  270. logger.errx('Cannot update learns for user: ' .. user)
  271. return false
  272. end
  273. end
  274. -- Set version
  275. conn:add_cmd('SET', {symbol..'_version', '2'})
  276. return conn:exec()
  277. end
  278. logger.messagex('Convert spam tokens')
  279. if not convert_db(db_spam, true) then
  280. return false
  281. end
  282. logger.messagex('Convert ham tokens')
  283. if not convert_db(db_ham, false) then
  284. return false
  285. end
  286. if learn_cache_db then
  287. logger.messagex('Convert learned ids from %s', learn_cache_db)
  288. local db = sqlite3.open(learn_cache_db)
  289. local ret = true
  290. local total = 0
  291. if not db then
  292. logger.errx('Cannot open cache database: ' .. learn_cache_db)
  293. return false
  294. end
  295. db:sql('BEGIN;')
  296. for row in db:rows('SELECT * FROM learns;') do
  297. local is_spam
  298. local digest = tostring(util.encode_base32(row.digest))
  299. if row.flag == '0' then
  300. is_spam = '-1'
  301. else
  302. is_spam = '1'
  303. end
  304. if not conn:add_cmd('HSET', {'learned_ids', digest, is_spam}) then
  305. logger.errx('Cannot add hash: ' .. digest)
  306. ret = false
  307. else
  308. total = total + 1
  309. end
  310. end
  311. db:sql('COMMIT;')
  312. if ret then
  313. conn:exec()
  314. end
  315. if ret then
  316. logger.messagex('Converted %s cached items from sqlite3 learned cache to redis',
  317. total)
  318. else
  319. logger.errx('Error occurred during sending data to redis')
  320. end
  321. end
  322. logger.messagex('Migrated %s tokens for %s users for symbols (%s, %s)',
  323. converted, nusers, symbol_spam, symbol_ham)
  324. return true
  325. end
  326. exports.convert_sqlite_to_redis = convert_sqlite_to_redis
  327. -- Loads sqlite3 based classifiers and output data in form of array of objects:
  328. -- [
  329. -- {
  330. -- symbol_spam = XXX
  331. -- symbol_ham = YYY
  332. -- db_spam = XXX.sqlite
  333. -- db_ham = YYY.sqlite
  334. -- learn_cahe = ZZZ.sqlite
  335. -- per_user = true/false
  336. -- label = str
  337. -- }
  338. -- ]
  339. local function load_sqlite_config(cfg)
  340. local result = {}
  341. local function parse_classifier(cls)
  342. local tbl = {}
  343. if cls.cache then
  344. local cache = cls.cache
  345. if cache.type == 'sqlite3' and (cache.file or cache.path) then
  346. tbl.learn_cache = (cache.file or cache.path)
  347. end
  348. end
  349. if cls.per_user then
  350. tbl.per_user = cls.per_user
  351. end
  352. if cls.label then
  353. tbl.label = cls.label
  354. end
  355. local statfiles = cls.statfile
  356. for _,stf in ipairs(statfiles) do
  357. local path = (stf.file or stf.path or stf.db or stf.dbname)
  358. local symbol = stf.symbol or 'undefined'
  359. if not path then
  360. logger.errx('no path defined for statfile %s', symbol)
  361. else
  362. local spam
  363. if stf.spam then
  364. spam = stf.spam
  365. else
  366. if string.match(symbol:upper(), 'SPAM') then
  367. spam = true
  368. else
  369. spam = false
  370. end
  371. end
  372. if spam then
  373. tbl.symbol_spam = symbol
  374. tbl.db_spam = path
  375. else
  376. tbl.symbol_ham = symbol
  377. tbl.db_ham = path
  378. end
  379. end
  380. end
  381. if tbl.symbol_spam and tbl.symbol_ham and tbl.db_ham and tbl.db_spam then
  382. table.insert(result, tbl)
  383. end
  384. end
  385. local classifier = cfg.classifier
  386. if classifier then
  387. if classifier[1] then
  388. for _,cls in ipairs(classifier) do
  389. if cls.bayes then cls = cls.bayes end
  390. if cls.backend and cls.backend == 'sqlite3' then
  391. parse_classifier(cls)
  392. end
  393. end
  394. else
  395. if classifier.bayes then
  396. classifier = classifier.bayes
  397. if classifier[1] then
  398. for _,cls in ipairs(classifier) do
  399. if cls.backend and cls.backend == 'sqlite3' then
  400. parse_classifier(cls)
  401. end
  402. end
  403. else
  404. if classifier.backend and classifier.backend == 'sqlite3' then
  405. parse_classifier(classifier)
  406. end
  407. end
  408. end
  409. end
  410. end
  411. return result
  412. end
  413. exports.load_sqlite_config = load_sqlite_config
  414. -- A helper method that suggests a user how to configure Redis based
  415. -- classifier based on the existing sqlite classifier
  416. local function redis_classifier_from_sqlite(sqlite_classifier, expire)
  417. local result = {
  418. new_schema = true,
  419. backend = 'redis',
  420. cache = {
  421. backend = 'redis'
  422. },
  423. statfile = {
  424. [sqlite_classifier.symbol_spam] = {
  425. spam = true
  426. },
  427. [sqlite_classifier.symbol_ham] = {
  428. spam = false
  429. }
  430. }
  431. }
  432. if expire then
  433. result.expire = expire
  434. end
  435. return {classifier = {bayes = result}}
  436. end
  437. exports.redis_classifier_from_sqlite = redis_classifier_from_sqlite
  438. return exports