end
end
- local function ann_trained(errcode, errmsg, train_mse)
+ local function ann_trained(err, data)
rule.learning_spawned = false
- if errcode ~= 0 then
+ if err then
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
- prefix, errmsg)
+ prefix, err)
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- redis_unlock_cb, --callback
- 'DEL', -- command
- {prefix .. '_locked'}
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ redis_unlock_cb, --callback
+ 'DEL', -- command
+ {prefix .. '_locked'}
)
else
- rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
- prefix, train_mse)
- local f = rule.anns[elt].ann_train:save()
- local ann_data = rspamd_util.zstd_compress(f)
-
+ rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes',
+ prefix, #data)
+ local ann_data = rspamd_util.zstd_compress(data)
+ rule.anns[elt].ann_train = rspamd_kann.load(data)
rule.anns[elt].version = rule.anns[elt].version + 1
rule.anns[elt].ann = rule.anns[elt].ann_train
rule.anns[elt].ann_train = nil
})
local out = rule.anns[elt].ann_train:save()
- return tostring(out)
+ return out
end
rule.learning_spawned = true