From: Vsevolod Stakhov Date: Wed, 16 Nov 2016 11:16:42 +0000 (+0000) Subject: [Fix] Forget old ANN when max_usages is reached to avoid overtrain X-Git-Tag: 1.4.0~51 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=5bec51d7ef9a774b8aecec8f1e48a1b01babfbd3;p=rspamd.git [Fix] Forget old ANN when max_usages is reached to avoid overtrain --- diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index cff8981de..324b7ba29 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -79,7 +79,9 @@ local redis_lua_script_maybe_load = [[ local ver = 0 local ret = redis.call('GET', KEYS[1] .. '_version') if ret then ver = tonumber(ret) end - if ver > tonumber(KEYS[2]) then return redis.call('GET', KEYS[1] .. '_data') end + if ver > tonumber(KEYS[2]) then + return {redis.call('GET', KEYS[1] .. '_data'), ret} + end return false ]] @@ -137,7 +139,8 @@ redis_params = rspamd_parse_redis_server('fann_redis') local fann_prefix = 'RFANN' local max_trains = 1000 -local max_epoch = 100 +local max_epoch = 1000 +local max_usages = 10 local use_settings = false local watch_interval = 60.0 local mse = 0.0001 @@ -280,8 +283,16 @@ local function create_train_fann(n, id) end if fanns[id].fann then - fanns[id].fann_train = fanns[id].fann - fanns[id].fann = nil + if fanns[id].version % max_usages == 0 then + -- Forget last fann + rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', id, + fanns[id].version) + fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1) + fanns[id].fann = nil + else + fanns[id].fann_train = fanns[id].fann + fanns[id].fann = nil + end else fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1) fanns[id].version = 0 @@ -289,24 +300,34 @@ local function create_train_fann(n, id) end local function load_or_invalidate_fann(data, id, ev_base) - local err,ann_data = rspamd_util.zstd_decompress(data) + local ver = data[2] + if not ver or not tonumber(ver) then + rspamd_logger.errx(rspamd_config, 'cannot get version for ann: %s', id) + return + end + + local err,ann_data = rspamd_util.zstd_decompress(data[1]) local ann if err or not ann_data then - rspamd_logger.errx(rspamd_config, 'cannot decompress ann: %s', err) + rspamd_logger.errx(rspamd_config, 'cannot decompress ann %s: %s', id, err) + return else ann = rspamd_fann.load_data(ann_data) end if is_fann_valid(ann) then fanns[id].fann = ann - rspamd_logger.infox(rspamd_config, 'loaded ann %s from redis', id) + rspamd_logger.infox(rspamd_config, 'loaded ann %s version %s from redis', + id, ver) + fanns[id].version = tonumber(ver) else local function redis_invalidate_cb(_err, _data) if _err then rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, _err) elseif type(_data) == 'string' then rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, _err) + fanns[id].version = 0 end end -- Invalidate ANN @@ -701,6 +722,12 @@ else if opts['train']['max_epoch'] then max_epoch = opts['train']['max_epoch'] end + if opts['train']['max_usages'] then + max_usages = opts['train']['max_usages'] + end + if opts['train']['mse'] then + mse = opts['train']['mse'] + end local ret = cfg:register_worker_script("log_helper", function(score, req_score, results, cf, _id, extra, ev_base) -- fun.map (snd x) (fun.filter (fst x == module_id) extra)