From 3038d9585f3c34fbd4fb65179d064fca2e445783 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 6 Jul 2019 13:39:46 +0100 Subject: [Project] Neural: Rework train data pushing part --- src/plugins/lua/neural.lua | 91 ++++++++++++++++++++-------------------------- 1 file changed, 39 insertions(+), 52 deletions(-) (limited to 'src') diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index fdb138321..f40778a7b 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -65,7 +65,7 @@ local default_options = { -- Settings ANN table is loaded from Redis and represents dynamic profile for ANN -- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically -- * version - version of ANN loaded from redis --- * ann_key - name of ANN key in Redis +-- * redis_key - name of ANN key in Redis -- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols) -- * distance - distance between set.symbols and set.ann.symbols -- * ann - kann object @@ -85,31 +85,25 @@ end -- Lua script to train a row -- Uses the following keys: --- key1 - prefix for fann --- key2 - fann suffix (settings id) --- key3 - spam or ham --- key4 - maximum trains +-- key1 - ann key +-- key2 - spam or ham +-- key3 - maximum trains -- returns 1 or 0: 1 - allow learn, 0 - not allow learn local redis_lua_script_can_train = [[ - local prefix = KEYS[1] .. KEYS[2] - local locked = redis.call('GET', prefix .. '_locked') + local prefix = KEYS[1] + local locked = redis.call('HGET', prefix, 'lock') if locked then return 0 end local nspam = 0 local nham = 0 - local lim = tonumber(KEYS[4]) + local lim = tonumber(KEYS[3]) lim = lim + lim * 0.1 - local exists = redis.call('SISMEMBER', KEYS[1], KEYS[2]) - if not exists or tonumber(exists) == 0 then - redis.call('SADD', KEYS[1], KEYS[2]) - end - local ret = redis.call('LLEN', prefix .. '_spam') if ret then nspam = tonumber(ret) end ret = redis.call('LLEN', prefix .. '_ham') if ret then nham = tonumber(ret) end - if KEYS[3] == 'spam' then + if KEYS[2] == 'spam' then if nham <= lim and nham + 1 >= nspam then return tostring(nspam + 1) else @@ -155,6 +149,9 @@ local redis_lua_script_maybe_invalidate = [[ for _,k in ipairs(to_delete) do local tb = cjson.decode(k) redis.call('DEL', tb.ann_key) + -- Also train vectors + redis.call('DEL', tb.ann_key .. '_spam') + redis.call('DEL', tb.ann_key .. '_ham') end redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))) return to_delete @@ -328,9 +325,8 @@ local function create_ann(n, nlayers) end -local function ann_train_callback(rule, task, score, required_score, id) - local train_opts = rule['train'] - local fname,suffix = gen_ann_prefix(rule, id) +local function ann_train_callback(rule, task, score, required_score, set) + local train_opts = rule.train local learn_spam, learn_ham @@ -360,16 +356,15 @@ local function ann_train_callback(rule, task, score, required_score, id) if learn_spam or learn_ham then - local k - local vec_len = 0 - if learn_spam then k = 'spam' else k = 'ham' end + local learn_type + if learn_spam then learn_type = 'spam' else learn_type = 'ham' end local function learn_vec_cb(err) if err then rspamd_logger.errx(task, 'cannot store train vector for %s: %s', fname, err) else rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes", - rule['name'], k, vec_len) + rule['name'], learn_type, vec_len) end end @@ -380,42 +375,36 @@ local function ann_train_callback(rule, task, score, required_score, id) rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) return end - local ann_data = task:get_symbols_tokens() - local mt = meta_functions.rspamd_gen_metatokens(task) - -- Add filtered meta tokens - fun.each(function(e) table.insert(ann_data, e) end, mt) - -- Check NaNs in train data - if fun.all(function(e) return e == e end, ann_data) then - local str = rspamd_util.zstd_compress(table.concat(ann_data, ';')) - vec_len = #str - - lua_redis.redis_make_request(task, + local vec = result_to_vector(task, set) + + local str = rspamd_util.zstd_compress(table.concat(vec, ';')) + + lua_redis.redis_make_request(task, rule.redis, nil, true, -- is write learn_vec_cb, --callback 'LPUSH', -- command - {fname .. '_' .. k, str} -- arguments - ) - else - rspamd_logger.errx(task, "do not store learn vector as it contains %s NaN values", - fun.length(fun.filter(function(e) return e ~= e end, ann_data))) - end - + { set.ann.redis_prefix .. '_' .. learn_type, str} -- arguments + ) else if err then rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err) elseif tonumber(data) < 0 then - rspamd_logger.infox(task, "cannot learn ANN %s: too many %s samples: %s", - fname, k, -tonumber(data)) + rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s", + rule.prefix, set.name, learn_type, -tonumber(data)) end end end + if not set.ann then + -- Need to create or load a profile corresponding to the current configuration + end + -- Check if we can learn lua_redis.exec_redis_script(redis_can_train_id, - {task = task, is_write = true}, - can_train_cb, - {gen_ann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)}) + {task = task, is_write = true}, + can_train_cb, + { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)}) end end @@ -742,7 +731,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) ann = ann, version = profile.version, symbols = profile.symbols, - distance = min_diff + distance = min_diff, + redis_key = profile.ann_key } rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s', @@ -909,15 +899,12 @@ local function ann_push_vector(task) if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end local scores = task:get_metric_score() for _,rule in pairs(settings.rules) do - local sid = "0" - if rule.use_settings then - sid = tostring(task:get_settings_id()) - end - if rule.per_user then - local r = task:get_principal_recipient() - sid = sid .. r + local sid = task:get_settings_id() or -1 + + if rule.settings[sid] then + ann_train_callback(rule, task, scores[1], scores[2], rule.settings[sid]) end - ann_train_callback(rule, task, scores[1], scores[2], sid) + end end -- cgit v1.2.3