From 85f4ad7eebf5e7778c487cf6edbd8dc97fb0f40c Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Wed, 6 Apr 2016 14:22:01 +0100 Subject: [PATCH] [Fix] Rework fann learning --- src/plugins/lua/fann_scores.lua | 67 +++++++++++++++++---------------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index f430d9150..63dcb5733 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -21,11 +21,12 @@ local rspamd_logger = require "rspamd_logger" local rspamd_fann = require "rspamd_fann" local rspamd_util = require "rspamd_util" local fann_symbol = 'FANN_SCORE' +require "fun" () local ucl = require "ucl" -- Module vars -local fann -local fann_train +local fann = nil +local fann_train = nil local fann_file local ntrains = 0 local max_trains = 1000 @@ -34,6 +35,26 @@ local max_epoch = 100 local fann_mtime = 0 local opts = rspamd_config:get_all_opt("fann_scores") +local function symbols_to_fann_vector(syms) + local learn_data = {} + local matched_symbols = {} + local n = rspamd_config:get_symbols_count() + + each(function(s) + matched_symbols[s + 1] = 1 + end, syms) + + for i=1,n do + if matched_symbols[i] then + learn_data[i] = 1 + else + learn_data[i] = 0 + end + end + + return learn_data +end + local function load_fann() local err,st = rspamd_util.stat(fann_file) @@ -60,8 +81,6 @@ local function load_fann() end local function check_fann() - local n = rspamd_config:get_symbols_count() - if fann then local n = rspamd_config:get_symbols_count() @@ -88,17 +107,10 @@ local function fann_scores_filter(task) check_fann() if fann then - local fann_input = {} - - for sym,idx in pairs(symbols) do - if task:has_symbol(sym) then - fann_input[idx + 1] = 1 - else - fann_input[idx + 1] = 0 - end - end + local symbols = task:get_symbols_numeric() + local fann_data = symbols_to_fann_vector(symbols) - local out = fann:test(nsymbols, fann_input) + local out = fann:test(fann_data) local result = rspamd_util.tanh(2 * (out[1] - 0.5)) local symscore = string.format('%.3f', out[1]) rspamd_logger.infox(task, 'fann score: %s', symscore) @@ -117,7 +129,7 @@ local function create_train_fann(n) epoch = 0 end -local function fann_train(score, required_score,results, cf, opts) +local function fann_train_callback(score, required_score,results, cf, opts) local n = cf:get_symbols_count() if not fann_train then @@ -162,28 +174,17 @@ local function fann_train(score, required_score,results, cf, opts) end if learn_spam or learn_ham then - local learn_data = {} - local matched_symbols = {} - - for _,sym in ipairs(results) do - matched_symbols[sym[1] + 1] = 1 - end - - for i=1,(n + 1) do - if matched_symbols[i] then - learn_data[i] = 1 - else - learn_data[i] = 0 - end - end + local learn_data = symbols_to_fann_vector( + map(function(r) return r[1] end, results) + ) if learn_spam then - fann_train:train(learn_data, 1.0) + fann_train:train(learn_data, {1.0}) else - fann_train:train(learn_data, 0.0) + fann_train:train(learn_data, {0.0}) end - trains = trains + 1 + ntrains = ntrains + 1 end end @@ -208,7 +209,7 @@ else max_trains = opts['train']['max_epoch'] end cfg:register_worker_script("log_helper", function(score, req_score, results, cf) - fann_train(score, req_score, results, cf, opts['train']) + fann_train_callback(score, req_score, results, cf, opts['train']) end) end) end -- 2.39.5