diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-04-27 14:07:33 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-04-27 14:07:33 +0100 |
commit | 4d5404b49c8deadf15690995a9bfa642b44c7f7f (patch) | |
tree | 329c53d7a91da159ebceb30bba0fbd4d129128f0 /src/plugins/lua/fann_scores.lua | |
parent | 3ec4cb2dd98abf2053e54d54054a025683836b10 (diff) | |
download | rspamd-4d5404b49c8deadf15690995a9bfa642b44c7f7f.tar.gz rspamd-4d5404b49c8deadf15690995a9bfa642b44c7f7f.zip |
[Feature] Rework fann module to understand settings
Diffstat (limited to 'src/plugins/lua/fann_scores.lua')
-rw-r--r-- | src/plugins/lua/fann_scores.lua | 142 |
1 files changed, 86 insertions, 56 deletions
diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index b20338fb8..ea82974a1 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -25,14 +25,18 @@ require "fun" () local ucl = require "ucl" -- Module vars -local fann = nil -local fann_train = nil +-- ANNs indexed by settings id +local data = { + ['0'] = { + fann_mtime = 0, + ntrains = 0, + epoch = 0, + } +} local fann_file -local ntrains = 0 local max_trains = 1000 -local epoch = 0 local max_epoch = 100 -local fann_mtime = 0 +local use_settings = false local opts = rspamd_config:get_all_opt("fann_scores") local function symbols_to_fann_vector(syms) @@ -55,125 +59,144 @@ local function symbols_to_fann_vector(syms) return learn_data end -local function load_fann() - local err,st = rspamd_util.stat(fann_file) +local function gen_fann_file(id) + if use_settings then + return fann_file .. id + else + return fann_file + end +end + +local function load_fann(id) + local fname = gen_fann_file(id) + local err,st = rspamd_util.stat(fname) if err then return false end - fann = rspamd_fann.load(fann_file) + data[id].fann = rspamd_fann.load(fname) - if fann then + if data[id].fann then local n = rspamd_config:get_symbols_count() - if n ~= fann:get_inputs() then + if n ~= data[id].fann:get_inputs() then rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache; removing', fann:get_inputs(), n) - fann = nil + ' is found in the cache; removing', data[id].fann:get_inputs(), n) + data[id].fann = nil - local ret,err = rspamd_util.unlink(fann_file) + local ret,err = rspamd_util.unlink(fname) if not ret then rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', - fann_file, err) + fname, err) end else - rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fann_file) + rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fname) return true end else - rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fann_file) - local ret,err = rspamd_util.unlink(fann_file) + rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fname) + local ret,err = rspamd_util.unlink(fname) if not ret then rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', - fann_file, err) + fname, err) end end return false end -local function check_fann() - if fann then +local function check_fann(id) + if data[id].fann then local n = rspamd_config:get_symbols_count() - if n ~= fann:get_inputs() then + if n ~= data[id].fann:get_inputs() then rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache', fann:get_inputs(), n) - fann = nil + ' is found in the cache', data[id].fann:get_inputs(), n) + data[id].fann = nil end end - local err,st = rspamd_util.stat(fann_file) + local fname = gen_fann_file(id) + local err,st = rspamd_util.stat(fname) if not err then local mtime = st['mtime'] - if mtime > fann_mtime then + if mtime > data[id].fann_mtime then rspamd_logger.infox(rspamd_config, 'have more fresh version of fann ' .. - 'file: %s -> %s, need to reload %s', fann_mtime, mtime, fann_file) - fann_mtime = mtime - fann = nil + 'file: %s -> %s, need to reload %s', data[id].fann_mtime, mtime, fname) + data[id].fann_mtime = mtime + data[id].fann = nil end end end local function fann_scores_filter(task) - check_fann() + local id = '0' + if use_settings then + local sid = task:get_settings_id() + if sid then + id = tostring(sid) + end + end + + check_fann(id) - if fann then + if data[id].fann then local symbols = task:get_symbols_numeric() local fann_data = symbols_to_fann_vector(symbols) - local out = fann:test(fann_data) + local out = data[id].fann:test(fann_data) local result = rspamd_util.tanh(2 * (out[1] - 0.5)) local symscore = string.format('%.3f', out[1]) rspamd_logger.infox(task, 'fann score: %s', symscore) - task:insert_result(fann_symbol, result, symscore) + task:insert_result(fann_symbol, result, symscore, id) else - if load_fann() then + if load_fann(id) then fann_scores_filter(task) end end end -local function create_train_fann(n) - fann_train = rspamd_fann.create(3, n, n / 2, 1) - ntrains = 0 - epoch = 0 +local function create_train_fann(n, id) + data[id].fann_train = rspamd_fann.create(3, n, n / 2, 1) + data[id].ntrains = 0 + data[id].epoch = 0 end -local function fann_train_callback(score, required_score,results, cf, opts) +local function fann_train_callback(score, required_score,results, cf, id, opts) local n = cf:get_symbols_count() + local fname = gen_fann_file(id) - if not fann_train then - create_train_fann(n) + if not data[id].fann_train then + create_train_fann(n, id) end - if fann_train:get_inputs() ~= n then + if data[id].fann_train:get_inputs() ~= n then rspamd_logger.infox(cf, 'fann has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache', fann_train:get_inputs(), n) - create_train_fann(n) + ' is found in the cache', data[id].fann_train:get_inputs(), n) + create_train_fann(n, id) end - if ntrains > max_trains then + if data[id].ntrains > max_trains then -- Store fann on disk - res = fann_train:save(fann_file) + local res = data[id].fann_train:save(fname) if not res then - rspamd_logger.errx(cf, 'cannot save fann in %s', fann_file) + rspamd_logger.errx(cf, 'cannot save fann in %s', fname) else - ntrains = 0 - epoch = epoch + 1 + data[id].ntrains = 0 + data[id].epoch = data[id].epoch + 1 end end - if epoch > max_epoch then + if data[id].epoch > max_epoch then -- Re-create fann - rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fann_file, + rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fname, max_epoch) - create_train_fann(n) + create_train_fann(n, id) end local learn_spam, learn_ham = false, false @@ -194,12 +217,12 @@ local function fann_train_callback(score, required_score,results, cf, opts) ) if learn_spam then - fann_train:train(learn_data, {1.0}) + data[id].fann_train:train(learn_data, {1.0}) else - fann_train:train(learn_data, {0.0}) + data[id].fann_train:train(learn_data, {0.0}) end - ntrains = ntrains + 1 + data[id].ntrains = data[id].ntrains + 1 end end @@ -212,6 +235,7 @@ else '`fann_file` to be specified') else fann_file = opts['fann_file'] + use_settings = opts['use_settings'] rspamd_config:set_metric_symbol(fann_symbol, 3.0, 'Experimental FANN adjustment') rspamd_config:register_post_filter(fann_scores_filter) @@ -223,8 +247,14 @@ else if opts['train']['max_epoch'] then max_trains = opts['train']['max_epoch'] end - cfg:register_worker_script("log_helper", function(score, req_score, results, cf) - fann_train_callback(score, req_score, results, cf, opts['train']) + cfg:register_worker_script("log_helper", + function(score, req_score, results, cf, id) + if use_settings then + fann_train_callback(score, req_score, results, cf, + tostring(id), opts['train']) + else + fann_train_callback(score, req_score, results, cf, '0', opts['train']) + end end) end) end |