Browse Source

[Project] Neural: Implement scoring

tags/2.0
Vsevolod Stakhov 5 years ago
parent
commit
933c82f6ed
1 changed files with 83 additions and 86 deletions
  1. 83
    86
      src/plugins/lua/neural.lua

+ 83
- 86
src/plugins/lua/neural.lua View File

@@ -50,6 +50,26 @@ local default_options = {
symbol_ham = 'NEURAL_HAM',
}

-- Rule structure:
-- * static config fields (see `default_options`)
-- * prefix - name or defined prefix
-- * settings - table of settings indexed by settings id, -1 is used when no settings defined

-- Rule settings element defines elements for specific settings id:
-- * symbols - static symbols profile (defined by config or extracted from symcache)
-- * name - name of settings id
-- * digest - digest of all symbols
-- * ann - dynamic ANN configuration loaded from Redis
-- * train - train data for ANN (e.g. the currently trained ANN)

-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
-- * version - version of ANN loaded from redis
-- * ann_key - name of ANN key in Redis
-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
-- * distance - distance between set.symbols and set.ann.symbols
-- * ann - kann object

local settings = {
rules = {},
prefix = 'rn', -- Neural network default prefix
@@ -212,30 +232,75 @@ local function load_scripts(params)
params)
end

local function result_to_vector(task, profile)
if not profile.zeros then
-- Fill zeros vector
local zeros = {}
for i=1,meta_functions.count_metatokens() do
zeros[i] = 0.0
end
for _,_ in ipairs(profile.symbols) do
zeros[#zeros + 1] = 0.0
end
profile.zeros = zeros
end

local vec = lua_util.shallowcopy(profile.zeros)
local mt = meta_functions.rspamd_gen_metatokens(task)

for i,v in ipairs(mt) do
vec[i] = v
end

task:process_ann_tokens(profile.symbols, vec, #mt)

return vec
end

local function ann_scores_filter(task)

for _,rule in pairs(settings.rules) do
local id = '0'
if rule.use_settings then
local sid = task:get_settings_id()
if sid then
id = tostring(sid)
end
end
if rule.per_user then
local r = task:get_principal_recipient()
id = id .. r
local sid = task:get_settings_id()
local ann
local profile

if sid then
if rule.settings[sid] then
local set = rule.settings[sid]

if set.ann then
ann = set.ann.ann
profile = set.ann
else
lua_util.debugm(N, task, 'no ann loaded for %s:%s',
rule.prefix, set.name)
end
else
lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
rule.prefix, sid)
end
else
if rule.settings[-1] then
local set = rule.settings[-1]

if set.ann then
ann = set.ann.ann
profile = set.ann
else
lua_util.debugm(N, task, 'no ann loaded for %s:%s',
rule.prefix, set.name)
end
else
lua_util.debugm(N, task, 'no default ann for rule %s',
rule.prefix)
end
end

if rule.anns[id] and rule.anns[id].ann then
local ann_data = task:get_symbols_tokens()
local mt = meta_functions.rspamd_gen_metatokens(task)
-- Add filtered meta tokens
fun.each(function(e) table.insert(ann_data, e) end, mt)
if ann then
local vec = result_to_vector(task, profile)

local score
local out = rule.anns[id].ann:apply1(ann_data)
local out = ann:apply1(vec)
score = out[1]

local symscore = string.format('%.3f', score)
@@ -262,76 +327,6 @@ local function create_ann(n, nlayers)
return rspamd_kann.new.kann(t)
end

local function create_train_ann(rule, n, id)
local prefix = gen_ann_prefix(rule, id)
if not rule.anns[id] then
rule.anns[id] = {}
end
-- Fix that for flexibe layers number
if rule.anns[id].ann then
if not is_ann_valid(rule, prefix, rule.anns[id].ann) then
rule.anns[id].ann_train = create_ann(n, rule.nlayers)
rule.anns[id].ann = nil
rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
elseif rule.train.max_usages > 0 and
rule.anns[id].version % rule.train.max_usages == 0 then
-- Forget last ann
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
rule.anns[id].version)
rule.anns[id].ann_train = create_ann(n, rule.nlayers)
else
rule.anns[id].ann_train = rule.anns[id].ann
rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix)
end
else
rule.anns[id].ann_train = create_ann(n, rule.nlayers)
rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix)
rule.anns[id].version = 0
end
end

local function load_or_invalidate_ann(rule, data, id, ev_base)
local ver = data[2]
local prefix = gen_ann_prefix(rule, id)

if not ver or not tonumber(ver) then
rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix)
return
end

local err,ann_data = rspamd_util.zstd_decompress(data[1])
local ann

if err or not ann_data then
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
return
else
ann = rspamd_kann.load(ann_data)
end

if is_ann_valid(rule, prefix, ann) then
if not rule.anns[id] then rule.anns[id] = {} end
rule.anns[id].ann = ann
rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
prefix, ver)
rule.anns[id].version = tonumber(ver)
else
local function redis_invalidate_cb(_err, _data)
if _err then
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
elseif type(_data) == 'string' then
rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
rule.anns[id].version = 0
end
end
-- Invalidate ANN
rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix)
lua_redis.exec_redis_script(redis_maybe_invalidate_id,
{ev_base = ev_base, is_write = true},
redis_invalidate_cb,
{prefix})
end
end

local function ann_train_callback(rule, task, score, required_score, id)
local train_opts = rule['train']
@@ -901,6 +896,7 @@ local function cleanup_anns(rule, cfg, ev_base)
end
end
end

lua_redis.exec_redis_script(redis_maybe_invalidate_id,
{ev_base = ev_base, is_write = true},
invalidate_cb,
@@ -1095,6 +1091,7 @@ for k,rule in pairs(settings.rules) do
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
-- Clean old ANNs
cleanup_anns(rule, cfg, ev_base)
return maybe_train_anns(rule, cfg, ev_base, worker)
end)

Loading…
Cancel
Save