local rspamd_util = require "rspamd_util"
local fann_symbol_spam = 'FANNR_SPAM'
local fann_symbol_ham = 'FANNR_HAM'
-require "fun" ()
-local ucl = require "ucl"
+local fun = require "fun"
local module_log_id = 0x200
-- Module vars
end
if not addr then
- logger.errx(task, 'cannot select server to make redis request')
+ rspamd_logger.errx(cfg, 'cannot select server to make redis request')
end
local options = {
local matched_symbols = {}
local n = rspamd_config:get_symbols_count()
- each(function(s, score)
+ fun.each(function(s, score)
matched_symbols[s + 1] = rspamd_util.tanh(score)
- end, zip(syms, scores))
+ end, fun.zip(syms, scores))
for i=1,n do
if matched_symbols[i] then
fanns[id].fann = ann
rspamd_logger.infox(rspamd_config, 'loaded ann %s from redis', id)
else
- local function redis_invalidate_cb(err, data)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, err)
- elseif type(data) == 'string' then
- rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, err)
+ local function redis_invalidate_cb(_err, _data)
+ if _err then
+ rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, _err)
+ elseif type(_data) == 'string' then
+ rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, _err)
end
end
-- Invalidate ANN
end
end
-local function fann_train_callback(score, required_score, results, cf, id, opts, extra, ev_base)
+local function fann_train_callback(score, required_score, results, _, id, opts, extra, ev_base)
local fname,suffix = gen_fann_prefix(id)
- local learn_spam, learn_ham = false, false
+ local learn_spam, learn_ham
if opts['spam_score'] then
learn_spam = score >= opts['spam_score']
local k
if learn_spam then k = 'spam' else k = 'ham' end
- local function learn_vec_cb(err, data)
+ local function learn_vec_cb(err)
if err then
rspamd_logger.errx(rspamd_config, 'cannot store train vector: %s', err)
end
local function can_train_cb(err, data)
if not err and tonumber(data) > 0 then
local learn_data = symbols_to_fann_vector(
- map(function(r) return r[1] end, results),
- map(function(r) return r[2] end, results)
+ fun.map(function(r) return r[1] end, results),
+ fun.map(function(r) return r[2] end, results)
)
-- Add filtered meta tokens
- each(function(e) table.insert(learn_data, e) end, extra)
+ fun.each(function(e) table.insert(learn_data, e) end, extra)
local str = rspamd_util.zstd_compress(table.concat(learn_data, ';'))
redis_make_request(ev_base,
end
end
-local function train_fann(cfg, ev_base, elt)
+local function train_fann(_, ev_base, elt)
local spam_elts = {}
local ham_elts = {}
elt = tostring(elt)
- local function redis_unlock_cb(err, data)
+ local function redis_unlock_cb(err)
if err then
rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s from redis: %s',
fann_prefix .. elt, err)
end
end
- local function redis_save_cb(err, data)
+ local function redis_save_cb(err)
if err then
rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
fann_prefix .. elt, err)
)
else
-- Decompress and convert to numbers each training vector
- ham_elts = map(function(tok)
+ ham_elts = fun.map(function(tok)
local _,str = rspamd_util.zstd_decompress(tok)
- return map(tonumber, rspamd_str_split(tostring(str), ';'))
+ return fun.map(tonumber, rspamd_str_split(tostring(str), ';'))
end, data)
-- Now we need to join inputs and create the appropriate test vectors
local inputs = {}
local outputs = {}
- each(function(sample)
- table.insert(inputs, totable(sample[1]))
+ fun.each(function(sample)
+ table.insert(inputs, fun.totable(sample[1]))
table.insert(outputs, {1.0})
- table.insert(inputs, totable(sample[2]))
+ table.insert(inputs, fun.totable(sample[2]))
table.insert(outputs, {-1.0})
- end, zip(spam_elts, ham_elts))
+ end, fun.zip(spam_elts, ham_elts))
-- Now we can train fann
local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
)
else
-- Decompress and convert to numbers each training vector
- spam_elts = map(function(tok)
+ spam_elts = fun.map(function(tok)
local _,str = rspamd_util.zstd_decompress(tok)
- return map(tonumber, rspamd_str_split(tostring(str), ';'))
+ return fun.map(tonumber, rspamd_str_split(tostring(str), ';'))
end, data)
redis_make_request(ev_base,
rspamd_config,
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
elseif type(data) == 'table' then
- each(function(elt)
+ fun.each(function(elt)
elt = tostring(elt)
- local redis_len_cb = function(err, data)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err)
- elseif data and type(data) == 'number' or type(data) == 'string' then
- if tonumber(data) and tonumber(data) > max_trains then
+ local redis_len_cb = function(_err, _data)
+ if _err then
+ rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, _err)
+ elseif _data and type(_data) == 'number' or type(_data) == 'string' then
+ if tonumber(_data) and tonumber(_data) > max_trains then
rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)',
- tonumber(data), max_trains)
+ tonumber(_data), max_trains)
train_fann(cfg, ev_base, elt)
end
end
end
- local local_ver = 0
- local numelt = tonumber(elt)
- if fanns[numelt] then
- if fanns[numelt].version then
- local_ver = fanns[numelt].version
- end
- end
redis_make_request(ev_base,
rspamd_config,
nil,
return watch_interval
end
-local function check_fanns(cfg, ev_base)
+local function check_fanns(_, ev_base)
local function members_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
elseif type(data) == 'table' then
- each(function(elt)
+ fun.each(function(elt)
elt = tostring(elt)
- local redis_update_cb = function(err, data)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err)
- elseif data and type(data) == 'string' then
- load_or_invalidate_fann(data, elt, ev_base)
+ local redis_update_cb = function(_err, _data)
+ if _err then
+ rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, _err)
+ elseif _data and type(_data) == 'string' then
+ load_or_invalidate_fann(_data, elt, ev_base)
end
end
max_epoch = opts['train']['max_epoch']
end
local ret = cfg:register_worker_script("log_helper",
- function(score, req_score, results, cf, id, extra, ev_base)
- -- map (snd x) (filter (fst x == module_id) extra)
- local extra_fann = map(function(e) return e[2] end,
- filter(function(e) return e[1] == module_log_id end, extra))
+ function(score, req_score, results, cf, _id, extra, ev_base)
+ -- fun.map (snd x) (fun.filter (fst x == module_id) extra)
+ local extra_fann = fun.map(function(e) return e[2] end,
+ fun.filter(function(e) return e[1] == module_log_id end, extra))
if use_settings then
fann_train_callback(score, req_score, results, cf,
- tostring(id), opts['train'], extra_fann, ev_base)
+ tostring(_id), opts['train'], extra_fann, ev_base)
else
fann_train_callback(score, req_score, results, cf, '0',
opts['train'], extra_fann, ev_base)
-- This is needed to pass extra tokens from worker to log_helper
rspamd_plugins["fann_redis"] = {
log_callback = function(task)
- return totable(map(
+ return fun.totable(fun.map(
function(tok) return {module_log_id, tok} end,
rspamd_gen_metatokens(task)))
end
redis_maybe_load_sha = tostring(data)
rspamd_config:add_periodic(ev_base, 0.0,
- function(cfg, ev_base)
- return check_fanns(cfg, ev_base)
+ function(_cfg, _ev_base)
+ return check_fanns(_cfg, _ev_base)
end)
end
end
if worker:get_name() == 'normal' then
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
- function(cfg, ev_base)
- return maybe_train_fanns(cfg, ev_base)
+ function(_cfg, _ev_base)
+ return maybe_train_fanns(_cfg, _ev_base)
end)
end
end)