]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Neural: Add storing vectors part
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 6 Jul 2019 14:49:08 +0000 (15:49 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 6 Jul 2019 14:49:08 +0000 (15:49 +0100)
src/plugins/lua/neural.lua

index f40778a7b1c26e7ace836bf10f4d4b43091860db..c2ffb3e15f9dc7e9658c105508d9c971dba7b3cf 100644 (file)
@@ -121,23 +121,6 @@ local redis_lua_script_can_train = [[
 ]]
 local redis_can_train_id = nil
 
--- Lua script to load ANN from redis
--- Uses the following keys
--- key1 - prefix for keys
--- key2 - local version
--- returns nil or bulk string if new ANN can be loaded
-local redis_lua_script_maybe_load = [[
-  local ver = 0
-  local ret = redis.call('GET', KEYS[1] .. '_version')
-  if ret then ver = tonumber(ret) end
-  if ver > tonumber(KEYS[2]) then
-    return {redis.call('GET', KEYS[1] .. '_data'), ret}
-  end
-
-  return tonumber(ret) or 0
-]]
-local redis_maybe_load_id = nil
-
 -- Lua script to invalidate ANNs by rank
 -- Uses the following keys
 -- key1 - prefix for keys
@@ -148,10 +131,10 @@ local redis_lua_script_maybe_invalidate = [[
     local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
     for _,k in ipairs(to_delete) do
       local tb = cjson.decode(k)
-      redis.call('DEL', tb.ann_key)
+      redis.call('DEL', tb.redis_key)
       -- Also train vectors
-      redis.call('DEL', tb.ann_key .. '_spam')
-      redis.call('DEL', tb.ann_key .. '_ham')
+      redis.call('DEL', tb.redis_key .. '_spam')
+      redis.call('DEL', tb.redis_key .. '_ham')
     end
     redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
     return to_delete
@@ -217,8 +200,6 @@ local redis_params
 local function load_scripts(params)
   redis_can_train_id = lua_redis.add_redis_script(redis_lua_script_can_train,
     params)
-  redis_maybe_load_id = lua_redis.add_redis_script(redis_lua_script_maybe_load,
-    params)
   redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
     params)
   redis_locked_invalidate_id = lua_redis.add_redis_script(redis_lua_script_locked_invalidate,
@@ -254,6 +235,55 @@ local function result_to_vector(task, profile)
   return vec
 end
 
+-- Used to generate new ANN key for specific profile
+local function new_ann_key(rule, set)
+  local ann_key = string.format('%s_%s_%s_%s_nn', settings.prefix,
+      rule.prefix, set.name, set.digest:sub(1, 8))
+
+  return ann_key
+end
+
+-- Creates and stores ANN profile in Redis
+local function new_ann_profile(task, rule, set)
+  local ann_key = new_ann_key(rule, set)
+
+
+  local profile = {
+    symbols = set.symbols,
+    redis_key = ann_key,
+    version = 0,
+    digest = set.digest,
+    distance = 0 -- Since we are using our own profile
+  }
+
+  local ucl = require "ucl"
+  local profile_serialized = ucl.to_format(profile, 'json-compact')
+
+  local function add_cb(err, _)
+    if err then
+      rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
+          rule.prefix, set.name, err)
+    else
+      rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
+          rule.prefix, set.name, profile.redis_key)
+    end
+  end
+
+  lua_redis.redis_make_request_taskless(ev_base,
+      rspamd_config,
+      rule.redis,
+      nil,
+      true, -- is write
+      add_cb, --callback
+      'ZADD', -- command
+      {set.prefix, profile_serialized, tostring(rspamd_util.get_time())}
+  )
+
+  return profile
+end
+
+
+-- ANN filter function, used to insert scores based on the existing symbols
 local function ann_scores_filter(task)
 
   for _,rule in pairs(settings.rules) do
@@ -389,7 +419,8 @@ local function ann_train_callback(rule, task, score, required_score, set)
         )
       else
         if err then
-          rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err)
+          rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
+              rule.prefix, set.name, err)
         elseif tonumber(data) < 0 then
           rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s",
             rule.prefix, set.name, learn_type, -tonumber(data))
@@ -399,6 +430,7 @@ local function ann_train_callback(rule, task, score, required_score, set)
 
     if not set.ann then
       -- Need to create or load a profile corresponding to the current configuration
+      set.ann = new_ann_profile(task, rule, set)
     end
     -- Check if we can learn
     lua_redis.exec_redis_script(redis_can_train_id,
@@ -704,12 +736,12 @@ end
 
 -- This function loads new ann from Redis
 -- This is based on `profile` attribute.
--- ANN is loaded from `profile.ann_key`
+-- ANN is loaded from `profile.redis_key`
 -- Rank of `profile` key is also increased, unfortunately, it means that we need to
 -- serialize profile one more time and set its rank to the current time
 -- set.ann fields are set according to Redis data received
 local function load_new_ann(rule, ev_base, set, profile, min_diff)
-  local ann_key = profile.ann_key
+  local ann_key = profile.redis_key
 
   local function data_cb(err, data)
     if err then
@@ -732,9 +764,25 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
             version = profile.version,
             symbols = profile.symbols,
             distance = min_diff,
-            redis_key = profile.ann_key
+            redis_key = profile.redis_key
           }
 
+          local ucl = require "ucl"
+          local profile_serialized = ucl.to_format(profile, 'json-compact')
+
+          local function rank_cb(_, _)
+            -- TODO: maybe add some logging
+          end
+          -- Also update rank for the loaded ANN to avoid removal
+          lua_redis.redis_make_request_taskless(ev_base,
+              rspamd_config,
+              rule.redis,
+              nil,
+              true, -- is write
+              rank_cb, --callback
+              'ZADD', -- command
+              {set.prefix, profile_serialized, tostring(rspamd_util.get_time())}
+          )
           rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s',
               rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version)
         else
@@ -881,7 +929,7 @@ local function cleanup_anns(rule, cfg, ev_base)
           local profile = load_ann_profile(expired)
           rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
               rule.prefix .. ':' .. set.name,
-              profile.ann_key,
+              profile.redis_key,
               profile.version)
         end
       end
@@ -941,7 +989,7 @@ local function process_rules_settings()
 
     lua_redis.register_prefix(selt.prefix, N,
         string.format('NN prefix for rule "%s"; settings id "%s"',
-            rule.prefix, selt.name))
+            rule.prefix, selt.name), {persistent = true})
   end
 
   for _,rule in pairs(opts.rules) do
@@ -1063,7 +1111,7 @@ rspamd_config:register_symbol({
 })
 
 -- Add training scripts
-for k,rule in pairs(settings.rules) do
+for _,rule in pairs(settings.rules) do
   load_scripts(rule.redis)
   -- We also need to deal with settings
   rspamd_config:add_post_init(process_rules_settings)