aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/neural.lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/lua/neural.lua')
-rw-r--r--src/plugins/lua/neural.lua31
1 files changed, 15 insertions, 16 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 5f9b1f9d2..dbd420257 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -469,26 +469,25 @@ local function train_ann(rule, _, ev_base, elt, worker)
end
end
- local function ann_trained(errcode, errmsg, train_mse)
+ local function ann_trained(err, data)
rule.learning_spawned = false
- if errcode ~= 0 then
+ if err then
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
- prefix, errmsg)
+ prefix, err)
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- redis_unlock_cb, --callback
- 'DEL', -- command
- {prefix .. '_locked'}
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ redis_unlock_cb, --callback
+ 'DEL', -- command
+ {prefix .. '_locked'}
)
else
- rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
- prefix, train_mse)
- local f = rule.anns[elt].ann_train:save()
- local ann_data = rspamd_util.zstd_compress(f)
-
+ rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes',
+ prefix, #data)
+ local ann_data = rspamd_util.zstd_compress(data)
+ rule.anns[elt].ann_train = rspamd_kann.load(data)
rule.anns[elt].version = rule.anns[elt].version + 1
rule.anns[elt].ann = rule.anns[elt].ann_train
rule.anns[elt].ann_train = nil
@@ -575,7 +574,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
})
local out = rule.anns[elt].ann_train:save()
- return tostring(out)
+ return out
end
rule.learning_spawned = true