aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/neural.lua106
1 files changed, 77 insertions, 29 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index f40778a7b..c2ffb3e15 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -121,23 +121,6 @@ local redis_lua_script_can_train = [[
]]
local redis_can_train_id = nil
--- Lua script to load ANN from redis
--- Uses the following keys
--- key1 - prefix for keys
--- key2 - local version
--- returns nil or bulk string if new ANN can be loaded
-local redis_lua_script_maybe_load = [[
- local ver = 0
- local ret = redis.call('GET', KEYS[1] .. '_version')
- if ret then ver = tonumber(ret) end
- if ver > tonumber(KEYS[2]) then
- return {redis.call('GET', KEYS[1] .. '_data'), ret}
- end
-
- return tonumber(ret) or 0
-]]
-local redis_maybe_load_id = nil
-
-- Lua script to invalidate ANNs by rank
-- Uses the following keys
-- key1 - prefix for keys
@@ -148,10 +131,10 @@ local redis_lua_script_maybe_invalidate = [[
local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
for _,k in ipairs(to_delete) do
local tb = cjson.decode(k)
- redis.call('DEL', tb.ann_key)
+ redis.call('DEL', tb.redis_key)
-- Also train vectors
- redis.call('DEL', tb.ann_key .. '_spam')
- redis.call('DEL', tb.ann_key .. '_ham')
+ redis.call('DEL', tb.redis_key .. '_spam')
+ redis.call('DEL', tb.redis_key .. '_ham')
end
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
return to_delete
@@ -217,8 +200,6 @@ local redis_params
local function load_scripts(params)
redis_can_train_id = lua_redis.add_redis_script(redis_lua_script_can_train,
params)
- redis_maybe_load_id = lua_redis.add_redis_script(redis_lua_script_maybe_load,
- params)
redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
params)
redis_locked_invalidate_id = lua_redis.add_redis_script(redis_lua_script_locked_invalidate,
@@ -254,6 +235,55 @@ local function result_to_vector(task, profile)
return vec
end
+-- Used to generate new ANN key for specific profile
+local function new_ann_key(rule, set)
+ local ann_key = string.format('%s_%s_%s_%s_nn', settings.prefix,
+ rule.prefix, set.name, set.digest:sub(1, 8))
+
+ return ann_key
+end
+
+-- Creates and stores ANN profile in Redis
+local function new_ann_profile(task, rule, set)
+ local ann_key = new_ann_key(rule, set)
+
+
+ local profile = {
+ symbols = set.symbols,
+ redis_key = ann_key,
+ version = 0,
+ digest = set.digest,
+ distance = 0 -- Since we are using our own profile
+ }
+
+ local ucl = require "ucl"
+ local profile_serialized = ucl.to_format(profile, 'json-compact')
+
+ local function add_cb(err, _)
+ if err then
+ rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
+ rule.prefix, set.name, err)
+ else
+ rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
+ rule.prefix, set.name, profile.redis_key)
+ end
+ end
+
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ add_cb, --callback
+ 'ZADD', -- command
+ {set.prefix, profile_serialized, tostring(rspamd_util.get_time())}
+ )
+
+ return profile
+end
+
+
+-- ANN filter function, used to insert scores based on the existing symbols
local function ann_scores_filter(task)
for _,rule in pairs(settings.rules) do
@@ -389,7 +419,8 @@ local function ann_train_callback(rule, task, score, required_score, set)
)
else
if err then
- rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err)
+ rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
+ rule.prefix, set.name, err)
elseif tonumber(data) < 0 then
rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s",
rule.prefix, set.name, learn_type, -tonumber(data))
@@ -399,6 +430,7 @@ local function ann_train_callback(rule, task, score, required_score, set)
if not set.ann then
-- Need to create or load a profile corresponding to the current configuration
+ set.ann = new_ann_profile(task, rule, set)
end
-- Check if we can learn
lua_redis.exec_redis_script(redis_can_train_id,
@@ -704,12 +736,12 @@ end
-- This function loads new ann from Redis
-- This is based on `profile` attribute.
--- ANN is loaded from `profile.ann_key`
+-- ANN is loaded from `profile.redis_key`
-- Rank of `profile` key is also increased, unfortunately, it means that we need to
-- serialize profile one more time and set its rank to the current time
-- set.ann fields are set according to Redis data received
local function load_new_ann(rule, ev_base, set, profile, min_diff)
- local ann_key = profile.ann_key
+ local ann_key = profile.redis_key
local function data_cb(err, data)
if err then
@@ -732,9 +764,25 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
version = profile.version,
symbols = profile.symbols,
distance = min_diff,
- redis_key = profile.ann_key
+ redis_key = profile.redis_key
}
+ local ucl = require "ucl"
+ local profile_serialized = ucl.to_format(profile, 'json-compact')
+
+ local function rank_cb(_, _)
+ -- TODO: maybe add some logging
+ end
+ -- Also update rank for the loaded ANN to avoid removal
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ rank_cb, --callback
+ 'ZADD', -- command
+ {set.prefix, profile_serialized, tostring(rspamd_util.get_time())}
+ )
rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s',
rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version)
else
@@ -881,7 +929,7 @@ local function cleanup_anns(rule, cfg, ev_base)
local profile = load_ann_profile(expired)
rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
rule.prefix .. ':' .. set.name,
- profile.ann_key,
+ profile.redis_key,
profile.version)
end
end
@@ -941,7 +989,7 @@ local function process_rules_settings()
lua_redis.register_prefix(selt.prefix, N,
string.format('NN prefix for rule "%s"; settings id "%s"',
- rule.prefix, selt.name))
+ rule.prefix, selt.name), {persistent = true})
end
for _,rule in pairs(opts.rules) do
@@ -1063,7 +1111,7 @@ rspamd_config:register_symbol({
})
-- Add training scripts
-for k,rule in pairs(settings.rules) do
+for _,rule in pairs(settings.rules) do
load_scripts(rule.redis)
-- We also need to deal with settings
rspamd_config:add_post_init(process_rules_settings)