summaryrefslogtreecommitdiffstats
path: root/src/plugins/lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-05 18:26:09 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-05 18:26:09 +0000
commita442c7f57f20fee40767c5d63bf98ceaabb8f183 (patch)
tree2294678746d09c2a1dcbc6466e2eb48c19549c3b /src/plugins/lua
parentd29a3dc66b6a7ed24d2eef0e87f8b8d54701e086 (diff)
downloadrspamd-a442c7f57f20fee40767c5d63bf98ceaabb8f183.tar.gz
rspamd-a442c7f57f20fee40767c5d63bf98ceaabb8f183.zip
[Rework] Add extract training data function to fann_redis
Diffstat (limited to 'src/plugins/lua')
-rw-r--r--src/plugins/lua/fann_redis.lua125
1 files changed, 125 insertions, 0 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index e81af4762..f55454bf6 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -98,6 +98,16 @@ local redis_lua_script_maybe_invalidate = [[
]]
local redis_maybe_invalidate_sha = nil
+-- Lua script to invalidate ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+local redis_lua_script_maybe_lock = [[
+ local locked = redis.call('GET', KEYS[1] .. '_locked')
+ if locked then return false end
+ return 1
+]]
+local redis_maybe_lock_sha = nil
+
local redis_params
redis_params = rspamd_parse_redis_server('fann_redis')
@@ -341,7 +351,106 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
end
local function train_fann(cfg, ev_base, elt)
+ local spam_elts = {}
+ local ham_elts = {}
+
+ local function redis_unlock_cb(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s from redis: %s',
+ fann_prefix .. elt, err)
+ end
+ end
+
+ local function redis_ham_cb(err, data)
+ if err or type(data) ~= 'table' then
+ rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
+ fann_prefix .. elt, err)
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_unlock_cb, --callback
+ 'DEL', -- command
+ {fann_prefix .. elt .. '_lock'}
+ )
+ else
+ -- Decompress and convert to numbers each training vector
+ ham_elts = map(function(i, elt)
+ local str = tostring(rspamd_util.zstd_decompress(elt))
+ return map(tonumber, rspamd_str_split(str, ';'))
+ end, data)
+
+ -- Now we need to join inputs and create the appropriate test vectors
+ local inputs = {}
+ local outputs = {}
+
+ each(function(i, elt)
+ table.insert(inputs, totable(elt))
+ table.insert(outputs, 1.0)
+ end, spam_elts)
+ each(function(i, elt)
+ table.insert(inputs, totable(elt))
+ table.insert(outputs, -1.0)
+ end, spam_elts)
+
+ -- Now we can train fann
+
+ end
+ end
+ local function redis_spam_cb(err, data)
+ if err or type(data) ~= 'table' then
+ rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
+ fann_prefix .. elt, err)
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_unlock_cb, --callback
+ 'DEL', -- command
+ {fann_prefix .. elt .. '_lock'}
+ )
+ else
+ -- Decompress and convert to numbers each training vector
+ spam_elts = map(function(i, elt)
+ local str = tostring(rspamd_util.zstd_decompress(elt))
+ return map(tonumber, rspamd_str_split(str, ';'))
+ end, data)
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_ham_cb, --callback
+ 'LRANGE', -- command
+ {fann_prefix .. elt .. '_ham', '0', '-1'}
+ )
+ end
+ end
+
+ local function redis_lock_cb(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+ fann_prefix .. elt, err)
+ elseif type(data) == 'number' then
+ -- Can train ANN
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_spam_cb, --callback
+ 'LRANGE', -- command
+ {fann_prefix .. elt .. '_spam', '0', '-1'}
+ )
+ end
+ end
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ redis_lock_cb, --callback
+ 'EVALSHA', -- command
+ {redis_maybe_lock_sha, '1', fann_prefix .. elt}
+ )
end
local function maybe_train_fanns(cfg, ev_base)
@@ -575,6 +684,22 @@ else
{'LOAD', redis_lua_script_maybe_invalidate} -- arguments
)
+ local function maybe_lock_sha_cb(err, data)
+ if err or not data or type(data) ~= 'string' then
+ rspamd_logger.errx(cfg, 'cannot save redis lock script: %s', err)
+ else
+ redis_maybe_lock_sha = tostring(data)
+ end
+ end
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ maybe_lock_sha_cb, --callback
+ 'SCRIPT', -- command
+ {'LOAD', redis_lua_script_maybe_lock} -- arguments
+ )
+
if worker:get_name() == 'normal' then
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,