]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Neural: Add PCA loading logic
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 27 Aug 2020 14:35:42 +0000 (15:35 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 27 Aug 2020 14:35:42 +0000 (15:35 +0100)
src/plugins/lua/neural.lua

index a3027662c43c1f3c93db544b430b34832abdc025..352d397d570f4033f441db4a01883130421008e9 100644 (file)
@@ -22,6 +22,7 @@ end
 local rspamd_logger = require "rspamd_logger"
 local rspamd_util = require "rspamd_util"
 local rspamd_kann = require "rspamd_kann"
+local rspamd_text = require "rspamd_text"
 local lua_redis = require "lua_redis"
 local lua_util = require "lua_util"
 local rspamd_tensor = require "rspamd_tensor"
@@ -71,6 +72,7 @@ local redis_profile_schema = ts.shape{
 }
 
 local has_blas = rspamd_tensor.has_blas()
+local text_cookie = rspamd_text.cookie
 
 -- Rule structure:
 -- * static config fields (see `default_options`)
@@ -327,7 +329,7 @@ local function ann_scores_filter(task)
       local vec = result_to_vector(task, profile)
 
       local score
-      local out = ann:apply1(vec)
+      local out = ann:apply1(vec, set.ann.pca)
       score = out[1]
 
       local symscore = string.format('%.3f', score)
@@ -940,52 +942,81 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
       rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
           ann_key, err)
     else
-      if type(data) == 'string' then
-        local _err,ann_data = rspamd_util.zstd_decompress(data)
-        local ann
-
-        if _err or not ann_data then
-          rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
-              rule.prefix .. ':' .. set.name, ann_key, _err)
-          return
+      if type(data) == 'table' then
+        if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
+          local _err,ann_data = rspamd_util.zstd_decompress(data[1])
+          local ann
+
+          if _err or not ann_data then
+            rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
+                rule.prefix .. ':' .. set.name, ann_key, _err)
+            return
+          else
+            ann = rspamd_kann.load(ann_data)
+
+            if ann then
+              set.ann = {
+                digest = profile.digest,
+                version = profile.version,
+                symbols = profile.symbols,
+                distance = min_diff,
+                redis_key = profile.redis_key
+              }
+
+              local ucl = require "ucl"
+              local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+              set.ann.ann = ann -- To avoid serialization
+
+              local function rank_cb(_, _)
+                -- TODO: maybe add some logging
+              end
+              -- Also update rank for the loaded ANN to avoid removal
+              lua_redis.redis_make_request_taskless(ev_base,
+                  rspamd_config,
+                  rule.redis,
+                  nil,
+                  true, -- is write
+                  rank_cb, --callback
+                  'ZADD', -- command
+                  {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
+              )
+              rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
+                  rule.prefix, set.name, ann_key, #ann_data, profile.version)
+            else
+              rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
+                  rule.prefix, set.name, ann_key)
+            end
+          end
         else
-          ann = rspamd_kann.load(ann_data)
-
-          if ann then
-            set.ann = {
-              digest = profile.digest,
-              version = profile.version,
-              symbols = profile.symbols,
-              distance = min_diff,
-              redis_key = profile.redis_key
-            }
-
-            local ucl = require "ucl"
-            local profile_serialized = ucl.to_format(profile, 'json-compact', true)
-            set.ann.ann = ann -- To avoid serialization
-
-            local function rank_cb(_, _)
-              -- TODO: maybe add some logging
+          lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
+              rule.prefix, set.name, ann_key)
+        end
+        if set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
+          -- PCA table
+          local _err,pca_data = rspamd_util.zstd_decompress(data[2])
+          if pca_data then
+            if rule.max_inputs then
+              -- We can use PCA
+              set.ann.pca = rspamd_tensor.load(pca_data)
+            else
+              -- no need in pca, why is it there?
+              rspamd_logger.warnx(rspamd_config, 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
+                  rule.prefix, set.name, ann_key)
             end
-            -- Also update rank for the loaded ANN to avoid removal
-            lua_redis.redis_make_request_taskless(ev_base,
-                rspamd_config,
-                rule.redis,
-                nil,
-                true, -- is write
-                rank_cb, --callback
-                'ZADD', -- command
-                {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
-            )
-            rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
-                rule.prefix, set.name, ann_key, #ann_data, profile.version)
           else
-            rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s',
-                rule.prefix, set.name, ann_key)
+            -- pca can be missing merely if we have no max_inputs
+            if rule.max_inputs then
+              rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
+                  rule.prefix, set.name, ann_key, _err)
+              set.ann.ann = nil
+            else
+              -- It is okay
+              set.ann.pca = nil
+            end
           end
         end
       else
-        lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s',
+        lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
             rule.prefix, set.name, ann_key)
       end
     end
@@ -996,8 +1027,9 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
       nil,
       false, -- is write
       data_cb, --callback
-      'HGET', -- command
-      {ann_key, 'ann'} -- arguments
+      'HMGET', -- command
+      {ann_key, 'ann', 'pca'}, -- arguments
+      {opaque_data = true}
   )
 end