Browse Source

[Minor] Neural: Moar fixes

tags/2.0
Vsevolod Stakhov 5 years ago
parent
commit
06664d00f8
1 changed files with 159 additions and 95 deletions
  1. 159
    95
      src/plugins/lua/neural.lua

+ 159
- 95
src/plugins/lua/neural.lua View File

@@ -104,14 +104,14 @@ local redis_lua_script_can_store_train_vec = [[
if ret then nham = tonumber(ret) end

if KEYS[2] == 'spam' then
if nham <= lim and nham + 1 >= nspam then
return tostring(nspam + 1)
if nspam <= lim then
return tostring(nspam)
else
return tostring(-(nspam))
end
else
if nspam <= lim and nspam + 1 >= nham then
return tostring(nham + 1)
if nham <= lim then
return tostring(nham)
else
return tostring(-(nham))
end
@@ -127,8 +127,9 @@ local redis_can_store_train_vec_id = nil
-- key2 - number of elements to leave
local redis_lua_script_maybe_invalidate = [[
local card = redis.call('ZCARD', KEYS[1])
if card > tonumber(KEYS[2]) then
local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))))
local lim = tonumber(KEYS[2])
if card > lim then
local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
for _,k in ipairs(to_delete) do
local tb = cjson.decode(k)
redis.call('DEL', tb.redis_key)
@@ -136,7 +137,7 @@ local redis_lua_script_maybe_invalidate = [[
redis.call('DEL', tb.redis_key .. '_spam')
redis.call('DEL', tb.redis_key .. '_ham')
end
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))))
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
return to_delete
else
return {}
@@ -152,17 +153,17 @@ local redis_maybe_invalidate_id = nil
-- key4 - hostname
local redis_lua_script_maybe_lock = [[
local locked = redis.call('HGET', KEYS[1], 'lock')
local now = tonumber(KEYS[2])
if locked then
locked = tonumber(locked)
now = tonumber(KEYS[2])
expire = tonumber(KEYS[3])
local expire = tonumber(KEYS[3])
if now > locked and (now - locked) < expire then
return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')}
end
end
redis.call('HSET', KEYS[1], 'lock', tostring(now))
redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
return true
return 1
]]
local redis_maybe_lock_id = nil

@@ -178,6 +179,8 @@ local redis_lua_script_save_unlock = [[
local now = tonumber(KEYS[6])
redis.call('ZADD', KEYS[2], now, KEYS[4])
redis.call('HSET', KEYS[1], 'ann', KEYS[3])
redis.call('DEL', KEYS[1] .. '_spam')
edis.call('DEL', KEYS[1] .. '_ham')
redis.call('HDEL', KEYS[1], 'lock')
redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
return 1
@@ -267,7 +270,7 @@ local function new_ann_profile(task, rule, set, version)
}

local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact')
local profile_serialized = ucl.to_format(profile, 'json-compact', true)

local function add_cb(err, _)
if err then
@@ -322,7 +325,8 @@ local function ann_scores_filter(task)
score = out[1]

local symscore = string.format('%.3f', score)
rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore)
lua_util.debugm(N, task, '%s:%s ann score: %s',
rule.prefix, set.name, symscore)

if score > 0 then
local result = score
@@ -348,26 +352,44 @@ end

local function ann_push_task_result(rule, task, verdict, score, set)
local train_opts = rule.train


local learn_spam, learn_ham
local skip_reason = 'unknown'

if train_opts.autotrain then
if verdict == 'passthrough' or verdict == 'uncertain' then
if verdict == 'passthrough' then
lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
verdict, score)
end

if train_opts['spam_score'] then
learn_spam = score >= train_opts['spam_score']
if train_opts.spam_score then
learn_spam = score >= train_opts.spam_score

if not learn_spam then
skip_reason = string.format('score < spam_score: %f < %f',
score, train_opts.spam_score)
end
else
learn_spam = verdict == 'spam' or verdict == 'junk'

if not learn_spam then
skip_reason = string.format('verdict: %s',
verdict)
end
end

if train_opts['ham_score'] then
learn_ham = score <= train_opts['ham_score']
if train_opts.ham_score then
learn_ham = score <= train_opts.ham_score
if not learn_ham then
skip_reason = string.format('score > ham_score: %f < %f',
score, train_opts.ham_score)
end
else
learn_ham = verdict == 'ham'

if not learn_ham then
skip_reason = string.format('verdict: %s',
verdict)
end
end
else
-- Train by request header
@@ -378,6 +400,8 @@ local function ann_push_task_result(rule, task, verdict, score, set)
learn_spam = true
elseif hdr:lower() == 'ham' then
learn_ham = true
else
skip_reason = string.format('no explicit header')
end
end
end
@@ -387,18 +411,8 @@ local function ann_push_task_result(rule, task, verdict, score, set)
local learn_type
if learn_spam then learn_type = 'spam' else learn_type = 'ham' end

local function learn_vec_cb(err)
if err then
rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
rule.prefix, set.name, err)
else
rspamd_logger.infox(task, "trained ANN rule %s:%s, save %s vector",
rule.prefix, set.name, learn_type)
end
end

local function can_train_cb(err, data)
if not err and tonumber(data) > 0 then
if not err and tonumber(data) >= 0 then
local coin = math.random()
if coin < 1.0 - train_opts.train_prob then
rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
@@ -408,6 +422,17 @@ local function ann_push_task_result(rule, task, verdict, score, set)

local str = rspamd_util.zstd_compress(table.concat(vec, ';'))

local function learn_vec_cb(_err)
if _err then
rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
rule.prefix, set.name, _err)
else
lua_util.debugm(N, task,
"add train data for ANN rule %s:%s, save %s vector of %s elts; %s bytes compressed",
rule.prefix, set.name, learn_type, #vec, #str)
end
end

lua_redis.redis_make_request(task,
rule.redis,
nil,
@@ -422,7 +447,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
rule.prefix, set.name, err)
elseif tonumber(data) < 0 then
rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s",
rule.prefix, set.name, learn_type, -tonumber(data))
rule.prefix, set.name, learn_type, -tonumber(data))
end
end
end
@@ -436,6 +461,9 @@ local function ann_push_task_result(rule, task, verdict, score, set)
{task = task, is_write = true},
can_train_cb,
{ set.ann.redis_key, learn_type, tostring(train_opts.max_trains)})
else
lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s',
skip_reason)
end
end

@@ -481,6 +509,7 @@ local function register_lock_extender(rule, set, ev_base, ann_key)
{ann_key, 'lock', '30'}
)
else
lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
return false -- do not plan any more updates
end

@@ -537,7 +566,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
return out
end

rule.learning_spawned = true
set.learning_spawned = true

local function redis_save_cb(err)
if err then
@@ -559,7 +588,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
end

local function ann_trained(err, data)
rule.learning_spawned = false
set.learning_spawned = false
if err then
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
rule.prefix, set.name, err)
@@ -598,7 +627,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
}

local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact')
local profile_serialized = ucl.to_format(profile, 'json-compact', true)

lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
@@ -695,7 +724,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
if err then
rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
ann_key, err)
elseif type(data) == 'boolean' and data then
elseif type(data) == 'number' and data == 1 then
-- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
@@ -752,47 +781,52 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
ann_key, err)
else
local _err,ann_data = rspamd_util.zstd_decompress(data[1])
local ann

if _err or not ann_data then
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
rule.prefix .. ':' .. set.name, ann_key, _err)
return
else
ann = rspamd_kann.load(ann_data)

if ann then
set.ann = {
ann = ann,
version = profile.version,
symbols = profile.symbols,
distance = min_diff,
redis_key = profile.redis_key
}
if type(data) == 'string' then
local _err,ann_data = rspamd_util.zstd_decompress(data)
local ann

local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact')

local function rank_cb(_, _)
-- TODO: maybe add some logging
end
-- Also update rank for the loaded ANN to avoid removal
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
true, -- is write
rank_cb, --callback
'ZADD', -- command
{set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
)
rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s',
rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version)
if _err or not ann_data then
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
rule.prefix .. ':' .. set.name, ann_key, _err)
return
else
rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s from Redis key %s',
rule.prefix .. ':' .. set.name, ann_key)
ann = rspamd_kann.load(ann_data)

if ann then
set.ann = {
ann = ann,
version = profile.version,
symbols = profile.symbols,
distance = min_diff,
redis_key = profile.redis_key
}

local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact', true)

local function rank_cb(_, _)
-- TODO: maybe add some logging
end
-- Also update rank for the loaded ANN to avoid removal
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
true, -- is write
rank_cb, --callback
'ZADD', -- command
{set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
)
rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
rule.prefix, set.name, ann_key, #ann_data, profile.version)
else
rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s',
rule.prefix, set.name, ann_key)
end
end
else
lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s',
rule.prefix, set.name, ann_key)
end
end
end
@@ -803,8 +837,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
false, -- is write
data_cb, --callback
'HGET', -- command
{ann_key, 'ann'}, -- arguments
{opaque_data = true}
{ann_key, 'ann'} -- arguments
)
end

@@ -900,23 +933,46 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
-- We have our ANN and that's train vectors, check if we can learn
local ann_key = sel_elt.redis_key

lua_util.debugm(N, rspamd_config, "check ANN %s", ann_key)
local redis_len_cb = function(err, data)
if err then
rspamd_logger.errx(rspamd_config,
'cannot get FANN trains %s from redis: %s', ann_key, err)
elseif data and type(data) == 'number' or type(data) == 'string' then
if tonumber(data) and tonumber(data) >= rule.train.max_trains then
rspamd_logger.infox(rspamd_config,
'need to learn ANN %s after %s learn vectors (%s required)',
ann_key, tonumber(data), rule.train.max_trains)
do_train_ann(worker, ev_base, rule, set, ann_key)
else
rspamd_logger.debugm(N, rspamd_config,
'no need to learn ANN %s %s learn vectors (%s required)',
ann_key, tonumber(data), rule.train.max_trains)
lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
ann_key)

-- Create continuation closure
local redis_len_cb_gen = function(cont_cb)
return function(err, data)
if err then
rspamd_logger.errx(rspamd_config,
'cannot get ANN trains %s from redis: %s', ann_key, err)
elseif data and type(data) == 'number' or type(data) == 'string' then
if tonumber(data) and tonumber(data) >= rule.train.max_trains then
cont_cb()
else
rspamd_logger.debugm(N, rspamd_config,
'no need to learn ANN %s %s learn vectors (%s required)',
ann_key, tonumber(data), rule.train.max_trains)
end
end
end

end

local function initiate_train()
rspamd_logger.infox(rspamd_config,
'need to learn ANN %s after %s learn vectors (%s required)',
ann_key, tonumber(data), rule.train.max_trains)
do_train_ann(worker, ev_base, rule, set, ann_key)
end

-- Spam vector is OK, check ham vector length
local function check_ham_len()
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false, -- is write
redis_len_cb_gen(initiate_train), --callback
'LLEN', -- command
{ann_key .. '_ham'}
)
end

lua_redis.redis_make_request_taskless(ev_base,
@@ -924,7 +980,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
rule.redis,
nil,
false, -- is write
redis_len_cb, --callback
redis_len_cb_gen(check_ham_len), --callback
'LLEN', -- command
{ann_key .. '_spam'}
)
@@ -1005,14 +1061,22 @@ local function cleanup_anns(rule, cfg, ev_base)
end

local function ann_push_vector(task)
if task:has_flag('skip') then return end
if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
if task:has_flag('skip') then
lua_util.debugm(N, task, 'do not push data for skipped task')
return
end
if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then
lua_util.debugm(N, task, 'do not push data for manual scan')
return
end
local verdict,score = lua_util.get_task_verdict(task)
for _,rule in pairs(settings.rules) do
local set = get_rule_settings(task, rule)

if set then
ann_push_task_result(rule, task, verdict, score, set)
else
lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
end

end
@@ -1064,7 +1128,7 @@ local function process_rules_settings()

if rule.default then
local default_settings = {
symbols = lua_util.keys(lua_settings.default_symbols()),
symbols = lua_settings.default_symbols(),
name = 'default'
}

@@ -1099,7 +1163,7 @@ local function process_rules_settings()

if nelt then
rule.settings[s] = nelt
lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s; same symbols',
lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s',
nelt.name, rule.prefix)
end
end

Loading…
Cancel
Save