Bläddra i källkod

[Project] Reputation: Migrate to adaptive EMA model

tags/2.0
Vsevolod Stakhov 5 år sedan
förälder
incheckning
73d2cee82a
1 ändrade filer med 145 tillägg och 226 borttagningar
  1. 145
    226
      src/plugins/lua/reputation.lua

+ 145
- 226
src/plugins/lua/reputation.lua Visa fil



local redis_params = nil local redis_params = nil
local default_expiry = 864000 -- 10 day by default local default_expiry = 864000 -- 10 day by default
local default_prefix = 'RR:' -- Rspamd Reputation


local keymap_schema = ts.shape{
['spam'] = ts.string,
['junk'] = ts.string,
['ham'] = ts.string,
}


-- Get reputation from ham/spam/probable hits -- Get reputation from ham/spam/probable hits
local function generic_reputation_calc(token, rule, mult) local function generic_reputation_calc(token, rule, mult)
return cfg.score_calc_func(rule, token, mult) return cfg.score_calc_func(rule, token, mult)
end end


local ham_samples = token.h or 0
local spam_samples = token.s or 0
local probable_samples = token.p or 0
local total_samples = ham_samples + spam_samples + probable_samples

if total_samples < cfg.lower_bound then return 0 end
if token[1] < cfg.lower_bound then return 0 end


local score = (ham_samples / total_samples) * -1.0 +
(spam_samples / total_samples) +
(probable_samples / total_samples) * 0.5
local score = fun.foldl(function(acc, v)
return acc + v
end, 0.0, fun.map(tonumber, token[2])) / #token[2]


return score return score
end end
end end
end end


local function sub_symbol_score(task, rule, score)
local function sym_score(sym)
local s = task:get_symbol(sym)[1]
return s.score
end
if rule.config.split_symbols then
local spam_sym = rule.symbol .. '_SPAM'
local ham_sym = rule.symbol .. '_HAM'

if task:has_symbol(spam_sym) then
score = score - sym_score(spam_sym)
elseif task:has_symbol(ham_sym) then
score = score - sym_score(ham_sym)
end
else
if task:has_symbol(rule.symbol) then
score = score - sym_score(rule.symbol)
end
end

return score
end

-- Extracts task score and subtracts score of the rule itself
local function extract_task_score(task, rule)
local _,score = lua_util.get_task_verdict(task)

if not score then return nil end

return sub_symbol_score(task, rule, score)
end

-- DKIM Selector functions -- DKIM Selector functions
local gr local gr
local function gen_dkim_queries(task, rule) local function gen_dkim_queries(task, rule)
end end


local function dkim_reputation_idempotent(task, rule) local function dkim_reputation_idempotent(task, rule)
local verdict = lua_util.get_task_verdict(task)
local token = {
}
local cfg = rule.selector.config
local need_set = false

-- TODO: take metric score into consideration
local k = cfg.keys_map[verdict]

if k then
token[k] = 1.0
need_set = true
end

if need_set then

local requests = gen_dkim_queries(task, rule)
local requests = gen_dkim_queries(task, rule)
local sc = extract_task_score(task, rule)


if sc then
for dom,res in pairs(requests) do for dom,res in pairs(requests) do
-- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs
local query = string.format('%s.%s', dom, res) local query = string.format('%s.%s', dom, res)
rule.backend.set_token(task, rule, query, token)
rule.backend.set_token(task, rule, query, sc)
end end
end end
end end


local dkim_selector = { local dkim_selector = {
config = { config = {
-- keys map between actions and hash elements in bucket,
-- h is for ham,
-- s is for spam,
-- p is for probable spam
keys_map = {
['spam'] = 's',
['junk'] = 'p',
['ham'] = 'h'
},
symbol = 'DKIM_SCORE', -- symbol to be inserted symbol = 'DKIM_SCORE', -- symbol to be inserted
lower_bound = 10, -- minimum number of messages to be scored lower_bound = 10, -- minimum number of messages to be scored
min_score = nil, min_score = nil,
end end


local function url_reputation_filter(task, rule) local function url_reputation_filter(task, rule)
local requests = gen_url_queries(task, rule)
local requests = lua_util.extract_specific_urls(task, rule.selector.config.max_urls)
local results = {} local results = {}
local nchecked = 0 local nchecked = 0


end end
end end


for _,tld in ipairs(requests) do
rule.backend.get_token(task, rule, tld[1], tokens_cb)
for _,u in ipairs(requests) do
rule.backend.get_token(task, rule, u:get_tld(), tokens_cb)
end end
end end


local function url_reputation_idempotent(task, rule) local function url_reputation_idempotent(task, rule)
local verdict = lua_util.get_task_verdict(task)
local token = {
}
local cfg = rule.selector.config
local need_set = false

-- TODO: take metric score into consideration
local k = cfg.keys_map[verdict]

if k then
token[k] = 1.0
need_set = true
end

if need_set then

local requests = gen_url_queries(task, rule)
local requests = gen_url_queries(task, rule)
local sc = extract_task_score(task, rule)


if sc then
for _,tld in ipairs(requests) do for _,tld in ipairs(requests) do
rule.backend.set_token(task, rule, tld[1], token)
rule.backend.set_token(task, rule, tld[1], sc)
end end
end end
end end


local url_selector = { local url_selector = {
config = { config = {
-- keys map between actions and hash elements in bucket,
-- h is for ham,
-- s is for spam,
-- p is for probable spam
keys_map = {
['spam'] = 's',
['junk'] = 'p',
['ham'] = 'h'
},
symbol = 'URL_SCORE', -- symbol to be inserted symbol = 'URL_SCORE', -- symbol to be inserted
lower_bound = 10, -- minimum number of messages to be scored lower_bound = 10, -- minimum number of messages to be scored
min_score = nil, min_score = nil,
return return
end end
end end

local verdict = lua_util.get_task_verdict(task)
local token = {
}
local need_set = false

-- TODO: take metric score into consideration
local k = cfg.keys_map[verdict]

if k then
token[k] = 1.0
need_set = true
local sc = extract_task_score(task, rule)
if asn then
rule.backend.set_token(task, rule, cfg.asn_prefix .. asn, sc)
end end

if need_set then
if asn then
rule.backend.set_token(task, rule, cfg.asn_prefix .. asn, token)
end
if country then
rule.backend.set_token(task, rule, cfg.country_prefix .. country, token)
end

rule.backend.set_token(task, rule, cfg.ip_prefix .. tostring(ip), token)
if country then
rule.backend.set_token(task, rule, cfg.country_prefix .. country, sc)
end end

rule.backend.set_token(task, rule, cfg.ip_prefix .. tostring(ip), sc)
end end


-- Selectors are used to extract reputation tokens -- Selectors are used to extract reputation tokens
local ip_selector = { local ip_selector = {
config = { config = {
-- keys map between actions and hash elements in bucket,
-- h is for ham,
-- s is for spam,
-- p is for probable spam
keys_map = {
['spam'] = 's',
['junk'] = 'p',
['ham'] = 'h'
},
scores = { -- how each component is evaluated scores = { -- how each component is evaluated
['asn'] = 0.4, ['asn'] = 0.4,
['country'] = 0.01, ['country'] = 0.01,
end end


local function spf_reputation_idempotent(task, rule) local function spf_reputation_idempotent(task, rule)
local verdict = lua_util.get_task_verdict(task)
local sc = extract_task_score(task, rule)
local spf_record = task:get_mempool():get_variable('spf_record') local spf_record = task:get_mempool():get_variable('spf_record')
local spf_allow = task:has_symbol('R_SPF_ALLOW') local spf_allow = task:has_symbol('R_SPF_ALLOW')
local token = {
}
local cfg = rule.selector.config
local need_set = false


if not spf_record or not spf_allow then return end
if not spf_record or not spf_allow or not sc then return end


-- TODO: take metric score into consideration
local k = cfg.keys_map[verdict]

if k then
token[k] = 1.0
need_set = true
end

if need_set then
local cr = require "rspamd_cryptobox_hash"
local hkey = cr.create(spf_record):base32():sub(1, 32)
local cr = require "rspamd_cryptobox_hash"
local hkey = cr.create(spf_record):base32():sub(1, 32)


lua_util.debugm(N, task, 'set spf record %s -> %s = %s',
spf_record, hkey, token)
rule.backend.set_token(task, rule, hkey, token)
end
lua_util.debugm(N, task, 'set spf record %s -> %s = %s',
spf_record, hkey, token)
rule.backend.set_token(task, rule, hkey, sc)
end end




local spf_selector = { local spf_selector = {
config = { config = {
-- keys map between actions and hash elements in bucket,
-- h is for ham,
-- s is for spam,
-- p is for probable spam
keys_map = {
['spam'] = 's',
['junk'] = 'p',
['ham'] = 'h'
},
symbol = 'SPF_SCORE', -- symbol to be inserted symbol = 'SPF_SCORE', -- symbol to be inserted
lower_bound = 10, -- minimum number of messages to be scored lower_bound = 10, -- minimum number of messages to be scored
min_score = nil, min_score = nil,
end end


local function generic_reputation_idempotent(task, rule) local function generic_reputation_idempotent(task, rule)
local verdict = lua_util.get_task_verdict(task)
local sc = extract_task_score(task, rule)
local cfg = rule.selector.config local cfg = rule.selector.config
local need_set = false
local token = {}


local selector_res = cfg.selector(task) local selector_res = cfg.selector(task)
if not selector_res then return end if not selector_res then return end


local k = cfg.keys_map[verdict]

if k then
token[k] = 1.0
need_set = true
end

if need_set then
if sc then
if type(selector_res) == 'table' then if type(selector_res) == 'table' then
fun.each(function(e) fun.each(function(e)
lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', lua_util.debugm(N, task, 'set generic selector (%s) %s = %s',
rule['symbol'], e, token)
rule.backend.set_token(task, rule, e, token)
rule['symbol'], e, sc)
rule.backend.set_token(task, rule, e, sc)
end, selector_res) end, selector_res)
else else
lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', lua_util.debugm(N, task, 'set generic selector (%s) %s = %s',
rule['symbol'], selector_res, token)
rule.backend.set_token(task, rule, selector_res, token)
rule['symbol'], selector_res, sc)
rule.backend.set_token(task, rule, selector_res, sc)
end end
end end
end end


local generic_selector = { local generic_selector = {
schema = ts.shape{ schema = ts.shape{
keys_map = keymap_schema,
lower_bound = ts.number + ts.string / tonumber, lower_bound = ts.number + ts.string / tonumber,
max_score = ts.number:is_optional(), max_score = ts.number:is_optional(),
min_score = ts.number:is_optional(), min_score = ts.number:is_optional(),
whitelist = ts.string:is_optional(), whitelist = ts.string:is_optional(),
}, },
config = { config = {
-- keys map between actions and hash elements in bucket,
-- h is for ham,
-- s is for spam,
-- p is for probable spam
keys_map = {
['spam'] = 's',
['junk'] = 'p',
['ham'] = 'h'
},
lower_bound = 10, -- minimum number of messages to be scored lower_bound = 10, -- minimum number of messages to be scored
min_score = nil, min_score = nil,
max_score = nil, max_score = nil,
res = string.sub(res, 1, rule.backend.config.hashlen) res = string.sub(res, 1, rule.backend.config.hashlen)
end end


if rule.backend.config.prefix then
res = rule.backend.config.prefix .. res
end

return res return res
end end


return false return false
end end
-- Init scripts for buckets -- Init scripts for buckets
-- Redis script to extract data from Redis buckets
-- KEYS[1] - key to extract
-- Value returned - table of scores as a strings vector + number of scores
local redis_get_script_tpl = [[ local redis_get_script_tpl = [[
local key = KEYS[1] .. '${name}'
local vals = redis.call('HGETALL', key)
for i=1,#vals,2 do
local k = vals[i]
local v = vals[i + 1]
if scores[k] then
scores[k] = scores[k] + tonumber(v) * ${mult}
local cnt = redis.call('HGET', KEYS[1], 'n')
local results = {}
if cnt then
{% for w in windows %}
local sc = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}'))
table.insert(results, tostring(sc * {= w.mult =}))
{% endfor %}
else else
scores[k] = tonumber(v) * ${mult}
end
end
]]
local redis_script_tbl = {'local scores = {}'}
for _,bucket in ipairs(rule.backend.config.buckets) do
table.insert(redis_script_tbl, lua_util.template(redis_get_script_tpl, bucket))
end
table.insert(redis_script_tbl, [[
local result = {}
for k,v in pairs(scores) do
table.insert(result, k)
table.insert(result, v)
end

return result
]])
rule.backend.script_get = lua_redis.add_redis_script(table.concat(redis_script_tbl, '\n'),
our_redis_params)

redis_script_tbl = {}
local redis_set_script_tpl = [[
local key = KEYS[1] .. '${name}'
local last = tonumber(redis.call('HGET', key, 'start'))
local now = tonumber(KEYS[2])
if not last then
last = 0
end
local discriminate_bucket = false
if now - last > ${time} then
discriminate_bucket = true
redis.call('HSET', key, 'start', now)
end
for i=1,#ARGV,2 do
local k = ARGV[i]
local v = tonumber(ARGV[i + 1])

if discriminate_bucket then
local last_value = redis.call('HGET', key, k)
if last_value then
redis.call('HSET', key, k, last_value / 2.0)
{% for w in windows %}
table.insert(results, '0')
{% endfor %}
end

return results,cnt
]]

local get_script = lua_util.jinja_template(redis_get_script_tpl,
{windows = rule.backend.config.buckets})
rspamd_logger.debugm(N, rspamd_config, 'added extraction script %s', get_script)
rule.backend.script_get = lua_redis.add_redis_script(get_script, our_redis_params)

-- Redis script to update Redis buckets
-- KEYS[1] - key to update
-- KEYS[2] - current time in milliseconds
-- KEYS[3] - message score
-- KEYS[4] - expire for a bucket
-- Value returned - table of scores as a strings vector
local redis_adaptive_emea_script_tpl = [[
local last = redis.call('HGET', KEYS[1], 'l')
local score = tonumber(KEYS[3])
local now = tonumber(KEYS[2])
local scores = {}

if last then
{% for w in windows %}
local last_value = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}'))
local window = {= w.time =}
-- Adjust alpha
local time_diff = now - last_value
if time_diff > 0 then
time_diff = 0
end end
local alpha = 1.0 - math.exp((-time_diff) / (1000 * window))
local nscore = alpha * score + (1.0 - alpha) * last_value
table.insert(scores, tostring(nscore * {= w.mult =}))
{% endfor %}
else
{% for w in windows %}
table.insert(scores, tostring(score * {= w.mult =}))
{% endfor %}
end end
redis.call('HINCRBYFLOAT', key, k, v)
end


redis.call('EXPIRE', key, KEYS[3])
redis.call('HSET', key, 'last', now)
local i = 1
{% for w in windows %}
redis.call('HSET', KEYS[1], 'v' .. '{= w.name =}', scores[i])
i = i + 1
{% endfor %}
redis.call('HSET', KEYS[1], 'l', now)
redis.call('HINCRBY', KEYS[1], 'n', 1)
redis.call('EXPIRE', KEYS[1], tonumber(KEYS[4]))

return scores
]] ]]
for _,bucket in ipairs(rule.backend.config.buckets) do
table.insert(redis_script_tbl, lua_util.template(redis_set_script_tpl,
bucket))
end


rule.backend.script_set = lua_redis.add_redis_script(table.concat(redis_script_tbl, '\n'),
our_redis_params)
local set_script = lua_util.jinja_template(redis_adaptive_emea_script_tpl,
{windows = rule.backend.config.buckets})
rspamd_logger.debugm(N, rspamd_config, 'added emea update script %s', set_script)
rule.backend.script_set = lua_redis.add_redis_script(set_script, our_redis_params)


return true return true
end end
local ret = lua_redis.exec_redis_script(rule.backend.script_get, local ret = lua_redis.exec_redis_script(rule.backend.script_get,
{task = task, is_write = false}, {task = task, is_write = false},
redis_get_cb, redis_get_cb,
{token})
{key})
if not ret then if not ret then
rspamd_logger.errx(task, 'cannot make redis request to check results') rspamd_logger.errx(task, 'cannot make redis request to check results')
end end
end end


local function reputation_redis_set_token(task, rule, token, values, continuation_cb)
local function reputation_redis_set_token(task, rule, token, sc, continuation_cb)
local key = gen_token_key(token, rule) local key = gen_token_key(token, rule)


local function redis_set_cb(err, data) local function redis_set_cb(err, data)
end end
end end


-- We start from expiry update
local args = {}
for k,v in pairs(values) do
table.insert(args, k)
table.insert(args, v)
end
lua_util.debugm(N, task, 'rule %s - set values for key %s -> %s', lua_util.debugm(N, task, 'rule %s - set values for key %s -> %s',
rule['symbol'], key, values)
rule['symbol'], key, sc)
local ret = lua_redis.exec_redis_script(rule.backend.script_set, local ret = lua_redis.exec_redis_script(rule.backend.script_set,
{task = task, is_write = true}, {task = task, is_write = true},
redis_set_cb, redis_set_cb,
{token, tostring(rspamd_util:get_time()),
tostring(rule.backend.config.expiry)}, args)
{key, tostring(os.time() * 1000),
tonumber(sc),
tostring(rule.backend.config.expiry)})
if not ret then if not ret then
rspamd_logger.errx(task, 'got error while connecting to redis') rspamd_logger.errx(task, 'got error while connecting to redis')
end end
local backends = { local backends = {
redis = { redis = {
schema = ts.shape({ schema = ts.shape({
prefix = ts.string,
expiry = ts.number + ts.string / lua_util.parse_time_interval, expiry = ts.number + ts.string / lua_util.parse_time_interval,
buckets = ts.array_of(ts.shape{ buckets = ts.array_of(ts.shape{
time = ts.number + ts.string / lua_util.parse_time_interval, time = ts.number + ts.string / lua_util.parse_time_interval,
}, {extra_fields = lua_redis.config_schema}), }, {extra_fields = lua_redis.config_schema}),
config = { config = {
expiry = default_expiry, expiry = default_expiry,
prefix = default_prefix,
buckets = { buckets = {
{ {
time = 60 * 60, time = 60 * 60,

Laddar…
Avbryt
Spara