diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/neural.lua | 106 |
1 files changed, 77 insertions, 29 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index f40778a7b..c2ffb3e15 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -121,23 +121,6 @@ local redis_lua_script_can_train = [[ ]] local redis_can_train_id = nil --- Lua script to load ANN from redis --- Uses the following keys --- key1 - prefix for keys --- key2 - local version --- returns nil or bulk string if new ANN can be loaded -local redis_lua_script_maybe_load = [[ - local ver = 0 - local ret = redis.call('GET', KEYS[1] .. '_version') - if ret then ver = tonumber(ret) end - if ver > tonumber(KEYS[2]) then - return {redis.call('GET', KEYS[1] .. '_data'), ret} - end - - return tonumber(ret) or 0 -]] -local redis_maybe_load_id = nil - -- Lua script to invalidate ANNs by rank -- Uses the following keys -- key1 - prefix for keys @@ -148,10 +131,10 @@ local redis_lua_script_maybe_invalidate = [[ local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))) for _,k in ipairs(to_delete) do local tb = cjson.decode(k) - redis.call('DEL', tb.ann_key) + redis.call('DEL', tb.redis_key) -- Also train vectors - redis.call('DEL', tb.ann_key .. '_spam') - redis.call('DEL', tb.ann_key .. '_ham') + redis.call('DEL', tb.redis_key .. '_spam') + redis.call('DEL', tb.redis_key .. '_ham') end redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))) return to_delete @@ -217,8 +200,6 @@ local redis_params local function load_scripts(params) redis_can_train_id = lua_redis.add_redis_script(redis_lua_script_can_train, params) - redis_maybe_load_id = lua_redis.add_redis_script(redis_lua_script_maybe_load, - params) redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate, params) redis_locked_invalidate_id = lua_redis.add_redis_script(redis_lua_script_locked_invalidate, @@ -254,6 +235,55 @@ local function result_to_vector(task, profile) return vec end +-- Used to generate new ANN key for specific profile +local function new_ann_key(rule, set) + local ann_key = string.format('%s_%s_%s_%s_nn', settings.prefix, + rule.prefix, set.name, set.digest:sub(1, 8)) + + return ann_key +end + +-- Creates and stores ANN profile in Redis +local function new_ann_profile(task, rule, set) + local ann_key = new_ann_key(rule, set) + + + local profile = { + symbols = set.symbols, + redis_key = ann_key, + version = 0, + digest = set.digest, + distance = 0 -- Since we are using our own profile + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact') + + local function add_cb(err, _) + if err then + rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s', + rule.prefix, set.name, err) + else + rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s', + rule.prefix, set.name, profile.redis_key) + end + end + + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + add_cb, --callback + 'ZADD', -- command + {set.prefix, profile_serialized, tostring(rspamd_util.get_time())} + ) + + return profile +end + + +-- ANN filter function, used to insert scores based on the existing symbols local function ann_scores_filter(task) for _,rule in pairs(settings.rules) do @@ -389,7 +419,8 @@ local function ann_train_callback(rule, task, score, required_score, set) ) else if err then - rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err) + rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s', + rule.prefix, set.name, err) elseif tonumber(data) < 0 then rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s", rule.prefix, set.name, learn_type, -tonumber(data)) @@ -399,6 +430,7 @@ local function ann_train_callback(rule, task, score, required_score, set) if not set.ann then -- Need to create or load a profile corresponding to the current configuration + set.ann = new_ann_profile(task, rule, set) end -- Check if we can learn lua_redis.exec_redis_script(redis_can_train_id, @@ -704,12 +736,12 @@ end -- This function loads new ann from Redis -- This is based on `profile` attribute. --- ANN is loaded from `profile.ann_key` +-- ANN is loaded from `profile.redis_key` -- Rank of `profile` key is also increased, unfortunately, it means that we need to -- serialize profile one more time and set its rank to the current time -- set.ann fields are set according to Redis data received local function load_new_ann(rule, ev_base, set, profile, min_diff) - local ann_key = profile.ann_key + local ann_key = profile.redis_key local function data_cb(err, data) if err then @@ -732,9 +764,25 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) version = profile.version, symbols = profile.symbols, distance = min_diff, - redis_key = profile.ann_key + redis_key = profile.redis_key } + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact') + + local function rank_cb(_, _) + -- TODO: maybe add some logging + end + -- Also update rank for the loaded ANN to avoid removal + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + rank_cb, --callback + 'ZADD', -- command + {set.prefix, profile_serialized, tostring(rspamd_util.get_time())} + ) rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s', rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version) else @@ -881,7 +929,7 @@ local function cleanup_anns(rule, cfg, ev_base) local profile = load_ann_profile(expired) rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s', rule.prefix .. ':' .. set.name, - profile.ann_key, + profile.redis_key, profile.version) end end @@ -941,7 +989,7 @@ local function process_rules_settings() lua_redis.register_prefix(selt.prefix, N, string.format('NN prefix for rule "%s"; settings id "%s"', - rule.prefix, selt.name)) + rule.prefix, selt.name), {persistent = true}) end for _,rule in pairs(opts.rules) do @@ -1063,7 +1111,7 @@ rspamd_config:register_symbol({ }) -- Add training scripts -for k,rule in pairs(settings.rules) do +for _,rule in pairs(settings.rules) do load_scripts(rule.redis) -- We also need to deal with settings rspamd_config:add_post_init(process_rules_settings) |