Browse Source

[Project] Neural: Further PCA fixes

tags/2.6
Vsevolod Stakhov 3 years ago
parent
commit
f3d8644e63
1 changed files with 14 additions and 6 deletions
  1. 14
    6
      src/plugins/lua/neural.lua

+ 14
- 6
src/plugins/lua/neural.lua View File

@@ -114,7 +114,10 @@ end
local redis_lua_script_vectors_len = [[
local prefix = KEYS[1]
local locked = redis.call('HGET', prefix, 'lock')
if locked then return false end
if locked then
local host = redis.call('HGET', prefix, 'hostname')
return string.format('%s:%s', hostname, locked)
end
local nspam = 0
local nham = 0

@@ -547,10 +550,10 @@ local function ann_push_task_result(rule, task, verdict, score, set)
if err then
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
rule.prefix, set.name, err)
elseif type(data) == 'userdata' then
elseif type(data) == 'string' then
-- nil return value
rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning",
learn_type, rule.prefix, set.name, set.ann.redis_key)
rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
learn_type, rule.prefix, set.name, set.ann.redis_key, data)
else
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
'please remove this key from Redis manually if you perform upgrade from the previous version',
@@ -647,6 +650,7 @@ local function fill_scatter(inputs)
inputs = rspamd_tensor.fromtable(inputs):transpose()

local meanv = inputs:mean()
lua_util.debugm(N, 'means: %s', meanv)

for i=1,nsamples do
local col = rspamd_tensor.new(1, #inputs)
@@ -662,6 +666,8 @@ local function fill_scatter(inputs)
end
end

lua_util.debugm(N, 'scatter matrix: %s', scatter_matrix)

return scatter_matrix
end

@@ -1004,7 +1010,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
{
ann_key,
tostring(os.time()),
tostring(rule.watch_interval * 2),
tostring(math.max(10.0, rule.watch_interval * 2)),
rspamd_util.get_hostname()
})
end
@@ -1062,7 +1068,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
{set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
)
rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
rule.prefix, set.name, ann_key, #ann_data, profile.version)
rule.prefix, set.name, ann_key, #data[1], profile.version)
else
rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
rule.prefix, set.name, ann_key)
@@ -1079,6 +1085,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
if rule.max_inputs then
-- We can use PCA
set.ann.pca = rspamd_tensor.load(pca_data)
rspamd_logger.infox(rspamd_config, 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
rule.prefix, set.name, ann_key, #data[2], profile.version)
else
-- no need in pca, why is it there?
rspamd_logger.warnx(rspamd_config, 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',

Loading…
Cancel
Save