aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/neural.lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/lua/neural.lua')
-rw-r--r--src/plugins/lua/neural.lua80
1 files changed, 39 insertions, 41 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 04b732472..9d0bbb446 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -65,10 +65,6 @@ local settings = {
rules = {}
}
--- ANNs indexed by settings id
-local anns = {
-}
-
local opts = rspamd_config:get_all_opt("neural")
if not opts then
-- Legacy
@@ -278,7 +274,7 @@ local function ann_scores_filter(task)
id = id .. r
end
- if anns[id] and anns[id].ann then
+ if rule.anns[id] and rule.anns[id].ann then
local ann_data = task:get_symbols_tokens()
local mt = meta_functions.rspamd_gen_metatokens(task)
-- Add filtered meta tokens
@@ -286,15 +282,15 @@ local function ann_scores_filter(task)
local score
if use_torch then
- local out = anns[id].ann:forward(torch.Tensor(ann_data))
+ local out = rule.anns[id].ann:forward(torch.Tensor(ann_data))
score = out[1]
else
- local out = anns[id].ann:test(ann_data)
+ local out = rule.anns[id].ann:test(ann_data)
score = out[1]
end
local symscore = string.format('%.3f', score)
- rspamd_logger.infox(task, 'ann score: %s', symscore)
+ rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore)
if score > 0 then
local result = score
@@ -339,28 +335,29 @@ end
local function create_train_ann(rule, n, id)
local prefix = gen_ann_prefix(rule, id)
- if not anns[id] then
- anns[id] = {}
+ if not rule.anns[id] then
+ rule.anns[id] = {}
end
-- Fix that for flexibe layers number
- if anns[id].ann then
- if not is_ann_valid(rule, prefix, anns[id].ann) then
- anns[id].ann_train = create_ann(n, rule.nlayers)
- anns[id].ann = nil
+ if rule.anns[id].ann then
+ if not is_ann_valid(rule, prefix, rule.anns[id].ann) then
+ rule.anns[id].ann_train = create_ann(n, rule.nlayers)
+ rule.anns[id].ann = nil
rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
- elseif rule.train.max_usages > 0 and anns[id].version % rule.train.max_usages == 0 then
+ elseif rule.train.max_usages > 0 and
+ rule.anns[id].version % rule.train.max_usages == 0 then
-- Forget last ann
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
- anns[id].version)
- anns[id].ann_train = create_ann(n, rule.nlayers)
+ rule.anns[id].version)
+ rule.anns[id].ann_train = create_ann(n, rule.nlayers)
else
- anns[id].ann_train = anns[id].ann
+ rule.anns[id].ann_train = rule.anns[id].ann
rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix)
end
else
- anns[id].ann_train = create_ann(n, rule.nlayers)
+ rule.anns[id].ann_train = create_ann(n, rule.nlayers)
rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix)
- anns[id].version = 0
+ rule.anns[id].version = 0
end
end
@@ -388,18 +385,18 @@ local function load_or_invalidate_ann(rule, data, id, ev_base)
end
if is_ann_valid(rule, prefix, ann) then
- if not anns[id] then anns[id] = {} end
- anns[id].ann = ann
+ if not rule.anns[id] then rule.anns[id] = {} end
+ rule.anns[id].ann = ann
rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
prefix, ver)
- anns[id].version = tonumber(ver)
+ rule.anns[id].version = tonumber(ver)
else
local function redis_invalidate_cb(_err, _data)
if _err then
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
elseif type(_data) == 'string' then
rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
- anns[id].version = 0
+ rule.anns[id].version = 0
end
end
-- Invalidate ANN
@@ -553,15 +550,15 @@ local function train_ann(rule, _, ev_base, elt, worker)
local ann_data
if use_torch then
local f = torch.MemoryFile()
- f:writeObject(anns[elt].ann_train)
+ f:writeObject(rule.anns[elt].ann_train)
ann_data = rspamd_util.zstd_compress(f:storage():string())
else
- ann_data = rspamd_util.zstd_compress(anns[elt].ann_train:data())
+ ann_data = rspamd_util.zstd_compress(rule.anns[elt].ann_train:data())
end
- anns[elt].version = anns[elt].version + 1
- anns[elt].ann = anns[elt].ann_train
- anns[elt].ann_train = nil
+ rule.anns[elt].version = rule.anns[elt].version + 1
+ rule.anns[elt].ann = rule.anns[elt].ann_train
+ rule.anns[elt].ann_train = nil
lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
redis_save_cb,
@@ -589,11 +586,11 @@ local function train_ann(rule, _, ev_base, elt, worker)
local ann_data
local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
ann_data = rspamd_util.zstd_compress(f:storage():string())
- anns[elt].ann_train = f:readObject()
+ rule.anns[elt].ann_train = f:readObject()
- anns[elt].version = anns[elt].version + 1
- anns[elt].ann = anns[elt].ann_train
- anns[elt].ann_train = nil
+ rule.anns[elt].version = rule.anns[elt].version + 1
+ rule.anns[elt].ann = rule.anns[elt].ann_train
+ rule.anns[elt].ann_train = nil
lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
redis_save_cb,
@@ -629,7 +626,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
end
-- Now we can train ann
- if not anns[elt] or not anns[elt].ann_train then
+ if not rule.anns[elt] or not rule.anns[elt].ann_train then
-- Create ann if it does not exist
create_train_ann(rule, n, elt)
end
@@ -641,7 +638,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
elseif type(_data) == 'string' then
rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
- anns[elt].version = 0
+ rule.anns[elt].version = 0
end
end
-- Invalidate ANN
@@ -668,7 +665,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
torch.setnumthreads(rule.train.learn_threads)
end
local criterion = nn.MSECriterion()
- local trainer = nn.StochasticGradient(anns[elt].ann_train,
+ local trainer = nn.StochasticGradient(rule.anns[elt].ann_train,
criterion)
trainer.learning_rate = rule.train.learning_rate
trainer.verbose = false
@@ -680,7 +677,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
trainer:train(dataset)
local out = torch.MemoryFile()
- out:writeObject(anns[elt].ann_train)
+ out:writeObject(rule.anns[elt].ann_train)
local st = out:storage():string()
return st
end
@@ -701,7 +698,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
rule.learning_spawned = true
rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
- anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
+ rule.anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
ev_base, {
max_epochs = rule.train.max_epoch,
desired_mse = rule.train.mse
@@ -880,9 +877,9 @@ local function check_anns(rule, _, ev_base)
end
local local_ver = 0
- if anns[elt] then
- if anns[elt].version then
- local_ver = anns[elt].version
+ if rule.anns[elt] then
+ if rule.anns[elt].version then
+ local_ver = rule.anns[elt].version
end
end
lua_redis.exec_redis_script(redis_maybe_load_id,
@@ -963,6 +960,7 @@ else
for k,r in pairs(rules) do
local def_rules = lua_util.override_defaults(default_options, r)
def_rules['redis'] = redis_params
+ def_rules['anns'] = {} -- Store ANNs here
if not def_rules.prefix then
def_rules.prefix = k