]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Add preliminary train tests
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 4 Nov 2016 15:56:16 +0000 (15:56 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 4 Nov 2016 15:56:16 +0000 (15:56 +0000)
src/plugins/lua/fann_scores.lua

index 3c46cda2f1a1951337f65283e8db6d02fb8cacd6..7bc55117dcd4ad849c5e81dd24b48b0092581226 100644 (file)
@@ -520,7 +520,6 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
     end
 
     local function can_train_cb(err, data)
-      rspamd_logger.errx('data: %s, err: %s', data, err)
       if not err and tonumber(data) > 0 then
         local learn_data = symbols_to_fann_vector(
           map(function(r) return r[1] end, results),
@@ -556,6 +555,63 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
   end
 end
 
+local function train_fann(cfg, ev_base, elt)
+
+end
+
+local function maybe_train_fanns(cfg, ev_base)
+  local function members_cb(err, data)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
+    elseif type(data) == 'table' then
+      each(function(i, elt)
+        local redis_len_cb = function(err, data)
+          if err then
+            rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err)
+          elseif data and type(data) == 'number' or type(data) == 'string' then
+            if tonumber(data) and tonumber(data) > max_trains then
+              train_fann(cfg, ev_base, elt)
+            end
+          end
+        end
+
+        local local_ver = 0
+        local numelt = tonumber(elt)
+        if data[numelt] then
+          if data[numelt].version then
+            local_ver = data[numelt].version
+          end
+        end
+        redis_make_request(ev_base,
+          rspamd_config,
+          nil,
+          false, -- is write
+          redis_len_cb, --callback
+          'LLEN', -- command
+          {fann_prefix .. elt .. '_spam'}
+        )
+      end,
+      data)
+    end
+  end
+
+  if not redis_maybe_load_sha then
+    -- Plan new event early
+    return 1.0
+  end
+  -- First we need to get all fanns stored in our Redis
+  redis_make_request(ev_base,
+    rspamd_config,
+    nil,
+    false, -- is write
+    members_cb, --callback
+    'SMEMBERS', -- command
+    {fann_prefix} -- arguments
+  )
+
+  return watch_interval
+end
+
 local function check_fanns(cfg, ev_base)
   local function members_cb(err, data)
     if err then
@@ -680,7 +736,7 @@ else
     }
   end
   -- Add training scripts
-  rspamd_config:add_on_load(function(cfg, ev_base)
+  rspamd_config:add_on_load(function(cfg, ev_base, worker)
     local function can_train_sha_cb(err, data)
       if err or not data or type(data) ~= 'string' then
         rspamd_logger.errx(cfg, 'cannot save redis train script: %s', err)
@@ -733,5 +789,13 @@ else
       'SCRIPT', -- command
       {'LOAD', redis_lua_script_maybe_invalidate} -- arguments
     )
+
+    if worker:get_name() == 'normal' then
+      -- We also want to train neural nets when they have enough data
+      rspamd_config:add_periodic(ev_base, 0.0,
+        function(cfg, ev_base)
+          return maybe_train_fanns(cfg, ev_base)
+        end)
+    end
   end)
 end