-- ANNs indexed by settings id
local data = {
['0'] = {
- fann_mtime = 0,
- ntrains = 0,
- epoch = 0,
+ version = 0,
}
}
]]
local redis_maybe_lock_sha = nil
+-- Lua script to save and unlock ANN in redis
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - compressed ANN
+local redis_lua_script_save_unlock = [[
+ redis.call('INCRBY', KEYS[1] .. '_version', '1')
+ redis.call('DEL', KEYS[1] .. '_spam')
+ redis.call('DEL', KEYS[1] .. '_ham')
+ redis.call('SET', KEYS[1] .. '_data', KEYS[2])
+ redis.call('DEL', KEYS[1] .. '_locked')
+ return 1
+]]
+local redis_save_unlock_sha = nil
+
local redis_params
redis_params = rspamd_parse_redis_server('fann_redis')
local max_epoch = 100
local use_settings = false
local watch_interval = 60.0
+local mse = 0.0001
local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args)
if not ev_base or not redis_params or not callback or not command then
local function create_train_fann(n, id)
data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1)
- data[id].ntrains = 0
- data[id].epoch = 0
+ data[id].version = 0
end
local function load_or_invalidate_fann(data, id, ev_base)
end
end
+ local function redis_save_unlock_sha(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
+ fann_prefix .. elt, err)
+ end
+ end
+
+ local function ann_trained(errcode, errmsg, train_mse)
+ if errcode ~= 0 then
+ rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
+ fann_prefix .. elt, errmsg)
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_unlock_cb, --callback
+ 'DEL', -- command
+ {fann_prefix .. elt .. '_lock'}
+ )
+ else
+ rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
+ fann_prefix .. elt, train_mse)
+ local ann_data = rspamd_util.zstd_compress(data[elt].fann:data())
+ data[elt].version = data[elt].version + 1
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ redis_save_cb, --callback
+ 'EVALSHA', -- command
+ {redis_save_unlock_sha, '2', fann_prefix .. elt, ann_data}
+ )
+ end
+ end
+
local function redis_ham_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
)
else
-- Decompress and convert to numbers each training vector
- ham_elts = map(function(i, elt)
- local str = tostring(rspamd_util.zstd_decompress(elt))
+ ham_elts = map(function(i, tok)
+ local str = tostring(rspamd_util.zstd_decompress(tok))
return map(tonumber, rspamd_str_split(str, ';'))
end, data)
local inputs = {}
local outputs = {}
- each(function(i, elt)
- table.insert(inputs, totable(elt))
+ each(function(i, sample)
+ table.insert(inputs, totable(sample))
table.insert(outputs, 1.0)
end, spam_elts)
- each(function(i, elt)
- table.insert(inputs, totable(elt))
+ each(function(i, sample)
+ table.insert(inputs, totable(sample))
table.insert(outputs, -1.0)
end, spam_elts)
-- Now we can train fann
+ local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
+ if not data[elt].fann then
+ -- Create fann if it does not exist
+ create_train_fann(n, elt)
+ end
+ data[elt].fann:train_threaded(inputs, outputs, ann_trained, ev_base,
+ {max_epochs = max_epoch, desired_mse = mse})
end
end
)
else
-- Decompress and convert to numbers each training vector
- spam_elts = map(function(i, elt)
- local str = tostring(rspamd_util.zstd_decompress(elt))
+ spam_elts = map(function(i, tok)
+ local str = tostring(rspamd_util.zstd_decompress(tok))
return map(tonumber, rspamd_str_split(str, ';'))
end, data)
redis_make_request(ev_base,
{'LOAD', redis_lua_script_maybe_lock} -- arguments
)
+ local function save_unlock_sha_cb(err, data)
+ if err or not data or type(data) ~= 'string' then
+ rspamd_logger.errx(cfg, 'cannot save redis save script: %s', err)
+ else
+ redis_save_unlock_sha = tostring(data)
+ end
+ end
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ save_unlock_sha_cb, --callback
+ 'SCRIPT', -- command
+ {'LOAD', redis_lua_script_save_unlock} -- arguments
+ )
+
if worker:get_name() == 'normal' then
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,