aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-09-16 15:56:28 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-09-16 15:56:28 +0100
commit6576616d3fe319148a29dc6d0ef0222b24384fd0 (patch)
tree85c7a5437de2cda2875ab93d8fc921429f104797 /src/plugins
parentd999f2bf3873910f5160429b42885073324f6479 (diff)
downloadrspamd-6576616d3fe319148a29dc6d0ef0222b24384fd0.tar.gz
rspamd-6576616d3fe319148a29dc6d0ef0222b24384fd0.zip
[Fix] Fix ANN checks
Diffstat (limited to 'src/plugins')
-rw-r--r--src/plugins/lua/fann_redis.lua44
1 files changed, 22 insertions, 22 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index 2473fb290..f7ec65d30 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -43,6 +43,7 @@ local default_options = {
max_trains = 1000,
max_epoch = 1000,
max_usages = 10,
+ max_iterations = 25, -- Torch style
mse = 0.001,
autotrain = true,
},
@@ -331,19 +332,7 @@ local function is_fann_valid(rule, prefix, ann)
meta_functions.rspamd_count_metatokens()
if torch then
- local nlayers = #ann
- if nlayers ~= rule.nlayers then
- rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
- prefix, nlayers)
- return false
- end
-
- local inp = ann:get(1):nElement()
- if n ~= inp then
- rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache', prefix, inp, n)
- return false
- end
+ return true
else
if n ~= ann:get_inputs() then
rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
@@ -364,12 +353,13 @@ local function is_fann_valid(rule, prefix, ann)
end
local function fann_scores_filter(task)
- for _,rule in ipairs(settings.rules) do
- local id = rule.prefix .. '0'
+
+ for _,rule in pairs(settings.rules) do
+ local id = '0'
if rule.use_settings then
local sid = task:get_settings_id()
if sid then
- id = rule.prefix .. tostring(sid)
+ id = tostring(sid)
end
end
if rule.per_user then
@@ -481,6 +471,7 @@ local function load_or_invalidate_fann(rule, data, id, ev_base)
end
if is_fann_valid(rule, prefix, ann) then
+ if not fanns[id] then fanns[id] = {} end
fanns[id].fann = ann
rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
prefix, ver)
@@ -627,6 +618,8 @@ local function train_fann(rule, _, ev_base, elt, worker)
if string.match(err, 'NOSCRIPT') then
load_scripts(rspamd_config, ev_base, nil)
end
+ else
+ rspamd_logger.infox(rspamd_config, 'saved ANN %s, key: %s_data', elt, prefix)
end
end
@@ -666,7 +659,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
true, -- is write
redis_save_cb, --callback
'EVALSHA', -- command
- {redis_save_unlock_sha, '2', prefix, ann_data, tostring(rule.ann_expire)}
+ {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)}
)
end
end
@@ -686,8 +679,8 @@ local function train_fann(rule, _, ev_base, elt, worker)
{prefix .. '_locked'}
)
else
- rspamd_logger.infox(rspamd_config, 'trained ANN %s',
- prefix)
+ rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes',
+ prefix, #data)
local ann_data
local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
ann_data = rspamd_util.zstd_compress(f:storage():string())
@@ -703,7 +696,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
true, -- is write
redis_save_cb, --callback
'EVALSHA', -- command
- {redis_save_unlock_sha, '2', prefix, ann_data, tostring(rule.ann_expire)}
+ {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)}
)
end
end
@@ -780,6 +773,8 @@ local function train_fann(rule, _, ev_base, elt, worker)
local trainer = nn.StochasticGradient(fanns[elt].fann_train,
criterion)
trainer.learning_rate = 0.01
+ trainer.verbose = false
+ trainer.maxIteration = rule.train.max_iterations
trainer.hookIteration = function(self, iteration, currentError)
rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
iteration, currentError)
@@ -980,18 +975,23 @@ end
local function check_fanns(rule, _, ev_base)
local function members_cb(err, data)
if err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
+ rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s',
+ err)
elseif type(data) == 'table' then
fun.each(function(elt)
elt = tostring(elt)
local redis_update_cb = function(_err, _data)
if _err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, _err)
+ rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s',
+ elt, _err)
if string.match(_err, 'NOSCRIPT') then
load_scripts(rspamd_config, ev_base, nil)
end
elseif _data and type(_data) == 'table' then
load_or_invalidate_fann(rule, _data, elt, ev_base)
+ else
+ rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis %s for prefix: %s',
+ type(_data), elt)
end
end