-- Module vars
local fann
-local symbols
-local nsymbols = 0
+local fann_train
+local fann_file
+local ntrains = 0
+local max_trains = 1000
+local epoch = 0
+local max_epoch = 100
+local fann_mtime = 0
local opts = rspamd_config:get_all_opt("fann_scores")
-local function fann_scores_filter(task)
- local fann_input = {}
+local function load_fann()
+ local err,st = rspamd_util.stat(fann_file)
+
+ if err then
+ return false
+ end
+
+ fann = rspamd_fann.load(fann_file)
- for sym,idx in pairs(symbols) do
- if task:has_symbol(sym) then
- fann_input[idx + 1] = 1
+ if fann then
+ local n = rspamd_config:get_symbols_count()
+
+ if n ~= fann:get_inputs() then
+ rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
+ ' is found in the cache', fann:get_inputs(), n)
+ fann = nil
else
- fann_input[idx + 1] = 0
+ rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fann_file)
+ return true
+ end
+ end
+
+ return false
+end
+
+local function check_fann()
+ local n = rspamd_config:get_symbols_count()
+
+ if fann then
+ local n = rspamd_config:get_symbols_count()
+
+ if n ~= fann:get_inputs() then
+ rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
+ ' is found in the cache', fann:get_inputs(), n)
+ fann = nil
end
end
- local out = fann:test(nsymbols, fann_input)
- local result = rspamd_util.tanh(2 * (out[1] - 0.5))
- local symscore = string.format('%.3f', out[1])
- rspamd_logger.infox(task, 'fann score: %s', symscore)
+ local err,st = rspamd_util.stat(fann_file)
- task:insert_result(fann_symbol, result, symscore)
+ if not err then
+ local mtime = st['mtime']
+
+ if mtime > fann_mtime then
+ fann_mtime = mtime
+ fann = nil
+ end
+ end
end
-if not rspamd_fann.is_enabled() then
- rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
- 'module is eventually disabled')
-else
- if not opts['fann_file'] or not opts['symbols_file'] then
- rspamd_logger.errx(rspamd_config, 'fann_scores module requires ' ..
- '`fann_file` and `symbols_file` to be specified')
- else
- fann = rspamd_fann.load(opts['fann_file'])
+local function fann_scores_filter(task)
+ check_fann()
- if not fann then
- rspamd_logger.errx(rspamd_config, 'cannot load fann from %s',
- opts['fann_file'])
- return
+ if fann then
+ local fann_input = {}
+
+ for sym,idx in pairs(symbols) do
+ if task:has_symbol(sym) then
+ fann_input[idx + 1] = 1
+ else
+ fann_input[idx + 1] = 0
+ end
end
- -- Parse symbols
- local parser = ucl.parser()
- local res, err = parser:parse_file(opts['symbols_file'])
+
+ local out = fann:test(nsymbols, fann_input)
+ local result = rspamd_util.tanh(2 * (out[1] - 0.5))
+ local symscore = string.format('%.3f', out[1])
+ rspamd_logger.infox(task, 'fann score: %s', symscore)
+
+ task:insert_result(fann_symbol, result, symscore)
+ else
+ if load_fann() then
+ fann_scores_filter(task)
+ end
+ end
+end
+
+local function create_train_fann(n)
+ fann_train = rspamd_fann.create(3, n, n / 2, 1)
+ ntrains = 0
+ epoch = 0
+end
+
+local function fann_train(score, required_score,results, cf, opts)
+ local n = cf:get_symbols_count()
+
+ if not fann_train then
+ create_train_fann(n)
+ end
+
+ if fann_train:get_inputs() ~= n then
+ rspamd_logger.infox(cf, 'fann has incorrect number of inputs: %s, %s symbols' ..
+ ' is found in the cache', fann_train:get_inputs(), n)
+ create_train_fann(n)
+ end
+
+ if ntrains > max_trains then
+ -- Store fann on disk
+ res = fann_train:save(fann_file)
+
if not res then
- rspamd_logger.errx(rspamd_config, 'cannot load symbols from %s: %s',
- opts['symbols_file'], err)
- return
+ rspamd_logger.errx(cf, 'cannot save fann in %s', fann_file)
+ else
+ ntrains = 0
+ epoch = epoch + 1
end
+ end
+
+ if epoch > max_epoch then
+ -- Re-create fann
+ rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fann_file,
+ max_epoch)
+ create_train_fann(n)
+ end
+
+ local learn_spam, learn_ham = false, false
+ if opts['spam_score'] then
+ learn_spam = score >= opts['spam_score']
+ else
+ learn_spam = score >= required_score
+ end
+ if opts['ham_score'] then
+ learn_ham = score <= opts['ham_score']
+ else
+ learn_ham = score < 0
+ end
+
+ if learn_spam or learn_ham then
+ local learn_data = {}
+ local matched_symbols = {}
- symbols = parser:get_object()
+ for _,sym in ipairs(results) do
+ matched_symbols[sym[1] + 1] = 1
+ end
- -- Check sanity
- for _,s in pairs(symbols) do nsymbols = nsymbols + 1 end
- if fann:get_inputs() ~= nsymbols then
- rspamd_logger.errx(rspamd_config, 'fann number of inputs: %s is not equal' ..
- ' to symbols count: %s',
- fann:get_inputs(), nsymbols)
- return
+ for i=1,(n + 1) do
+ if matched_symbols[i] then
+ learn_data[i] = 1
+ else
+ learn_data[i] = 0
+ end
end
- if fann:get_outputs() ~= 1 then
- rspamd_logger.errx(rspamd_config, 'fann nuber of outputs is invalid: %s',
- fann:get_outputs())
- return
+ if learn_spam then
+ fann_train:train(learn_data, 1.0)
+ else
+ fann_train:train(learn_data, 0.0)
end
+ trains = trains + 1
+ end
+end
+
+if not rspamd_fann.is_enabled() then
+ rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
+ 'module is eventually disabled')
+else
+ if not opts['fann_file'] then
+ rspamd_logger.errx(rspamd_config, 'fann_scores module requires ' ..
+ '`fann_file` to be specified')
+ else
+ fann_file = opts['fann_file']
rspamd_config:set_metric_symbol(fann_symbol, 3.0, 'Experimental FANN adjustment')
rspamd_config:register_post_filter(fann_scores_filter)
+
+ if opts['train'] then
+ rspamd_config:add_on_load(function(cfg)
+ if opts['train']['max_train'] then
+ max_trains = opts['train']['max_train']
+ end
+ if opts['train']['max_epoch'] then
+ max_trains = opts['train']['max_epoch']
+ end
+ cfg:register_worker_script("log_helper", function(score, req_score, results, cf)
+ fann_train(score, req_score, results, cf, opts['train'])
+ end)
+ end)
+ end
end
end