From a442c7f57f20fee40767c5d63bf98ceaabb8f183 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 5 Nov 2016 18:26:09 +0000 Subject: [PATCH] [Rework] Add extract training data function to fann_redis --- src/plugins/lua/fann_redis.lua | 125 +++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index e81af4762..f55454bf6 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -98,6 +98,16 @@ local redis_lua_script_maybe_invalidate = [[ ]] local redis_maybe_invalidate_sha = nil +-- Lua script to invalidate ANN from redis +-- Uses the following keys +-- key1 - prefix for keys +local redis_lua_script_maybe_lock = [[ + local locked = redis.call('GET', KEYS[1] .. '_locked') + if locked then return false end + return 1 +]] +local redis_maybe_lock_sha = nil + local redis_params redis_params = rspamd_parse_redis_server('fann_redis') @@ -341,7 +351,106 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, end local function train_fann(cfg, ev_base, elt) + local spam_elts = {} + local ham_elts = {} + + local function redis_unlock_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s from redis: %s', + fann_prefix .. elt, err) + end + end + + local function redis_ham_cb(err, data) + if err or type(data) ~= 'table' then + rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', + fann_prefix .. elt, err) + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + redis_unlock_cb, --callback + 'DEL', -- command + {fann_prefix .. elt .. '_lock'} + ) + else + -- Decompress and convert to numbers each training vector + ham_elts = map(function(i, elt) + local str = tostring(rspamd_util.zstd_decompress(elt)) + return map(tonumber, rspamd_str_split(str, ';')) + end, data) + + -- Now we need to join inputs and create the appropriate test vectors + local inputs = {} + local outputs = {} + + each(function(i, elt) + table.insert(inputs, totable(elt)) + table.insert(outputs, 1.0) + end, spam_elts) + each(function(i, elt) + table.insert(inputs, totable(elt)) + table.insert(outputs, -1.0) + end, spam_elts) + + -- Now we can train fann + + end + end + local function redis_spam_cb(err, data) + if err or type(data) ~= 'table' then + rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s', + fann_prefix .. elt, err) + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + redis_unlock_cb, --callback + 'DEL', -- command + {fann_prefix .. elt .. '_lock'} + ) + else + -- Decompress and convert to numbers each training vector + spam_elts = map(function(i, elt) + local str = tostring(rspamd_util.zstd_decompress(elt)) + return map(tonumber, rspamd_str_split(str, ';')) + end, data) + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + redis_ham_cb, --callback + 'LRANGE', -- command + {fann_prefix .. elt .. '_ham', '0', '-1'} + ) + end + end + + local function redis_lock_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', + fann_prefix .. elt, err) + elseif type(data) == 'number' then + -- Can train ANN + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + redis_spam_cb, --callback + 'LRANGE', -- command + {fann_prefix .. elt .. '_spam', '0', '-1'} + ) + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + redis_lock_cb, --callback + 'EVALSHA', -- command + {redis_maybe_lock_sha, '1', fann_prefix .. elt} + ) end local function maybe_train_fanns(cfg, ev_base) @@ -575,6 +684,22 @@ else {'LOAD', redis_lua_script_maybe_invalidate} -- arguments ) + local function maybe_lock_sha_cb(err, data) + if err or not data or type(data) ~= 'string' then + rspamd_logger.errx(cfg, 'cannot save redis lock script: %s', err) + else + redis_maybe_lock_sha = tostring(data) + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + maybe_lock_sha_cb, --callback + 'SCRIPT', -- command + {'LOAD', redis_lua_script_maybe_lock} -- arguments + ) + if worker:get_name() == 'normal' then -- We also want to train neural nets when they have enough data rspamd_config:add_periodic(ev_base, 0.0, -- 2.39.5