|
|
@@ -50,6 +50,26 @@ local default_options = { |
|
|
|
symbol_ham = 'NEURAL_HAM', |
|
|
|
} |
|
|
|
|
|
|
|
-- Rule structure: |
|
|
|
-- * static config fields (see `default_options`) |
|
|
|
-- * prefix - name or defined prefix |
|
|
|
-- * settings - table of settings indexed by settings id, -1 is used when no settings defined |
|
|
|
|
|
|
|
-- Rule settings element defines elements for specific settings id: |
|
|
|
-- * symbols - static symbols profile (defined by config or extracted from symcache) |
|
|
|
-- * name - name of settings id |
|
|
|
-- * digest - digest of all symbols |
|
|
|
-- * ann - dynamic ANN configuration loaded from Redis |
|
|
|
-- * train - train data for ANN (e.g. the currently trained ANN) |
|
|
|
|
|
|
|
-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN |
|
|
|
-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically |
|
|
|
-- * version - version of ANN loaded from redis |
|
|
|
-- * ann_key - name of ANN key in Redis |
|
|
|
-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols) |
|
|
|
-- * distance - distance between set.symbols and set.ann.symbols |
|
|
|
-- * ann - kann object |
|
|
|
|
|
|
|
local settings = { |
|
|
|
rules = {}, |
|
|
|
prefix = 'rn', -- Neural network default prefix |
|
|
@@ -212,30 +232,75 @@ local function load_scripts(params) |
|
|
|
params) |
|
|
|
end |
|
|
|
|
|
|
|
local function result_to_vector(task, profile) |
|
|
|
if not profile.zeros then |
|
|
|
-- Fill zeros vector |
|
|
|
local zeros = {} |
|
|
|
for i=1,meta_functions.count_metatokens() do |
|
|
|
zeros[i] = 0.0 |
|
|
|
end |
|
|
|
for _,_ in ipairs(profile.symbols) do |
|
|
|
zeros[#zeros + 1] = 0.0 |
|
|
|
end |
|
|
|
profile.zeros = zeros |
|
|
|
end |
|
|
|
|
|
|
|
local vec = lua_util.shallowcopy(profile.zeros) |
|
|
|
local mt = meta_functions.rspamd_gen_metatokens(task) |
|
|
|
|
|
|
|
for i,v in ipairs(mt) do |
|
|
|
vec[i] = v |
|
|
|
end |
|
|
|
|
|
|
|
task:process_ann_tokens(profile.symbols, vec, #mt) |
|
|
|
|
|
|
|
return vec |
|
|
|
end |
|
|
|
|
|
|
|
local function ann_scores_filter(task) |
|
|
|
|
|
|
|
for _,rule in pairs(settings.rules) do |
|
|
|
local id = '0' |
|
|
|
if rule.use_settings then |
|
|
|
local sid = task:get_settings_id() |
|
|
|
if sid then |
|
|
|
id = tostring(sid) |
|
|
|
end |
|
|
|
end |
|
|
|
if rule.per_user then |
|
|
|
local r = task:get_principal_recipient() |
|
|
|
id = id .. r |
|
|
|
local sid = task:get_settings_id() |
|
|
|
local ann |
|
|
|
local profile |
|
|
|
|
|
|
|
if sid then |
|
|
|
if rule.settings[sid] then |
|
|
|
local set = rule.settings[sid] |
|
|
|
|
|
|
|
if set.ann then |
|
|
|
ann = set.ann.ann |
|
|
|
profile = set.ann |
|
|
|
else |
|
|
|
lua_util.debugm(N, task, 'no ann loaded for %s:%s', |
|
|
|
rule.prefix, set.name) |
|
|
|
end |
|
|
|
else |
|
|
|
lua_util.debugm(N, task, 'no ann defined in %s for settings id %s', |
|
|
|
rule.prefix, sid) |
|
|
|
end |
|
|
|
else |
|
|
|
if rule.settings[-1] then |
|
|
|
local set = rule.settings[-1] |
|
|
|
|
|
|
|
if set.ann then |
|
|
|
ann = set.ann.ann |
|
|
|
profile = set.ann |
|
|
|
else |
|
|
|
lua_util.debugm(N, task, 'no ann loaded for %s:%s', |
|
|
|
rule.prefix, set.name) |
|
|
|
end |
|
|
|
else |
|
|
|
lua_util.debugm(N, task, 'no default ann for rule %s', |
|
|
|
rule.prefix) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
if rule.anns[id] and rule.anns[id].ann then |
|
|
|
local ann_data = task:get_symbols_tokens() |
|
|
|
local mt = meta_functions.rspamd_gen_metatokens(task) |
|
|
|
-- Add filtered meta tokens |
|
|
|
fun.each(function(e) table.insert(ann_data, e) end, mt) |
|
|
|
if ann then |
|
|
|
local vec = result_to_vector(task, profile) |
|
|
|
|
|
|
|
local score |
|
|
|
local out = rule.anns[id].ann:apply1(ann_data) |
|
|
|
local out = ann:apply1(vec) |
|
|
|
score = out[1] |
|
|
|
|
|
|
|
local symscore = string.format('%.3f', score) |
|
|
@@ -262,76 +327,6 @@ local function create_ann(n, nlayers) |
|
|
|
return rspamd_kann.new.kann(t) |
|
|
|
end |
|
|
|
|
|
|
|
local function create_train_ann(rule, n, id) |
|
|
|
local prefix = gen_ann_prefix(rule, id) |
|
|
|
if not rule.anns[id] then |
|
|
|
rule.anns[id] = {} |
|
|
|
end |
|
|
|
-- Fix that for flexibe layers number |
|
|
|
if rule.anns[id].ann then |
|
|
|
if not is_ann_valid(rule, prefix, rule.anns[id].ann) then |
|
|
|
rule.anns[id].ann_train = create_ann(n, rule.nlayers) |
|
|
|
rule.anns[id].ann = nil |
|
|
|
rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix) |
|
|
|
elseif rule.train.max_usages > 0 and |
|
|
|
rule.anns[id].version % rule.train.max_usages == 0 then |
|
|
|
-- Forget last ann |
|
|
|
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, |
|
|
|
rule.anns[id].version) |
|
|
|
rule.anns[id].ann_train = create_ann(n, rule.nlayers) |
|
|
|
else |
|
|
|
rule.anns[id].ann_train = rule.anns[id].ann |
|
|
|
rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix) |
|
|
|
end |
|
|
|
else |
|
|
|
rule.anns[id].ann_train = create_ann(n, rule.nlayers) |
|
|
|
rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix) |
|
|
|
rule.anns[id].version = 0 |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
local function load_or_invalidate_ann(rule, data, id, ev_base) |
|
|
|
local ver = data[2] |
|
|
|
local prefix = gen_ann_prefix(rule, id) |
|
|
|
|
|
|
|
if not ver or not tonumber(ver) then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix) |
|
|
|
return |
|
|
|
end |
|
|
|
|
|
|
|
local err,ann_data = rspamd_util.zstd_decompress(data[1]) |
|
|
|
local ann |
|
|
|
|
|
|
|
if err or not ann_data then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err) |
|
|
|
return |
|
|
|
else |
|
|
|
ann = rspamd_kann.load(ann_data) |
|
|
|
end |
|
|
|
|
|
|
|
if is_ann_valid(rule, prefix, ann) then |
|
|
|
if not rule.anns[id] then rule.anns[id] = {} end |
|
|
|
rule.anns[id].ann = ann |
|
|
|
rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis', |
|
|
|
prefix, ver) |
|
|
|
rule.anns[id].version = tonumber(ver) |
|
|
|
else |
|
|
|
local function redis_invalidate_cb(_err, _data) |
|
|
|
if _err then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) |
|
|
|
elseif type(_data) == 'string' then |
|
|
|
rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) |
|
|
|
rule.anns[id].version = 0 |
|
|
|
end |
|
|
|
end |
|
|
|
-- Invalidate ANN |
|
|
|
rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix) |
|
|
|
lua_redis.exec_redis_script(redis_maybe_invalidate_id, |
|
|
|
{ev_base = ev_base, is_write = true}, |
|
|
|
redis_invalidate_cb, |
|
|
|
{prefix}) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
local function ann_train_callback(rule, task, score, required_score, id) |
|
|
|
local train_opts = rule['train'] |
|
|
@@ -901,6 +896,7 @@ local function cleanup_anns(rule, cfg, ev_base) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
lua_redis.exec_redis_script(redis_maybe_invalidate_id, |
|
|
|
{ev_base = ev_base, is_write = true}, |
|
|
|
invalidate_cb, |
|
|
@@ -1095,6 +1091,7 @@ for k,rule in pairs(settings.rules) do |
|
|
|
-- We also want to train neural nets when they have enough data |
|
|
|
rspamd_config:add_periodic(ev_base, 0.0, |
|
|
|
function(_, _) |
|
|
|
-- Clean old ANNs |
|
|
|
cleanup_anns(rule, cfg, ev_base) |
|
|
|
return maybe_train_anns(rule, cfg, ev_base, worker) |
|
|
|
end) |