diff options
Diffstat (limited to 'src/plugins/lua/fann_redis.lua')
-rw-r--r-- | src/plugins/lua/fann_redis.lua | 34 |
1 files changed, 24 insertions, 10 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index 361d82303..ad8a0f79a 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -105,9 +105,16 @@ local redis_maybe_invalidate_sha = nil -- Lua script to invalidate ANN from redis -- Uses the following keys -- key1 - prefix for keys +-- key2 - current time +-- key3 - key expire local redis_lua_script_maybe_lock = [[ local locked = redis.call('GET', KEYS[1] .. '_locked') - if locked then return false end + if locked then + if tonumber(KEYS[2]) < tonumber(locked) then + return false + end + end + redis.call('SET', KEYS[1] .. '_locked', tostring(tonumber(KEYS[2]) + tonumber(KEYS[3]))) return 1 ]] local redis_maybe_lock_sha = nil @@ -136,6 +143,8 @@ local use_settings = false local watch_interval = 60.0 local mse = 0.0001 local nlayers = 4 +local lock_expire = 600 +local learning_spawned = false local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args) if not ev_base or not redis_params or not callback or not command then @@ -285,13 +294,14 @@ local function load_or_invalidate_fann(data, id, ev_base) local ann if err or not ann_data then - rspamd_logger.errx('cannot decompress ann: %s', err) + rspamd_logger.errx(rspamd_config, 'cannot decompress ann: %s', err) else ann = rspamd_fann.load_data(ann_data) end if is_fann_valid(ann) then fanns[id].fann = ann + rspamd_logger.infox(rspamd_config, 'loaded ann %s from redis', id) else local function redis_invalidate_cb(err, data) if err then @@ -387,7 +397,7 @@ local function train_fann(cfg, ev_base, elt) end end - local function redis_save_unlock_sha(err, data) + local function redis_save_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s', fann_prefix .. elt, err) @@ -409,7 +419,8 @@ local function train_fann(cfg, ev_base, elt) else rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s', fann_prefix .. elt, train_mse) - local ann_data = rspamd_util.zstd_compress(data[elt].fann:data()) + learning_spawned = false + local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data()) fanns[elt].version = fanns[elt].version + 1 fanns[elt].fann = fanns[elt].fann_train fanns[elt].fann_train = nil @@ -448,13 +459,11 @@ local function train_fann(cfg, ev_base, elt) local outputs = {} each(function(sample) - table.insert(inputs, totable(sample)) + table.insert(inputs, totable(sample[1])) table.insert(outputs, {1.0}) - end, spam_elts) - each(function(sample) - table.insert(inputs, totable(sample)) + table.insert(inputs, totable(sample[2])) table.insert(outputs, {-1.0}) - end, ham_elts) + end, zip(spam_elts, ham_elts)) -- Now we can train fann local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens() @@ -463,6 +472,7 @@ local function train_fann(cfg, ev_base, elt) create_train_fann(n, elt) end + learning_spawned = true fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base, {max_epochs = max_epoch, desired_mse = mse}) end @@ -513,13 +523,15 @@ local function train_fann(cfg, ev_base, elt) ) end end + if learning_spawned then return end redis_make_request(ev_base, rspamd_config, nil, true, -- is write redis_lock_cb, --callback 'EVALSHA', -- command - {redis_maybe_lock_sha, '1', fann_prefix .. elt} + {redis_maybe_lock_sha, '3', fann_prefix .. elt, tostring(os.time()), + tostring(lock_expire)} ) end @@ -535,6 +547,8 @@ local function maybe_train_fanns(cfg, ev_base) rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err) elseif data and type(data) == 'number' or type(data) == 'string' then if tonumber(data) and tonumber(data) > max_trains then + rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)', + tonumber(data), max_trains) train_fann(cfg, ev_base, elt) end end |