|
|
@@ -28,6 +28,7 @@ local fann_symbol_spam = 'FANNR_SPAM' |
|
|
|
local fann_symbol_ham = 'FANNR_HAM' |
|
|
|
local rspamd_redis = require "lua_redis" |
|
|
|
local fun = require "fun" |
|
|
|
local meta_functions = require "meta_functions" |
|
|
|
|
|
|
|
local module_log_id = 0x200 |
|
|
|
-- Module vars |
|
|
@@ -286,26 +287,6 @@ local function load_scripts(cfg, ev_base, on_load_cb) |
|
|
|
) |
|
|
|
end |
|
|
|
|
|
|
|
local function symbols_to_fann_vector(syms, scores) |
|
|
|
local learn_data = {} |
|
|
|
local matched_symbols = {} |
|
|
|
local n = rspamd_config:get_symbols_count() |
|
|
|
|
|
|
|
fun.each(function(s, score) |
|
|
|
matched_symbols[s + 1] = rspamd_util.tanh(score) |
|
|
|
end, fun.zip(syms, scores)) |
|
|
|
|
|
|
|
for i=1,n do |
|
|
|
if matched_symbols[i] then |
|
|
|
learn_data[i] = matched_symbols[i] |
|
|
|
else |
|
|
|
learn_data[i] = 0 |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
return learn_data |
|
|
|
end |
|
|
|
|
|
|
|
local function gen_fann_prefix(id) |
|
|
|
if id then |
|
|
|
return fann_prefix .. rspamd_config:get_symbols_cksum():hex() .. id,id |
|
|
@@ -345,13 +326,10 @@ local function fann_scores_filter(task) |
|
|
|
end |
|
|
|
|
|
|
|
if fanns[id].fann then |
|
|
|
local symbols,scores = task:get_symbols_numeric() |
|
|
|
local fann_data = symbols_to_fann_vector(symbols, scores) |
|
|
|
local mt = rspamd_gen_metatokens(task) |
|
|
|
|
|
|
|
for _,tok in ipairs(mt) do |
|
|
|
table.insert(fann_data, tok) |
|
|
|
end |
|
|
|
local fann_data = task:get_symbols_tokens() |
|
|
|
local mt = meta_functions.rspamd_gen_metatokens(task) |
|
|
|
-- Add filtered meta tokens |
|
|
|
fun.each(function(e) table.insert(fann_data, e) end, mt) |
|
|
|
|
|
|
|
local out = fanns[id].fann:test(fann_data) |
|
|
|
local symscore = string.format('%.3f', out[1]) |
|
|
@@ -445,7 +423,7 @@ local function load_or_invalidate_fann(data, id, ev_base) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
local function fann_train_callback(score, required_score, results, _, id, opts, extra, ev_base) |
|
|
|
local function fann_train_callback(task, score, required_score, id, opts) |
|
|
|
local fname,suffix = gen_fann_prefix(id) |
|
|
|
|
|
|
|
local learn_spam, learn_ham |
|
|
@@ -473,13 +451,11 @@ local function fann_train_callback(score, required_score, results, _, id, opts, |
|
|
|
|
|
|
|
local function can_train_cb(err, data) |
|
|
|
if not err and tonumber(data) > 0 then |
|
|
|
local learn_data = symbols_to_fann_vector( |
|
|
|
fun.map(function(r) return r[1] end, results), |
|
|
|
fun.map(function(r) return r[2] end, results) |
|
|
|
) |
|
|
|
local fann_data = task:get_symbols_tokens() |
|
|
|
local mt = meta_functions.rspamd_gen_metatokens(task) |
|
|
|
-- Add filtered meta tokens |
|
|
|
fun.each(function(e) table.insert(learn_data, e) end, extra) |
|
|
|
local str = rspamd_util.zstd_compress(table.concat(learn_data, ';')) |
|
|
|
fun.each(function(e) table.insert(fann_data, e) end, mt) |
|
|
|
local str = rspamd_util.zstd_compress(table.concat(fann_data, ';')) |
|
|
|
|
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
@@ -500,14 +476,14 @@ local function fann_train_callback(score, required_score, results, _, id, opts, |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
rspamd_redis.rspamd_redis_make_request(task, |
|
|
|
redis_params, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
can_train_cb, --callback |
|
|
|
'EVALSHA', -- command |
|
|
|
{redis_can_train_sha, '4', gen_fann_prefix(nil), suffix, k, tostring(max_trains)} -- arguments |
|
|
|
{redis_can_train_sha, '4', gen_fann_prefix(nil), |
|
|
|
suffix, k, tostring(max_trains)} -- arguments |
|
|
|
) |
|
|
|
end |
|
|
|
end |
|
|
@@ -857,6 +833,18 @@ local function check_fanns(_, ev_base) |
|
|
|
return watch_interval |
|
|
|
end |
|
|
|
|
|
|
|
local function ann_push_vector(task) |
|
|
|
local scores = task:get_metric_score() |
|
|
|
local sid = task:get_settings_id() |
|
|
|
if use_settings then |
|
|
|
fann_train_callback(task, scores[1], scores[2], |
|
|
|
tostring(sid), opts['train']) |
|
|
|
else |
|
|
|
fann_train_callback(task, scores[1], scores[2], |
|
|
|
tostring(sid), opts['train']) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
-- Initialization part |
|
|
|
|
|
|
|
local opts = rspamd_config:get_all_opt("fann_redis") |
|
|
@@ -892,7 +880,7 @@ else |
|
|
|
local id = rspamd_config:register_symbol({ |
|
|
|
name = fann_symbol_spam, |
|
|
|
type = 'postfilter', |
|
|
|
priority = 5, |
|
|
|
priority = 6, |
|
|
|
callback = fann_scores_filter |
|
|
|
}) |
|
|
|
rspamd_config:set_metric_symbol({ |
|
|
@@ -907,45 +895,24 @@ else |
|
|
|
parent = id |
|
|
|
}) |
|
|
|
if opts['train'] then |
|
|
|
rspamd_config:add_on_load(function(cfg) |
|
|
|
if opts['train']['max_train'] then |
|
|
|
max_trains = opts['train']['max_train'] |
|
|
|
end |
|
|
|
if opts['train']['max_epoch'] then |
|
|
|
max_epoch = opts['train']['max_epoch'] |
|
|
|
end |
|
|
|
if opts['train']['max_usages'] then |
|
|
|
max_usages = opts['train']['max_usages'] |
|
|
|
end |
|
|
|
if opts['train']['mse'] then |
|
|
|
mse = opts['train']['mse'] |
|
|
|
end |
|
|
|
local ret = cfg:register_worker_script("log_helper", |
|
|
|
function(score, req_score, results, cf, _id, extra, ev_base) |
|
|
|
-- fun.map (snd x) (fun.filter (fst x == module_id) extra) |
|
|
|
local extra_fann = fun.map(function(e) return e[2] end, |
|
|
|
fun.filter(function(e) return e[1] == module_log_id end, extra)) |
|
|
|
if use_settings then |
|
|
|
fann_train_callback(score, req_score, results, cf, |
|
|
|
tostring(_id), opts['train'], extra_fann, ev_base) |
|
|
|
else |
|
|
|
fann_train_callback(score, req_score, results, cf, '0', |
|
|
|
opts['train'], extra_fann, ev_base) |
|
|
|
end |
|
|
|
end) |
|
|
|
|
|
|
|
if not ret then |
|
|
|
rspamd_logger.errx(cfg, 'cannot find worker "log_helper"') |
|
|
|
end |
|
|
|
end) |
|
|
|
-- This is needed to pass extra tokens from worker to log_helper |
|
|
|
rspamd_plugins["fann_redis"] = { |
|
|
|
log_callback = function(task) |
|
|
|
return fun.totable(fun.map( |
|
|
|
function(tok) return {module_log_id, tok} end, |
|
|
|
rspamd_gen_metatokens(task))) |
|
|
|
end |
|
|
|
} |
|
|
|
if opts['train']['max_train'] then |
|
|
|
max_trains = opts['train']['max_train'] |
|
|
|
end |
|
|
|
if opts['train']['max_epoch'] then |
|
|
|
max_epoch = opts['train']['max_epoch'] |
|
|
|
end |
|
|
|
if opts['train']['max_usages'] then |
|
|
|
max_usages = opts['train']['max_usages'] |
|
|
|
end |
|
|
|
if opts['train']['mse'] then |
|
|
|
mse = opts['train']['mse'] |
|
|
|
end |
|
|
|
rspamd_config:register_symbol({ |
|
|
|
name = 'FANN_VECTOR_PUSH', |
|
|
|
type = 'postfilter', |
|
|
|
priority = 5, |
|
|
|
callback = ann_push_vector |
|
|
|
}) |
|
|
|
end |
|
|
|
-- Add training scripts |
|
|
|
rspamd_config:add_on_load(function(cfg, ev_base, worker) |