+++ /dev/null
-local lua_util = require "lua_util"
-local rspamd_util = require "rspamd_util"
-local fun = require "fun"
-
-local utility = {}
-
-function utility.get_all_symbols(logs, ignore_symbols)
- -- Returns a list of all symbols
-
- local symbols_set = {}
-
- for _, line in pairs(logs) do
- line = lua_util.rspamd_str_split(line, " ")
- for i = 4, (#line - 1) do
- line[i] = line[i]:gsub("%s+", "")
- if not symbols_set[line[i]] then
- symbols_set[line[i]] = true
- end
- end
- end
-
- local all_symbols = {}
-
- for symbol, _ in pairs(symbols_set) do
- if not ignore_symbols[symbol] then
- all_symbols[#all_symbols + 1] = symbol
- end
- end
-
- table.sort(all_symbols)
-
- return all_symbols
-end
-
-function utility.read_log_file(file)
-
- local lines = {}
- local messages = {}
-
- local fd = assert(io.open(file, "r"))
- local fname = string.gsub(file, "(.*/)(.*)", "%2")
-
- for line in fd:lines() do
- local start, stop = string.find(line, fname .. ':')
-
- if start and stop then
- table.insert(lines, string.sub(line, 1, start))
- table.insert(messages, string.sub(line, stop + 1, -1))
- end
- end
-
- io.close(fd)
-
- return lines, messages
-end
-
-function utility.get_all_logs(dirs)
- -- Reads all log files in the directory and returns a list of logs.
-
- if type(dirs) == 'string' then
- dirs = { dirs }
- end
-
- local all_logs = {}
- local all_messages = {}
-
- for _, dir in ipairs(dirs) do
- if dir:sub(-1, -1) == "/" then
- dir = dir:sub(1, -2)
- local files = rspamd_util.glob(dir .. "/*.log")
- for _, file in pairs(files) do
- local logs, messages = utility.read_log_file(file)
- for i = 1, #logs do
- table.insert(all_logs, logs[i])
- table.insert(all_messages, messages[i])
- end
- end
- else
- local logs, messages = utility.read_log_file(dir)
- for i = 1, #logs do
- table.insert(all_logs, logs[i])
- table.insert(all_messages, messages[i])
- end
- end
- end
-
- return all_logs, all_messages
-end
-
-function utility.get_all_symbol_scores(conf, ignore_symbols)
- local symbols = conf:get_symbols_scores()
-
- return fun.tomap(fun.map(function(name, elt)
- return name, elt['score']
- end, fun.filter(function(name, elt)
- return not ignore_symbols[name]
- end, symbols)))
-end
-
-function utility.generate_statistics_from_logs(logs, messages, threshold)
-
- -- Returns file_stats table and list of symbol_stats table.
-
- local file_stats = {
- no_of_emails = 0,
- no_of_spam = 0,
- no_of_ham = 0,
- spam_percent = 0,
- ham_percent = 0,
- true_positives = 0,
- true_negatives = 0,
- false_negative_rate = 0,
- false_positive_rate = 0,
- overall_accuracy = 0,
- fscore = 0,
- avg_scan_time = 0,
- slowest_file = nil,
- slowest = 0
- }
-
- local all_symbols_stats = {}
- local all_fps = {}
- local all_fns = {}
-
- local false_positives = 0
- local false_negatives = 0
- local true_positives = 0
- local true_negatives = 0
- local no_of_emails = 0
- local no_of_spam = 0
- local no_of_ham = 0
-
- for i, log in ipairs(logs) do
- log = lua_util.rspamd_str_trim(log)
- log = lua_util.rspamd_str_split(log, " ")
- local message = messages[i]
-
- local is_spam = (log[1] == "SPAM")
- local score = tonumber(log[2])
-
- no_of_emails = no_of_emails + 1
-
- if is_spam then
- no_of_spam = no_of_spam + 1
- else
- no_of_ham = no_of_ham + 1
- end
-
- if is_spam and (score >= threshold) then
- true_positives = true_positives + 1
- elseif is_spam and (score < threshold) then
- false_negatives = false_negatives + 1
- table.insert(all_fns, message)
- elseif not is_spam and (score >= threshold) then
- false_positives = false_positives + 1
- table.insert(all_fps, message)
- else
- true_negatives = true_negatives + 1
- end
-
- for j = 4, (#log - 1) do
- if all_symbols_stats[log[j]] == nil then
- all_symbols_stats[log[j]] = {
- name = message,
- no_of_hits = 0,
- spam_hits = 0,
- ham_hits = 0,
- spam_overall = 0
- }
- end
- local sym = log[j]
-
- all_symbols_stats[sym].no_of_hits = all_symbols_stats[sym].no_of_hits + 1
-
- if is_spam then
- all_symbols_stats[sym].spam_hits = all_symbols_stats[sym].spam_hits + 1
- else
- all_symbols_stats[sym].ham_hits = all_symbols_stats[sym].ham_hits + 1
- end
-
- -- Find slowest message
- if ((tonumber(log[#log]) or 0) > file_stats.slowest) then
- file_stats.slowest = tonumber(log[#log])
- file_stats.slowest_file = message
- end
- end
- end
-
- -- Calculating file stats
-
- file_stats.no_of_ham = no_of_ham
- file_stats.no_of_spam = no_of_spam
- file_stats.no_of_emails = no_of_emails
- file_stats.true_positives = true_positives
- file_stats.true_negatives = true_negatives
-
- if no_of_emails > 0 then
- file_stats.spam_percent = no_of_spam * 100 / no_of_emails
- file_stats.ham_percent = no_of_ham * 100 / no_of_emails
- file_stats.overall_accuracy = (true_positives + true_negatives) * 100 /
- no_of_emails
- end
-
- if no_of_ham > 0 then
- file_stats.false_positive_rate = false_positives * 100 / no_of_ham
- end
-
- if no_of_spam > 0 then
- file_stats.false_negative_rate = false_negatives * 100 / no_of_spam
- end
-
- file_stats.fscore = 2 * true_positives / (2
- * true_positives
- + false_positives
- + false_negatives)
-
- -- Calculating symbol stats
-
- for _, symbol_stats in pairs(all_symbols_stats) do
- symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam
- symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham
- symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails
- symbol_stats.spam_overall = symbol_stats.spam_percent /
- (symbol_stats.spam_percent + symbol_stats.ham_percent)
- end
-
- return file_stats, all_symbols_stats, all_fps, all_fns
-end
-
-return utility
+++ /dev/null
---[[
-Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-]]--
-
---[[
-local lua_util = require "lua_util"
-local ucl = require "ucl"
-local logger = require "rspamd_logger"
-local rspamd_util = require "rspamd_util"
-local argparse = require "argparse"
-local rescore_utility = require "rescore_utility"
-
-
-local opts
-local ignore_symbols = {
- ['DATE_IN_PAST'] =true,
- ['DATE_IN_FUTURE'] = true,
-}
-
-local parser = argparse()
- :name "rspamadm rescore"
- :description "Estimate optimal symbol weights from log files"
- :help_description_margin(37)
-
-parser:option "-l --log"
- :description "Log file or files (from rescore)"
- :argname("<log>")
- :args "*"
-parser:option "-c --config"
- :description "Path to config file"
- :argname("<file>")
- :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
-parser:option "-o --output"
- :description "Output file"
- :argname("<file>")
- :default("new.scores")
-parser:flag "-d --diff"
- :description "Show differences in scores"
-parser:flag "-v --verbose"
- :description "Verbose output"
-parser:flag "-z --freq"
- :description "Display hit frequencies"
-parser:option "-i --iters"
- :description "Learn iterations"
- :argname("<n>")
- :convert(tonumber)
- :default(10)
-parser:option "-b --batch"
- :description "Batch size"
- :argname("<n>")
- :convert(tonumber)
- :default(100)
-parser:option "-d --decay"
- :description "Decay rate"
- :argname("<n>")
- :convert(tonumber)
- :default(0.001)
-parser:option "-m --momentum"
- :description "Learn momentum"
- :argname("<n>")
- :convert(tonumber)
- :default(0.1)
-parser:option "-t --threads"
- :description "Number of threads to use"
- :argname("<n>")
- :convert(tonumber)
- :default(1)
-parser:option "-o --optim"
- :description "Optimisation algorithm"
- :argname("<alg>")
- :convert {
- LBFGS = "LBFGS",
- ADAM = "ADAM",
- ADAGRAD = "ADAGRAD",
- SGD = "SGD",
- NAG = "NAG"
- }
- :default "ADAM"
-parser:option "--ignore-symbol"
- :description "Ignore symbol from logs"
- :argname("<sym>")
- :args "*"
-parser:option "--penalty-weight"
- :description "Add new penalty weight to test"
- :argname("<n>")
- :convert(tonumber)
- :args "*"
-parser:option "--learning-rate"
- :description "Add new learning rate to test"
- :argname("<n>")
- :convert(tonumber)
- :args "*"
-parser:option "--spam_action"
- :description "Spam action"
- :argname("<act>")
- :default("reject")
-parser:option "--learning_rate_decay"
- :description "Learn rate decay (for some algs)"
- :argname("<n>")
- :convert(tonumber)
- :default(0.0)
-parser:option "--weight_decay"
- :description "Weight decay (for some algs)"
- :argname("<n>")
- :convert(tonumber)
- :default(0.0)
-parser:option "--l1"
- :description "L1 regularization penalty"
- :argname("<n>")
- :convert(tonumber)
- :default(0.0)
-parser:option "--l2"
- :description "L2 regularization penalty"
- :argname("<n>")
- :convert(tonumber)
- :default(0.0)
-
-local function make_dataset_from_logs(logs, all_symbols, spam_score)
-
- local inputs = {}
- local outputs = {}
-
- for _, log in pairs(logs) do
-
- log = lua_util.rspamd_str_split(log, " ")
-
- if log[1] == "SPAM" then
- outputs[#outputs+1] = 1
- else
- outputs[#outputs+1] = 0
- end
-
- local symbols_set = {}
-
- for i=4,#log do
- if not ignore_symbols[ log[i] ] then
- symbols_set[log[i] ] = true
- end
- end
-
- local input_vec = {}
- for index, symbol in pairs(all_symbols) do
- if symbols_set[symbol] then
- input_vec[index] = 1
- else
- input_vec[index] = 0
- end
- end
-
- inputs[#inputs + 1] = input_vec
- end
-
- return inputs,outputs
-end
-
-local function init_weights(all_symbols, original_symbol_scores)
-end
-
-local function shuffle(logs, messages)
-
- local size = #logs
- for i = size, 1, -1 do
- local rand = math.random(size)
- logs[i], logs[rand] = logs[rand], logs[i]
- messages[i], messages[rand] = messages[rand], messages[i]
- end
-
-end
-
-local function split_logs(logs, messages, split_percent)
-
- if not split_percent then
- split_percent = 60
- end
-
- local split_index = math.floor(#logs * split_percent / 100)
-
- local test_logs = {}
- local train_logs = {}
- local test_messages = {}
- local train_messages = {}
-
- for i=1,split_index do
- table.insert(train_logs, logs[i])
- table.insert(train_messages, messages[i])
- end
-
- for i=split_index + 1, #logs do
- table.insert(test_logs, logs[i])
- table.insert(test_messages, messages[i])
- end
-
- return {train_logs,train_messages}, {test_logs,test_messages}
-end
-
-local function stitch_new_scores(all_symbols, new_scores)
-
- local new_symbol_scores = {}
-
- for idx, symbol in pairs(all_symbols) do
- new_symbol_scores[symbol] = new_scores[idx]
- end
-
- return new_symbol_scores
-end
-
-
-local function update_logs(logs, symbol_scores)
-
- for i, log in ipairs(logs) do
-
- log = lua_util.rspamd_str_split(log, " ")
-
- local score = 0
-
- for j=4,#log do
- log[j] = log[j]:gsub("%s+", "")
- score = score + (symbol_scores[log[j] ] or 0)
- end
-
- log[2] = lua_util.round(score, 2)
-
- logs[i] = table.concat(log, " ")
- end
-
- return logs
-end
-
-local function write_scores(new_symbol_scores, file_path)
-
- local file = assert(io.open(file_path, "w"))
-
- local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl")
-
- file:write(new_scores_ucl)
-
- file:close()
-end
-
-local function print_score_diff(new_symbol_scores, original_symbol_scores)
-
- logger.message(string.format("%-35s %-10s %-10s",
- "SYMBOL", "OLD_SCORE", "NEW_SCORE"))
-
- for symbol, new_score in pairs(new_symbol_scores) do
- logger.message(string.format("%-35s %-10s %-10s",
- symbol,
- original_symbol_scores[symbol] or 0,
- lua_util.round(new_score, 2)))
- end
-
- logger.message("\nClass changes \n")
- for symbol, new_score in pairs(new_symbol_scores) do
- if original_symbol_scores[symbol] ~= nil then
- if (original_symbol_scores[symbol] > 0 and new_score < 0) or
- (original_symbol_scores[symbol] < 0 and new_score > 0) then
- logger.message(string.format("%-35s %-10s %-10s",
- symbol,
- original_symbol_scores[symbol] or 0,
- lua_util.round(new_score, 2)))
- end
- end
- end
-
-end
-
-local function calculate_fscore_from_weights(logs, messages,
- all_symbols,
- weights,
- threshold)
-
- local new_symbol_scores = weights:clone()
-
- new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
-
- logs = update_logs(logs, new_symbol_scores)
-
- local file_stats, _, all_fps, all_fns =
- rescore_utility.generate_statistics_from_logs(logs, messages, threshold)
-
- return file_stats.fscore, all_fps, all_fns
-end
-
-local function print_stats(logs, messages, threshold)
-
- local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs,
- messages, threshold)
-
- local file_stat_format = [=[
-F-score: %.2f
-False positive rate: %.2f %%
-False negative rate: %.2f %%
-Overall accuracy: %.2f %%
-Slowest message: %.2f (%s)
-]=]
-
- logger.message("\nStatistics at threshold: " .. threshold)
-
- logger.message(string.format(file_stat_format,
- file_stats.fscore,
- file_stats.false_positive_rate,
- file_stats.false_negative_rate,
- file_stats.overall_accuracy,
- file_stats.slowest,
- file_stats.slowest_file))
-
-end
-
--- training function
-local function train(dataset, opt, model, criterion, epoch,
- all_symbols, spam_threshold, initial_weights)
-end
-
-local learning_rates = {
- 0.01
-}
-local penalty_weights = {
- 0
-}
-
-local function get_threshold()
- local actions = rspamd_config:get_all_actions()
-
- if opts['spam-action'] then
- return (actions[opts['spam-action'] ] or 0),actions['reject']
- end
- return (actions['add header'] or actions['rewrite subject']
- or actions['reject']), actions['reject']
-end
-
-local function handler(args)
- opts = parser:parse(args)
- if not opts['log'] then
- parser:error('no log specified')
- end
-
- local _r,err = rspamd_config:load_ucl(opts['config'])
-
- if not _r then
- logger.errx('cannot parse %s: %s', opts['config'], err)
- os.exit(1)
- end
-
- _r,err = rspamd_config:parse_rcl({'logging', 'worker'})
- if not _r then
- logger.errx('cannot process %s: %s', opts['config'], err)
- os.exit(1)
- end
-
- local threshold,reject_score = get_threshold()
- local logs,messages = rescore_utility.get_all_logs(opts['log'])
-
- if opts['ignore-symbol'] then
- local function add_ignore(s)
- ignore_symbols[s] = true
- end
- if type(opts['ignore-symbol']) == 'table' then
- for _,s in ipairs(opts['ignore-symbol']) do
- add_ignore(s)
- end
- else
- add_ignore(opts['ignore-symbol'])
- end
- end
-
- if opts['learning-rate'] then
- learning_rates = {}
-
- local function add_rate(r)
- if tonumber(r) then
- table.insert(learning_rates, tonumber(r))
- end
- end
- if type(opts['learning-rate']) == 'table' then
- for _,s in ipairs(opts['learning-rate']) do
- add_rate(s)
- end
- else
- add_rate(opts['learning-rate'])
- end
- end
-
- if opts['penalty-weight'] then
- penalty_weights = {}
-
- local function add_weight(r)
- if tonumber(r) then
- table.insert(penalty_weights, tonumber(r))
- end
- end
- if type(opts['penalty-weight']) == 'table' then
- for _,s in ipairs(opts['penalty-weight']) do
- add_weight(s)
- end
- else
- add_weight(opts['penalty-weight'])
- end
- end
-
- local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols)
- local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config,
- ignore_symbols)
-
- -- Display hit frequencies
- if opts['freq'] then
- local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs,
- messages,
- threshold)
- local t = {}
- for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end
-
- local function compare_symbols(a, b)
- if (a.spam_overall ~= b.spam_overall) then
- return b.spam_overall < a.spam_overall
- end
- if (b.spam_hits ~= a.spam_hits) then
- return b.spam_hits < a.spam_hits
- end
- return b.ham_hits < a.ham_hits
- end
- table.sort(t, compare_symbols)
-
- logger.message(string.format("%-40s %6s %6s %6s %6s %6s %6s %6s",
- "NAME", "HITS", "HAM", "HAM%", "SPAM", "SPAM%", "S/O", "OVER%"))
- for _, symbol_stats in pairs(t) do
- logger.message(
- string.format("%-40s %6d %6d %6.2f %6d %6.2f %6.2f %6.2f",
- symbol_stats.name,
- symbol_stats.no_of_hits,
- symbol_stats.ham_hits,
- lua_util.round(symbol_stats.ham_percent,2),
- symbol_stats.spam_hits,
- lua_util.round(symbol_stats.spam_percent,2),
- lua_util.round(symbol_stats.spam_overall,2),
- lua_util.round(symbol_stats.overall, 2)
- )
- )
- end
-
- -- Print file statistics
- print_stats(logs, messages, threshold)
-
- -- Work out how many symbols weren't seen in the corpus
- local symbols_no_hits = {}
- local total_symbols = 0
- for sym in pairs(original_symbol_scores) do
- total_symbols = total_symbols + 1
- if (all_symbols_stats[sym] == nil) then
- table.insert(symbols_no_hits, sym)
- end
- end
- if (#symbols_no_hits > 0) then
- table.sort(symbols_no_hits)
- -- Calculate percentage of rules with no hits
- local nhpct = lua_util.round((#symbols_no_hits/total_symbols)*100,2)
- logger.message(
- string.format('\nFound %s (%-.2f%%) symbols out of %s with no hits in corpus:',
- #symbols_no_hits, nhpct, total_symbols
- )
- )
- for _, symbol in pairs(symbols_no_hits) do
- logger.messagex('%s', symbol)
- end
- end
-
- return
- end
-
- shuffle(logs, messages)
- local train_logs, validation_logs = split_logs(logs, messages,70)
- local cv_logs, test_logs = split_logs(validation_logs[1], validation_logs[2], 50)
-
- local dataset = make_dataset_from_logs(train_logs[1], all_symbols, reject_score)
- -- Start of perceptron training
- local input_size = #all_symbols
-
- local linear_module = nn.Linear(input_size, 1, false)
- local activation = nn.Sigmoid()
-
- local perceptron = nn.Sequential()
- perceptron:add(linear_module)
- perceptron:add(activation)
-
- local criterion = nn.MSECriterion()
- --criterion.sizeAverage = false
-
- local best_fscore = -math.huge
- local best_weights = linear_module.weight[1]:clone()
- local best_learning_rate
- local best_weight_decay
- local all_fps
- local all_fns
-
- for _,lr in ipairs(learning_rates) do
- for _,wd in ipairs(penalty_weights) do
- linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
- local initial_weights = linear_module.weight[1]:clone()
- opts.learning_rate = lr
- opts.weight_decay = wd
- for i=1,tonumber(opts.iters) do
- train(dataset, opts, perceptron, criterion, i, all_symbols, threshold,
- initial_weights)
- end
-
- local fscore, fps, fns = calculate_fscore_from_weights(cv_logs[1],
- cv_logs[2],
- all_symbols,
- linear_module.weight[1],
- threshold)
-
- logger.messagex("Cross-validation fscore=%s, learning rate=%s, weight decay=%s",
- fscore, lr, wd)
-
- if best_fscore < fscore then
- best_learning_rate = lr
- best_weight_decay = wd
- best_fscore = fscore
- best_weights = linear_module.weight[1]:clone()
- all_fps = fps
- all_fns = fns
- end
- end
- end
-
- -- End perceptron training
-
- local new_symbol_scores = best_weights
-
- new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
-
- if opts["output"] then
- write_scores(new_symbol_scores, opts["output"])
- end
-
- if opts["diff"] then
- print_score_diff(new_symbol_scores, original_symbol_scores)
- end
-
- -- Pre-rescore test stats
- logger.message("\n\nPre-rescore test stats\n")
- test_logs[1] = update_logs(test_logs[1], original_symbol_scores)
- print_stats(test_logs[1], test_logs[2], threshold)
-
- -- Post-rescore test stats
- test_logs[1] = update_logs(test_logs[1], new_symbol_scores)
- logger.message("\n\nPost-rescore test stats\n")
- print_stats(test_logs[1], test_logs[2], threshold)
-
- logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s',
- best_fscore, best_learning_rate, best_weight_decay)
-
- -- Show all FPs/FNs, useful for corpus checking and rule creation/modification
- if (all_fps and #all_fps > 0) then
- logger.message("\nFalse-Positives:")
- for _, fp in pairs(all_fps) do
- logger.messagex('%s', fp)
- end
- end
-
- if (all_fns and #all_fns > 0) then
- logger.message("\nFalse-Negatives:")
- for _, fn in pairs(all_fns) do
- logger.messagex('%s', fn)
- end
- end
-end
-
-
-return {
- handler = handler,
- description = parser._description,
- name = 'rescore'
-}
---]]
-
-return nil