You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

neural.lua 46KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483
  1. --[[
  2. Copyright (c) 2016, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. if confighelp then
  14. return
  15. end
  16. local rspamd_logger = require "rspamd_logger"
  17. local rspamd_util = require "rspamd_util"
  18. local rspamd_kann = require "rspamd_kann"
  19. local lua_redis = require "lua_redis"
  20. local lua_util = require "lua_util"
  21. local fun = require "fun"
  22. local lua_settings = require "lua_settings"
  23. local meta_functions = require "lua_meta"
  24. local ts = require("tableshape").types
  25. local lua_verdict = require "lua_verdict"
  26. local N = "neural"
  27. -- Module vars
  28. local default_options = {
  29. train = {
  30. max_trains = 1000,
  31. max_epoch = 1000,
  32. max_usages = 10,
  33. max_iterations = 25, -- Torch style
  34. mse = 0.001,
  35. autotrain = true,
  36. train_prob = 1.0,
  37. learn_threads = 1,
  38. learning_rate = 0.01,
  39. classes_bias = 0.0, -- What difference is allowed between classes (1:1 proportion means 0 bias)
  40. },
  41. watch_interval = 60.0,
  42. lock_expire = 600,
  43. learning_spawned = false,
  44. ann_expire = 60 * 60 * 24 * 2, -- 2 days
  45. symbol_spam = 'NEURAL_SPAM',
  46. symbol_ham = 'NEURAL_HAM',
  47. }
  48. local redis_profile_schema = ts.shape{
  49. digest = ts.string,
  50. symbols = ts.array_of(ts.string),
  51. version = ts.number,
  52. redis_key = ts.string,
  53. distance = ts.number:is_optional(),
  54. }
  55. -- Rule structure:
  56. -- * static config fields (see `default_options`)
  57. -- * prefix - name or defined prefix
  58. -- * settings - table of settings indexed by settings id, -1 is used when no settings defined
  59. -- Rule settings element defines elements for specific settings id:
  60. -- * symbols - static symbols profile (defined by config or extracted from symcache)
  61. -- * name - name of settings id
  62. -- * digest - digest of all symbols
  63. -- * ann - dynamic ANN configuration loaded from Redis
  64. -- * train - train data for ANN (e.g. the currently trained ANN)
  65. -- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
  66. -- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
  67. -- * version - version of ANN loaded from redis
  68. -- * redis_key - name of ANN key in Redis
  69. -- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
  70. -- * distance - distance between set.symbols and set.ann.symbols
  71. -- * ann - kann object
  72. local settings = {
  73. rules = {},
  74. prefix = 'rn', -- Neural network default prefix
  75. max_profiles = 3, -- Maximum number of NN profiles stored
  76. }
  77. local module_config = rspamd_config:get_all_opt("neural")
  78. if not module_config then
  79. -- Legacy
  80. module_config = rspamd_config:get_all_opt("fann_redis")
  81. end
  82. -- Lua script that checks if we can store a new training vector
  83. -- Uses the following keys:
  84. -- key1 - ann key
  85. -- key2 - spam or ham
  86. -- key3 - maximum trains
  87. -- key4 - sampling coin (as Redis scripts do not allow math.random calls)
  88. -- key5 - classes bias
  89. -- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
  90. local redis_lua_script_can_store_train_vec = [[
  91. local prefix = KEYS[1]
  92. local locked = redis.call('HGET', prefix, 'lock')
  93. if locked then return {tostring(-1),'locked by another process till: ' .. locked} end
  94. local nspam = 0
  95. local nham = 0
  96. local lim = tonumber(KEYS[3])
  97. local coin = tonumber(KEYS[4])
  98. local classes_bias = tonumber(KEYS[5])
  99. local ret = redis.call('LLEN', prefix .. '_spam')
  100. if ret then nspam = tonumber(ret) end
  101. ret = redis.call('LLEN', prefix .. '_ham')
  102. if ret then nham = tonumber(ret) end
  103. if KEYS[2] == 'spam' then
  104. if nspam <= lim then
  105. if nspam > nham then
  106. -- Apply sampling
  107. local skip_rate = 1.0 - nham / (nspam + 1)
  108. if coin < skip_rate - classes_bias then
  109. return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate - classes_bias)}
  110. end
  111. end
  112. return {tostring(nspam),'can learn'}
  113. else -- Enough learns
  114. return {tostring(-(nspam)),'too many spam samples'}
  115. end
  116. else
  117. if nham <= lim then
  118. if nham > nspam then
  119. -- Apply sampling
  120. local skip_rate = 1.0 - nspam / (nham + 1)
  121. if coin < skip_rate - classes_bias then
  122. return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate - classes_bias)}
  123. end
  124. end
  125. return {tostring(nham),'can learn'}
  126. else
  127. return {tostring(-(nham)),'too many ham samples'}
  128. end
  129. end
  130. return {tostring(-1),'bad input'}
  131. ]]
  132. local redis_can_store_train_vec_id = nil
  133. -- Lua script to invalidate ANNs by rank
  134. -- Uses the following keys
  135. -- key1 - prefix for keys
  136. -- key2 - number of elements to leave
  137. local redis_lua_script_maybe_invalidate = [[
  138. local card = redis.call('ZCARD', KEYS[1])
  139. local lim = tonumber(KEYS[2])
  140. if card > lim then
  141. local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
  142. for _,k in ipairs(to_delete) do
  143. local tb = cjson.decode(k)
  144. redis.call('DEL', tb.redis_key)
  145. -- Also train vectors
  146. redis.call('DEL', tb.redis_key .. '_spam')
  147. redis.call('DEL', tb.redis_key .. '_ham')
  148. end
  149. redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
  150. return to_delete
  151. else
  152. return {}
  153. end
  154. ]]
  155. local redis_maybe_invalidate_id = nil
  156. -- Lua script to invalidate ANN from redis
  157. -- Uses the following keys
  158. -- key1 - prefix for keys
  159. -- key2 - current time
  160. -- key3 - key expire
  161. -- key4 - hostname
  162. local redis_lua_script_maybe_lock = [[
  163. local locked = redis.call('HGET', KEYS[1], 'lock')
  164. local now = tonumber(KEYS[2])
  165. if locked then
  166. locked = tonumber(locked)
  167. local expire = tonumber(KEYS[3])
  168. if now > locked and (now - locked) < expire then
  169. return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')}
  170. end
  171. end
  172. redis.call('HSET', KEYS[1], 'lock', tostring(now))
  173. redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
  174. return 1
  175. ]]
  176. local redis_maybe_lock_id = nil
  177. -- Lua script to save and unlock ANN in redis
  178. -- Uses the following keys
  179. -- key1 - prefix for ANN
  180. -- key2 - prefix for profile
  181. -- key3 - compressed ANN
  182. -- key4 - profile as JSON
  183. -- key5 - expire in seconds
  184. -- key6 - current time
  185. -- key7 - old key
  186. local redis_lua_script_save_unlock = [[
  187. local now = tonumber(KEYS[6])
  188. redis.call('ZADD', KEYS[2], now, KEYS[4])
  189. redis.call('HSET', KEYS[1], 'ann', KEYS[3])
  190. redis.call('DEL', KEYS[1] .. '_spam')
  191. redis.call('DEL', KEYS[1] .. '_ham')
  192. redis.call('HDEL', KEYS[1], 'lock')
  193. redis.call('HDEL', KEYS[7], 'lock')
  194. redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
  195. return 1
  196. ]]
  197. local redis_save_unlock_id = nil
  198. local redis_params
  199. local function load_scripts(params)
  200. redis_can_store_train_vec_id = lua_redis.add_redis_script(redis_lua_script_can_store_train_vec,
  201. params)
  202. redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
  203. params)
  204. redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
  205. params)
  206. redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock,
  207. params)
  208. end
  209. local function result_to_vector(task, profile)
  210. if not profile.zeros then
  211. -- Fill zeros vector
  212. local zeros = {}
  213. for i=1,meta_functions.count_metatokens() do
  214. zeros[i] = 0.0
  215. end
  216. for _,_ in ipairs(profile.symbols) do
  217. zeros[#zeros + 1] = 0.0
  218. end
  219. profile.zeros = zeros
  220. end
  221. local vec = lua_util.shallowcopy(profile.zeros)
  222. local mt = meta_functions.rspamd_gen_metatokens(task)
  223. for i,v in ipairs(mt) do
  224. vec[i] = v
  225. end
  226. task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
  227. return vec
  228. end
  229. -- Used to generate new ANN key for specific profile
  230. local function new_ann_key(rule, set, version)
  231. local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
  232. rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
  233. return ann_key
  234. end
  235. -- Extract settings element for a specific settings id
  236. local function get_rule_settings(task, rule)
  237. local sid = task:get_settings_id() or -1
  238. local set = rule.settings[sid]
  239. if not set then return nil end
  240. while type(set) == 'number' do
  241. -- Reference to another settings!
  242. set = rule.settings[set]
  243. end
  244. return set
  245. end
  246. -- Generate redis prefix for specific rule and specific settings
  247. local function redis_ann_prefix(rule, settings_name)
  248. -- We also need to count metatokens:
  249. local n = meta_functions.version
  250. return string.format('%s_%s_%d_%s',
  251. settings.prefix, rule.prefix, n, settings_name)
  252. end
  253. -- Creates and stores ANN profile in Redis
  254. local function new_ann_profile(task, rule, set, version)
  255. local ann_key = new_ann_key(rule, set, version)
  256. local profile = {
  257. symbols = set.symbols,
  258. redis_key = ann_key,
  259. version = version,
  260. digest = set.digest,
  261. distance = 0 -- Since we are using our own profile
  262. }
  263. local ucl = require "ucl"
  264. local profile_serialized = ucl.to_format(profile, 'json-compact', true)
  265. local function add_cb(err, _)
  266. if err then
  267. rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
  268. rule.prefix, set.name, profile.redis_key, err)
  269. else
  270. rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
  271. rule.prefix, set.name, profile.redis_key)
  272. end
  273. end
  274. lua_redis.redis_make_request(task,
  275. rule.redis,
  276. nil,
  277. true, -- is write
  278. add_cb, --callback
  279. 'ZADD', -- command
  280. {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
  281. )
  282. return profile
  283. end
  284. -- ANN filter function, used to insert scores based on the existing symbols
  285. local function ann_scores_filter(task)
  286. for _,rule in pairs(settings.rules) do
  287. local sid = task:get_settings_id() or -1
  288. local ann
  289. local profile
  290. local set = get_rule_settings(task, rule)
  291. if set then
  292. if set.ann then
  293. ann = set.ann.ann
  294. profile = set.ann
  295. else
  296. lua_util.debugm(N, task, 'no ann loaded for %s:%s',
  297. rule.prefix, set.name)
  298. end
  299. else
  300. lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
  301. rule.prefix, sid)
  302. end
  303. if ann then
  304. local vec = result_to_vector(task, profile)
  305. local score
  306. local out = ann:apply1(vec)
  307. score = out[1]
  308. local symscore = string.format('%.3f', score)
  309. lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
  310. rule.prefix, set.name, set.ann.version, symscore)
  311. if score > 0 then
  312. local result = score
  313. task:insert_result(rule.symbol_spam, result, symscore)
  314. else
  315. local result = -(score)
  316. task:insert_result(rule.symbol_ham, result, symscore)
  317. end
  318. end
  319. end
  320. end
  321. local function create_ann(n, nlayers)
  322. -- We ignore number of layers so far when using kann
  323. local nhidden = math.floor((n + 1) / 2)
  324. local t = rspamd_kann.layer.input(n)
  325. t = rspamd_kann.transform.relu(t)
  326. t = rspamd_kann.transform.tanh(rspamd_kann.layer.dense(t, nhidden));
  327. t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.mse)
  328. return rspamd_kann.new.kann(t)
  329. end
  330. local function ann_push_task_result(rule, task, verdict, score, set)
  331. local train_opts = rule.train
  332. local learn_spam, learn_ham
  333. local skip_reason = 'unknown'
  334. if train_opts.autotrain then
  335. if train_opts.spam_score then
  336. learn_spam = score >= train_opts.spam_score
  337. if not learn_spam then
  338. skip_reason = string.format('score < spam_score: %f < %f',
  339. score, train_opts.spam_score)
  340. end
  341. else
  342. learn_spam = verdict == 'spam' or verdict == 'junk'
  343. if not learn_spam then
  344. skip_reason = string.format('verdict: %s',
  345. verdict)
  346. end
  347. end
  348. if train_opts.ham_score then
  349. learn_ham = score <= train_opts.ham_score
  350. if not learn_ham then
  351. skip_reason = string.format('score > ham_score: %f > %f',
  352. score, train_opts.ham_score)
  353. end
  354. else
  355. learn_ham = verdict == 'ham'
  356. if not learn_ham then
  357. skip_reason = string.format('verdict: %s',
  358. verdict)
  359. end
  360. end
  361. else
  362. -- Train by request header
  363. local hdr = task:get_request_header('ANN-Train')
  364. if hdr then
  365. if hdr:lower() == 'spam' then
  366. learn_spam = true
  367. elseif hdr:lower() == 'ham' then
  368. learn_ham = true
  369. else
  370. skip_reason = string.format('no explicit header')
  371. end
  372. end
  373. end
  374. if learn_spam or learn_ham then
  375. local learn_type
  376. if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
  377. local function can_train_cb(err, data)
  378. if not err and type(data) == 'table' then
  379. local nsamples,reason = tonumber(data[1]),data[2]
  380. if nsamples >= 0 then
  381. local coin = math.random()
  382. if coin < 1.0 - train_opts.train_prob then
  383. rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
  384. return
  385. end
  386. local vec = result_to_vector(task, set)
  387. local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
  388. local target_key = set.ann.redis_key .. '_' .. learn_type
  389. local function learn_vec_cb(_err)
  390. if _err then
  391. rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
  392. rule.prefix, set.name, _err)
  393. else
  394. lua_util.debugm(N, task,
  395. "add train data for ANN rule " ..
  396. "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
  397. rule.prefix, set.name, learn_type, #vec, target_key, #str)
  398. end
  399. end
  400. lua_redis.redis_make_request(task,
  401. rule.redis,
  402. nil,
  403. true, -- is write
  404. learn_vec_cb, --callback
  405. 'LPUSH', -- command
  406. { target_key, str } -- arguments
  407. )
  408. else
  409. -- Negative result returned
  410. rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: %s (%s vectors stored)",
  411. learn_type, rule.prefix, set.name, set.ann.redis_key, reason, -tonumber(nsamples))
  412. end
  413. else
  414. if err then
  415. rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
  416. rule.prefix, set.name, err)
  417. else
  418. rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
  419. 'please remove this key from Redis manually if you perform upgrade from the previous version',
  420. rule.prefix, set.name, set.ann.redis_key, type(data))
  421. end
  422. end
  423. end
  424. -- Check if we can learn
  425. if set.can_store_vectors then
  426. if not set.ann then
  427. -- Need to create or load a profile corresponding to the current configuration
  428. set.ann = new_ann_profile(task, rule, set, 0)
  429. lua_util.debugm(N, task,
  430. 'requested new profile for %s, set.ann is missing',
  431. set.name)
  432. end
  433. lua_redis.exec_redis_script(redis_can_store_train_vec_id,
  434. {task = task, is_write = true},
  435. can_train_cb,
  436. {
  437. set.ann.redis_key,
  438. learn_type,
  439. tostring(train_opts.max_trains),
  440. tostring(math.random()),
  441. tostring(train_opts.classes_bias)
  442. })
  443. else
  444. lua_util.debugm(N, task,
  445. 'do not push data: train condition not satisfied; reason: not checked existing ANNs')
  446. end
  447. else
  448. lua_util.debugm(N, task,
  449. 'do not push data to key %s: train condition not satisfied; reason: %s',
  450. (set.ann or {}).redis_key,
  451. skip_reason)
  452. end
  453. end
  454. --- Offline training logic
  455. -- Closure generator for unlock function
  456. local function gen_unlock_cb(rule, set, ann_key)
  457. return function (err)
  458. if err then
  459. rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
  460. rule.prefix, set.name, ann_key, err)
  461. else
  462. lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
  463. rule.prefix, set.name, ann_key)
  464. end
  465. end
  466. end
  467. -- This function is intended to extend lock for ANN during training
  468. -- It registers periodic that increases locked key each 30 seconds unless
  469. -- `set.learning_spawned` is set to `true`
  470. local function register_lock_extender(rule, set, ev_base, ann_key)
  471. rspamd_config:add_periodic(ev_base, 30.0,
  472. function()
  473. local function redis_lock_extend_cb(_err, _)
  474. if _err then
  475. rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
  476. ann_key, _err)
  477. else
  478. rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
  479. ann_key)
  480. end
  481. end
  482. if set.learning_spawned then
  483. lua_redis.redis_make_request_taskless(ev_base,
  484. rspamd_config,
  485. rule.redis,
  486. nil,
  487. true, -- is write
  488. redis_lock_extend_cb, --callback
  489. 'HINCRBY', -- command
  490. {ann_key, 'lock', '30'}
  491. )
  492. else
  493. lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
  494. return false -- do not plan any more updates
  495. end
  496. return true
  497. end
  498. )
  499. end
  500. -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
  501. local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
  502. -- Check training data sanity
  503. -- Now we need to join inputs and create the appropriate test vectors
  504. local n = #set.symbols +
  505. meta_functions.rspamd_count_metatokens()
  506. -- Now we can train ann
  507. local train_ann = create_ann(n, 3)
  508. if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
  509. -- Invalidate ANN as it is definitely invalid
  510. -- TODO: add invalidation
  511. assert(false)
  512. else
  513. local inputs, outputs = {}, {}
  514. -- Used to show sparsed vectors in a convenient format (for debugging only)
  515. local function debug_vec(t)
  516. local ret = {}
  517. for i,v in ipairs(t) do
  518. if v ~= 0 then
  519. ret[#ret + 1] = string.format('%d=%.2f', i, v)
  520. end
  521. end
  522. return ret
  523. end
  524. -- Make training set by joining vectors
  525. -- KANN automatically shuffles those samples
  526. -- 1.0 is used for spam and -1.0 is used for ham
  527. -- It implies that output layer can express that (e.g. tanh output)
  528. for _,e in ipairs(spam_vec) do
  529. inputs[#inputs + 1] = e
  530. outputs[#outputs + 1] = {1.0}
  531. --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
  532. end
  533. for _,e in ipairs(ham_vec) do
  534. inputs[#inputs + 1] = e
  535. outputs[#outputs + 1] = {-1.0}
  536. --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
  537. end
  538. -- Called in child process
  539. local function train()
  540. local log_thresh = rule.train.max_iterations / 10
  541. local seen_nan = false
  542. local function train_cb(iter, train_cost, value_cost)
  543. if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
  544. if train_cost ~= train_cost and not seen_nan then
  545. -- We have nan :( try to log lot's of stuff to dig into a problem
  546. seen_nan = true
  547. rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
  548. rule.prefix, set.name,
  549. value_cost)
  550. for i,e in ipairs(inputs) do
  551. lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
  552. debug_vec(e), outputs[i][1])
  553. end
  554. end
  555. rspamd_logger.infox(rspamd_config,
  556. "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
  557. rule.prefix, set.name,
  558. ann_key,
  559. iter,
  560. train_cost,
  561. value_cost)
  562. end
  563. end
  564. train_ann:train1(inputs, outputs, {
  565. lr = rule.train.learning_rate,
  566. max_epoch = rule.train.max_iterations,
  567. cb = train_cb,
  568. })
  569. if not seen_nan then
  570. local out = train_ann:save()
  571. return out
  572. else
  573. return nil
  574. end
  575. end
  576. set.learning_spawned = true
  577. local function redis_save_cb(err)
  578. if err then
  579. rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
  580. rule.prefix, set.name, ann_key, err)
  581. lua_redis.redis_make_request_taskless(ev_base,
  582. rspamd_config,
  583. rule.redis,
  584. nil,
  585. false, -- is write
  586. gen_unlock_cb(rule, set, ann_key), --callback
  587. 'HDEL', -- command
  588. {ann_key, 'lock'}
  589. )
  590. else
  591. rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
  592. rule.prefix, set.name, set.ann.redis_key)
  593. end
  594. end
  595. local function ann_trained(err, data)
  596. set.learning_spawned = false
  597. if err then
  598. rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
  599. rule.prefix, set.name, err)
  600. lua_redis.redis_make_request_taskless(ev_base,
  601. rspamd_config,
  602. rule.redis,
  603. nil,
  604. true, -- is write
  605. gen_unlock_cb(rule, set, ann_key), --callback
  606. 'HDEL', -- command
  607. {ann_key, 'lock'}
  608. )
  609. else
  610. local ann_data = rspamd_util.zstd_compress(data)
  611. if not set.ann then
  612. set.ann = {
  613. symbols = set.symbols,
  614. distance = 0,
  615. digest = set.digest,
  616. redis_key = ann_key,
  617. }
  618. end
  619. -- Deserialise ANN from the child process
  620. ann_trained = rspamd_kann.load(data)
  621. local version = (set.ann.version or 0) + 1
  622. set.ann.version = version
  623. set.ann.ann = ann_trained
  624. set.ann.symbols = set.symbols
  625. set.ann.redis_key = new_ann_key(rule, set, version)
  626. local profile = {
  627. symbols = set.symbols,
  628. digest = set.digest,
  629. redis_key = set.ann.redis_key,
  630. version = version
  631. }
  632. local ucl = require "ucl"
  633. local profile_serialized = ucl.to_format(profile, 'json-compact', true)
  634. rspamd_logger.infox(rspamd_config,
  635. 'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)',
  636. rule.prefix, set.name, #data, set.ann.redis_key, ann_key)
  637. lua_redis.exec_redis_script(redis_save_unlock_id,
  638. {ev_base = ev_base, is_write = true},
  639. redis_save_cb,
  640. {profile.redis_key,
  641. redis_ann_prefix(rule, set.name),
  642. ann_data,
  643. profile_serialized,
  644. tostring(rule.ann_expire),
  645. tostring(os.time()),
  646. ann_key, -- old key to unlock...
  647. })
  648. end
  649. end
  650. worker:spawn_process{
  651. func = train,
  652. on_complete = ann_trained,
  653. proctitle = string.format("ANN train for %s/%s", rule.prefix, set.name),
  654. }
  655. end
  656. -- Spawn learn and register lock extension
  657. set.learning_spawned = true
  658. register_lock_extender(rule, set, ev_base, ann_key)
  659. end
  660. -- Utility to extract and split saved training vectors to a table of tables
  661. local function process_training_vectors(data)
  662. return fun.totable(fun.map(function(tok)
  663. local _,str = rspamd_util.zstd_decompress(tok)
  664. return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';')))
  665. end, data))
  666. end
  667. -- This function does the following:
  668. -- * Tries to lock ANN
  669. -- * Loads spam and ham vectors
  670. -- * Spawn learning process
  671. local function do_train_ann(worker, ev_base, rule, set, ann_key)
  672. local spam_elts = {}
  673. local ham_elts = {}
  674. local function redis_ham_cb(err, data)
  675. if err or type(data) ~= 'table' then
  676. rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
  677. ann_key, err)
  678. -- Unlock on error
  679. lua_redis.redis_make_request_taskless(ev_base,
  680. rspamd_config,
  681. rule.redis,
  682. nil,
  683. true, -- is write
  684. gen_unlock_cb(rule, set, ann_key), --callback
  685. 'HDEL', -- command
  686. {ann_key, 'lock'}
  687. )
  688. else
  689. -- Decompress and convert to numbers each training vector
  690. ham_elts = process_training_vectors(data)
  691. spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts)
  692. end
  693. end
  694. -- Spam vectors received
  695. local function redis_spam_cb(err, data)
  696. if err or type(data) ~= 'table' then
  697. rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
  698. ann_key, err)
  699. -- Unlock ANN on error
  700. lua_redis.redis_make_request_taskless(ev_base,
  701. rspamd_config,
  702. rule.redis,
  703. nil,
  704. true, -- is write
  705. gen_unlock_cb(rule, set, ann_key), --callback
  706. 'HDEL', -- command
  707. {ann_key, 'lock'}
  708. )
  709. else
  710. -- Decompress and convert to numbers each training vector
  711. spam_elts = process_training_vectors(data)
  712. -- Now get ham vectors...
  713. lua_redis.redis_make_request_taskless(ev_base,
  714. rspamd_config,
  715. rule.redis,
  716. nil,
  717. false, -- is write
  718. redis_ham_cb, --callback
  719. 'LRANGE', -- command
  720. {ann_key .. '_ham', '0', '-1'}
  721. )
  722. end
  723. end
  724. local function redis_lock_cb(err, data)
  725. if err then
  726. rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
  727. ann_key, err)
  728. elseif type(data) == 'number' and data == 1 then
  729. -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
  730. lua_redis.redis_make_request_taskless(ev_base,
  731. rspamd_config,
  732. rule.redis,
  733. nil,
  734. false, -- is write
  735. redis_spam_cb, --callback
  736. 'LRANGE', -- command
  737. {ann_key .. '_spam', '0', '-1'}
  738. )
  739. rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
  740. rule.prefix, set.name, ann_key)
  741. else
  742. local lock_tm = tonumber(data[1])
  743. rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
  744. 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
  745. data[2], os.date('%c', lock_tm))
  746. end
  747. end
  748. -- Check if we are already learning this network
  749. if set.learning_spawned then
  750. rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
  751. ann_key)
  752. return
  753. end
  754. -- Call Redis script that tries to acquire a lock
  755. -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
  756. -- ANN is locked by another host (or a process, meh)
  757. lua_redis.exec_redis_script(redis_maybe_lock_id,
  758. {ev_base = ev_base, is_write = true},
  759. redis_lock_cb,
  760. {
  761. ann_key,
  762. tostring(os.time()),
  763. tostring(rule.watch_interval * 2),
  764. rspamd_util.get_hostname()
  765. })
  766. end
  767. -- This function loads new ann from Redis
  768. -- This is based on `profile` attribute.
  769. -- ANN is loaded from `profile.redis_key`
  770. -- Rank of `profile` key is also increased, unfortunately, it means that we need to
  771. -- serialize profile one more time and set its rank to the current time
  772. -- set.ann fields are set according to Redis data received
  773. local function load_new_ann(rule, ev_base, set, profile, min_diff)
  774. local ann_key = profile.redis_key
  775. local function data_cb(err, data)
  776. if err then
  777. rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
  778. ann_key, err)
  779. else
  780. if type(data) == 'string' then
  781. local _err,ann_data = rspamd_util.zstd_decompress(data)
  782. local ann
  783. if _err or not ann_data then
  784. rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
  785. rule.prefix .. ':' .. set.name, ann_key, _err)
  786. return
  787. else
  788. ann = rspamd_kann.load(ann_data)
  789. if ann then
  790. set.ann = {
  791. digest = profile.digest,
  792. version = profile.version,
  793. symbols = profile.symbols,
  794. distance = min_diff,
  795. redis_key = profile.redis_key
  796. }
  797. local ucl = require "ucl"
  798. local profile_serialized = ucl.to_format(profile, 'json-compact', true)
  799. set.ann.ann = ann -- To avoid serialization
  800. local function rank_cb(_, _)
  801. -- TODO: maybe add some logging
  802. end
  803. -- Also update rank for the loaded ANN to avoid removal
  804. lua_redis.redis_make_request_taskless(ev_base,
  805. rspamd_config,
  806. rule.redis,
  807. nil,
  808. true, -- is write
  809. rank_cb, --callback
  810. 'ZADD', -- command
  811. {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
  812. )
  813. rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
  814. rule.prefix, set.name, ann_key, #ann_data, profile.version)
  815. else
  816. rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s',
  817. rule.prefix, set.name, ann_key)
  818. end
  819. end
  820. else
  821. lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s',
  822. rule.prefix, set.name, ann_key)
  823. end
  824. end
  825. end
  826. lua_redis.redis_make_request_taskless(ev_base,
  827. rspamd_config,
  828. rule.redis,
  829. nil,
  830. false, -- is write
  831. data_cb, --callback
  832. 'HGET', -- command
  833. {ann_key, 'ann'} -- arguments
  834. )
  835. end
  836. -- Used to check an element in Redis serialized as JSON
  837. -- for some specific rule + some specific setting
  838. -- This function tries to load more fresh or more specific ANNs in lieu of
  839. -- the existing ones.
  840. -- Use this function to load ANNs as `callback` parameter for `check_anns` function
  841. local function process_existing_ann(_, ev_base, rule, set, profiles)
  842. local my_symbols = set.symbols
  843. local min_diff = math.huge
  844. local sel_elt
  845. for _,elt in fun.iter(profiles) do
  846. if elt and elt.symbols then
  847. local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
  848. -- Check distance
  849. if dist < #my_symbols * .3 then
  850. if dist < min_diff then
  851. min_diff = dist
  852. sel_elt = elt
  853. end
  854. end
  855. end
  856. end
  857. if sel_elt then
  858. -- We can load element from ANN
  859. if set.ann then
  860. -- We have an existing ANN, probably the same...
  861. if set.ann.digest == sel_elt.digest then
  862. -- Same ANN, check version
  863. if set.ann.version < sel_elt.version then
  864. -- Load new ann
  865. rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
  866. 'our version = %s, remote version = %s',
  867. rule.prefix .. ':' .. set.name,
  868. set.ann.version,
  869. sel_elt.version)
  870. load_new_ann(rule, ev_base, set, sel_elt, min_diff)
  871. else
  872. lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
  873. 'our version = %s, remote version = %s',
  874. rule.prefix .. ':' .. set.name,
  875. set.ann.version,
  876. sel_elt.version)
  877. end
  878. else
  879. -- We have some different ANN, so we need to compare distance
  880. if set.ann.distance > min_diff then
  881. -- Load more specific ANN
  882. rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
  883. 'our distance = %s, remote distance = %s',
  884. rule.prefix .. ':' .. set.name,
  885. set.ann.distance,
  886. min_diff)
  887. load_new_ann(rule, ev_base, set, sel_elt, min_diff)
  888. else
  889. lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
  890. 'our distance = %s, remote distance = %s',
  891. rule.prefix .. ':' .. set.name,
  892. set.ann.distance,
  893. min_diff)
  894. end
  895. end
  896. else
  897. -- We have no ANN, load new one
  898. load_new_ann(rule, ev_base, set, sel_elt, min_diff)
  899. end
  900. end
  901. end
  902. -- This function checks all profiles and selects if we can train our
  903. -- ANN. By our we mean that it has exactly the same symbols in profile.
  904. -- Use this function to train ANN as `callback` parameter for `check_anns` function
  905. local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
  906. local my_symbols = set.symbols
  907. local sel_elt
  908. local lens = {
  909. spam = 0,
  910. ham = 0,
  911. }
  912. for _,elt in fun.iter(profiles) do
  913. if elt and elt.symbols then
  914. local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
  915. -- Check distance
  916. if dist == 0 then
  917. sel_elt = elt
  918. break
  919. end
  920. end
  921. end
  922. if sel_elt then
  923. -- We have our ANN and that's train vectors, check if we can learn
  924. local ann_key = sel_elt.redis_key
  925. lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
  926. ann_key)
  927. -- Create continuation closure
  928. local redis_len_cb_gen = function(cont_cb, what, is_final)
  929. return function(err, data)
  930. if err then
  931. rspamd_logger.errx(rspamd_config,
  932. 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
  933. elseif data and type(data) == 'number' or type(data) == 'string' then
  934. local ntrains = tonumber(data) or 0
  935. lens[what] = ntrains
  936. if is_final then
  937. -- Ensure that we have the following:
  938. -- one class has reached max_trains
  939. -- other class(es) are at least as full as classes_bias
  940. -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
  941. -- one class must have 10 or more trains whilst another should have
  942. -- at least (10 * (1 - 0.25)) = 8 trains
  943. local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
  944. local len_bias_check_pred = function(_, l)
  945. return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
  946. end
  947. if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
  948. rspamd_logger.debugm(N, rspamd_config,
  949. 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
  950. ann_key, lens, rule.train.max_trains, what)
  951. cont_cb()
  952. else
  953. rspamd_logger.debugm(N, rspamd_config,
  954. 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
  955. ann_key, what, lens, rule.train.max_trains)
  956. end
  957. else
  958. rspamd_logger.debugm(N, rspamd_config,
  959. 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
  960. what, ann_key, ntrains, rule.train.max_trains)
  961. cont_cb()
  962. end
  963. end
  964. end
  965. end
  966. local function initiate_train()
  967. rspamd_logger.infox(rspamd_config,
  968. 'need to learn ANN %s after %s required learn vectors',
  969. ann_key, lens)
  970. do_train_ann(worker, ev_base, rule, set, ann_key)
  971. end
  972. -- Spam vector is OK, check ham vector length
  973. local function check_ham_len()
  974. lua_redis.redis_make_request_taskless(ev_base,
  975. rspamd_config,
  976. rule.redis,
  977. nil,
  978. false, -- is write
  979. redis_len_cb_gen(initiate_train, 'ham', true), --callback
  980. 'LLEN', -- command
  981. {ann_key .. '_ham'}
  982. )
  983. end
  984. lua_redis.redis_make_request_taskless(ev_base,
  985. rspamd_config,
  986. rule.redis,
  987. nil,
  988. false, -- is write
  989. redis_len_cb_gen(check_ham_len, 'spam', false), --callback
  990. 'LLEN', -- command
  991. {ann_key .. '_spam'}
  992. )
  993. end
  994. end
  995. -- Used to deserialise ANN element from a list
  996. local function load_ann_profile(element)
  997. local ucl = require "ucl"
  998. local parser = ucl.parser()
  999. local res,ucl_err = parser:parse_string(element)
  1000. if not res then
  1001. rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
  1002. ucl_err)
  1003. return nil
  1004. else
  1005. local profile = parser:get_object()
  1006. local checked,schema_err = redis_profile_schema:transform(profile)
  1007. if not checked then
  1008. rspamd_logger.errx(rspamd_config, "cannot parse profile schema: %s", schema_err)
  1009. return nil
  1010. end
  1011. return checked
  1012. end
  1013. end
  1014. -- Function to check or load ANNs from Redis
  1015. local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
  1016. for _,set in pairs(rule.settings) do
  1017. local function members_cb(err, data)
  1018. if err then
  1019. rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
  1020. err)
  1021. set.can_store_vectors = true
  1022. elseif type(data) == 'table' then
  1023. lua_util.debugm(N, cfg, '%s: process element %s:%s',
  1024. what, rule.prefix, set.name)
  1025. process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
  1026. set.can_store_vectors = true
  1027. end
  1028. end
  1029. if type(set) == 'table' then
  1030. -- Extract all profiles for some specific settings id
  1031. -- Get the last `max_profiles` recently used
  1032. -- Select the most appropriate to our profile but it should not differ by more
  1033. -- than 30% of symbols
  1034. lua_redis.redis_make_request_taskless(ev_base,
  1035. cfg,
  1036. rule.redis,
  1037. nil,
  1038. false, -- is write
  1039. members_cb, --callback
  1040. 'ZREVRANGE', -- command
  1041. {set.prefix, '0', tostring(settings.max_profiles)} -- arguments
  1042. )
  1043. end
  1044. end -- Cycle over all settings
  1045. return rule.watch_interval
  1046. end
  1047. -- Function to clean up old ANNs
  1048. local function cleanup_anns(rule, cfg, ev_base)
  1049. for _,set in pairs(rule.settings) do
  1050. local function invalidate_cb(err, data)
  1051. if err then
  1052. rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
  1053. err)
  1054. elseif type(data) == 'table' then
  1055. for _,expired in ipairs(data) do
  1056. local profile = load_ann_profile(expired)
  1057. rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
  1058. rule.prefix .. ':' .. set.name,
  1059. profile.redis_key,
  1060. profile.version)
  1061. end
  1062. end
  1063. end
  1064. if type(set) == 'table' then
  1065. lua_redis.exec_redis_script(redis_maybe_invalidate_id,
  1066. {ev_base = ev_base, is_write = true},
  1067. invalidate_cb,
  1068. {set.prefix, tostring(settings.max_profiles)})
  1069. end
  1070. end
  1071. end
  1072. local function ann_push_vector(task)
  1073. if task:has_flag('skip') then
  1074. lua_util.debugm(N, task, 'do not push data for skipped task')
  1075. return
  1076. end
  1077. if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then
  1078. lua_util.debugm(N, task, 'do not push data for manual scan')
  1079. return
  1080. end
  1081. local verdict,score = lua_verdict.get_specific_verdict(N, task)
  1082. if verdict == 'passthrough' then
  1083. lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
  1084. verdict, score)
  1085. return
  1086. end
  1087. if score ~= score then
  1088. lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
  1089. verdict)
  1090. return
  1091. end
  1092. for _,rule in pairs(settings.rules) do
  1093. local set = get_rule_settings(task, rule)
  1094. if set then
  1095. ann_push_task_result(rule, task, verdict, score, set)
  1096. else
  1097. lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
  1098. end
  1099. end
  1100. end
  1101. -- This function is used to adjust profiles and allowed setting ids for each rule
  1102. -- It must be called when all settings are already registered (e.g. at post-init for config)
  1103. local function process_rules_settings()
  1104. local function process_settings_elt(rule, selt)
  1105. local profile = rule.profile[selt.name]
  1106. if profile then
  1107. -- Use static user defined profile
  1108. -- Ensure that we have an array...
  1109. lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
  1110. rule.prefix, selt.name, profile)
  1111. if not profile[1] then profile = lua_util.keys(profile) end
  1112. selt.symbols = profile
  1113. else
  1114. lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
  1115. rule.prefix, selt.name)
  1116. end
  1117. local function filter_symbols_predicate(sname)
  1118. local fl = rspamd_config:get_symbol_flags(sname)
  1119. if fl then
  1120. fl = lua_util.list_to_hash(fl)
  1121. return not (fl.nostat or fl.idempotent or fl.skip)
  1122. end
  1123. return false
  1124. end
  1125. -- Generic stuff
  1126. table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))
  1127. selt.digest = lua_util.table_digest(selt.symbols)
  1128. selt.prefix = redis_ann_prefix(rule, selt.name)
  1129. lua_redis.register_prefix(selt.prefix, N,
  1130. string.format('NN prefix for rule "%s"; settings id "%s"',
  1131. rule.prefix, selt.name), {
  1132. persistent = true,
  1133. type = 'zlist',
  1134. })
  1135. -- Versions
  1136. lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
  1137. string.format('NN storage for rule "%s"; settings id "%s"',
  1138. rule.prefix, selt.name), {
  1139. persistent = true,
  1140. type = 'hash',
  1141. })
  1142. lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N,
  1143. string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
  1144. rule.prefix, selt.name), {
  1145. persistent = true,
  1146. type = 'list',
  1147. })
  1148. lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N,
  1149. string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
  1150. rule.prefix, selt.name), {
  1151. persistent = true,
  1152. type = 'list',
  1153. })
  1154. end
  1155. for k,rule in pairs(settings.rules) do
  1156. if not rule.allowed_settings then
  1157. rule.allowed_settings = {}
  1158. elseif rule.allowed_settings == 'all' then
  1159. -- Extract all settings ids
  1160. rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
  1161. end
  1162. -- Convert to a map <setting_id> -> true
  1163. rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
  1164. -- Check if we can work without settings
  1165. if k == 'default' or type(rule.default) ~= 'boolean' then
  1166. rule.default = true
  1167. end
  1168. rule.settings = {}
  1169. if rule.default then
  1170. local default_settings = {
  1171. symbols = lua_settings.default_symbols(),
  1172. name = 'default'
  1173. }
  1174. process_settings_elt(rule, default_settings)
  1175. rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
  1176. end
  1177. -- Now, for each allowed settings, we store sorted symbols + digest
  1178. -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
  1179. for s,_ in pairs(rule.allowed_settings) do
  1180. -- Here, we have a name, set of symbols and
  1181. local settings_id = s
  1182. if type(settings_id) ~= 'number' then
  1183. settings_id = lua_settings.numeric_settings_id(s)
  1184. end
  1185. local selt = lua_settings.settings_by_id(settings_id)
  1186. local nelt = {
  1187. symbols = selt.symbols, -- Already sorted
  1188. name = selt.name
  1189. }
  1190. process_settings_elt(rule, nelt)
  1191. for id,ex in pairs(rule.settings) do
  1192. if type(ex) == 'table' then
  1193. if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
  1194. -- Equal symbols, add reference
  1195. lua_util.debugm(N, rspamd_config,
  1196. 'added reference from settings id %s to %s; same symbols',
  1197. nelt.name, ex.name)
  1198. rule.settings[settings_id] = id
  1199. nelt = nil
  1200. end
  1201. end
  1202. end
  1203. if nelt then
  1204. rule.settings[settings_id] = nelt
  1205. lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
  1206. nelt.name, settings_id, rule.prefix)
  1207. end
  1208. end
  1209. end
  1210. end
  1211. redis_params = lua_redis.parse_redis_server('neural')
  1212. if not redis_params then
  1213. redis_params = lua_redis.parse_redis_server('fann_redis')
  1214. end
  1215. -- Initialization part
  1216. if not (module_config and type(module_config) == 'table') or not redis_params then
  1217. rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
  1218. lua_util.disable_module(N, "redis")
  1219. return
  1220. end
  1221. local rules = module_config['rules']
  1222. if not rules then
  1223. -- Use legacy configuration
  1224. rules = {}
  1225. rules['default'] = module_config
  1226. end
  1227. local id = rspamd_config:register_symbol({
  1228. name = 'NEURAL_CHECK',
  1229. type = 'postfilter,callback',
  1230. flags = 'nostat',
  1231. priority = 6,
  1232. callback = ann_scores_filter
  1233. })
  1234. settings = lua_util.override_defaults(settings, module_config)
  1235. settings.rules = {} -- Reset unless validated further in the cycle
  1236. -- Check all rules
  1237. for k,r in pairs(rules) do
  1238. local rule_elt = lua_util.override_defaults(default_options, r)
  1239. rule_elt['redis'] = redis_params
  1240. rule_elt['anns'] = {} -- Store ANNs here
  1241. if not rule_elt.prefix then
  1242. rule_elt.prefix = k
  1243. end
  1244. if not rule_elt.name then
  1245. rule_elt.name = k
  1246. end
  1247. if rule_elt.train.max_train then
  1248. rule_elt.train.max_trains = rule_elt.train.max_train
  1249. end
  1250. if not rule_elt.profile then rule_elt.profile = {} end
  1251. rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
  1252. settings.rules[k] = rule_elt
  1253. rspamd_config:set_metric_symbol({
  1254. name = rule_elt.symbol_spam,
  1255. score = 0.0,
  1256. description = 'Neural network SPAM',
  1257. group = 'neural'
  1258. })
  1259. rspamd_config:register_symbol({
  1260. name = rule_elt.symbol_spam,
  1261. type = 'virtual',
  1262. flags = 'nostat',
  1263. parent = id
  1264. })
  1265. rspamd_config:set_metric_symbol({
  1266. name = rule_elt.symbol_ham,
  1267. score = -0.0,
  1268. description = 'Neural network HAM',
  1269. group = 'neural'
  1270. })
  1271. rspamd_config:register_symbol({
  1272. name = rule_elt.symbol_ham,
  1273. type = 'virtual',
  1274. flags = 'nostat',
  1275. parent = id
  1276. })
  1277. end
  1278. rspamd_config:register_symbol({
  1279. name = 'NEURAL_LEARN',
  1280. type = 'idempotent,callback',
  1281. flags = 'nostat,explicit_disable',
  1282. priority = 5,
  1283. callback = ann_push_vector
  1284. })
  1285. -- Add training scripts
  1286. for _,rule in pairs(settings.rules) do
  1287. load_scripts(rule.redis)
  1288. -- We also need to deal with settings
  1289. rspamd_config:add_post_init(process_rules_settings)
  1290. -- This function will check ANNs in Redis when a worker is loaded
  1291. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  1292. if worker:is_scanner() then
  1293. rspamd_config:add_periodic(ev_base, 0.0,
  1294. function(_, _)
  1295. return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
  1296. 'try_load_ann')
  1297. end)
  1298. end
  1299. if worker:is_primary_controller() then
  1300. -- We also want to train neural nets when they have enough data
  1301. rspamd_config:add_periodic(ev_base, 0.0,
  1302. function(_, _)
  1303. -- Clean old ANNs
  1304. cleanup_anns(rule, cfg, ev_base)
  1305. return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
  1306. 'try_train_ann')
  1307. end)
  1308. end
  1309. end)
  1310. end