]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Rework fann plugin to be a normal post-filter
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 29 Jul 2017 14:23:39 +0000 (15:23 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 29 Jul 2017 14:23:39 +0000 (15:23 +0100)
src/plugins/lua/fann_redis.lua

index 2c3de9ddcda06c4a13f136a70df7436973da5223..dbb4955effda6b36f4b251fc3b8602adcb72afdd 100644 (file)
@@ -28,6 +28,7 @@ local fann_symbol_spam = 'FANNR_SPAM'
 local fann_symbol_ham = 'FANNR_HAM'
 local rspamd_redis = require "lua_redis"
 local fun = require "fun"
+local meta_functions = require "meta_functions"
 
 local module_log_id = 0x200
 -- Module vars
@@ -286,26 +287,6 @@ local function load_scripts(cfg, ev_base, on_load_cb)
   )
 end
 
-local function symbols_to_fann_vector(syms, scores)
-  local learn_data = {}
-  local matched_symbols = {}
-  local n = rspamd_config:get_symbols_count()
-
-  fun.each(function(s, score)
-     matched_symbols[s + 1] = rspamd_util.tanh(score)
-  end, fun.zip(syms, scores))
-
-  for i=1,n do
-    if matched_symbols[i] then
-      learn_data[i] = matched_symbols[i]
-    else
-      learn_data[i] = 0
-    end
-  end
-
-  return learn_data
-end
-
 local function gen_fann_prefix(id)
   if id then
     return fann_prefix .. rspamd_config:get_symbols_cksum():hex() .. id,id
@@ -345,13 +326,10 @@ local function fann_scores_filter(task)
   end
 
   if fanns[id].fann then
-    local symbols,scores = task:get_symbols_numeric()
-    local fann_data = symbols_to_fann_vector(symbols, scores)
-    local mt = rspamd_gen_metatokens(task)
-
-    for _,tok in ipairs(mt) do
-      table.insert(fann_data, tok)
-    end
+    local fann_data = task:get_symbols_tokens()
+    local mt = meta_functions.rspamd_gen_metatokens(task)
+    -- Add filtered meta tokens
+    fun.each(function(e) table.insert(fann_data, e) end, mt)
 
     local out = fanns[id].fann:test(fann_data)
     local symscore = string.format('%.3f', out[1])
@@ -445,7 +423,7 @@ local function load_or_invalidate_fann(data, id, ev_base)
   end
 end
 
-local function fann_train_callback(score, required_score, results, _, id, opts, extra, ev_base)
+local function fann_train_callback(task, score, required_score, id, opts)
   local fname,suffix = gen_fann_prefix(id)
 
   local learn_spam, learn_ham
@@ -473,13 +451,11 @@ local function fann_train_callback(score, required_score, results, _, id, opts,
 
     local function can_train_cb(err, data)
       if not err and tonumber(data) > 0 then
-        local learn_data = symbols_to_fann_vector(
-          fun.map(function(r) return r[1] end, results),
-          fun.map(function(r) return r[2] end, results)
-        )
+        local fann_data = task:get_symbols_tokens()
+        local mt = meta_functions.rspamd_gen_metatokens(task)
         -- Add filtered meta tokens
-        fun.each(function(e) table.insert(learn_data, e) end, extra)
-        local str = rspamd_util.zstd_compress(table.concat(learn_data, ';'))
+        fun.each(function(e) table.insert(fann_data, e) end, mt)
+        local str = rspamd_util.zstd_compress(table.concat(fann_data, ';'))
 
         rspamd_redis.redis_make_request_taskless(ev_base,
           rspamd_config,
@@ -500,14 +476,14 @@ local function fann_train_callback(score, required_score, results, _, id, opts,
       end
     end
 
-    rspamd_redis.redis_make_request_taskless(ev_base,
-      rspamd_config,
+    rspamd_redis.rspamd_redis_make_request(task,
       redis_params,
       nil,
       true, -- is write
       can_train_cb, --callback
       'EVALSHA', -- command
-      {redis_can_train_sha, '4', gen_fann_prefix(nil), suffix, k, tostring(max_trains)} -- arguments
+      {redis_can_train_sha, '4', gen_fann_prefix(nil),
+        suffix, k, tostring(max_trains)} -- arguments
     )
   end
 end
@@ -857,6 +833,18 @@ local function check_fanns(_, ev_base)
   return watch_interval
 end
 
+local function ann_push_vector(task)
+  local scores = task:get_metric_score()
+  local sid = task:get_settings_id()
+  if use_settings then
+    fann_train_callback(task, scores[1], scores[2],
+      tostring(sid), opts['train'])
+  else
+    fann_train_callback(task, scores[1], scores[2],
+      tostring(sid), opts['train'])
+  end
+end
+
 -- Initialization part
 
 local opts = rspamd_config:get_all_opt("fann_redis")
@@ -892,7 +880,7 @@ else
   local id = rspamd_config:register_symbol({
     name = fann_symbol_spam,
     type = 'postfilter',
-    priority = 5,
+    priority = 6,
     callback = fann_scores_filter
   })
   rspamd_config:set_metric_symbol({
@@ -907,45 +895,24 @@ else
     parent = id
   })
   if opts['train'] then
-    rspamd_config:add_on_load(function(cfg)
-      if opts['train']['max_train'] then
-        max_trains = opts['train']['max_train']
-      end
-      if opts['train']['max_epoch'] then
-        max_epoch = opts['train']['max_epoch']
-      end
-      if opts['train']['max_usages'] then
-        max_usages = opts['train']['max_usages']
-      end
-      if opts['train']['mse'] then
-        mse = opts['train']['mse']
-      end
-      local ret = cfg:register_worker_script("log_helper",
-        function(score, req_score, results, cf, _id, extra, ev_base)
-          -- fun.map (snd x) (fun.filter (fst x == module_id) extra)
-          local extra_fann = fun.map(function(e) return e[2] end,
-            fun.filter(function(e) return e[1] == module_log_id end, extra))
-          if use_settings then
-            fann_train_callback(score, req_score, results, cf,
-              tostring(_id), opts['train'], extra_fann, ev_base)
-          else
-            fann_train_callback(score, req_score, results, cf, '0',
-              opts['train'], extra_fann, ev_base)
-          end
-        end)
-
-      if not ret then
-        rspamd_logger.errx(cfg, 'cannot find worker "log_helper"')
-      end
-    end)
-    -- This is needed to pass extra tokens from worker to log_helper
-    rspamd_plugins["fann_redis"] = {
-      log_callback = function(task)
-        return fun.totable(fun.map(
-          function(tok) return {module_log_id, tok} end,
-          rspamd_gen_metatokens(task)))
-      end
-    }
+    if opts['train']['max_train'] then
+      max_trains = opts['train']['max_train']
+    end
+    if opts['train']['max_epoch'] then
+      max_epoch = opts['train']['max_epoch']
+    end
+    if opts['train']['max_usages'] then
+      max_usages = opts['train']['max_usages']
+    end
+    if opts['train']['mse'] then
+      mse = opts['train']['mse']
+    end
+    rspamd_config:register_symbol({
+      name = 'FANN_VECTOR_PUSH',
+      type = 'postfilter',
+      priority = 5,
+      callback = ann_push_vector
+    })
   end
   -- Add training scripts
   rspamd_config:add_on_load(function(cfg, ev_base, worker)