aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-05 21:48:53 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-05 21:48:53 +0100
commit5eb00bd1b49028f9f5591572108b4bb104e56ebb (patch)
tree55894175af7dea56acb82fc70f04228b2a159e9d
parenta442c7f57f20fee40767c5d63bf98ceaabb8f183 (diff)
downloadrspamd-5eb00bd1b49028f9f5591572108b4bb104e56ebb.tar.gz
rspamd-5eb00bd1b49028f9f5591572108b4bb104e56ebb.zip
[Rework] Add redis storage feature to fann_redis
-rw-r--r--src/plugins/lua/fann_redis.lua96
1 files changed, 83 insertions, 13 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index f55454bf6..aabf465ce 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -30,9 +30,7 @@ local module_log_id = 0x100
-- ANNs indexed by settings id
local data = {
['0'] = {
- fann_mtime = 0,
- ntrains = 0,
- epoch = 0,
+ version = 0,
}
}
@@ -108,6 +106,20 @@ local redis_lua_script_maybe_lock = [[
]]
local redis_maybe_lock_sha = nil
+-- Lua script to save and unlock ANN in redis
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - compressed ANN
+local redis_lua_script_save_unlock = [[
+ redis.call('INCRBY', KEYS[1] .. '_version', '1')
+ redis.call('DEL', KEYS[1] .. '_spam')
+ redis.call('DEL', KEYS[1] .. '_ham')
+ redis.call('SET', KEYS[1] .. '_data', KEYS[2])
+ redis.call('DEL', KEYS[1] .. '_locked')
+ return 1
+]]
+local redis_save_unlock_sha = nil
+
local redis_params
redis_params = rspamd_parse_redis_server('fann_redis')
@@ -116,6 +128,7 @@ local max_trains = 1000
local max_epoch = 100
local use_settings = false
local watch_interval = 60.0
+local mse = 0.0001
local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args)
if not ev_base or not redis_params or not callback or not command then
@@ -251,8 +264,7 @@ end
local function create_train_fann(n, id)
data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1)
- data[id].ntrains = 0
- data[id].epoch = 0
+ data[id].version = 0
end
local function load_or_invalidate_fann(data, id, ev_base)
@@ -361,6 +373,41 @@ local function train_fann(cfg, ev_base, elt)
end
end
+ local function redis_save_unlock_sha(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
+ fann_prefix .. elt, err)
+ end
+ end
+
+ local function ann_trained(errcode, errmsg, train_mse)
+ if errcode ~= 0 then
+ rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
+ fann_prefix .. elt, errmsg)
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_unlock_cb, --callback
+ 'DEL', -- command
+ {fann_prefix .. elt .. '_lock'}
+ )
+ else
+ rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
+ fann_prefix .. elt, train_mse)
+ local ann_data = rspamd_util.zstd_compress(data[elt].fann:data())
+ data[elt].version = data[elt].version + 1
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ redis_save_cb, --callback
+ 'EVALSHA', -- command
+ {redis_save_unlock_sha, '2', fann_prefix .. elt, ann_data}
+ )
+ end
+ end
+
local function redis_ham_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
@@ -375,8 +422,8 @@ local function train_fann(cfg, ev_base, elt)
)
else
-- Decompress and convert to numbers each training vector
- ham_elts = map(function(i, elt)
- local str = tostring(rspamd_util.zstd_decompress(elt))
+ ham_elts = map(function(i, tok)
+ local str = tostring(rspamd_util.zstd_decompress(tok))
return map(tonumber, rspamd_str_split(str, ';'))
end, data)
@@ -384,17 +431,24 @@ local function train_fann(cfg, ev_base, elt)
local inputs = {}
local outputs = {}
- each(function(i, elt)
- table.insert(inputs, totable(elt))
+ each(function(i, sample)
+ table.insert(inputs, totable(sample))
table.insert(outputs, 1.0)
end, spam_elts)
- each(function(i, elt)
- table.insert(inputs, totable(elt))
+ each(function(i, sample)
+ table.insert(inputs, totable(sample))
table.insert(outputs, -1.0)
end, spam_elts)
-- Now we can train fann
+ local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
+ if not data[elt].fann then
+ -- Create fann if it does not exist
+ create_train_fann(n, elt)
+ end
+ data[elt].fann:train_threaded(inputs, outputs, ann_trained, ev_base,
+ {max_epochs = max_epoch, desired_mse = mse})
end
end
@@ -412,8 +466,8 @@ local function train_fann(cfg, ev_base, elt)
)
else
-- Decompress and convert to numbers each training vector
- spam_elts = map(function(i, elt)
- local str = tostring(rspamd_util.zstd_decompress(elt))
+ spam_elts = map(function(i, tok)
+ local str = tostring(rspamd_util.zstd_decompress(tok))
return map(tonumber, rspamd_str_split(str, ';'))
end, data)
redis_make_request(ev_base,
@@ -700,6 +754,22 @@ else
{'LOAD', redis_lua_script_maybe_lock} -- arguments
)
+ local function save_unlock_sha_cb(err, data)
+ if err or not data or type(data) ~= 'string' then
+ rspamd_logger.errx(cfg, 'cannot save redis save script: %s', err)
+ else
+ redis_save_unlock_sha = tostring(data)
+ end
+ end
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ save_unlock_sha_cb, --callback
+ 'SCRIPT', -- command
+ {'LOAD', redis_lua_script_save_unlock} -- arguments
+ )
+
if worker:get_name() == 'normal' then
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,