diff options
Diffstat (limited to 'src/plugins/lua/neural.lua')
-rw-r--r-- | src/plugins/lua/neural.lua | 254 |
1 files changed, 159 insertions, 95 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 7b6c2fa5f..0375d57cd 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -104,14 +104,14 @@ local redis_lua_script_can_store_train_vec = [[ if ret then nham = tonumber(ret) end if KEYS[2] == 'spam' then - if nham <= lim and nham + 1 >= nspam then - return tostring(nspam + 1) + if nspam <= lim then + return tostring(nspam) else return tostring(-(nspam)) end else - if nspam <= lim and nspam + 1 >= nham then - return tostring(nham + 1) + if nham <= lim then + return tostring(nham) else return tostring(-(nham)) end @@ -127,8 +127,9 @@ local redis_can_store_train_vec_id = nil -- key2 - number of elements to leave local redis_lua_script_maybe_invalidate = [[ local card = redis.call('ZCARD', KEYS[1]) - if card > tonumber(KEYS[2]) then - local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))) + local lim = tonumber(KEYS[2]) + if card > lim then + local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1) for _,k in ipairs(to_delete) do local tb = cjson.decode(k) redis.call('DEL', tb.redis_key) @@ -136,7 +137,7 @@ local redis_lua_script_maybe_invalidate = [[ redis.call('DEL', tb.redis_key .. '_spam') redis.call('DEL', tb.redis_key .. '_ham') end - redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))) + redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1) return to_delete else return {} @@ -152,17 +153,17 @@ local redis_maybe_invalidate_id = nil -- key4 - hostname local redis_lua_script_maybe_lock = [[ local locked = redis.call('HGET', KEYS[1], 'lock') + local now = tonumber(KEYS[2]) if locked then locked = tonumber(locked) - now = tonumber(KEYS[2]) - expire = tonumber(KEYS[3]) + local expire = tonumber(KEYS[3]) if now > locked and (now - locked) < expire then return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')} end end redis.call('HSET', KEYS[1], 'lock', tostring(now)) redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) - return true + return 1 ]] local redis_maybe_lock_id = nil @@ -178,6 +179,8 @@ local redis_lua_script_save_unlock = [[ local now = tonumber(KEYS[6]) redis.call('ZADD', KEYS[2], now, KEYS[4]) redis.call('HSET', KEYS[1], 'ann', KEYS[3]) + redis.call('DEL', KEYS[1] .. '_spam') + edis.call('DEL', KEYS[1] .. '_ham') redis.call('HDEL', KEYS[1], 'lock') redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) return 1 @@ -267,7 +270,7 @@ local function new_ann_profile(task, rule, set, version) } local ucl = require "ucl" - local profile_serialized = ucl.to_format(profile, 'json-compact') + local profile_serialized = ucl.to_format(profile, 'json-compact', true) local function add_cb(err, _) if err then @@ -322,7 +325,8 @@ local function ann_scores_filter(task) score = out[1] local symscore = string.format('%.3f', score) - rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore) + lua_util.debugm(N, task, '%s:%s ann score: %s', + rule.prefix, set.name, symscore) if score > 0 then local result = score @@ -348,26 +352,44 @@ end local function ann_push_task_result(rule, task, verdict, score, set) local train_opts = rule.train - - local learn_spam, learn_ham + local skip_reason = 'unknown' if train_opts.autotrain then - if verdict == 'passthrough' or verdict == 'uncertain' then + if verdict == 'passthrough' then lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)', verdict, score) end - if train_opts['spam_score'] then - learn_spam = score >= train_opts['spam_score'] + if train_opts.spam_score then + learn_spam = score >= train_opts.spam_score + + if not learn_spam then + skip_reason = string.format('score < spam_score: %f < %f', + score, train_opts.spam_score) + end else learn_spam = verdict == 'spam' or verdict == 'junk' + + if not learn_spam then + skip_reason = string.format('verdict: %s', + verdict) + end end - if train_opts['ham_score'] then - learn_ham = score <= train_opts['ham_score'] + if train_opts.ham_score then + learn_ham = score <= train_opts.ham_score + if not learn_ham then + skip_reason = string.format('score > ham_score: %f < %f', + score, train_opts.ham_score) + end else learn_ham = verdict == 'ham' + + if not learn_ham then + skip_reason = string.format('verdict: %s', + verdict) + end end else -- Train by request header @@ -378,6 +400,8 @@ local function ann_push_task_result(rule, task, verdict, score, set) learn_spam = true elseif hdr:lower() == 'ham' then learn_ham = true + else + skip_reason = string.format('no explicit header') end end end @@ -387,18 +411,8 @@ local function ann_push_task_result(rule, task, verdict, score, set) local learn_type if learn_spam then learn_type = 'spam' else learn_type = 'ham' end - local function learn_vec_cb(err) - if err then - rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', - rule.prefix, set.name, err) - else - rspamd_logger.infox(task, "trained ANN rule %s:%s, save %s vector", - rule.prefix, set.name, learn_type) - end - end - local function can_train_cb(err, data) - if not err and tonumber(data) > 0 then + if not err and tonumber(data) >= 0 then local coin = math.random() if coin < 1.0 - train_opts.train_prob then rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) @@ -408,6 +422,17 @@ local function ann_push_task_result(rule, task, verdict, score, set) local str = rspamd_util.zstd_compress(table.concat(vec, ';')) + local function learn_vec_cb(_err) + if _err then + rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', + rule.prefix, set.name, _err) + else + lua_util.debugm(N, task, + "add train data for ANN rule %s:%s, save %s vector of %s elts; %s bytes compressed", + rule.prefix, set.name, learn_type, #vec, #str) + end + end + lua_redis.redis_make_request(task, rule.redis, nil, @@ -422,7 +447,7 @@ local function ann_push_task_result(rule, task, verdict, score, set) rule.prefix, set.name, err) elseif tonumber(data) < 0 then rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s", - rule.prefix, set.name, learn_type, -tonumber(data)) + rule.prefix, set.name, learn_type, -tonumber(data)) end end end @@ -436,6 +461,9 @@ local function ann_push_task_result(rule, task, verdict, score, set) {task = task, is_write = true}, can_train_cb, { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)}) + else + lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s', + skip_reason) end end @@ -481,6 +509,7 @@ local function register_lock_extender(rule, set, ev_base, ann_key) {ann_key, 'lock', '30'} ) else + lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") return false -- do not plan any more updates end @@ -537,7 +566,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve return out end - rule.learning_spawned = true + set.learning_spawned = true local function redis_save_cb(err) if err then @@ -559,7 +588,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve end local function ann_trained(err, data) - rule.learning_spawned = false + set.learning_spawned = false if err then rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', rule.prefix, set.name, err) @@ -598,7 +627,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve } local ucl = require "ucl" - local profile_serialized = ucl.to_format(profile, 'json-compact') + local profile_serialized = ucl.to_format(profile, 'json-compact', true) lua_redis.exec_redis_script(redis_save_unlock_id, {ev_base = ev_base, is_write = true}, @@ -695,7 +724,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) if err then rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s', ann_key, err) - elseif type(data) == 'boolean' and data then + elseif type(data) == 'number' and data == 1 then -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning lua_redis.redis_make_request_taskless(ev_base, rspamd_config, @@ -752,47 +781,52 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', ann_key, err) else - local _err,ann_data = rspamd_util.zstd_decompress(data[1]) - local ann - - if _err or not ann_data then - rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', - rule.prefix .. ':' .. set.name, ann_key, _err) - return - else - ann = rspamd_kann.load(ann_data) - - if ann then - set.ann = { - ann = ann, - version = profile.version, - symbols = profile.symbols, - distance = min_diff, - redis_key = profile.redis_key - } + if type(data) == 'string' then + local _err,ann_data = rspamd_util.zstd_decompress(data) + local ann - local ucl = require "ucl" - local profile_serialized = ucl.to_format(profile, 'json-compact') - - local function rank_cb(_, _) - -- TODO: maybe add some logging - end - -- Also update rank for the loaded ANN to avoid removal - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - rank_cb, --callback - 'ZADD', -- command - {set.prefix, tostring(rspamd_util.get_time()), profile_serialized} - ) - rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s', - rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version) + if _err or not ann_data then + rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', + rule.prefix .. ':' .. set.name, ann_key, _err) + return else - rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s from Redis key %s', - rule.prefix .. ':' .. set.name, ann_key) + ann = rspamd_kann.load(ann_data) + + if ann then + set.ann = { + ann = ann, + version = profile.version, + symbols = profile.symbols, + distance = min_diff, + redis_key = profile.redis_key + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + + local function rank_cb(_, _) + -- TODO: maybe add some logging + end + -- Also update rank for the loaded ANN to avoid removal + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + rank_cb, --callback + 'ZADD', -- command + {set.prefix, tostring(rspamd_util.get_time()), profile_serialized} + ) + rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #ann_data, profile.version) + else + rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s', + rule.prefix, set.name, ann_key) + end end + else + lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s', + rule.prefix, set.name, ann_key) end end end @@ -803,8 +837,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) false, -- is write data_cb, --callback 'HGET', -- command - {ann_key, 'ann'}, -- arguments - {opaque_data = true} + {ann_key, 'ann'} -- arguments ) end @@ -900,23 +933,46 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) -- We have our ANN and that's train vectors, check if we can learn local ann_key = sel_elt.redis_key - lua_util.debugm(N, rspamd_config, "check ANN %s", ann_key) - local redis_len_cb = function(err, data) - if err then - rspamd_logger.errx(rspamd_config, - 'cannot get FANN trains %s from redis: %s', ann_key, err) - elseif data and type(data) == 'number' or type(data) == 'string' then - if tonumber(data) and tonumber(data) >= rule.train.max_trains then - rspamd_logger.infox(rspamd_config, - 'need to learn ANN %s after %s learn vectors (%s required)', - ann_key, tonumber(data), rule.train.max_trains) - do_train_ann(worker, ev_base, rule, set, ann_key) - else - rspamd_logger.debugm(N, rspamd_config, - 'no need to learn ANN %s %s learn vectors (%s required)', - ann_key, tonumber(data), rule.train.max_trains) + lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained", + ann_key) + + -- Create continuation closure + local redis_len_cb_gen = function(cont_cb) + return function(err, data) + if err then + rspamd_logger.errx(rspamd_config, + 'cannot get ANN trains %s from redis: %s', ann_key, err) + elseif data and type(data) == 'number' or type(data) == 'string' then + if tonumber(data) and tonumber(data) >= rule.train.max_trains then + cont_cb() + else + rspamd_logger.debugm(N, rspamd_config, + 'no need to learn ANN %s %s learn vectors (%s required)', + ann_key, tonumber(data), rule.train.max_trains) + end end end + + end + + local function initiate_train() + rspamd_logger.infox(rspamd_config, + 'need to learn ANN %s after %s learn vectors (%s required)', + ann_key, tonumber(data), rule.train.max_trains) + do_train_ann(worker, ev_base, rule, set, ann_key) + end + + -- Spam vector is OK, check ham vector length + local function check_ham_len() + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(initiate_train), --callback + 'LLEN', -- command + {ann_key .. '_ham'} + ) end lua_redis.redis_make_request_taskless(ev_base, @@ -924,7 +980,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) rule.redis, nil, false, -- is write - redis_len_cb, --callback + redis_len_cb_gen(check_ham_len), --callback 'LLEN', -- command {ann_key .. '_spam'} ) @@ -1005,14 +1061,22 @@ local function cleanup_anns(rule, cfg, ev_base) end local function ann_push_vector(task) - if task:has_flag('skip') then return end - if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end + if task:has_flag('skip') then + lua_util.debugm(N, task, 'do not push data for skipped task') + return + end + if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then + lua_util.debugm(N, task, 'do not push data for manual scan') + return + end local verdict,score = lua_util.get_task_verdict(task) for _,rule in pairs(settings.rules) do local set = get_rule_settings(task, rule) if set then ann_push_task_result(rule, task, verdict, score, set) + else + lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix) end end @@ -1064,7 +1128,7 @@ local function process_rules_settings() if rule.default then local default_settings = { - symbols = lua_util.keys(lua_settings.default_symbols()), + symbols = lua_settings.default_symbols(), name = 'default' } @@ -1099,7 +1163,7 @@ local function process_rules_settings() if nelt then rule.settings[s] = nelt - lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s; same symbols', + lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s', nelt.name, rule.prefix) end end |