]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Neural: Start new NN profiles implementation
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 5 Jul 2019 17:46:30 +0000 (18:46 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 5 Jul 2019 17:46:30 +0000 (18:46 +0100)
src/plugins/lua/multimap.lua
src/plugins/lua/neural.lua

index 5db8d4680a82233f3485929ff66bbd1f42caacaa..9c4861e42d483ed339d8474369326452e339ad8a 100644 (file)
@@ -265,7 +265,7 @@ local function apply_addr_filter(task, filter, input, rule)
     end
   else
     -- regexp case
-  if not rule['re_filter'] then
+    if not rule['re_filter'] then
       local type,pat = string.match(filter, '(regexp:)(.+)')
       if type and pat then
         rule['re_filter'] = regexp.create(pat)
index dbd420257e5beb6ccf16466f290b82292acc9c38..ff53249c59ac2a998875af5b1447383ec3ec06d5 100644 (file)
@@ -25,6 +25,7 @@ local rspamd_kann = require "rspamd_kann"
 local lua_redis = require "lua_redis"
 local lua_util = require "lua_util"
 local fun = require "fun"
+local lua_settings = require "lua_settings"
 local meta_functions = require "lua_meta"
 local N = "neural"
 
@@ -41,10 +42,7 @@ local default_options = {
     learn_threads = 1,
     learning_rate = 0.01,
   },
-  use_settings = false,
-  per_user = false,
   watch_interval = 60.0,
-  nlayers = 4,
   lock_expire = 600,
   learning_spawned = false,
   ann_expire = 60 * 60 * 24 * 2, -- 2 days
@@ -53,7 +51,9 @@ local default_options = {
 }
 
 local settings = {
-  rules = {}
+  rules = {},
+  prefix = 'rn', -- Neural network default prefix
+  max_profiles = 3, -- Maximum number of NN profiles stored
 }
 
 local opts = rspamd_config:get_all_opt("neural")
@@ -124,20 +124,23 @@ local redis_lua_script_maybe_load = [[
 ]]
 local redis_maybe_load_id = nil
 
--- Lua script to invalidate ANN from redis
+-- Lua script to invalidate ANNs by rank
 -- Uses the following keys
 -- key1 - prefix for keys
+-- key2 - number of elements to leave
 local redis_lua_script_maybe_invalidate = [[
-  local locked = redis.call('GET', KEYS[1] .. '_locked')
-  if locked then return false end
-  redis.call('SET', KEYS[1] .. '_locked', '1')
-  redis.call('SET', KEYS[1] .. '_version', '0')
-  redis.call('DEL', KEYS[1] .. '_spam')
-  redis.call('DEL', KEYS[1] .. '_ham')
-  redis.call('DEL', KEYS[1] .. '_data')
-  redis.call('DEL', KEYS[1] .. '_locked')
-  redis.call('DEL', KEYS[1] .. '_hostname')
-  return 1
+  local card = redis.call('ZCARD', KEYS[1])
+  if card > tonumber(KEYS[2]) then
+    local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
+    for _,k in ipairs(to_delete) do
+      local tb = cjson.decode(k)
+      redis.call('DEL', tb.ann_key)
+    end
+    redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
+    return to_delete
+  else
+    return {}
+  end
 ]]
 local redis_maybe_invalidate_id = nil
 
@@ -209,23 +212,6 @@ local function load_scripts(params)
     params)
 end
 
-local function gen_ann_prefix(rule, id)
-  local cksum = rspamd_config:get_symbols_cksum():hex()
-  -- We also need to count metatokens:
-  local n = meta_functions.rspamd_count_metatokens()
-  local tprefix = 'k'
-  if id then
-    return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id
-  else
-    return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil
-  end
-end
-
-local function is_ann_valid(rule, prefix, ann)
-  if ann then
-    return true
-  end
-end
 
 local function ann_scores_filter(task)
 
@@ -732,57 +718,153 @@ local function maybe_train_anns(rule, cfg, ev_base, worker)
   return rule.watch_interval
 end
 
-local function check_anns(rule, _, ev_base)
-  local function members_cb(err, data)
-    if err then
-      rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s',
-        err)
-    elseif type(data) == 'table' then
-      fun.each(function(elt)
-        elt = tostring(elt)
-        local redis_update_cb = function(_err, _data)
-          if _err then
-            rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s',
-              elt, _err)
-          elseif _data and type(_data) == 'table' then
-            load_or_invalidate_ann(rule, _data, elt, ev_base)
-          else
-            if type(_data) ~= 'number' then
-              rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s',
-                type(_data), elt)
-            end
-          end
+-- This function loads new ann from Redis
+-- This is based on `profile` attribute.
+-- ANN is loaded from `profile.ann_key`
+-- Rank of `profile` key is also increased, unfortunately, it means that we need to
+-- serialize profile one more time and set its rank to the current time
+-- set.ann fields are set according to Redis data received
+local function load_new_ann(rule, ev_base, set, profile, min_diff)
+
+end
+
+-- Used to check an element in Redis serialized as JSON
+-- for some specific rule + some specific setting
+-- This function tries to load more fresh or more specific ANNs in lieu of
+-- the existing ones.
+local function process_existing_ann(rule, ev_base, set, profiles)
+  local my_symbols = set.symbols
+  local min_diff = math.huge
+  local sel_elt
+
+  for _,elt in fun.iter(profiles) do
+    if elt and elt.symbols then
+      local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
+
+      if dist < #my_symbols * .3 then
+        if dist < min_diff then
+          min_diff = dist
+          sel_elt = elt
         end
+      end
+    end
+  end
 
-        local local_ver = 0
-        if rule.anns[elt] then
-          if rule.anns[elt].version then
-            local_ver = rule.anns[elt].version
-          end
+  if sel_elt then
+    -- We can load element from ANN
+    if set.ann then
+      -- We have an existing ANN, probably the same...
+      if set.ann.digest == sel_elt.digest then
+        -- Same ANN, check version
+        if set.ann.version < sel_elt.version then
+          -- Load new ann
+          rspamd_logger.infox(rspamd_config, 'ann %s is changed,' ..
+              'our version = %s, remote version = %s',
+              rule.prefix .. ':' .. set.name,
+              set.ann.version,
+              sel_elt.version)
+          load_new_ann(rule, ev_base, set, sel_elt, min_diff)
+        else
+          lua_util.debugm(N, rspamd_config, 'ann %s is not changed,' ..
+              'our version = %s, remote version = %s',
+              rule.prefix .. ':' .. set.name,
+              set.ann.version,
+              sel_elt.version)
         end
-        lua_redis.exec_redis_script(redis_maybe_load_id,
-          {ev_base = ev_base, is_write = false},
-          redis_update_cb,
-          {gen_ann_prefix(rule, elt), tostring(local_ver)})
-      end,
-      data)
+      else
+        -- We have some different ANN, so we need to compare distance
+        if set.ann.distance > min_diff then
+          -- Load more specific ANN
+          rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s,' ..
+              'our distance = %s, remote distance = %s',
+              rule.prefix .. ':' .. set.name,
+              set.ann.distance,
+              min_diff)
+          load_new_ann(rule, ev_base, set, sel_elt, min_diff)
+        else
+          lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific,' ..
+              'our distance = %s, remote distance = %s',
+              rule.prefix .. ':' .. set.name,
+              set.ann.distance,
+              min_diff)
+        end
+      end
+    else
+      -- We have no ANN, load new one
+      load_new_ann(rule, ev_base, set, sel_elt, min_diff)
     end
   end
+end
 
-  -- First we need to get all anns stored in our Redis
-  lua_redis.redis_make_request_taskless(ev_base,
-    rspamd_config,
-    rule.redis,
-    nil,
-    false, -- is write
-    members_cb, --callback
-    'SMEMBERS', -- command
-    {gen_ann_prefix(rule, nil)} -- arguments
-  )
+-- Used to deserialise ANN element from a list
+local function load_ann_profile(element)
+  local ucl = require "ucl"
+
+  local parser = ucl.parser()
+  local res,ucl_err = parser:parse_string(element)
+  if not res then
+    rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
+        ucl_err)
+    return nil
+  else
+    return parser:get_object()
+  end
+end
+
+-- Function to check or load ANNs from Redis
+local function check_anns(rule, cfg, ev_base)
+  for _,set in pairs(rule.settings) do
+    local function members_cb(err, data)
+      if err then
+        rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
+            err)
+      elseif type(data) == 'table' then
+        process_existing_ann(rule, ev_base, set, fun.map(load_ann_profile, data))
+      end
+    end
+
+    -- Extract all profiles for some specific settings id
+    -- Get the last `max_profiles` recently used
+    -- Select the most appropriate to our profile but it should not differ by more
+    -- than 30% of symbols
+    lua_redis.redis_make_request_taskless(ev_base,
+        cfg,
+        rule.redis,
+        nil,
+        false, -- is write
+        members_cb, --callback
+        'ZREVRANGE', -- command
+        {set.prefix, '0', tostring(settings.max_profiles)} -- arguments
+    )
+  end -- Cycle over all settings
 
   return rule.watch_interval
 end
 
+-- Function to clean up old ANNs
+local function cleanup_anns(rule, cfg, ev_base)
+  for _,set in pairs(rule.settings) do
+    local function invalidate_cb(err, data)
+      if err then
+        rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
+            err)
+      elseif type(data) == 'table' then
+        for _,expired in ipairs(data) do
+          local profile = load_ann_profile(expired)
+          rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
+              rule.prefix .. ':' .. set.name,
+              profile.ann_key,
+              profile.version)
+        end
+      end
+    end
+    lua_redis.exec_redis_script(redis_maybe_invalidate_id,
+        {ev_base = ev_base, is_write = true},
+        invalidate_cb,
+        {set.prefix, tostring(settings.max_profiles)})
+  end
+end
+
 local function ann_push_vector(task)
   if task:has_flag('skip') then return end
   if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
@@ -800,6 +882,83 @@ local function ann_push_vector(task)
   end
 end
 
+
+-- Generate redis prefix for specific rule and specific settings
+local function redis_ann_prefix(rule, settings_name)
+  -- We also need to count metatokens:
+  local n = meta_functions.version
+  return string.format('%s_%s_%d_%s',
+      settings.prefix, rule.prefix, n, settings_name)
+end
+
+-- This function is used to adjust profiles and allowed setting ids for each rule
+-- It must be called when all settings are already registered (e.g. at post-init for config)
+local function process_rules_settings()
+  local function process_settings_elt(rule, selt)
+    local profile = rule.profile[selt.name]
+    if profile then
+      -- Use static user defined profile
+      -- Ensure that we have an array...
+      lua_util.debugm(N, rspamd_config, "use static profile for %s (%s)",
+          rule.prefix, selt.name)
+      if not profile[1] then profile = lua_util.keys(profile) end
+      selt.symbols = profile
+    else
+      lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
+          rule.prefix, selt.name)
+    end
+
+    -- Generic stuff
+    table.sort(selt.symbols)
+    selt.digest = lua_util.table_digest(selt.symbols)
+    selt.prefix = redis_ann_prefix(rule, selt.name)
+
+    lua_redis.register_prefix(selt.prefix, N,
+        string.format('NN prefix for rule "%s"; settings id "%s"',
+            rule.prefix, selt.name))
+  end
+
+  for _,rule in pairs(opts.rules) do
+    if not rule.allowed_settings then
+      -- Extract all settings ids
+      rule.allowed_settings = lua_util.keys(lua_settings.all_settings)
+    end
+
+    -- Convert to a map <setting_id> -> true
+    rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
+
+    -- Check if we can work without settings
+    if type(rule.default) ~= 'boolean' then
+      rule.default = true
+    end
+
+    rule.settings = {}
+
+    if rule.default then
+      local default_settings = {
+        symbols = lua_util.keys(lua_settings.default_symbols),
+        name = 'default'
+      }
+
+      process_settings_elt(rule, default_settings)
+      rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
+    end
+
+    -- Now, for each allowed settings, we store sorted symbols + digest
+    -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
+    for s,_ in pairs(rule.allowed_settings) do
+      -- Here, we have a name, set of symbols and
+      local selt = lua_settings.settings_by_id(s)
+      rule.settings[s] = {
+        symbols = selt.symbols, -- Already sorted
+        name = selt.name
+      }
+
+      process_settings_elt(rule, rule.settings[s])
+    end
+  end
+end
+
 redis_params = lua_redis.parse_redis_server('neural')
 
 if not redis_params then
@@ -818,7 +977,7 @@ local rules = opts['rules']
 if not rules then
   -- Use legacy configuration
   rules = {}
-  rules['RFANN'] = opts
+  rules['default'] = opts
 end
 
 local id = rspamd_config:register_symbol({
@@ -827,6 +986,7 @@ local id = rspamd_config:register_symbol({
   priority = 6,
   callback = ann_scores_filter
 })
+
 for k,r in pairs(rules) do
   local def_rules = lua_util.override_defaults(default_options, r)
   def_rules['redis'] = redis_params
@@ -841,6 +1001,7 @@ for k,r in pairs(rules) do
   if def_rules.train.max_train then
     def_rules.train.max_trains = def_rules.train.max_train
   end
+
   rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
   settings.rules[k] = def_rules
   rspamd_config:set_metric_symbol({
@@ -876,8 +1037,11 @@ rspamd_config:register_symbol({
 })
 
 -- Add training scripts
-for _,rule in pairs(settings.rules) do
+for k,rule in pairs(settings.rules) do
   load_scripts(rule.redis)
+  -- We also need to deal with settings
+  rspamd_config:add_post_init(process_rules_settings)
+  -- This function will check ANNs in Redis when a worker is loaded
   rspamd_config:add_on_load(function(cfg, ev_base, worker)
     rspamd_config:add_periodic(ev_base, 0.0,
         function(_, _)
@@ -888,6 +1052,7 @@ for _,rule in pairs(settings.rules) do
       -- We also want to train neural nets when they have enough data
       rspamd_config:add_periodic(ev_base, 0.0,
           function(_, _)
+            cleanup_anns(rule, cfg, ev_base)
             return maybe_train_anns(rule, cfg, ev_base, worker)
           end)
     end