rspamd/lualib/rspamadm/rescore.lua

589 рядки
16 KiB
Lua

--[[
Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]--
--[[
2017-06-02 01:07:28 +02:00
local lua_util = require "lua_util"
local ucl = require "ucl"
2018-03-01 17:00:39 +01:00
local logger = require "rspamd_logger"
local rspamd_util = require "rspamd_util"
local argparse = require "argparse"
local rescore_utility = require "rescore_utility"
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local opts
local ignore_symbols = {
['DATE_IN_PAST'] =true,
['DATE_IN_FUTURE'] = true,
}
2018-03-01 17:00:39 +01:00
local parser = argparse()
:name "rspamadm rescore"
:description "Estimate optimal symbol weights from log files"
:help_description_margin(37)
parser:option "-l --log"
:description "Log file or files (from rescore)"
:argname("<log>")
:args "*"
parser:option "-c --config"
:description "Path to config file"
:argname("<file>")
:default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
parser:option "-o --output"
:description "Output file"
:argname("<file>")
:default("new.scores")
parser:flag "-d --diff"
:description "Show differences in scores"
parser:flag "-v --verbose"
:description "Verbose output"
parser:flag "-z --freq"
:description "Display hit frequencies"
parser:option "-i --iters"
:description "Learn iterations"
:argname("<n>")
:convert(tonumber)
:default(10)
parser:option "-b --batch"
:description "Batch size"
:argname("<n>")
:convert(tonumber)
:default(100)
parser:option "-d --decay"
:description "Decay rate"
:argname("<n>")
:convert(tonumber)
:default(0.001)
parser:option "-m --momentum"
:description "Learn momentum"
:argname("<n>")
:convert(tonumber)
:default(0.1)
parser:option "-t --threads"
:description "Number of threads to use"
:argname("<n>")
:convert(tonumber)
:default(1)
parser:option "-o --optim"
:description "Optimisation algorithm"
:argname("<alg>")
:convert {
LBFGS = "LBFGS",
ADAM = "ADAM",
ADAGRAD = "ADAGRAD",
SGD = "SGD",
NAG = "NAG"
}
:default "ADAM"
parser:option "--ignore-symbol"
:description "Ignore symbol from logs"
:argname("<sym>")
:args "*"
parser:option "--penalty-weight"
:description "Add new penalty weight to test"
:argname("<n>")
:convert(tonumber)
:args "*"
parser:option "--learning-rate"
:description "Add new learning rate to test"
:argname("<n>")
:convert(tonumber)
:args "*"
parser:option "--spam_action"
:description "Spam action"
:argname("<act>")
:default("reject")
parser:option "--learning_rate_decay"
:description "Learn rate decay (for some algs)"
:argname("<n>")
:convert(tonumber)
:default(0.0)
parser:option "--weight_decay"
:description "Weight decay (for some algs)"
:argname("<n>")
:convert(tonumber)
:default(0.0)
parser:option "--l1"
:description "L1 regularization penalty"
:argname("<n>")
:convert(tonumber)
:default(0.0)
parser:option "--l2"
:description "L2 regularization penalty"
:argname("<n>")
:convert(tonumber)
:default(0.0)
2018-03-07 14:59:27 +01:00
local function make_dataset_from_logs(logs, all_symbols, spam_score)
2017-06-02 01:07:28 +02:00
2019-07-01 16:05:52 +02:00
local inputs = {}
local outputs = {}
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
for _, log in pairs(logs) do
2019-07-01 16:05:52 +02:00
2018-03-01 17:00:39 +01:00
log = lua_util.rspamd_str_split(log, " ")
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
if log[1] == "SPAM" then
2019-07-01 16:05:52 +02:00
outputs[#outputs+1] = 1
2018-03-01 17:00:39 +01:00
else
2019-07-01 16:05:52 +02:00
outputs[#outputs+1] = 0
2018-03-01 17:00:39 +01:00
end
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local symbols_set = {}
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
for i=4,#log do
if not ignore_symbols[ log[i] ] then
symbols_set[log[i] ] = true
end
2018-03-01 17:00:39 +01:00
end
2017-06-02 01:07:28 +02:00
2019-07-01 16:05:52 +02:00
local input_vec = {}
2018-03-01 17:00:39 +01:00
for index, symbol in pairs(all_symbols) do
if symbols_set[symbol] then
2019-07-01 16:05:52 +02:00
input_vec[index] = 1
2018-03-01 17:00:39 +01:00
else
2019-07-01 16:05:52 +02:00
input_vec[index] = 0
2018-03-01 17:00:39 +01:00
end
end
2017-06-02 01:07:28 +02:00
2019-07-01 16:05:52 +02:00
inputs[#inputs + 1] = input_vec
2018-03-01 17:00:39 +01:00
end
2017-06-02 01:07:28 +02:00
2019-07-01 16:05:52 +02:00
return inputs,outputs
2017-06-02 01:07:28 +02:00
end
local function init_weights(all_symbols, original_symbol_scores)
end
2018-05-30 15:54:41 +02:00
local function shuffle(logs, messages)
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local size = #logs
for i = size, 1, -1 do
local rand = math.random(size)
logs[i], logs[rand] = logs[rand], logs[i]
2018-05-30 15:54:41 +02:00
messages[i], messages[rand] = messages[rand], messages[i]
2018-03-01 17:00:39 +01:00
end
2017-06-02 01:07:28 +02:00
end
2018-05-30 15:54:41 +02:00
local function split_logs(logs, messages, split_percent)
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
if not split_percent then
split_percent = 60
end
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local split_index = math.floor(#logs * split_percent / 100)
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local test_logs = {}
local train_logs = {}
2018-05-30 15:54:41 +02:00
local test_messages = {}
local train_messages = {}
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
for i=1,split_index do
2018-05-30 15:54:41 +02:00
table.insert(train_logs, logs[i])
table.insert(train_messages, messages[i])
2018-03-01 17:00:39 +01:00
end
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
for i=split_index + 1, #logs do
2018-05-30 15:54:41 +02:00
table.insert(test_logs, logs[i])
table.insert(test_messages, messages[i])
2018-03-01 17:00:39 +01:00
end
2017-06-02 01:07:28 +02:00
2018-05-30 15:54:41 +02:00
return {train_logs,train_messages}, {test_logs,test_messages}
2017-06-02 01:07:28 +02:00
end
local function stitch_new_scores(all_symbols, new_scores)
2018-03-01 17:00:39 +01:00
local new_symbol_scores = {}
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
for idx, symbol in pairs(all_symbols) do
new_symbol_scores[symbol] = new_scores[idx]
end
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
return new_symbol_scores
2017-06-02 01:07:28 +02:00
end
local function update_logs(logs, symbol_scores)
2018-03-01 17:00:39 +01:00
for i, log in ipairs(logs) do
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
log = lua_util.rspamd_str_split(log, " ")
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local score = 0
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
for j=4,#log do
log[j] = log[j]:gsub("%s+", "")
score = score + (symbol_scores[log[j] ] or 0)
2018-03-01 17:00:39 +01:00
end
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
log[2] = lua_util.round(score, 2)
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
logs[i] = table.concat(log, " ")
end
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
return logs
2017-06-02 01:07:28 +02:00
end
local function write_scores(new_symbol_scores, file_path)
2018-03-01 17:00:39 +01:00
local file = assert(io.open(file_path, "w"))
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl")
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
file:write(new_scores_ucl)
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
file:close()
2017-06-02 01:07:28 +02:00
end
local function print_score_diff(new_symbol_scores, original_symbol_scores)
2018-03-02 16:54:36 +01:00
logger.message(string.format("%-35s %-10s %-10s",
"SYMBOL", "OLD_SCORE", "NEW_SCORE"))
2018-03-01 17:00:39 +01:00
for symbol, new_score in pairs(new_symbol_scores) do
2018-03-02 16:54:36 +01:00
logger.message(string.format("%-35s %-10s %-10s",
2018-03-01 17:00:39 +01:00
symbol,
original_symbol_scores[symbol] or 0,
2018-03-05 08:18:48 +01:00
lua_util.round(new_score, 2)))
2018-03-01 17:00:39 +01:00
end
2018-03-02 16:54:36 +01:00
logger.message("\nClass changes \n")
2018-03-01 17:00:39 +01:00
for symbol, new_score in pairs(new_symbol_scores) do
if original_symbol_scores[symbol] ~= nil then
if (original_symbol_scores[symbol] > 0 and new_score < 0) or
(original_symbol_scores[symbol] < 0 and new_score > 0) then
2018-03-02 16:54:36 +01:00
logger.message(string.format("%-35s %-10s %-10s",
2018-03-01 17:00:39 +01:00
symbol,
original_symbol_scores[symbol] or 0,
2018-03-05 08:18:48 +01:00
lua_util.round(new_score, 2)))
2018-03-01 17:00:39 +01:00
end
end
end
2017-06-02 01:07:28 +02:00
end
2018-05-30 15:54:41 +02:00
local function calculate_fscore_from_weights(logs, messages,
all_symbols,
weights,
threshold)
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local new_symbol_scores = weights:clone()
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
logs = update_logs(logs, new_symbol_scores)
2017-06-02 01:07:28 +02:00
local file_stats, _, all_fps, all_fns =
2018-05-30 15:54:41 +02:00
rescore_utility.generate_statistics_from_logs(logs, messages, threshold)
2017-06-02 01:07:28 +02:00
return file_stats.fscore, all_fps, all_fns
2017-06-02 01:07:28 +02:00
end
2018-05-30 15:54:41 +02:00
local function print_stats(logs, messages, threshold)
2017-06-02 01:07:28 +02:00
2018-05-30 15:54:41 +02:00
local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs,
messages, threshold)
2017-06-02 01:07:28 +02:00
local file_stat_format = [=[
2017-06-02 01:07:28 +02:00
F-score: %.2f
False positive rate: %.2f %%
False negative rate: %.2f %%
Overall accuracy: %.2f %%
Slowest message: %.2f (%s)
]=]
2017-06-02 01:07:28 +02:00
2018-03-02 16:54:36 +01:00
logger.message("\nStatistics at threshold: " .. threshold)
2018-03-01 17:00:39 +01:00
2018-03-02 16:54:36 +01:00
logger.message(string.format(file_stat_format,
2018-03-01 17:00:39 +01:00
file_stats.fscore,
file_stats.false_positive_rate,
file_stats.false_negative_rate,
file_stats.overall_accuracy,
file_stats.slowest,
file_stats.slowest_file))
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
end
2017-06-02 01:07:28 +02:00
-- training function
local function train(dataset, opt, model, criterion, epoch,
all_symbols, spam_threshold, initial_weights)
end
local learning_rates = {
2018-03-07 15:00:11 +01:00
0.01
}
local penalty_weights = {
2018-03-07 15:00:11 +01:00
0
}
2018-03-02 16:54:36 +01:00
local function get_threshold()
2018-03-01 17:00:39 +01:00
local actions = rspamd_config:get_all_actions()
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
if opts['spam-action'] then
return (actions[opts['spam-action'] ] or 0),actions['reject']
2018-03-01 17:00:39 +01:00
end
2018-03-07 14:59:27 +01:00
return (actions['add header'] or actions['rewrite subject']
or actions['reject']), actions['reject']
2018-03-01 17:00:39 +01:00
end
2017-06-02 01:07:28 +02:00
local function handler(args)
opts = parser:parse(args)
if not opts['log'] then
parser:error('no log specified')
end
local _r,err = rspamd_config:load_ucl(opts['config'])
if not _r then
logger.errx('cannot parse %s: %s', opts['config'], err)
os.exit(1)
end
_r,err = rspamd_config:parse_rcl({'logging', 'worker'})
if not _r then
logger.errx('cannot process %s: %s', opts['config'], err)
os.exit(1)
end
2018-03-07 14:59:27 +01:00
local threshold,reject_score = get_threshold()
2018-05-30 15:54:41 +02:00
local logs,messages = rescore_utility.get_all_logs(opts['log'])
if opts['ignore-symbol'] then
local function add_ignore(s)
ignore_symbols[s] = true
end
if type(opts['ignore-symbol']) == 'table' then
for _,s in ipairs(opts['ignore-symbol']) do
add_ignore(s)
end
else
add_ignore(opts['ignore-symbol'])
end
end
if opts['learning-rate'] then
learning_rates = {}
local function add_rate(r)
if tonumber(r) then
table.insert(learning_rates, tonumber(r))
end
end
if type(opts['learning-rate']) == 'table' then
for _,s in ipairs(opts['learning-rate']) do
add_rate(s)
end
else
add_rate(opts['learning-rate'])
end
end
if opts['penalty-weight'] then
penalty_weights = {}
local function add_weight(r)
if tonumber(r) then
table.insert(penalty_weights, tonumber(r))
end
end
if type(opts['penalty-weight']) == 'table' then
for _,s in ipairs(opts['penalty-weight']) do
add_weight(s)
end
else
add_weight(opts['penalty-weight'])
end
end
local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols)
local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config,
ignore_symbols)
2017-06-02 01:07:28 +02:00
-- Display hit frequencies
if opts['freq'] then
2018-05-30 15:54:41 +02:00
local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs,
messages,
threshold)
local t = {}
for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end
2018-03-19 16:43:03 +01:00
local function compare_symbols(a, b)
if (a.spam_overall ~= b.spam_overall) then
return b.spam_overall < a.spam_overall
end
if (b.spam_hits ~= a.spam_hits) then
return b.spam_hits < a.spam_hits
end
return b.ham_hits < a.ham_hits
end
table.sort(t, compare_symbols)
2018-03-19 16:43:03 +01:00
logger.message(string.format("%-40s %6s %6s %6s %6s %6s %6s %6s",
"NAME", "HITS", "HAM", "HAM%", "SPAM", "SPAM%", "S/O", "OVER%"))
for _, symbol_stats in pairs(t) do
logger.message(
string.format("%-40s %6d %6d %6.2f %6d %6.2f %6.2f %6.2f",
symbol_stats.name,
symbol_stats.no_of_hits,
symbol_stats.ham_hits,
lua_util.round(symbol_stats.ham_percent,2),
symbol_stats.spam_hits,
lua_util.round(symbol_stats.spam_percent,2),
lua_util.round(symbol_stats.spam_overall,2),
lua_util.round(symbol_stats.overall, 2)
)
)
end
-- Print file statistics
2018-05-30 15:54:41 +02:00
print_stats(logs, messages, threshold)
-- Work out how many symbols weren't seen in the corpus
local symbols_no_hits = {}
local total_symbols = 0
for sym in pairs(original_symbol_scores) do
total_symbols = total_symbols + 1
if (all_symbols_stats[sym] == nil) then
table.insert(symbols_no_hits, sym)
end
end
if (#symbols_no_hits > 0) then
table.sort(symbols_no_hits)
-- Calculate percentage of rules with no hits
local nhpct = lua_util.round((#symbols_no_hits/total_symbols)*100,2)
logger.message(
string.format('\nFound %s (%-.2f%%) symbols out of %s with no hits in corpus:',
#symbols_no_hits, nhpct, total_symbols
)
)
for _, symbol in pairs(symbols_no_hits) do
logger.messagex('%s', symbol)
end
end
return
end
2018-05-30 15:54:41 +02:00
shuffle(logs, messages)
local train_logs, validation_logs = split_logs(logs, messages,70)
local cv_logs, test_logs = split_logs(validation_logs[1], validation_logs[2], 50)
2017-06-02 01:07:28 +02:00
2018-05-30 15:54:41 +02:00
local dataset = make_dataset_from_logs(train_logs[1], all_symbols, reject_score)
2018-03-01 17:00:39 +01:00
-- Start of perceptron training
local input_size = #all_symbols
2018-03-07 14:59:27 +01:00
local linear_module = nn.Linear(input_size, 1, false)
local activation = nn.Sigmoid()
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local perceptron = nn.Sequential()
perceptron:add(linear_module)
perceptron:add(activation)
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local criterion = nn.MSECriterion()
2018-03-07 14:59:27 +01:00
--criterion.sizeAverage = false
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local best_fscore = -math.huge
local best_weights = linear_module.weight[1]:clone()
local best_learning_rate
local best_weight_decay
local all_fps
local all_fns
2017-06-02 01:07:28 +02:00
for _,lr in ipairs(learning_rates) do
for _,wd in ipairs(penalty_weights) do
linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
local initial_weights = linear_module.weight[1]:clone()
opts.learning_rate = lr
opts.weight_decay = wd
for i=1,tonumber(opts.iters) do
train(dataset, opts, perceptron, criterion, i, all_symbols, threshold,
initial_weights)
end
2017-06-02 01:07:28 +02:00
2018-05-30 15:54:41 +02:00
local fscore, fps, fns = calculate_fscore_from_weights(cv_logs[1],
cv_logs[2],
2018-03-01 17:00:39 +01:00
all_symbols,
linear_module.weight[1],
threshold)
2017-06-02 01:07:28 +02:00
logger.messagex("Cross-validation fscore=%s, learning rate=%s, weight decay=%s",
fscore, lr, wd)
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
if best_fscore < fscore then
best_learning_rate = lr
best_weight_decay = wd
2018-03-01 17:00:39 +01:00
best_fscore = fscore
best_weights = linear_module.weight[1]:clone()
all_fps = fps
all_fns = fns
2018-03-01 17:00:39 +01:00
end
end
end
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
-- End perceptron training
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
local new_symbol_scores = best_weights
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
2017-06-02 01:07:28 +02:00
if opts["output"] then
write_scores(new_symbol_scores, opts["output"])
2018-03-01 17:00:39 +01:00
end
2017-06-02 01:07:28 +02:00
if opts["diff"] then
2018-03-01 17:00:39 +01:00
print_score_diff(new_symbol_scores, original_symbol_scores)
end
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
-- Pre-rescore test stats
2018-03-02 16:54:36 +01:00
logger.message("\n\nPre-rescore test stats\n")
2018-05-30 15:54:41 +02:00
test_logs[1] = update_logs(test_logs[1], original_symbol_scores)
print_stats(test_logs[1], test_logs[2], threshold)
2017-06-02 01:07:28 +02:00
2018-03-01 17:00:39 +01:00
-- Post-rescore test stats
2018-05-30 15:54:41 +02:00
test_logs[1] = update_logs(test_logs[1], new_symbol_scores)
2018-03-02 16:54:36 +01:00
logger.message("\n\nPost-rescore test stats\n")
2018-05-30 15:54:41 +02:00
print_stats(test_logs[1], test_logs[2], threshold)
logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s',
best_fscore, best_learning_rate, best_weight_decay)
-- Show all FPs/FNs, useful for corpus checking and rule creation/modification
if (all_fps and #all_fps > 0) then
logger.message("\nFalse-Positives:")
for _, fp in pairs(all_fps) do
logger.messagex('%s', fp)
end
end
if (all_fns and #all_fns > 0) then
logger.message("\nFalse-Negatives:")
for _, fn in pairs(all_fns) do
logger.messagex('%s', fn)
end
end
end
return {
handler = handler,
description = parser._description,
name = 'rescore'
}
--]]
return nil