Browse Source

[Minor] Neural: Moar fixes

tags/2.0
Vsevolod Stakhov 5 years ago
parent
commit
06664d00f8
1 changed files with 159 additions and 95 deletions
  1. 159
    95
      src/plugins/lua/neural.lua

+ 159
- 95
src/plugins/lua/neural.lua View File

if ret then nham = tonumber(ret) end if ret then nham = tonumber(ret) end


if KEYS[2] == 'spam' then if KEYS[2] == 'spam' then
if nham <= lim and nham + 1 >= nspam then
return tostring(nspam + 1)
if nspam <= lim then
return tostring(nspam)
else else
return tostring(-(nspam)) return tostring(-(nspam))
end end
else else
if nspam <= lim and nspam + 1 >= nham then
return tostring(nham + 1)
if nham <= lim then
return tostring(nham)
else else
return tostring(-(nham)) return tostring(-(nham))
end end
-- key2 - number of elements to leave -- key2 - number of elements to leave
local redis_lua_script_maybe_invalidate = [[ local redis_lua_script_maybe_invalidate = [[
local card = redis.call('ZCARD', KEYS[1]) local card = redis.call('ZCARD', KEYS[1])
if card > tonumber(KEYS[2]) then
local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))))
local lim = tonumber(KEYS[2])
if card > lim then
local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
for _,k in ipairs(to_delete) do for _,k in ipairs(to_delete) do
local tb = cjson.decode(k) local tb = cjson.decode(k)
redis.call('DEL', tb.redis_key) redis.call('DEL', tb.redis_key)
redis.call('DEL', tb.redis_key .. '_spam') redis.call('DEL', tb.redis_key .. '_spam')
redis.call('DEL', tb.redis_key .. '_ham') redis.call('DEL', tb.redis_key .. '_ham')
end end
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))))
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
return to_delete return to_delete
else else
return {} return {}
-- key4 - hostname -- key4 - hostname
local redis_lua_script_maybe_lock = [[ local redis_lua_script_maybe_lock = [[
local locked = redis.call('HGET', KEYS[1], 'lock') local locked = redis.call('HGET', KEYS[1], 'lock')
local now = tonumber(KEYS[2])
if locked then if locked then
locked = tonumber(locked) locked = tonumber(locked)
now = tonumber(KEYS[2])
expire = tonumber(KEYS[3])
local expire = tonumber(KEYS[3])
if now > locked and (now - locked) < expire then if now > locked and (now - locked) < expire then
return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')} return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')}
end end
end end
redis.call('HSET', KEYS[1], 'lock', tostring(now)) redis.call('HSET', KEYS[1], 'lock', tostring(now))
redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
return true
return 1
]] ]]
local redis_maybe_lock_id = nil local redis_maybe_lock_id = nil


local now = tonumber(KEYS[6]) local now = tonumber(KEYS[6])
redis.call('ZADD', KEYS[2], now, KEYS[4]) redis.call('ZADD', KEYS[2], now, KEYS[4])
redis.call('HSET', KEYS[1], 'ann', KEYS[3]) redis.call('HSET', KEYS[1], 'ann', KEYS[3])
redis.call('DEL', KEYS[1] .. '_spam')
edis.call('DEL', KEYS[1] .. '_ham')
redis.call('HDEL', KEYS[1], 'lock') redis.call('HDEL', KEYS[1], 'lock')
redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
return 1 return 1
} }


local ucl = require "ucl" local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact')
local profile_serialized = ucl.to_format(profile, 'json-compact', true)


local function add_cb(err, _) local function add_cb(err, _)
if err then if err then
score = out[1] score = out[1]


local symscore = string.format('%.3f', score) local symscore = string.format('%.3f', score)
rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore)
lua_util.debugm(N, task, '%s:%s ann score: %s',
rule.prefix, set.name, symscore)


if score > 0 then if score > 0 then
local result = score local result = score


local function ann_push_task_result(rule, task, verdict, score, set) local function ann_push_task_result(rule, task, verdict, score, set)
local train_opts = rule.train local train_opts = rule.train


local learn_spam, learn_ham local learn_spam, learn_ham
local skip_reason = 'unknown'


if train_opts.autotrain then if train_opts.autotrain then
if verdict == 'passthrough' or verdict == 'uncertain' then
if verdict == 'passthrough' then
lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)', lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
verdict, score) verdict, score)
end end


if train_opts['spam_score'] then
learn_spam = score >= train_opts['spam_score']
if train_opts.spam_score then
learn_spam = score >= train_opts.spam_score

if not learn_spam then
skip_reason = string.format('score < spam_score: %f < %f',
score, train_opts.spam_score)
end
else else
learn_spam = verdict == 'spam' or verdict == 'junk' learn_spam = verdict == 'spam' or verdict == 'junk'

if not learn_spam then
skip_reason = string.format('verdict: %s',
verdict)
end
end end


if train_opts['ham_score'] then
learn_ham = score <= train_opts['ham_score']
if train_opts.ham_score then
learn_ham = score <= train_opts.ham_score
if not learn_ham then
skip_reason = string.format('score > ham_score: %f < %f',
score, train_opts.ham_score)
end
else else
learn_ham = verdict == 'ham' learn_ham = verdict == 'ham'

if not learn_ham then
skip_reason = string.format('verdict: %s',
verdict)
end
end end
else else
-- Train by request header -- Train by request header
learn_spam = true learn_spam = true
elseif hdr:lower() == 'ham' then elseif hdr:lower() == 'ham' then
learn_ham = true learn_ham = true
else
skip_reason = string.format('no explicit header')
end end
end end
end end
local learn_type local learn_type
if learn_spam then learn_type = 'spam' else learn_type = 'ham' end if learn_spam then learn_type = 'spam' else learn_type = 'ham' end


local function learn_vec_cb(err)
if err then
rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
rule.prefix, set.name, err)
else
rspamd_logger.infox(task, "trained ANN rule %s:%s, save %s vector",
rule.prefix, set.name, learn_type)
end
end

local function can_train_cb(err, data) local function can_train_cb(err, data)
if not err and tonumber(data) > 0 then
if not err and tonumber(data) >= 0 then
local coin = math.random() local coin = math.random()
if coin < 1.0 - train_opts.train_prob then if coin < 1.0 - train_opts.train_prob then
rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)


local str = rspamd_util.zstd_compress(table.concat(vec, ';')) local str = rspamd_util.zstd_compress(table.concat(vec, ';'))


local function learn_vec_cb(_err)
if _err then
rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
rule.prefix, set.name, _err)
else
lua_util.debugm(N, task,
"add train data for ANN rule %s:%s, save %s vector of %s elts; %s bytes compressed",
rule.prefix, set.name, learn_type, #vec, #str)
end
end

lua_redis.redis_make_request(task, lua_redis.redis_make_request(task,
rule.redis, rule.redis,
nil, nil,
rule.prefix, set.name, err) rule.prefix, set.name, err)
elseif tonumber(data) < 0 then elseif tonumber(data) < 0 then
rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s", rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s",
rule.prefix, set.name, learn_type, -tonumber(data))
rule.prefix, set.name, learn_type, -tonumber(data))
end end
end end
end end
{task = task, is_write = true}, {task = task, is_write = true},
can_train_cb, can_train_cb,
{ set.ann.redis_key, learn_type, tostring(train_opts.max_trains)}) { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)})
else
lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s',
skip_reason)
end end
end end


{ann_key, 'lock', '30'} {ann_key, 'lock', '30'}
) )
else else
lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
return false -- do not plan any more updates return false -- do not plan any more updates
end end


return out return out
end end


rule.learning_spawned = true
set.learning_spawned = true


local function redis_save_cb(err) local function redis_save_cb(err)
if err then if err then
end end


local function ann_trained(err, data) local function ann_trained(err, data)
rule.learning_spawned = false
set.learning_spawned = false
if err then if err then
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
rule.prefix, set.name, err) rule.prefix, set.name, err)
} }


local ucl = require "ucl" local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact')
local profile_serialized = ucl.to_format(profile, 'json-compact', true)


lua_redis.exec_redis_script(redis_save_unlock_id, lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true}, {ev_base = ev_base, is_write = true},
if err then if err then
rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s', rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
ann_key, err) ann_key, err)
elseif type(data) == 'boolean' and data then
elseif type(data) == 'number' and data == 1 then
-- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
lua_redis.redis_make_request_taskless(ev_base, lua_redis.redis_make_request_taskless(ev_base,
rspamd_config, rspamd_config,
rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
ann_key, err) ann_key, err)
else else
local _err,ann_data = rspamd_util.zstd_decompress(data[1])
local ann

if _err or not ann_data then
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
rule.prefix .. ':' .. set.name, ann_key, _err)
return
else
ann = rspamd_kann.load(ann_data)

if ann then
set.ann = {
ann = ann,
version = profile.version,
symbols = profile.symbols,
distance = min_diff,
redis_key = profile.redis_key
}
if type(data) == 'string' then
local _err,ann_data = rspamd_util.zstd_decompress(data)
local ann


local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact')

local function rank_cb(_, _)
-- TODO: maybe add some logging
end
-- Also update rank for the loaded ANN to avoid removal
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
true, -- is write
rank_cb, --callback
'ZADD', -- command
{set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
)
rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s',
rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version)
if _err or not ann_data then
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
rule.prefix .. ':' .. set.name, ann_key, _err)
return
else else
rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s from Redis key %s',
rule.prefix .. ':' .. set.name, ann_key)
ann = rspamd_kann.load(ann_data)

if ann then
set.ann = {
ann = ann,
version = profile.version,
symbols = profile.symbols,
distance = min_diff,
redis_key = profile.redis_key
}

local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact', true)

local function rank_cb(_, _)
-- TODO: maybe add some logging
end
-- Also update rank for the loaded ANN to avoid removal
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
true, -- is write
rank_cb, --callback
'ZADD', -- command
{set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
)
rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
rule.prefix, set.name, ann_key, #ann_data, profile.version)
else
rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s',
rule.prefix, set.name, ann_key)
end
end end
else
lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s',
rule.prefix, set.name, ann_key)
end end
end end
end end
false, -- is write false, -- is write
data_cb, --callback data_cb, --callback
'HGET', -- command 'HGET', -- command
{ann_key, 'ann'}, -- arguments
{opaque_data = true}
{ann_key, 'ann'} -- arguments
) )
end end


-- We have our ANN and that's train vectors, check if we can learn -- We have our ANN and that's train vectors, check if we can learn
local ann_key = sel_elt.redis_key local ann_key = sel_elt.redis_key


lua_util.debugm(N, rspamd_config, "check ANN %s", ann_key)
local redis_len_cb = function(err, data)
if err then
rspamd_logger.errx(rspamd_config,
'cannot get FANN trains %s from redis: %s', ann_key, err)
elseif data and type(data) == 'number' or type(data) == 'string' then
if tonumber(data) and tonumber(data) >= rule.train.max_trains then
rspamd_logger.infox(rspamd_config,
'need to learn ANN %s after %s learn vectors (%s required)',
ann_key, tonumber(data), rule.train.max_trains)
do_train_ann(worker, ev_base, rule, set, ann_key)
else
rspamd_logger.debugm(N, rspamd_config,
'no need to learn ANN %s %s learn vectors (%s required)',
ann_key, tonumber(data), rule.train.max_trains)
lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
ann_key)

-- Create continuation closure
local redis_len_cb_gen = function(cont_cb)
return function(err, data)
if err then
rspamd_logger.errx(rspamd_config,
'cannot get ANN trains %s from redis: %s', ann_key, err)
elseif data and type(data) == 'number' or type(data) == 'string' then
if tonumber(data) and tonumber(data) >= rule.train.max_trains then
cont_cb()
else
rspamd_logger.debugm(N, rspamd_config,
'no need to learn ANN %s %s learn vectors (%s required)',
ann_key, tonumber(data), rule.train.max_trains)
end
end end
end end

end

local function initiate_train()
rspamd_logger.infox(rspamd_config,
'need to learn ANN %s after %s learn vectors (%s required)',
ann_key, tonumber(data), rule.train.max_trains)
do_train_ann(worker, ev_base, rule, set, ann_key)
end

-- Spam vector is OK, check ham vector length
local function check_ham_len()
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false, -- is write
redis_len_cb_gen(initiate_train), --callback
'LLEN', -- command
{ann_key .. '_ham'}
)
end end


lua_redis.redis_make_request_taskless(ev_base, lua_redis.redis_make_request_taskless(ev_base,
rule.redis, rule.redis,
nil, nil,
false, -- is write false, -- is write
redis_len_cb, --callback
redis_len_cb_gen(check_ham_len), --callback
'LLEN', -- command 'LLEN', -- command
{ann_key .. '_spam'} {ann_key .. '_spam'}
) )
end end


local function ann_push_vector(task) local function ann_push_vector(task)
if task:has_flag('skip') then return end
if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
if task:has_flag('skip') then
lua_util.debugm(N, task, 'do not push data for skipped task')
return
end
if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then
lua_util.debugm(N, task, 'do not push data for manual scan')
return
end
local verdict,score = lua_util.get_task_verdict(task) local verdict,score = lua_util.get_task_verdict(task)
for _,rule in pairs(settings.rules) do for _,rule in pairs(settings.rules) do
local set = get_rule_settings(task, rule) local set = get_rule_settings(task, rule)


if set then if set then
ann_push_task_result(rule, task, verdict, score, set) ann_push_task_result(rule, task, verdict, score, set)
else
lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
end end


end end


if rule.default then if rule.default then
local default_settings = { local default_settings = {
symbols = lua_util.keys(lua_settings.default_symbols()),
symbols = lua_settings.default_symbols(),
name = 'default' name = 'default'
} }




if nelt then if nelt then
rule.settings[s] = nelt rule.settings[s] = nelt
lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s; same symbols',
lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s',
nelt.name, rule.prefix) nelt.name, rule.prefix)
end end
end end

Loading…
Cancel
Save