|
|
@@ -104,14 +104,14 @@ local redis_lua_script_can_store_train_vec = [[ |
|
|
|
if ret then nham = tonumber(ret) end |
|
|
|
|
|
|
|
if KEYS[2] == 'spam' then |
|
|
|
if nham <= lim and nham + 1 >= nspam then |
|
|
|
return tostring(nspam + 1) |
|
|
|
if nspam <= lim then |
|
|
|
return tostring(nspam) |
|
|
|
else |
|
|
|
return tostring(-(nspam)) |
|
|
|
end |
|
|
|
else |
|
|
|
if nspam <= lim and nspam + 1 >= nham then |
|
|
|
return tostring(nham + 1) |
|
|
|
if nham <= lim then |
|
|
|
return tostring(nham) |
|
|
|
else |
|
|
|
return tostring(-(nham)) |
|
|
|
end |
|
|
@@ -127,8 +127,9 @@ local redis_can_store_train_vec_id = nil |
|
|
|
-- key2 - number of elements to leave |
|
|
|
local redis_lua_script_maybe_invalidate = [[ |
|
|
|
local card = redis.call('ZCARD', KEYS[1]) |
|
|
|
if card > tonumber(KEYS[2]) then |
|
|
|
local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))) |
|
|
|
local lim = tonumber(KEYS[2]) |
|
|
|
if card > lim then |
|
|
|
local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1) |
|
|
|
for _,k in ipairs(to_delete) do |
|
|
|
local tb = cjson.decode(k) |
|
|
|
redis.call('DEL', tb.redis_key) |
|
|
@@ -136,7 +137,7 @@ local redis_lua_script_maybe_invalidate = [[ |
|
|
|
redis.call('DEL', tb.redis_key .. '_spam') |
|
|
|
redis.call('DEL', tb.redis_key .. '_ham') |
|
|
|
end |
|
|
|
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))) |
|
|
|
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1) |
|
|
|
return to_delete |
|
|
|
else |
|
|
|
return {} |
|
|
@@ -152,17 +153,17 @@ local redis_maybe_invalidate_id = nil |
|
|
|
-- key4 - hostname |
|
|
|
local redis_lua_script_maybe_lock = [[ |
|
|
|
local locked = redis.call('HGET', KEYS[1], 'lock') |
|
|
|
local now = tonumber(KEYS[2]) |
|
|
|
if locked then |
|
|
|
locked = tonumber(locked) |
|
|
|
now = tonumber(KEYS[2]) |
|
|
|
expire = tonumber(KEYS[3]) |
|
|
|
local expire = tonumber(KEYS[3]) |
|
|
|
if now > locked and (now - locked) < expire then |
|
|
|
return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')} |
|
|
|
end |
|
|
|
end |
|
|
|
redis.call('HSET', KEYS[1], 'lock', tostring(now)) |
|
|
|
redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) |
|
|
|
return true |
|
|
|
return 1 |
|
|
|
]] |
|
|
|
local redis_maybe_lock_id = nil |
|
|
|
|
|
|
@@ -178,6 +179,8 @@ local redis_lua_script_save_unlock = [[ |
|
|
|
local now = tonumber(KEYS[6]) |
|
|
|
redis.call('ZADD', KEYS[2], now, KEYS[4]) |
|
|
|
redis.call('HSET', KEYS[1], 'ann', KEYS[3]) |
|
|
|
redis.call('DEL', KEYS[1] .. '_spam') |
|
|
|
edis.call('DEL', KEYS[1] .. '_ham') |
|
|
|
redis.call('HDEL', KEYS[1], 'lock') |
|
|
|
redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) |
|
|
|
return 1 |
|
|
@@ -267,7 +270,7 @@ local function new_ann_profile(task, rule, set, version) |
|
|
|
} |
|
|
|
|
|
|
|
local ucl = require "ucl" |
|
|
|
local profile_serialized = ucl.to_format(profile, 'json-compact') |
|
|
|
local profile_serialized = ucl.to_format(profile, 'json-compact', true) |
|
|
|
|
|
|
|
local function add_cb(err, _) |
|
|
|
if err then |
|
|
@@ -322,7 +325,8 @@ local function ann_scores_filter(task) |
|
|
|
score = out[1] |
|
|
|
|
|
|
|
local symscore = string.format('%.3f', score) |
|
|
|
rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore) |
|
|
|
lua_util.debugm(N, task, '%s:%s ann score: %s', |
|
|
|
rule.prefix, set.name, symscore) |
|
|
|
|
|
|
|
if score > 0 then |
|
|
|
local result = score |
|
|
@@ -348,26 +352,44 @@ end |
|
|
|
|
|
|
|
local function ann_push_task_result(rule, task, verdict, score, set) |
|
|
|
local train_opts = rule.train |
|
|
|
|
|
|
|
|
|
|
|
local learn_spam, learn_ham |
|
|
|
local skip_reason = 'unknown' |
|
|
|
|
|
|
|
if train_opts.autotrain then |
|
|
|
if verdict == 'passthrough' or verdict == 'uncertain' then |
|
|
|
if verdict == 'passthrough' then |
|
|
|
lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)', |
|
|
|
verdict, score) |
|
|
|
end |
|
|
|
|
|
|
|
if train_opts['spam_score'] then |
|
|
|
learn_spam = score >= train_opts['spam_score'] |
|
|
|
if train_opts.spam_score then |
|
|
|
learn_spam = score >= train_opts.spam_score |
|
|
|
|
|
|
|
if not learn_spam then |
|
|
|
skip_reason = string.format('score < spam_score: %f < %f', |
|
|
|
score, train_opts.spam_score) |
|
|
|
end |
|
|
|
else |
|
|
|
learn_spam = verdict == 'spam' or verdict == 'junk' |
|
|
|
|
|
|
|
if not learn_spam then |
|
|
|
skip_reason = string.format('verdict: %s', |
|
|
|
verdict) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
if train_opts['ham_score'] then |
|
|
|
learn_ham = score <= train_opts['ham_score'] |
|
|
|
if train_opts.ham_score then |
|
|
|
learn_ham = score <= train_opts.ham_score |
|
|
|
if not learn_ham then |
|
|
|
skip_reason = string.format('score > ham_score: %f < %f', |
|
|
|
score, train_opts.ham_score) |
|
|
|
end |
|
|
|
else |
|
|
|
learn_ham = verdict == 'ham' |
|
|
|
|
|
|
|
if not learn_ham then |
|
|
|
skip_reason = string.format('verdict: %s', |
|
|
|
verdict) |
|
|
|
end |
|
|
|
end |
|
|
|
else |
|
|
|
-- Train by request header |
|
|
@@ -378,6 +400,8 @@ local function ann_push_task_result(rule, task, verdict, score, set) |
|
|
|
learn_spam = true |
|
|
|
elseif hdr:lower() == 'ham' then |
|
|
|
learn_ham = true |
|
|
|
else |
|
|
|
skip_reason = string.format('no explicit header') |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
@@ -387,18 +411,8 @@ local function ann_push_task_result(rule, task, verdict, score, set) |
|
|
|
local learn_type |
|
|
|
if learn_spam then learn_type = 'spam' else learn_type = 'ham' end |
|
|
|
|
|
|
|
local function learn_vec_cb(err) |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', |
|
|
|
rule.prefix, set.name, err) |
|
|
|
else |
|
|
|
rspamd_logger.infox(task, "trained ANN rule %s:%s, save %s vector", |
|
|
|
rule.prefix, set.name, learn_type) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
local function can_train_cb(err, data) |
|
|
|
if not err and tonumber(data) > 0 then |
|
|
|
if not err and tonumber(data) >= 0 then |
|
|
|
local coin = math.random() |
|
|
|
if coin < 1.0 - train_opts.train_prob then |
|
|
|
rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) |
|
|
@@ -408,6 +422,17 @@ local function ann_push_task_result(rule, task, verdict, score, set) |
|
|
|
|
|
|
|
local str = rspamd_util.zstd_compress(table.concat(vec, ';')) |
|
|
|
|
|
|
|
local function learn_vec_cb(_err) |
|
|
|
if _err then |
|
|
|
rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', |
|
|
|
rule.prefix, set.name, _err) |
|
|
|
else |
|
|
|
lua_util.debugm(N, task, |
|
|
|
"add train data for ANN rule %s:%s, save %s vector of %s elts; %s bytes compressed", |
|
|
|
rule.prefix, set.name, learn_type, #vec, #str) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
lua_redis.redis_make_request(task, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
@@ -422,7 +447,7 @@ local function ann_push_task_result(rule, task, verdict, score, set) |
|
|
|
rule.prefix, set.name, err) |
|
|
|
elseif tonumber(data) < 0 then |
|
|
|
rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s", |
|
|
|
rule.prefix, set.name, learn_type, -tonumber(data)) |
|
|
|
rule.prefix, set.name, learn_type, -tonumber(data)) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
@@ -436,6 +461,9 @@ local function ann_push_task_result(rule, task, verdict, score, set) |
|
|
|
{task = task, is_write = true}, |
|
|
|
can_train_cb, |
|
|
|
{ set.ann.redis_key, learn_type, tostring(train_opts.max_trains)}) |
|
|
|
else |
|
|
|
lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s', |
|
|
|
skip_reason) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
@@ -481,6 +509,7 @@ local function register_lock_extender(rule, set, ev_base, ann_key) |
|
|
|
{ann_key, 'lock', '30'} |
|
|
|
) |
|
|
|
else |
|
|
|
lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") |
|
|
|
return false -- do not plan any more updates |
|
|
|
end |
|
|
|
|
|
|
@@ -537,7 +566,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve |
|
|
|
return out |
|
|
|
end |
|
|
|
|
|
|
|
rule.learning_spawned = true |
|
|
|
set.learning_spawned = true |
|
|
|
|
|
|
|
local function redis_save_cb(err) |
|
|
|
if err then |
|
|
@@ -559,7 +588,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve |
|
|
|
end |
|
|
|
|
|
|
|
local function ann_trained(err, data) |
|
|
|
rule.learning_spawned = false |
|
|
|
set.learning_spawned = false |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', |
|
|
|
rule.prefix, set.name, err) |
|
|
@@ -598,7 +627,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve |
|
|
|
} |
|
|
|
|
|
|
|
local ucl = require "ucl" |
|
|
|
local profile_serialized = ucl.to_format(profile, 'json-compact') |
|
|
|
local profile_serialized = ucl.to_format(profile, 'json-compact', true) |
|
|
|
|
|
|
|
lua_redis.exec_redis_script(redis_save_unlock_id, |
|
|
|
{ev_base = ev_base, is_write = true}, |
|
|
@@ -695,7 +724,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s', |
|
|
|
ann_key, err) |
|
|
|
elseif type(data) == 'boolean' and data then |
|
|
|
elseif type(data) == 'number' and data == 1 then |
|
|
|
-- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning |
|
|
|
lua_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
@@ -752,47 +781,52 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', |
|
|
|
ann_key, err) |
|
|
|
else |
|
|
|
local _err,ann_data = rspamd_util.zstd_decompress(data[1]) |
|
|
|
local ann |
|
|
|
|
|
|
|
if _err or not ann_data then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', |
|
|
|
rule.prefix .. ':' .. set.name, ann_key, _err) |
|
|
|
return |
|
|
|
else |
|
|
|
ann = rspamd_kann.load(ann_data) |
|
|
|
|
|
|
|
if ann then |
|
|
|
set.ann = { |
|
|
|
ann = ann, |
|
|
|
version = profile.version, |
|
|
|
symbols = profile.symbols, |
|
|
|
distance = min_diff, |
|
|
|
redis_key = profile.redis_key |
|
|
|
} |
|
|
|
if type(data) == 'string' then |
|
|
|
local _err,ann_data = rspamd_util.zstd_decompress(data) |
|
|
|
local ann |
|
|
|
|
|
|
|
local ucl = require "ucl" |
|
|
|
local profile_serialized = ucl.to_format(profile, 'json-compact') |
|
|
|
|
|
|
|
local function rank_cb(_, _) |
|
|
|
-- TODO: maybe add some logging |
|
|
|
end |
|
|
|
-- Also update rank for the loaded ANN to avoid removal |
|
|
|
lua_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
rank_cb, --callback |
|
|
|
'ZADD', -- command |
|
|
|
{set.prefix, tostring(rspamd_util.get_time()), profile_serialized} |
|
|
|
) |
|
|
|
rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s', |
|
|
|
rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version) |
|
|
|
if _err or not ann_data then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', |
|
|
|
rule.prefix .. ':' .. set.name, ann_key, _err) |
|
|
|
return |
|
|
|
else |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s from Redis key %s', |
|
|
|
rule.prefix .. ':' .. set.name, ann_key) |
|
|
|
ann = rspamd_kann.load(ann_data) |
|
|
|
|
|
|
|
if ann then |
|
|
|
set.ann = { |
|
|
|
ann = ann, |
|
|
|
version = profile.version, |
|
|
|
symbols = profile.symbols, |
|
|
|
distance = min_diff, |
|
|
|
redis_key = profile.redis_key |
|
|
|
} |
|
|
|
|
|
|
|
local ucl = require "ucl" |
|
|
|
local profile_serialized = ucl.to_format(profile, 'json-compact', true) |
|
|
|
|
|
|
|
local function rank_cb(_, _) |
|
|
|
-- TODO: maybe add some logging |
|
|
|
end |
|
|
|
-- Also update rank for the loaded ANN to avoid removal |
|
|
|
lua_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
rank_cb, --callback |
|
|
|
'ZADD', -- command |
|
|
|
{set.prefix, tostring(rspamd_util.get_time()), profile_serialized} |
|
|
|
) |
|
|
|
rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', |
|
|
|
rule.prefix, set.name, ann_key, #ann_data, profile.version) |
|
|
|
else |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s', |
|
|
|
rule.prefix, set.name, ann_key) |
|
|
|
end |
|
|
|
end |
|
|
|
else |
|
|
|
lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s', |
|
|
|
rule.prefix, set.name, ann_key) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
@@ -803,8 +837,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) |
|
|
|
false, -- is write |
|
|
|
data_cb, --callback |
|
|
|
'HGET', -- command |
|
|
|
{ann_key, 'ann'}, -- arguments |
|
|
|
{opaque_data = true} |
|
|
|
{ann_key, 'ann'} -- arguments |
|
|
|
) |
|
|
|
end |
|
|
|
|
|
|
@@ -900,23 +933,46 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) |
|
|
|
-- We have our ANN and that's train vectors, check if we can learn |
|
|
|
local ann_key = sel_elt.redis_key |
|
|
|
|
|
|
|
lua_util.debugm(N, rspamd_config, "check ANN %s", ann_key) |
|
|
|
local redis_len_cb = function(err, data) |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(rspamd_config, |
|
|
|
'cannot get FANN trains %s from redis: %s', ann_key, err) |
|
|
|
elseif data and type(data) == 'number' or type(data) == 'string' then |
|
|
|
if tonumber(data) and tonumber(data) >= rule.train.max_trains then |
|
|
|
rspamd_logger.infox(rspamd_config, |
|
|
|
'need to learn ANN %s after %s learn vectors (%s required)', |
|
|
|
ann_key, tonumber(data), rule.train.max_trains) |
|
|
|
do_train_ann(worker, ev_base, rule, set, ann_key) |
|
|
|
else |
|
|
|
rspamd_logger.debugm(N, rspamd_config, |
|
|
|
'no need to learn ANN %s %s learn vectors (%s required)', |
|
|
|
ann_key, tonumber(data), rule.train.max_trains) |
|
|
|
lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained", |
|
|
|
ann_key) |
|
|
|
|
|
|
|
-- Create continuation closure |
|
|
|
local redis_len_cb_gen = function(cont_cb) |
|
|
|
return function(err, data) |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(rspamd_config, |
|
|
|
'cannot get ANN trains %s from redis: %s', ann_key, err) |
|
|
|
elseif data and type(data) == 'number' or type(data) == 'string' then |
|
|
|
if tonumber(data) and tonumber(data) >= rule.train.max_trains then |
|
|
|
cont_cb() |
|
|
|
else |
|
|
|
rspamd_logger.debugm(N, rspamd_config, |
|
|
|
'no need to learn ANN %s %s learn vectors (%s required)', |
|
|
|
ann_key, tonumber(data), rule.train.max_trains) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
end |
|
|
|
|
|
|
|
local function initiate_train() |
|
|
|
rspamd_logger.infox(rspamd_config, |
|
|
|
'need to learn ANN %s after %s learn vectors (%s required)', |
|
|
|
ann_key, tonumber(data), rule.train.max_trains) |
|
|
|
do_train_ann(worker, ev_base, rule, set, ann_key) |
|
|
|
end |
|
|
|
|
|
|
|
-- Spam vector is OK, check ham vector length |
|
|
|
local function check_ham_len() |
|
|
|
lua_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
false, -- is write |
|
|
|
redis_len_cb_gen(initiate_train), --callback |
|
|
|
'LLEN', -- command |
|
|
|
{ann_key .. '_ham'} |
|
|
|
) |
|
|
|
end |
|
|
|
|
|
|
|
lua_redis.redis_make_request_taskless(ev_base, |
|
|
@@ -924,7 +980,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
false, -- is write |
|
|
|
redis_len_cb, --callback |
|
|
|
redis_len_cb_gen(check_ham_len), --callback |
|
|
|
'LLEN', -- command |
|
|
|
{ann_key .. '_spam'} |
|
|
|
) |
|
|
@@ -1005,14 +1061,22 @@ local function cleanup_anns(rule, cfg, ev_base) |
|
|
|
end |
|
|
|
|
|
|
|
local function ann_push_vector(task) |
|
|
|
if task:has_flag('skip') then return end |
|
|
|
if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end |
|
|
|
if task:has_flag('skip') then |
|
|
|
lua_util.debugm(N, task, 'do not push data for skipped task') |
|
|
|
return |
|
|
|
end |
|
|
|
if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then |
|
|
|
lua_util.debugm(N, task, 'do not push data for manual scan') |
|
|
|
return |
|
|
|
end |
|
|
|
local verdict,score = lua_util.get_task_verdict(task) |
|
|
|
for _,rule in pairs(settings.rules) do |
|
|
|
local set = get_rule_settings(task, rule) |
|
|
|
|
|
|
|
if set then |
|
|
|
ann_push_task_result(rule, task, verdict, score, set) |
|
|
|
else |
|
|
|
lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix) |
|
|
|
end |
|
|
|
|
|
|
|
end |
|
|
@@ -1064,7 +1128,7 @@ local function process_rules_settings() |
|
|
|
|
|
|
|
if rule.default then |
|
|
|
local default_settings = { |
|
|
|
symbols = lua_util.keys(lua_settings.default_symbols()), |
|
|
|
symbols = lua_settings.default_symbols(), |
|
|
|
name = 'default' |
|
|
|
} |
|
|
|
|
|
|
@@ -1099,7 +1163,7 @@ local function process_rules_settings() |
|
|
|
|
|
|
|
if nelt then |
|
|
|
rule.settings[s] = nelt |
|
|
|
lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s; same symbols', |
|
|
|
lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s', |
|
|
|
nelt.name, rule.prefix) |
|
|
|
end |
|
|
|
end |