diff options
Diffstat (limited to 'src/plugins/lua/neural.lua')
-rw-r--r-- | src/plugins/lua/neural.lua | 31 |
1 files changed, 15 insertions, 16 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 5f9b1f9d2..dbd420257 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -469,26 +469,25 @@ local function train_ann(rule, _, ev_base, elt, worker) end end - local function ann_trained(errcode, errmsg, train_mse) + local function ann_trained(err, data) rule.learning_spawned = false - if errcode ~= 0 then + if err then rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s', - prefix, errmsg) + prefix, err) lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - redis_unlock_cb, --callback - 'DEL', -- command - {prefix .. '_locked'} + rspamd_config, + rule.redis, + nil, + true, -- is write + redis_unlock_cb, --callback + 'DEL', -- command + {prefix .. '_locked'} ) else - rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s', - prefix, train_mse) - local f = rule.anns[elt].ann_train:save() - local ann_data = rspamd_util.zstd_compress(f) - + rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes', + prefix, #data) + local ann_data = rspamd_util.zstd_compress(data) + rule.anns[elt].ann_train = rspamd_kann.load(data) rule.anns[elt].version = rule.anns[elt].version + 1 rule.anns[elt].ann = rule.anns[elt].ann_train rule.anns[elt].ann_train = nil @@ -575,7 +574,7 @@ local function train_ann(rule, _, ev_base, elt, worker) }) local out = rule.anns[elt].ann_train:save() - return tostring(out) + return out end rule.learning_spawned = true |