rules = {}
}
--- ANNs indexed by settings id
-local anns = {
-}
-
local opts = rspamd_config:get_all_opt("neural")
if not opts then
-- Legacy
id = id .. r
end
- if anns[id] and anns[id].ann then
+ if rule.anns[id] and rule.anns[id].ann then
local ann_data = task:get_symbols_tokens()
local mt = meta_functions.rspamd_gen_metatokens(task)
-- Add filtered meta tokens
local score
if use_torch then
- local out = anns[id].ann:forward(torch.Tensor(ann_data))
+ local out = rule.anns[id].ann:forward(torch.Tensor(ann_data))
score = out[1]
else
- local out = anns[id].ann:test(ann_data)
+ local out = rule.anns[id].ann:test(ann_data)
score = out[1]
end
local function create_train_ann(rule, n, id)
local prefix = gen_ann_prefix(rule, id)
- if not anns[id] then
- anns[id] = {}
+ if not rule.anns[id] then
+ rule.anns[id] = {}
end
-- Fix that for flexibe layers number
- if anns[id].ann then
- if not is_ann_valid(rule, prefix, anns[id].ann) then
- anns[id].ann_train = create_ann(n, rule.nlayers)
- anns[id].ann = nil
+ if rule.anns[id].ann then
+ if not is_ann_valid(rule, prefix, rule.anns[id].ann) then
+ rule.anns[id].ann_train = create_ann(n, rule.nlayers)
+ rule.anns[id].ann = nil
rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
- elseif rule.train.max_usages > 0 and anns[id].version % rule.train.max_usages == 0 then
+ elseif rule.train.max_usages > 0 and
+ rule.anns[id].version % rule.train.max_usages == 0 then
-- Forget last ann
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
- anns[id].version)
- anns[id].ann_train = create_ann(n, rule.nlayers)
+ rule.anns[id].version)
+ rule.anns[id].ann_train = create_ann(n, rule.nlayers)
else
- anns[id].ann_train = anns[id].ann
+ rule.anns[id].ann_train = rule.anns[id].ann
rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix)
end
else
- anns[id].ann_train = create_ann(n, rule.nlayers)
+ rule.anns[id].ann_train = create_ann(n, rule.nlayers)
rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix)
- anns[id].version = 0
+ rule.anns[id].version = 0
end
end
end
if is_ann_valid(rule, prefix, ann) then
- if not anns[id] then anns[id] = {} end
- anns[id].ann = ann
+ if not rule.anns[id] then rule.anns[id] = {} end
+ rule.anns[id].ann = ann
rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
prefix, ver)
- anns[id].version = tonumber(ver)
+ rule.anns[id].version = tonumber(ver)
else
local function redis_invalidate_cb(_err, _data)
if _err then
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
elseif type(_data) == 'string' then
rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
- anns[id].version = 0
+ rule.anns[id].version = 0
end
end
-- Invalidate ANN
local ann_data
if use_torch then
local f = torch.MemoryFile()
- f:writeObject(anns[elt].ann_train)
+ f:writeObject(rule.anns[elt].ann_train)
ann_data = rspamd_util.zstd_compress(f:storage():string())
else
- ann_data = rspamd_util.zstd_compress(anns[elt].ann_train:data())
+ ann_data = rspamd_util.zstd_compress(rule.anns[elt].ann_train:data())
end
- anns[elt].version = anns[elt].version + 1
- anns[elt].ann = anns[elt].ann_train
- anns[elt].ann_train = nil
+ rule.anns[elt].version = rule.anns[elt].version + 1
+ rule.anns[elt].ann = rule.anns[elt].ann_train
+ rule.anns[elt].ann_train = nil
lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
redis_save_cb,
local ann_data
local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
ann_data = rspamd_util.zstd_compress(f:storage():string())
- anns[elt].ann_train = f:readObject()
+ rule.anns[elt].ann_train = f:readObject()
- anns[elt].version = anns[elt].version + 1
- anns[elt].ann = anns[elt].ann_train
- anns[elt].ann_train = nil
+ rule.anns[elt].version = rule.anns[elt].version + 1
+ rule.anns[elt].ann = rule.anns[elt].ann_train
+ rule.anns[elt].ann_train = nil
lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
redis_save_cb,
end
-- Now we can train ann
- if not anns[elt] or not anns[elt].ann_train then
+ if not rule.anns[elt] or not rule.anns[elt].ann_train then
-- Create ann if it does not exist
create_train_ann(rule, n, elt)
end
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
elseif type(_data) == 'string' then
rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
- anns[elt].version = 0
+ rule.anns[elt].version = 0
end
end
-- Invalidate ANN
torch.setnumthreads(rule.train.learn_threads)
end
local criterion = nn.MSECriterion()
- local trainer = nn.StochasticGradient(anns[elt].ann_train,
+ local trainer = nn.StochasticGradient(rule.anns[elt].ann_train,
criterion)
trainer.learning_rate = rule.train.learning_rate
trainer.verbose = false
trainer:train(dataset)
local out = torch.MemoryFile()
- out:writeObject(anns[elt].ann_train)
+ out:writeObject(rule.anns[elt].ann_train)
local st = out:storage():string()
return st
end
end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
rule.learning_spawned = true
rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
- anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
+ rule.anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
ev_base, {
max_epochs = rule.train.max_epoch,
desired_mse = rule.train.mse
end
local local_ver = 0
- if anns[elt] then
- if anns[elt].version then
- local_ver = anns[elt].version
+ if rule.anns[elt] then
+ if rule.anns[elt].version then
+ local_ver = rule.anns[elt].version
end
end
lua_redis.exec_redis_script(redis_maybe_load_id,
for k,r in pairs(rules) do
local def_rules = lua_util.override_defaults(default_options, r)
def_rules['redis'] = redis_params
+ def_rules['anns'] = {} -- Store ANNs here
if not def_rules.prefix then
def_rules.prefix = k