diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2017-07-30 09:56:52 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2017-07-30 09:56:52 +0100 |
commit | 7080268d601feea9d127b7a22362eea505d93967 (patch) | |
tree | 0996ca902c2c8dd5ea9f91fac922c36db8d2328a /src/plugins | |
parent | e9261b7c8ebcea93e0b8bbe6fc847d1168a25a4d (diff) | |
download | rspamd-7080268d601feea9d127b7a22362eea505d93967.tar.gz rspamd-7080268d601feea9d127b7a22362eea505d93967.zip |
[Feature] Allow multiple fann rules
Diffstat (limited to 'src/plugins')
-rw-r--r-- | src/plugins/lua/fann_redis.lua | 348 |
1 files changed, 186 insertions, 162 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index 378000a24..f2b6c44f6 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -24,17 +24,36 @@ end local rspamd_logger = require "rspamd_logger" local rspamd_fann = require "rspamd_fann" local rspamd_util = require "rspamd_util" -local fann_symbol_spam = 'FANNR_SPAM' -local fann_symbol_ham = 'FANNR_HAM' local rspamd_redis = require "lua_redis" local fun = require "fun" local meta_functions = require "meta_functions" + -- Module vars +local default_options = { + train = { + max_trains = 1000, + max_epoch = 1000, + max_usages = 10, + use_settings = false, + watch_interval = 60.0, + mse = 0.001, + autotrain = true, + }, + nlayers = 4, + lock_expire = 600, + learning_spawned = false, + ann_expire = 60 * 60 * 24 * 2, -- 2 days + symbol_spam = 'FANNR_SPAM', + symbol_ham = 'FANNR_HAM', +} + +local settings = { + rules = { + } +} + -- ANNs indexed by settings id local fanns = { - ['0'] = { - version = 0, - } } local opts = rspamd_config:get_all_opt("fann_redis") @@ -162,19 +181,6 @@ local redis_lua_script_save_unlock = [[ local redis_save_unlock_sha = nil local redis_params -redis_params = rspamd_parse_redis_server('fann_redis') - -local fann_prefix = 'RFANN' -local max_trains = 1000 -local max_epoch = 1000 -local max_usages = 10 -local use_settings = false -local watch_interval = 60.0 -local mse = 0.0001 -local nlayers = 4 -local lock_expire = 600 -local learning_spawned = false -local ann_expire = 60 * 60 * 24 * 2 -- 2 days local function load_scripts(cfg, ev_base, on_load_cb) local function can_train_sha_cb(err, data) @@ -287,15 +293,16 @@ local function load_scripts(cfg, ev_base, on_load_cb) ) end -local function gen_fann_prefix(id) +local function gen_fann_prefix(rule, id) if id then - return fann_prefix .. rspamd_config:get_symbols_cksum():hex() .. id,id + return rule.prefix .. rspamd_config:get_symbols_cksum():hex() .. id, + rule.prefix .. id else - return fann_prefix .. rspamd_config:get_symbols_cksum():hex(), nil + return rule.prefix .. rspamd_config:get_symbols_cksum():hex(), nil end end -local function is_fann_valid(prefix, ann) +local function is_fann_valid(rule, prefix, ann) if ann then local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens() @@ -306,7 +313,7 @@ local function is_fann_valid(prefix, ann) end local layers = ann:get_layers() - if not layers or #layers ~= nlayers then + if not layers or #layers ~= rule.nlayers then rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s', prefix, #layers) return false @@ -317,65 +324,69 @@ local function is_fann_valid(prefix, ann) end local function fann_scores_filter(task) - local id = '0' - if use_settings then - local sid = task:get_settings_id() - if sid then - id = tostring(sid) - end - end + for _,rule in settings.rules do + local id = rule.prefix .. '0' + if rule.use_settings then + local sid = task:get_settings_id() + if sid then + id = rule.prefix .. tostring(sid) + end + end - if fanns[id].fann then - local fann_data = task:get_symbols_tokens() - local mt = meta_functions.rspamd_gen_metatokens(task) - -- Add filtered meta tokens - fun.each(function(e) table.insert(fann_data, e) end, mt) - - local out = fanns[id].fann:test(fann_data) - local symscore = string.format('%.3f', out[1]) - rspamd_logger.infox(task, 'fann score: %s', symscore) - - if out[1] > 0 then - local result = rspamd_util.normalize_prob(out[1] / 2.0, 0) - task:insert_result(fann_symbol_spam, result, symscore, id) - else - local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0) - task:insert_result(fann_symbol_ham, result, symscore, id) + if fanns[id].fann then + local fann_data = task:get_symbols_tokens() + local mt = meta_functions.rspamd_gen_metatokens(task) + -- Add filtered meta tokens + fun.each(function(e) table.insert(fann_data, e) end, mt) + + local out = fanns[id].fann:test(fann_data) + local symscore = string.format('%.3f', out[1]) + rspamd_logger.infox(task, 'fann score: %s', symscore) + + if out[1] > 0 then + local result = rspamd_util.normalize_prob(out[1] / 2.0, 0) + task:insert_result(rule.symbol_spam, result, symscore, id) + else + local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0) + task:insert_result(rule.symbol_ham, result, symscore, id) + end end end end -local function create_train_fann(n, id) - id = tostring(id) - local prefix = gen_fann_prefix(id) +local function create_train_fann(rule, n, id) + id = rule.prefix .. tostring(id) + local prefix = gen_fann_prefix(rule, id) if not fanns[id] then fanns[id] = {} end - + -- Fix that for flexibe layers number if fanns[id].fann then - if n ~= fanns[id].fann:get_inputs() or + if n ~= fanns[id].fann:get_inputs() or -- (fanns[id].fann_train and n ~= fanns[id].fann_train:get_inputs()) then - rspamd_logger.infox(rspamd_config, 'recreate ANN %s as it has a wrong number of inputs, version %s', prefix, + rspamd_logger.infox(rspamd_config, + 'recreate ANN %s as it has a wrong number of inputs, version %s', + prefix, fanns[id].version) - fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1) + fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1) fanns[id].fann = nil - elseif fanns[id].version % max_usages == 0 then + elseif fanns[id].version % rule.max_usages == 0 then -- Forget last fann rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, fanns[id].version) - fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1) + fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1) else fanns[id].fann_train = fanns[id].fann end else - fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1) + fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1) fanns[id].version = 0 end end -local function load_or_invalidate_fann(data, id, ev_base) +local function load_or_invalidate_fann(rule, data, id, ev_base) local ver = data[2] - local prefix = gen_fann_prefix(id) + local prefix = gen_fann_prefix(rule, id) if not ver or not tonumber(ver) then rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix) @@ -392,7 +403,7 @@ local function load_or_invalidate_fann(data, id, ev_base) ann = rspamd_fann.load_data(ann_data) end - if is_fann_valid(prefix, ann) then + if is_fann_valid(rule, prefix, ann) then fanns[id].fann = ann rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis', prefix, ver) @@ -413,7 +424,7 @@ local function load_or_invalidate_fann(data, id, ev_base) rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix) rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, true, -- is write redis_invalidate_cb, --callback @@ -423,9 +434,9 @@ local function load_or_invalidate_fann(data, id, ev_base) end end -local function fann_train_callback(task, score, required_score, id) - local train_opts = opts['train'] - local fname,suffix = gen_fann_prefix(id) +local function fann_train_callback(rule, task, score, required_score, id) + local train_opts = rule['train'] + local fname,suffix = gen_fann_prefix(rule, id) local learn_spam, learn_ham @@ -459,7 +470,7 @@ local function fann_train_callback(task, score, required_score, id) local str = rspamd_util.zstd_compress(table.concat(fann_data, ';')) rspamd_redis.redis_make_request(task, - redis_params, + rule.redis, nil, true, -- is write learn_vec_cb, --callback @@ -477,22 +488,22 @@ local function fann_train_callback(task, score, required_score, id) end rspamd_redis.rspamd_redis_make_request(task, - redis_params, + rule.redis, nil, true, -- is write can_train_cb, --callback 'EVALSHA', -- command - {redis_can_train_sha, '4', gen_fann_prefix(nil), + {redis_can_train_sha, '4', gen_fann_prefix(rule, nil), suffix, k, tostring(max_trains)} -- arguments ) end end -local function train_fann(_, ev_base, elt) +local function train_fann(rule, _, ev_base, elt) local spam_elts = {} local ham_elts = {} elt = tostring(elt) - local prefix = gen_fann_prefix(elt) + local prefix = gen_fann_prefix(rule, elt) local function redis_unlock_cb(err) if err then @@ -507,7 +518,7 @@ local function train_fann(_, ev_base, elt) prefix, err) rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, false, -- is write redis_unlock_cb, --callback @@ -521,13 +532,13 @@ local function train_fann(_, ev_base, elt) end local function ann_trained(errcode, errmsg, train_mse) - learning_spawned = false + rule.learning_spawned = false if errcode ~= 0 then rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s', prefix, errmsg) rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, true, -- is write redis_unlock_cb, --callback @@ -543,7 +554,7 @@ local function train_fann(_, ev_base, elt) fanns[elt].fann_train = nil rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, true, -- is write redis_save_cb, --callback @@ -559,7 +570,7 @@ local function train_fann(_, ev_base, elt) prefix, err) rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, true, -- is write redis_unlock_cb, --callback @@ -593,7 +604,7 @@ local function train_fann(_, ev_base, elt) if not fanns[elt] or not fanns[elt].fann_train or n ~= fanns[elt].fann_train:get_inputs() then -- Create fann if it does not exist - create_train_fann(n, elt) + create_train_fann(rule, n, elt) end if #inputs < max_trains / 2 then @@ -610,7 +621,7 @@ local function train_fann(_, ev_base, elt) rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix) rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, true, -- is write redis_invalidate_cb, --callback @@ -618,10 +629,13 @@ local function train_fann(_, ev_base, elt) {redis_locked_invalidate_sha, 1, prefix} ) else - learning_spawned = true + rule.learning_spawned = true rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix) - fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base, - {max_epochs = max_epoch, desired_mse = mse}) + fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, + ev_base, { + max_epochs = rule.train.max_epoch, + desired_mse = rule.train.mse + }) end end end @@ -632,7 +646,7 @@ local function train_fann(_, ev_base, elt) prefix, err) rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, true, -- is write redis_unlock_cb, --callback @@ -647,7 +661,7 @@ local function train_fann(_, ev_base, elt) end, data)) rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, false, -- is write redis_ham_cb, --callback @@ -668,7 +682,7 @@ local function train_fann(_, ev_base, elt) -- Can train ANN rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, false, -- is write redis_spam_cb, --callback @@ -687,10 +701,10 @@ local function train_fann(_, ev_base, elt) prefix) end end - if learning_spawned then + if rule.learning_spawned then rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, true, -- is write redis_lock_extend_cb, --callback @@ -709,45 +723,47 @@ local function train_fann(_, ev_base, elt) rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', prefix) end end - if learning_spawned then + if rule.learning_spawned then rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix) return end rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, true, -- is write redis_lock_cb, --callback 'EVALSHA', -- command {redis_maybe_lock_sha, '4', prefix, tostring(os.time()), - tostring(lock_expire), rspamd_util.get_hostname()} + tostring(rule.lock_expire), rspamd_util.get_hostname()} ) end -local function maybe_train_fanns(cfg, ev_base) +local function maybe_train_fanns(rule, cfg, ev_base) local function members_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) elseif type(data) == 'table' then fun.each(function(elt) elt = tostring(elt) - local prefix = gen_fann_prefix(elt) + local prefix = gen_fann_prefix(rule, elt) local redis_len_cb = function(_err, _data) if _err then - rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', prefix, _err) + rspamd_logger.errx(rspamd_config, + 'cannot get FANN trains %s from redis: %s', prefix, _err) elseif _data and type(_data) == 'number' or type(_data) == 'string' then if tonumber(_data) and tonumber(_data) >= max_trains then - rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)', + rspamd_logger.infox(rspamd_config, + 'need to learn ANN %s after %s learn vectors (%s required)', prefix, tonumber(_data), max_trains) - train_fann(cfg, ev_base, elt) + train_fann(rule, cfg, ev_base, elt) end end end rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, false, -- is write redis_len_cb, --callback @@ -766,18 +782,18 @@ local function maybe_train_fanns(cfg, ev_base) -- First we need to get all fanns stored in our Redis rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, false, -- is write members_cb, --callback 'SMEMBERS', -- command - {gen_fann_prefix(nil)} -- arguments + {gen_fann_prefix(rule, nil)} -- arguments ) - return watch_interval + return rule.watch_interval end -local function check_fanns(_, ev_base) +local function check_fanns(rule, _, ev_base) local function members_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) @@ -791,7 +807,7 @@ local function check_fanns(_, ev_base) load_scripts(rspamd_config, ev_base, nil) end elseif _data and type(_data) == 'table' then - load_or_invalidate_fann(_data, elt, ev_base) + load_or_invalidate_fann(rule, _data, elt, ev_base) end end @@ -803,12 +819,12 @@ local function check_fanns(_, ev_base) end rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, false, -- is write redis_update_cb, --callback 'EVALSHA', -- command - {redis_maybe_load_sha, 2, gen_fann_prefix(elt), tostring(local_ver)} + {redis_maybe_load_sha, 2, gen_fann_prefix(rule, elt), tostring(local_ver)} ) end, data) @@ -822,27 +838,32 @@ local function check_fanns(_, ev_base) -- First we need to get all fanns stored in our Redis rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, - redis_params, + rule.redis, nil, false, -- is write members_cb, --callback 'SMEMBERS', -- command - {gen_fann_prefix(nil)} -- arguments + {gen_fann_prefix(rule, nil)} -- arguments ) - return watch_interval + return rule.watch_interval end local function ann_push_vector(task) local scores = task:get_metric_score() local sid = task:get_settings_id() - if use_settings then - fann_train_callback(task, scores[1], scores[2], tostring(sid)) - else - fann_train_callback(task, scores[1], scores[2], "1") + + for _,rule in settings.rules do + if rule.use_settings then + fann_train_callback(rule, task, scores[1], scores[2], tostring(sid)) + else + fann_train_callback(rule, task, scores[1], scores[2], "0") + end end end +redis_params = rspamd_parse_redis_server('fann_redis') + -- Initialization part if not (opts and type(opts) == 'table') or not redis_params then rspamd_logger.infox(rspamd_config, 'Module is unconfigured') @@ -854,72 +875,75 @@ if not rspamd_fann.is_enabled() then 'module is eventually disabled') return else - use_settings = opts['use_settings'] - if opts['spam_symbol'] then - fann_symbol_spam = opts['spam_symbol'] - end - if opts['ham_symbol'] then - fann_symbol_ham = opts['ham_symbol'] - end - if opts['prefix'] then - fann_prefix = opts['prefix'] - end - if opts['lock_expire'] then - lock_expire = tonumber(opts['lock_expire']) + local rules = opts['rules'] + + if not rules then + -- Use legacy configuration + rules = {} + rules['RFANN'] = opts end - rspamd_config:set_metric_symbol({ - name = fann_symbol_spam, - score = 3.0, - description = 'Neural network SPAM', - group = 'fann' - }) + local id = rspamd_config:register_symbol({ - name = fann_symbol_spam, + name = 'FANN_CHECK', type = 'postfilter,nostat', priority = 6, callback = fann_scores_filter }) - rspamd_config:set_metric_symbol({ - name = fann_symbol_ham, - score = -2.0, - description = 'Neural network HAM', - group = 'fann' - }) - rspamd_config:register_symbol({ - name = fann_symbol_ham, - type = 'virtual,nostat', - parent = id - }) - if opts['train'] then - if opts['train']['max_train'] then - max_trains = opts['train']['max_train'] - end - if opts['train']['max_epoch'] then - max_epoch = opts['train']['max_epoch'] - end - if opts['train']['max_usages'] then - max_usages = opts['train']['max_usages'] + + for k,r in rules do + rules[k] = default_options + rules[k]['redis'] = redis_params + local cur = rules[k] + -- Override defaults + for sk,v in r do + cur[sk] = v end - if opts['train']['mse'] then - mse = opts['train']['mse'] + if not cur.prefix then + cur.prefix = k end + rspamd_config:set_metric_symbol({ + name = cur.symbol_spam, + score = 3.0, + description = 'Neural network SPAM', + group = 'fann' + }) + + rspamd_config:set_metric_symbol({ + name = cur.symbol_ham, + score = -2.0, + description = 'Neural network HAM', + group = 'fann' + }) rspamd_config:register_symbol({ - name = 'FANN_VECTOR_PUSH', - type = 'postfilter,nostat', - priority = 5, - callback = ann_push_vector + name = cur.symbol_ham, + type = 'virtual,nostat', + parent = id }) end + + rspamd_config:register_symbol({ + name = 'FANN_VECTOR_PUSH', + type = 'postfilter,nostat', + priority = 5, + callback = ann_push_vector + }) + + settings.rules = rules + -- Add training scripts - rspamd_config:add_on_load(function(cfg, ev_base, worker) - load_scripts(cfg, ev_base, check_fanns) - - if worker:get_name() == 'normal' then - -- We also want to train neural nets when they have enough data - rspamd_config:add_periodic(ev_base, 0.0, - function(_cfg, _ev_base) - return maybe_train_fanns(_cfg, _ev_base) - end) - end - end) + for k,rule in settings.rules do + rspamd_config:add_on_load(function(cfg, ev_base, worker) + load_scripts(cfg, ev_base, function(cfg, ev_base) + check_fanns(rule, cfg, ev_base) + end) + + if worker:get_name() == 'normal' then + -- We also want to train neural nets when they have enough data + rspamd_config:add_periodic(ev_base, 0.0, + function(_cfg, _ev_base) + return maybe_train_fanns(rule, _cfg, _ev_base) + end) + end + end) + end end |