local lua_redis = require "lua_redis"
local lua_util = require "lua_util"
local fun = require "fun"
+local lua_settings = require "lua_settings"
local meta_functions = require "lua_meta"
local N = "neural"
learn_threads = 1,
learning_rate = 0.01,
},
- use_settings = false,
- per_user = false,
watch_interval = 60.0,
- nlayers = 4,
lock_expire = 600,
learning_spawned = false,
ann_expire = 60 * 60 * 24 * 2, -- 2 days
}
local settings = {
- rules = {}
+ rules = {},
+ prefix = 'rn', -- Neural network default prefix
+ max_profiles = 3, -- Maximum number of NN profiles stored
}
local opts = rspamd_config:get_all_opt("neural")
]]
local redis_maybe_load_id = nil
--- Lua script to invalidate ANN from redis
+-- Lua script to invalidate ANNs by rank
-- Uses the following keys
-- key1 - prefix for keys
+-- key2 - number of elements to leave
local redis_lua_script_maybe_invalidate = [[
- local locked = redis.call('GET', KEYS[1] .. '_locked')
- if locked then return false end
- redis.call('SET', KEYS[1] .. '_locked', '1')
- redis.call('SET', KEYS[1] .. '_version', '0')
- redis.call('DEL', KEYS[1] .. '_spam')
- redis.call('DEL', KEYS[1] .. '_ham')
- redis.call('DEL', KEYS[1] .. '_data')
- redis.call('DEL', KEYS[1] .. '_locked')
- redis.call('DEL', KEYS[1] .. '_hostname')
- return 1
+ local card = redis.call('ZCARD', KEYS[1])
+ if card > tonumber(KEYS[2]) then
+ local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
+ for _,k in ipairs(to_delete) do
+ local tb = cjson.decode(k)
+ redis.call('DEL', tb.ann_key)
+ end
+ redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
+ return to_delete
+ else
+ return {}
+ end
]]
local redis_maybe_invalidate_id = nil
params)
end
-local function gen_ann_prefix(rule, id)
- local cksum = rspamd_config:get_symbols_cksum():hex()
- -- We also need to count metatokens:
- local n = meta_functions.rspamd_count_metatokens()
- local tprefix = 'k'
- if id then
- return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id
- else
- return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil
- end
-end
-
-local function is_ann_valid(rule, prefix, ann)
- if ann then
- return true
- end
-end
local function ann_scores_filter(task)
return rule.watch_interval
end
-local function check_anns(rule, _, ev_base)
- local function members_cb(err, data)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s',
- err)
- elseif type(data) == 'table' then
- fun.each(function(elt)
- elt = tostring(elt)
- local redis_update_cb = function(_err, _data)
- if _err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s',
- elt, _err)
- elseif _data and type(_data) == 'table' then
- load_or_invalidate_ann(rule, _data, elt, ev_base)
- else
- if type(_data) ~= 'number' then
- rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s',
- type(_data), elt)
- end
- end
+-- This function loads new ann from Redis
+-- This is based on `profile` attribute.
+-- ANN is loaded from `profile.ann_key`
+-- Rank of `profile` key is also increased, unfortunately, it means that we need to
+-- serialize profile one more time and set its rank to the current time
+-- set.ann fields are set according to Redis data received
+local function load_new_ann(rule, ev_base, set, profile, min_diff)
+
+end
+
+-- Used to check an element in Redis serialized as JSON
+-- for some specific rule + some specific setting
+-- This function tries to load more fresh or more specific ANNs in lieu of
+-- the existing ones.
+local function process_existing_ann(rule, ev_base, set, profiles)
+ local my_symbols = set.symbols
+ local min_diff = math.huge
+ local sel_elt
+
+ for _,elt in fun.iter(profiles) do
+ if elt and elt.symbols then
+ local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
+
+ if dist < #my_symbols * .3 then
+ if dist < min_diff then
+ min_diff = dist
+ sel_elt = elt
end
+ end
+ end
+ end
- local local_ver = 0
- if rule.anns[elt] then
- if rule.anns[elt].version then
- local_ver = rule.anns[elt].version
- end
+ if sel_elt then
+ -- We can load element from ANN
+ if set.ann then
+ -- We have an existing ANN, probably the same...
+ if set.ann.digest == sel_elt.digest then
+ -- Same ANN, check version
+ if set.ann.version < sel_elt.version then
+ -- Load new ann
+ rspamd_logger.infox(rspamd_config, 'ann %s is changed,' ..
+ 'our version = %s, remote version = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.version,
+ sel_elt.version)
+ load_new_ann(rule, ev_base, set, sel_elt, min_diff)
+ else
+ lua_util.debugm(N, rspamd_config, 'ann %s is not changed,' ..
+ 'our version = %s, remote version = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.version,
+ sel_elt.version)
end
- lua_redis.exec_redis_script(redis_maybe_load_id,
- {ev_base = ev_base, is_write = false},
- redis_update_cb,
- {gen_ann_prefix(rule, elt), tostring(local_ver)})
- end,
- data)
+ else
+ -- We have some different ANN, so we need to compare distance
+ if set.ann.distance > min_diff then
+ -- Load more specific ANN
+ rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s,' ..
+ 'our distance = %s, remote distance = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.distance,
+ min_diff)
+ load_new_ann(rule, ev_base, set, sel_elt, min_diff)
+ else
+ lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific,' ..
+ 'our distance = %s, remote distance = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.distance,
+ min_diff)
+ end
+ end
+ else
+ -- We have no ANN, load new one
+ load_new_ann(rule, ev_base, set, sel_elt, min_diff)
end
end
+end
- -- First we need to get all anns stored in our Redis
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- members_cb, --callback
- 'SMEMBERS', -- command
- {gen_ann_prefix(rule, nil)} -- arguments
- )
+-- Used to deserialise ANN element from a list
+local function load_ann_profile(element)
+ local ucl = require "ucl"
+
+ local parser = ucl.parser()
+ local res,ucl_err = parser:parse_string(element)
+ if not res then
+ rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
+ ucl_err)
+ return nil
+ else
+ return parser:get_object()
+ end
+end
+
+-- Function to check or load ANNs from Redis
+local function check_anns(rule, cfg, ev_base)
+ for _,set in pairs(rule.settings) do
+ local function members_cb(err, data)
+ if err then
+ rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
+ err)
+ elseif type(data) == 'table' then
+ process_existing_ann(rule, ev_base, set, fun.map(load_ann_profile, data))
+ end
+ end
+
+ -- Extract all profiles for some specific settings id
+ -- Get the last `max_profiles` recently used
+ -- Select the most appropriate to our profile but it should not differ by more
+ -- than 30% of symbols
+ lua_redis.redis_make_request_taskless(ev_base,
+ cfg,
+ rule.redis,
+ nil,
+ false, -- is write
+ members_cb, --callback
+ 'ZREVRANGE', -- command
+ {set.prefix, '0', tostring(settings.max_profiles)} -- arguments
+ )
+ end -- Cycle over all settings
return rule.watch_interval
end
+-- Function to clean up old ANNs
+local function cleanup_anns(rule, cfg, ev_base)
+ for _,set in pairs(rule.settings) do
+ local function invalidate_cb(err, data)
+ if err then
+ rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
+ err)
+ elseif type(data) == 'table' then
+ for _,expired in ipairs(data) do
+ local profile = load_ann_profile(expired)
+ rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
+ rule.prefix .. ':' .. set.name,
+ profile.ann_key,
+ profile.version)
+ end
+ end
+ end
+ lua_redis.exec_redis_script(redis_maybe_invalidate_id,
+ {ev_base = ev_base, is_write = true},
+ invalidate_cb,
+ {set.prefix, tostring(settings.max_profiles)})
+ end
+end
+
local function ann_push_vector(task)
if task:has_flag('skip') then return end
if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
end
end
+
+-- Generate redis prefix for specific rule and specific settings
+local function redis_ann_prefix(rule, settings_name)
+ -- We also need to count metatokens:
+ local n = meta_functions.version
+ return string.format('%s_%s_%d_%s',
+ settings.prefix, rule.prefix, n, settings_name)
+end
+
+-- This function is used to adjust profiles and allowed setting ids for each rule
+-- It must be called when all settings are already registered (e.g. at post-init for config)
+local function process_rules_settings()
+ local function process_settings_elt(rule, selt)
+ local profile = rule.profile[selt.name]
+ if profile then
+ -- Use static user defined profile
+ -- Ensure that we have an array...
+ lua_util.debugm(N, rspamd_config, "use static profile for %s (%s)",
+ rule.prefix, selt.name)
+ if not profile[1] then profile = lua_util.keys(profile) end
+ selt.symbols = profile
+ else
+ lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
+ rule.prefix, selt.name)
+ end
+
+ -- Generic stuff
+ table.sort(selt.symbols)
+ selt.digest = lua_util.table_digest(selt.symbols)
+ selt.prefix = redis_ann_prefix(rule, selt.name)
+
+ lua_redis.register_prefix(selt.prefix, N,
+ string.format('NN prefix for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name))
+ end
+
+ for _,rule in pairs(opts.rules) do
+ if not rule.allowed_settings then
+ -- Extract all settings ids
+ rule.allowed_settings = lua_util.keys(lua_settings.all_settings)
+ end
+
+ -- Convert to a map <setting_id> -> true
+ rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
+
+ -- Check if we can work without settings
+ if type(rule.default) ~= 'boolean' then
+ rule.default = true
+ end
+
+ rule.settings = {}
+
+ if rule.default then
+ local default_settings = {
+ symbols = lua_util.keys(lua_settings.default_symbols),
+ name = 'default'
+ }
+
+ process_settings_elt(rule, default_settings)
+ rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
+ end
+
+ -- Now, for each allowed settings, we store sorted symbols + digest
+ -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
+ for s,_ in pairs(rule.allowed_settings) do
+ -- Here, we have a name, set of symbols and
+ local selt = lua_settings.settings_by_id(s)
+ rule.settings[s] = {
+ symbols = selt.symbols, -- Already sorted
+ name = selt.name
+ }
+
+ process_settings_elt(rule, rule.settings[s])
+ end
+ end
+end
+
redis_params = lua_redis.parse_redis_server('neural')
if not redis_params then
if not rules then
-- Use legacy configuration
rules = {}
- rules['RFANN'] = opts
+ rules['default'] = opts
end
local id = rspamd_config:register_symbol({
priority = 6,
callback = ann_scores_filter
})
+
for k,r in pairs(rules) do
local def_rules = lua_util.override_defaults(default_options, r)
def_rules['redis'] = redis_params
if def_rules.train.max_train then
def_rules.train.max_trains = def_rules.train.max_train
end
+
rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
settings.rules[k] = def_rules
rspamd_config:set_metric_symbol({
})
-- Add training scripts
-for _,rule in pairs(settings.rules) do
+for k,rule in pairs(settings.rules) do
load_scripts(rule.redis)
+ -- We also need to deal with settings
+ rspamd_config:add_post_init(process_rules_settings)
+ -- This function will check ANNs in Redis when a worker is loaded
rspamd_config:add_on_load(function(cfg, ev_base, worker)
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
+ cleanup_anns(rule, cfg, ev_base)
return maybe_train_anns(rule, cfg, ev_base, worker)
end)
end