aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/plugins/lua/neural.lua20
1 files changed, 14 insertions, 6 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 225c9895b..d7410b225 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -114,7 +114,10 @@ end
local redis_lua_script_vectors_len = [[
local prefix = KEYS[1]
local locked = redis.call('HGET', prefix, 'lock')
- if locked then return false end
+ if locked then
+ local host = redis.call('HGET', prefix, 'hostname')
+ return string.format('%s:%s', hostname, locked)
+ end
local nspam = 0
local nham = 0
@@ -547,10 +550,10 @@ local function ann_push_task_result(rule, task, verdict, score, set)
if err then
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
rule.prefix, set.name, err)
- elseif type(data) == 'userdata' then
+ elseif type(data) == 'string' then
-- nil return value
- rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning",
- learn_type, rule.prefix, set.name, set.ann.redis_key)
+ rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
+ learn_type, rule.prefix, set.name, set.ann.redis_key, data)
else
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
'please remove this key from Redis manually if you perform upgrade from the previous version',
@@ -647,6 +650,7 @@ local function fill_scatter(inputs)
inputs = rspamd_tensor.fromtable(inputs):transpose()
local meanv = inputs:mean()
+ lua_util.debugm(N, 'means: %s', meanv)
for i=1,nsamples do
local col = rspamd_tensor.new(1, #inputs)
@@ -662,6 +666,8 @@ local function fill_scatter(inputs)
end
end
+ lua_util.debugm(N, 'scatter matrix: %s', scatter_matrix)
+
return scatter_matrix
end
@@ -1004,7 +1010,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
{
ann_key,
tostring(os.time()),
- tostring(rule.watch_interval * 2),
+ tostring(math.max(10.0, rule.watch_interval * 2)),
rspamd_util.get_hostname()
})
end
@@ -1062,7 +1068,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
{set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
)
rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
- rule.prefix, set.name, ann_key, #ann_data, profile.version)
+ rule.prefix, set.name, ann_key, #data[1], profile.version)
else
rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
rule.prefix, set.name, ann_key)
@@ -1079,6 +1085,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
if rule.max_inputs then
-- We can use PCA
set.ann.pca = rspamd_tensor.load(pca_data)
+ rspamd_logger.infox(rspamd_config, 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
+ rule.prefix, set.name, ann_key, #data[2], profile.version)
else
-- no need in pca, why is it there?
rspamd_logger.warnx(rspamd_config, 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',