]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Add extract training data function to fann_redis
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 5 Nov 2016 18:26:09 +0000 (18:26 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 5 Nov 2016 18:26:09 +0000 (18:26 +0000)
src/plugins/lua/fann_redis.lua

index e81af4762cf1f0975d4382481d28063420e6da06..f55454bf643dae2ecc0ace53b34d8ffece4a7610 100644 (file)
@@ -98,6 +98,16 @@ local redis_lua_script_maybe_invalidate = [[
 ]]
 local redis_maybe_invalidate_sha = nil
 
+-- Lua script to invalidate ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+local redis_lua_script_maybe_lock = [[
+  local locked = redis.call('GET', KEYS[1] .. '_locked')
+  if locked then return false end
+  return 1
+]]
+local redis_maybe_lock_sha = nil
+
 local redis_params
 redis_params = rspamd_parse_redis_server('fann_redis')
 
@@ -341,7 +351,106 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
 end
 
 local function train_fann(cfg, ev_base, elt)
+  local spam_elts = {}
+  local ham_elts = {}
+
+  local function redis_unlock_cb(err, data)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s from redis: %s',
+        fann_prefix .. elt, err)
+    end
+  end
+
+  local function redis_ham_cb(err, data)
+    if err or type(data) ~= 'table' then
+      rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
+        fann_prefix .. elt, err)
+      redis_make_request(ev_base,
+        rspamd_config,
+        nil,
+        false, -- is write
+        redis_unlock_cb, --callback
+        'DEL', -- command
+        {fann_prefix .. elt .. '_lock'}
+      )
+    else
+      -- Decompress and convert to numbers each training vector
+      ham_elts = map(function(i, elt)
+        local str = tostring(rspamd_util.zstd_decompress(elt))
+        return map(tonumber, rspamd_str_split(str, ';'))
+      end, data)
+
+      -- Now we need to join inputs and create the appropriate test vectors
+      local inputs = {}
+      local outputs = {}
+
+      each(function(i, elt)
+        table.insert(inputs, totable(elt))
+        table.insert(outputs, 1.0)
+      end, spam_elts)
+      each(function(i, elt)
+        table.insert(inputs, totable(elt))
+        table.insert(outputs, -1.0)
+      end, spam_elts)
+
+      -- Now we can train fann
+
+    end
+  end
 
+  local function redis_spam_cb(err, data)
+    if err or type(data) ~= 'table' then
+      rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
+        fann_prefix .. elt, err)
+      redis_make_request(ev_base,
+        rspamd_config,
+        nil,
+        false, -- is write
+        redis_unlock_cb, --callback
+        'DEL', -- command
+        {fann_prefix .. elt .. '_lock'}
+      )
+    else
+      -- Decompress and convert to numbers each training vector
+      spam_elts = map(function(i, elt)
+        local str = tostring(rspamd_util.zstd_decompress(elt))
+        return map(tonumber, rspamd_str_split(str, ';'))
+      end, data)
+      redis_make_request(ev_base,
+        rspamd_config,
+        nil,
+        false, -- is write
+        redis_ham_cb, --callback
+        'LRANGE', -- command
+        {fann_prefix .. elt .. '_ham', '0', '-1'}
+      )
+    end
+  end
+
+  local function redis_lock_cb(err, data)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+        fann_prefix .. elt, err)
+    elseif type(data) == 'number' then
+      -- Can train ANN
+      redis_make_request(ev_base,
+        rspamd_config,
+        nil,
+        false, -- is write
+        redis_spam_cb, --callback
+        'LRANGE', -- command
+        {fann_prefix .. elt .. '_spam', '0', '-1'}
+      )
+    end
+  end
+  redis_make_request(ev_base,
+    rspamd_config,
+    nil,
+    true, -- is write
+    redis_lock_cb, --callback
+    'EVALSHA', -- command
+    {redis_maybe_lock_sha, '1', fann_prefix .. elt}
+  )
 end
 
 local function maybe_train_fanns(cfg, ev_base)
@@ -575,6 +684,22 @@ else
       {'LOAD', redis_lua_script_maybe_invalidate} -- arguments
     )
 
+    local function maybe_lock_sha_cb(err, data)
+      if err or not data or type(data) ~= 'string' then
+        rspamd_logger.errx(cfg, 'cannot save redis lock script: %s', err)
+      else
+        redis_maybe_lock_sha = tostring(data)
+      end
+    end
+    redis_make_request(ev_base,
+      rspamd_config,
+      nil,
+      true, -- is write
+      maybe_lock_sha_cb, --callback
+      'SCRIPT', -- command
+      {'LOAD', redis_lua_script_maybe_lock} -- arguments
+    )
+
     if worker:get_name() == 'normal' then
       -- We also want to train neural nets when they have enough data
       rspamd_config:add_periodic(ev_base, 0.0,