diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-05 18:46:30 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-05 18:46:30 +0100 |
commit | 178cb973ce7de1738fecdf5ca3e6ebf43bc51020 (patch) | |
tree | 45dc9df1daf82a5c054ded12499a8401f20ce4ff /src/plugins | |
parent | 6ba7b6d32f6bde52b22c4e52381384495b52154a (diff) | |
download | rspamd-178cb973ce7de1738fecdf5ca3e6ebf43bc51020.tar.gz rspamd-178cb973ce7de1738fecdf5ca3e6ebf43bc51020.zip |
[Project] Neural: Start new NN profiles implementation
Diffstat (limited to 'src/plugins')
-rw-r--r-- | src/plugins/lua/multimap.lua | 2 | ||||
-rw-r--r-- | src/plugins/lua/neural.lua | 315 |
2 files changed, 241 insertions, 76 deletions
diff --git a/src/plugins/lua/multimap.lua b/src/plugins/lua/multimap.lua index 5db8d4680..9c4861e42 100644 --- a/src/plugins/lua/multimap.lua +++ b/src/plugins/lua/multimap.lua @@ -265,7 +265,7 @@ local function apply_addr_filter(task, filter, input, rule) end else -- regexp case - if not rule['re_filter'] then + if not rule['re_filter'] then local type,pat = string.match(filter, '(regexp:)(.+)') if type and pat then rule['re_filter'] = regexp.create(pat) diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index dbd420257..ff53249c5 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -25,6 +25,7 @@ local rspamd_kann = require "rspamd_kann" local lua_redis = require "lua_redis" local lua_util = require "lua_util" local fun = require "fun" +local lua_settings = require "lua_settings" local meta_functions = require "lua_meta" local N = "neural" @@ -41,10 +42,7 @@ local default_options = { learn_threads = 1, learning_rate = 0.01, }, - use_settings = false, - per_user = false, watch_interval = 60.0, - nlayers = 4, lock_expire = 600, learning_spawned = false, ann_expire = 60 * 60 * 24 * 2, -- 2 days @@ -53,7 +51,9 @@ local default_options = { } local settings = { - rules = {} + rules = {}, + prefix = 'rn', -- Neural network default prefix + max_profiles = 3, -- Maximum number of NN profiles stored } local opts = rspamd_config:get_all_opt("neural") @@ -124,20 +124,23 @@ local redis_lua_script_maybe_load = [[ ]] local redis_maybe_load_id = nil --- Lua script to invalidate ANN from redis +-- Lua script to invalidate ANNs by rank -- Uses the following keys -- key1 - prefix for keys +-- key2 - number of elements to leave local redis_lua_script_maybe_invalidate = [[ - local locked = redis.call('GET', KEYS[1] .. '_locked') - if locked then return false end - redis.call('SET', KEYS[1] .. '_locked', '1') - redis.call('SET', KEYS[1] .. '_version', '0') - redis.call('DEL', KEYS[1] .. '_spam') - redis.call('DEL', KEYS[1] .. '_ham') - redis.call('DEL', KEYS[1] .. '_data') - redis.call('DEL', KEYS[1] .. '_locked') - redis.call('DEL', KEYS[1] .. '_hostname') - return 1 + local card = redis.call('ZCARD', KEYS[1]) + if card > tonumber(KEYS[2]) then + local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))) + for _,k in ipairs(to_delete) do + local tb = cjson.decode(k) + redis.call('DEL', tb.ann_key) + end + redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))) + return to_delete + else + return {} + end ]] local redis_maybe_invalidate_id = nil @@ -209,23 +212,6 @@ local function load_scripts(params) params) end -local function gen_ann_prefix(rule, id) - local cksum = rspamd_config:get_symbols_cksum():hex() - -- We also need to count metatokens: - local n = meta_functions.rspamd_count_metatokens() - local tprefix = 'k' - if id then - return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id - else - return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil - end -end - -local function is_ann_valid(rule, prefix, ann) - if ann then - return true - end -end local function ann_scores_filter(task) @@ -732,57 +718,153 @@ local function maybe_train_anns(rule, cfg, ev_base, worker) return rule.watch_interval end -local function check_anns(rule, _, ev_base) - local function members_cb(err, data) - if err then - rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', - err) - elseif type(data) == 'table' then - fun.each(function(elt) - elt = tostring(elt) - local redis_update_cb = function(_err, _data) - if _err then - rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', - elt, _err) - elseif _data and type(_data) == 'table' then - load_or_invalidate_ann(rule, _data, elt, ev_base) - else - if type(_data) ~= 'number' then - rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s', - type(_data), elt) - end - end +-- This function loads new ann from Redis +-- This is based on `profile` attribute. +-- ANN is loaded from `profile.ann_key` +-- Rank of `profile` key is also increased, unfortunately, it means that we need to +-- serialize profile one more time and set its rank to the current time +-- set.ann fields are set according to Redis data received +local function load_new_ann(rule, ev_base, set, profile, min_diff) + +end + +-- Used to check an element in Redis serialized as JSON +-- for some specific rule + some specific setting +-- This function tries to load more fresh or more specific ANNs in lieu of +-- the existing ones. +local function process_existing_ann(rule, ev_base, set, profiles) + local my_symbols = set.symbols + local min_diff = math.huge + local sel_elt + + for _,elt in fun.iter(profiles) do + if elt and elt.symbols then + local dist = lua_util.distance_sorted(elt.symbols, my_symbols) + + if dist < #my_symbols * .3 then + if dist < min_diff then + min_diff = dist + sel_elt = elt end + end + end + end - local local_ver = 0 - if rule.anns[elt] then - if rule.anns[elt].version then - local_ver = rule.anns[elt].version - end + if sel_elt then + -- We can load element from ANN + if set.ann then + -- We have an existing ANN, probably the same... + if set.ann.digest == sel_elt.digest then + -- Same ANN, check version + if set.ann.version < sel_elt.version then + -- Load new ann + rspamd_logger.infox(rspamd_config, 'ann %s is changed,' .. + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + else + lua_util.debugm(N, rspamd_config, 'ann %s is not changed,' .. + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) end - lua_redis.exec_redis_script(redis_maybe_load_id, - {ev_base = ev_base, is_write = false}, - redis_update_cb, - {gen_ann_prefix(rule, elt), tostring(local_ver)}) - end, - data) + else + -- We have some different ANN, so we need to compare distance + if set.ann.distance > min_diff then + -- Load more specific ANN + rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s,' .. + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + else + lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific,' .. + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) + end + end + else + -- We have no ANN, load new one + load_new_ann(rule, ev_base, set, sel_elt, min_diff) end end +end - -- First we need to get all anns stored in our Redis - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - members_cb, --callback - 'SMEMBERS', -- command - {gen_ann_prefix(rule, nil)} -- arguments - ) +-- Used to deserialise ANN element from a list +local function load_ann_profile(element) + local ucl = require "ucl" + + local parser = ucl.parser() + local res,ucl_err = parser:parse_string(element) + if not res then + rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s', + ucl_err) + return nil + else + return parser:get_object() + end +end + +-- Function to check or load ANNs from Redis +local function check_anns(rule, cfg, ev_base) + for _,set in pairs(rule.settings) do + local function members_cb(err, data) + if err then + rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', + err) + elseif type(data) == 'table' then + process_existing_ann(rule, ev_base, set, fun.map(load_ann_profile, data)) + end + end + + -- Extract all profiles for some specific settings id + -- Get the last `max_profiles` recently used + -- Select the most appropriate to our profile but it should not differ by more + -- than 30% of symbols + lua_redis.redis_make_request_taskless(ev_base, + cfg, + rule.redis, + nil, + false, -- is write + members_cb, --callback + 'ZREVRANGE', -- command + {set.prefix, '0', tostring(settings.max_profiles)} -- arguments + ) + end -- Cycle over all settings return rule.watch_interval end +-- Function to clean up old ANNs +local function cleanup_anns(rule, cfg, ev_base) + for _,set in pairs(rule.settings) do + local function invalidate_cb(err, data) + if err then + rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s', + err) + elseif type(data) == 'table' then + for _,expired in ipairs(data) do + local profile = load_ann_profile(expired) + rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s', + rule.prefix .. ':' .. set.name, + profile.ann_key, + profile.version) + end + end + end + lua_redis.exec_redis_script(redis_maybe_invalidate_id, + {ev_base = ev_base, is_write = true}, + invalidate_cb, + {set.prefix, tostring(settings.max_profiles)}) + end +end + local function ann_push_vector(task) if task:has_flag('skip') then return end if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end @@ -800,6 +882,83 @@ local function ann_push_vector(task) end end + +-- Generate redis prefix for specific rule and specific settings +local function redis_ann_prefix(rule, settings_name) + -- We also need to count metatokens: + local n = meta_functions.version + return string.format('%s_%s_%d_%s', + settings.prefix, rule.prefix, n, settings_name) +end + +-- This function is used to adjust profiles and allowed setting ids for each rule +-- It must be called when all settings are already registered (e.g. at post-init for config) +local function process_rules_settings() + local function process_settings_elt(rule, selt) + local profile = rule.profile[selt.name] + if profile then + -- Use static user defined profile + -- Ensure that we have an array... + lua_util.debugm(N, rspamd_config, "use static profile for %s (%s)", + rule.prefix, selt.name) + if not profile[1] then profile = lua_util.keys(profile) end + selt.symbols = profile + else + lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)", + rule.prefix, selt.name) + end + + -- Generic stuff + table.sort(selt.symbols) + selt.digest = lua_util.table_digest(selt.symbols) + selt.prefix = redis_ann_prefix(rule, selt.name) + + lua_redis.register_prefix(selt.prefix, N, + string.format('NN prefix for rule "%s"; settings id "%s"', + rule.prefix, selt.name)) + end + + for _,rule in pairs(opts.rules) do + if not rule.allowed_settings then + -- Extract all settings ids + rule.allowed_settings = lua_util.keys(lua_settings.all_settings) + end + + -- Convert to a map <setting_id> -> true + rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings) + + -- Check if we can work without settings + if type(rule.default) ~= 'boolean' then + rule.default = true + end + + rule.settings = {} + + if rule.default then + local default_settings = { + symbols = lua_util.keys(lua_settings.default_symbols), + name = 'default' + } + + process_settings_elt(rule, default_settings) + rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32 + end + + -- Now, for each allowed settings, we store sorted symbols + digest + -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest } + for s,_ in pairs(rule.allowed_settings) do + -- Here, we have a name, set of symbols and + local selt = lua_settings.settings_by_id(s) + rule.settings[s] = { + symbols = selt.symbols, -- Already sorted + name = selt.name + } + + process_settings_elt(rule, rule.settings[s]) + end + end +end + redis_params = lua_redis.parse_redis_server('neural') if not redis_params then @@ -818,7 +977,7 @@ local rules = opts['rules'] if not rules then -- Use legacy configuration rules = {} - rules['RFANN'] = opts + rules['default'] = opts end local id = rspamd_config:register_symbol({ @@ -827,6 +986,7 @@ local id = rspamd_config:register_symbol({ priority = 6, callback = ann_scores_filter }) + for k,r in pairs(rules) do local def_rules = lua_util.override_defaults(default_options, r) def_rules['redis'] = redis_params @@ -841,6 +1001,7 @@ for k,r in pairs(rules) do if def_rules.train.max_train then def_rules.train.max_trains = def_rules.train.max_train end + rspamd_logger.infox(rspamd_config, "register ann rule %s", k) settings.rules[k] = def_rules rspamd_config:set_metric_symbol({ @@ -876,8 +1037,11 @@ rspamd_config:register_symbol({ }) -- Add training scripts -for _,rule in pairs(settings.rules) do +for k,rule in pairs(settings.rules) do load_scripts(rule.redis) + -- We also need to deal with settings + rspamd_config:add_post_init(process_rules_settings) + -- This function will check ANNs in Redis when a worker is loaded rspamd_config:add_on_load(function(cfg, ev_base, worker) rspamd_config:add_periodic(ev_base, 0.0, function(_, _) @@ -888,6 +1052,7 @@ for _,rule in pairs(settings.rules) do -- We also want to train neural nets when they have enough data rspamd_config:add_periodic(ev_base, 0.0, function(_, _) + cleanup_anns(rule, cfg, ev_base) return maybe_train_anns(rule, cfg, ev_base, worker) end) end |