You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

neural.lua 33KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995
  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. if confighelp then
  14. return
  15. end
  16. local fun = require "fun"
  17. local lua_redis = require "lua_redis"
  18. local lua_util = require "lua_util"
  19. local lua_verdict = require "lua_verdict"
  20. local neural_common = require "plugins/neural"
  21. local rspamd_kann = require "rspamd_kann"
  22. local rspamd_logger = require "rspamd_logger"
  23. local rspamd_tensor = require "rspamd_tensor"
  24. local rspamd_text = require "rspamd_text"
  25. local rspamd_util = require "rspamd_util"
  26. local ts = require("tableshape").types
  27. local N = "neural"
  28. local settings = neural_common.settings
  29. local redis_profile_schema = ts.shape{
  30. digest = ts.string,
  31. symbols = ts.array_of(ts.string),
  32. version = ts.number,
  33. redis_key = ts.string,
  34. distance = ts.number:is_optional(),
  35. }
  36. local has_blas = rspamd_tensor.has_blas()
  37. local text_cookie = rspamd_text.cookie
  38. -- Creates and stores ANN profile in Redis
  39. local function new_ann_profile(task, rule, set, version)
  40. local ann_key = neural_common.new_ann_key(rule, set, version, settings)
  41. local profile = {
  42. symbols = set.symbols,
  43. redis_key = ann_key,
  44. version = version,
  45. digest = set.digest,
  46. distance = 0 -- Since we are using our own profile
  47. }
  48. local ucl = require "ucl"
  49. local profile_serialized = ucl.to_format(profile, 'json-compact', true)
  50. local function add_cb(err, _)
  51. if err then
  52. rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
  53. rule.prefix, set.name, profile.redis_key, err)
  54. else
  55. rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
  56. rule.prefix, set.name, profile.redis_key)
  57. end
  58. end
  59. lua_redis.redis_make_request(task,
  60. rule.redis,
  61. nil,
  62. true, -- is write
  63. add_cb, --callback
  64. 'ZADD', -- command
  65. {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
  66. )
  67. return profile
  68. end
  69. -- ANN filter function, used to insert scores based on the existing symbols
  70. local function ann_scores_filter(task)
  71. for _,rule in pairs(settings.rules) do
  72. local sid = task:get_settings_id() or -1
  73. local ann
  74. local profile
  75. local set = neural_common.get_rule_settings(task, rule)
  76. if set then
  77. if set.ann then
  78. ann = set.ann.ann
  79. profile = set.ann
  80. else
  81. lua_util.debugm(N, task, 'no ann loaded for %s:%s',
  82. rule.prefix, set.name)
  83. end
  84. else
  85. lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
  86. rule.prefix, sid)
  87. end
  88. if ann then
  89. local vec = neural_common.result_to_vector(task, profile)
  90. local score
  91. local out = ann:apply1(vec, set.ann.pca)
  92. score = out[1]
  93. local symscore = string.format('%.3f', score)
  94. task:cache_set(rule.prefix .. '_neural_score', score)
  95. lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
  96. rule.prefix, set.name, set.ann.version, symscore)
  97. if score > 0 then
  98. local result = score
  99. -- If spam_score_threshold is defined, override all other thresholds.
  100. local spam_threshold = 0
  101. if rule.spam_score_threshold then
  102. spam_threshold = rule.spam_score_threshold
  103. elseif rule.roc_enabled and not set.ann.roc_thresholds then
  104. spam_threshold = set.ann.roc_thresholds[1]
  105. end
  106. if result >= spam_threshold then
  107. if rule.flat_threshold_curve then
  108. task:insert_result(rule.symbol_spam, 1.0, symscore)
  109. else
  110. task:insert_result(rule.symbol_spam, result, symscore)
  111. end
  112. else
  113. lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
  114. rule.prefix, set.name, set.ann.version, symscore,
  115. spam_threshold)
  116. end
  117. else
  118. local result = -(score)
  119. -- If ham_score_threshold is defined, override all other thresholds.
  120. local ham_threshold = 0
  121. if rule.ham_score_threshold then
  122. ham_threshold = rule.ham_score_threshold
  123. elseif rule.roc_enabled and not set.ann.roc_thresholds then
  124. ham_threshold = set.ann.roc_thresholds[2]
  125. end
  126. if result >= ham_threshold then
  127. if rule.flat_threshold_curve then
  128. task:insert_result(rule.symbol_ham, 1.0, symscore)
  129. else
  130. task:insert_result(rule.symbol_ham, result, symscore)
  131. end
  132. else
  133. lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
  134. rule.prefix, set.name, set.ann.version, result,
  135. ham_threshold)
  136. end
  137. end
  138. end
  139. end
  140. end
  141. local function ann_push_task_result(rule, task, verdict, score, set)
  142. local train_opts = rule.train
  143. local learn_spam, learn_ham
  144. local skip_reason = 'unknown'
  145. if not train_opts.store_pool_only and train_opts.autotrain then
  146. if train_opts.spam_score then
  147. learn_spam = score >= train_opts.spam_score
  148. if not learn_spam then
  149. skip_reason = string.format('score < spam_score: %f < %f',
  150. score, train_opts.spam_score)
  151. end
  152. else
  153. learn_spam = verdict == 'spam' or verdict == 'junk'
  154. if not learn_spam then
  155. skip_reason = string.format('verdict: %s',
  156. verdict)
  157. end
  158. end
  159. if train_opts.ham_score then
  160. learn_ham = score <= train_opts.ham_score
  161. if not learn_ham then
  162. skip_reason = string.format('score > ham_score: %f > %f',
  163. score, train_opts.ham_score)
  164. end
  165. else
  166. learn_ham = verdict == 'ham'
  167. if not learn_ham then
  168. skip_reason = string.format('verdict: %s',
  169. verdict)
  170. end
  171. end
  172. else
  173. -- Train by request header
  174. local hdr = task:get_request_header('ANN-Train')
  175. if hdr then
  176. if hdr:lower() == 'spam' then
  177. learn_spam = true
  178. elseif hdr:lower() == 'ham' then
  179. learn_ham = true
  180. else
  181. skip_reason = 'no explicit header'
  182. end
  183. elseif train_opts.store_pool_only then
  184. local ucl = require "ucl"
  185. learn_ham = false
  186. learn_spam = false
  187. -- Explicitly store tokens in cache
  188. local vec = neural_common.result_to_vector(task, set)
  189. task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
  190. task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest)
  191. skip_reason = 'store_pool_only has been set'
  192. end
  193. end
  194. if learn_spam or learn_ham then
  195. local learn_type
  196. if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
  197. local function vectors_len_cb(err, data)
  198. if not err and type(data) == 'table' then
  199. local nspam,nham = data[1],data[2]
  200. if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
  201. local vec = neural_common.result_to_vector(task, set)
  202. local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
  203. local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
  204. local function learn_vec_cb(redis_err)
  205. if redis_err then
  206. rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
  207. rule.prefix, set.name, redis_err)
  208. else
  209. lua_util.debugm(N, task,
  210. "add train data for ANN rule " ..
  211. "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
  212. rule.prefix, set.name, learn_type, #vec, target_key, #str)
  213. end
  214. end
  215. lua_redis.redis_make_request(task,
  216. rule.redis,
  217. nil,
  218. true, -- is write
  219. learn_vec_cb, --callback
  220. 'SADD', -- command
  221. { target_key, str } -- arguments
  222. )
  223. else
  224. lua_util.debugm(N, task,
  225. "do not add %s train data for ANN rule " ..
  226. "%s:%s",
  227. learn_type, rule.prefix, set.name)
  228. end
  229. else
  230. if err then
  231. rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
  232. rule.prefix, set.name, err)
  233. elseif type(data) == 'string' then
  234. -- nil return value
  235. rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
  236. learn_type, rule.prefix, set.name, set.ann.redis_key, data)
  237. else
  238. rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
  239. 'please remove this key from Redis manually if you perform upgrade from the previous version',
  240. rule.prefix, set.name, set.ann.redis_key, type(data))
  241. end
  242. end
  243. end
  244. -- Check if we can learn
  245. if set.can_store_vectors then
  246. if not set.ann then
  247. -- Need to create or load a profile corresponding to the current configuration
  248. set.ann = new_ann_profile(task, rule, set, 0)
  249. lua_util.debugm(N, task,
  250. 'requested new profile for %s, set.ann is missing',
  251. set.name)
  252. end
  253. lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
  254. {task = task, is_write = false},
  255. vectors_len_cb,
  256. {
  257. set.ann.redis_key,
  258. })
  259. else
  260. lua_util.debugm(N, task,
  261. 'do not push data: train condition not satisfied; reason: not checked existing ANNs')
  262. end
  263. else
  264. lua_util.debugm(N, task,
  265. 'do not push data to key %s: train condition not satisfied; reason: %s',
  266. (set.ann or {}).redis_key,
  267. skip_reason)
  268. end
  269. end
  270. --- Offline training logic
  271. -- Utility to extract and split saved training vectors to a table of tables
  272. local function process_training_vectors(data)
  273. return fun.totable(fun.map(function(tok)
  274. local _,str = rspamd_util.zstd_decompress(tok)
  275. return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';')))
  276. end, data))
  277. end
  278. -- This function does the following:
  279. -- * Tries to lock ANN
  280. -- * Loads spam and ham vectors
  281. -- * Spawn learning process
  282. local function do_train_ann(worker, ev_base, rule, set, ann_key)
  283. local spam_elts = {}
  284. local ham_elts = {}
  285. local function redis_ham_cb(err, data)
  286. if err or type(data) ~= 'table' then
  287. rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
  288. ann_key, err)
  289. -- Unlock on error
  290. lua_redis.redis_make_request_taskless(ev_base,
  291. rspamd_config,
  292. rule.redis,
  293. nil,
  294. true, -- is write
  295. neural_common.gen_unlock_cb(rule, set, ann_key), --callback
  296. 'HDEL', -- command
  297. {ann_key, 'lock'}
  298. )
  299. else
  300. -- Decompress and convert to numbers each training vector
  301. ham_elts = process_training_vectors(data)
  302. neural_common.spawn_train({worker = worker, ev_base = ev_base,
  303. rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts,
  304. spam_vec = spam_elts})
  305. end
  306. end
  307. -- Spam vectors received
  308. local function redis_spam_cb(err, data)
  309. if err or type(data) ~= 'table' then
  310. rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
  311. ann_key, err)
  312. -- Unlock ANN on error
  313. lua_redis.redis_make_request_taskless(ev_base,
  314. rspamd_config,
  315. rule.redis,
  316. nil,
  317. true, -- is write
  318. neural_common.gen_unlock_cb(rule, set, ann_key), --callback
  319. 'HDEL', -- command
  320. {ann_key, 'lock'}
  321. )
  322. else
  323. -- Decompress and convert to numbers each training vector
  324. spam_elts = process_training_vectors(data)
  325. -- Now get ham vectors...
  326. lua_redis.redis_make_request_taskless(ev_base,
  327. rspamd_config,
  328. rule.redis,
  329. nil,
  330. false, -- is write
  331. redis_ham_cb, --callback
  332. 'SMEMBERS', -- command
  333. {ann_key .. '_ham_set'}
  334. )
  335. end
  336. end
  337. local function redis_lock_cb(err, data)
  338. if err then
  339. rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
  340. ann_key, err)
  341. elseif type(data) == 'number' and data == 1 then
  342. -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
  343. lua_redis.redis_make_request_taskless(ev_base,
  344. rspamd_config,
  345. rule.redis,
  346. nil,
  347. false, -- is write
  348. redis_spam_cb, --callback
  349. 'SMEMBERS', -- command
  350. {ann_key .. '_spam_set'}
  351. )
  352. rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
  353. rule.prefix, set.name, ann_key)
  354. else
  355. local lock_tm = tonumber(data[1])
  356. rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
  357. 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
  358. data[2], os.date('%c', lock_tm))
  359. end
  360. end
  361. -- Check if we are already learning this network
  362. if set.learning_spawned then
  363. rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
  364. ann_key)
  365. return
  366. end
  367. -- Call Redis script that tries to acquire a lock
  368. -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
  369. -- ANN is locked by another host (or a process, meh)
  370. lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
  371. {ev_base = ev_base, is_write = true},
  372. redis_lock_cb,
  373. {
  374. ann_key,
  375. tostring(os.time()),
  376. tostring(math.max(10.0, rule.watch_interval * 2)),
  377. rspamd_util.get_hostname()
  378. })
  379. end
  380. -- This function loads new ann from Redis
  381. -- This is based on `profile` attribute.
  382. -- ANN is loaded from `profile.redis_key`
  383. -- Rank of `profile` key is also increased, unfortunately, it means that we need to
  384. -- serialize profile one more time and set its rank to the current time
  385. -- set.ann fields are set according to Redis data received
  386. local function load_new_ann(rule, ev_base, set, profile, min_diff)
  387. local ann_key = profile.redis_key
  388. local function data_cb(err, data)
  389. if err then
  390. rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
  391. ann_key, err)
  392. else
  393. if type(data) == 'table' then
  394. if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
  395. local _err,ann_data = rspamd_util.zstd_decompress(data[1])
  396. local ann
  397. if _err or not ann_data then
  398. rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
  399. rule.prefix .. ':' .. set.name, ann_key, _err)
  400. return
  401. else
  402. ann = rspamd_kann.load(ann_data)
  403. if ann then
  404. set.ann = {
  405. digest = profile.digest,
  406. version = profile.version,
  407. symbols = profile.symbols,
  408. distance = min_diff,
  409. redis_key = profile.redis_key
  410. }
  411. local ucl = require "ucl"
  412. local profile_serialized = ucl.to_format(profile, 'json-compact', true)
  413. set.ann.ann = ann -- To avoid serialization
  414. local function rank_cb(_, _)
  415. -- TODO: maybe add some logging
  416. end
  417. -- Also update rank for the loaded ANN to avoid removal
  418. lua_redis.redis_make_request_taskless(ev_base,
  419. rspamd_config,
  420. rule.redis,
  421. nil,
  422. true, -- is write
  423. rank_cb, --callback
  424. 'ZADD', -- command
  425. {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
  426. )
  427. rspamd_logger.infox(rspamd_config,
  428. 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
  429. rule.prefix, set.name, ann_key, #data[1], profile.version)
  430. else
  431. rspamd_logger.errx(rspamd_config,
  432. 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
  433. rule.prefix, set.name, ann_key)
  434. end
  435. end
  436. else
  437. lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
  438. rule.prefix, set.name, ann_key)
  439. end
  440. if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
  441. if rule.roc_enabled then
  442. local ucl = require "ucl"
  443. local parser = ucl.parser()
  444. local ok, parse_err = parser:parse_text(data[2])
  445. assert(ok, parse_err)
  446. local roc_thresholds = parser:get_object()
  447. set.ann.roc_thresholds = roc_thresholds
  448. rspamd_logger.infox(rspamd_config,
  449. 'loaded ROC thresholds for %s:%s; version=%s',
  450. rule.prefix, set.name, profile.version)
  451. rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
  452. end
  453. end
  454. if set.ann and set.ann.ann and type(data[3]) == 'userdata' and data[3].cookie == text_cookie then
  455. -- PCA table
  456. local _err,pca_data = rspamd_util.zstd_decompress(data[3])
  457. if pca_data then
  458. if rule.max_inputs then
  459. -- We can use PCA
  460. set.ann.pca = rspamd_tensor.load(pca_data)
  461. rspamd_logger.infox(rspamd_config,
  462. 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
  463. rule.prefix, set.name, ann_key, #data[3], profile.version)
  464. else
  465. -- no need in pca, why is it there?
  466. rspamd_logger.warnx(rspamd_config,
  467. 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
  468. rule.prefix, set.name, ann_key)
  469. end
  470. else
  471. -- pca can be missing merely if we have no max_inputs
  472. if rule.max_inputs then
  473. rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
  474. rule.prefix, set.name, ann_key, _err)
  475. set.ann.ann = nil
  476. else
  477. -- It is okay
  478. set.ann.pca = nil
  479. end
  480. end
  481. end
  482. else
  483. lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
  484. rule.prefix, set.name, ann_key)
  485. end
  486. end
  487. end
  488. lua_redis.redis_make_request_taskless(ev_base,
  489. rspamd_config,
  490. rule.redis,
  491. nil,
  492. false, -- is write
  493. data_cb, --callback
  494. 'HMGET', -- command
  495. {ann_key, 'ann', 'roc_thresholds', 'pca'}, -- arguments
  496. {opaque_data = true}
  497. )
  498. end
  499. -- Used to check an element in Redis serialized as JSON
  500. -- for some specific rule + some specific setting
  501. -- This function tries to load more fresh or more specific ANNs in lieu of
  502. -- the existing ones.
  503. -- Use this function to load ANNs as `callback` parameter for `check_anns` function
  504. local function process_existing_ann(_, ev_base, rule, set, profiles)
  505. local my_symbols = set.symbols
  506. local min_diff = math.huge
  507. local sel_elt
  508. for _,elt in fun.iter(profiles) do
  509. if elt and elt.symbols then
  510. local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
  511. -- Check distance
  512. if dist < #my_symbols * .3 then
  513. if dist < min_diff then
  514. min_diff = dist
  515. sel_elt = elt
  516. end
  517. end
  518. end
  519. end
  520. if sel_elt then
  521. -- We can load element from ANN
  522. if set.ann then
  523. -- We have an existing ANN, probably the same...
  524. if set.ann.digest == sel_elt.digest then
  525. -- Same ANN, check version
  526. if set.ann.version < sel_elt.version then
  527. -- Load new ann
  528. rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
  529. 'our version = %s, remote version = %s',
  530. rule.prefix .. ':' .. set.name,
  531. set.ann.version,
  532. sel_elt.version)
  533. load_new_ann(rule, ev_base, set, sel_elt, min_diff)
  534. else
  535. lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
  536. 'our version = %s, remote version = %s',
  537. rule.prefix .. ':' .. set.name,
  538. set.ann.version,
  539. sel_elt.version)
  540. end
  541. else
  542. -- We have some different ANN, so we need to compare distance
  543. if set.ann.distance > min_diff then
  544. -- Load more specific ANN
  545. rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
  546. 'our distance = %s, remote distance = %s',
  547. rule.prefix .. ':' .. set.name,
  548. set.ann.distance,
  549. min_diff)
  550. load_new_ann(rule, ev_base, set, sel_elt, min_diff)
  551. else
  552. lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
  553. 'our distance = %s, remote distance = %s',
  554. rule.prefix .. ':' .. set.name,
  555. set.ann.distance,
  556. min_diff)
  557. end
  558. end
  559. else
  560. -- We have no ANN, load new one
  561. load_new_ann(rule, ev_base, set, sel_elt, min_diff)
  562. end
  563. end
  564. end
  565. -- This function checks all profiles and selects if we can train our
  566. -- ANN. By our we mean that it has exactly the same symbols in profile.
  567. -- Use this function to train ANN as `callback` parameter for `check_anns` function
  568. local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
  569. local my_symbols = set.symbols
  570. local sel_elt
  571. local lens = {
  572. spam = 0,
  573. ham = 0,
  574. }
  575. for _,elt in fun.iter(profiles) do
  576. if elt and elt.symbols then
  577. local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
  578. -- Check distance
  579. if dist == 0 then
  580. sel_elt = elt
  581. break
  582. end
  583. end
  584. end
  585. if sel_elt then
  586. -- We have our ANN and that's train vectors, check if we can learn
  587. local ann_key = sel_elt.redis_key
  588. lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
  589. ann_key)
  590. -- Create continuation closure
  591. local redis_len_cb_gen = function(cont_cb, what, is_final)
  592. return function(err, data)
  593. if err then
  594. rspamd_logger.errx(rspamd_config,
  595. 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
  596. elseif data and type(data) == 'number' or type(data) == 'string' then
  597. local ntrains = tonumber(data) or 0
  598. lens[what] = ntrains
  599. if is_final then
  600. -- Ensure that we have the following:
  601. -- one class has reached max_trains
  602. -- other class(es) are at least as full as classes_bias
  603. -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
  604. -- one class must have 10 or more trains whilst another should have
  605. -- at least (10 * (1 - 0.25)) = 8 trains
  606. local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
  607. local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
  608. if rule.train.learn_type == 'balanced' then
  609. local len_bias_check_pred = function(_, l)
  610. return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
  611. end
  612. if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
  613. rspamd_logger.debugm(N, rspamd_config,
  614. 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
  615. ann_key, lens, rule.train.max_trains, what)
  616. cont_cb()
  617. else
  618. rspamd_logger.debugm(N, rspamd_config,
  619. 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
  620. ann_key, what, lens, rule.train.max_trains)
  621. end
  622. else
  623. -- Probabilistic mode, just ensure that at least one vector is okay
  624. if min_len > 0 and max_len >= rule.train.max_trains then
  625. rspamd_logger.debugm(N, rspamd_config,
  626. 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
  627. ann_key, lens, rule.train.max_trains, what)
  628. cont_cb()
  629. else
  630. rspamd_logger.debugm(N, rspamd_config,
  631. 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
  632. ann_key, what, lens, rule.train.max_trains)
  633. end
  634. end
  635. else
  636. rspamd_logger.debugm(N, rspamd_config,
  637. 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
  638. what, ann_key, ntrains, rule.train.max_trains)
  639. cont_cb()
  640. end
  641. end
  642. end
  643. end
  644. local function initiate_train()
  645. rspamd_logger.infox(rspamd_config,
  646. 'need to learn ANN %s after %s required learn vectors',
  647. ann_key, lens)
  648. do_train_ann(worker, ev_base, rule, set, ann_key)
  649. end
  650. -- Spam vector is OK, check ham vector length
  651. local function check_ham_len()
  652. lua_redis.redis_make_request_taskless(ev_base,
  653. rspamd_config,
  654. rule.redis,
  655. nil,
  656. false, -- is write
  657. redis_len_cb_gen(initiate_train, 'ham', true), --callback
  658. 'SCARD', -- command
  659. {ann_key .. '_ham_set'}
  660. )
  661. end
  662. lua_redis.redis_make_request_taskless(ev_base,
  663. rspamd_config,
  664. rule.redis,
  665. nil,
  666. false, -- is write
  667. redis_len_cb_gen(check_ham_len, 'spam', false), --callback
  668. 'SCARD', -- command
  669. {ann_key .. '_spam_set'}
  670. )
  671. end
  672. end
  673. -- Used to deserialise ANN element from a list
  674. local function load_ann_profile(element)
  675. local ucl = require "ucl"
  676. local parser = ucl.parser()
  677. local res,ucl_err = parser:parse_string(element)
  678. if not res then
  679. rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
  680. ucl_err)
  681. return nil
  682. else
  683. local profile = parser:get_object()
  684. local checked,schema_err = redis_profile_schema:transform(profile)
  685. if not checked then
  686. rspamd_logger.errx(rspamd_config, "cannot parse profile schema: %s", schema_err)
  687. return nil
  688. end
  689. return checked
  690. end
  691. end
  692. -- Function to check or load ANNs from Redis
  693. local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
  694. for _,set in pairs(rule.settings) do
  695. local function members_cb(err, data)
  696. if err then
  697. rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
  698. err)
  699. set.can_store_vectors = true
  700. elseif type(data) == 'table' then
  701. lua_util.debugm(N, cfg, '%s: process element %s:%s',
  702. what, rule.prefix, set.name)
  703. process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
  704. set.can_store_vectors = true
  705. end
  706. end
  707. if type(set) == 'table' then
  708. -- Extract all profiles for some specific settings id
  709. -- Get the last `max_profiles` recently used
  710. -- Select the most appropriate to our profile but it should not differ by more
  711. -- than 30% of symbols
  712. lua_redis.redis_make_request_taskless(ev_base,
  713. cfg,
  714. rule.redis,
  715. nil,
  716. false, -- is write
  717. members_cb, --callback
  718. 'ZREVRANGE', -- command
  719. {set.prefix, '0', tostring(settings.max_profiles)} -- arguments
  720. )
  721. end
  722. end -- Cycle over all settings
  723. return rule.watch_interval
  724. end
  725. -- Function to clean up old ANNs
  726. local function cleanup_anns(rule, cfg, ev_base)
  727. for _,set in pairs(rule.settings) do
  728. local function invalidate_cb(err, data)
  729. if err then
  730. rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
  731. err)
  732. elseif type(data) == 'table' then
  733. for _,expired in ipairs(data) do
  734. local profile = load_ann_profile(expired)
  735. rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
  736. rule.prefix .. ':' .. set.name,
  737. profile.redis_key,
  738. profile.version)
  739. end
  740. end
  741. end
  742. if type(set) == 'table' then
  743. lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
  744. {ev_base = ev_base, is_write = true},
  745. invalidate_cb,
  746. {set.prefix, tostring(settings.max_profiles)})
  747. end
  748. end
  749. end
  750. local function ann_push_vector(task)
  751. if task:has_flag('skip') then
  752. lua_util.debugm(N, task, 'do not push data for skipped task')
  753. return
  754. end
  755. if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then
  756. lua_util.debugm(N, task, 'do not push data for manual scan')
  757. return
  758. end
  759. local verdict,score = lua_verdict.get_specific_verdict(N, task)
  760. if verdict == 'passthrough' then
  761. lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
  762. verdict, score)
  763. return
  764. end
  765. if score ~= score then
  766. lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
  767. verdict)
  768. return
  769. end
  770. for _,rule in pairs(settings.rules) do
  771. local set = neural_common.get_rule_settings(task, rule)
  772. if set then
  773. ann_push_task_result(rule, task, verdict, score, set)
  774. else
  775. lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
  776. end
  777. end
  778. end
  779. -- Initialization part
  780. if not (neural_common.module_config and type(neural_common.module_config) == 'table')
  781. or not neural_common.redis_params then
  782. rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
  783. lua_util.disable_module(N, "redis")
  784. return
  785. end
  786. local rules = neural_common.module_config['rules']
  787. if not rules then
  788. -- Use legacy configuration
  789. rules = {}
  790. rules['default'] = neural_common.module_config
  791. end
  792. local id = rspamd_config:register_symbol({
  793. name = 'NEURAL_CHECK',
  794. type = 'postfilter,callback',
  795. flags = 'nostat',
  796. priority = lua_util.symbols_priorities.medium,
  797. callback = ann_scores_filter
  798. })
  799. neural_common.settings.rules = {} -- Reset unless validated further in the cycle
  800. if settings.blacklisted_symbols and settings.blacklisted_symbols[1] then
  801. -- Transform to hash for simplicity
  802. settings.blacklisted_symbols = lua_util.list_to_hash(settings.blacklisted_symbols)
  803. end
  804. -- Check all rules
  805. for k,r in pairs(rules) do
  806. local rule_elt = lua_util.override_defaults(neural_common.default_options, r)
  807. rule_elt['redis'] = neural_common.redis_params
  808. rule_elt['anns'] = {} -- Store ANNs here
  809. if not rule_elt.prefix then
  810. rule_elt.prefix = k
  811. end
  812. if not rule_elt.name then
  813. rule_elt.name = k
  814. end
  815. if rule_elt.train.max_train and not rule_elt.train.max_trains then
  816. rule_elt.train.max_trains = rule_elt.train.max_train
  817. end
  818. if not rule_elt.profile then rule_elt.profile = {} end
  819. if rule_elt.max_inputs and not has_blas then
  820. rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in',
  821. rule_elt.name, rule_elt.max_inputs)
  822. rule_elt.max_inputs = nil
  823. end
  824. rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
  825. settings.rules[k] = rule_elt
  826. rspamd_config:set_metric_symbol({
  827. name = rule_elt.symbol_spam,
  828. score = 0.0,
  829. description = 'Neural network SPAM',
  830. group = 'neural'
  831. })
  832. rspamd_config:register_symbol({
  833. name = rule_elt.symbol_spam,
  834. type = 'virtual',
  835. flags = 'nostat',
  836. parent = id
  837. })
  838. rspamd_config:set_metric_symbol({
  839. name = rule_elt.symbol_ham,
  840. score = -0.0,
  841. description = 'Neural network HAM',
  842. group = 'neural'
  843. })
  844. rspamd_config:register_symbol({
  845. name = rule_elt.symbol_ham,
  846. type = 'virtual',
  847. flags = 'nostat',
  848. parent = id
  849. })
  850. end
  851. rspamd_config:register_symbol({
  852. name = 'NEURAL_LEARN',
  853. type = 'idempotent,callback',
  854. flags = 'nostat,explicit_disable,ignore_passthrough',
  855. callback = ann_push_vector
  856. })
  857. -- We also need to deal with settings
  858. rspamd_config:add_post_init(neural_common.process_rules_settings)
  859. -- Add training scripts
  860. for _,rule in pairs(settings.rules) do
  861. neural_common.load_scripts(rule.redis)
  862. -- This function will check ANNs in Redis when a worker is loaded
  863. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  864. if worker:is_scanner() then
  865. rspamd_config:add_periodic(ev_base, 0.0,
  866. function(_, _)
  867. return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
  868. 'try_load_ann')
  869. end)
  870. end
  871. if worker:is_primary_controller() then
  872. -- We also want to train neural nets when they have enough data
  873. rspamd_config:add_periodic(ev_base, 0.0,
  874. function(_, _)
  875. -- Clean old ANNs
  876. cleanup_anns(rule, cfg, ev_base)
  877. return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
  878. 'try_train_ann')
  879. end)
  880. end
  881. end)
  882. end