Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

neural.lua 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884
  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local fun = require "fun"
  14. local lua_redis = require "lua_redis"
  15. local lua_settings = require "lua_settings"
  16. local lua_util = require "lua_util"
  17. local meta_functions = require "lua_meta"
  18. local rspamd_kann = require "rspamd_kann"
  19. local rspamd_logger = require "rspamd_logger"
  20. local rspamd_tensor = require "rspamd_tensor"
  21. local rspamd_util = require "rspamd_util"
  22. local ucl = require "ucl"
  23. local N = 'neural'
  24. -- Used in prefix to avoid wrong ANN to be loaded
  25. local plugin_ver = '2'
  26. -- Module vars
  27. local default_options = {
  28. train = {
  29. max_trains = 1000,
  30. max_epoch = 1000,
  31. max_usages = 10,
  32. max_iterations = 25, -- Torch style
  33. mse = 0.001,
  34. autotrain = true,
  35. train_prob = 1.0,
  36. learn_threads = 1,
  37. learn_mode = 'balanced', -- Possible values: balanced, proportional
  38. learning_rate = 0.01,
  39. classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
  40. spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
  41. ham_skip_prob = 0.0, -- proportional mode: ham skip probability
  42. store_pool_only = false, -- store tokens in cache only (disables autotrain);
  43. -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
  44. },
  45. watch_interval = 60.0,
  46. lock_expire = 600,
  47. learning_spawned = false,
  48. ann_expire = 60 * 60 * 24 * 2, -- 2 days
  49. hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
  50. roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
  51. roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
  52. spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
  53. ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
  54. flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
  55. symbol_spam = 'NEURAL_SPAM',
  56. symbol_ham = 'NEURAL_HAM',
  57. max_inputs = nil, -- when PCA is used
  58. blacklisted_symbols = {}, -- list of symbols skipped in neural processing
  59. }
  60. -- Rule structure:
  61. -- * static config fields (see `default_options`)
  62. -- * prefix - name or defined prefix
  63. -- * settings - table of settings indexed by settings id, -1 is used when no settings defined
  64. -- Rule settings element defines elements for specific settings id:
  65. -- * symbols - static symbols profile (defined by config or extracted from symcache)
  66. -- * name - name of settings id
  67. -- * digest - digest of all symbols
  68. -- * ann - dynamic ANN configuration loaded from Redis
  69. -- * train - train data for ANN (e.g. the currently trained ANN)
  70. -- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
  71. -- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
  72. -- * version - version of ANN loaded from redis
  73. -- * redis_key - name of ANN key in Redis
  74. -- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
  75. -- * distance - distance between set.symbols and set.ann.symbols
  76. -- * ann - kann object
  77. local settings = {
  78. rules = {},
  79. prefix = 'rn', -- Neural network default prefix
  80. max_profiles = 3, -- Maximum number of NN profiles stored
  81. }
  82. -- Get module & Redis configuration
  83. local module_config = rspamd_config:get_all_opt(N)
  84. settings = lua_util.override_defaults(settings, module_config)
  85. local redis_params = lua_redis.parse_redis_server('neural')
  86. local redis_lua_script_vectors_len = "neural_train_size.lua"
  87. local redis_lua_script_maybe_invalidate = "neural_maybe_invalidate.lua"
  88. local redis_lua_script_maybe_lock = "neural_maybe_lock.lua"
  89. local redis_lua_script_save_unlock = "neural_save_unlock.lua"
  90. local redis_script_id = {}
  91. local function load_scripts()
  92. redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len,
  93. redis_params)
  94. redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate,
  95. redis_params)
  96. redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock,
  97. redis_params)
  98. redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock,
  99. redis_params)
  100. end
  101. local function create_ann(n, nlayers, rule)
  102. -- We ignore number of layers so far when using kann
  103. local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
  104. local t = rspamd_kann.layer.input(n)
  105. t = rspamd_kann.transform.relu(t)
  106. t = rspamd_kann.layer.dense(t, nhidden);
  107. t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
  108. return rspamd_kann.new.kann(t)
  109. end
  110. -- Fills ANN data for a specific settings element
  111. local function fill_set_ann(set, ann_key)
  112. if not set.ann then
  113. set.ann = {
  114. symbols = set.symbols,
  115. distance = 0,
  116. digest = set.digest,
  117. redis_key = ann_key,
  118. version = 0,
  119. }
  120. end
  121. end
  122. -- This function takes all inputs, applies PCA transformation and returns the final
  123. -- PCA matrix as rspamd_tensor
  124. local function learn_pca(inputs, max_inputs)
  125. local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
  126. local eigenvals = scatter_matrix:eigen()
  127. -- scatter matrix is not filled with eigenvectors
  128. lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
  129. local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
  130. for i=1,max_inputs do
  131. w[i] = scatter_matrix[#scatter_matrix - i + 1]
  132. end
  133. lua_util.debugm(N, 'pca matrix: %s', w)
  134. return w
  135. end
  136. -- This function computes optimal threshold using ROC for the given set of inputs.
  137. -- Returns a threshold that minimizes:
  138. -- alpha * (false_positive_rate) + beta * (false_negative_rate)
  139. -- Where alpha is cost of false positive result
  140. -- beta is cost of false negative result
  141. local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
  142. -- Sorts list x and list y based on the values in list x.
  143. local sort_relative = function(x, y)
  144. local r = {}
  145. assert(#x == #y)
  146. local n = #x
  147. local a = {}
  148. local b = {}
  149. for i=1,n do
  150. r[i] = i
  151. end
  152. local cmp = function(p, q) return p < q end
  153. table.sort(r, function(p, q) return cmp(x[p], x[q]) end)
  154. for i=1,n do
  155. a[i] = x[r[i]]
  156. b[i] = y[r[i]]
  157. end
  158. return a, b
  159. end
  160. local function get_scores(nn, input_vectors)
  161. local scores = {}
  162. for i=1,#inputs do
  163. local score = nn:apply1(input_vectors[i], nn.pca)[1]
  164. scores[#scores+1] = score
  165. end
  166. return scores
  167. end
  168. local fpr = {}
  169. local fnr = {}
  170. local scores = get_scores(ann, inputs)
  171. scores, outputs = sort_relative(scores, outputs)
  172. local n_samples = #outputs
  173. local n_spam = 0
  174. local n_ham = 0
  175. local ham_count_ahead = {}
  176. local spam_count_ahead = {}
  177. local ham_count_behind = {}
  178. local spam_count_behind = {}
  179. ham_count_ahead[n_samples + 1] = 0
  180. spam_count_ahead[n_samples + 1] = 0
  181. for i=n_samples,1,-1 do
  182. if outputs[i][1] == 0 then
  183. n_ham = n_ham + 1
  184. ham_count_ahead[i] = 1
  185. spam_count_ahead[i] = 0
  186. else
  187. n_spam = n_spam + 1
  188. ham_count_ahead[i] = 0
  189. spam_count_ahead[i] = 1
  190. end
  191. ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1]
  192. spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1]
  193. end
  194. for i=1,n_samples do
  195. if outputs[i][1] == 0 then
  196. ham_count_behind[i] = 1
  197. spam_count_behind[i] = 0
  198. else
  199. ham_count_behind[i] = 0
  200. spam_count_behind[i] = 1
  201. end
  202. if i ~= 1 then
  203. ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1]
  204. spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1]
  205. end
  206. end
  207. for i=1,n_samples do
  208. fpr[i] = 0
  209. fnr[i] = 0
  210. if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then
  211. fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i])
  212. end
  213. if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then
  214. fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1])
  215. end
  216. end
  217. local p = n_spam / (n_spam + n_ham)
  218. local cost = {}
  219. local min_cost_idx = 0
  220. local min_cost = math.huge
  221. for i=1,n_samples do
  222. cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i])
  223. if min_cost >= cost[i] then
  224. min_cost = cost[i]
  225. min_cost_idx = i
  226. end
  227. end
  228. return scores[min_cost_idx]
  229. end
  230. -- This function is intended to extend lock for ANN during training
  231. -- It registers periodic that increases locked key each 30 seconds unless
  232. -- `set.learning_spawned` is set to `true`
  233. local function register_lock_extender(rule, set, ev_base, ann_key)
  234. rspamd_config:add_periodic(ev_base, 30.0,
  235. function()
  236. local function redis_lock_extend_cb(err, _)
  237. if err then
  238. rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
  239. ann_key, err)
  240. else
  241. rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
  242. ann_key)
  243. end
  244. end
  245. if set.learning_spawned then
  246. lua_redis.redis_make_request_taskless(ev_base,
  247. rspamd_config,
  248. rule.redis,
  249. nil,
  250. true, -- is write
  251. redis_lock_extend_cb, --callback
  252. 'HINCRBY', -- command
  253. {ann_key, 'lock', '30'}
  254. )
  255. else
  256. lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
  257. return false -- do not plan any more updates
  258. end
  259. return true
  260. end
  261. )
  262. end
  263. local function can_push_train_vector(rule, task, learn_type, nspam, nham)
  264. local train_opts = rule.train
  265. local coin = math.random()
  266. if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
  267. rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
  268. return false
  269. end
  270. if train_opts.learn_mode == 'balanced' then
  271. -- Keep balanced training set based on number of spam and ham samples
  272. if learn_type == 'spam' then
  273. if nspam <= train_opts.max_trains then
  274. if nspam > nham then
  275. -- Apply sampling
  276. local skip_rate = 1.0 - nham / (nspam + 1)
  277. if coin < skip_rate - train_opts.classes_bias then
  278. rspamd_logger.infox(task,
  279. 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
  280. learn_type,
  281. skip_rate - train_opts.classes_bias,
  282. nspam, nham)
  283. return false
  284. end
  285. end
  286. return true
  287. else -- Enough learns
  288. rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
  289. learn_type,
  290. nspam)
  291. end
  292. else
  293. if nham <= train_opts.max_trains then
  294. if nham > nspam then
  295. -- Apply sampling
  296. local skip_rate = 1.0 - nspam / (nham + 1)
  297. if coin < skip_rate - train_opts.classes_bias then
  298. rspamd_logger.infox(task,
  299. 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
  300. learn_type,
  301. skip_rate - train_opts.classes_bias,
  302. nspam, nham)
  303. return false
  304. end
  305. end
  306. return true
  307. else
  308. rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
  309. nham)
  310. end
  311. end
  312. else
  313. -- Probabilistic learn mode, we just skip learn if we already have enough samples or
  314. -- if our coin drop is less than desired probability
  315. if learn_type == 'spam' then
  316. if nspam <= train_opts.max_trains then
  317. if train_opts.spam_skip_prob then
  318. if coin <= train_opts.spam_skip_prob then
  319. rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
  320. coin, train_opts.spam_skip_prob)
  321. return false
  322. end
  323. return true
  324. end
  325. else
  326. rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
  327. nspam, train_opts.max_trains)
  328. end
  329. else
  330. if nham <= train_opts.max_trains then
  331. if train_opts.ham_skip_prob then
  332. if coin <= train_opts.ham_skip_prob then
  333. rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
  334. coin, train_opts.ham_skip_prob)
  335. return false
  336. end
  337. return true
  338. end
  339. else
  340. rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
  341. nham, train_opts.max_trains)
  342. end
  343. end
  344. end
  345. return false
  346. end
  347. -- Closure generator for unlock function
  348. local function gen_unlock_cb(rule, set, ann_key)
  349. return function (err)
  350. if err then
  351. rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
  352. rule.prefix, set.name, ann_key, err)
  353. else
  354. lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
  355. rule.prefix, set.name, ann_key)
  356. end
  357. end
  358. end
  359. -- Used to generate new ANN key for specific profile
  360. local function new_ann_key(rule, set, version)
  361. local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
  362. rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
  363. return ann_key
  364. end
  365. local function redis_ann_prefix(rule, settings_name)
  366. -- We also need to count metatokens:
  367. local n = meta_functions.version
  368. return string.format('%s%d_%s_%d_%s',
  369. settings.prefix, plugin_ver, rule.prefix, n, settings_name)
  370. end
  371. -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
  372. local function spawn_train(params)
  373. -- Check training data sanity
  374. -- Now we need to join inputs and create the appropriate test vectors
  375. local n = #params.set.symbols +
  376. meta_functions.rspamd_count_metatokens()
  377. -- Now we can train ann
  378. local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule)
  379. if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then
  380. -- Invalidate ANN as it is definitely invalid
  381. -- TODO: add invalidation
  382. assert(false)
  383. else
  384. local inputs, outputs = {}, {}
  385. -- Used to show parsed vectors in a convenient format (for debugging only)
  386. local function debug_vec(t)
  387. local ret = {}
  388. for i,v in ipairs(t) do
  389. if v ~= 0 then
  390. ret[#ret + 1] = string.format('%d=%.2f', i, v)
  391. end
  392. end
  393. return ret
  394. end
  395. -- Make training set by joining vectors
  396. -- KANN automatically shuffles those samples
  397. -- 1.0 is used for spam and -1.0 is used for ham
  398. -- It implies that output layer can express that (e.g. tanh output)
  399. for _,e in ipairs(params.spam_vec) do
  400. inputs[#inputs + 1] = e
  401. outputs[#outputs + 1] = {1.0}
  402. --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
  403. end
  404. for _,e in ipairs(params.ham_vec) do
  405. inputs[#inputs + 1] = e
  406. outputs[#outputs + 1] = {-1.0}
  407. --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
  408. end
  409. -- Called in child process
  410. local function train()
  411. local log_thresh = params.rule.train.max_iterations / 10
  412. local seen_nan = false
  413. local function train_cb(iter, train_cost, value_cost)
  414. if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then
  415. if train_cost ~= train_cost and not seen_nan then
  416. -- We have nan :( try to log lot's of stuff to dig into a problem
  417. seen_nan = true
  418. rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
  419. params.rule.prefix, params.set.name,
  420. value_cost)
  421. for i,e in ipairs(inputs) do
  422. lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
  423. debug_vec(e), outputs[i][1])
  424. end
  425. end
  426. rspamd_logger.infox(rspamd_config,
  427. "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
  428. params.rule.prefix, params.set.name,
  429. params.ann_key,
  430. iter,
  431. train_cost,
  432. value_cost)
  433. end
  434. end
  435. lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
  436. params.rule.prefix, params.set.name)
  437. local pca
  438. if params.rule.max_inputs then
  439. -- Train PCA in the main process, presumably it is not that long
  440. lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s",
  441. params.rule.prefix, params.set.name)
  442. pca = learn_pca(inputs, params.rule.max_inputs)
  443. end
  444. lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s",
  445. params.rule.prefix, params.set.name)
  446. local ret,err = pcall(train_ann.train1, train_ann,
  447. inputs, outputs, {
  448. lr = params.rule.train.learning_rate,
  449. max_epoch = params.rule.train.max_iterations,
  450. cb = train_cb,
  451. pca = pca
  452. })
  453. if not ret then
  454. rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
  455. params.rule.prefix, params.set.name, err)
  456. return nil
  457. else
  458. lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s",
  459. params.rule.prefix, params.set.name)
  460. end
  461. local roc_thresholds = {}
  462. if params.rule.roc_enabled then
  463. local spam_threshold = get_roc_thresholds(train_ann,
  464. inputs,
  465. outputs,
  466. 1 - params.rule.roc_misclassification_cost,
  467. params.rule.roc_misclassification_cost)
  468. local ham_threshold = get_roc_thresholds(train_ann,
  469. inputs,
  470. outputs,
  471. params.rule.roc_misclassification_cost,
  472. 1 - params.rule.roc_misclassification_cost)
  473. roc_thresholds = {spam_threshold, ham_threshold}
  474. rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)",
  475. roc_thresholds[1], roc_thresholds[2])
  476. end
  477. if not seen_nan then
  478. -- Convert to strings as ucl cannot rspamd_text properly
  479. local pca_data
  480. if pca then
  481. pca_data = tostring(pca:save())
  482. end
  483. local out = {
  484. ann_data = tostring(train_ann:save()),
  485. pca_data = pca_data,
  486. roc_thresholds = roc_thresholds,
  487. }
  488. local final_data = ucl.to_format(out, 'msgpack')
  489. lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes",
  490. params.rule.prefix, params.set.name, #final_data)
  491. return final_data
  492. else
  493. return nil
  494. end
  495. end
  496. params.set.learning_spawned = true
  497. local function redis_save_cb(err)
  498. if err then
  499. rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
  500. params.rule.prefix, params.set.name, params.ann_key, err)
  501. lua_redis.redis_make_request_taskless(params.ev_base,
  502. rspamd_config,
  503. params.rule.redis,
  504. nil,
  505. false, -- is write
  506. gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
  507. 'HDEL', -- command
  508. {params.ann_key, 'lock'}
  509. )
  510. else
  511. rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
  512. params.rule.prefix, params.set.name, params.set.ann.redis_key)
  513. end
  514. end
  515. local function ann_trained(err, data)
  516. params.set.learning_spawned = false
  517. if err then
  518. rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
  519. params.rule.prefix, params.set.name, err)
  520. lua_redis.redis_make_request_taskless(params.ev_base,
  521. rspamd_config,
  522. params.rule.redis,
  523. nil,
  524. true, -- is write
  525. gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
  526. 'HDEL', -- command
  527. {params.ann_key, 'lock'}
  528. )
  529. else
  530. local parser = ucl.parser()
  531. local ok, parse_err = parser:parse_text(data, 'msgpack')
  532. assert(ok, parse_err)
  533. local parsed = parser:get_object()
  534. local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
  535. local pca_data = parsed.pca_data
  536. local roc_thresholds = parsed.roc_thresholds
  537. fill_set_ann(params.set, params.ann_key)
  538. if pca_data then
  539. params.set.ann.pca = rspamd_tensor.load(pca_data)
  540. pca_data = rspamd_util.zstd_compress(pca_data)
  541. end
  542. if roc_thresholds then
  543. params.set.ann.roc_thresholds = roc_thresholds
  544. end
  545. -- Deserialise ANN from the child process
  546. ann_trained = rspamd_kann.load(parsed.ann_data)
  547. local version = (params.set.ann.version or 0) + 1
  548. params.set.ann.version = version
  549. params.set.ann.ann = ann_trained
  550. params.set.ann.symbols = params.set.symbols
  551. params.set.ann.redis_key = new_ann_key(params.rule, params.set, version)
  552. local profile = {
  553. symbols = params.set.symbols,
  554. digest = params.set.digest,
  555. redis_key = params.set.ann.redis_key,
  556. version = version
  557. }
  558. local profile_serialized = ucl.to_format(profile, 'json-compact', true)
  559. local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true)
  560. rspamd_logger.infox(rspamd_config,
  561. 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
  562. params.rule.prefix, params.set.name,
  563. #data, #ann_data,
  564. #(params.set.ann.pca or {}), #(pca_data or {}),
  565. params.set.ann.redis_key, params.ann_key)
  566. lua_redis.exec_redis_script(redis_script_id.save_unlock,
  567. {ev_base = params.ev_base, is_write = true},
  568. redis_save_cb,
  569. {profile.redis_key,
  570. redis_ann_prefix(params.rule, params.set.name),
  571. ann_data,
  572. profile_serialized,
  573. tostring(params.rule.ann_expire),
  574. tostring(os.time()),
  575. params.ann_key, -- old key to unlock...
  576. roc_thresholds_serialized,
  577. pca_data,
  578. })
  579. end
  580. end
  581. if params.rule.max_inputs then
  582. fill_set_ann(params.set, params.ann_key)
  583. end
  584. params.worker:spawn_process{
  585. func = train,
  586. on_complete = ann_trained,
  587. proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name),
  588. }
  589. -- Spawn learn and register lock extension
  590. params.set.learning_spawned = true
  591. register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key)
  592. return
  593. end
  594. end
  595. -- This function is used to adjust profiles and allowed setting ids for each rule
  596. -- It must be called when all settings are already registered (e.g. at post-init for config)
  597. local function process_rules_settings()
  598. local function process_settings_elt(rule, selt)
  599. local profile = rule.profile[selt.name]
  600. if profile then
  601. -- Use static user defined profile
  602. -- Ensure that we have an array...
  603. lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
  604. rule.prefix, selt.name, profile)
  605. if not profile[1] then profile = lua_util.keys(profile) end
  606. selt.symbols = profile
  607. else
  608. lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
  609. rule.prefix, selt.name)
  610. end
  611. local function filter_symbols_predicate(sname)
  612. if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then
  613. return false
  614. end
  615. local fl = rspamd_config:get_symbol_flags(sname)
  616. if fl then
  617. fl = lua_util.list_to_hash(fl)
  618. return not (fl.nostat or fl.idempotent or fl.skip or fl.composite)
  619. end
  620. return false
  621. end
  622. -- Generic stuff
  623. if not profile then
  624. -- Do filtering merely if we are using a dynamic profile
  625. selt.symbols = fun.totable(fun.filter(filter_symbols_predicate, selt.symbols))
  626. end
  627. table.sort(selt.symbols)
  628. selt.digest = lua_util.table_digest(selt.symbols)
  629. selt.prefix = redis_ann_prefix(rule, selt.name)
  630. rspamd_logger.messagex(rspamd_config,
  631. 'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"',
  632. selt.prefix, selt.name, selt.digest)
  633. lua_redis.register_prefix(selt.prefix, N,
  634. string.format('NN prefix for rule "%s"; settings id "%s"',
  635. selt.prefix, selt.name), {
  636. persistent = true,
  637. type = 'zlist',
  638. })
  639. -- Versions
  640. lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
  641. string.format('NN storage for rule "%s"; settings id "%s"',
  642. selt.prefix, selt.name), {
  643. persistent = true,
  644. type = 'hash',
  645. })
  646. lua_redis.register_prefix(selt.prefix .. '_\\d+_spam_set', N,
  647. string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
  648. selt.prefix, selt.name), {
  649. persistent = true,
  650. type = 'set',
  651. })
  652. lua_redis.register_prefix(selt.prefix .. '_\\d+_ham_set', N,
  653. string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
  654. rule.prefix, selt.name), {
  655. persistent = true,
  656. type = 'set',
  657. })
  658. end
  659. for k,rule in pairs(settings.rules) do
  660. if not rule.allowed_settings then
  661. rule.allowed_settings = {}
  662. elseif rule.allowed_settings == 'all' then
  663. -- Extract all settings ids
  664. rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
  665. end
  666. -- Convert to a map <setting_id> -> true
  667. rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
  668. -- Check if we can work without settings
  669. if k == 'default' or type(rule.default) ~= 'boolean' then
  670. rule.default = true
  671. end
  672. rule.settings = {}
  673. if rule.default then
  674. local default_settings = {
  675. symbols = lua_settings.default_symbols(),
  676. name = 'default'
  677. }
  678. process_settings_elt(rule, default_settings)
  679. rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
  680. end
  681. -- Now, for each allowed settings, we store sorted symbols + digest
  682. -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
  683. for s,_ in pairs(rule.allowed_settings) do
  684. -- Here, we have a name, set of symbols and
  685. local settings_id = s
  686. if type(settings_id) ~= 'number' then
  687. settings_id = lua_settings.numeric_settings_id(s)
  688. end
  689. local selt = lua_settings.settings_by_id(settings_id)
  690. local nelt = {
  691. symbols = selt.symbols, -- Already sorted
  692. name = selt.name
  693. }
  694. process_settings_elt(rule, nelt)
  695. for id,ex in pairs(rule.settings) do
  696. if type(ex) == 'table' then
  697. if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
  698. -- Equal symbols, add reference
  699. lua_util.debugm(N, rspamd_config,
  700. 'added reference from settings id %s to %s; same symbols',
  701. nelt.name, ex.name)
  702. rule.settings[settings_id] = id
  703. nelt = nil
  704. end
  705. end
  706. end
  707. if nelt then
  708. rule.settings[settings_id] = nelt
  709. lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
  710. nelt.name, settings_id, rule.prefix)
  711. end
  712. end
  713. end
  714. end
  715. -- Extract settings element for a specific settings id
  716. local function get_rule_settings(task, rule)
  717. local sid = task:get_settings_id() or -1
  718. local set = rule.settings[sid]
  719. if not set then return nil end
  720. while type(set) == 'number' do
  721. -- Reference to another settings!
  722. set = rule.settings[set]
  723. end
  724. return set
  725. end
  726. local function result_to_vector(task, profile)
  727. if not profile.zeros then
  728. -- Fill zeros vector
  729. local zeros = {}
  730. for i=1,meta_functions.count_metatokens() do
  731. zeros[i] = 0.0
  732. end
  733. for _,_ in ipairs(profile.symbols) do
  734. zeros[#zeros + 1] = 0.0
  735. end
  736. profile.zeros = zeros
  737. end
  738. local vec = lua_util.shallowcopy(profile.zeros)
  739. local mt = meta_functions.rspamd_gen_metatokens(task)
  740. for i,v in ipairs(mt) do
  741. vec[i] = v
  742. end
  743. task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
  744. return vec
  745. end
  746. return {
  747. can_push_train_vector = can_push_train_vector,
  748. create_ann = create_ann,
  749. default_options = default_options,
  750. gen_unlock_cb = gen_unlock_cb,
  751. get_rule_settings = get_rule_settings,
  752. load_scripts = load_scripts,
  753. module_config = module_config,
  754. new_ann_key = new_ann_key,
  755. plugin_ver = plugin_ver,
  756. process_rules_settings = process_rules_settings,
  757. redis_ann_prefix = redis_ann_prefix,
  758. redis_params = redis_params,
  759. redis_script_id = redis_script_id,
  760. result_to_vector = result_to_vector,
  761. settings = settings,
  762. spawn_train = spawn_train,
  763. }