|
|
@@ -25,6 +25,7 @@ local rspamd_kann = require "rspamd_kann" |
|
|
|
local lua_redis = require "lua_redis" |
|
|
|
local lua_util = require "lua_util" |
|
|
|
local fun = require "fun" |
|
|
|
local lua_settings = require "lua_settings" |
|
|
|
local meta_functions = require "lua_meta" |
|
|
|
local N = "neural" |
|
|
|
|
|
|
@@ -41,10 +42,7 @@ local default_options = { |
|
|
|
learn_threads = 1, |
|
|
|
learning_rate = 0.01, |
|
|
|
}, |
|
|
|
use_settings = false, |
|
|
|
per_user = false, |
|
|
|
watch_interval = 60.0, |
|
|
|
nlayers = 4, |
|
|
|
lock_expire = 600, |
|
|
|
learning_spawned = false, |
|
|
|
ann_expire = 60 * 60 * 24 * 2, -- 2 days |
|
|
@@ -53,7 +51,9 @@ local default_options = { |
|
|
|
} |
|
|
|
|
|
|
|
local settings = { |
|
|
|
rules = {} |
|
|
|
rules = {}, |
|
|
|
prefix = 'rn', -- Neural network default prefix |
|
|
|
max_profiles = 3, -- Maximum number of NN profiles stored |
|
|
|
} |
|
|
|
|
|
|
|
local opts = rspamd_config:get_all_opt("neural") |
|
|
@@ -124,20 +124,23 @@ local redis_lua_script_maybe_load = [[ |
|
|
|
]] |
|
|
|
local redis_maybe_load_id = nil |
|
|
|
|
|
|
|
-- Lua script to invalidate ANNs by rank |
|
|
|
-- Uses the following keys |
|
|
|
-- key1 - prefix for keys |
|
|
|
-- key2 - number of elements to leave |
|
|
|
local redis_lua_script_maybe_invalidate = [[ |
|
|
|
local locked = redis.call('GET', KEYS[1] .. '_locked') |
|
|
|
if locked then return false end |
|
|
|
redis.call('SET', KEYS[1] .. '_locked', '1') |
|
|
|
redis.call('SET', KEYS[1] .. '_version', '0') |
|
|
|
redis.call('DEL', KEYS[1] .. '_spam') |
|
|
|
redis.call('DEL', KEYS[1] .. '_ham') |
|
|
|
redis.call('DEL', KEYS[1] .. '_data') |
|
|
|
redis.call('DEL', KEYS[1] .. '_locked') |
|
|
|
redis.call('DEL', KEYS[1] .. '_hostname') |
|
|
|
return 1 |
|
|
|
local card = redis.call('ZCARD', KEYS[1]) |
|
|
|
if card > tonumber(KEYS[2]) then |
|
|
|
local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))) |
|
|
|
for _,k in ipairs(to_delete) do |
|
|
|
local tb = cjson.decode(k) |
|
|
|
redis.call('DEL', tb.ann_key) |
|
|
|
end |
|
|
|
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))) |
|
|
|
return to_delete |
|
|
|
else |
|
|
|
return {} |
|
|
|
end |
|
|
|
]] |
|
|
|
local redis_maybe_invalidate_id = nil |
|
|
|
|
|
|
@@ -209,23 +212,6 @@ local function load_scripts(params) |
|
|
|
params) |
|
|
|
end |
|
|
|
|
|
|
|
local function gen_ann_prefix(rule, id) |
|
|
|
local cksum = rspamd_config:get_symbols_cksum():hex() |
|
|
|
-- We also need to count metatokens: |
|
|
|
local n = meta_functions.rspamd_count_metatokens() |
|
|
|
local tprefix = 'k' |
|
|
|
if id then |
|
|
|
return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id |
|
|
|
else |
|
|
|
return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
local function is_ann_valid(rule, prefix, ann) |
|
|
|
if ann then |
|
|
|
return true |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
local function ann_scores_filter(task) |
|
|
|
|
|
|
@@ -732,57 +718,153 @@ local function maybe_train_anns(rule, cfg, ev_base, worker) |
|
|
|
return rule.watch_interval |
|
|
|
end |
|
|
|
|
|
|
|
local function check_anns(rule, _, ev_base) |
|
|
|
local function members_cb(err, data) |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', |
|
|
|
err) |
|
|
|
elseif type(data) == 'table' then |
|
|
|
fun.each(function(elt) |
|
|
|
elt = tostring(elt) |
|
|
|
local redis_update_cb = function(_err, _data) |
|
|
|
if _err then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', |
|
|
|
elt, _err) |
|
|
|
elseif _data and type(_data) == 'table' then |
|
|
|
load_or_invalidate_ann(rule, _data, elt, ev_base) |
|
|
|
else |
|
|
|
if type(_data) ~= 'number' then |
|
|
|
rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s', |
|
|
|
type(_data), elt) |
|
|
|
end |
|
|
|
end |
|
|
|
-- This function loads new ann from Redis |
|
|
|
-- This is based on `profile` attribute. |
|
|
|
-- ANN is loaded from `profile.ann_key` |
|
|
|
-- Rank of `profile` key is also increased, unfortunately, it means that we need to |
|
|
|
-- serialize profile one more time and set its rank to the current time |
|
|
|
-- set.ann fields are set according to Redis data received |
|
|
|
local function load_new_ann(rule, ev_base, set, profile, min_diff) |
|
|
|
|
|
|
|
end |
|
|
|
|
|
|
|
-- Used to check an element in Redis serialized as JSON |
|
|
|
-- for some specific rule + some specific setting |
|
|
|
-- This function tries to load more fresh or more specific ANNs in lieu of |
|
|
|
-- the existing ones. |
|
|
|
local function process_existing_ann(rule, ev_base, set, profiles) |
|
|
|
local my_symbols = set.symbols |
|
|
|
local min_diff = math.huge |
|
|
|
local sel_elt |
|
|
|
|
|
|
|
for _,elt in fun.iter(profiles) do |
|
|
|
if elt and elt.symbols then |
|
|
|
local dist = lua_util.distance_sorted(elt.symbols, my_symbols) |
|
|
|
|
|
|
|
if dist < #my_symbols * .3 then |
|
|
|
if dist < min_diff then |
|
|
|
min_diff = dist |
|
|
|
sel_elt = elt |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
local local_ver = 0 |
|
|
|
if rule.anns[elt] then |
|
|
|
if rule.anns[elt].version then |
|
|
|
local_ver = rule.anns[elt].version |
|
|
|
end |
|
|
|
if sel_elt then |
|
|
|
-- We can load element from ANN |
|
|
|
if set.ann then |
|
|
|
-- We have an existing ANN, probably the same... |
|
|
|
if set.ann.digest == sel_elt.digest then |
|
|
|
-- Same ANN, check version |
|
|
|
if set.ann.version < sel_elt.version then |
|
|
|
-- Load new ann |
|
|
|
rspamd_logger.infox(rspamd_config, 'ann %s is changed,' .. |
|
|
|
'our version = %s, remote version = %s', |
|
|
|
rule.prefix .. ':' .. set.name, |
|
|
|
set.ann.version, |
|
|
|
sel_elt.version) |
|
|
|
load_new_ann(rule, ev_base, set, sel_elt, min_diff) |
|
|
|
else |
|
|
|
lua_util.debugm(N, rspamd_config, 'ann %s is not changed,' .. |
|
|
|
'our version = %s, remote version = %s', |
|
|
|
rule.prefix .. ':' .. set.name, |
|
|
|
set.ann.version, |
|
|
|
sel_elt.version) |
|
|
|
end |
|
|
|
lua_redis.exec_redis_script(redis_maybe_load_id, |
|
|
|
{ev_base = ev_base, is_write = false}, |
|
|
|
redis_update_cb, |
|
|
|
{gen_ann_prefix(rule, elt), tostring(local_ver)}) |
|
|
|
end, |
|
|
|
data) |
|
|
|
else |
|
|
|
-- We have some different ANN, so we need to compare distance |
|
|
|
if set.ann.distance > min_diff then |
|
|
|
-- Load more specific ANN |
|
|
|
rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s,' .. |
|
|
|
'our distance = %s, remote distance = %s', |
|
|
|
rule.prefix .. ':' .. set.name, |
|
|
|
set.ann.distance, |
|
|
|
min_diff) |
|
|
|
load_new_ann(rule, ev_base, set, sel_elt, min_diff) |
|
|
|
else |
|
|
|
lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific,' .. |
|
|
|
'our distance = %s, remote distance = %s', |
|
|
|
rule.prefix .. ':' .. set.name, |
|
|
|
set.ann.distance, |
|
|
|
min_diff) |
|
|
|
end |
|
|
|
end |
|
|
|
else |
|
|
|
-- We have no ANN, load new one |
|
|
|
load_new_ann(rule, ev_base, set, sel_elt, min_diff) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
-- First we need to get all anns stored in our Redis |
|
|
|
lua_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
false, -- is write |
|
|
|
members_cb, --callback |
|
|
|
'SMEMBERS', -- command |
|
|
|
{gen_ann_prefix(rule, nil)} -- arguments |
|
|
|
) |
|
|
|
-- Used to deserialise ANN element from a list |
|
|
|
local function load_ann_profile(element) |
|
|
|
local ucl = require "ucl" |
|
|
|
|
|
|
|
local parser = ucl.parser() |
|
|
|
local res,ucl_err = parser:parse_string(element) |
|
|
|
if not res then |
|
|
|
rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s', |
|
|
|
ucl_err) |
|
|
|
return nil |
|
|
|
else |
|
|
|
return parser:get_object() |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
-- Function to check or load ANNs from Redis |
|
|
|
local function check_anns(rule, cfg, ev_base) |
|
|
|
for _,set in pairs(rule.settings) do |
|
|
|
local function members_cb(err, data) |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', |
|
|
|
err) |
|
|
|
elseif type(data) == 'table' then |
|
|
|
process_existing_ann(rule, ev_base, set, fun.map(load_ann_profile, data)) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
-- Extract all profiles for some specific settings id |
|
|
|
-- Get the last `max_profiles` recently used |
|
|
|
-- Select the most appropriate to our profile but it should not differ by more |
|
|
|
-- than 30% of symbols |
|
|
|
lua_redis.redis_make_request_taskless(ev_base, |
|
|
|
cfg, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
false, -- is write |
|
|
|
members_cb, --callback |
|
|
|
'ZREVRANGE', -- command |
|
|
|
{set.prefix, '0', tostring(settings.max_profiles)} -- arguments |
|
|
|
) |
|
|
|
end -- Cycle over all settings |
|
|
|
|
|
|
|
return rule.watch_interval |
|
|
|
end |
|
|
|
|
|
|
|
-- Function to clean up old ANNs |
|
|
|
local function cleanup_anns(rule, cfg, ev_base) |
|
|
|
for _,set in pairs(rule.settings) do |
|
|
|
local function invalidate_cb(err, data) |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s', |
|
|
|
err) |
|
|
|
elseif type(data) == 'table' then |
|
|
|
for _,expired in ipairs(data) do |
|
|
|
local profile = load_ann_profile(expired) |
|
|
|
rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s', |
|
|
|
rule.prefix .. ':' .. set.name, |
|
|
|
profile.ann_key, |
|
|
|
profile.version) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
lua_redis.exec_redis_script(redis_maybe_invalidate_id, |
|
|
|
{ev_base = ev_base, is_write = true}, |
|
|
|
invalidate_cb, |
|
|
|
{set.prefix, tostring(settings.max_profiles)}) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
local function ann_push_vector(task) |
|
|
|
if task:has_flag('skip') then return end |
|
|
|
if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end |
|
|
@@ -800,6 +882,83 @@ local function ann_push_vector(task) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
-- Generate redis prefix for specific rule and specific settings |
|
|
|
local function redis_ann_prefix(rule, settings_name) |
|
|
|
-- We also need to count metatokens: |
|
|
|
local n = meta_functions.version |
|
|
|
return string.format('%s_%s_%d_%s', |
|
|
|
settings.prefix, rule.prefix, n, settings_name) |
|
|
|
end |
|
|
|
|
|
|
|
-- This function is used to adjust profiles and allowed setting ids for each rule |
|
|
|
-- It must be called when all settings are already registered (e.g. at post-init for config) |
|
|
|
local function process_rules_settings() |
|
|
|
local function process_settings_elt(rule, selt) |
|
|
|
local profile = rule.profile[selt.name] |
|
|
|
if profile then |
|
|
|
-- Use static user defined profile |
|
|
|
-- Ensure that we have an array... |
|
|
|
lua_util.debugm(N, rspamd_config, "use static profile for %s (%s)", |
|
|
|
rule.prefix, selt.name) |
|
|
|
if not profile[1] then profile = lua_util.keys(profile) end |
|
|
|
selt.symbols = profile |
|
|
|
else |
|
|
|
lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)", |
|
|
|
rule.prefix, selt.name) |
|
|
|
end |
|
|
|
|
|
|
|
-- Generic stuff |
|
|
|
table.sort(selt.symbols) |
|
|
|
selt.digest = lua_util.table_digest(selt.symbols) |
|
|
|
selt.prefix = redis_ann_prefix(rule, selt.name) |
|
|
|
|
|
|
|
lua_redis.register_prefix(selt.prefix, N, |
|
|
|
string.format('NN prefix for rule "%s"; settings id "%s"', |
|
|
|
rule.prefix, selt.name)) |
|
|
|
end |
|
|
|
|
|
|
|
for _,rule in pairs(opts.rules) do |
|
|
|
if not rule.allowed_settings then |
|
|
|
-- Extract all settings ids |
|
|
|
rule.allowed_settings = lua_util.keys(lua_settings.all_settings) |
|
|
|
end |
|
|
|
|
|
|
|
-- Convert to a map <setting_id> -> true |
|
|
|
rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings) |
|
|
|
|
|
|
|
-- Check if we can work without settings |
|
|
|
if type(rule.default) ~= 'boolean' then |
|
|
|
rule.default = true |
|
|
|
end |
|
|
|
|
|
|
|
rule.settings = {} |
|
|
|
|
|
|
|
if rule.default then |
|
|
|
local default_settings = { |
|
|
|
symbols = lua_util.keys(lua_settings.default_symbols), |
|
|
|
name = 'default' |
|
|
|
} |
|
|
|
|
|
|
|
process_settings_elt(rule, default_settings) |
|
|
|
rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32 |
|
|
|
end |
|
|
|
|
|
|
|
-- Now, for each allowed settings, we store sorted symbols + digest |
|
|
|
-- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest } |
|
|
|
for s,_ in pairs(rule.allowed_settings) do |
|
|
|
-- Here, we have a name, set of symbols and |
|
|
|
local selt = lua_settings.settings_by_id(s) |
|
|
|
rule.settings[s] = { |
|
|
|
symbols = selt.symbols, -- Already sorted |
|
|
|
name = selt.name |
|
|
|
} |
|
|
|
|
|
|
|
process_settings_elt(rule, rule.settings[s]) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
redis_params = lua_redis.parse_redis_server('neural') |
|
|
|
|
|
|
|
if not redis_params then |
|
|
@@ -818,7 +977,7 @@ local rules = opts['rules'] |
|
|
|
if not rules then |
|
|
|
-- Use legacy configuration |
|
|
|
rules = {} |
|
|
|
rules['RFANN'] = opts |
|
|
|
rules['default'] = opts |
|
|
|
end |
|
|
|
|
|
|
|
local id = rspamd_config:register_symbol({ |
|
|
@@ -827,6 +986,7 @@ local id = rspamd_config:register_symbol({ |
|
|
|
priority = 6, |
|
|
|
callback = ann_scores_filter |
|
|
|
}) |
|
|
|
|
|
|
|
for k,r in pairs(rules) do |
|
|
|
local def_rules = lua_util.override_defaults(default_options, r) |
|
|
|
def_rules['redis'] = redis_params |
|
|
@@ -841,6 +1001,7 @@ for k,r in pairs(rules) do |
|
|
|
if def_rules.train.max_train then |
|
|
|
def_rules.train.max_trains = def_rules.train.max_train |
|
|
|
end |
|
|
|
|
|
|
|
rspamd_logger.infox(rspamd_config, "register ann rule %s", k) |
|
|
|
settings.rules[k] = def_rules |
|
|
|
rspamd_config:set_metric_symbol({ |
|
|
@@ -876,8 +1037,11 @@ rspamd_config:register_symbol({ |
|
|
|
}) |
|
|
|
|
|
|
|
-- Add training scripts |
|
|
|
for _,rule in pairs(settings.rules) do |
|
|
|
for k,rule in pairs(settings.rules) do |
|
|
|
load_scripts(rule.redis) |
|
|
|
-- We also need to deal with settings |
|
|
|
rspamd_config:add_post_init(process_rules_settings) |
|
|
|
-- This function will check ANNs in Redis when a worker is loaded |
|
|
|
rspamd_config:add_on_load(function(cfg, ev_base, worker) |
|
|
|
rspamd_config:add_periodic(ev_base, 0.0, |
|
|
|
function(_, _) |
|
|
@@ -888,6 +1052,7 @@ for _,rule in pairs(settings.rules) do |
|
|
|
-- We also want to train neural nets when they have enough data |
|
|
|
rspamd_config:add_periodic(ev_base, 0.0, |
|
|
|
function(_, _) |
|
|
|
cleanup_anns(rule, cfg, ev_base) |
|
|
|
return maybe_train_anns(rule, cfg, ev_base, worker) |
|
|
|
end) |
|
|
|
end |