aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-16 11:16:42 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-16 11:16:42 +0000
commit5bec51d7ef9a774b8aecec8f1e48a1b01babfbd3 (patch)
tree0eac95391f7b418c5e512a528852140b9a64dc02 /src
parenta5eef8b08ca4ea69cc3c5bb718102fe24f20a61b (diff)
downloadrspamd-5bec51d7ef9a774b8aecec8f1e48a1b01babfbd3.tar.gz
rspamd-5bec51d7ef9a774b8aecec8f1e48a1b01babfbd3.zip
[Fix] Forget old ANN when max_usages is reached to avoid overtrain
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/fann_redis.lua41
1 files changed, 34 insertions, 7 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index cff8981de..324b7ba29 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -79,7 +79,9 @@ local redis_lua_script_maybe_load = [[
local ver = 0
local ret = redis.call('GET', KEYS[1] .. '_version')
if ret then ver = tonumber(ret) end
- if ver > tonumber(KEYS[2]) then return redis.call('GET', KEYS[1] .. '_data') end
+ if ver > tonumber(KEYS[2]) then
+ return {redis.call('GET', KEYS[1] .. '_data'), ret}
+ end
return false
]]
@@ -137,7 +139,8 @@ redis_params = rspamd_parse_redis_server('fann_redis')
local fann_prefix = 'RFANN'
local max_trains = 1000
-local max_epoch = 100
+local max_epoch = 1000
+local max_usages = 10
local use_settings = false
local watch_interval = 60.0
local mse = 0.0001
@@ -280,8 +283,16 @@ local function create_train_fann(n, id)
end
if fanns[id].fann then
- fanns[id].fann_train = fanns[id].fann
- fanns[id].fann = nil
+ if fanns[id].version % max_usages == 0 then
+ -- Forget last fann
+ rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', id,
+ fanns[id].version)
+ fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
+ fanns[id].fann = nil
+ else
+ fanns[id].fann_train = fanns[id].fann
+ fanns[id].fann = nil
+ end
else
fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
fanns[id].version = 0
@@ -289,24 +300,34 @@ local function create_train_fann(n, id)
end
local function load_or_invalidate_fann(data, id, ev_base)
- local err,ann_data = rspamd_util.zstd_decompress(data)
+ local ver = data[2]
+ if not ver or not tonumber(ver) then
+ rspamd_logger.errx(rspamd_config, 'cannot get version for ann: %s', id)
+ return
+ end
+
+ local err,ann_data = rspamd_util.zstd_decompress(data[1])
local ann
if err or not ann_data then
- rspamd_logger.errx(rspamd_config, 'cannot decompress ann: %s', err)
+ rspamd_logger.errx(rspamd_config, 'cannot decompress ann %s: %s', id, err)
+ return
else
ann = rspamd_fann.load_data(ann_data)
end
if is_fann_valid(ann) then
fanns[id].fann = ann
- rspamd_logger.infox(rspamd_config, 'loaded ann %s from redis', id)
+ rspamd_logger.infox(rspamd_config, 'loaded ann %s version %s from redis',
+ id, ver)
+ fanns[id].version = tonumber(ver)
else
local function redis_invalidate_cb(_err, _data)
if _err then
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, _err)
elseif type(_data) == 'string' then
rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, _err)
+ fanns[id].version = 0
end
end
-- Invalidate ANN
@@ -701,6 +722,12 @@ else
if opts['train']['max_epoch'] then
max_epoch = opts['train']['max_epoch']
end
+ if opts['train']['max_usages'] then
+ max_usages = opts['train']['max_usages']
+ end
+ if opts['train']['mse'] then
+ mse = opts['train']['mse']
+ end
local ret = cfg:register_worker_script("log_helper",
function(score, req_score, results, cf, _id, extra, ev_base)
-- fun.map (snd x) (fun.filter (fst x == module_id) extra)