aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/fann_scores.lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-04-27 14:07:33 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-04-27 14:07:33 +0100
commit4d5404b49c8deadf15690995a9bfa642b44c7f7f (patch)
tree329c53d7a91da159ebceb30bba0fbd4d129128f0 /src/plugins/lua/fann_scores.lua
parent3ec4cb2dd98abf2053e54d54054a025683836b10 (diff)
downloadrspamd-4d5404b49c8deadf15690995a9bfa642b44c7f7f.tar.gz
rspamd-4d5404b49c8deadf15690995a9bfa642b44c7f7f.zip
[Feature] Rework fann module to understand settings
Diffstat (limited to 'src/plugins/lua/fann_scores.lua')
-rw-r--r--src/plugins/lua/fann_scores.lua142
1 files changed, 86 insertions, 56 deletions
diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua
index b20338fb8..ea82974a1 100644
--- a/src/plugins/lua/fann_scores.lua
+++ b/src/plugins/lua/fann_scores.lua
@@ -25,14 +25,18 @@ require "fun" ()
local ucl = require "ucl"
-- Module vars
-local fann = nil
-local fann_train = nil
+-- ANNs indexed by settings id
+local data = {
+ ['0'] = {
+ fann_mtime = 0,
+ ntrains = 0,
+ epoch = 0,
+ }
+}
local fann_file
-local ntrains = 0
local max_trains = 1000
-local epoch = 0
local max_epoch = 100
-local fann_mtime = 0
+local use_settings = false
local opts = rspamd_config:get_all_opt("fann_scores")
local function symbols_to_fann_vector(syms)
@@ -55,125 +59,144 @@ local function symbols_to_fann_vector(syms)
return learn_data
end
-local function load_fann()
- local err,st = rspamd_util.stat(fann_file)
+local function gen_fann_file(id)
+ if use_settings then
+ return fann_file .. id
+ else
+ return fann_file
+ end
+end
+
+local function load_fann(id)
+ local fname = gen_fann_file(id)
+ local err,st = rspamd_util.stat(fname)
if err then
return false
end
- fann = rspamd_fann.load(fann_file)
+ data[id].fann = rspamd_fann.load(fname)
- if fann then
+ if data[id].fann then
local n = rspamd_config:get_symbols_count()
- if n ~= fann:get_inputs() then
+ if n ~= data[id].fann:get_inputs() then
rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache; removing', fann:get_inputs(), n)
- fann = nil
+ ' is found in the cache; removing', data[id].fann:get_inputs(), n)
+ data[id].fann = nil
- local ret,err = rspamd_util.unlink(fann_file)
+ local ret,err = rspamd_util.unlink(fname)
if not ret then
rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
- fann_file, err)
+ fname, err)
end
else
- rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fann_file)
+ rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fname)
return true
end
else
- rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fann_file)
- local ret,err = rspamd_util.unlink(fann_file)
+ rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fname)
+ local ret,err = rspamd_util.unlink(fname)
if not ret then
rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
- fann_file, err)
+ fname, err)
end
end
return false
end
-local function check_fann()
- if fann then
+local function check_fann(id)
+ if data[id].fann then
local n = rspamd_config:get_symbols_count()
- if n ~= fann:get_inputs() then
+ if n ~= data[id].fann:get_inputs() then
rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache', fann:get_inputs(), n)
- fann = nil
+ ' is found in the cache', data[id].fann:get_inputs(), n)
+ data[id].fann = nil
end
end
- local err,st = rspamd_util.stat(fann_file)
+ local fname = gen_fann_file(id)
+ local err,st = rspamd_util.stat(fname)
if not err then
local mtime = st['mtime']
- if mtime > fann_mtime then
+ if mtime > data[id].fann_mtime then
rspamd_logger.infox(rspamd_config, 'have more fresh version of fann ' ..
- 'file: %s -> %s, need to reload %s', fann_mtime, mtime, fann_file)
- fann_mtime = mtime
- fann = nil
+ 'file: %s -> %s, need to reload %s', data[id].fann_mtime, mtime, fname)
+ data[id].fann_mtime = mtime
+ data[id].fann = nil
end
end
end
local function fann_scores_filter(task)
- check_fann()
+ local id = '0'
+ if use_settings then
+ local sid = task:get_settings_id()
+ if sid then
+ id = tostring(sid)
+ end
+ end
+
+ check_fann(id)
- if fann then
+ if data[id].fann then
local symbols = task:get_symbols_numeric()
local fann_data = symbols_to_fann_vector(symbols)
- local out = fann:test(fann_data)
+ local out = data[id].fann:test(fann_data)
local result = rspamd_util.tanh(2 * (out[1] - 0.5))
local symscore = string.format('%.3f', out[1])
rspamd_logger.infox(task, 'fann score: %s', symscore)
- task:insert_result(fann_symbol, result, symscore)
+ task:insert_result(fann_symbol, result, symscore, id)
else
- if load_fann() then
+ if load_fann(id) then
fann_scores_filter(task)
end
end
end
-local function create_train_fann(n)
- fann_train = rspamd_fann.create(3, n, n / 2, 1)
- ntrains = 0
- epoch = 0
+local function create_train_fann(n, id)
+ data[id].fann_train = rspamd_fann.create(3, n, n / 2, 1)
+ data[id].ntrains = 0
+ data[id].epoch = 0
end
-local function fann_train_callback(score, required_score,results, cf, opts)
+local function fann_train_callback(score, required_score,results, cf, id, opts)
local n = cf:get_symbols_count()
+ local fname = gen_fann_file(id)
- if not fann_train then
- create_train_fann(n)
+ if not data[id].fann_train then
+ create_train_fann(n, id)
end
- if fann_train:get_inputs() ~= n then
+ if data[id].fann_train:get_inputs() ~= n then
rspamd_logger.infox(cf, 'fann has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache', fann_train:get_inputs(), n)
- create_train_fann(n)
+ ' is found in the cache', data[id].fann_train:get_inputs(), n)
+ create_train_fann(n, id)
end
- if ntrains > max_trains then
+ if data[id].ntrains > max_trains then
-- Store fann on disk
- res = fann_train:save(fann_file)
+ local res = data[id].fann_train:save(fname)
if not res then
- rspamd_logger.errx(cf, 'cannot save fann in %s', fann_file)
+ rspamd_logger.errx(cf, 'cannot save fann in %s', fname)
else
- ntrains = 0
- epoch = epoch + 1
+ data[id].ntrains = 0
+ data[id].epoch = data[id].epoch + 1
end
end
- if epoch > max_epoch then
+ if data[id].epoch > max_epoch then
-- Re-create fann
- rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fann_file,
+ rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fname,
max_epoch)
- create_train_fann(n)
+ create_train_fann(n, id)
end
local learn_spam, learn_ham = false, false
@@ -194,12 +217,12 @@ local function fann_train_callback(score, required_score,results, cf, opts)
)
if learn_spam then
- fann_train:train(learn_data, {1.0})
+ data[id].fann_train:train(learn_data, {1.0})
else
- fann_train:train(learn_data, {0.0})
+ data[id].fann_train:train(learn_data, {0.0})
end
- ntrains = ntrains + 1
+ data[id].ntrains = data[id].ntrains + 1
end
end
@@ -212,6 +235,7 @@ else
'`fann_file` to be specified')
else
fann_file = opts['fann_file']
+ use_settings = opts['use_settings']
rspamd_config:set_metric_symbol(fann_symbol, 3.0, 'Experimental FANN adjustment')
rspamd_config:register_post_filter(fann_scores_filter)
@@ -223,8 +247,14 @@ else
if opts['train']['max_epoch'] then
max_trains = opts['train']['max_epoch']
end
- cfg:register_worker_script("log_helper", function(score, req_score, results, cf)
- fann_train_callback(score, req_score, results, cf, opts['train'])
+ cfg:register_worker_script("log_helper",
+ function(score, req_score, results, cf, id)
+ if use_settings then
+ fann_train_callback(score, req_score, results, cf,
+ tostring(id), opts['train'])
+ else
+ fann_train_callback(score, req_score, results, cf, '0', opts['train'])
+ end
end)
end)
end