--- Lua script to train a row
+-- Lua script that checks if we can store a new training vector
-- Uses the following keys:
-- key1 - ann key
-- key2 - spam or ham
-- key3 - maximum trains
-- returns 1 or 0: 1 - allow learn, 0 - not allow learn
-local redis_lua_script_can_train = [[
+local redis_lua_script_can_store_train_vec = [[
local prefix = KEYS[1]
local locked = redis.call('HGET', prefix, 'lock')
if locked then return 0 end
return tostring(0)
-local redis_can_train_id = nil
+local redis_can_store_train_vec_id = nil
-- Lua script to invalidate ANNs by rank
-- Uses the following keys
local redis_maybe_invalidate_id = nil
--- Lua script to invalidate ANN from redis
--- Uses the following keys
--- key1 - prefix for keys
-local redis_lua_script_locked_invalidate = [[
- redis.call('SET', KEYS[1] .. '_version', '0')
- redis.call('DEL', KEYS[1] .. '_spam')
- redis.call('DEL', KEYS[1] .. '_ham')
- redis.call('DEL', KEYS[1] .. '_data')
- redis.call('DEL', KEYS[1] .. '_locked')
- redis.call('DEL', KEYS[1] .. '_hostname')
- return 1
-local redis_locked_invalidate_id = nil
-- Lua script to invalidate ANN from redis
-- Uses the following keys
-- key1 - prefix for keys
-- key3 - key expire
-- key4 - hostname
local redis_lua_script_maybe_lock = [[
- local locked = redis.call('GET', KEYS[1] .. '_locked')
+ local locked = redis.call('HGET', KEYS[1], 'lock')
if locked then
- if tonumber(KEYS[2]) < tonumber(locked) then
- return false
+ locked = tonumber(locked)
+ now = tonumber(KEYS[2])
+ expire = tonumber(KEYS[3])
+ if now > locked and (now - locked) < expire then
+ return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')}
- redis.call('SET', KEYS[1] .. '_locked', tostring(tonumber(KEYS[2]) + tonumber(KEYS[3])))
- redis.call('SET', KEYS[1] .. '_hostname', KEYS[4])
- return 1
+ redis.call('HSET', KEYS[1], 'lock', tostring(now))
+ redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
+ return true
local redis_maybe_lock_id = nil
local redis_params
local function load_scripts(params)
- redis_can_train_id = lua_redis.add_redis_script(redis_lua_script_can_train,
+ redis_can_store_train_vec_id = lua_redis.add_redis_script(redis_lua_script_can_store_train_vec,
redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
- redis_locked_invalidate_id = lua_redis.add_redis_script(redis_lua_script_locked_invalidate,
- params)
redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock,
-- Creates and stores ANN profile in Redis
-local function new_ann_profile(task, rule, set)
+local function new_ann_profile(task, rule, set, version)
local ann_key = new_ann_key(rule, set)
local profile = {
symbols = set.symbols,
redis_key = ann_key,
- version = 0,
+ version = version or 0,
digest = set.digest,
distance = 0 -- Since we are using our own profile
- lua_redis.redis_make_request_taskless(ev_base,
+ lua_redis.redis_make_request(task,
if score > 0 then
local result = score
- task:insert_result(rule.symbol_spam, result, symscore, id)
+ task:insert_result(rule.symbol_spam, result, symscore)
local result = -(score)
- task:insert_result(rule.symbol_ham, result, symscore, id)
+ task:insert_result(rule.symbol_ham, result, symscore)
local function learn_vec_cb(err)
if err then
- rspamd_logger.errx(task, 'cannot store train vector for %s: %s', fname, err)
+ rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+ rule.prefix, set.name, err)
- rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes",
- rule['name'], learn_type, vec_len)
+ rspamd_logger.infox(task, "trained ANN rule %s:%s, save %s vector",
+ rule.prefix, set.name, learn_type)
set.ann = new_ann_profile(task, rule, set)
-- Check if we can learn
- lua_redis.exec_redis_script(redis_can_train_id,
+ lua_redis.exec_redis_script(redis_can_store_train_vec_id,
{task = task, is_write = true},
{ set.ann.redis_key, learn_type, tostring(train_opts.max_trains)})
-local function train_ann(rule, _, ev_base, elt, worker)
- local spam_elts = {}
- local ham_elts = {}
- elt = tostring(elt)
- local prefix = gen_ann_prefix(rule, elt)
+--- Offline training logic
- local function redis_unlock_cb(err)
+-- Closure generator for unlock function
+local function gen_unlock_cb(rule, set, ann_key)
+ return function (err)
if err then
- rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s from redis: %s',
- prefix, err)
+ rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
+ rule.prefix, set.name, ann_key, err)
+ else
+ lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
+ rule.prefix, set.name, ann_key)
- local function redis_save_cb(err)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
- prefix, err)
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_unlock_cb, --callback
- 'DEL', -- command
- {prefix .. '_locked'}
- )
- else
- rspamd_logger.infox(rspamd_config, 'saved ANN %s, key: %s_data', elt, prefix)
+-- This function is intended to extend lock for ANN during training
+-- It registers periodic that increases locked key each 30 seconds unless
+-- `set.learning_spawned` is set to `true`
+local function register_lock_extender(rule, set, ev_base, ann_key)
+ rspamd_config:add_periodic(ev_base, 30.0,
+ function()
+ local function redis_lock_extend_cb(_err, _)
+ if _err then
+ rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+ ann_key, _err)
+ else
+ rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
+ ann_key)
+ end
+ end
+ if set.learning_spawned then
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ redis_lock_extend_cb, --callback
+ 'HINCRBY', -- command
+ {ann_key, 'lock', '30'}
+ )
+ else
+ return false -- do not plan any more updates
+ end
+ return true
+ end
+ )
+-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
+local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
+ -- Check training data sanity
+ -- Now we need to join inputs and create the appropriate test vectors
+ local n = #set.symbols +
+ meta_functions.rspamd_count_metatokens()
+ -- Now we can train ann
+ local train_ann = create_ann(n, 3)
+ if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
+ -- Invalidate ANN as it is definitely invalid
+ -- TODO: add invalidation
+ assert(false)
+ else
+ local inputs, outputs = {}, {}
+ -- Make training set by joining vectors
+ -- KANN automatically shuffles those samples
+ -- 1.0 is used for spam and -1.0 is used for ham
+ -- It implies that output layer can express that (e.g. tanh output)
+ for _,e in ipairs(spam_vec) do
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = {1.0}
+ end
+ for _,e in ipairs(ham_vec) do
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = {-1.0}
- end
- local function ann_trained(err, data)
- rule.learning_spawned = false
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
- prefix, err)
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- redis_unlock_cb, --callback
- 'DEL', -- command
- {prefix .. '_locked'}
- )
- else
- rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes',
- prefix, #data)
- local ann_data = rspamd_util.zstd_compress(data)
- rule.anns[elt].ann_train = rspamd_kann.load(data)
- rule.anns[elt].version = rule.anns[elt].version + 1
- rule.anns[elt].ann = rule.anns[elt].ann_train
- rule.anns[elt].ann_train = nil
- lua_redis.exec_redis_script(redis_save_unlock_id,
- {ev_base = ev_base, is_write = true},
- redis_save_cb,
- {prefix, tostring(ann_data), tostring(rule.ann_expire)})
+ -- Called in child process
+ local function train()
+ train_ann:train1(inputs, outputs, {
+ lr = rule.train.learning_rate,
+ max_epoch = rule.train.max_iterations,
+ cb = function(iter, train_cost, _)
+ if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then
+ rspamd_logger.infox(rspamd_config, "ANN %s:%s: learned %s iterations, error: %s",
+ rule.prefix, set.name,
+ iter, train_cost)
+ end
+ end
+ })
+ local out = train_ann:save()
+ return out
+ rule.learning_spawned = true
+ local function redis_save_cb(err)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
+ rule.prefix, set.name, ann_key, err)
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ {ann_key, 'lock'}
+ )
+ else
+ rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
+ rule.prefix, set.name, ann_key)
+ end
+ end
+ local function ann_trained(err, data)
+ rule.learning_spawned = false
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
+ rule.prefix, set.name, err)
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ {ann_key, 'lock'}
+ )
+ else
+ rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes',
+ rule.prefix, set.name, #data)
+ local ann_data = rspamd_util.zstd_compress(data)
+ if not set.ann then
+ set.ann = {
+ symbols = set.symbols,
+ distance = 0,
+ digest = set.digest,
+ redis_key = ann_key,
+ }
+ end
+ -- Deserialise ANN from the child process
+ ann_trained = rspamd_kann.load(data)
+ set.ann.version = (set.ann.version or 0) + 1
+ set.ann.ann = ann_trained
+ lua_redis.exec_redis_script(redis_save_unlock_id,
+ {ev_base = ev_base, is_write = true},
+ redis_save_cb,
+ {ann_key, tostring(ann_data), tostring(rule.ann_expire)})
+ end
+ end
+ worker:spawn_process{
+ func = train,
+ on_complete = ann_trained,
+ }
+ -- Spawn learn and register lock extension
+ set.learning_spawned = true
+ register_lock_extender(rule, set, ev_base, ann_key)
+-- Utility to extract and split saved training vectors to a table of tables
+local function process_training_vectors(data)
+ return fun.totable(fun.map(function(tok)
+ local _,str = rspamd_util.zstd_decompress(tok)
+ return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';')))
+ end, data))
+-- This function does the following:
+-- * Tries to lock ANN
+-- * Loads spam and ham vectors
+-- * Spawn learning process
+local function do_train_ann(worker, ev_base, rule, set, ann_key)
+ local spam_elts = {}
+ local ham_elts = {}
local function redis_ham_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
- prefix, err)
+ ann_key, err)
+ -- Unlock on error
true, -- is write
- redis_unlock_cb, --callback
- 'DEL', -- command
- {prefix .. '_locked'}
+ gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ {ann_key, 'lock'}
-- Decompress and convert to numbers each training vector
- ham_elts = fun.totable(fun.map(function(tok)
- local _,str = rspamd_util.zstd_decompress(tok)
- return fun.totable(fun.map(tonumber, rspamd_str_split(tostring(str), ';')))
- end, data))
- -- Now we need to join inputs and create the appropriate test vectors
- local n = rspamd_config:get_symbols_count() +
- meta_functions.rspamd_count_metatokens()
- -- Now we can train ann
- if not rule.anns[elt] or not rule.anns[elt].ann_train then
- -- Create ann if it does not exist
- create_train_ann(rule, n, elt)
- end
- if #spam_elts + #ham_elts < rule.train.max_trains / 2 then
- -- Invalidate ANN as it is definitely invalid
- local function redis_invalidate_cb(_err, _data)
- if _err then
- rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
- elseif type(_data) == 'string' then
- rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
- rule.anns[elt].version = 0
- end
- end
- -- Invalidate ANN
- rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix)
- lua_redis.exec_redis_script(redis_locked_invalidate_id,
- {ev_base = ev_base, is_write = true},
- redis_invalidate_cb,
- {prefix})
- else
- local inputs, outputs = {}, {}
- for _,e in ipairs(spam_elts) do
- if e == e then
- inputs[#inputs + 1] = e
- outputs[#outputs + 1] = {1.0}
- end
- end
- for _,e in ipairs(ham_elts) do
- if e == e then
- inputs[#inputs + 1] = e
- outputs[#outputs + 1] = {0.0}
- end
- end
- local function train()
- rule.anns[elt].ann_train:train1(inputs, outputs, {
- lr = rule.train.learning_rate,
- max_epoch = rule.train.max_iterations,
- cb = function(iter, train_cost, _)
- if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then
- rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
- iter, train_cost)
- end
- end
- })
- local out = rule.anns[elt].ann_train:save()
- return out
- end
- rule.learning_spawned = true
- worker:spawn_process{
- func = train,
- on_complete = ann_trained,
- }
- end
+ ham_elts = process_training_vectors(data)
+ spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts)
+ -- Spam vectors received
local function redis_spam_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
- prefix, err)
+ ann_key, err)
+ -- Unlock ANN on error
true, -- is write
- redis_unlock_cb, --callback
- 'DEL', -- command
- {prefix .. '_locked'}
+ gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ {ann_key, 'lock'}
-- Decompress and convert to numbers each training vector
- spam_elts = fun.totable(fun.map(function(tok)
- local _,str = rspamd_util.zstd_decompress(tok)
- return fun.totable(fun.map(tonumber, rspamd_str_split(tostring(str), ';')))
- end, data))
+ spam_elts = process_training_vectors(data)
+ -- Now get ham vectors...
false, -- is write
redis_ham_cb, --callback
'LRANGE', -- command
- {prefix .. '_ham', '0', '-1'}
+ {ann_key .. '_ham', '0', '-1'}
local function redis_lock_cb(err, data)
if err then
- rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
- prefix, err)
- elseif type(data) == 'number' then
- -- Can train ANN
+ rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
+ ann_key, err)
+ elseif type(data) == 'boolean' and data then
+ -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
false, -- is write
redis_spam_cb, --callback
'LRANGE', -- command
- {prefix .. '_spam', '0', '-1'}
+ {ann_key .. '_spam', '0', '-1'}
- rspamd_config:add_periodic(ev_base, 30.0,
- function(_, _)
- local function redis_lock_extend_cb(_err, _)
- if _err then
- rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
- prefix, _err)
- else
- rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
- prefix)
- end
- end
- if rule.learning_spawned then
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- redis_lock_extend_cb, --callback
- 'INCRBY', -- command
- {prefix .. '_locked', '30'}
- )
- else
- return false -- do not plan any more updates
- end
- return true
- end
- )
- rspamd_logger.infox(rspamd_config, 'lock ANN %s for learning', prefix)
+ rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
+ rule.prefix, set.name, ann_key)
- rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', prefix)
+ local lock_tm = tonumber(data[1])
+ rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
+ 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
+ data[2], os.date('%c', lock_tm))
- if rule.learning_spawned then
- rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix)
+ -- Check if we are already learning this network
+ if set.learning_spawned then
+ rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
+ ann_key)
+ -- Call Redis script that tries to acquire a lock
+ -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
+ -- ANN is locked by another host (or a process, meh)
{ev_base = ev_base, is_write = true},
- {prefix, tostring(os.time()), tostring(rule.lock_expire), rspamd_util.get_hostname()})
-local function maybe_train_anns(rule, cfg, ev_base, worker)
- local function members_cb(err, data)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
- elseif type(data) == 'table' then
- fun.each(function(elt)
- elt = tostring(elt)
- local prefix = gen_ann_prefix(rule, elt)
- rspamd_logger.infox(cfg, "check ANN %s", prefix)
- local redis_len_cb = function(_err, _data)
- if _err then
- rspamd_logger.errx(rspamd_config,
- 'cannot get FANN trains %s from redis: %s', prefix, _err)
- elseif _data and type(_data) == 'number' or type(_data) == 'string' then
- if tonumber(_data) and tonumber(_data) >= rule.train.max_trains then
- rspamd_logger.infox(rspamd_config,
- 'need to learn ANN %s after %s learn vectors (%s required)',
- prefix, tonumber(_data), rule.train.max_trains)
- train_ann(rule, cfg, ev_base, elt, worker)
- else
- rspamd_logger.infox(rspamd_config,
- 'no need to learn ANN %s %s learn vectors (%s required)',
- prefix, tonumber(_data), rule.train.max_trains)
- end
- end
- end
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_len_cb, --callback
- 'LLEN', -- command
- {prefix .. '_spam'}
- )
- end,
- data)
- end
- end
- -- First we need to get all anns stored in our Redis
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- members_cb, --callback
- 'SMEMBERS', -- command
- {gen_ann_prefix(rule, nil)} -- arguments
- )
- return rule.watch_interval
+ {
+ ann_key,
+ tostring(os.time()),
+ tostring(rule.watch_interval * 2),
+ rspamd_util.get_hostname()
+ })
-- This function loads new ann from Redis
-- for some specific rule + some specific setting
-- This function tries to load more fresh or more specific ANNs in lieu of
-- the existing ones.
-local function process_existing_ann(rule, ev_base, set, profiles)
+-- Use this function to load ANNs as `callback` parameter for `check_anns` function
+local function process_existing_ann(_, ev_base, rule, set, profiles)
local my_symbols = set.symbols
local min_diff = math.huge
local sel_elt
+-- This function checks all profiles and selects if we can train our
+-- ANN. By our we mean that it has exactly the same symbols in profile.
+-- Use this function to train ANN as `callback` parameter for `check_anns` function
+local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
+ local my_symbols = set.symbols
+ local sel_elt
+ for _,elt in fun.iter(profiles) do
+ if elt and elt.symbols then
+ local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
+ -- Check distance
+ if dist == 0 then
+ sel_elt = elt
+ break
+ end
+ end
+ end
+ if sel_elt then
+ -- We have our ANN and that's train vectors, check if we can learn
+ local ann_key = sel_elt.redis_key
+ lua_util.debugm(N, rspamd_config, "check ANN %s", ann_key)
+ local redis_len_cb = function(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config,
+ 'cannot get FANN trains %s from redis: %s', ann_key, err)
+ elseif data and type(data) == 'number' or type(data) == 'string' then
+ if tonumber(data) and tonumber(data) >= rule.train.max_trains then
+ rspamd_logger.infox(rspamd_config,
+ 'need to learn ANN %s after %s learn vectors (%s required)',
+ ann_key, tonumber(data), rule.train.max_trains)
+ do_train_ann(worker, ev_base, rule, set, ann_key)
+ else
+ rspamd_logger.debugm(N, rspamd_config,
+ 'no need to learn ANN %s %s learn vectors (%s required)',
+ ann_key, tonumber(data), rule.train.max_trains)
+ end
+ end
+ end
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_len_cb, --callback
+ 'LLEN', -- command
+ {ann_key .. '_spam'}
+ )
+ end
-- Used to deserialise ANN element from a list
local function load_ann_profile(element)
local ucl = require "ucl"
-- Function to check or load ANNs from Redis
-local function check_anns(rule, cfg, ev_base)
+local function check_anns(worker, rule, cfg, ev_base, process_callback)
for _,set in pairs(rule.settings) do
local function members_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
elseif type(data) == 'table' then
- process_existing_ann(rule, ev_base, set, fun.map(load_ann_profile, data))
+ process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
rspamd_config:add_on_load(function(cfg, ev_base, worker)
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
- return check_anns(rule, cfg, ev_base)
+ return check_anns(worker, cfg, ev_base, rule, process_existing_ann)
if worker:is_primary_controller() then
function(_, _)
-- Clean old ANNs
cleanup_anns(rule, cfg, ev_base)
- return maybe_train_anns(rule, cfg, ev_base, worker)
+ return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann)