limitations under the License.
]]--
--- This plugin is a concept of FANN scores adjustment
--- NOT FOR PRODUCTION USE so far
if confighelp then
return
local rspamd_logger = require "rspamd_logger"
local rspamd_fann = require "rspamd_fann"
local rspamd_util = require "rspamd_util"
-local rspamd_redis = require "lua_redis"
+local lua_redis = require "lua_redis"
local lua_util = require "lua_util"
local fun = require "fun"
local meta_functions = require "lua_meta"
local use_torch = false
local torch
local nn
-local N = "fann_redis"
+local N = "neural"
if rspamd_config:has_torch() then
use_torch = true
}
-- ANNs indexed by settings id
-local fanns = {
+local anns = {
}
-local opts = rspamd_config:get_all_opt("fann_redis")
+local opts = rspamd_config:get_all_opt("neural")
+if not opts then
+ -- Legacy
+ opts = rspamd_config:get_all_opt("fann_redis")
+end
-- Lua script to train a row
local redis_params
local function load_scripts(params)
- redis_can_train_id = rspamd_redis.add_redis_script(redis_lua_script_can_train,
+ redis_can_train_id = lua_redis.add_redis_script(redis_lua_script_can_train,
params)
- redis_maybe_load_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_load,
+ redis_maybe_load_id = lua_redis.add_redis_script(redis_lua_script_maybe_load,
params)
- redis_maybe_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_invalidate,
+ redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
params)
- redis_locked_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_locked_invalidate,
+ redis_locked_invalidate_id = lua_redis.add_redis_script(redis_lua_script_locked_invalidate,
params)
- redis_maybe_lock_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_lock,
+ redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
params)
- redis_save_unlock_id = rspamd_redis.add_redis_script(redis_lua_script_save_unlock,
+ redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock,
params)
end
-local function gen_fann_prefix(rule, id)
+local function gen_ann_prefix(rule, id)
local cksum = rspamd_config:get_symbols_cksum():hex()
-- We also need to count metatokens:
local n = meta_functions.rspamd_count_metatokens()
end
end
-local function is_fann_valid(rule, prefix, ann)
+local function is_ann_valid(rule, prefix, ann)
if ann then
local n = rspamd_config:get_symbols_count() +
meta_functions.rspamd_count_metatokens()
end
end
-local function fann_scores_filter(task)
+local function ann_scores_filter(task)
for _,rule in pairs(settings.rules) do
local id = '0'
id = id .. r
end
- if fanns[id] and fanns[id].fann then
- local fann_data = task:get_symbols_tokens()
+ if anns[id] and anns[id].ann then
+ local ann_data = task:get_symbols_tokens()
local mt = meta_functions.rspamd_gen_metatokens(task)
-- Add filtered meta tokens
- fun.each(function(e) table.insert(fann_data, e) end, mt)
+ fun.each(function(e) table.insert(ann_data, e) end, mt)
local score
if use_torch then
- local out = fanns[id].fann:forward(torch.Tensor(fann_data))
+ local out = anns[id].ann:forward(torch.Tensor(ann_data))
score = out[1]
else
- local out = fanns[id].fann:test(fann_data)
+ local out = anns[id].ann:test(ann_data)
score = out[1]
end
local symscore = string.format('%.3f', score)
- rspamd_logger.infox(task, 'fann score: %s', symscore)
+ rspamd_logger.infox(task, 'ann score: %s', symscore)
if score > 0 then
local result = score
end
end
-local function create_fann(n, nlayers)
+local function create_ann(n, nlayers)
if use_torch then
-- We ignore number of layers so far when using torch
local ann = nn.Sequential()
end
end
-local function create_train_fann(rule, n, id)
- local prefix = gen_fann_prefix(rule, id)
- if not fanns[id] then
- fanns[id] = {}
+local function create_train_ann(rule, n, id)
+ local prefix = gen_ann_prefix(rule, id)
+ if not anns[id] then
+ anns[id] = {}
end
-- Fix that for flexibe layers number
- if fanns[id].fann then
- if not is_fann_valid(rule, prefix, fanns[id].fann) then
- fanns[id].fann_train = create_fann(n, rule.nlayers)
- fanns[id].fann = nil
+ if anns[id].ann then
+ if not is_ann_valid(rule, prefix, anns[id].ann) then
+ anns[id].ann_train = create_ann(n, rule.nlayers)
+ anns[id].ann = nil
rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
- elseif rule.train.max_usages > 0 and fanns[id].version % rule.train.max_usages == 0 then
- -- Forget last fann
+ elseif rule.train.max_usages > 0 and anns[id].version % rule.train.max_usages == 0 then
+ -- Forget last ann
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
- fanns[id].version)
- fanns[id].fann_train = create_fann(n, rule.nlayers)
+ anns[id].version)
+ anns[id].ann_train = create_ann(n, rule.nlayers)
else
- fanns[id].fann_train = fanns[id].fann
+ anns[id].ann_train = anns[id].ann
rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix)
end
else
- fanns[id].fann_train = create_fann(n, rule.nlayers)
+ anns[id].ann_train = create_ann(n, rule.nlayers)
rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix)
- fanns[id].version = 0
+ anns[id].version = 0
end
end
-local function load_or_invalidate_fann(rule, data, id, ev_base)
+local function load_or_invalidate_ann(rule, data, id, ev_base)
local ver = data[2]
- local prefix = gen_fann_prefix(rule, id)
+ local prefix = gen_ann_prefix(rule, id)
if not ver or not tonumber(ver) then
rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix)
end
end
- if is_fann_valid(rule, prefix, ann) then
- if not fanns[id] then fanns[id] = {} end
- fanns[id].fann = ann
+ if is_ann_valid(rule, prefix, ann) then
+ if not anns[id] then anns[id] = {} end
+ anns[id].ann = ann
rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
prefix, ver)
- fanns[id].version = tonumber(ver)
+ anns[id].version = tonumber(ver)
else
local function redis_invalidate_cb(_err, _data)
if _err then
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
elseif type(_data) == 'string' then
rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
- fanns[id].version = 0
+ anns[id].version = 0
end
end
-- Invalidate ANN
rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix)
- rspamd_redis.exec_redis_script(redis_maybe_invalidate_id,
+ lua_redis.exec_redis_script(redis_maybe_invalidate_id,
{ev_base = ev_base, is_write = true},
redis_invalidate_cb,
{prefix})
end
end
-local function fann_train_callback(rule, task, score, required_score, id)
+local function ann_train_callback(rule, task, score, required_score, id)
local train_opts = rule['train']
- local fname,suffix = gen_fann_prefix(rule, id)
+ local fname,suffix = gen_ann_prefix(rule, id)
local learn_spam, learn_ham
rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
return
end
- local fann_data = task:get_symbols_tokens()
+ local ann_data = task:get_symbols_tokens()
local mt = meta_functions.rspamd_gen_metatokens(task)
-- Add filtered meta tokens
- fun.each(function(e) table.insert(fann_data, e) end, mt)
+ fun.each(function(e) table.insert(ann_data, e) end, mt)
-- Check NaNs in train data
- if fun.all(function(e) return e == e end, fann_data) then
- local str = rspamd_util.zstd_compress(table.concat(fann_data, ';'))
+ if fun.all(function(e) return e == e end, ann_data) then
+ local str = rspamd_util.zstd_compress(table.concat(ann_data, ';'))
vec_len = #str
- rspamd_redis.redis_make_request(task,
+ lua_redis.redis_make_request(task,
rule.redis,
nil,
true, -- is write
)
else
rspamd_logger.errx(task, "do not store learn vector as it contains %s NaN values",
- fun.length(fun.filter(function(e) return e ~= e end, fann_data)))
+ fun.length(fun.filter(function(e) return e ~= e end, ann_data)))
end
else
end
end
- rspamd_redis.exec_redis_script(redis_can_train_id,
+ lua_redis.exec_redis_script(redis_can_train_id,
{task = task, is_write = true},
can_train_cb,
- {gen_fann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)})
+ {gen_ann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)})
end
end
-local function train_fann(rule, _, ev_base, elt, worker)
+local function train_ann(rule, _, ev_base, elt, worker)
local spam_elts = {}
local ham_elts = {}
elt = tostring(elt)
- local prefix = gen_fann_prefix(rule, elt)
+ local prefix = gen_ann_prefix(rule, elt)
local function redis_unlock_cb(err)
if err then
if err then
rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
prefix, err)
- rspamd_redis.redis_make_request_taskless(ev_base,
+ lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
if errcode ~= 0 then
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
prefix, errmsg)
- rspamd_redis.redis_make_request_taskless(ev_base,
+ lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
local ann_data
if use_torch then
local f = torch.MemoryFile()
- f:writeObject(fanns[elt].fann_train)
+ f:writeObject(anns[elt].ann_train)
ann_data = rspamd_util.zstd_compress(f:storage():string())
else
- ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
+ ann_data = rspamd_util.zstd_compress(anns[elt].ann_train:data())
end
- fanns[elt].version = fanns[elt].version + 1
- fanns[elt].fann = fanns[elt].fann_train
- fanns[elt].fann_train = nil
- rspamd_redis.exec_redis_script(redis_save_unlock_id,
+ anns[elt].version = anns[elt].version + 1
+ anns[elt].ann = anns[elt].ann_train
+ anns[elt].ann_train = nil
+ lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
redis_save_cb,
{prefix, tostring(ann_data), tostring(rule.ann_expire)})
if err then
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
prefix, err)
- rspamd_redis.redis_make_request_taskless(ev_base,
+ lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
local ann_data
local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
ann_data = rspamd_util.zstd_compress(f:storage():string())
- fanns[elt].fann_train = f:readObject()
+ anns[elt].ann_train = f:readObject()
- fanns[elt].version = fanns[elt].version + 1
- fanns[elt].fann = fanns[elt].fann_train
- fanns[elt].fann_train = nil
- rspamd_redis.exec_redis_script(redis_save_unlock_id,
+ anns[elt].version = anns[elt].version + 1
+ anns[elt].ann = anns[elt].ann_train
+ anns[elt].ann_train = nil
+ lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
redis_save_cb,
{prefix, tostring(ann_data), tostring(rule.ann_expire)})
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
prefix, err)
- rspamd_redis.redis_make_request_taskless(ev_base,
+ lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
return #elts == n
end
- -- Now we can train fann
- if not fanns[elt] or not fanns[elt].fann_train then
- -- Create fann if it does not exist
- create_train_fann(rule, n, elt)
+ -- Now we can train ann
+ if not anns[elt] or not anns[elt].ann_train then
+ -- Create ann if it does not exist
+ create_train_ann(rule, n, elt)
end
if #spam_elts + #ham_elts < rule.train.max_trains / 2 then
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
elseif type(_data) == 'string' then
rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
- fanns[elt].version = 0
+ anns[elt].version = 0
end
end
-- Invalidate ANN
rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix)
- rspamd_redis.exec_redis_script(redis_locked_invalidate_id,
+ lua_redis.exec_redis_script(redis_locked_invalidate_id,
{ev_base = ev_base, is_write = true},
redis_invalidate_cb,
{prefix})
torch.setnumthreads(rule.train.learn_threads)
end
local criterion = nn.MSECriterion()
- local trainer = nn.StochasticGradient(fanns[elt].fann_train,
+ local trainer = nn.StochasticGradient(anns[elt].ann_train,
criterion)
trainer.learning_rate = 0.01
trainer.verbose = false
trainer:train(dataset)
local out = torch.MemoryFile()
- out:writeObject(fanns[elt].fann_train)
+ out:writeObject(anns[elt].ann_train)
local st = out:storage():string()
return st
end
end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
rule.learning_spawned = true
rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
- fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained,
+ anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
ev_base, {
max_epochs = rule.train.max_epoch,
desired_mse = rule.train.mse
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
prefix, err)
- rspamd_redis.redis_make_request_taskless(ev_base,
+ lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
local _,str = rspamd_util.zstd_decompress(tok)
return fun.totable(fun.map(tonumber, rspamd_str_split(tostring(str), ';')))
end, data))
- rspamd_redis.redis_make_request_taskless(ev_base,
+ lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
prefix, err)
elseif type(data) == 'number' then
-- Can train ANN
- rspamd_redis.redis_make_request_taskless(ev_base,
+ lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
end
end
if rule.learning_spawned then
- rspamd_redis.redis_make_request_taskless(ev_base,
+ lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix)
return
end
- rspamd_redis.exec_redis_script(redis_maybe_lock_id,
+ lua_redis.exec_redis_script(redis_maybe_lock_id,
{ev_base = ev_base, is_write = true},
redis_lock_cb,
{prefix, tostring(os.time()), tostring(rule.lock_expire), rspamd_util.get_hostname()})
end
-local function maybe_train_fanns(rule, cfg, ev_base, worker)
+local function maybe_train_anns(rule, cfg, ev_base, worker)
local function members_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
elseif type(data) == 'table' then
fun.each(function(elt)
elt = tostring(elt)
- local prefix = gen_fann_prefix(rule, elt)
+ local prefix = gen_ann_prefix(rule, elt)
rspamd_logger.infox(cfg, "check ANN %s", prefix)
local redis_len_cb = function(_err, _data)
if _err then
rspamd_logger.infox(rspamd_config,
'need to learn ANN %s after %s learn vectors (%s required)',
prefix, tonumber(_data), rule.train.max_trains)
- train_fann(rule, cfg, ev_base, elt, worker)
+ train_ann(rule, cfg, ev_base, elt, worker)
else
rspamd_logger.infox(rspamd_config,
'no need to learn ANN %s %s learn vectors (%s required)',
end
end
- rspamd_redis.redis_make_request_taskless(ev_base,
+ lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
end
end
- -- First we need to get all fanns stored in our Redis
- rspamd_redis.redis_make_request_taskless(ev_base,
+ -- First we need to get all anns stored in our Redis
+ lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false, -- is write
members_cb, --callback
'SMEMBERS', -- command
- {gen_fann_prefix(rule, nil)} -- arguments
+ {gen_ann_prefix(rule, nil)} -- arguments
)
return rule.watch_interval
end
-local function check_fanns(rule, _, ev_base)
+local function check_anns(rule, _, ev_base)
local function members_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s',
rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s',
elt, _err)
elseif _data and type(_data) == 'table' then
- load_or_invalidate_fann(rule, _data, elt, ev_base)
+ load_or_invalidate_ann(rule, _data, elt, ev_base)
else
if type(_data) ~= 'number' then
rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s',
end
local local_ver = 0
- if fanns[elt] then
- if fanns[elt].version then
- local_ver = fanns[elt].version
+ if anns[elt] then
+ if anns[elt].version then
+ local_ver = anns[elt].version
end
end
- rspamd_redis.exec_redis_script(redis_maybe_load_id,
+ lua_redis.exec_redis_script(redis_maybe_load_id,
{ev_base = ev_base, is_write = false},
redis_update_cb,
- {gen_fann_prefix(rule, elt), tostring(local_ver)})
+ {gen_ann_prefix(rule, elt), tostring(local_ver)})
end,
data)
end
end
- -- First we need to get all fanns stored in our Redis
- rspamd_redis.redis_make_request_taskless(ev_base,
+ -- First we need to get all anns stored in our Redis
+ lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false, -- is write
members_cb, --callback
'SMEMBERS', -- command
- {gen_fann_prefix(rule, nil)} -- arguments
+ {gen_ann_prefix(rule, nil)} -- arguments
)
return rule.watch_interval
local r = task:get_principal_recipient()
sid = sid .. r
end
- fann_train_callback(rule, task, scores[1], scores[2], sid)
+ ann_train_callback(rule, task, scores[1], scores[2], sid)
end
end
-redis_params = rspamd_parse_redis_server('fann_redis')
+redis_params = lua_redis.parse_redis_server('neural')
+
+if not redis_params then
+ redis_params = lua_redis.parse_redis_server('fann_redis')
+end
-- Initialization part
if not (opts and type(opts) == 'table') or not redis_params then
return
end
-if not rspamd_fann.is_enabled() then
- rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
+if not rspamd_fann.is_enabled() and not use_torch then
+ rspamd_logger.errx(rspamd_config, 'neural networks support is not compiled in rspamd, this ' ..
'module is eventually disabled')
lua_util.disable_module(N, "fail")
return
name = 'FANN_CHECK',
type = 'postfilter,nostat',
priority = 6,
- callback = fann_scores_filter
+ callback = ann_scores_filter
})
local function deepcopy(orig)
name = def_rules.symbol_spam,
score = 3.0,
description = 'Neural network SPAM',
- group = 'fann'
+ group = 'neural'
})
rspamd_config:register_symbol({
name = def_rules.symbol_spam,
name = def_rules.symbol_ham,
score = -2.0,
description = 'Neural network HAM',
- group = 'fann'
+ group = 'neural'
})
rspamd_config:register_symbol({
name = def_rules.symbol_ham,
for _,rule in pairs(settings.rules) do
load_scripts(rule.redis)
rspamd_config:add_on_load(function(cfg, ev_base, worker)
- check_fanns(rule, cfg, ev_base)
+ check_anns(rule, cfg, ev_base)
if worker:is_primary_controller() then
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
- return maybe_train_fanns(rule, cfg, ev_base, worker)
+ return maybe_train_anns(rule, cfg, ev_base, worker)
end)
end
end)