diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-21 12:57:34 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-21 12:57:34 +0100 |
commit | daa7d496c42c32e2a4cfc18af64207e5ed910f4d (patch) | |
tree | eb56a094e25c4b60b183dd7f0bea7f0b674abaab | |
parent | f160bf3b0abf77c7f48c2ea85e7a7e0f5d678926 (diff) | |
download | rspamd-daa7d496c42c32e2a4cfc18af64207e5ed910f4d.tar.gz rspamd-daa7d496c42c32e2a4cfc18af64207e5ed910f4d.zip |
[CritFix] Fix multiple neural networks support
Issue: #2252
-rw-r--r-- | src/plugins/lua/neural.lua | 78 |
1 files changed, 38 insertions, 40 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 04b732472..c8a6f1173 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -65,10 +65,6 @@ local settings = { rules = {} } --- ANNs indexed by settings id -local anns = { -} - local opts = rspamd_config:get_all_opt("neural") if not opts then -- Legacy @@ -278,7 +274,7 @@ local function ann_scores_filter(task) id = id .. r end - if anns[id] and anns[id].ann then + if rule.anns[id] and rule.anns[id].ann then local ann_data = task:get_symbols_tokens() local mt = meta_functions.rspamd_gen_metatokens(task) -- Add filtered meta tokens @@ -286,10 +282,10 @@ local function ann_scores_filter(task) local score if use_torch then - local out = anns[id].ann:forward(torch.Tensor(ann_data)) + local out = rule.anns[id].ann:forward(torch.Tensor(ann_data)) score = out[1] else - local out = anns[id].ann:test(ann_data) + local out = rule.anns[id].ann:test(ann_data) score = out[1] end @@ -339,28 +335,29 @@ end local function create_train_ann(rule, n, id) local prefix = gen_ann_prefix(rule, id) - if not anns[id] then - anns[id] = {} + if not rule.anns[id] then + rule.anns[id] = {} end -- Fix that for flexibe layers number - if anns[id].ann then - if not is_ann_valid(rule, prefix, anns[id].ann) then - anns[id].ann_train = create_ann(n, rule.nlayers) - anns[id].ann = nil + if rule.anns[id].ann then + if not is_ann_valid(rule, prefix, rule.anns[id].ann) then + rule.anns[id].ann_train = create_ann(n, rule.nlayers) + rule.anns[id].ann = nil rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix) - elseif rule.train.max_usages > 0 and anns[id].version % rule.train.max_usages == 0 then + elseif rule.train.max_usages > 0 and + rule.anns[id].version % rule.train.max_usages == 0 then -- Forget last ann rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, - anns[id].version) - anns[id].ann_train = create_ann(n, rule.nlayers) + rule.anns[id].version) + rule.anns[id].ann_train = create_ann(n, rule.nlayers) else - anns[id].ann_train = anns[id].ann + rule.anns[id].ann_train = rule.anns[id].ann rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix) end else - anns[id].ann_train = create_ann(n, rule.nlayers) + rule.anns[id].ann_train = create_ann(n, rule.nlayers) rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix) - anns[id].version = 0 + rule.anns[id].version = 0 end end @@ -388,18 +385,18 @@ local function load_or_invalidate_ann(rule, data, id, ev_base) end if is_ann_valid(rule, prefix, ann) then - if not anns[id] then anns[id] = {} end - anns[id].ann = ann + if not rule.anns[id] then rule.anns[id] = {} end + rule.anns[id].ann = ann rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis', prefix, ver) - anns[id].version = tonumber(ver) + rule.anns[id].version = tonumber(ver) else local function redis_invalidate_cb(_err, _data) if _err then rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) elseif type(_data) == 'string' then rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) - anns[id].version = 0 + rule.anns[id].version = 0 end end -- Invalidate ANN @@ -553,15 +550,15 @@ local function train_ann(rule, _, ev_base, elt, worker) local ann_data if use_torch then local f = torch.MemoryFile() - f:writeObject(anns[elt].ann_train) + f:writeObject(rule.anns[elt].ann_train) ann_data = rspamd_util.zstd_compress(f:storage():string()) else - ann_data = rspamd_util.zstd_compress(anns[elt].ann_train:data()) + ann_data = rspamd_util.zstd_compress(rule.anns[elt].ann_train:data()) end - anns[elt].version = anns[elt].version + 1 - anns[elt].ann = anns[elt].ann_train - anns[elt].ann_train = nil + rule.anns[elt].version = rule.anns[elt].version + 1 + rule.anns[elt].ann = rule.anns[elt].ann_train + rule.anns[elt].ann_train = nil lua_redis.exec_redis_script(redis_save_unlock_id, {ev_base = ev_base, is_write = true}, redis_save_cb, @@ -589,11 +586,11 @@ local function train_ann(rule, _, ev_base, elt, worker) local ann_data local f = torch.MemoryFile(torch.CharStorage():string(tostring(data))) ann_data = rspamd_util.zstd_compress(f:storage():string()) - anns[elt].ann_train = f:readObject() + rule.anns[elt].ann_train = f:readObject() - anns[elt].version = anns[elt].version + 1 - anns[elt].ann = anns[elt].ann_train - anns[elt].ann_train = nil + rule.anns[elt].version = rule.anns[elt].version + 1 + rule.anns[elt].ann = rule.anns[elt].ann_train + rule.anns[elt].ann_train = nil lua_redis.exec_redis_script(redis_save_unlock_id, {ev_base = ev_base, is_write = true}, redis_save_cb, @@ -629,7 +626,7 @@ local function train_ann(rule, _, ev_base, elt, worker) end -- Now we can train ann - if not anns[elt] or not anns[elt].ann_train then + if not rule.anns[elt] or not rule.anns[elt].ann_train then -- Create ann if it does not exist create_train_ann(rule, n, elt) end @@ -641,7 +638,7 @@ local function train_ann(rule, _, ev_base, elt, worker) rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) elseif type(_data) == 'string' then rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) - anns[elt].version = 0 + rule.anns[elt].version = 0 end end -- Invalidate ANN @@ -668,7 +665,7 @@ local function train_ann(rule, _, ev_base, elt, worker) torch.setnumthreads(rule.train.learn_threads) end local criterion = nn.MSECriterion() - local trainer = nn.StochasticGradient(anns[elt].ann_train, + local trainer = nn.StochasticGradient(rule.anns[elt].ann_train, criterion) trainer.learning_rate = rule.train.learning_rate trainer.verbose = false @@ -680,7 +677,7 @@ local function train_ann(rule, _, ev_base, elt, worker) trainer:train(dataset) local out = torch.MemoryFile() - out:writeObject(anns[elt].ann_train) + out:writeObject(rule.anns[elt].ann_train) local st = out:storage():string() return st end @@ -701,7 +698,7 @@ local function train_ann(rule, _, ev_base, elt, worker) end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts))) rule.learning_spawned = true rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix) - anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained, + rule.anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained, ev_base, { max_epochs = rule.train.max_epoch, desired_mse = rule.train.mse @@ -880,9 +877,9 @@ local function check_anns(rule, _, ev_base) end local local_ver = 0 - if anns[elt] then - if anns[elt].version then - local_ver = anns[elt].version + if rule.anns[elt] then + if rule.anns[elt].version then + local_ver = rule.anns[elt].version end end lua_redis.exec_redis_script(redis_maybe_load_id, @@ -963,6 +960,7 @@ else for k,r in pairs(rules) do local def_rules = lua_util.override_defaults(default_options, r) def_rules['redis'] = redis_params + def_rules['anns'] = {} -- Store ANNs here if not def_rules.prefix then def_rules.prefix = k |