]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Add ANN load function
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 6 Jul 2019 08:39:48 +0000 (09:39 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 6 Jul 2019 08:39:48 +0000 (09:39 +0100)
src/plugins/lua/neural.lua

index ff53249c59ac2a998875af5b1447383ec3ec06d5..cca6f647cf5466ba5c0a33b1ab4e63086f1b940e 100644 (file)
@@ -725,7 +725,50 @@ end
 -- serialize profile one more time and set its rank to the current time
 -- set.ann fields are set according to Redis data received
 local function load_new_ann(rule, ev_base, set, profile, min_diff)
+  local ann_key = profile.ann_key
 
+  local function data_cb(err, data)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
+          ann_key, err)
+    else
+      local _err,ann_data = rspamd_util.zstd_decompress(data[1])
+      local ann
+
+      if _err or not ann_data then
+        rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
+            rule.prefix .. ':' .. set.name, ann_key, _err)
+        return
+      else
+        ann = rspamd_kann.load(ann_data)
+
+        if ann then
+          set.ann = {
+            ann = ann,
+            version = profile.version,
+            symbols = profile.symbols,
+            distance = min_diff
+          }
+
+          rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s',
+              rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version)
+        else
+          rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s from Redis key %s',
+              rule.prefix .. ':' .. set.name, ann_key)
+        end
+      end
+    end
+  end
+  lua_redis.redis_make_request_taskless(ev_base,
+      rspamd_config,
+      rule.redis,
+      nil,
+      false, -- is write
+      data_cb, --callback
+      'HGET', -- command
+      {ann_key, 'ann'}, -- arguments
+      {opaque_data = true}
+  )
 end
 
 -- Used to check an element in Redis serialized as JSON
@@ -740,7 +783,7 @@ local function process_existing_ann(rule, ev_base, set, profiles)
   for _,elt in fun.iter(profiles) do
     if elt and elt.symbols then
       local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
-
+      -- Check distance
       if dist < #my_symbols * .3 then
         if dist < min_diff then
           min_diff = dist