]]
local redis_maybe_invalidate_sha = nil
+-- Lua script to invalidate ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+local redis_lua_script_maybe_lock = [[
+ local locked = redis.call('GET', KEYS[1] .. '_locked')
+ if locked then return false end
+ return 1
+]]
+local redis_maybe_lock_sha = nil
+
local redis_params
redis_params = rspamd_parse_redis_server('fann_redis')
end
local function train_fann(cfg, ev_base, elt)
+ local spam_elts = {}
+ local ham_elts = {}
+
+ local function redis_unlock_cb(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s from redis: %s',
+ fann_prefix .. elt, err)
+ end
+ end
+
+ local function redis_ham_cb(err, data)
+ if err or type(data) ~= 'table' then
+ rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
+ fann_prefix .. elt, err)
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_unlock_cb, --callback
+ 'DEL', -- command
+ {fann_prefix .. elt .. '_lock'}
+ )
+ else
+ -- Decompress and convert to numbers each training vector
+ ham_elts = map(function(i, elt)
+ local str = tostring(rspamd_util.zstd_decompress(elt))
+ return map(tonumber, rspamd_str_split(str, ';'))
+ end, data)
+
+ -- Now we need to join inputs and create the appropriate test vectors
+ local inputs = {}
+ local outputs = {}
+
+ each(function(i, elt)
+ table.insert(inputs, totable(elt))
+ table.insert(outputs, 1.0)
+ end, spam_elts)
+ each(function(i, elt)
+ table.insert(inputs, totable(elt))
+ table.insert(outputs, -1.0)
+ end, spam_elts)
+
+ -- Now we can train fann
+
+ end
+ end
+ local function redis_spam_cb(err, data)
+ if err or type(data) ~= 'table' then
+ rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
+ fann_prefix .. elt, err)
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_unlock_cb, --callback
+ 'DEL', -- command
+ {fann_prefix .. elt .. '_lock'}
+ )
+ else
+ -- Decompress and convert to numbers each training vector
+ spam_elts = map(function(i, elt)
+ local str = tostring(rspamd_util.zstd_decompress(elt))
+ return map(tonumber, rspamd_str_split(str, ';'))
+ end, data)
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_ham_cb, --callback
+ 'LRANGE', -- command
+ {fann_prefix .. elt .. '_ham', '0', '-1'}
+ )
+ end
+ end
+
+ local function redis_lock_cb(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+ fann_prefix .. elt, err)
+ elseif type(data) == 'number' then
+ -- Can train ANN
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_spam_cb, --callback
+ 'LRANGE', -- command
+ {fann_prefix .. elt .. '_spam', '0', '-1'}
+ )
+ end
+ end
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ redis_lock_cb, --callback
+ 'EVALSHA', -- command
+ {redis_maybe_lock_sha, '1', fann_prefix .. elt}
+ )
end
local function maybe_train_fanns(cfg, ev_base)
{'LOAD', redis_lua_script_maybe_invalidate} -- arguments
)
+ local function maybe_lock_sha_cb(err, data)
+ if err or not data or type(data) ~= 'string' then
+ rspamd_logger.errx(cfg, 'cannot save redis lock script: %s', err)
+ else
+ redis_maybe_lock_sha = tostring(data)
+ end
+ end
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ maybe_lock_sha_cb, --callback
+ 'SCRIPT', -- command
+ {'LOAD', redis_lua_script_maybe_lock} -- arguments
+ )
+
if worker:get_name() == 'normal' then
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,