Parcourir la source

[Project] Neural: Start new NN profiles implementation

tags/2.0
Vsevolod Stakhov il y a 4 ans
Parent
révision
178cb973ce
2 fichiers modifiés avec 241 ajouts et 75 suppressions
  1. 1
    1
      src/plugins/lua/multimap.lua
  2. 240
    74
      src/plugins/lua/neural.lua

+ 1
- 1
src/plugins/lua/multimap.lua Voir le fichier

@@ -265,7 +265,7 @@ local function apply_addr_filter(task, filter, input, rule)
end
else
-- regexp case
if not rule['re_filter'] then
if not rule['re_filter'] then
local type,pat = string.match(filter, '(regexp:)(.+)')
if type and pat then
rule['re_filter'] = regexp.create(pat)

+ 240
- 74
src/plugins/lua/neural.lua Voir le fichier

@@ -25,6 +25,7 @@ local rspamd_kann = require "rspamd_kann"
local lua_redis = require "lua_redis"
local lua_util = require "lua_util"
local fun = require "fun"
local lua_settings = require "lua_settings"
local meta_functions = require "lua_meta"
local N = "neural"

@@ -41,10 +42,7 @@ local default_options = {
learn_threads = 1,
learning_rate = 0.01,
},
use_settings = false,
per_user = false,
watch_interval = 60.0,
nlayers = 4,
lock_expire = 600,
learning_spawned = false,
ann_expire = 60 * 60 * 24 * 2, -- 2 days
@@ -53,7 +51,9 @@ local default_options = {
}

local settings = {
rules = {}
rules = {},
prefix = 'rn', -- Neural network default prefix
max_profiles = 3, -- Maximum number of NN profiles stored
}

local opts = rspamd_config:get_all_opt("neural")
@@ -124,20 +124,23 @@ local redis_lua_script_maybe_load = [[
]]
local redis_maybe_load_id = nil

-- Lua script to invalidate ANNs by rank
-- Uses the following keys
-- key1 - prefix for keys
-- key2 - number of elements to leave
local redis_lua_script_maybe_invalidate = [[
local locked = redis.call('GET', KEYS[1] .. '_locked')
if locked then return false end
redis.call('SET', KEYS[1] .. '_locked', '1')
redis.call('SET', KEYS[1] .. '_version', '0')
redis.call('DEL', KEYS[1] .. '_spam')
redis.call('DEL', KEYS[1] .. '_ham')
redis.call('DEL', KEYS[1] .. '_data')
redis.call('DEL', KEYS[1] .. '_locked')
redis.call('DEL', KEYS[1] .. '_hostname')
return 1
local card = redis.call('ZCARD', KEYS[1])
if card > tonumber(KEYS[2]) then
local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
for _,k in ipairs(to_delete) do
local tb = cjson.decode(k)
redis.call('DEL', tb.ann_key)
end
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
return to_delete
else
return {}
end
]]
local redis_maybe_invalidate_id = nil

@@ -209,23 +212,6 @@ local function load_scripts(params)
params)
end

local function gen_ann_prefix(rule, id)
local cksum = rspamd_config:get_symbols_cksum():hex()
-- We also need to count metatokens:
local n = meta_functions.rspamd_count_metatokens()
local tprefix = 'k'
if id then
return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id
else
return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil
end
end

local function is_ann_valid(rule, prefix, ann)
if ann then
return true
end
end

local function ann_scores_filter(task)

@@ -732,57 +718,153 @@ local function maybe_train_anns(rule, cfg, ev_base, worker)
return rule.watch_interval
end

local function check_anns(rule, _, ev_base)
local function members_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s',
err)
elseif type(data) == 'table' then
fun.each(function(elt)
elt = tostring(elt)
local redis_update_cb = function(_err, _data)
if _err then
rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s',
elt, _err)
elseif _data and type(_data) == 'table' then
load_or_invalidate_ann(rule, _data, elt, ev_base)
else
if type(_data) ~= 'number' then
rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s',
type(_data), elt)
end
end
-- This function loads new ann from Redis
-- This is based on `profile` attribute.
-- ANN is loaded from `profile.ann_key`
-- Rank of `profile` key is also increased, unfortunately, it means that we need to
-- serialize profile one more time and set its rank to the current time
-- set.ann fields are set according to Redis data received
local function load_new_ann(rule, ev_base, set, profile, min_diff)

end

-- Used to check an element in Redis serialized as JSON
-- for some specific rule + some specific setting
-- This function tries to load more fresh or more specific ANNs in lieu of
-- the existing ones.
local function process_existing_ann(rule, ev_base, set, profiles)
local my_symbols = set.symbols
local min_diff = math.huge
local sel_elt

for _,elt in fun.iter(profiles) do
if elt and elt.symbols then
local dist = lua_util.distance_sorted(elt.symbols, my_symbols)

if dist < #my_symbols * .3 then
if dist < min_diff then
min_diff = dist
sel_elt = elt
end
end
end
end

local local_ver = 0
if rule.anns[elt] then
if rule.anns[elt].version then
local_ver = rule.anns[elt].version
end
if sel_elt then
-- We can load element from ANN
if set.ann then
-- We have an existing ANN, probably the same...
if set.ann.digest == sel_elt.digest then
-- Same ANN, check version
if set.ann.version < sel_elt.version then
-- Load new ann
rspamd_logger.infox(rspamd_config, 'ann %s is changed,' ..
'our version = %s, remote version = %s',
rule.prefix .. ':' .. set.name,
set.ann.version,
sel_elt.version)
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
else
lua_util.debugm(N, rspamd_config, 'ann %s is not changed,' ..
'our version = %s, remote version = %s',
rule.prefix .. ':' .. set.name,
set.ann.version,
sel_elt.version)
end
lua_redis.exec_redis_script(redis_maybe_load_id,
{ev_base = ev_base, is_write = false},
redis_update_cb,
{gen_ann_prefix(rule, elt), tostring(local_ver)})
end,
data)
else
-- We have some different ANN, so we need to compare distance
if set.ann.distance > min_diff then
-- Load more specific ANN
rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s,' ..
'our distance = %s, remote distance = %s',
rule.prefix .. ':' .. set.name,
set.ann.distance,
min_diff)
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
else
lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific,' ..
'our distance = %s, remote distance = %s',
rule.prefix .. ':' .. set.name,
set.ann.distance,
min_diff)
end
end
else
-- We have no ANN, load new one
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
end
end
end

-- First we need to get all anns stored in our Redis
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false, -- is write
members_cb, --callback
'SMEMBERS', -- command
{gen_ann_prefix(rule, nil)} -- arguments
)
-- Used to deserialise ANN element from a list
local function load_ann_profile(element)
local ucl = require "ucl"

local parser = ucl.parser()
local res,ucl_err = parser:parse_string(element)
if not res then
rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
ucl_err)
return nil
else
return parser:get_object()
end
end

-- Function to check or load ANNs from Redis
local function check_anns(rule, cfg, ev_base)
for _,set in pairs(rule.settings) do
local function members_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
err)
elseif type(data) == 'table' then
process_existing_ann(rule, ev_base, set, fun.map(load_ann_profile, data))
end
end

-- Extract all profiles for some specific settings id
-- Get the last `max_profiles` recently used
-- Select the most appropriate to our profile but it should not differ by more
-- than 30% of symbols
lua_redis.redis_make_request_taskless(ev_base,
cfg,
rule.redis,
nil,
false, -- is write
members_cb, --callback
'ZREVRANGE', -- command
{set.prefix, '0', tostring(settings.max_profiles)} -- arguments
)
end -- Cycle over all settings

return rule.watch_interval
end

-- Function to clean up old ANNs
local function cleanup_anns(rule, cfg, ev_base)
for _,set in pairs(rule.settings) do
local function invalidate_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
err)
elseif type(data) == 'table' then
for _,expired in ipairs(data) do
local profile = load_ann_profile(expired)
rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
rule.prefix .. ':' .. set.name,
profile.ann_key,
profile.version)
end
end
end
lua_redis.exec_redis_script(redis_maybe_invalidate_id,
{ev_base = ev_base, is_write = true},
invalidate_cb,
{set.prefix, tostring(settings.max_profiles)})
end
end

local function ann_push_vector(task)
if task:has_flag('skip') then return end
if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
@@ -800,6 +882,83 @@ local function ann_push_vector(task)
end
end


-- Generate redis prefix for specific rule and specific settings
local function redis_ann_prefix(rule, settings_name)
-- We also need to count metatokens:
local n = meta_functions.version
return string.format('%s_%s_%d_%s',
settings.prefix, rule.prefix, n, settings_name)
end

-- This function is used to adjust profiles and allowed setting ids for each rule
-- It must be called when all settings are already registered (e.g. at post-init for config)
local function process_rules_settings()
local function process_settings_elt(rule, selt)
local profile = rule.profile[selt.name]
if profile then
-- Use static user defined profile
-- Ensure that we have an array...
lua_util.debugm(N, rspamd_config, "use static profile for %s (%s)",
rule.prefix, selt.name)
if not profile[1] then profile = lua_util.keys(profile) end
selt.symbols = profile
else
lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
rule.prefix, selt.name)
end

-- Generic stuff
table.sort(selt.symbols)
selt.digest = lua_util.table_digest(selt.symbols)
selt.prefix = redis_ann_prefix(rule, selt.name)

lua_redis.register_prefix(selt.prefix, N,
string.format('NN prefix for rule "%s"; settings id "%s"',
rule.prefix, selt.name))
end

for _,rule in pairs(opts.rules) do
if not rule.allowed_settings then
-- Extract all settings ids
rule.allowed_settings = lua_util.keys(lua_settings.all_settings)
end

-- Convert to a map <setting_id> -> true
rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)

-- Check if we can work without settings
if type(rule.default) ~= 'boolean' then
rule.default = true
end

rule.settings = {}

if rule.default then
local default_settings = {
symbols = lua_util.keys(lua_settings.default_symbols),
name = 'default'
}

process_settings_elt(rule, default_settings)
rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
end

-- Now, for each allowed settings, we store sorted symbols + digest
-- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
for s,_ in pairs(rule.allowed_settings) do
-- Here, we have a name, set of symbols and
local selt = lua_settings.settings_by_id(s)
rule.settings[s] = {
symbols = selt.symbols, -- Already sorted
name = selt.name
}

process_settings_elt(rule, rule.settings[s])
end
end
end

redis_params = lua_redis.parse_redis_server('neural')

if not redis_params then
@@ -818,7 +977,7 @@ local rules = opts['rules']
if not rules then
-- Use legacy configuration
rules = {}
rules['RFANN'] = opts
rules['default'] = opts
end

local id = rspamd_config:register_symbol({
@@ -827,6 +986,7 @@ local id = rspamd_config:register_symbol({
priority = 6,
callback = ann_scores_filter
})

for k,r in pairs(rules) do
local def_rules = lua_util.override_defaults(default_options, r)
def_rules['redis'] = redis_params
@@ -841,6 +1001,7 @@ for k,r in pairs(rules) do
if def_rules.train.max_train then
def_rules.train.max_trains = def_rules.train.max_train
end

rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
settings.rules[k] = def_rules
rspamd_config:set_metric_symbol({
@@ -876,8 +1037,11 @@ rspamd_config:register_symbol({
})

-- Add training scripts
for _,rule in pairs(settings.rules) do
for k,rule in pairs(settings.rules) do
load_scripts(rule.redis)
-- We also need to deal with settings
rspamd_config:add_post_init(process_rules_settings)
-- This function will check ANNs in Redis when a worker is loaded
rspamd_config:add_on_load(function(cfg, ev_base, worker)
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
@@ -888,6 +1052,7 @@ for _,rule in pairs(settings.rules) do
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
cleanup_anns(rule, cfg, ev_base)
return maybe_train_anns(rule, cfg, ev_base, worker)
end)
end

Chargement…
Annuler
Enregistrer