aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/neural.lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/lua/neural.lua')
-rw-r--r--src/plugins/lua/neural.lua169
1 files changed, 83 insertions, 86 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index cca6f647c..fdb138321 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -50,6 +50,26 @@ local default_options = {
symbol_ham = 'NEURAL_HAM',
}
+-- Rule structure:
+-- * static config fields (see `default_options`)
+-- * prefix - name or defined prefix
+-- * settings - table of settings indexed by settings id, -1 is used when no settings defined
+
+-- Rule settings element defines elements for specific settings id:
+-- * symbols - static symbols profile (defined by config or extracted from symcache)
+-- * name - name of settings id
+-- * digest - digest of all symbols
+-- * ann - dynamic ANN configuration loaded from Redis
+-- * train - train data for ANN (e.g. the currently trained ANN)
+
+-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
+-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
+-- * version - version of ANN loaded from redis
+-- * ann_key - name of ANN key in Redis
+-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
+-- * distance - distance between set.symbols and set.ann.symbols
+-- * ann - kann object
+
local settings = {
rules = {},
prefix = 'rn', -- Neural network default prefix
@@ -212,30 +232,75 @@ local function load_scripts(params)
params)
end
+local function result_to_vector(task, profile)
+ if not profile.zeros then
+ -- Fill zeros vector
+ local zeros = {}
+ for i=1,meta_functions.count_metatokens() do
+ zeros[i] = 0.0
+ end
+ for _,_ in ipairs(profile.symbols) do
+ zeros[#zeros + 1] = 0.0
+ end
+ profile.zeros = zeros
+ end
+
+ local vec = lua_util.shallowcopy(profile.zeros)
+ local mt = meta_functions.rspamd_gen_metatokens(task)
+
+ for i,v in ipairs(mt) do
+ vec[i] = v
+ end
+
+ task:process_ann_tokens(profile.symbols, vec, #mt)
+
+ return vec
+end
local function ann_scores_filter(task)
for _,rule in pairs(settings.rules) do
- local id = '0'
- if rule.use_settings then
- local sid = task:get_settings_id()
- if sid then
- id = tostring(sid)
- end
- end
- if rule.per_user then
- local r = task:get_principal_recipient()
- id = id .. r
+ local sid = task:get_settings_id()
+ local ann
+ local profile
+
+ if sid then
+ if rule.settings[sid] then
+ local set = rule.settings[sid]
+
+ if set.ann then
+ ann = set.ann.ann
+ profile = set.ann
+ else
+ lua_util.debugm(N, task, 'no ann loaded for %s:%s',
+ rule.prefix, set.name)
+ end
+ else
+ lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
+ rule.prefix, sid)
+ end
+ else
+ if rule.settings[-1] then
+ local set = rule.settings[-1]
+
+ if set.ann then
+ ann = set.ann.ann
+ profile = set.ann
+ else
+ lua_util.debugm(N, task, 'no ann loaded for %s:%s',
+ rule.prefix, set.name)
+ end
+ else
+ lua_util.debugm(N, task, 'no default ann for rule %s',
+ rule.prefix)
+ end
end
- if rule.anns[id] and rule.anns[id].ann then
- local ann_data = task:get_symbols_tokens()
- local mt = meta_functions.rspamd_gen_metatokens(task)
- -- Add filtered meta tokens
- fun.each(function(e) table.insert(ann_data, e) end, mt)
+ if ann then
+ local vec = result_to_vector(task, profile)
local score
- local out = rule.anns[id].ann:apply1(ann_data)
+ local out = ann:apply1(vec)
score = out[1]
local symscore = string.format('%.3f', score)
@@ -262,76 +327,6 @@ local function create_ann(n, nlayers)
return rspamd_kann.new.kann(t)
end
-local function create_train_ann(rule, n, id)
- local prefix = gen_ann_prefix(rule, id)
- if not rule.anns[id] then
- rule.anns[id] = {}
- end
- -- Fix that for flexibe layers number
- if rule.anns[id].ann then
- if not is_ann_valid(rule, prefix, rule.anns[id].ann) then
- rule.anns[id].ann_train = create_ann(n, rule.nlayers)
- rule.anns[id].ann = nil
- rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
- elseif rule.train.max_usages > 0 and
- rule.anns[id].version % rule.train.max_usages == 0 then
- -- Forget last ann
- rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
- rule.anns[id].version)
- rule.anns[id].ann_train = create_ann(n, rule.nlayers)
- else
- rule.anns[id].ann_train = rule.anns[id].ann
- rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix)
- end
- else
- rule.anns[id].ann_train = create_ann(n, rule.nlayers)
- rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix)
- rule.anns[id].version = 0
- end
-end
-
-local function load_or_invalidate_ann(rule, data, id, ev_base)
- local ver = data[2]
- local prefix = gen_ann_prefix(rule, id)
-
- if not ver or not tonumber(ver) then
- rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix)
- return
- end
-
- local err,ann_data = rspamd_util.zstd_decompress(data[1])
- local ann
-
- if err or not ann_data then
- rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
- return
- else
- ann = rspamd_kann.load(ann_data)
- end
-
- if is_ann_valid(rule, prefix, ann) then
- if not rule.anns[id] then rule.anns[id] = {} end
- rule.anns[id].ann = ann
- rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
- prefix, ver)
- rule.anns[id].version = tonumber(ver)
- else
- local function redis_invalidate_cb(_err, _data)
- if _err then
- rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
- elseif type(_data) == 'string' then
- rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
- rule.anns[id].version = 0
- end
- end
- -- Invalidate ANN
- rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix)
- lua_redis.exec_redis_script(redis_maybe_invalidate_id,
- {ev_base = ev_base, is_write = true},
- redis_invalidate_cb,
- {prefix})
- end
-end
local function ann_train_callback(rule, task, score, required_score, id)
local train_opts = rule['train']
@@ -901,6 +896,7 @@ local function cleanup_anns(rule, cfg, ev_base)
end
end
end
+
lua_redis.exec_redis_script(redis_maybe_invalidate_id,
{ev_base = ev_base, is_write = true},
invalidate_cb,
@@ -1095,6 +1091,7 @@ for k,rule in pairs(settings.rules) do
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
+ -- Clean old ANNs
cleanup_anns(rule, cfg, ev_base)
return maybe_train_anns(rule, cfg, ev_base, worker)
end)