You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

neural.lua 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892
  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local fun = require "fun"
  14. local lua_redis = require "lua_redis"
  15. local lua_settings = require "lua_settings"
  16. local lua_util = require "lua_util"
  17. local meta_functions = require "lua_meta"
  18. local rspamd_kann = require "rspamd_kann"
  19. local rspamd_logger = require "rspamd_logger"
  20. local rspamd_tensor = require "rspamd_tensor"
  21. local rspamd_util = require "rspamd_util"
  22. local ucl = require "ucl"
  23. local N = 'neural'
  24. -- Used in prefix to avoid wrong ANN to be loaded
  25. local plugin_ver = '2'
  26. -- Module vars
  27. local default_options = {
  28. train = {
  29. max_trains = 1000,
  30. max_epoch = 1000,
  31. max_usages = 10,
  32. max_iterations = 25, -- Torch style
  33. mse = 0.001,
  34. autotrain = true,
  35. train_prob = 1.0,
  36. learn_threads = 1,
  37. learn_mode = 'balanced', -- Possible values: balanced, proportional
  38. learning_rate = 0.01,
  39. classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
  40. spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
  41. ham_skip_prob = 0.0, -- proportional mode: ham skip probability
  42. store_pool_only = false, -- store tokens in cache only (disables autotrain);
  43. -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
  44. },
  45. watch_interval = 60.0,
  46. lock_expire = 600,
  47. learning_spawned = false,
  48. ann_expire = 60 * 60 * 24 * 2, -- 2 days
  49. hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
  50. roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
  51. roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
  52. spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
  53. ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
  54. flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
  55. symbol_spam = 'NEURAL_SPAM',
  56. symbol_ham = 'NEURAL_HAM',
  57. max_inputs = nil, -- when PCA is used
  58. blacklisted_symbols = {}, -- list of symbols skipped in neural processing
  59. }
  60. -- Rule structure:
  61. -- * static config fields (see `default_options`)
  62. -- * prefix - name or defined prefix
  63. -- * settings - table of settings indexed by settings id, -1 is used when no settings defined
  64. -- Rule settings element defines elements for specific settings id:
  65. -- * symbols - static symbols profile (defined by config or extracted from symcache)
  66. -- * name - name of settings id
  67. -- * digest - digest of all symbols
  68. -- * ann - dynamic ANN configuration loaded from Redis
  69. -- * train - train data for ANN (e.g. the currently trained ANN)
  70. -- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
  71. -- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
  72. -- * version - version of ANN loaded from redis
  73. -- * redis_key - name of ANN key in Redis
  74. -- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
  75. -- * distance - distance between set.symbols and set.ann.symbols
  76. -- * ann - kann object
  77. local settings = {
  78. rules = {},
  79. prefix = 'rn', -- Neural network default prefix
  80. max_profiles = 3, -- Maximum number of NN profiles stored
  81. }
  82. -- Get module & Redis configuration
  83. local module_config = rspamd_config:get_all_opt(N)
  84. settings = lua_util.override_defaults(settings, module_config)
  85. local redis_params = lua_redis.parse_redis_server('neural')
  86. local redis_lua_script_vectors_len = "neural_train_size.lua"
  87. local redis_lua_script_maybe_invalidate = "neural_maybe_invalidate.lua"
  88. local redis_lua_script_maybe_lock = "neural_maybe_lock.lua"
  89. local redis_lua_script_save_unlock = "neural_save_unlock.lua"
  90. local redis_script_id = {}
  91. local function load_scripts()
  92. redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len,
  93. redis_params)
  94. redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate,
  95. redis_params)
  96. redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock,
  97. redis_params)
  98. redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock,
  99. redis_params)
  100. end
  101. local function create_ann(n, nlayers, rule)
  102. -- We ignore number of layers so far when using kann
  103. local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
  104. local t = rspamd_kann.layer.input(n)
  105. t = rspamd_kann.transform.relu(t)
  106. t = rspamd_kann.layer.dense(t, nhidden);
  107. t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
  108. return rspamd_kann.new.kann(t)
  109. end
  110. -- Fills ANN data for a specific settings element
  111. local function fill_set_ann(set, ann_key)
  112. if not set.ann then
  113. set.ann = {
  114. symbols = set.symbols,
  115. distance = 0,
  116. digest = set.digest,
  117. redis_key = ann_key,
  118. version = 0,
  119. }
  120. end
  121. end
  122. -- This function takes all inputs, applies PCA transformation and returns the final
  123. -- PCA matrix as rspamd_tensor
  124. local function learn_pca(inputs, max_inputs)
  125. local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
  126. local eigenvals = scatter_matrix:eigen()
  127. -- scatter matrix is not filled with eigenvectors
  128. lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
  129. local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
  130. for i = 1, max_inputs do
  131. w[i] = scatter_matrix[#scatter_matrix - i + 1]
  132. end
  133. lua_util.debugm(N, 'pca matrix: %s', w)
  134. return w
  135. end
  136. -- This function computes optimal threshold using ROC for the given set of inputs.
  137. -- Returns a threshold that minimizes:
  138. -- alpha * (false_positive_rate) + beta * (false_negative_rate)
  139. -- Where alpha is cost of false positive result
  140. -- beta is cost of false negative result
  141. local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
  142. -- Sorts list x and list y based on the values in list x.
  143. local sort_relative = function(x, y)
  144. local r = {}
  145. assert(#x == #y)
  146. local n = #x
  147. local a = {}
  148. local b = {}
  149. for i = 1, n do
  150. r[i] = i
  151. end
  152. local cmp = function(p, q)
  153. return p < q
  154. end
  155. table.sort(r, function(p, q)
  156. return cmp(x[p], x[q])
  157. end)
  158. for i = 1, n do
  159. a[i] = x[r[i]]
  160. b[i] = y[r[i]]
  161. end
  162. return a, b
  163. end
  164. local function get_scores(nn, input_vectors)
  165. local scores = {}
  166. for i = 1, #inputs do
  167. local score = nn:apply1(input_vectors[i], nn.pca)[1]
  168. scores[#scores + 1] = score
  169. end
  170. return scores
  171. end
  172. local fpr = {}
  173. local fnr = {}
  174. local scores = get_scores(ann, inputs)
  175. scores, outputs = sort_relative(scores, outputs)
  176. local n_samples = #outputs
  177. local n_spam = 0
  178. local n_ham = 0
  179. local ham_count_ahead = {}
  180. local spam_count_ahead = {}
  181. local ham_count_behind = {}
  182. local spam_count_behind = {}
  183. ham_count_ahead[n_samples + 1] = 0
  184. spam_count_ahead[n_samples + 1] = 0
  185. for i = n_samples, 1, -1 do
  186. if outputs[i][1] == 0 then
  187. n_ham = n_ham + 1
  188. ham_count_ahead[i] = 1
  189. spam_count_ahead[i] = 0
  190. else
  191. n_spam = n_spam + 1
  192. ham_count_ahead[i] = 0
  193. spam_count_ahead[i] = 1
  194. end
  195. ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1]
  196. spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1]
  197. end
  198. for i = 1, n_samples do
  199. if outputs[i][1] == 0 then
  200. ham_count_behind[i] = 1
  201. spam_count_behind[i] = 0
  202. else
  203. ham_count_behind[i] = 0
  204. spam_count_behind[i] = 1
  205. end
  206. if i ~= 1 then
  207. ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1]
  208. spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1]
  209. end
  210. end
  211. for i = 1, n_samples do
  212. fpr[i] = 0
  213. fnr[i] = 0
  214. if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then
  215. fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i])
  216. end
  217. if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then
  218. fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1])
  219. end
  220. end
  221. local p = n_spam / (n_spam + n_ham)
  222. local cost = {}
  223. local min_cost_idx = 0
  224. local min_cost = math.huge
  225. for i = 1, n_samples do
  226. cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i])
  227. if min_cost >= cost[i] then
  228. min_cost = cost[i]
  229. min_cost_idx = i
  230. end
  231. end
  232. return scores[min_cost_idx]
  233. end
  234. -- This function is intended to extend lock for ANN during training
  235. -- It registers periodic that increases locked key each 30 seconds unless
  236. -- `set.learning_spawned` is set to `true`
  237. local function register_lock_extender(rule, set, ev_base, ann_key)
  238. rspamd_config:add_periodic(ev_base, 30.0,
  239. function()
  240. local function redis_lock_extend_cb(err, _)
  241. if err then
  242. rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
  243. ann_key, err)
  244. else
  245. rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
  246. ann_key)
  247. end
  248. end
  249. if set.learning_spawned then
  250. lua_redis.redis_make_request_taskless(ev_base,
  251. rspamd_config,
  252. rule.redis,
  253. nil,
  254. true, -- is write
  255. redis_lock_extend_cb, --callback
  256. 'HINCRBY', -- command
  257. { ann_key, 'lock', '30' }
  258. )
  259. else
  260. lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
  261. return false -- do not plan any more updates
  262. end
  263. return true
  264. end
  265. )
  266. end
  267. local function can_push_train_vector(rule, task, learn_type, nspam, nham)
  268. local train_opts = rule.train
  269. local coin = math.random()
  270. if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
  271. rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
  272. return false
  273. end
  274. if train_opts.learn_mode == 'balanced' then
  275. -- Keep balanced training set based on number of spam and ham samples
  276. if learn_type == 'spam' then
  277. if nspam <= train_opts.max_trains then
  278. if nspam > nham then
  279. -- Apply sampling
  280. local skip_rate = 1.0 - nham / (nspam + 1)
  281. if coin < skip_rate - train_opts.classes_bias then
  282. rspamd_logger.infox(task,
  283. 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
  284. learn_type,
  285. skip_rate - train_opts.classes_bias,
  286. nspam, nham)
  287. return false
  288. end
  289. end
  290. return true
  291. else
  292. -- Enough learns
  293. rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
  294. learn_type,
  295. nspam)
  296. end
  297. else
  298. if nham <= train_opts.max_trains then
  299. if nham > nspam then
  300. -- Apply sampling
  301. local skip_rate = 1.0 - nspam / (nham + 1)
  302. if coin < skip_rate - train_opts.classes_bias then
  303. rspamd_logger.infox(task,
  304. 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
  305. learn_type,
  306. skip_rate - train_opts.classes_bias,
  307. nspam, nham)
  308. return false
  309. end
  310. end
  311. return true
  312. else
  313. rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
  314. nham)
  315. end
  316. end
  317. else
  318. -- Probabilistic learn mode, we just skip learn if we already have enough samples or
  319. -- if our coin drop is less than desired probability
  320. if learn_type == 'spam' then
  321. if nspam <= train_opts.max_trains then
  322. if train_opts.spam_skip_prob then
  323. if coin <= train_opts.spam_skip_prob then
  324. rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
  325. coin, train_opts.spam_skip_prob)
  326. return false
  327. end
  328. return true
  329. end
  330. else
  331. rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
  332. nspam, train_opts.max_trains)
  333. end
  334. else
  335. if nham <= train_opts.max_trains then
  336. if train_opts.ham_skip_prob then
  337. if coin <= train_opts.ham_skip_prob then
  338. rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
  339. coin, train_opts.ham_skip_prob)
  340. return false
  341. end
  342. return true
  343. end
  344. else
  345. rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
  346. nham, train_opts.max_trains)
  347. end
  348. end
  349. end
  350. return false
  351. end
  352. -- Closure generator for unlock function
  353. local function gen_unlock_cb(rule, set, ann_key)
  354. return function(err)
  355. if err then
  356. rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
  357. rule.prefix, set.name, ann_key, err)
  358. else
  359. lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
  360. rule.prefix, set.name, ann_key)
  361. end
  362. end
  363. end
  364. -- Used to generate new ANN key for specific profile
  365. local function new_ann_key(rule, set, version)
  366. local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
  367. rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
  368. return ann_key
  369. end
  370. local function redis_ann_prefix(rule, settings_name)
  371. -- We also need to count metatokens:
  372. local n = meta_functions.version
  373. return string.format('%s%d_%s_%d_%s',
  374. settings.prefix, plugin_ver, rule.prefix, n, settings_name)
  375. end
  376. -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
  377. local function spawn_train(params)
  378. -- Check training data sanity
  379. -- Now we need to join inputs and create the appropriate test vectors
  380. local n = #params.set.symbols +
  381. meta_functions.rspamd_count_metatokens()
  382. -- Now we can train ann
  383. local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule)
  384. if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then
  385. -- Invalidate ANN as it is definitely invalid
  386. -- TODO: add invalidation
  387. assert(false)
  388. else
  389. local inputs, outputs = {}, {}
  390. -- Used to show parsed vectors in a convenient format (for debugging only)
  391. local function debug_vec(t)
  392. local ret = {}
  393. for i, v in ipairs(t) do
  394. if v ~= 0 then
  395. ret[#ret + 1] = string.format('%d=%.2f', i, v)
  396. end
  397. end
  398. return ret
  399. end
  400. -- Make training set by joining vectors
  401. -- KANN automatically shuffles those samples
  402. -- 1.0 is used for spam and -1.0 is used for ham
  403. -- It implies that output layer can express that (e.g. tanh output)
  404. for _, e in ipairs(params.spam_vec) do
  405. inputs[#inputs + 1] = e
  406. outputs[#outputs + 1] = { 1.0 }
  407. --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
  408. end
  409. for _, e in ipairs(params.ham_vec) do
  410. inputs[#inputs + 1] = e
  411. outputs[#outputs + 1] = { -1.0 }
  412. --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
  413. end
  414. -- Called in child process
  415. local function train()
  416. local log_thresh = params.rule.train.max_iterations / 10
  417. local seen_nan = false
  418. local function train_cb(iter, train_cost, value_cost)
  419. if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then
  420. if train_cost ~= train_cost and not seen_nan then
  421. -- We have nan :( try to log lot's of stuff to dig into a problem
  422. seen_nan = true
  423. rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
  424. params.rule.prefix, params.set.name,
  425. value_cost)
  426. for i, e in ipairs(inputs) do
  427. lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
  428. debug_vec(e), outputs[i][1])
  429. end
  430. end
  431. rspamd_logger.infox(rspamd_config,
  432. "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
  433. params.rule.prefix, params.set.name,
  434. params.ann_key,
  435. iter,
  436. train_cost,
  437. value_cost)
  438. end
  439. end
  440. lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
  441. params.rule.prefix, params.set.name)
  442. local pca
  443. if params.rule.max_inputs then
  444. -- Train PCA in the main process, presumably it is not that long
  445. lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s",
  446. params.rule.prefix, params.set.name)
  447. pca = learn_pca(inputs, params.rule.max_inputs)
  448. end
  449. lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s",
  450. params.rule.prefix, params.set.name)
  451. local ret, err = pcall(train_ann.train1, train_ann,
  452. inputs, outputs, {
  453. lr = params.rule.train.learning_rate,
  454. max_epoch = params.rule.train.max_iterations,
  455. cb = train_cb,
  456. pca = pca
  457. })
  458. if not ret then
  459. rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
  460. params.rule.prefix, params.set.name, err)
  461. return nil
  462. else
  463. lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s",
  464. params.rule.prefix, params.set.name)
  465. end
  466. local roc_thresholds = {}
  467. if params.rule.roc_enabled then
  468. local spam_threshold = get_roc_thresholds(train_ann,
  469. inputs,
  470. outputs,
  471. 1 - params.rule.roc_misclassification_cost,
  472. params.rule.roc_misclassification_cost)
  473. local ham_threshold = get_roc_thresholds(train_ann,
  474. inputs,
  475. outputs,
  476. params.rule.roc_misclassification_cost,
  477. 1 - params.rule.roc_misclassification_cost)
  478. roc_thresholds = { spam_threshold, ham_threshold }
  479. rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)",
  480. roc_thresholds[1], roc_thresholds[2])
  481. end
  482. if not seen_nan then
  483. -- Convert to strings as ucl cannot rspamd_text properly
  484. local pca_data
  485. if pca then
  486. pca_data = tostring(pca:save())
  487. end
  488. local out = {
  489. ann_data = tostring(train_ann:save()),
  490. pca_data = pca_data,
  491. roc_thresholds = roc_thresholds,
  492. }
  493. local final_data = ucl.to_format(out, 'msgpack')
  494. lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes",
  495. params.rule.prefix, params.set.name, #final_data)
  496. return final_data
  497. else
  498. return nil
  499. end
  500. end
  501. params.set.learning_spawned = true
  502. local function redis_save_cb(err)
  503. if err then
  504. rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
  505. params.rule.prefix, params.set.name, params.ann_key, err)
  506. lua_redis.redis_make_request_taskless(params.ev_base,
  507. rspamd_config,
  508. params.rule.redis,
  509. nil,
  510. false, -- is write
  511. gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
  512. 'HDEL', -- command
  513. { params.ann_key, 'lock' }
  514. )
  515. else
  516. rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
  517. params.rule.prefix, params.set.name, params.set.ann.redis_key)
  518. end
  519. end
  520. local function ann_trained(err, data)
  521. params.set.learning_spawned = false
  522. if err then
  523. rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
  524. params.rule.prefix, params.set.name, err)
  525. lua_redis.redis_make_request_taskless(params.ev_base,
  526. rspamd_config,
  527. params.rule.redis,
  528. nil,
  529. true, -- is write
  530. gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
  531. 'HDEL', -- command
  532. { params.ann_key, 'lock' }
  533. )
  534. else
  535. local parser = ucl.parser()
  536. local ok, parse_err = parser:parse_text(data, 'msgpack')
  537. assert(ok, parse_err)
  538. local parsed = parser:get_object()
  539. local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
  540. local pca_data = parsed.pca_data
  541. local roc_thresholds = parsed.roc_thresholds
  542. fill_set_ann(params.set, params.ann_key)
  543. if pca_data then
  544. params.set.ann.pca = rspamd_tensor.load(pca_data)
  545. pca_data = rspamd_util.zstd_compress(pca_data)
  546. end
  547. if roc_thresholds then
  548. params.set.ann.roc_thresholds = roc_thresholds
  549. end
  550. -- Deserialise ANN from the child process
  551. ann_trained = rspamd_kann.load(parsed.ann_data)
  552. local version = (params.set.ann.version or 0) + 1
  553. params.set.ann.version = version
  554. params.set.ann.ann = ann_trained
  555. params.set.ann.symbols = params.set.symbols
  556. params.set.ann.redis_key = new_ann_key(params.rule, params.set, version)
  557. local profile = {
  558. symbols = params.set.symbols,
  559. digest = params.set.digest,
  560. redis_key = params.set.ann.redis_key,
  561. version = version
  562. }
  563. local profile_serialized = ucl.to_format(profile, 'json-compact', true)
  564. local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true)
  565. rspamd_logger.infox(rspamd_config,
  566. 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
  567. params.rule.prefix, params.set.name,
  568. #data, #ann_data,
  569. #(params.set.ann.pca or {}), #(pca_data or {}),
  570. params.set.ann.redis_key, params.ann_key)
  571. lua_redis.exec_redis_script(redis_script_id.save_unlock,
  572. { ev_base = params.ev_base, is_write = true },
  573. redis_save_cb,
  574. { profile.redis_key,
  575. redis_ann_prefix(params.rule, params.set.name),
  576. ann_data,
  577. profile_serialized,
  578. tostring(params.rule.ann_expire),
  579. tostring(os.time()),
  580. params.ann_key, -- old key to unlock...
  581. roc_thresholds_serialized,
  582. pca_data,
  583. })
  584. end
  585. end
  586. if params.rule.max_inputs then
  587. fill_set_ann(params.set, params.ann_key)
  588. end
  589. params.worker:spawn_process {
  590. func = train,
  591. on_complete = ann_trained,
  592. proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name),
  593. }
  594. -- Spawn learn and register lock extension
  595. params.set.learning_spawned = true
  596. register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key)
  597. return
  598. end
  599. end
  600. -- This function is used to adjust profiles and allowed setting ids for each rule
  601. -- It must be called when all settings are already registered (e.g. at post-init for config)
  602. local function process_rules_settings()
  603. local function process_settings_elt(rule, selt)
  604. local profile = rule.profile[selt.name]
  605. if profile then
  606. -- Use static user defined profile
  607. -- Ensure that we have an array...
  608. lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
  609. rule.prefix, selt.name, profile)
  610. if not profile[1] then
  611. profile = lua_util.keys(profile)
  612. end
  613. selt.symbols = profile
  614. else
  615. lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
  616. rule.prefix, selt.name)
  617. end
  618. local function filter_symbols_predicate(sname)
  619. if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then
  620. return false
  621. end
  622. local fl = rspamd_config:get_symbol_flags(sname)
  623. if fl then
  624. fl = lua_util.list_to_hash(fl)
  625. return not (fl.nostat or fl.idempotent or fl.skip or fl.composite)
  626. end
  627. return false
  628. end
  629. -- Generic stuff
  630. if not profile then
  631. -- Do filtering merely if we are using a dynamic profile
  632. selt.symbols = fun.totable(fun.filter(filter_symbols_predicate, selt.symbols))
  633. end
  634. table.sort(selt.symbols)
  635. selt.digest = lua_util.table_digest(selt.symbols)
  636. selt.prefix = redis_ann_prefix(rule, selt.name)
  637. rspamd_logger.messagex(rspamd_config,
  638. 'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"',
  639. selt.prefix, selt.name, selt.digest)
  640. lua_redis.register_prefix(selt.prefix, N,
  641. string.format('NN prefix for rule "%s"; settings id "%s"',
  642. selt.prefix, selt.name), {
  643. persistent = true,
  644. type = 'zlist',
  645. })
  646. -- Versions
  647. lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
  648. string.format('NN storage for rule "%s"; settings id "%s"',
  649. selt.prefix, selt.name), {
  650. persistent = true,
  651. type = 'hash',
  652. })
  653. lua_redis.register_prefix(selt.prefix .. '_\\d+_spam_set', N,
  654. string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
  655. selt.prefix, selt.name), {
  656. persistent = true,
  657. type = 'set',
  658. })
  659. lua_redis.register_prefix(selt.prefix .. '_\\d+_ham_set', N,
  660. string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
  661. rule.prefix, selt.name), {
  662. persistent = true,
  663. type = 'set',
  664. })
  665. end
  666. for k, rule in pairs(settings.rules) do
  667. if not rule.allowed_settings then
  668. rule.allowed_settings = {}
  669. elseif rule.allowed_settings == 'all' then
  670. -- Extract all settings ids
  671. rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
  672. end
  673. -- Convert to a map <setting_id> -> true
  674. rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
  675. -- Check if we can work without settings
  676. if k == 'default' or type(rule.default) ~= 'boolean' then
  677. rule.default = true
  678. end
  679. rule.settings = {}
  680. if rule.default then
  681. local default_settings = {
  682. symbols = lua_settings.default_symbols(),
  683. name = 'default'
  684. }
  685. process_settings_elt(rule, default_settings)
  686. rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
  687. end
  688. -- Now, for each allowed settings, we store sorted symbols + digest
  689. -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
  690. for s, _ in pairs(rule.allowed_settings) do
  691. -- Here, we have a name, set of symbols and
  692. local settings_id = s
  693. if type(settings_id) ~= 'number' then
  694. settings_id = lua_settings.numeric_settings_id(s)
  695. end
  696. local selt = lua_settings.settings_by_id(settings_id)
  697. local nelt = {
  698. symbols = selt.symbols, -- Already sorted
  699. name = selt.name
  700. }
  701. process_settings_elt(rule, nelt)
  702. for id, ex in pairs(rule.settings) do
  703. if type(ex) == 'table' then
  704. if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
  705. -- Equal symbols, add reference
  706. lua_util.debugm(N, rspamd_config,
  707. 'added reference from settings id %s to %s; same symbols',
  708. nelt.name, ex.name)
  709. rule.settings[settings_id] = id
  710. nelt = nil
  711. end
  712. end
  713. end
  714. if nelt then
  715. rule.settings[settings_id] = nelt
  716. lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
  717. nelt.name, settings_id, rule.prefix)
  718. end
  719. end
  720. end
  721. end
  722. -- Extract settings element for a specific settings id
  723. local function get_rule_settings(task, rule)
  724. local sid = task:get_settings_id() or -1
  725. local set = rule.settings[sid]
  726. if not set then
  727. return nil
  728. end
  729. while type(set) == 'number' do
  730. -- Reference to another settings!
  731. set = rule.settings[set]
  732. end
  733. return set
  734. end
  735. local function result_to_vector(task, profile)
  736. if not profile.zeros then
  737. -- Fill zeros vector
  738. local zeros = {}
  739. for i = 1, meta_functions.count_metatokens() do
  740. zeros[i] = 0.0
  741. end
  742. for _, _ in ipairs(profile.symbols) do
  743. zeros[#zeros + 1] = 0.0
  744. end
  745. profile.zeros = zeros
  746. end
  747. local vec = lua_util.shallowcopy(profile.zeros)
  748. local mt = meta_functions.rspamd_gen_metatokens(task)
  749. for i, v in ipairs(mt) do
  750. vec[i] = v
  751. end
  752. task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
  753. return vec
  754. end
  755. return {
  756. can_push_train_vector = can_push_train_vector,
  757. create_ann = create_ann,
  758. default_options = default_options,
  759. gen_unlock_cb = gen_unlock_cb,
  760. get_rule_settings = get_rule_settings,
  761. load_scripts = load_scripts,
  762. module_config = module_config,
  763. new_ann_key = new_ann_key,
  764. plugin_ver = plugin_ver,
  765. process_rules_settings = process_rules_settings,
  766. redis_ann_prefix = redis_ann_prefix,
  767. redis_params = redis_params,
  768. redis_script_id = redis_script_id,
  769. result_to_vector = result_to_vector,
  770. settings = settings,
  771. spawn_train = spawn_train,
  772. }