summaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/fann_scores.lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-10-15 13:34:22 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-10-15 13:34:42 +0100
commitc5e504950a444472e7772f2cc4a18ee9f81b6bd8 (patch)
treedf9e61fefb19bf727103ccc46c0ea36dcfffe0c3 /src/plugins/lua/fann_scores.lua
parenta903f7be21746c3a803d0a47dad4c0716584eaa7 (diff)
downloadrspamd-c5e504950a444472e7772f2cc4a18ee9f81b6bd8.tar.gz
rspamd-c5e504950a444472e7772f2cc4a18ee9f81b6bd8.zip
[Feature] Use more layers for fann and another normalization
Diffstat (limited to 'src/plugins/lua/fann_scores.lua')
-rw-r--r--src/plugins/lua/fann_scores.lua18
1 files changed, 13 insertions, 5 deletions
diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua
index c1c3d80c0..9647fd3d3 100644
--- a/src/plugins/lua/fann_scores.lua
+++ b/src/plugins/lua/fann_scores.lua
@@ -335,6 +335,13 @@ local function check_fann(id)
' is found in the cache', data[id].fann:get_inputs(), n)
data[id].fann = nil
end
+ local layers = data[id].fann:get_layers()
+
+ if not layers or #layers ~= 5 then
+ rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s',
+ #layers)
+ data[id].fann = nil
+ end
end
local fname = gen_fann_file(id)
@@ -373,14 +380,15 @@ local function fann_scores_filter(task)
end
local out = data[id].fann:test(fann_data)
- local result = rspamd_util.tanh(2 * (out[1] - 0.5))
local symscore = string.format('%.3f', out[1])
rspamd_logger.infox(task, 'fann score: %s', symscore)
- if result > 0 then
+ if out[1] > 0 then
+ local result = rspamd_util.normalize_prob(out[1] / 2.0, 0)
task:insert_result(fann_symbol_spam, result, symscore, id)
else
- task:insert_result(fann_symbol_ham, -(result), symscore, id)
+ local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
+ task:insert_result(fann_symbol_ham, result, symscore, id)
end
else
if load_fann(id) then
@@ -390,7 +398,7 @@ local function fann_scores_filter(task)
end
local function create_train_fann(n, id)
- data[id].fann_train = rspamd_fann.create(3, n, n / 2, 1)
+ data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1)
data[id].ntrains = 0
data[id].epoch = 0
end
@@ -480,7 +488,7 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
if learn_spam then
data[id].fann_train:train(learn_data, {1.0})
else
- data[id].fann_train:train(learn_data, {0.0})
+ data[id].fann_train:train(learn_data, {-1.0})
end
data[id].ntrains = data[id].ntrains + 1