aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-06 13:39:46 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-06 13:39:46 +0100
commit3038d9585f3c34fbd4fb65179d064fca2e445783 (patch)
treeb9e054785ea47fc3d2568df55dbb38342d1f79cb /src
parent933c82f6ed2f4450d3c0cfad7d35a9750918b74e (diff)
downloadrspamd-3038d9585f3c34fbd4fb65179d064fca2e445783.tar.gz
rspamd-3038d9585f3c34fbd4fb65179d064fca2e445783.zip
[Project] Neural: Rework train data pushing part
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/neural.lua91
1 files changed, 39 insertions, 52 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index fdb138321..f40778a7b 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -65,7 +65,7 @@ local default_options = {
-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
-- * version - version of ANN loaded from redis
--- * ann_key - name of ANN key in Redis
+-- * redis_key - name of ANN key in Redis
-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
-- * distance - distance between set.symbols and set.ann.symbols
-- * ann - kann object
@@ -85,31 +85,25 @@ end
-- Lua script to train a row
-- Uses the following keys:
--- key1 - prefix for fann
--- key2 - fann suffix (settings id)
--- key3 - spam or ham
--- key4 - maximum trains
+-- key1 - ann key
+-- key2 - spam or ham
+-- key3 - maximum trains
-- returns 1 or 0: 1 - allow learn, 0 - not allow learn
local redis_lua_script_can_train = [[
- local prefix = KEYS[1] .. KEYS[2]
- local locked = redis.call('GET', prefix .. '_locked')
+ local prefix = KEYS[1]
+ local locked = redis.call('HGET', prefix, 'lock')
if locked then return 0 end
local nspam = 0
local nham = 0
- local lim = tonumber(KEYS[4])
+ local lim = tonumber(KEYS[3])
lim = lim + lim * 0.1
- local exists = redis.call('SISMEMBER', KEYS[1], KEYS[2])
- if not exists or tonumber(exists) == 0 then
- redis.call('SADD', KEYS[1], KEYS[2])
- end
-
local ret = redis.call('LLEN', prefix .. '_spam')
if ret then nspam = tonumber(ret) end
ret = redis.call('LLEN', prefix .. '_ham')
if ret then nham = tonumber(ret) end
- if KEYS[3] == 'spam' then
+ if KEYS[2] == 'spam' then
if nham <= lim and nham + 1 >= nspam then
return tostring(nspam + 1)
else
@@ -155,6 +149,9 @@ local redis_lua_script_maybe_invalidate = [[
for _,k in ipairs(to_delete) do
local tb = cjson.decode(k)
redis.call('DEL', tb.ann_key)
+ -- Also train vectors
+ redis.call('DEL', tb.ann_key .. '_spam')
+ redis.call('DEL', tb.ann_key .. '_ham')
end
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
return to_delete
@@ -328,9 +325,8 @@ local function create_ann(n, nlayers)
end
-local function ann_train_callback(rule, task, score, required_score, id)
- local train_opts = rule['train']
- local fname,suffix = gen_ann_prefix(rule, id)
+local function ann_train_callback(rule, task, score, required_score, set)
+ local train_opts = rule.train
local learn_spam, learn_ham
@@ -360,16 +356,15 @@ local function ann_train_callback(rule, task, score, required_score, id)
if learn_spam or learn_ham then
- local k
- local vec_len = 0
- if learn_spam then k = 'spam' else k = 'ham' end
+ local learn_type
+ if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
local function learn_vec_cb(err)
if err then
rspamd_logger.errx(task, 'cannot store train vector for %s: %s', fname, err)
else
rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes",
- rule['name'], k, vec_len)
+ rule['name'], learn_type, vec_len)
end
end
@@ -380,42 +375,36 @@ local function ann_train_callback(rule, task, score, required_score, id)
rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
return
end
- local ann_data = task:get_symbols_tokens()
- local mt = meta_functions.rspamd_gen_metatokens(task)
- -- Add filtered meta tokens
- fun.each(function(e) table.insert(ann_data, e) end, mt)
- -- Check NaNs in train data
- if fun.all(function(e) return e == e end, ann_data) then
- local str = rspamd_util.zstd_compress(table.concat(ann_data, ';'))
- vec_len = #str
-
- lua_redis.redis_make_request(task,
+ local vec = result_to_vector(task, set)
+
+ local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
+
+ lua_redis.redis_make_request(task,
rule.redis,
nil,
true, -- is write
learn_vec_cb, --callback
'LPUSH', -- command
- {fname .. '_' .. k, str} -- arguments
- )
- else
- rspamd_logger.errx(task, "do not store learn vector as it contains %s NaN values",
- fun.length(fun.filter(function(e) return e ~= e end, ann_data)))
- end
-
+ { set.ann.redis_prefix .. '_' .. learn_type, str} -- arguments
+ )
else
if err then
rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err)
elseif tonumber(data) < 0 then
- rspamd_logger.infox(task, "cannot learn ANN %s: too many %s samples: %s",
- fname, k, -tonumber(data))
+ rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s",
+ rule.prefix, set.name, learn_type, -tonumber(data))
end
end
end
+ if not set.ann then
+ -- Need to create or load a profile corresponding to the current configuration
+ end
+ -- Check if we can learn
lua_redis.exec_redis_script(redis_can_train_id,
- {task = task, is_write = true},
- can_train_cb,
- {gen_ann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)})
+ {task = task, is_write = true},
+ can_train_cb,
+ { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)})
end
end
@@ -742,7 +731,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
ann = ann,
version = profile.version,
symbols = profile.symbols,
- distance = min_diff
+ distance = min_diff,
+ redis_key = profile.ann_key
}
rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s',
@@ -909,15 +899,12 @@ local function ann_push_vector(task)
if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
local scores = task:get_metric_score()
for _,rule in pairs(settings.rules) do
- local sid = "0"
- if rule.use_settings then
- sid = tostring(task:get_settings_id())
- end
- if rule.per_user then
- local r = task:get_principal_recipient()
- sid = sid .. r
+ local sid = task:get_settings_id() or -1
+
+ if rule.settings[sid] then
+ ann_train_callback(rule, task, scores[1], scores[2], rule.settings[sid])
end
- ann_train_callback(rule, task, scores[1], scores[2], sid)
+
end
end