diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-10-15 13:34:22 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-10-15 13:34:42 +0100 |
commit | c5e504950a444472e7772f2cc4a18ee9f81b6bd8 (patch) | |
tree | df9e61fefb19bf727103ccc46c0ea36dcfffe0c3 /src/plugins/lua/fann_scores.lua | |
parent | a903f7be21746c3a803d0a47dad4c0716584eaa7 (diff) | |
download | rspamd-c5e504950a444472e7772f2cc4a18ee9f81b6bd8.tar.gz rspamd-c5e504950a444472e7772f2cc4a18ee9f81b6bd8.zip |
[Feature] Use more layers for fann and another normalization
Diffstat (limited to 'src/plugins/lua/fann_scores.lua')
-rw-r--r-- | src/plugins/lua/fann_scores.lua | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index c1c3d80c0..9647fd3d3 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -335,6 +335,13 @@ local function check_fann(id) ' is found in the cache', data[id].fann:get_inputs(), n) data[id].fann = nil end + local layers = data[id].fann:get_layers() + + if not layers or #layers ~= 5 then + rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s', + #layers) + data[id].fann = nil + end end local fname = gen_fann_file(id) @@ -373,14 +380,15 @@ local function fann_scores_filter(task) end local out = data[id].fann:test(fann_data) - local result = rspamd_util.tanh(2 * (out[1] - 0.5)) local symscore = string.format('%.3f', out[1]) rspamd_logger.infox(task, 'fann score: %s', symscore) - if result > 0 then + if out[1] > 0 then + local result = rspamd_util.normalize_prob(out[1] / 2.0, 0) task:insert_result(fann_symbol_spam, result, symscore, id) else - task:insert_result(fann_symbol_ham, -(result), symscore, id) + local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0) + task:insert_result(fann_symbol_ham, result, symscore, id) end else if load_fann(id) then @@ -390,7 +398,7 @@ local function fann_scores_filter(task) end local function create_train_fann(n, id) - data[id].fann_train = rspamd_fann.create(3, n, n / 2, 1) + data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1) data[id].ntrains = 0 data[id].epoch = 0 end @@ -480,7 +488,7 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, if learn_spam then data[id].fann_train:train(learn_data, {1.0}) else - data[id].fann_train:train(learn_data, {0.0}) + data[id].fann_train:train(learn_data, {-1.0}) end data[id].ntrains = data[id].ntrains + 1 |