|
|
@@ -43,12 +43,12 @@ local default_options = { |
|
|
|
max_trains = 1000, |
|
|
|
max_epoch = 1000, |
|
|
|
max_usages = 10, |
|
|
|
use_settings = false, |
|
|
|
per_user = false, |
|
|
|
watch_interval = 60.0, |
|
|
|
mse = 0.001, |
|
|
|
autotrain = true, |
|
|
|
}, |
|
|
|
use_settings = false, |
|
|
|
per_user = false, |
|
|
|
watch_interval = 60.0, |
|
|
|
nlayers = 4, |
|
|
|
lock_expire = 600, |
|
|
|
learning_spawned = false, |
|
|
@@ -58,8 +58,7 @@ local default_options = { |
|
|
|
} |
|
|
|
|
|
|
|
local settings = { |
|
|
|
rules = { |
|
|
|
} |
|
|
|
rules = {} |
|
|
|
} |
|
|
|
|
|
|
|
-- ANNs indexed by settings id |
|
|
@@ -96,9 +95,17 @@ local redis_lua_script_can_train = [[ |
|
|
|
if ret then nham = tonumber(ret) end |
|
|
|
|
|
|
|
if KEYS[3] == 'spam' then |
|
|
|
if nham <= lim and nham + 1 >= nspam then return tostring(nspam + 1) end |
|
|
|
if nham <= lim and nham + 1 >= nspam then |
|
|
|
return tostring(nspam + 1) |
|
|
|
else |
|
|
|
return tostring(-(nham + 1)) |
|
|
|
end |
|
|
|
else |
|
|
|
if nspam <= lim and nspam + 1 >= nham then return tostring(nham + 1) end |
|
|
|
if nspam <= lim and nspam + 1 >= nham then |
|
|
|
return tostring(nham + 1) |
|
|
|
else |
|
|
|
return tostring(-(nspam + 1)) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
return tostring(0) |
|
|
@@ -312,8 +319,7 @@ local function gen_fann_prefix(rule, id) |
|
|
|
tprefix = 't'; |
|
|
|
end |
|
|
|
if id then |
|
|
|
return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), |
|
|
|
rule.prefix .. id |
|
|
|
return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id |
|
|
|
else |
|
|
|
return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil |
|
|
|
end |
|
|
@@ -433,7 +439,7 @@ local function create_train_fann(rule, n, id) |
|
|
|
if not is_fann_valid(rule, prefix, fanns[id].fann) then |
|
|
|
fanns[id].fann_train = create_fann(n, rule.nlayers) |
|
|
|
fanns[id].fann = nil |
|
|
|
elseif fanns[id].version % rule.max_usages == 0 then |
|
|
|
elseif fanns[id].version % rule.train.max_usages == 0 then |
|
|
|
-- Forget last fann |
|
|
|
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, |
|
|
|
fanns[id].version) |
|
|
@@ -507,7 +513,7 @@ local function fann_train_callback(rule, task, score, required_score, id) |
|
|
|
|
|
|
|
local learn_spam, learn_ham |
|
|
|
|
|
|
|
if rule.autotrain then |
|
|
|
if train_opts.autotrain then |
|
|
|
if train_opts['spam_score'] then |
|
|
|
learn_spam = score >= train_opts['spam_score'] |
|
|
|
else |
|
|
@@ -531,13 +537,18 @@ local function fann_train_callback(rule, task, score, required_score, id) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
if learn_spam or learn_ham then |
|
|
|
local k |
|
|
|
local vec_len = 0 |
|
|
|
if learn_spam then k = 'spam' else k = 'ham' end |
|
|
|
|
|
|
|
local function learn_vec_cb(err) |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot store train vector for %s: %s', fname, err) |
|
|
|
else |
|
|
|
rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes", |
|
|
|
rule['name'], k, vec_len) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
@@ -548,6 +559,7 @@ local function fann_train_callback(rule, task, score, required_score, id) |
|
|
|
-- Add filtered meta tokens |
|
|
|
fun.each(function(e) table.insert(fann_data, e) end, mt) |
|
|
|
local str = rspamd_util.zstd_compress(table.concat(fann_data, ';')) |
|
|
|
vec_len = #str |
|
|
|
|
|
|
|
rspamd_redis.redis_make_request(task, |
|
|
|
rule.redis, |
|
|
@@ -559,10 +571,13 @@ local function fann_train_callback(rule, task, score, required_score, id) |
|
|
|
) |
|
|
|
else |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot check if we can train %s: %s', fname, err) |
|
|
|
rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err) |
|
|
|
if string.match(err, 'NOSCRIPT') then |
|
|
|
load_scripts(rspamd_config, task:get_ev_base(), nil) |
|
|
|
end |
|
|
|
elseif tonumber(data) < 0 then |
|
|
|
rspamd_logger.infox(task, "cannot learn ANN %s: too many %s samples: %s", |
|
|
|
fname, k, -tonumber(data)) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
@@ -574,7 +589,7 @@ local function fann_train_callback(rule, task, score, required_score, id) |
|
|
|
can_train_cb, --callback |
|
|
|
'EVALSHA', -- command |
|
|
|
{redis_can_train_sha, '4', gen_fann_prefix(rule, nil), |
|
|
|
suffix, k, tostring(rule.max_trains)} -- arguments |
|
|
|
suffix, k, tostring(train_opts.max_trains)} -- arguments |
|
|
|
) |
|
|
|
end |
|
|
|
end |
|
|
@@ -722,7 +737,7 @@ local function train_fann(rule, _, ev_base, elt, worker) |
|
|
|
create_train_fann(rule, n, elt) |
|
|
|
end |
|
|
|
|
|
|
|
if #spam_elts + #ham_elts < rule.max_trains / 2 then |
|
|
|
if #spam_elts + #ham_elts < rule.train.max_trains / 2 then |
|
|
|
-- Invalidate ANN as it is definitely invalid |
|
|
|
local function redis_invalidate_cb(_err, _data) |
|
|
|
if _err then |
|
|
@@ -912,10 +927,10 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker) |
|
|
|
rspamd_logger.errx(rspamd_config, |
|
|
|
'cannot get FANN trains %s from redis: %s', prefix, _err) |
|
|
|
elseif _data and type(_data) == 'number' or type(_data) == 'string' then |
|
|
|
if tonumber(_data) and tonumber(_data) >= rule.max_trains then |
|
|
|
if tonumber(_data) and tonumber(_data) >= rule.train.max_trains then |
|
|
|
rspamd_logger.infox(rspamd_config, |
|
|
|
'need to learn ANN %s after %s learn vectors (%s required)', |
|
|
|
prefix, tonumber(_data), rule.max_trains) |
|
|
|
prefix, tonumber(_data), rule.train.max_trains) |
|
|
|
train_fann(rule, cfg, ev_base, elt, worker) |
|
|
|
end |
|
|
|
end |
|
|
@@ -1011,8 +1026,7 @@ end |
|
|
|
|
|
|
|
local function ann_push_vector(task) |
|
|
|
local scores = task:get_metric_score() |
|
|
|
|
|
|
|
for _,rule in ipairs(settings.rules) do |
|
|
|
for _,rule in pairs(settings.rules) do |
|
|
|
local sid = "0" |
|
|
|
if rule.use_settings then |
|
|
|
sid = tostring(task:get_settings_id()) |
|
|
@@ -1053,32 +1067,64 @@ else |
|
|
|
callback = fann_scores_filter |
|
|
|
}) |
|
|
|
|
|
|
|
local function deepcopy(orig) |
|
|
|
local orig_type = type(orig) |
|
|
|
local copy |
|
|
|
if orig_type == 'table' then |
|
|
|
copy = {} |
|
|
|
for orig_key, orig_value in next, orig, nil do |
|
|
|
copy[deepcopy(orig_key)] = deepcopy(orig_value) |
|
|
|
end |
|
|
|
setmetatable(copy, deepcopy(getmetatable(orig))) |
|
|
|
else -- number, string, boolean, etc |
|
|
|
copy = orig |
|
|
|
end |
|
|
|
return copy |
|
|
|
end |
|
|
|
local function override_defaults(def, override) |
|
|
|
for k,v in pairs(def) do |
|
|
|
if override[k] then |
|
|
|
if def[k] then |
|
|
|
if type(override[k]) == 'table' then |
|
|
|
override_defaults(def[k], override[k]) |
|
|
|
else |
|
|
|
def[k] = override[k] |
|
|
|
end |
|
|
|
else |
|
|
|
def[k] = override[k] |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
for k,r in pairs(rules) do |
|
|
|
rules[k] = default_options |
|
|
|
rules[k]['redis'] = redis_params |
|
|
|
local cur = rules[k] |
|
|
|
local def_rules = deepcopy(default_options) |
|
|
|
def_rules['redis'] = redis_params |
|
|
|
-- Override defaults |
|
|
|
for sk,v in pairs(r) do |
|
|
|
cur[sk] = v |
|
|
|
override_defaults(def_rules, r) |
|
|
|
|
|
|
|
if not def_rules.prefix then |
|
|
|
def_rules.prefix = k |
|
|
|
end |
|
|
|
if not cur.prefix then |
|
|
|
cur.prefix = k |
|
|
|
if not def_rules.name then |
|
|
|
def_rules.name = k |
|
|
|
end |
|
|
|
rspamd_logger.infox(rspamd_config, "register ann rule %s", k) |
|
|
|
settings.rules[k] = def_rules |
|
|
|
rspamd_config:set_metric_symbol({ |
|
|
|
name = cur.symbol_spam, |
|
|
|
name = def_rules.symbol_spam, |
|
|
|
score = 3.0, |
|
|
|
description = 'Neural network SPAM', |
|
|
|
group = 'fann' |
|
|
|
}) |
|
|
|
|
|
|
|
rspamd_config:set_metric_symbol({ |
|
|
|
name = cur.symbol_ham, |
|
|
|
name = def_rules.symbol_ham, |
|
|
|
score = -2.0, |
|
|
|
description = 'Neural network HAM', |
|
|
|
group = 'fann' |
|
|
|
}) |
|
|
|
rspamd_config:register_symbol({ |
|
|
|
name = cur.symbol_ham, |
|
|
|
name = def_rules.symbol_ham, |
|
|
|
type = 'virtual,nostat', |
|
|
|
parent = id |
|
|
|
}) |
|
|
@@ -1086,13 +1132,11 @@ else |
|
|
|
|
|
|
|
rspamd_config:register_symbol({ |
|
|
|
name = 'FANN_VECTOR_PUSH', |
|
|
|
type = 'postfilter,nostat', |
|
|
|
type = 'idempotent,nostat', |
|
|
|
priority = 5, |
|
|
|
callback = ann_push_vector |
|
|
|
}) |
|
|
|
|
|
|
|
settings.rules = rules |
|
|
|
|
|
|
|
-- Add training scripts |
|
|
|
for _,rule in pairs(settings.rules) do |
|
|
|
rspamd_config:add_on_load(function(cfg, ev_base, worker) |