From 48a1be2cd19f795b2bd26cc061cdc8655e098248 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Tue, 5 Apr 2016 17:26:43 +0100 Subject: [PATCH] [Feature] Implement preliminary code for fann autolearn --- src/plugins/lua/fann_scores.lua | 209 +++++++++++++++++++++++++------- 1 file changed, 165 insertions(+), 44 deletions(-) diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index 66a5e1879..f430d9150 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -25,71 +25,192 @@ local ucl = require "ucl" -- Module vars local fann -local symbols -local nsymbols = 0 +local fann_train +local fann_file +local ntrains = 0 +local max_trains = 1000 +local epoch = 0 +local max_epoch = 100 +local fann_mtime = 0 local opts = rspamd_config:get_all_opt("fann_scores") -local function fann_scores_filter(task) - local fann_input = {} +local function load_fann() + local err,st = rspamd_util.stat(fann_file) + + if err then + return false + end + + fann = rspamd_fann.load(fann_file) - for sym,idx in pairs(symbols) do - if task:has_symbol(sym) then - fann_input[idx + 1] = 1 + if fann then + local n = rspamd_config:get_symbols_count() + + if n ~= fann:get_inputs() then + rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. + ' is found in the cache', fann:get_inputs(), n) + fann = nil else - fann_input[idx + 1] = 0 + rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fann_file) + return true + end + end + + return false +end + +local function check_fann() + local n = rspamd_config:get_symbols_count() + + if fann then + local n = rspamd_config:get_symbols_count() + + if n ~= fann:get_inputs() then + rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. + ' is found in the cache', fann:get_inputs(), n) + fann = nil end end - local out = fann:test(nsymbols, fann_input) - local result = rspamd_util.tanh(2 * (out[1] - 0.5)) - local symscore = string.format('%.3f', out[1]) - rspamd_logger.infox(task, 'fann score: %s', symscore) + local err,st = rspamd_util.stat(fann_file) - task:insert_result(fann_symbol, result, symscore) + if not err then + local mtime = st['mtime'] + + if mtime > fann_mtime then + fann_mtime = mtime + fann = nil + end + end end -if not rspamd_fann.is_enabled() then - rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' .. - 'module is eventually disabled') -else - if not opts['fann_file'] or not opts['symbols_file'] then - rspamd_logger.errx(rspamd_config, 'fann_scores module requires ' .. - '`fann_file` and `symbols_file` to be specified') - else - fann = rspamd_fann.load(opts['fann_file']) +local function fann_scores_filter(task) + check_fann() - if not fann then - rspamd_logger.errx(rspamd_config, 'cannot load fann from %s', - opts['fann_file']) - return + if fann then + local fann_input = {} + + for sym,idx in pairs(symbols) do + if task:has_symbol(sym) then + fann_input[idx + 1] = 1 + else + fann_input[idx + 1] = 0 + end end - -- Parse symbols - local parser = ucl.parser() - local res, err = parser:parse_file(opts['symbols_file']) + + local out = fann:test(nsymbols, fann_input) + local result = rspamd_util.tanh(2 * (out[1] - 0.5)) + local symscore = string.format('%.3f', out[1]) + rspamd_logger.infox(task, 'fann score: %s', symscore) + + task:insert_result(fann_symbol, result, symscore) + else + if load_fann() then + fann_scores_filter(task) + end + end +end + +local function create_train_fann(n) + fann_train = rspamd_fann.create(3, n, n / 2, 1) + ntrains = 0 + epoch = 0 +end + +local function fann_train(score, required_score,results, cf, opts) + local n = cf:get_symbols_count() + + if not fann_train then + create_train_fann(n) + end + + if fann_train:get_inputs() ~= n then + rspamd_logger.infox(cf, 'fann has incorrect number of inputs: %s, %s symbols' .. + ' is found in the cache', fann_train:get_inputs(), n) + create_train_fann(n) + end + + if ntrains > max_trains then + -- Store fann on disk + res = fann_train:save(fann_file) + if not res then - rspamd_logger.errx(rspamd_config, 'cannot load symbols from %s: %s', - opts['symbols_file'], err) - return + rspamd_logger.errx(cf, 'cannot save fann in %s', fann_file) + else + ntrains = 0 + epoch = epoch + 1 end + end + + if epoch > max_epoch then + -- Re-create fann + rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fann_file, + max_epoch) + create_train_fann(n) + end + + local learn_spam, learn_ham = false, false + if opts['spam_score'] then + learn_spam = score >= opts['spam_score'] + else + learn_spam = score >= required_score + end + if opts['ham_score'] then + learn_ham = score <= opts['ham_score'] + else + learn_ham = score < 0 + end + + if learn_spam or learn_ham then + local learn_data = {} + local matched_symbols = {} - symbols = parser:get_object() + for _,sym in ipairs(results) do + matched_symbols[sym[1] + 1] = 1 + end - -- Check sanity - for _,s in pairs(symbols) do nsymbols = nsymbols + 1 end - if fann:get_inputs() ~= nsymbols then - rspamd_logger.errx(rspamd_config, 'fann number of inputs: %s is not equal' .. - ' to symbols count: %s', - fann:get_inputs(), nsymbols) - return + for i=1,(n + 1) do + if matched_symbols[i] then + learn_data[i] = 1 + else + learn_data[i] = 0 + end end - if fann:get_outputs() ~= 1 then - rspamd_logger.errx(rspamd_config, 'fann nuber of outputs is invalid: %s', - fann:get_outputs()) - return + if learn_spam then + fann_train:train(learn_data, 1.0) + else + fann_train:train(learn_data, 0.0) end + trains = trains + 1 + end +end + +if not rspamd_fann.is_enabled() then + rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' .. + 'module is eventually disabled') +else + if not opts['fann_file'] then + rspamd_logger.errx(rspamd_config, 'fann_scores module requires ' .. + '`fann_file` to be specified') + else + fann_file = opts['fann_file'] rspamd_config:set_metric_symbol(fann_symbol, 3.0, 'Experimental FANN adjustment') rspamd_config:register_post_filter(fann_scores_filter) + + if opts['train'] then + rspamd_config:add_on_load(function(cfg) + if opts['train']['max_train'] then + max_trains = opts['train']['max_train'] + end + if opts['train']['max_epoch'] then + max_trains = opts['train']['max_epoch'] + end + cfg:register_worker_script("log_helper", function(score, req_score, results, cf) + fann_train(score, req_score, results, cf, opts['train']) + end) + end) + end end end -- 2.39.5