You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_stat.lua 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869
  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. --[[[
  14. -- @module lua_stat
  15. -- This module contains helper functions for supporting statistics
  16. --]]
  17. local logger = require "rspamd_logger"
  18. local sqlite3 = require "rspamd_sqlite3"
  19. local util = require "rspamd_util"
  20. local lua_redis = require "lua_redis"
  21. local lua_util = require "lua_util"
  22. local exports = {}
  23. local N = "stat_tools" -- luacheck: ignore (maybe unused)
  24. -- Performs synchronous conversion of redis schema
  25. local function convert_bayes_schema(redis_params, symbol_spam, symbol_ham, expire)
  26. -- Old schema is the following one:
  27. -- Keys are named <symbol>[<user>]
  28. -- Elements are placed within hash:
  29. -- BAYES_SPAM -> {<id1>: <num_hits>, <id2>: <num_hits> ...}
  30. -- In new schema it is changed to a more extensible schema:
  31. -- Keys are named RS[<user>]_<id> -> {'H': <ham_hits>, 'S': <spam_hits>}
  32. -- So we can expire individual records, measure most popular elements by zranges,
  33. -- add new fields, such as tokens etc
  34. local res, conn = lua_redis.redis_connect_sync(redis_params, true)
  35. if not res then
  36. logger.errx("cannot connect to redis server")
  37. return false
  38. end
  39. -- KEYS[1]: key to check (e.g. 'BAYES_SPAM')
  40. -- KEYS[2]: hash key ('S' or 'H')
  41. -- KEYS[3]: expire
  42. local lua_script = [[
  43. redis.replicate_commands()
  44. local keys = redis.call('SMEMBERS', KEYS[1]..'_keys')
  45. local nconverted = 0
  46. for _,k in ipairs(keys) do
  47. local cursor = redis.call('HSCAN', k, 0)
  48. local neutral_prefix = string.gsub(k, KEYS[1], 'RS')
  49. local elts
  50. while cursor[1] ~= "0" do
  51. elts = cursor[2]
  52. cursor = redis.call('HSCAN', k, cursor[1])
  53. local real_key
  54. for i,v in ipairs(elts) do
  55. if i % 2 ~= 0 then
  56. real_key = v
  57. else
  58. local nkey = string.format('%s_%s', neutral_prefix, real_key)
  59. redis.call('HSET', nkey, KEYS[2], v)
  60. if KEYS[3] and tonumber(KEYS[3]) > 0 then
  61. redis.call('EXPIRE', nkey, KEYS[3])
  62. end
  63. nconverted = nconverted + 1
  64. end
  65. end
  66. end
  67. end
  68. return nconverted
  69. ]]
  70. conn:add_cmd('EVAL', { lua_script, '3', symbol_spam, 'S', tostring(expire) })
  71. local ret
  72. ret, res = conn:exec()
  73. if not ret then
  74. logger.errx('error converting symbol %s: %s', symbol_spam, res)
  75. return false
  76. else
  77. logger.messagex('converted %s elements from symbol %s', res, symbol_spam)
  78. end
  79. conn:add_cmd('EVAL', { lua_script, '3', symbol_ham, 'H', tostring(expire) })
  80. ret, res = conn:exec()
  81. if not ret then
  82. logger.errx('error converting symbol %s: %s', symbol_ham, res)
  83. return false
  84. else
  85. logger.messagex('converted %s elements from symbol %s', res, symbol_ham)
  86. end
  87. -- We can now convert metadata: set + learned + version
  88. -- KEYS[1]: key to check (e.g. 'BAYES_SPAM')
  89. -- KEYS[2]: learn key (e.g. 'learns_spam' or 'learns_ham')
  90. lua_script = [[
  91. local keys = redis.call('SMEMBERS', KEYS[1]..'_keys')
  92. for _,k in ipairs(keys) do
  93. local learns = redis.call('HGET', k, 'learns') or 0
  94. local neutral_prefix = string.gsub(k, KEYS[1], 'RS')
  95. redis.call('HSET', neutral_prefix, KEYS[2], learns)
  96. redis.call('SADD', KEYS[1]..'_keys', neutral_prefix)
  97. redis.call('SREM', KEYS[1]..'_keys', k)
  98. redis.call('DEL', KEYS[1])
  99. redis.call('SET', k ..'_version', '2')
  100. end
  101. ]]
  102. conn:add_cmd('EVAL', { lua_script, '2', symbol_spam, 'learns_spam' })
  103. ret, res = conn:exec()
  104. if not ret then
  105. logger.errx('error converting metadata for symbol %s: %s', symbol_spam, res)
  106. return false
  107. end
  108. conn:add_cmd('EVAL', { lua_script, '2', symbol_ham, 'learns_ham' })
  109. ret, res = conn:exec()
  110. if not ret then
  111. logger.errx('error converting metadata for symbol %s', symbol_ham, res)
  112. return false
  113. end
  114. return true
  115. end
  116. exports.convert_bayes_schema = convert_bayes_schema
  117. -- It now accepts both ham and spam databases
  118. -- parameters:
  119. -- redis_params - how do we connect to a redis server
  120. -- sqlite_db_spam - name for sqlite database with spam tokens
  121. -- sqlite_db_ham - name for sqlite database with ham tokens
  122. -- symbol_ham - name for symbol representing spam, e.g. BAYES_SPAM
  123. -- symbol_spam - name for symbol representing ham, e.g. BAYES_HAM
  124. -- learn_cache_spam - name for sqlite database with spam learn cache
  125. -- learn_cache_ham - name for sqlite database with ham learn cache
  126. -- reset_previous - if true, then the old database is flushed (slow)
  127. local function convert_sqlite_to_redis(redis_params,
  128. sqlite_db_spam, sqlite_db_ham, symbol_spam, symbol_ham,
  129. learn_cache_db, expire, reset_previous)
  130. local nusers = 0
  131. local lim = 1000 -- Update each 1000 tokens
  132. local users_map = {}
  133. local converted = 0
  134. local db_spam = sqlite3.open(sqlite_db_spam)
  135. if not db_spam then
  136. logger.errx('Cannot open source db: %s', sqlite_db_spam)
  137. return false
  138. end
  139. local db_ham = sqlite3.open(sqlite_db_ham)
  140. if not db_ham then
  141. logger.errx('Cannot open source db: %s', sqlite_db_ham)
  142. return false
  143. end
  144. local res, conn = lua_redis.redis_connect_sync(redis_params, true)
  145. if not res then
  146. logger.errx("cannot connect to redis server")
  147. return false
  148. end
  149. if reset_previous then
  150. -- Do a more complicated cleanup
  151. -- execute a lua script that cleans up data
  152. local script = [[
  153. local members = redis.call('SMEMBERS', KEYS[1]..'_keys')
  154. for _,prefix in ipairs(members) do
  155. local keys = redis.call('KEYS', prefix..'*')
  156. redis.call('DEL', keys)
  157. end
  158. ]]
  159. -- Common keys
  160. for _, sym in ipairs({ symbol_spam, symbol_ham }) do
  161. logger.messagex('Cleaning up old data for %s', sym)
  162. conn:add_cmd('EVAL', { script, '1', sym })
  163. conn:exec()
  164. conn:add_cmd('DEL', { sym .. "_version" })
  165. conn:add_cmd('DEL', { sym .. "_keys" })
  166. conn:exec()
  167. end
  168. if learn_cache_db then
  169. -- Cleanup learned_cache
  170. logger.messagex('Cleaning up old data learned cache')
  171. conn:add_cmd('DEL', { "learned_ids" })
  172. conn:exec()
  173. end
  174. end
  175. local function convert_db(db, is_spam)
  176. -- Map users and languages
  177. local what = 'ham'
  178. if is_spam then
  179. what = 'spam'
  180. end
  181. local learns = {}
  182. db:sql('BEGIN;')
  183. -- Fill users mapping
  184. for row in db:rows('SELECT * FROM users;') do
  185. if row.id == '0' then
  186. users_map[row.id] = ''
  187. else
  188. users_map[row.id] = row.name
  189. end
  190. learns[row.id] = row.learns
  191. nusers = nusers + 1
  192. end
  193. -- Workaround for old databases
  194. for row in db:rows('SELECT * FROM languages') do
  195. if learns['0'] then
  196. learns['0'] = learns['0'] + row.learns
  197. else
  198. learns['0'] = row.learns
  199. end
  200. end
  201. local function send_batch(tokens, prefix)
  202. -- We use the new schema: RS[user]_token -> H=ham count
  203. -- S=spam count
  204. local hash_key = 'H'
  205. if is_spam then
  206. hash_key = 'S'
  207. end
  208. for _, tok in ipairs(tokens) do
  209. -- tok schema:
  210. -- tok[1] = token_id (uint64 represented as a string)
  211. -- tok[2] = token value (number)
  212. -- tok[3] = user_map[user_id] or ''
  213. local rkey = string.format('%s%s_%s', prefix, tok[3], tok[1])
  214. conn:add_cmd('HINCRBYFLOAT', { rkey, hash_key, tostring(tok[2]) })
  215. if expire and expire ~= 0 then
  216. conn:add_cmd('EXPIRE', { rkey, tostring(expire) })
  217. end
  218. end
  219. return conn:exec()
  220. end
  221. -- Fill tokens, sending data to redis each `lim` records
  222. local ntokens = db:query('SELECT count(*) as c FROM tokens')['c']
  223. local tokens = {}
  224. local num = 0
  225. local total = 0
  226. for row in db:rows('SELECT token,value,user FROM tokens;') do
  227. local user = ''
  228. if row.user ~= 0 and users_map[row.user] then
  229. user = users_map[row.user]
  230. end
  231. table.insert(tokens, { row.token, row.value, user })
  232. num = num + 1
  233. total = total + 1
  234. if num > lim then
  235. -- TODO: we use the default 'RS' prefix, it can be false in case of
  236. -- classifiers with labels
  237. local ret, err_str = send_batch(tokens, 'RS')
  238. if not ret then
  239. logger.errx('Cannot send tokens to the redis server: ' .. err_str)
  240. db:sql('COMMIT;')
  241. return false
  242. end
  243. num = 0
  244. tokens = {}
  245. end
  246. io.write(string.format('Processed batch %s: %s/%s\r', what, total, ntokens))
  247. end
  248. -- Last batch
  249. if #tokens > 0 then
  250. local ret, err_str = send_batch(tokens, 'RS')
  251. if not ret then
  252. logger.errx('Cannot send tokens to the redis server: ' .. err_str)
  253. db:sql('COMMIT;')
  254. return false
  255. end
  256. io.write(string.format('Processed batch %s: %s/%s\r', what, total, ntokens))
  257. end
  258. io.write('\n')
  259. converted = converted + total
  260. -- Close DB
  261. db:sql('COMMIT;')
  262. local symbol = symbol_ham
  263. local learns_elt = "learns_ham"
  264. if is_spam then
  265. symbol = symbol_spam
  266. learns_elt = "learns_spam"
  267. end
  268. for id, learned in pairs(learns) do
  269. local user = users_map[id]
  270. if not conn:add_cmd('HSET', { 'RS' .. user, learns_elt, learned }) then
  271. logger.errx('Cannot update learns for user: ' .. user)
  272. return false
  273. end
  274. if not conn:add_cmd('SADD', { symbol .. '_keys', 'RS' .. user }) then
  275. logger.errx('Cannot update learns for user: ' .. user)
  276. return false
  277. end
  278. end
  279. -- Set version
  280. conn:add_cmd('SET', { symbol .. '_version', '2' })
  281. return conn:exec()
  282. end
  283. logger.messagex('Convert spam tokens')
  284. if not convert_db(db_spam, true) then
  285. return false
  286. end
  287. logger.messagex('Convert ham tokens')
  288. if not convert_db(db_ham, false) then
  289. return false
  290. end
  291. if learn_cache_db then
  292. logger.messagex('Convert learned ids from %s', learn_cache_db)
  293. local db = sqlite3.open(learn_cache_db)
  294. local ret = true
  295. local total = 0
  296. if not db then
  297. logger.errx('Cannot open cache database: ' .. learn_cache_db)
  298. return false
  299. end
  300. db:sql('BEGIN;')
  301. for row in db:rows('SELECT * FROM learns;') do
  302. local is_spam
  303. local digest = tostring(util.encode_base32(row.digest))
  304. if row.flag == '0' then
  305. is_spam = '-1'
  306. else
  307. is_spam = '1'
  308. end
  309. if not conn:add_cmd('HSET', { 'learned_ids', digest, is_spam }) then
  310. logger.errx('Cannot add hash: ' .. digest)
  311. ret = false
  312. else
  313. total = total + 1
  314. end
  315. end
  316. db:sql('COMMIT;')
  317. if ret then
  318. conn:exec()
  319. end
  320. if ret then
  321. logger.messagex('Converted %s cached items from sqlite3 learned cache to redis',
  322. total)
  323. else
  324. logger.errx('Error occurred during sending data to redis')
  325. end
  326. end
  327. logger.messagex('Migrated %s tokens for %s users for symbols (%s, %s)',
  328. converted, nusers, symbol_spam, symbol_ham)
  329. return true
  330. end
  331. exports.convert_sqlite_to_redis = convert_sqlite_to_redis
  332. -- Loads sqlite3 based classifiers and output data in form of array of objects:
  333. -- [
  334. -- {
  335. -- symbol_spam = XXX
  336. -- symbol_ham = YYY
  337. -- db_spam = XXX.sqlite
  338. -- db_ham = YYY.sqlite
  339. -- learn_cache = ZZZ.sqlite
  340. -- per_user = true/false
  341. -- label = str
  342. -- }
  343. -- ]
  344. local function load_sqlite_config(cfg)
  345. local result = {}
  346. local function parse_classifier(cls)
  347. local tbl = {}
  348. if cls.cache then
  349. local cache = cls.cache
  350. if cache.type == 'sqlite3' and (cache.file or cache.path) then
  351. tbl.learn_cache = (cache.file or cache.path)
  352. end
  353. end
  354. if cls.per_user then
  355. tbl.per_user = cls.per_user
  356. end
  357. if cls.label then
  358. tbl.label = cls.label
  359. end
  360. local statfiles = cls.statfile
  361. for _, stf in ipairs(statfiles) do
  362. local path = (stf.file or stf.path or stf.db or stf.dbname)
  363. local symbol = stf.symbol or 'undefined'
  364. if not path then
  365. logger.errx('no path defined for statfile %s', symbol)
  366. else
  367. local spam
  368. if stf.spam then
  369. spam = stf.spam
  370. else
  371. if string.match(symbol:upper(), 'SPAM') then
  372. spam = true
  373. else
  374. spam = false
  375. end
  376. end
  377. if spam then
  378. tbl.symbol_spam = symbol
  379. tbl.db_spam = path
  380. else
  381. tbl.symbol_ham = symbol
  382. tbl.db_ham = path
  383. end
  384. end
  385. end
  386. if tbl.symbol_spam and tbl.symbol_ham and tbl.db_ham and tbl.db_spam then
  387. table.insert(result, tbl)
  388. end
  389. end
  390. local classifier = cfg.classifier
  391. if classifier then
  392. if classifier[1] then
  393. for _, cls in ipairs(classifier) do
  394. if cls.bayes then
  395. cls = cls.bayes
  396. end
  397. if cls.backend and cls.backend == 'sqlite3' then
  398. parse_classifier(cls)
  399. end
  400. end
  401. else
  402. if classifier.bayes then
  403. classifier = classifier.bayes
  404. if classifier[1] then
  405. for _, cls in ipairs(classifier) do
  406. if cls.backend and cls.backend == 'sqlite3' then
  407. parse_classifier(cls)
  408. end
  409. end
  410. else
  411. if classifier.backend and classifier.backend == 'sqlite3' then
  412. parse_classifier(classifier)
  413. end
  414. end
  415. end
  416. end
  417. end
  418. return result
  419. end
  420. exports.load_sqlite_config = load_sqlite_config
  421. -- A helper method that suggests a user how to configure Redis based
  422. -- classifier based on the existing sqlite classifier
  423. local function redis_classifier_from_sqlite(sqlite_classifier, expire)
  424. local result = {
  425. new_schema = true,
  426. backend = 'redis',
  427. cache = {
  428. backend = 'redis'
  429. },
  430. statfile = {
  431. [sqlite_classifier.symbol_spam] = {
  432. spam = true
  433. },
  434. [sqlite_classifier.symbol_ham] = {
  435. spam = false
  436. }
  437. }
  438. }
  439. if expire then
  440. result.expire = expire
  441. end
  442. return { classifier = { bayes = result } }
  443. end
  444. exports.redis_classifier_from_sqlite = redis_classifier_from_sqlite
  445. -- Reads statistics config and return preprocessed table
  446. local function process_stat_config(cfg)
  447. local opts_section = cfg:get_all_opt('options') or {}
  448. -- Check if we have a dedicated section for statistics
  449. if opts_section.statistics then
  450. opts_section = opts_section.statistics
  451. end
  452. -- Default
  453. local res_config = {
  454. classify_headers = {
  455. "User-Agent",
  456. "X-Mailer",
  457. "Content-Type",
  458. "X-MimeOLE",
  459. "Organization",
  460. "Organisation"
  461. },
  462. classify_images = true,
  463. classify_mime_info = true,
  464. classify_urls = true,
  465. classify_meta = true,
  466. classify_max_tlds = 10,
  467. }
  468. res_config = lua_util.override_defaults(res_config, opts_section)
  469. -- Postprocess classify_headers
  470. local classify_headers_parsed = {}
  471. for _, v in ipairs(res_config.classify_headers) do
  472. local s1, s2 = v:match("^([A-Z])[^%-]+%-([A-Z]).*$")
  473. local hname
  474. if s1 and s2 then
  475. hname = string.format('%s-%s', s1, s2)
  476. else
  477. s1 = v:match("^X%-([A-Z].*)$")
  478. if s1 then
  479. hname = string.format('x%s', s1:sub(1, 3):lower())
  480. else
  481. hname = string.format('%s', v:sub(1, 3):lower())
  482. end
  483. end
  484. if classify_headers_parsed[hname] then
  485. table.insert(classify_headers_parsed[hname], v)
  486. else
  487. classify_headers_parsed[hname] = { v }
  488. end
  489. end
  490. res_config.classify_headers_parsed = classify_headers_parsed
  491. return res_config
  492. end
  493. local function get_mime_stat_tokens(task, res, i)
  494. local parts = task:get_parts() or {}
  495. local seen_multipart = false
  496. local seen_plain = false
  497. local seen_html = false
  498. local empty_plain = false
  499. local empty_html = false
  500. local online_text = false
  501. for _, part in ipairs(parts) do
  502. local fname = part:get_filename()
  503. local sz = part:get_length()
  504. if sz > 0 then
  505. rawset(res, i, string.format("#ps:%d",
  506. math.floor(math.log(sz))))
  507. lua_util.debugm("bayes", task, "part size: %s",
  508. res[i])
  509. i = i + 1
  510. end
  511. if fname then
  512. rawset(res, i, "#f:" .. fname)
  513. i = i + 1
  514. lua_util.debugm("bayes", task, "added attachment: #f:%s",
  515. fname)
  516. end
  517. if part:is_text() then
  518. local tp = part:get_text()
  519. if tp:is_html() then
  520. seen_html = true
  521. if tp:get_length() == 0 then
  522. empty_html = true
  523. end
  524. else
  525. seen_plain = true
  526. if tp:get_length() == 0 then
  527. empty_plain = true
  528. end
  529. end
  530. if tp:get_lines_count() < 2 then
  531. online_text = true
  532. end
  533. rawset(res, i, "#lang:" .. (tp:get_language() or 'unk'))
  534. lua_util.debugm("bayes", task, "added language: %s",
  535. res[i])
  536. i = i + 1
  537. rawset(res, i, "#cs:" .. (tp:get_charset() or 'unk'))
  538. lua_util.debugm("bayes", task, "added charset: %s",
  539. res[i])
  540. i = i + 1
  541. elseif part:is_multipart() then
  542. seen_multipart = true;
  543. end
  544. end
  545. -- Create a special token depending on parts structure
  546. local st_tok = "#unk"
  547. if seen_multipart and seen_html and seen_plain then
  548. st_tok = '#mpth'
  549. end
  550. if seen_html and not seen_plain then
  551. st_tok = "#ho"
  552. end
  553. if seen_plain and not seen_html then
  554. st_tok = "#to"
  555. end
  556. local spec_tok = ""
  557. if online_text then
  558. spec_tok = "#ot"
  559. end
  560. if empty_plain then
  561. spec_tok = spec_tok .. "#ep"
  562. end
  563. if empty_html then
  564. spec_tok = spec_tok .. "#eh"
  565. end
  566. rawset(res, i, string.format("#m:%s%s", st_tok, spec_tok))
  567. lua_util.debugm("bayes", task, "added mime token: %s",
  568. res[i])
  569. i = i + 1
  570. return i
  571. end
  572. local function get_headers_stat_tokens(task, cf, res, i)
  573. --[[
  574. -- As discussed with Alexander Moisseev, this feature can skew statistics
  575. -- especially when learning is separated from scanning, so learning
  576. -- has a different set of tokens where this token can have too high weight
  577. local hdrs_cksum = task:get_mempool():get_variable("headers_hash")
  578. if hdrs_cksum then
  579. rawset(res, i, string.format("#hh:%s", hdrs_cksum:sub(1, 7)))
  580. lua_util.debugm("bayes", task, "added hdrs hash token: %s",
  581. res[i])
  582. i = i + 1
  583. end
  584. ]]--
  585. for k, hdrs in pairs(cf.classify_headers_parsed) do
  586. for _, hname in ipairs(hdrs) do
  587. local value = task:get_header(hname)
  588. if value then
  589. rawset(res, i, string.format("#h:%s:%s", k, value))
  590. lua_util.debugm("bayes", task, "added hdrs token: %s",
  591. res[i])
  592. i = i + 1
  593. end
  594. end
  595. end
  596. local from = (task:get_from('mime') or {})[1]
  597. if from and from.name then
  598. rawset(res, i, string.format("#F:%s", from.name))
  599. lua_util.debugm("bayes", task, "added from name token: %s",
  600. res[i])
  601. i = i + 1
  602. end
  603. return i
  604. end
  605. local function get_meta_stat_tokens(task, res, i)
  606. local day_and_hour = os.date('%u:%H',
  607. task:get_date { format = 'message', gmt = true })
  608. rawset(res, i, string.format("#dt:%s", day_and_hour))
  609. lua_util.debugm("bayes", task, "added day_of_week token: %s",
  610. res[i])
  611. i = i + 1
  612. local pol = {}
  613. -- Authentication results
  614. if task:has_symbol('DKIM_TRACE') then
  615. -- Autolearn or scan
  616. if task:has_symbol('R_SPF_ALLOW') then
  617. table.insert(pol, 's=pass')
  618. end
  619. local trace = task:get_symbol('DKIM_TRACE')
  620. local dkim_opts = trace[1]['options']
  621. if dkim_opts then
  622. for _, o in ipairs(dkim_opts) do
  623. local check_res = string.sub(o, -1)
  624. local domain = string.sub(o, 1, -3)
  625. if check_res == '+' then
  626. table.insert(pol, string.format('d=%s:%s', "pass", domain))
  627. end
  628. end
  629. end
  630. else
  631. -- Offline learn
  632. local aur = task:get_header('Authentication-Results')
  633. if aur then
  634. local spf = aur:match('spf=([a-z]+)')
  635. local dkim, dkim_domain = aur:match('dkim=([a-z]+) header.d=([a-z.%-]+)')
  636. if spf then
  637. table.insert(pol, 's=' .. spf)
  638. end
  639. if dkim and dkim_domain then
  640. table.insert(pol, string.format('d=%s:%s', dkim, dkim_domain))
  641. end
  642. end
  643. end
  644. if #pol > 0 then
  645. rawset(res, i, string.format("#aur:%s", table.concat(pol, ',')))
  646. lua_util.debugm("bayes", task, "added policies token: %s",
  647. res[i])
  648. i = i + 1
  649. end
  650. --[[
  651. -- Disabled.
  652. -- 1. Depending on the source the message has a different set of Received
  653. -- headers as the receiving MTA adds another Received header.
  654. -- 2. The usefulness of the Received tokens is questionable.
  655. local rh = task:get_received_headers()
  656. if rh and #rh > 0 then
  657. local lim = math.min(5, #rh)
  658. for j =1,lim do
  659. local rcvd = rh[j]
  660. local ip = rcvd.real_ip
  661. if ip and ip:is_valid() and ip:get_version() == 4 then
  662. local masked = ip:apply_mask(24)
  663. rawset(res, i, string.format("#rcv:%s:%s", tostring(masked),
  664. rcvd.proto))
  665. lua_util.debugm("bayes", task, "added received token: %s",
  666. res[i])
  667. i = i + 1
  668. end
  669. end
  670. end
  671. ]]--
  672. return i
  673. end
  674. local function get_stat_tokens(task, cf)
  675. local res = {}
  676. local E = {}
  677. local i = 1
  678. if cf.classify_images then
  679. local images = task:get_images() or E
  680. for _, img in ipairs(images) do
  681. rawset(res, i, "image")
  682. i = i + 1
  683. rawset(res, i, tostring(img:get_height()))
  684. i = i + 1
  685. rawset(res, i, tostring(img:get_width()))
  686. i = i + 1
  687. rawset(res, i, tostring(img:get_type()))
  688. i = i + 1
  689. local fname = img:get_filename()
  690. if fname then
  691. rawset(res, i, tostring(img:get_filename()))
  692. i = i + 1
  693. end
  694. lua_util.debugm("bayes", task, "added image: %s",
  695. fname)
  696. end
  697. end
  698. if cf.classify_mime_info then
  699. i = get_mime_stat_tokens(task, res, i)
  700. end
  701. if cf.classify_headers and #cf.classify_headers > 0 then
  702. i = get_headers_stat_tokens(task, cf, res, i)
  703. end
  704. if cf.classify_urls then
  705. local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 }
  706. if urls then
  707. for _, u in ipairs(urls) do
  708. rawset(res, i, string.format("#u:%s", u:get_tld()))
  709. lua_util.debugm("bayes", task, "added url token: %s",
  710. res[i])
  711. i = i + 1
  712. end
  713. end
  714. end
  715. if cf.classify_meta then
  716. i = get_meta_stat_tokens(task, res, i)
  717. end
  718. return res
  719. end
  720. exports.gen_stat_tokens = function(cfg)
  721. local stat_config = process_stat_config(cfg)
  722. return function(task)
  723. return get_stat_tokens(task, stat_config)
  724. end
  725. end
  726. return exports