You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

neural.lua 32KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local fun = require "fun"
  14. local lua_redis = require "lua_redis"
  15. local lua_settings = require "lua_settings"
  16. local lua_util = require "lua_util"
  17. local meta_functions = require "lua_meta"
  18. local rspamd_kann = require "rspamd_kann"
  19. local rspamd_logger = require "rspamd_logger"
  20. local rspamd_tensor = require "rspamd_tensor"
  21. local rspamd_util = require "rspamd_util"
  22. local ucl = require "ucl"
  23. local N = 'neural'
  24. -- Used in prefix to avoid wrong ANN to be loaded
  25. local plugin_ver = '2'
  26. -- Module vars
  27. local default_options = {
  28. train = {
  29. max_trains = 1000,
  30. max_epoch = 1000,
  31. max_usages = 10,
  32. max_iterations = 25, -- Torch style
  33. mse = 0.001,
  34. autotrain = true,
  35. train_prob = 1.0,
  36. learn_threads = 1,
  37. learn_mode = 'balanced', -- Possible values: balanced, proportional
  38. learning_rate = 0.01,
  39. classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
  40. spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
  41. ham_skip_prob = 0.0, -- proportional mode: ham skip probability
  42. store_pool_only = false, -- store tokens in cache only (disables autotrain);
  43. -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
  44. },
  45. watch_interval = 60.0,
  46. lock_expire = 600,
  47. learning_spawned = false,
  48. ann_expire = 60 * 60 * 24 * 2, -- 2 days
  49. hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
  50. roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
  51. roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
  52. spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
  53. ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
  54. flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
  55. symbol_spam = 'NEURAL_SPAM',
  56. symbol_ham = 'NEURAL_HAM',
  57. max_inputs = nil, -- when PCA is used
  58. blacklisted_symbols = {}, -- list of symbols skipped in neural processing
  59. }
  60. -- Rule structure:
  61. -- * static config fields (see `default_options`)
  62. -- * prefix - name or defined prefix
  63. -- * settings - table of settings indexed by settings id, -1 is used when no settings defined
  64. -- Rule settings element defines elements for specific settings id:
  65. -- * symbols - static symbols profile (defined by config or extracted from symcache)
  66. -- * name - name of settings id
  67. -- * digest - digest of all symbols
  68. -- * ann - dynamic ANN configuration loaded from Redis
  69. -- * train - train data for ANN (e.g. the currently trained ANN)
  70. -- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
  71. -- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
  72. -- * version - version of ANN loaded from redis
  73. -- * redis_key - name of ANN key in Redis
  74. -- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
  75. -- * distance - distance between set.symbols and set.ann.symbols
  76. -- * ann - kann object
  77. local settings = {
  78. rules = {},
  79. prefix = 'rn', -- Neural network default prefix
  80. max_profiles = 3, -- Maximum number of NN profiles stored
  81. }
  82. -- Get module & Redis configuration
  83. local module_config = rspamd_config:get_all_opt(N)
  84. settings = lua_util.override_defaults(settings, module_config)
  85. local redis_params = lua_redis.parse_redis_server('neural')
  86. -- Lua script that checks if we can store a new training vector
  87. -- Uses the following keys:
  88. -- key1 - ann key
  89. -- returns nspam,nham (or nil if locked)
  90. local redis_lua_script_vectors_len = [[
  91. local prefix = KEYS[1]
  92. local locked = redis.call('HGET', prefix, 'lock')
  93. if locked then
  94. local host = redis.call('HGET', prefix, 'hostname') or 'unknown'
  95. return string.format('%s:%s', host, locked)
  96. end
  97. local nspam = 0
  98. local nham = 0
  99. local ret = redis.call('SCARD', prefix .. '_spam_set')
  100. if ret then nspam = tonumber(ret) end
  101. ret = redis.call('SCARD', prefix .. '_ham_set')
  102. if ret then nham = tonumber(ret) end
  103. return {nspam,nham}
  104. ]]
  105. -- Lua script to invalidate ANNs by rank
  106. -- Uses the following keys
  107. -- key1 - prefix for keys
  108. -- key2 - number of elements to leave
  109. local redis_lua_script_maybe_invalidate = [[
  110. local card = redis.call('ZCARD', KEYS[1])
  111. local lim = tonumber(KEYS[2])
  112. if card > lim then
  113. local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
  114. if to_delete then
  115. for _,k in ipairs(to_delete) do
  116. local tb = cjson.decode(k)
  117. if type(tb) == 'table' and type(tb.redis_key) == 'string' then
  118. redis.call('DEL', tb.redis_key)
  119. -- Also train vectors
  120. redis.call('DEL', tb.redis_key .. '_spam_set')
  121. redis.call('DEL', tb.redis_key .. '_ham_set')
  122. end
  123. end
  124. end
  125. redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
  126. return to_delete
  127. else
  128. return {}
  129. end
  130. ]]
  131. -- Lua script to invalidate ANN from redis
  132. -- Uses the following keys
  133. -- key1 - prefix for keys
  134. -- key2 - current time
  135. -- key3 - key expire
  136. -- key4 - hostname
  137. local redis_lua_script_maybe_lock = [[
  138. local locked = redis.call('HGET', KEYS[1], 'lock')
  139. local now = tonumber(KEYS[2])
  140. if locked then
  141. locked = tonumber(locked)
  142. local expire = tonumber(KEYS[3])
  143. if now > locked and (now - locked) < expire then
  144. return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'}
  145. end
  146. end
  147. redis.call('HSET', KEYS[1], 'lock', tostring(now))
  148. redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
  149. return 1
  150. ]]
  151. -- Lua script to save and unlock ANN in redis
  152. -- Uses the following keys
  153. -- key1 - prefix for ANN
  154. -- key2 - prefix for profile
  155. -- key3 - compressed ANN
  156. -- key4 - profile as JSON
  157. -- key5 - expire in seconds
  158. -- key6 - current time
  159. -- key7 - old key
  160. -- key8 - ROC Thresholds
  161. -- key9 - optional PCA
  162. local redis_lua_script_save_unlock = [[
  163. local now = tonumber(KEYS[6])
  164. redis.call('ZADD', KEYS[2], now, KEYS[4])
  165. redis.call('HSET', KEYS[1], 'ann', KEYS[3])
  166. redis.call('DEL', KEYS[1] .. '_spam_set')
  167. redis.call('DEL', KEYS[1] .. '_ham_set')
  168. redis.call('HDEL', KEYS[1], 'lock')
  169. redis.call('HDEL', KEYS[7], 'lock')
  170. redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
  171. redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8])
  172. if KEYS[9] then
  173. redis.call('HSET', KEYS[1], 'pca', KEYS[9])
  174. end
  175. return 1
  176. ]]
  177. local redis_script_id = {}
  178. local function load_scripts()
  179. redis_script_id.vectors_len = lua_redis.add_redis_script(redis_lua_script_vectors_len,
  180. redis_params)
  181. redis_script_id.maybe_invalidate = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
  182. redis_params)
  183. redis_script_id.maybe_lock = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
  184. redis_params)
  185. redis_script_id.save_unlock = lua_redis.add_redis_script(redis_lua_script_save_unlock,
  186. redis_params)
  187. end
  188. local function create_ann(n, nlayers, rule)
  189. -- We ignore number of layers so far when using kann
  190. local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
  191. local t = rspamd_kann.layer.input(n)
  192. t = rspamd_kann.transform.relu(t)
  193. t = rspamd_kann.layer.dense(t, nhidden);
  194. t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
  195. return rspamd_kann.new.kann(t)
  196. end
  197. -- Fills ANN data for a specific settings element
  198. local function fill_set_ann(set, ann_key)
  199. if not set.ann then
  200. set.ann = {
  201. symbols = set.symbols,
  202. distance = 0,
  203. digest = set.digest,
  204. redis_key = ann_key,
  205. version = 0,
  206. }
  207. end
  208. end
  209. -- This function takes all inputs, applies PCA transformation and returns the final
  210. -- PCA matrix as rspamd_tensor
  211. local function learn_pca(inputs, max_inputs)
  212. local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
  213. local eigenvals = scatter_matrix:eigen()
  214. -- scatter matrix is not filled with eigenvectors
  215. lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
  216. local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
  217. for i=1,max_inputs do
  218. w[i] = scatter_matrix[#scatter_matrix - i + 1]
  219. end
  220. lua_util.debugm(N, 'pca matrix: %s', w)
  221. return w
  222. end
  223. -- This function computes optimal threshold using ROC for the given set of inputs.
  224. -- Returns a threshold that minimizes:
  225. -- alpha * (false_positive_rate) + beta * (false_negative_rate)
  226. -- Where alpha is cost of false positive result
  227. -- beta is cost of false negative result
  228. local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
  229. -- Sorts list x and list y based on the values in list x.
  230. local sort_relative = function(x, y)
  231. local r = {}
  232. assert(#x == #y)
  233. local n = #x
  234. local a = {}
  235. local b = {}
  236. for i=1,n do
  237. r[i] = i
  238. end
  239. local cmp = function(p, q) return p < q end
  240. table.sort(r, function(p, q) return cmp(x[p], x[q]) end)
  241. for i=1,n do
  242. a[i] = x[r[i]]
  243. b[i] = y[r[i]]
  244. end
  245. return a, b
  246. end
  247. local function get_scores(nn, input_vectors)
  248. local scores = {}
  249. for i=1,#inputs do
  250. local score = nn:apply1(input_vectors[i], nn.pca)[1]
  251. scores[#scores+1] = score
  252. end
  253. return scores
  254. end
  255. local fpr = {}
  256. local fnr = {}
  257. local scores = get_scores(ann, inputs)
  258. scores, outputs = sort_relative(scores, outputs)
  259. local n_samples = #outputs
  260. local n_spam = 0
  261. local n_ham = 0
  262. local ham_count_ahead = {}
  263. local spam_count_ahead = {}
  264. local ham_count_behind = {}
  265. local spam_count_behind = {}
  266. ham_count_ahead[n_samples + 1] = 0
  267. spam_count_ahead[n_samples + 1] = 0
  268. for i=n_samples,1,-1 do
  269. if outputs[i][1] == 0 then
  270. n_ham = n_ham + 1
  271. ham_count_ahead[i] = 1
  272. spam_count_ahead[i] = 0
  273. else
  274. n_spam = n_spam + 1
  275. ham_count_ahead[i] = 0
  276. spam_count_ahead[i] = 1
  277. end
  278. ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1]
  279. spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1]
  280. end
  281. for i=1,n_samples do
  282. if outputs[i][1] == 0 then
  283. ham_count_behind[i] = 1
  284. spam_count_behind[i] = 0
  285. else
  286. ham_count_behind[i] = 0
  287. spam_count_behind[i] = 1
  288. end
  289. if i ~= 1 then
  290. ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1]
  291. spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1]
  292. end
  293. end
  294. for i=1,n_samples do
  295. fpr[i] = 0
  296. fnr[i] = 0
  297. if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then
  298. fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i])
  299. end
  300. if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then
  301. fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1])
  302. end
  303. end
  304. local p = n_spam / (n_spam + n_ham)
  305. local cost = {}
  306. local min_cost_idx = 0
  307. local min_cost = math.huge
  308. for i=1,n_samples do
  309. cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i])
  310. if min_cost >= cost[i] then
  311. min_cost = cost[i]
  312. min_cost_idx = i
  313. end
  314. end
  315. return scores[min_cost_idx]
  316. end
  317. -- This function is intended to extend lock for ANN during training
  318. -- It registers periodic that increases locked key each 30 seconds unless
  319. -- `set.learning_spawned` is set to `true`
  320. local function register_lock_extender(rule, set, ev_base, ann_key)
  321. rspamd_config:add_periodic(ev_base, 30.0,
  322. function()
  323. local function redis_lock_extend_cb(err, _)
  324. if err then
  325. rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
  326. ann_key, err)
  327. else
  328. rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
  329. ann_key)
  330. end
  331. end
  332. if set.learning_spawned then
  333. lua_redis.redis_make_request_taskless(ev_base,
  334. rspamd_config,
  335. rule.redis,
  336. nil,
  337. true, -- is write
  338. redis_lock_extend_cb, --callback
  339. 'HINCRBY', -- command
  340. {ann_key, 'lock', '30'}
  341. )
  342. else
  343. lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
  344. return false -- do not plan any more updates
  345. end
  346. return true
  347. end
  348. )
  349. end
  350. local function can_push_train_vector(rule, task, learn_type, nspam, nham)
  351. local train_opts = rule.train
  352. local coin = math.random()
  353. if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
  354. rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
  355. return false
  356. end
  357. if train_opts.learn_mode == 'balanced' then
  358. -- Keep balanced training set based on number of spam and ham samples
  359. if learn_type == 'spam' then
  360. if nspam <= train_opts.max_trains then
  361. if nspam > nham then
  362. -- Apply sampling
  363. local skip_rate = 1.0 - nham / (nspam + 1)
  364. if coin < skip_rate - train_opts.classes_bias then
  365. rspamd_logger.infox(task,
  366. 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
  367. learn_type,
  368. skip_rate - train_opts.classes_bias,
  369. nspam, nham)
  370. return false
  371. end
  372. end
  373. return true
  374. else -- Enough learns
  375. rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
  376. learn_type,
  377. nspam)
  378. end
  379. else
  380. if nham <= train_opts.max_trains then
  381. if nham > nspam then
  382. -- Apply sampling
  383. local skip_rate = 1.0 - nspam / (nham + 1)
  384. if coin < skip_rate - train_opts.classes_bias then
  385. rspamd_logger.infox(task,
  386. 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
  387. learn_type,
  388. skip_rate - train_opts.classes_bias,
  389. nspam, nham)
  390. return false
  391. end
  392. end
  393. return true
  394. else
  395. rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
  396. nham)
  397. end
  398. end
  399. else
  400. -- Probabilistic learn mode, we just skip learn if we already have enough samples or
  401. -- if our coin drop is less than desired probability
  402. if learn_type == 'spam' then
  403. if nspam <= train_opts.max_trains then
  404. if train_opts.spam_skip_prob then
  405. if coin <= train_opts.spam_skip_prob then
  406. rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
  407. coin, train_opts.spam_skip_prob)
  408. return false
  409. end
  410. return true
  411. end
  412. else
  413. rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
  414. nspam, train_opts.max_trains)
  415. end
  416. else
  417. if nham <= train_opts.max_trains then
  418. if train_opts.ham_skip_prob then
  419. if coin <= train_opts.ham_skip_prob then
  420. rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
  421. coin, train_opts.ham_skip_prob)
  422. return false
  423. end
  424. return true
  425. end
  426. else
  427. rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
  428. nham, train_opts.max_trains)
  429. end
  430. end
  431. end
  432. return false
  433. end
  434. -- Closure generator for unlock function
  435. local function gen_unlock_cb(rule, set, ann_key)
  436. return function (err)
  437. if err then
  438. rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
  439. rule.prefix, set.name, ann_key, err)
  440. else
  441. lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
  442. rule.prefix, set.name, ann_key)
  443. end
  444. end
  445. end
  446. -- Used to generate new ANN key for specific profile
  447. local function new_ann_key(rule, set, version)
  448. local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
  449. rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
  450. return ann_key
  451. end
  452. local function redis_ann_prefix(rule, settings_name)
  453. -- We also need to count metatokens:
  454. local n = meta_functions.version
  455. return string.format('%s%d_%s_%d_%s',
  456. settings.prefix, plugin_ver, rule.prefix, n, settings_name)
  457. end
  458. -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
  459. local function spawn_train(params)
  460. -- Check training data sanity
  461. -- Now we need to join inputs and create the appropriate test vectors
  462. local n = #params.set.symbols +
  463. meta_functions.rspamd_count_metatokens()
  464. -- Now we can train ann
  465. local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule)
  466. if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then
  467. -- Invalidate ANN as it is definitely invalid
  468. -- TODO: add invalidation
  469. assert(false)
  470. else
  471. local inputs, outputs = {}, {}
  472. -- Used to show parsed vectors in a convenient format (for debugging only)
  473. local function debug_vec(t)
  474. local ret = {}
  475. for i,v in ipairs(t) do
  476. if v ~= 0 then
  477. ret[#ret + 1] = string.format('%d=%.2f', i, v)
  478. end
  479. end
  480. return ret
  481. end
  482. -- Make training set by joining vectors
  483. -- KANN automatically shuffles those samples
  484. -- 1.0 is used for spam and -1.0 is used for ham
  485. -- It implies that output layer can express that (e.g. tanh output)
  486. for _,e in ipairs(params.spam_vec) do
  487. inputs[#inputs + 1] = e
  488. outputs[#outputs + 1] = {1.0}
  489. --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
  490. end
  491. for _,e in ipairs(params.ham_vec) do
  492. inputs[#inputs + 1] = e
  493. outputs[#outputs + 1] = {-1.0}
  494. --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
  495. end
  496. -- Called in child process
  497. local function train()
  498. local log_thresh = params.rule.train.max_iterations / 10
  499. local seen_nan = false
  500. local function train_cb(iter, train_cost, value_cost)
  501. if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then
  502. if train_cost ~= train_cost and not seen_nan then
  503. -- We have nan :( try to log lot's of stuff to dig into a problem
  504. seen_nan = true
  505. rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
  506. params.rule.prefix, params.set.name,
  507. value_cost)
  508. for i,e in ipairs(inputs) do
  509. lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
  510. debug_vec(e), outputs[i][1])
  511. end
  512. end
  513. rspamd_logger.infox(rspamd_config,
  514. "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
  515. params.rule.prefix, params.set.name,
  516. params.ann_key,
  517. iter,
  518. train_cost,
  519. value_cost)
  520. end
  521. end
  522. lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
  523. params.rule.prefix, params.set.name)
  524. local pca
  525. if params.rule.max_inputs then
  526. -- Train PCA in the main process, presumably it is not that long
  527. lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s",
  528. params.rule.prefix, params.set.name)
  529. pca = learn_pca(inputs, params.rule.max_inputs)
  530. end
  531. lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s",
  532. params.rule.prefix, params.set.name)
  533. local ret,err = pcall(train_ann.train1, train_ann,
  534. inputs, outputs, {
  535. lr = params.rule.train.learning_rate,
  536. max_epoch = params.rule.train.max_iterations,
  537. cb = train_cb,
  538. pca = pca
  539. })
  540. if not ret then
  541. rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
  542. params.rule.prefix, params.set.name, err)
  543. return nil
  544. else
  545. lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s",
  546. params.rule.prefix, params.set.name)
  547. end
  548. local roc_thresholds = {}
  549. if params.rule.roc_enabled then
  550. local spam_threshold = get_roc_thresholds(train_ann,
  551. inputs,
  552. outputs,
  553. 1 - params.rule.roc_misclassification_cost,
  554. params.rule.roc_misclassification_cost)
  555. local ham_threshold = get_roc_thresholds(train_ann,
  556. inputs,
  557. outputs,
  558. params.rule.roc_misclassification_cost,
  559. 1 - params.rule.roc_misclassification_cost)
  560. roc_thresholds = {spam_threshold, ham_threshold}
  561. rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)",
  562. roc_thresholds[1], roc_thresholds[2])
  563. end
  564. if not seen_nan then
  565. -- Convert to strings as ucl cannot rspamd_text properly
  566. local pca_data
  567. if pca then
  568. pca_data = tostring(pca:save())
  569. end
  570. local out = {
  571. ann_data = tostring(train_ann:save()),
  572. pca_data = pca_data,
  573. roc_thresholds = roc_thresholds,
  574. }
  575. local final_data = ucl.to_format(out, 'msgpack')
  576. lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes",
  577. params.rule.prefix, params.set.name, #final_data)
  578. return final_data
  579. else
  580. return nil
  581. end
  582. end
  583. params.set.learning_spawned = true
  584. local function redis_save_cb(err)
  585. if err then
  586. rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
  587. params.rule.prefix, params.set.name, params.ann_key, err)
  588. lua_redis.redis_make_request_taskless(params.ev_base,
  589. rspamd_config,
  590. params.rule.redis,
  591. nil,
  592. false, -- is write
  593. gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
  594. 'HDEL', -- command
  595. {params.ann_key, 'lock'}
  596. )
  597. else
  598. rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
  599. params.rule.prefix, params.set.name, params.set.ann.redis_key)
  600. end
  601. end
  602. local function ann_trained(err, data)
  603. params.set.learning_spawned = false
  604. if err then
  605. rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
  606. params.rule.prefix, params.set.name, err)
  607. lua_redis.redis_make_request_taskless(params.ev_base,
  608. rspamd_config,
  609. params.rule.redis,
  610. nil,
  611. true, -- is write
  612. gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
  613. 'HDEL', -- command
  614. {params.ann_key, 'lock'}
  615. )
  616. else
  617. local parser = ucl.parser()
  618. local ok, parse_err = parser:parse_text(data, 'msgpack')
  619. assert(ok, parse_err)
  620. local parsed = parser:get_object()
  621. local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
  622. local pca_data = parsed.pca_data
  623. local roc_thresholds = parsed.roc_thresholds
  624. fill_set_ann(params.set, params.ann_key)
  625. if pca_data then
  626. params.set.ann.pca = rspamd_tensor.load(pca_data)
  627. pca_data = rspamd_util.zstd_compress(pca_data)
  628. end
  629. if roc_thresholds then
  630. params.set.ann.roc_thresholds = roc_thresholds
  631. end
  632. -- Deserialise ANN from the child process
  633. ann_trained = rspamd_kann.load(parsed.ann_data)
  634. local version = (params.set.ann.version or 0) + 1
  635. params.set.ann.version = version
  636. params.set.ann.ann = ann_trained
  637. params.set.ann.symbols = params.set.symbols
  638. params.set.ann.redis_key = new_ann_key(params.rule, params.set, version)
  639. local profile = {
  640. symbols = params.set.symbols,
  641. digest = params.set.digest,
  642. redis_key = params.set.ann.redis_key,
  643. version = version
  644. }
  645. local profile_serialized = ucl.to_format(profile, 'json-compact', true)
  646. local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true)
  647. rspamd_logger.infox(rspamd_config,
  648. 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
  649. params.rule.prefix, params.set.name,
  650. #data, #ann_data,
  651. #(params.set.ann.pca or {}), #(pca_data or {}),
  652. params.set.ann.redis_key, params.ann_key)
  653. lua_redis.exec_redis_script(redis_script_id.save_unlock,
  654. {ev_base = params.ev_base, is_write = true},
  655. redis_save_cb,
  656. {profile.redis_key,
  657. redis_ann_prefix(params.rule, params.set.name),
  658. ann_data,
  659. profile_serialized,
  660. tostring(params.rule.ann_expire),
  661. tostring(os.time()),
  662. params.ann_key, -- old key to unlock...
  663. roc_thresholds_serialized,
  664. pca_data,
  665. })
  666. end
  667. end
  668. if params.rule.max_inputs then
  669. fill_set_ann(params.set, params.ann_key)
  670. end
  671. params.worker:spawn_process{
  672. func = train,
  673. on_complete = ann_trained,
  674. proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name),
  675. }
  676. -- Spawn learn and register lock extension
  677. params.set.learning_spawned = true
  678. register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key)
  679. return
  680. end
  681. end
  682. -- This function is used to adjust profiles and allowed setting ids for each rule
  683. -- It must be called when all settings are already registered (e.g. at post-init for config)
  684. local function process_rules_settings()
  685. local function process_settings_elt(rule, selt)
  686. local profile = rule.profile[selt.name]
  687. if profile then
  688. -- Use static user defined profile
  689. -- Ensure that we have an array...
  690. lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
  691. rule.prefix, selt.name, profile)
  692. if not profile[1] then profile = lua_util.keys(profile) end
  693. selt.symbols = profile
  694. else
  695. lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
  696. rule.prefix, selt.name)
  697. end
  698. local function filter_symbols_predicate(sname)
  699. if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then
  700. return false
  701. end
  702. local fl = rspamd_config:get_symbol_flags(sname)
  703. if fl then
  704. fl = lua_util.list_to_hash(fl)
  705. return not (fl.nostat or fl.idempotent or fl.skip or fl.composite)
  706. end
  707. return false
  708. end
  709. -- Generic stuff
  710. if not profile then
  711. -- Do filtering merely if we are using a dynamic profile
  712. selt.symbols = fun.totable(fun.filter(filter_symbols_predicate, selt.symbols))
  713. end
  714. table.sort(selt.symbols)
  715. selt.digest = lua_util.table_digest(selt.symbols)
  716. selt.prefix = redis_ann_prefix(rule, selt.name)
  717. rspamd_logger.messagex(rspamd_config,
  718. 'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"',
  719. selt.prefix, selt.name, selt.digest)
  720. lua_redis.register_prefix(selt.prefix, N,
  721. string.format('NN prefix for rule "%s"; settings id "%s"',
  722. selt.prefix, selt.name), {
  723. persistent = true,
  724. type = 'zlist',
  725. })
  726. -- Versions
  727. lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
  728. string.format('NN storage for rule "%s"; settings id "%s"',
  729. selt.prefix, selt.name), {
  730. persistent = true,
  731. type = 'hash',
  732. })
  733. lua_redis.register_prefix(selt.prefix .. '_\\d+_spam_set', N,
  734. string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
  735. selt.prefix, selt.name), {
  736. persistent = true,
  737. type = 'set',
  738. })
  739. lua_redis.register_prefix(selt.prefix .. '_\\d+_ham_set', N,
  740. string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
  741. rule.prefix, selt.name), {
  742. persistent = true,
  743. type = 'set',
  744. })
  745. end
  746. for k,rule in pairs(settings.rules) do
  747. if not rule.allowed_settings then
  748. rule.allowed_settings = {}
  749. elseif rule.allowed_settings == 'all' then
  750. -- Extract all settings ids
  751. rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
  752. end
  753. -- Convert to a map <setting_id> -> true
  754. rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
  755. -- Check if we can work without settings
  756. if k == 'default' or type(rule.default) ~= 'boolean' then
  757. rule.default = true
  758. end
  759. rule.settings = {}
  760. if rule.default then
  761. local default_settings = {
  762. symbols = lua_settings.default_symbols(),
  763. name = 'default'
  764. }
  765. process_settings_elt(rule, default_settings)
  766. rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
  767. end
  768. -- Now, for each allowed settings, we store sorted symbols + digest
  769. -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
  770. for s,_ in pairs(rule.allowed_settings) do
  771. -- Here, we have a name, set of symbols and
  772. local settings_id = s
  773. if type(settings_id) ~= 'number' then
  774. settings_id = lua_settings.numeric_settings_id(s)
  775. end
  776. local selt = lua_settings.settings_by_id(settings_id)
  777. local nelt = {
  778. symbols = selt.symbols, -- Already sorted
  779. name = selt.name
  780. }
  781. process_settings_elt(rule, nelt)
  782. for id,ex in pairs(rule.settings) do
  783. if type(ex) == 'table' then
  784. if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
  785. -- Equal symbols, add reference
  786. lua_util.debugm(N, rspamd_config,
  787. 'added reference from settings id %s to %s; same symbols',
  788. nelt.name, ex.name)
  789. rule.settings[settings_id] = id
  790. nelt = nil
  791. end
  792. end
  793. end
  794. if nelt then
  795. rule.settings[settings_id] = nelt
  796. lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
  797. nelt.name, settings_id, rule.prefix)
  798. end
  799. end
  800. end
  801. end
  802. -- Extract settings element for a specific settings id
  803. local function get_rule_settings(task, rule)
  804. local sid = task:get_settings_id() or -1
  805. local set = rule.settings[sid]
  806. if not set then return nil end
  807. while type(set) == 'number' do
  808. -- Reference to another settings!
  809. set = rule.settings[set]
  810. end
  811. return set
  812. end
  813. local function result_to_vector(task, profile)
  814. if not profile.zeros then
  815. -- Fill zeros vector
  816. local zeros = {}
  817. for i=1,meta_functions.count_metatokens() do
  818. zeros[i] = 0.0
  819. end
  820. for _,_ in ipairs(profile.symbols) do
  821. zeros[#zeros + 1] = 0.0
  822. end
  823. profile.zeros = zeros
  824. end
  825. local vec = lua_util.shallowcopy(profile.zeros)
  826. local mt = meta_functions.rspamd_gen_metatokens(task)
  827. for i,v in ipairs(mt) do
  828. vec[i] = v
  829. end
  830. task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
  831. return vec
  832. end
  833. return {
  834. can_push_train_vector = can_push_train_vector,
  835. create_ann = create_ann,
  836. default_options = default_options,
  837. gen_unlock_cb = gen_unlock_cb,
  838. get_rule_settings = get_rule_settings,
  839. load_scripts = load_scripts,
  840. module_config = module_config,
  841. new_ann_key = new_ann_key,
  842. plugin_ver = plugin_ver,
  843. process_rules_settings = process_rules_settings,
  844. redis_ann_prefix = redis_ann_prefix,
  845. redis_params = redis_params,
  846. redis_script_id = redis_script_id,
  847. result_to_vector = result_to_vector,
  848. settings = settings,
  849. spawn_train = spawn_train,
  850. }