]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Neural: Implement scoring
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 6 Jul 2019 11:41:20 +0000 (12:41 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 6 Jul 2019 11:41:20 +0000 (12:41 +0100)
src/plugins/lua/neural.lua

index cca6f647cf5466ba5c0a33b1ab4e63086f1b940e..fdb138321f97eae90f00e29ba67be4ea0df496eb 100644 (file)
@@ -50,6 +50,26 @@ local default_options = {
   symbol_ham = 'NEURAL_HAM',
 }
 
+-- Rule structure:
+-- * static config fields (see `default_options`)
+-- * prefix - name or defined prefix
+-- * settings - table of settings indexed by settings id, -1 is used when no settings defined
+
+-- Rule settings element defines elements for specific settings id:
+-- * symbols - static symbols profile (defined by config or extracted from symcache)
+-- * name - name of settings id
+-- * digest - digest of all symbols
+-- * ann - dynamic ANN configuration loaded from Redis
+-- * train - train data for ANN (e.g. the currently trained ANN)
+
+-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
+-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
+-- * version - version of ANN loaded from redis
+-- * ann_key - name of ANN key in Redis
+-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
+-- * distance - distance between set.symbols and set.ann.symbols
+-- * ann - kann object
+
 local settings = {
   rules = {},
   prefix = 'rn', -- Neural network default prefix
@@ -212,30 +232,75 @@ local function load_scripts(params)
     params)
 end
 
+local function result_to_vector(task, profile)
+  if not profile.zeros then
+    -- Fill zeros vector
+    local zeros = {}
+    for i=1,meta_functions.count_metatokens() do
+      zeros[i] = 0.0
+    end
+    for _,_ in ipairs(profile.symbols) do
+      zeros[#zeros + 1] = 0.0
+    end
+    profile.zeros = zeros
+  end
+
+  local vec = lua_util.shallowcopy(profile.zeros)
+  local mt = meta_functions.rspamd_gen_metatokens(task)
+
+  for i,v in ipairs(mt) do
+    vec[i] = v
+  end
+
+  task:process_ann_tokens(profile.symbols, vec, #mt)
+
+  return vec
+end
 
 local function ann_scores_filter(task)
 
   for _,rule in pairs(settings.rules) do
-    local id = '0'
-    if rule.use_settings then
-     local sid = task:get_settings_id()
-     if sid then
-      id = tostring(sid)
-     end
-    end
-    if rule.per_user then
-      local r = task:get_principal_recipient()
-      id = id .. r
+    local sid = task:get_settings_id()
+    local ann
+    local profile
+
+    if sid then
+      if rule.settings[sid] then
+        local set = rule.settings[sid]
+
+        if set.ann then
+          ann = set.ann.ann
+          profile = set.ann
+        else
+          lua_util.debugm(N, task, 'no ann loaded for %s:%s',
+              rule.prefix, set.name)
+        end
+      else
+        lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
+            rule.prefix, sid)
+      end
+    else
+      if rule.settings[-1] then
+        local set = rule.settings[-1]
+
+        if set.ann then
+          ann = set.ann.ann
+          profile = set.ann
+        else
+          lua_util.debugm(N, task, 'no ann loaded for %s:%s',
+              rule.prefix, set.name)
+        end
+      else
+        lua_util.debugm(N, task, 'no default ann for rule %s',
+            rule.prefix)
+      end
     end
 
-    if rule.anns[id] and rule.anns[id].ann then
-      local ann_data = task:get_symbols_tokens()
-      local mt = meta_functions.rspamd_gen_metatokens(task)
-      -- Add filtered meta tokens
-      fun.each(function(e) table.insert(ann_data, e) end, mt)
+    if ann then
+      local vec = result_to_vector(task, profile)
 
       local score
-      local out = rule.anns[id].ann:apply1(ann_data)
+      local out = ann:apply1(vec)
       score = out[1]
 
       local symscore = string.format('%.3f', score)
@@ -262,76 +327,6 @@ local function create_ann(n, nlayers)
   return rspamd_kann.new.kann(t)
 end
 
-local function create_train_ann(rule, n, id)
-  local prefix = gen_ann_prefix(rule, id)
-  if not rule.anns[id] then
-    rule.anns[id] = {}
-  end
-  -- Fix that for flexibe layers number
-  if rule.anns[id].ann then
-    if not is_ann_valid(rule, prefix, rule.anns[id].ann) then
-      rule.anns[id].ann_train = create_ann(n, rule.nlayers)
-      rule.anns[id].ann = nil
-      rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
-    elseif rule.train.max_usages > 0 and
-        rule.anns[id].version % rule.train.max_usages == 0 then
-      -- Forget last ann
-      rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
-          rule.anns[id].version)
-      rule.anns[id].ann_train = create_ann(n, rule.nlayers)
-    else
-      rule.anns[id].ann_train = rule.anns[id].ann
-      rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix)
-    end
-  else
-    rule.anns[id].ann_train = create_ann(n, rule.nlayers)
-    rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix)
-    rule.anns[id].version = 0
-  end
-end
-
-local function load_or_invalidate_ann(rule, data, id, ev_base)
-  local ver = data[2]
-  local prefix = gen_ann_prefix(rule, id)
-
-  if not ver or not tonumber(ver) then
-    rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix)
-    return
-  end
-
-  local err,ann_data = rspamd_util.zstd_decompress(data[1])
-  local ann
-
-  if err or not ann_data then
-    rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
-    return
-  else
-    ann = rspamd_kann.load(ann_data)
-  end
-
-  if is_ann_valid(rule, prefix, ann) then
-    if not rule.anns[id] then rule.anns[id] = {} end
-    rule.anns[id].ann = ann
-    rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
-      prefix, ver)
-    rule.anns[id].version = tonumber(ver)
-  else
-    local function redis_invalidate_cb(_err, _data)
-      if _err then
-        rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
-      elseif type(_data) == 'string' then
-        rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
-        rule.anns[id].version = 0
-      end
-    end
-    -- Invalidate ANN
-    rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix)
-    lua_redis.exec_redis_script(redis_maybe_invalidate_id,
-      {ev_base = ev_base, is_write = true},
-      redis_invalidate_cb,
-      {prefix})
-  end
-end
 
 local function ann_train_callback(rule, task, score, required_score, id)
   local train_opts = rule['train']
@@ -901,6 +896,7 @@ local function cleanup_anns(rule, cfg, ev_base)
         end
       end
     end
+
     lua_redis.exec_redis_script(redis_maybe_invalidate_id,
         {ev_base = ev_base, is_write = true},
         invalidate_cb,
@@ -1095,6 +1091,7 @@ for k,rule in pairs(settings.rules) do
       -- We also want to train neural nets when they have enough data
       rspamd_config:add_periodic(ev_base, 0.0,
           function(_, _)
+            -- Clean old ANNs
             cleanup_anns(rule, cfg, ev_base)
             return maybe_train_anns(rule, cfg, ev_base, worker)
           end)