diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/neural.lua | 169 |
1 files changed, 83 insertions, 86 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index cca6f647c..fdb138321 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -50,6 +50,26 @@ local default_options = { symbol_ham = 'NEURAL_HAM', } +-- Rule structure: +-- * static config fields (see `default_options`) +-- * prefix - name or defined prefix +-- * settings - table of settings indexed by settings id, -1 is used when no settings defined + +-- Rule settings element defines elements for specific settings id: +-- * symbols - static symbols profile (defined by config or extracted from symcache) +-- * name - name of settings id +-- * digest - digest of all symbols +-- * ann - dynamic ANN configuration loaded from Redis +-- * train - train data for ANN (e.g. the currently trained ANN) + +-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN +-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically +-- * version - version of ANN loaded from redis +-- * ann_key - name of ANN key in Redis +-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols) +-- * distance - distance between set.symbols and set.ann.symbols +-- * ann - kann object + local settings = { rules = {}, prefix = 'rn', -- Neural network default prefix @@ -212,30 +232,75 @@ local function load_scripts(params) params) end +local function result_to_vector(task, profile) + if not profile.zeros then + -- Fill zeros vector + local zeros = {} + for i=1,meta_functions.count_metatokens() do + zeros[i] = 0.0 + end + for _,_ in ipairs(profile.symbols) do + zeros[#zeros + 1] = 0.0 + end + profile.zeros = zeros + end + + local vec = lua_util.shallowcopy(profile.zeros) + local mt = meta_functions.rspamd_gen_metatokens(task) + + for i,v in ipairs(mt) do + vec[i] = v + end + + task:process_ann_tokens(profile.symbols, vec, #mt) + + return vec +end local function ann_scores_filter(task) for _,rule in pairs(settings.rules) do - local id = '0' - if rule.use_settings then - local sid = task:get_settings_id() - if sid then - id = tostring(sid) - end - end - if rule.per_user then - local r = task:get_principal_recipient() - id = id .. r + local sid = task:get_settings_id() + local ann + local profile + + if sid then + if rule.settings[sid] then + local set = rule.settings[sid] + + if set.ann then + ann = set.ann.ann + profile = set.ann + else + lua_util.debugm(N, task, 'no ann loaded for %s:%s', + rule.prefix, set.name) + end + else + lua_util.debugm(N, task, 'no ann defined in %s for settings id %s', + rule.prefix, sid) + end + else + if rule.settings[-1] then + local set = rule.settings[-1] + + if set.ann then + ann = set.ann.ann + profile = set.ann + else + lua_util.debugm(N, task, 'no ann loaded for %s:%s', + rule.prefix, set.name) + end + else + lua_util.debugm(N, task, 'no default ann for rule %s', + rule.prefix) + end end - if rule.anns[id] and rule.anns[id].ann then - local ann_data = task:get_symbols_tokens() - local mt = meta_functions.rspamd_gen_metatokens(task) - -- Add filtered meta tokens - fun.each(function(e) table.insert(ann_data, e) end, mt) + if ann then + local vec = result_to_vector(task, profile) local score - local out = rule.anns[id].ann:apply1(ann_data) + local out = ann:apply1(vec) score = out[1] local symscore = string.format('%.3f', score) @@ -262,76 +327,6 @@ local function create_ann(n, nlayers) return rspamd_kann.new.kann(t) end -local function create_train_ann(rule, n, id) - local prefix = gen_ann_prefix(rule, id) - if not rule.anns[id] then - rule.anns[id] = {} - end - -- Fix that for flexibe layers number - if rule.anns[id].ann then - if not is_ann_valid(rule, prefix, rule.anns[id].ann) then - rule.anns[id].ann_train = create_ann(n, rule.nlayers) - rule.anns[id].ann = nil - rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix) - elseif rule.train.max_usages > 0 and - rule.anns[id].version % rule.train.max_usages == 0 then - -- Forget last ann - rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, - rule.anns[id].version) - rule.anns[id].ann_train = create_ann(n, rule.nlayers) - else - rule.anns[id].ann_train = rule.anns[id].ann - rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix) - end - else - rule.anns[id].ann_train = create_ann(n, rule.nlayers) - rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix) - rule.anns[id].version = 0 - end -end - -local function load_or_invalidate_ann(rule, data, id, ev_base) - local ver = data[2] - local prefix = gen_ann_prefix(rule, id) - - if not ver or not tonumber(ver) then - rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix) - return - end - - local err,ann_data = rspamd_util.zstd_decompress(data[1]) - local ann - - if err or not ann_data then - rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err) - return - else - ann = rspamd_kann.load(ann_data) - end - - if is_ann_valid(rule, prefix, ann) then - if not rule.anns[id] then rule.anns[id] = {} end - rule.anns[id].ann = ann - rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis', - prefix, ver) - rule.anns[id].version = tonumber(ver) - else - local function redis_invalidate_cb(_err, _data) - if _err then - rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) - elseif type(_data) == 'string' then - rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) - rule.anns[id].version = 0 - end - end - -- Invalidate ANN - rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix) - lua_redis.exec_redis_script(redis_maybe_invalidate_id, - {ev_base = ev_base, is_write = true}, - redis_invalidate_cb, - {prefix}) - end -end local function ann_train_callback(rule, task, score, required_score, id) local train_opts = rule['train'] @@ -901,6 +896,7 @@ local function cleanup_anns(rule, cfg, ev_base) end end end + lua_redis.exec_redis_script(redis_maybe_invalidate_id, {ev_base = ev_base, is_write = true}, invalidate_cb, @@ -1095,6 +1091,7 @@ for k,rule in pairs(settings.rules) do -- We also want to train neural nets when they have enough data rspamd_config:add_periodic(ev_base, 0.0, function(_, _) + -- Clean old ANNs cleanup_anns(rule, cfg, ev_base) return maybe_train_anns(rule, cfg, ev_base, worker) end) |