From 4b8c9f7b9c9f399a50665e0aaa28fd67dff3d7eb Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 6 Jul 2019 20:52:40 +0100 Subject: [PATCH] [Project] Add training and saving ANN logic --- src/plugins/lua/neural.lua | 566 +++++++++++++++++++------------------ 1 file changed, 298 insertions(+), 268 deletions(-) diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index c2ffb3e15..4b4cc7354 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -83,13 +83,13 @@ if not opts then end --- Lua script to train a row +-- Lua script that checks if we can store a new training vector -- Uses the following keys: -- key1 - ann key -- key2 - spam or ham -- key3 - maximum trains -- returns 1 or 0: 1 - allow learn, 0 - not allow learn -local redis_lua_script_can_train = [[ +local redis_lua_script_can_store_train_vec = [[ local prefix = KEYS[1] local locked = redis.call('HGET', prefix, 'lock') if locked then return 0 end @@ -119,7 +119,7 @@ local redis_lua_script_can_train = [[ return tostring(0) ]] -local redis_can_train_id = nil +local redis_can_store_train_vec_id = nil -- Lua script to invalidate ANNs by rank -- Uses the following keys @@ -144,20 +144,6 @@ local redis_lua_script_maybe_invalidate = [[ ]] local redis_maybe_invalidate_id = nil --- Lua script to invalidate ANN from redis --- Uses the following keys --- key1 - prefix for keys -local redis_lua_script_locked_invalidate = [[ - redis.call('SET', KEYS[1] .. '_version', '0') - redis.call('DEL', KEYS[1] .. '_spam') - redis.call('DEL', KEYS[1] .. '_ham') - redis.call('DEL', KEYS[1] .. '_data') - redis.call('DEL', KEYS[1] .. '_locked') - redis.call('DEL', KEYS[1] .. '_hostname') - return 1 -]] -local redis_locked_invalidate_id = nil - -- Lua script to invalidate ANN from redis -- Uses the following keys -- key1 - prefix for keys @@ -165,15 +151,18 @@ local redis_locked_invalidate_id = nil -- key3 - key expire -- key4 - hostname local redis_lua_script_maybe_lock = [[ - local locked = redis.call('GET', KEYS[1] .. '_locked') + local locked = redis.call('HGET', KEYS[1], 'lock') if locked then - if tonumber(KEYS[2]) < tonumber(locked) then - return false + locked = tonumber(locked) + now = tonumber(KEYS[2]) + expire = tonumber(KEYS[3]) + if now > locked and (now - locked) < expire then + return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')} end end - redis.call('SET', KEYS[1] .. '_locked', tostring(tonumber(KEYS[2]) + tonumber(KEYS[3]))) - redis.call('SET', KEYS[1] .. '_hostname', KEYS[4]) - return 1 + redis.call('HSET', KEYS[1], 'lock', tostring(now)) + redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) + return true ]] local redis_maybe_lock_id = nil @@ -198,12 +187,10 @@ local redis_save_unlock_id = nil local redis_params local function load_scripts(params) - redis_can_train_id = lua_redis.add_redis_script(redis_lua_script_can_train, + redis_can_store_train_vec_id = lua_redis.add_redis_script(redis_lua_script_can_store_train_vec, params) redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate, params) - redis_locked_invalidate_id = lua_redis.add_redis_script(redis_lua_script_locked_invalidate, - params) redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock, params) redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock, @@ -244,14 +231,13 @@ local function new_ann_key(rule, set) end -- Creates and stores ANN profile in Redis -local function new_ann_profile(task, rule, set) +local function new_ann_profile(task, rule, set, version) local ann_key = new_ann_key(rule, set) - local profile = { symbols = set.symbols, redis_key = ann_key, - version = 0, + version = version or 0, digest = set.digest, distance = 0 -- Since we are using our own profile } @@ -269,7 +255,7 @@ local function new_ann_profile(task, rule, set) end end - lua_redis.redis_make_request_taskless(ev_base, + lua_redis.redis_make_request(task, rspamd_config, rule.redis, nil, @@ -335,10 +321,10 @@ local function ann_scores_filter(task) if score > 0 then local result = score - task:insert_result(rule.symbol_spam, result, symscore, id) + task:insert_result(rule.symbol_spam, result, symscore) else local result = -(score) - task:insert_result(rule.symbol_ham, result, symscore, id) + task:insert_result(rule.symbol_ham, result, symscore) end end end @@ -391,10 +377,11 @@ local function ann_train_callback(rule, task, score, required_score, set) local function learn_vec_cb(err) if err then - rspamd_logger.errx(task, 'cannot store train vector for %s: %s', fname, err) + rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', + rule.prefix, set.name, err) else - rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes", - rule['name'], learn_type, vec_len) + rspamd_logger.infox(task, "trained ANN rule %s:%s, save %s vector", + rule.prefix, set.name, learn_type) end end @@ -433,181 +420,236 @@ local function ann_train_callback(rule, task, score, required_score, set) set.ann = new_ann_profile(task, rule, set) end -- Check if we can learn - lua_redis.exec_redis_script(redis_can_train_id, + lua_redis.exec_redis_script(redis_can_store_train_vec_id, {task = task, is_write = true}, can_train_cb, { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)}) end end -local function train_ann(rule, _, ev_base, elt, worker) - local spam_elts = {} - local ham_elts = {} - elt = tostring(elt) - local prefix = gen_ann_prefix(rule, elt) +--- Offline training logic - local function redis_unlock_cb(err) +-- Closure generator for unlock function +local function gen_unlock_cb(rule, set, ann_key) + return function (err) if err then - rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s from redis: %s', - prefix, err) + rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s', + rule.prefix, set.name, ann_key, err) + else + lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s', + rule.prefix, set.name, ann_key) end end +end - local function redis_save_cb(err) - if err then - rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s', - prefix, err) - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_unlock_cb, --callback - 'DEL', -- command - {prefix .. '_locked'} - ) - else - rspamd_logger.infox(rspamd_config, 'saved ANN %s, key: %s_data', elt, prefix) +-- This function is intended to extend lock for ANN during training +-- It registers periodic that increases locked key each 30 seconds unless +-- `set.learning_spawned` is set to `true` +local function register_lock_extender(rule, set, ev_base, ann_key) + rspamd_config:add_periodic(ev_base, 30.0, + function() + local function redis_lock_extend_cb(_err, _) + if _err then + rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', + ann_key, _err) + else + rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', + ann_key) + end + end + + if set.learning_spawned then + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + redis_lock_extend_cb, --callback + 'HINCRBY', -- command + {ann_key, 'lock', '30'} + ) + else + return false -- do not plan any more updates + end + + return true + end + ) +end + +-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis +local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec) + -- Check training data sanity + -- Now we need to join inputs and create the appropriate test vectors + local n = #set.symbols + + meta_functions.rspamd_count_metatokens() + + -- Now we can train ann + local train_ann = create_ann(n, 3) + + if #ham_vec + #spam_vec < rule.train.max_trains / 2 then + -- Invalidate ANN as it is definitely invalid + -- TODO: add invalidation + assert(false) + else + local inputs, outputs = {}, {} + + -- Make training set by joining vectors + -- KANN automatically shuffles those samples + -- 1.0 is used for spam and -1.0 is used for ham + -- It implies that output layer can express that (e.g. tanh output) + for _,e in ipairs(spam_vec) do + inputs[#inputs + 1] = e + outputs[#outputs + 1] = {1.0} + end + for _,e in ipairs(ham_vec) do + inputs[#inputs + 1] = e + outputs[#outputs + 1] = {-1.0} end - end - local function ann_trained(err, data) - rule.learning_spawned = false - if err then - rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s', - prefix, err) - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - redis_unlock_cb, --callback - 'DEL', -- command - {prefix .. '_locked'} - ) - else - rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes', - prefix, #data) - local ann_data = rspamd_util.zstd_compress(data) - rule.anns[elt].ann_train = rspamd_kann.load(data) - rule.anns[elt].version = rule.anns[elt].version + 1 - rule.anns[elt].ann = rule.anns[elt].ann_train - rule.anns[elt].ann_train = nil - lua_redis.exec_redis_script(redis_save_unlock_id, - {ev_base = ev_base, is_write = true}, - redis_save_cb, - {prefix, tostring(ann_data), tostring(rule.ann_expire)}) + -- Called in child process + local function train() + train_ann:train1(inputs, outputs, { + lr = rule.train.learning_rate, + max_epoch = rule.train.max_iterations, + cb = function(iter, train_cost, _) + if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then + rspamd_logger.infox(rspamd_config, "ANN %s:%s: learned %s iterations, error: %s", + rule.prefix, set.name, + iter, train_cost) + end + end + }) + + local out = train_ann:save() + return out end + + rule.learning_spawned = true + + local function redis_save_cb(err) + if err then + rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s', + rule.prefix, set.name, ann_key, err) + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + {ann_key, 'lock'} + ) + else + rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', + rule.prefix, set.name, ann_key) + end + end + + local function ann_trained(err, data) + rule.learning_spawned = false + if err then + rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', + rule.prefix, set.name, err) + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + {ann_key, 'lock'} + ) + else + rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes', + rule.prefix, set.name, #data) + local ann_data = rspamd_util.zstd_compress(data) + if not set.ann then + set.ann = { + symbols = set.symbols, + distance = 0, + digest = set.digest, + redis_key = ann_key, + } + end + -- Deserialise ANN from the child process + ann_trained = rspamd_kann.load(data) + set.ann.version = (set.ann.version or 0) + 1 + set.ann.ann = ann_trained + + lua_redis.exec_redis_script(redis_save_unlock_id, + {ev_base = ev_base, is_write = true}, + redis_save_cb, + {ann_key, tostring(ann_data), tostring(rule.ann_expire)}) + end + end + + worker:spawn_process{ + func = train, + on_complete = ann_trained, + } end + -- Spawn learn and register lock extension + set.learning_spawned = true + register_lock_extender(rule, set, ev_base, ann_key) +end + +-- Utility to extract and split saved training vectors to a table of tables +local function process_training_vectors(data) + return fun.totable(fun.map(function(tok) + local _,str = rspamd_util.zstd_decompress(tok) + return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';'))) + end, data)) +end + +-- This function does the following: +-- * Tries to lock ANN +-- * Loads spam and ham vectors +-- * Spawn learning process +local function do_train_ann(worker, ev_base, rule, set, ann_key) + local spam_elts = {} + local ham_elts = {} local function redis_ham_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', - prefix, err) + ann_key, err) + -- Unlock on error lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, true, -- is write - redis_unlock_cb, --callback - 'DEL', -- command - {prefix .. '_locked'} + gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + {ann_key, 'lock'} ) else -- Decompress and convert to numbers each training vector - ham_elts = fun.totable(fun.map(function(tok) - local _,str = rspamd_util.zstd_decompress(tok) - return fun.totable(fun.map(tonumber, rspamd_str_split(tostring(str), ';'))) - end, data)) - - -- Now we need to join inputs and create the appropriate test vectors - local n = rspamd_config:get_symbols_count() + - meta_functions.rspamd_count_metatokens() - - -- Now we can train ann - if not rule.anns[elt] or not rule.anns[elt].ann_train then - -- Create ann if it does not exist - create_train_ann(rule, n, elt) - end - - if #spam_elts + #ham_elts < rule.train.max_trains / 2 then - -- Invalidate ANN as it is definitely invalid - local function redis_invalidate_cb(_err, _data) - if _err then - rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) - elseif type(_data) == 'string' then - rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) - rule.anns[elt].version = 0 - end - end - -- Invalidate ANN - rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix) - lua_redis.exec_redis_script(redis_locked_invalidate_id, - {ev_base = ev_base, is_write = true}, - redis_invalidate_cb, - {prefix}) - else - local inputs, outputs = {}, {} - - for _,e in ipairs(spam_elts) do - if e == e then - inputs[#inputs + 1] = e - outputs[#outputs + 1] = {1.0} - end - end - for _,e in ipairs(ham_elts) do - if e == e then - inputs[#inputs + 1] = e - outputs[#outputs + 1] = {0.0} - end - end - - - local function train() - rule.anns[elt].ann_train:train1(inputs, outputs, { - lr = rule.train.learning_rate, - max_epoch = rule.train.max_iterations, - cb = function(iter, train_cost, _) - if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then - rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s", - iter, train_cost) - end - end - }) - - local out = rule.anns[elt].ann_train:save() - return out - end - - rule.learning_spawned = true - - worker:spawn_process{ - func = train, - on_complete = ann_trained, - } - end + ham_elts = process_training_vectors(data) + spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts) end end + -- Spam vectors received local function redis_spam_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s', - prefix, err) + ann_key, err) + -- Unlock ANN on error lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, true, -- is write - redis_unlock_cb, --callback - 'DEL', -- command - {prefix .. '_locked'} + gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + {ann_key, 'lock'} ) else -- Decompress and convert to numbers each training vector - spam_elts = fun.totable(fun.map(function(tok) - local _,str = rspamd_util.zstd_decompress(tok) - return fun.totable(fun.map(tonumber, rspamd_str_split(tostring(str), ';'))) - end, data)) + spam_elts = process_training_vectors(data) + -- Now get ham vectors... lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, @@ -615,17 +657,17 @@ local function train_ann(rule, _, ev_base, elt, worker) false, -- is write redis_ham_cb, --callback 'LRANGE', -- command - {prefix .. '_ham', '0', '-1'} + {ann_key .. '_ham', '0', '-1'} ) end end local function redis_lock_cb(err, data) if err then - rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', - prefix, err) - elseif type(data) == 'number' then - -- Can train ANN + rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s', + ann_key, err) + elseif type(data) == 'boolean' and data then + -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, @@ -633,105 +675,38 @@ local function train_ann(rule, _, ev_base, elt, worker) false, -- is write redis_spam_cb, --callback 'LRANGE', -- command - {prefix .. '_spam', '0', '-1'} + {ann_key .. '_spam', '0', '-1'} ) - rspamd_config:add_periodic(ev_base, 30.0, - function(_, _) - local function redis_lock_extend_cb(_err, _) - if _err then - rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', - prefix, _err) - else - rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', - prefix) - end - end - if rule.learning_spawned then - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - redis_lock_extend_cb, --callback - 'INCRBY', -- command - {prefix .. '_locked', '30'} - ) - else - return false -- do not plan any more updates - end - - return true - end - ) - rspamd_logger.infox(rspamd_config, 'lock ANN %s for learning', prefix) + rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning', + rule.prefix, set.name, ann_key) else - rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', prefix) + local lock_tm = tonumber(data[1]) + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' .. + 'locked by another host %s at %s', rule.prefix, set.name, ann_key, + data[2], os.date('%c', lock_tm)) end end - if rule.learning_spawned then - rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix) + + -- Check if we are already learning this network + if set.learning_spawned then + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', + ann_key) return end + + -- Call Redis script that tries to acquire a lock + -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when + -- ANN is locked by another host (or a process, meh) lua_redis.exec_redis_script(redis_maybe_lock_id, {ev_base = ev_base, is_write = true}, redis_lock_cb, - {prefix, tostring(os.time()), tostring(rule.lock_expire), rspamd_util.get_hostname()}) -end - -local function maybe_train_anns(rule, cfg, ev_base, worker) - local function members_cb(err, data) - if err then - rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) - elseif type(data) == 'table' then - fun.each(function(elt) - elt = tostring(elt) - local prefix = gen_ann_prefix(rule, elt) - rspamd_logger.infox(cfg, "check ANN %s", prefix) - local redis_len_cb = function(_err, _data) - if _err then - rspamd_logger.errx(rspamd_config, - 'cannot get FANN trains %s from redis: %s', prefix, _err) - elseif _data and type(_data) == 'number' or type(_data) == 'string' then - if tonumber(_data) and tonumber(_data) >= rule.train.max_trains then - rspamd_logger.infox(rspamd_config, - 'need to learn ANN %s after %s learn vectors (%s required)', - prefix, tonumber(_data), rule.train.max_trains) - train_ann(rule, cfg, ev_base, elt, worker) - else - rspamd_logger.infox(rspamd_config, - 'no need to learn ANN %s %s learn vectors (%s required)', - prefix, tonumber(_data), rule.train.max_trains) - end - end - end - - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_len_cb, --callback - 'LLEN', -- command - {prefix .. '_spam'} - ) - end, - data) - end - end - - -- First we need to get all anns stored in our Redis - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - members_cb, --callback - 'SMEMBERS', -- command - {gen_ann_prefix(rule, nil)} -- arguments - ) - - return rule.watch_interval + { + ann_key, + tostring(os.time()), + tostring(rule.watch_interval * 2), + rspamd_util.get_hostname() + }) end -- This function loads new ann from Redis @@ -808,7 +783,8 @@ end -- for some specific rule + some specific setting -- This function tries to load more fresh or more specific ANNs in lieu of -- the existing ones. -local function process_existing_ann(rule, ev_base, set, profiles) +-- Use this function to load ANNs as `callback` parameter for `check_anns` function +local function process_existing_ann(_, ev_base, rule, set, profiles) local my_symbols = set.symbols local min_diff = math.huge local sel_elt @@ -872,6 +848,60 @@ local function process_existing_ann(rule, ev_base, set, profiles) end end + +-- This function checks all profiles and selects if we can train our +-- ANN. By our we mean that it has exactly the same symbols in profile. +-- Use this function to train ANN as `callback` parameter for `check_anns` function +local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) + local my_symbols = set.symbols + local sel_elt + + for _,elt in fun.iter(profiles) do + if elt and elt.symbols then + local dist = lua_util.distance_sorted(elt.symbols, my_symbols) + -- Check distance + if dist == 0 then + sel_elt = elt + break + end + end + end + + if sel_elt then + -- We have our ANN and that's train vectors, check if we can learn + local ann_key = sel_elt.redis_key + + lua_util.debugm(N, rspamd_config, "check ANN %s", ann_key) + local redis_len_cb = function(err, data) + if err then + rspamd_logger.errx(rspamd_config, + 'cannot get FANN trains %s from redis: %s', ann_key, err) + elseif data and type(data) == 'number' or type(data) == 'string' then + if tonumber(data) and tonumber(data) >= rule.train.max_trains then + rspamd_logger.infox(rspamd_config, + 'need to learn ANN %s after %s learn vectors (%s required)', + ann_key, tonumber(data), rule.train.max_trains) + do_train_ann(worker, ev_base, rule, set, ann_key) + else + rspamd_logger.debugm(N, rspamd_config, + 'no need to learn ANN %s %s learn vectors (%s required)', + ann_key, tonumber(data), rule.train.max_trains) + end + end + end + + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb, --callback + 'LLEN', -- command + {ann_key .. '_spam'} + ) + end +end + -- Used to deserialise ANN element from a list local function load_ann_profile(element) local ucl = require "ucl" @@ -888,14 +918,14 @@ local function load_ann_profile(element) end -- Function to check or load ANNs from Redis -local function check_anns(rule, cfg, ev_base) +local function check_anns(worker, rule, cfg, ev_base, process_callback) for _,set in pairs(rule.settings) do local function members_cb(err, data) if err then rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', err) elseif type(data) == 'table' then - process_existing_ann(rule, ev_base, set, fun.map(load_ann_profile, data)) + process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data)) end end @@ -1119,7 +1149,7 @@ for _,rule in pairs(settings.rules) do rspamd_config:add_on_load(function(cfg, ev_base, worker) rspamd_config:add_periodic(ev_base, 0.0, function(_, _) - return check_anns(rule, cfg, ev_base) + return check_anns(worker, cfg, ev_base, rule, process_existing_ann) end) if worker:is_primary_controller() then @@ -1128,7 +1158,7 @@ for _,rule in pairs(settings.rules) do function(_, _) -- Clean old ANNs cleanup_anns(rule, cfg, ev_base) - return maybe_train_anns(rule, cfg, ev_base, worker) + return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann) end) end end) -- 2.39.5