From 73d2cee82a5d55a239c628c21454137027a29db2 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Wed, 15 May 2019 14:09:40 +0100 Subject: [PATCH] [Project] Reputation: Migrate to adaptive EMA model --- src/plugins/lua/reputation.lua | 371 +++++++++++++-------------------- 1 file changed, 145 insertions(+), 226 deletions(-) diff --git a/src/plugins/lua/reputation.lua b/src/plugins/lua/reputation.lua index ad05023be..f1062dbaa 100644 --- a/src/plugins/lua/reputation.lua +++ b/src/plugins/lua/reputation.lua @@ -35,12 +35,8 @@ local ts = require("tableshape").types local redis_params = nil local default_expiry = 864000 -- 10 day by default +local default_prefix = 'RR:' -- Rspamd Reputation -local keymap_schema = ts.shape{ - ['spam'] = ts.string, - ['junk'] = ts.string, - ['ham'] = ts.string, -} -- Get reputation from ham/spam/probable hits local function generic_reputation_calc(token, rule, mult) @@ -50,16 +46,11 @@ local function generic_reputation_calc(token, rule, mult) return cfg.score_calc_func(rule, token, mult) end - local ham_samples = token.h or 0 - local spam_samples = token.s or 0 - local probable_samples = token.p or 0 - local total_samples = ham_samples + spam_samples + probable_samples - - if total_samples < cfg.lower_bound then return 0 end + if token[1] < cfg.lower_bound then return 0 end - local score = (ham_samples / total_samples) * -1.0 + - (spam_samples / total_samples) + - (probable_samples / total_samples) * 0.5 + local score = fun.foldl(function(acc, v) + return acc + v + end, 0.0, fun.map(tonumber, token[2])) / #token[2] return score end @@ -79,6 +70,38 @@ local function add_symbol_score(task, rule, mult, params) end end +local function sub_symbol_score(task, rule, score) + local function sym_score(sym) + local s = task:get_symbol(sym)[1] + return s.score + end + if rule.config.split_symbols then + local spam_sym = rule.symbol .. '_SPAM' + local ham_sym = rule.symbol .. '_HAM' + + if task:has_symbol(spam_sym) then + score = score - sym_score(spam_sym) + elseif task:has_symbol(ham_sym) then + score = score - sym_score(ham_sym) + end + else + if task:has_symbol(rule.symbol) then + score = score - sym_score(rule.symbol) + end + end + + return score +end + +-- Extracts task score and subtracts score of the rule itself +local function extract_task_score(task, rule) + local _,score = lua_util.get_task_verdict(task) + + if not score then return nil end + + return sub_symbol_score(task, rule, score) +end + -- DKIM Selector functions local gr local function gen_dkim_queries(task, rule) @@ -164,28 +187,14 @@ local function dkim_reputation_filter(task, rule) end local function dkim_reputation_idempotent(task, rule) - local verdict = lua_util.get_task_verdict(task) - local token = { - } - local cfg = rule.selector.config - local need_set = false - - -- TODO: take metric score into consideration - local k = cfg.keys_map[verdict] - - if k then - token[k] = 1.0 - need_set = true - end - - if need_set then - - local requests = gen_dkim_queries(task, rule) + local requests = gen_dkim_queries(task, rule) + local sc = extract_task_score(task, rule) + if sc then for dom,res in pairs(requests) do -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs local query = string.format('%s.%s', dom, res) - rule.backend.set_token(task, rule, query, token) + rule.backend.set_token(task, rule, query, sc) end end end @@ -212,15 +221,6 @@ end local dkim_selector = { config = { - -- keys map between actions and hash elements in bucket, - -- h is for ham, - -- s is for spam, - -- p is for probable spam - keys_map = { - ['spam'] = 's', - ['junk'] = 'p', - ['ham'] = 'h' - }, symbol = 'DKIM_SCORE', -- symbol to be inserted lower_bound = 10, -- minimum number of messages to be scored min_score = nil, @@ -270,7 +270,7 @@ local function gen_url_queries(task, rule) end local function url_reputation_filter(task, rule) - local requests = gen_url_queries(task, rule) + local requests = lua_util.extract_specific_urls(task, rule.selector.config.max_urls) local results = {} local nchecked = 0 @@ -304,47 +304,24 @@ local function url_reputation_filter(task, rule) end end - for _,tld in ipairs(requests) do - rule.backend.get_token(task, rule, tld[1], tokens_cb) + for _,u in ipairs(requests) do + rule.backend.get_token(task, rule, u:get_tld(), tokens_cb) end end local function url_reputation_idempotent(task, rule) - local verdict = lua_util.get_task_verdict(task) - local token = { - } - local cfg = rule.selector.config - local need_set = false - - -- TODO: take metric score into consideration - local k = cfg.keys_map[verdict] - - if k then - token[k] = 1.0 - need_set = true - end - - if need_set then - - local requests = gen_url_queries(task, rule) + local requests = gen_url_queries(task, rule) + local sc = extract_task_score(task, rule) + if sc then for _,tld in ipairs(requests) do - rule.backend.set_token(task, rule, tld[1], token) + rule.backend.set_token(task, rule, tld[1], sc) end end end local url_selector = { config = { - -- keys map between actions and hash elements in bucket, - -- h is for ham, - -- s is for spam, - -- p is for probable spam - keys_map = { - ['spam'] = 's', - ['junk'] = 'p', - ['ham'] = 'h' - }, symbol = 'URL_SCORE', -- symbol to be inserted lower_bound = 10, -- minimum number of messages to be scored min_score = nil, @@ -489,44 +466,20 @@ local function ip_reputation_idempotent(task, rule) return end end - - local verdict = lua_util.get_task_verdict(task) - local token = { - } - local need_set = false - - -- TODO: take metric score into consideration - local k = cfg.keys_map[verdict] - - if k then - token[k] = 1.0 - need_set = true + local sc = extract_task_score(task, rule) + if asn then + rule.backend.set_token(task, rule, cfg.asn_prefix .. asn, sc) end - - if need_set then - if asn then - rule.backend.set_token(task, rule, cfg.asn_prefix .. asn, token) - end - if country then - rule.backend.set_token(task, rule, cfg.country_prefix .. country, token) - end - - rule.backend.set_token(task, rule, cfg.ip_prefix .. tostring(ip), token) + if country then + rule.backend.set_token(task, rule, cfg.country_prefix .. country, sc) end + + rule.backend.set_token(task, rule, cfg.ip_prefix .. tostring(ip), sc) end -- Selectors are used to extract reputation tokens local ip_selector = { config = { - -- keys map between actions and hash elements in bucket, - -- h is for ham, - -- s is for spam, - -- p is for probable spam - keys_map = { - ['spam'] = 's', - ['junk'] = 'p', - ['ham'] = 'h' - }, scores = { -- how each component is evaluated ['asn'] = 0.4, ['country'] = 0.01, @@ -578,46 +531,23 @@ local function spf_reputation_filter(task, rule) end local function spf_reputation_idempotent(task, rule) - local verdict = lua_util.get_task_verdict(task) + local sc = extract_task_score(task, rule) local spf_record = task:get_mempool():get_variable('spf_record') local spf_allow = task:has_symbol('R_SPF_ALLOW') - local token = { - } - local cfg = rule.selector.config - local need_set = false - if not spf_record or not spf_allow then return end + if not spf_record or not spf_allow or not sc then return end - -- TODO: take metric score into consideration - local k = cfg.keys_map[verdict] - - if k then - token[k] = 1.0 - need_set = true - end - - if need_set then - local cr = require "rspamd_cryptobox_hash" - local hkey = cr.create(spf_record):base32():sub(1, 32) + local cr = require "rspamd_cryptobox_hash" + local hkey = cr.create(spf_record):base32():sub(1, 32) - lua_util.debugm(N, task, 'set spf record %s -> %s = %s', - spf_record, hkey, token) - rule.backend.set_token(task, rule, hkey, token) - end + lua_util.debugm(N, task, 'set spf record %s -> %s = %s', + spf_record, hkey, token) + rule.backend.set_token(task, rule, hkey, sc) end local spf_selector = { config = { - -- keys map between actions and hash elements in bucket, - -- h is for ham, - -- s is for spam, - -- p is for probable spam - keys_map = { - ['spam'] = 's', - ['junk'] = 'p', - ['ham'] = 'h' - }, symbol = 'SPF_SCORE', -- symbol to be inserted lower_bound = 10, -- minimum number of messages to be scored min_score = nil, @@ -694,32 +624,23 @@ local function generic_reputation_filter(task, rule) end local function generic_reputation_idempotent(task, rule) - local verdict = lua_util.get_task_verdict(task) + local sc = extract_task_score(task, rule) local cfg = rule.selector.config - local need_set = false - local token = {} local selector_res = cfg.selector(task) if not selector_res then return end - local k = cfg.keys_map[verdict] - - if k then - token[k] = 1.0 - need_set = true - end - - if need_set then + if sc then if type(selector_res) == 'table' then fun.each(function(e) lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', - rule['symbol'], e, token) - rule.backend.set_token(task, rule, e, token) + rule['symbol'], e, sc) + rule.backend.set_token(task, rule, e, sc) end, selector_res) else lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', - rule['symbol'], selector_res, token) - rule.backend.set_token(task, rule, selector_res, token) + rule['symbol'], selector_res, sc) + rule.backend.set_token(task, rule, selector_res, sc) end end end @@ -727,7 +648,6 @@ end local generic_selector = { schema = ts.shape{ - keys_map = keymap_schema, lower_bound = ts.number + ts.string / tonumber, max_score = ts.number:is_optional(), min_score = ts.number:is_optional(), @@ -738,15 +658,6 @@ local generic_selector = { whitelist = ts.string:is_optional(), }, config = { - -- keys map between actions and hash elements in bucket, - -- h is for ham, - -- s is for spam, - -- p is for probable spam - keys_map = { - ['spam'] = 's', - ['junk'] = 'p', - ['ham'] = 'h' - }, lower_bound = 10, -- minimum number of messages to be scored min_score = nil, max_score = nil, @@ -806,6 +717,10 @@ local function gen_token_key(token, rule) res = string.sub(res, 1, rule.backend.config.hashlen) end + if rule.backend.config.prefix then + res = rule.backend.config.prefix .. res + end + return res end @@ -887,71 +802,78 @@ local function reputation_redis_init(rule, cfg, ev_base, worker) return false end -- Init scripts for buckets + -- Redis script to extract data from Redis buckets + -- KEYS[1] - key to extract + -- Value returned - table of scores as a strings vector + number of scores local redis_get_script_tpl = [[ -local key = KEYS[1] .. '${name}' -local vals = redis.call('HGETALL', key) -for i=1,#vals,2 do - local k = vals[i] - local v = vals[i + 1] - if scores[k] then - scores[k] = scores[k] + tonumber(v) * ${mult} + local cnt = redis.call('HGET', KEYS[1], 'n') + local results = {} + if cnt then + {% for w in windows %} + local sc = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}')) + table.insert(results, tostring(sc * {= w.mult =})) + {% endfor %} else - scores[k] = tonumber(v) * ${mult} - end -end -]] - local redis_script_tbl = {'local scores = {}'} - for _,bucket in ipairs(rule.backend.config.buckets) do - table.insert(redis_script_tbl, lua_util.template(redis_get_script_tpl, bucket)) - end - table.insert(redis_script_tbl, [[ - local result = {} - for k,v in pairs(scores) do - table.insert(result, k) - table.insert(result, v) - end - - return result -]]) - rule.backend.script_get = lua_redis.add_redis_script(table.concat(redis_script_tbl, '\n'), - our_redis_params) - - redis_script_tbl = {} - local redis_set_script_tpl = [[ -local key = KEYS[1] .. '${name}' -local last = tonumber(redis.call('HGET', key, 'start')) -local now = tonumber(KEYS[2]) -if not last then - last = 0 -end -local discriminate_bucket = false -if now - last > ${time} then - discriminate_bucket = true - redis.call('HSET', key, 'start', now) -end -for i=1,#ARGV,2 do - local k = ARGV[i] - local v = tonumber(ARGV[i + 1]) - - if discriminate_bucket then - local last_value = redis.call('HGET', key, k) - if last_value then - redis.call('HSET', key, k, last_value / 2.0) + {% for w in windows %} + table.insert(results, '0') + {% endfor %} + end + + return results,cnt + ]] + + local get_script = lua_util.jinja_template(redis_get_script_tpl, + {windows = rule.backend.config.buckets}) + rspamd_logger.debugm(N, rspamd_config, 'added extraction script %s', get_script) + rule.backend.script_get = lua_redis.add_redis_script(get_script, our_redis_params) + + -- Redis script to update Redis buckets + -- KEYS[1] - key to update + -- KEYS[2] - current time in milliseconds + -- KEYS[3] - message score + -- KEYS[4] - expire for a bucket + -- Value returned - table of scores as a strings vector + local redis_adaptive_emea_script_tpl = [[ + local last = redis.call('HGET', KEYS[1], 'l') + local score = tonumber(KEYS[3]) + local now = tonumber(KEYS[2]) + local scores = {} + + if last then + {% for w in windows %} + local last_value = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}')) + local window = {= w.time =} + -- Adjust alpha + local time_diff = now - last_value + if time_diff > 0 then + time_diff = 0 end + local alpha = 1.0 - math.exp((-time_diff) / (1000 * window)) + local nscore = alpha * score + (1.0 - alpha) * last_value + table.insert(scores, tostring(nscore * {= w.mult =})) + {% endfor %} + else + {% for w in windows %} + table.insert(scores, tostring(score * {= w.mult =})) + {% endfor %} end - redis.call('HINCRBYFLOAT', key, k, v) -end -redis.call('EXPIRE', key, KEYS[3]) -redis.call('HSET', key, 'last', now) + local i = 1 + {% for w in windows %} + redis.call('HSET', KEYS[1], 'v' .. '{= w.name =}', scores[i]) + i = i + 1 + {% endfor %} + redis.call('HSET', KEYS[1], 'l', now) + redis.call('HINCRBY', KEYS[1], 'n', 1) + redis.call('EXPIRE', KEYS[1], tonumber(KEYS[4])) + + return scores ]] - for _,bucket in ipairs(rule.backend.config.buckets) do - table.insert(redis_script_tbl, lua_util.template(redis_set_script_tpl, - bucket)) - end - rule.backend.script_set = lua_redis.add_redis_script(table.concat(redis_script_tbl, '\n'), - our_redis_params) + local set_script = lua_util.jinja_template(redis_adaptive_emea_script_tpl, + {windows = rule.backend.config.buckets}) + rspamd_logger.debugm(N, rspamd_config, 'added emea update script %s', set_script) + rule.backend.script_set = lua_redis.add_redis_script(set_script, our_redis_params) return true end @@ -992,13 +914,13 @@ local function reputation_redis_get_token(task, rule, token, continuation_cb) local ret = lua_redis.exec_redis_script(rule.backend.script_get, {task = task, is_write = false}, redis_get_cb, - {token}) + {key}) if not ret then rspamd_logger.errx(task, 'cannot make redis request to check results') end end -local function reputation_redis_set_token(task, rule, token, values, continuation_cb) +local function reputation_redis_set_token(task, rule, token, sc, continuation_cb) local key = gen_token_key(token, rule) local function redis_set_cb(err, data) @@ -1015,19 +937,14 @@ local function reputation_redis_set_token(task, rule, token, values, continuatio end end - -- We start from expiry update - local args = {} - for k,v in pairs(values) do - table.insert(args, k) - table.insert(args, v) - end lua_util.debugm(N, task, 'rule %s - set values for key %s -> %s', - rule['symbol'], key, values) + rule['symbol'], key, sc) local ret = lua_redis.exec_redis_script(rule.backend.script_set, {task = task, is_write = true}, redis_set_cb, - {token, tostring(rspamd_util:get_time()), - tostring(rule.backend.config.expiry)}, args) + {key, tostring(os.time() * 1000), + tonumber(sc), + tostring(rule.backend.config.expiry)}) if not ret then rspamd_logger.errx(task, 'got error while connecting to redis') end @@ -1043,6 +960,7 @@ end local backends = { redis = { schema = ts.shape({ + prefix = ts.string, expiry = ts.number + ts.string / lua_util.parse_time_interval, buckets = ts.array_of(ts.shape{ time = ts.number + ts.string / lua_util.parse_time_interval, @@ -1052,6 +970,7 @@ local backends = { }, {extra_fields = lua_redis.config_schema}), config = { expiry = default_expiry, + prefix = default_prefix, buckets = { { time = 60 * 60, -- 2.39.5