aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-07-30 09:56:52 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-07-30 09:56:52 +0100
commit7080268d601feea9d127b7a22362eea505d93967 (patch)
tree0996ca902c2c8dd5ea9f91fac922c36db8d2328a /src/plugins
parente9261b7c8ebcea93e0b8bbe6fc847d1168a25a4d (diff)
downloadrspamd-7080268d601feea9d127b7a22362eea505d93967.tar.gz
rspamd-7080268d601feea9d127b7a22362eea505d93967.zip
[Feature] Allow multiple fann rules
Diffstat (limited to 'src/plugins')
-rw-r--r--src/plugins/lua/fann_redis.lua348
1 files changed, 186 insertions, 162 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index 378000a24..f2b6c44f6 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -24,17 +24,36 @@ end
local rspamd_logger = require "rspamd_logger"
local rspamd_fann = require "rspamd_fann"
local rspamd_util = require "rspamd_util"
-local fann_symbol_spam = 'FANNR_SPAM'
-local fann_symbol_ham = 'FANNR_HAM'
local rspamd_redis = require "lua_redis"
local fun = require "fun"
local meta_functions = require "meta_functions"
+
-- Module vars
+local default_options = {
+ train = {
+ max_trains = 1000,
+ max_epoch = 1000,
+ max_usages = 10,
+ use_settings = false,
+ watch_interval = 60.0,
+ mse = 0.001,
+ autotrain = true,
+ },
+ nlayers = 4,
+ lock_expire = 600,
+ learning_spawned = false,
+ ann_expire = 60 * 60 * 24 * 2, -- 2 days
+ symbol_spam = 'FANNR_SPAM',
+ symbol_ham = 'FANNR_HAM',
+}
+
+local settings = {
+ rules = {
+ }
+}
+
-- ANNs indexed by settings id
local fanns = {
- ['0'] = {
- version = 0,
- }
}
local opts = rspamd_config:get_all_opt("fann_redis")
@@ -162,19 +181,6 @@ local redis_lua_script_save_unlock = [[
local redis_save_unlock_sha = nil
local redis_params
-redis_params = rspamd_parse_redis_server('fann_redis')
-
-local fann_prefix = 'RFANN'
-local max_trains = 1000
-local max_epoch = 1000
-local max_usages = 10
-local use_settings = false
-local watch_interval = 60.0
-local mse = 0.0001
-local nlayers = 4
-local lock_expire = 600
-local learning_spawned = false
-local ann_expire = 60 * 60 * 24 * 2 -- 2 days
local function load_scripts(cfg, ev_base, on_load_cb)
local function can_train_sha_cb(err, data)
@@ -287,15 +293,16 @@ local function load_scripts(cfg, ev_base, on_load_cb)
)
end
-local function gen_fann_prefix(id)
+local function gen_fann_prefix(rule, id)
if id then
- return fann_prefix .. rspamd_config:get_symbols_cksum():hex() .. id,id
+ return rule.prefix .. rspamd_config:get_symbols_cksum():hex() .. id,
+ rule.prefix .. id
else
- return fann_prefix .. rspamd_config:get_symbols_cksum():hex(), nil
+ return rule.prefix .. rspamd_config:get_symbols_cksum():hex(), nil
end
end
-local function is_fann_valid(prefix, ann)
+local function is_fann_valid(rule, prefix, ann)
if ann then
local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
@@ -306,7 +313,7 @@ local function is_fann_valid(prefix, ann)
end
local layers = ann:get_layers()
- if not layers or #layers ~= nlayers then
+ if not layers or #layers ~= rule.nlayers then
rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
prefix, #layers)
return false
@@ -317,65 +324,69 @@ local function is_fann_valid(prefix, ann)
end
local function fann_scores_filter(task)
- local id = '0'
- if use_settings then
- local sid = task:get_settings_id()
- if sid then
- id = tostring(sid)
- end
- end
+ for _,rule in settings.rules do
+ local id = rule.prefix .. '0'
+ if rule.use_settings then
+ local sid = task:get_settings_id()
+ if sid then
+ id = rule.prefix .. tostring(sid)
+ end
+ end
- if fanns[id].fann then
- local fann_data = task:get_symbols_tokens()
- local mt = meta_functions.rspamd_gen_metatokens(task)
- -- Add filtered meta tokens
- fun.each(function(e) table.insert(fann_data, e) end, mt)
-
- local out = fanns[id].fann:test(fann_data)
- local symscore = string.format('%.3f', out[1])
- rspamd_logger.infox(task, 'fann score: %s', symscore)
-
- if out[1] > 0 then
- local result = rspamd_util.normalize_prob(out[1] / 2.0, 0)
- task:insert_result(fann_symbol_spam, result, symscore, id)
- else
- local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
- task:insert_result(fann_symbol_ham, result, symscore, id)
+ if fanns[id].fann then
+ local fann_data = task:get_symbols_tokens()
+ local mt = meta_functions.rspamd_gen_metatokens(task)
+ -- Add filtered meta tokens
+ fun.each(function(e) table.insert(fann_data, e) end, mt)
+
+ local out = fanns[id].fann:test(fann_data)
+ local symscore = string.format('%.3f', out[1])
+ rspamd_logger.infox(task, 'fann score: %s', symscore)
+
+ if out[1] > 0 then
+ local result = rspamd_util.normalize_prob(out[1] / 2.0, 0)
+ task:insert_result(rule.symbol_spam, result, symscore, id)
+ else
+ local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
+ task:insert_result(rule.symbol_ham, result, symscore, id)
+ end
end
end
end
-local function create_train_fann(n, id)
- id = tostring(id)
- local prefix = gen_fann_prefix(id)
+local function create_train_fann(rule, n, id)
+ id = rule.prefix .. tostring(id)
+ local prefix = gen_fann_prefix(rule, id)
if not fanns[id] then
fanns[id] = {}
end
-
+ -- Fix that for flexibe layers number
if fanns[id].fann then
- if n ~= fanns[id].fann:get_inputs() or
+ if n ~= fanns[id].fann:get_inputs() or --
(fanns[id].fann_train and n ~= fanns[id].fann_train:get_inputs()) then
- rspamd_logger.infox(rspamd_config, 'recreate ANN %s as it has a wrong number of inputs, version %s', prefix,
+ rspamd_logger.infox(rspamd_config,
+ 'recreate ANN %s as it has a wrong number of inputs, version %s',
+ prefix,
fanns[id].version)
- fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
+ fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
fanns[id].fann = nil
- elseif fanns[id].version % max_usages == 0 then
+ elseif fanns[id].version % rule.max_usages == 0 then
-- Forget last fann
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
fanns[id].version)
- fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
+ fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
else
fanns[id].fann_train = fanns[id].fann
end
else
- fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
+ fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
fanns[id].version = 0
end
end
-local function load_or_invalidate_fann(data, id, ev_base)
+local function load_or_invalidate_fann(rule, data, id, ev_base)
local ver = data[2]
- local prefix = gen_fann_prefix(id)
+ local prefix = gen_fann_prefix(rule, id)
if not ver or not tonumber(ver) then
rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix)
@@ -392,7 +403,7 @@ local function load_or_invalidate_fann(data, id, ev_base)
ann = rspamd_fann.load_data(ann_data)
end
- if is_fann_valid(prefix, ann) then
+ if is_fann_valid(rule, prefix, ann) then
fanns[id].fann = ann
rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
prefix, ver)
@@ -413,7 +424,7 @@ local function load_or_invalidate_fann(data, id, ev_base)
rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix)
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
true, -- is write
redis_invalidate_cb, --callback
@@ -423,9 +434,9 @@ local function load_or_invalidate_fann(data, id, ev_base)
end
end
-local function fann_train_callback(task, score, required_score, id)
- local train_opts = opts['train']
- local fname,suffix = gen_fann_prefix(id)
+local function fann_train_callback(rule, task, score, required_score, id)
+ local train_opts = rule['train']
+ local fname,suffix = gen_fann_prefix(rule, id)
local learn_spam, learn_ham
@@ -459,7 +470,7 @@ local function fann_train_callback(task, score, required_score, id)
local str = rspamd_util.zstd_compress(table.concat(fann_data, ';'))
rspamd_redis.redis_make_request(task,
- redis_params,
+ rule.redis,
nil,
true, -- is write
learn_vec_cb, --callback
@@ -477,22 +488,22 @@ local function fann_train_callback(task, score, required_score, id)
end
rspamd_redis.rspamd_redis_make_request(task,
- redis_params,
+ rule.redis,
nil,
true, -- is write
can_train_cb, --callback
'EVALSHA', -- command
- {redis_can_train_sha, '4', gen_fann_prefix(nil),
+ {redis_can_train_sha, '4', gen_fann_prefix(rule, nil),
suffix, k, tostring(max_trains)} -- arguments
)
end
end
-local function train_fann(_, ev_base, elt)
+local function train_fann(rule, _, ev_base, elt)
local spam_elts = {}
local ham_elts = {}
elt = tostring(elt)
- local prefix = gen_fann_prefix(elt)
+ local prefix = gen_fann_prefix(rule, elt)
local function redis_unlock_cb(err)
if err then
@@ -507,7 +518,7 @@ local function train_fann(_, ev_base, elt)
prefix, err)
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
false, -- is write
redis_unlock_cb, --callback
@@ -521,13 +532,13 @@ local function train_fann(_, ev_base, elt)
end
local function ann_trained(errcode, errmsg, train_mse)
- learning_spawned = false
+ rule.learning_spawned = false
if errcode ~= 0 then
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
prefix, errmsg)
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
true, -- is write
redis_unlock_cb, --callback
@@ -543,7 +554,7 @@ local function train_fann(_, ev_base, elt)
fanns[elt].fann_train = nil
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
true, -- is write
redis_save_cb, --callback
@@ -559,7 +570,7 @@ local function train_fann(_, ev_base, elt)
prefix, err)
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
true, -- is write
redis_unlock_cb, --callback
@@ -593,7 +604,7 @@ local function train_fann(_, ev_base, elt)
if not fanns[elt] or not fanns[elt].fann_train
or n ~= fanns[elt].fann_train:get_inputs() then
-- Create fann if it does not exist
- create_train_fann(n, elt)
+ create_train_fann(rule, n, elt)
end
if #inputs < max_trains / 2 then
@@ -610,7 +621,7 @@ local function train_fann(_, ev_base, elt)
rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix)
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
true, -- is write
redis_invalidate_cb, --callback
@@ -618,10 +629,13 @@ local function train_fann(_, ev_base, elt)
{redis_locked_invalidate_sha, 1, prefix}
)
else
- learning_spawned = true
+ rule.learning_spawned = true
rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
- fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base,
- {max_epochs = max_epoch, desired_mse = mse})
+ fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained,
+ ev_base, {
+ max_epochs = rule.train.max_epoch,
+ desired_mse = rule.train.mse
+ })
end
end
end
@@ -632,7 +646,7 @@ local function train_fann(_, ev_base, elt)
prefix, err)
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
true, -- is write
redis_unlock_cb, --callback
@@ -647,7 +661,7 @@ local function train_fann(_, ev_base, elt)
end, data))
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
false, -- is write
redis_ham_cb, --callback
@@ -668,7 +682,7 @@ local function train_fann(_, ev_base, elt)
-- Can train ANN
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
false, -- is write
redis_spam_cb, --callback
@@ -687,10 +701,10 @@ local function train_fann(_, ev_base, elt)
prefix)
end
end
- if learning_spawned then
+ if rule.learning_spawned then
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
true, -- is write
redis_lock_extend_cb, --callback
@@ -709,45 +723,47 @@ local function train_fann(_, ev_base, elt)
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', prefix)
end
end
- if learning_spawned then
+ if rule.learning_spawned then
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix)
return
end
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
true, -- is write
redis_lock_cb, --callback
'EVALSHA', -- command
{redis_maybe_lock_sha, '4', prefix, tostring(os.time()),
- tostring(lock_expire), rspamd_util.get_hostname()}
+ tostring(rule.lock_expire), rspamd_util.get_hostname()}
)
end
-local function maybe_train_fanns(cfg, ev_base)
+local function maybe_train_fanns(rule, cfg, ev_base)
local function members_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
elseif type(data) == 'table' then
fun.each(function(elt)
elt = tostring(elt)
- local prefix = gen_fann_prefix(elt)
+ local prefix = gen_fann_prefix(rule, elt)
local redis_len_cb = function(_err, _data)
if _err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', prefix, _err)
+ rspamd_logger.errx(rspamd_config,
+ 'cannot get FANN trains %s from redis: %s', prefix, _err)
elseif _data and type(_data) == 'number' or type(_data) == 'string' then
if tonumber(_data) and tonumber(_data) >= max_trains then
- rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)',
+ rspamd_logger.infox(rspamd_config,
+ 'need to learn ANN %s after %s learn vectors (%s required)',
prefix, tonumber(_data), max_trains)
- train_fann(cfg, ev_base, elt)
+ train_fann(rule, cfg, ev_base, elt)
end
end
end
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
false, -- is write
redis_len_cb, --callback
@@ -766,18 +782,18 @@ local function maybe_train_fanns(cfg, ev_base)
-- First we need to get all fanns stored in our Redis
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
false, -- is write
members_cb, --callback
'SMEMBERS', -- command
- {gen_fann_prefix(nil)} -- arguments
+ {gen_fann_prefix(rule, nil)} -- arguments
)
- return watch_interval
+ return rule.watch_interval
end
-local function check_fanns(_, ev_base)
+local function check_fanns(rule, _, ev_base)
local function members_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
@@ -791,7 +807,7 @@ local function check_fanns(_, ev_base)
load_scripts(rspamd_config, ev_base, nil)
end
elseif _data and type(_data) == 'table' then
- load_or_invalidate_fann(_data, elt, ev_base)
+ load_or_invalidate_fann(rule, _data, elt, ev_base)
end
end
@@ -803,12 +819,12 @@ local function check_fanns(_, ev_base)
end
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
false, -- is write
redis_update_cb, --callback
'EVALSHA', -- command
- {redis_maybe_load_sha, 2, gen_fann_prefix(elt), tostring(local_ver)}
+ {redis_maybe_load_sha, 2, gen_fann_prefix(rule, elt), tostring(local_ver)}
)
end,
data)
@@ -822,27 +838,32 @@ local function check_fanns(_, ev_base)
-- First we need to get all fanns stored in our Redis
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
- redis_params,
+ rule.redis,
nil,
false, -- is write
members_cb, --callback
'SMEMBERS', -- command
- {gen_fann_prefix(nil)} -- arguments
+ {gen_fann_prefix(rule, nil)} -- arguments
)
- return watch_interval
+ return rule.watch_interval
end
local function ann_push_vector(task)
local scores = task:get_metric_score()
local sid = task:get_settings_id()
- if use_settings then
- fann_train_callback(task, scores[1], scores[2], tostring(sid))
- else
- fann_train_callback(task, scores[1], scores[2], "1")
+
+ for _,rule in settings.rules do
+ if rule.use_settings then
+ fann_train_callback(rule, task, scores[1], scores[2], tostring(sid))
+ else
+ fann_train_callback(rule, task, scores[1], scores[2], "0")
+ end
end
end
+redis_params = rspamd_parse_redis_server('fann_redis')
+
-- Initialization part
if not (opts and type(opts) == 'table') or not redis_params then
rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
@@ -854,72 +875,75 @@ if not rspamd_fann.is_enabled() then
'module is eventually disabled')
return
else
- use_settings = opts['use_settings']
- if opts['spam_symbol'] then
- fann_symbol_spam = opts['spam_symbol']
- end
- if opts['ham_symbol'] then
- fann_symbol_ham = opts['ham_symbol']
- end
- if opts['prefix'] then
- fann_prefix = opts['prefix']
- end
- if opts['lock_expire'] then
- lock_expire = tonumber(opts['lock_expire'])
+ local rules = opts['rules']
+
+ if not rules then
+ -- Use legacy configuration
+ rules = {}
+ rules['RFANN'] = opts
end
- rspamd_config:set_metric_symbol({
- name = fann_symbol_spam,
- score = 3.0,
- description = 'Neural network SPAM',
- group = 'fann'
- })
+
local id = rspamd_config:register_symbol({
- name = fann_symbol_spam,
+ name = 'FANN_CHECK',
type = 'postfilter,nostat',
priority = 6,
callback = fann_scores_filter
})
- rspamd_config:set_metric_symbol({
- name = fann_symbol_ham,
- score = -2.0,
- description = 'Neural network HAM',
- group = 'fann'
- })
- rspamd_config:register_symbol({
- name = fann_symbol_ham,
- type = 'virtual,nostat',
- parent = id
- })
- if opts['train'] then
- if opts['train']['max_train'] then
- max_trains = opts['train']['max_train']
- end
- if opts['train']['max_epoch'] then
- max_epoch = opts['train']['max_epoch']
- end
- if opts['train']['max_usages'] then
- max_usages = opts['train']['max_usages']
+
+ for k,r in rules do
+ rules[k] = default_options
+ rules[k]['redis'] = redis_params
+ local cur = rules[k]
+ -- Override defaults
+ for sk,v in r do
+ cur[sk] = v
end
- if opts['train']['mse'] then
- mse = opts['train']['mse']
+ if not cur.prefix then
+ cur.prefix = k
end
+ rspamd_config:set_metric_symbol({
+ name = cur.symbol_spam,
+ score = 3.0,
+ description = 'Neural network SPAM',
+ group = 'fann'
+ })
+
+ rspamd_config:set_metric_symbol({
+ name = cur.symbol_ham,
+ score = -2.0,
+ description = 'Neural network HAM',
+ group = 'fann'
+ })
rspamd_config:register_symbol({
- name = 'FANN_VECTOR_PUSH',
- type = 'postfilter,nostat',
- priority = 5,
- callback = ann_push_vector
+ name = cur.symbol_ham,
+ type = 'virtual,nostat',
+ parent = id
})
end
+
+ rspamd_config:register_symbol({
+ name = 'FANN_VECTOR_PUSH',
+ type = 'postfilter,nostat',
+ priority = 5,
+ callback = ann_push_vector
+ })
+
+ settings.rules = rules
+
-- Add training scripts
- rspamd_config:add_on_load(function(cfg, ev_base, worker)
- load_scripts(cfg, ev_base, check_fanns)
-
- if worker:get_name() == 'normal' then
- -- We also want to train neural nets when they have enough data
- rspamd_config:add_periodic(ev_base, 0.0,
- function(_cfg, _ev_base)
- return maybe_train_fanns(_cfg, _ev_base)
- end)
- end
- end)
+ for k,rule in settings.rules do
+ rspamd_config:add_on_load(function(cfg, ev_base, worker)
+ load_scripts(cfg, ev_base, function(cfg, ev_base)
+ check_fanns(rule, cfg, ev_base)
+ end)
+
+ if worker:get_name() == 'normal' then
+ -- We also want to train neural nets when they have enough data
+ rspamd_config:add_periodic(ev_base, 0.0,
+ function(_cfg, _ev_base)
+ return maybe_train_fanns(rule, _cfg, _ev_base)
+ end)
+ end
+ end)
+ end
end