]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Fix train callback
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 4 Jul 2019 11:02:10 +0000 (12:02 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 4 Jul 2019 11:02:10 +0000 (12:02 +0100)
src/plugins/lua/neural.lua

index 5f9b1f9d2af9738c274d997fecf601a9959ae029..dbd420257e5beb6ccf16466f290b82292acc9c38 100644 (file)
@@ -469,26 +469,25 @@ local function train_ann(rule, _, ev_base, elt, worker)
     end
   end
 
-  local function ann_trained(errcode, errmsg, train_mse)
+  local function ann_trained(err, data)
     rule.learning_spawned = false
-    if errcode ~= 0 then
+    if err then
       rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
-        prefix, errmsg)
+          prefix, err)
       lua_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        true, -- is write
-        redis_unlock_cb, --callback
-        'DEL', -- command
-        {prefix .. '_locked'}
+          rspamd_config,
+          rule.redis,
+          nil,
+          true, -- is write
+          redis_unlock_cb, --callback
+          'DEL', -- command
+          {prefix .. '_locked'}
       )
     else
-      rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
-          prefix, train_mse)
-      local f = rule.anns[elt].ann_train:save()
-      local ann_data = rspamd_util.zstd_compress(f)
-
+      rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes',
+          prefix, #data)
+      local ann_data = rspamd_util.zstd_compress(data)
+      rule.anns[elt].ann_train = rspamd_kann.load(data)
       rule.anns[elt].version = rule.anns[elt].version + 1
       rule.anns[elt].ann = rule.anns[elt].ann_train
       rule.anns[elt].ann_train = nil
@@ -575,7 +574,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
           })
 
           local out = rule.anns[elt].ann_train:save()
-          return tostring(out)
+          return out
         end
 
         rule.learning_spawned = true