aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/fann_redis.lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/lua/fann_redis.lua')
-rw-r--r--src/plugins/lua/fann_redis.lua34
1 files changed, 24 insertions, 10 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index 361d82303..ad8a0f79a 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -105,9 +105,16 @@ local redis_maybe_invalidate_sha = nil
-- Lua script to invalidate ANN from redis
-- Uses the following keys
-- key1 - prefix for keys
+-- key2 - current time
+-- key3 - key expire
local redis_lua_script_maybe_lock = [[
local locked = redis.call('GET', KEYS[1] .. '_locked')
- if locked then return false end
+ if locked then
+ if tonumber(KEYS[2]) < tonumber(locked) then
+ return false
+ end
+ end
+ redis.call('SET', KEYS[1] .. '_locked', tostring(tonumber(KEYS[2]) + tonumber(KEYS[3])))
return 1
]]
local redis_maybe_lock_sha = nil
@@ -136,6 +143,8 @@ local use_settings = false
local watch_interval = 60.0
local mse = 0.0001
local nlayers = 4
+local lock_expire = 600
+local learning_spawned = false
local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args)
if not ev_base or not redis_params or not callback or not command then
@@ -285,13 +294,14 @@ local function load_or_invalidate_fann(data, id, ev_base)
local ann
if err or not ann_data then
- rspamd_logger.errx('cannot decompress ann: %s', err)
+ rspamd_logger.errx(rspamd_config, 'cannot decompress ann: %s', err)
else
ann = rspamd_fann.load_data(ann_data)
end
if is_fann_valid(ann) then
fanns[id].fann = ann
+ rspamd_logger.infox(rspamd_config, 'loaded ann %s from redis', id)
else
local function redis_invalidate_cb(err, data)
if err then
@@ -387,7 +397,7 @@ local function train_fann(cfg, ev_base, elt)
end
end
- local function redis_save_unlock_sha(err, data)
+ local function redis_save_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
fann_prefix .. elt, err)
@@ -409,7 +419,8 @@ local function train_fann(cfg, ev_base, elt)
else
rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
fann_prefix .. elt, train_mse)
- local ann_data = rspamd_util.zstd_compress(data[elt].fann:data())
+ learning_spawned = false
+ local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
fanns[elt].version = fanns[elt].version + 1
fanns[elt].fann = fanns[elt].fann_train
fanns[elt].fann_train = nil
@@ -448,13 +459,11 @@ local function train_fann(cfg, ev_base, elt)
local outputs = {}
each(function(sample)
- table.insert(inputs, totable(sample))
+ table.insert(inputs, totable(sample[1]))
table.insert(outputs, {1.0})
- end, spam_elts)
- each(function(sample)
- table.insert(inputs, totable(sample))
+ table.insert(inputs, totable(sample[2]))
table.insert(outputs, {-1.0})
- end, ham_elts)
+ end, zip(spam_elts, ham_elts))
-- Now we can train fann
local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
@@ -463,6 +472,7 @@ local function train_fann(cfg, ev_base, elt)
create_train_fann(n, elt)
end
+ learning_spawned = true
fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base,
{max_epochs = max_epoch, desired_mse = mse})
end
@@ -513,13 +523,15 @@ local function train_fann(cfg, ev_base, elt)
)
end
end
+ if learning_spawned then return end
redis_make_request(ev_base,
rspamd_config,
nil,
true, -- is write
redis_lock_cb, --callback
'EVALSHA', -- command
- {redis_maybe_lock_sha, '1', fann_prefix .. elt}
+ {redis_maybe_lock_sha, '3', fann_prefix .. elt, tostring(os.time()),
+ tostring(lock_expire)}
)
end
@@ -535,6 +547,8 @@ local function maybe_train_fanns(cfg, ev_base)
rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err)
elseif data and type(data) == 'number' or type(data) == 'string' then
if tonumber(data) and tonumber(data) > max_trains then
+ rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)',
+ tonumber(data), max_trains)
train_fann(cfg, ev_base, elt)
end
end