Просмотр исходного кода

[Minor] Further fixes to rescore tool

tags/1.7.6
Vsevolod Stakhov 6 лет назад
Родитель
Сommit
5d0d46d815
2 измененных файлов: 73 добавлений и 52 удалений
  1. 40
    31
      lualib/rescore_utility.lua
  2. 33
    21
      lualib/rspamadm/rescore.lua

+ 40
- 31
lualib/rescore_utility.lua Просмотреть файл

@@ -11,7 +11,7 @@ function utility.get_all_symbols(logs, ignore_symbols)

for _, line in pairs(logs) do
line = lua_util.rspamd_str_split(line, " ")
for i=4,(#line-2) do
for i=4,(#line-1) do
line[i] = line[i]:gsub("%s+", "")
if not symbols_set[line[i]] then
symbols_set[line[i]] = true
@@ -35,16 +35,23 @@ end
function utility.read_log_file(file)

local lines = {}
local messages = {}

file = assert(io.open(file, "r"))
local fd = assert(io.open(file, "r"))
local fname = string.gsub(file, "(.*/)(.*)", "%2")

for line in file:lines() do
lines[#lines + 1] = line
end
for line in fd:lines() do
local start,stop = string.find(line, fname .. ':')

if start and stop then
table.insert(lines, string.sub(line, 1, start))
table.insert(messages, string.sub(line, stop + 1, -1))
end
end

io.close(file)
io.close(fd)

return lines
return lines,messages
end

function utility.get_all_logs(dirs)
@@ -55,26 +62,29 @@ function utility.get_all_logs(dirs)
end

local all_logs = {}
local all_messages = {}

for _,dir in ipairs(dirs) do
if dir:sub(-1, -1) == "/" then
dir = dir:sub(1, -2)
local files = rspamd_util.glob(dir .. "/*.log")
for _, file in pairs(files) do
local logs = utility.read_log_file(file)
for _, log_line in pairs(logs) do
table.insert(all_logs, log_line)
local logs,messages = utility.read_log_file(file)
for i=1,#logs do
table.insert(all_logs, logs[i])
table.insert(all_messages, messages[i])
end
end
else
local logs = utility.read_log_file(dir)
for _, log_line in pairs(logs) do
table.insert(all_logs, log_line)
local logs,messages = utility.read_log_file(dir)
for i=1,#logs do
table.insert(all_logs, logs[i])
table.insert(all_messages, messages[i])
end
end
end

return all_logs
return all_logs,all_messages
end

function utility.get_all_symbol_scores(conf, ignore_symbols)
@@ -87,7 +97,7 @@ function utility.get_all_symbol_scores(conf, ignore_symbols)
end, symbols)))
end

function utility.generate_statistics_from_logs(logs, threshold)
function utility.generate_statistics_from_logs(logs, messages, threshold)

-- Returns file_stats table and list of symbol_stats table.

@@ -120,9 +130,10 @@ function utility.generate_statistics_from_logs(logs, threshold)
local no_of_spam = 0
local no_of_ham = 0

for _, log in pairs(logs) do
for i, log in ipairs(logs) do
log = lua_util.rspamd_str_trim(log)
log = lua_util.rspamd_str_split(log, " ")
local message = messages[i]

local is_spam = (log[1] == "SPAM")
local score = tonumber(log[2])
@@ -139,40 +150,38 @@ function utility.generate_statistics_from_logs(logs, threshold)
true_positives = true_positives + 1
elseif is_spam and (score < threshold) then
false_negatives = false_negatives + 1
table.insert(all_fns, log[#log])
table.insert(all_fns, message)
elseif not is_spam and (score >= threshold) then
false_positives = false_positives + 1
table.insert(all_fps, log[#log])
table.insert(all_fps, message)
else
true_negatives = true_negatives + 1
end

for i=4, (#log-2) do
if all_symbols_stats[log[i]] == nil then
all_symbols_stats[log[i]] = {
name = log[i],
for j=4, (#log-1) do
if all_symbols_stats[log[j]] == nil then
all_symbols_stats[log[j]] = {
name = message,
no_of_hits = 0,
spam_hits = 0,
ham_hits = 0,
spam_overall = 0
}
end
local sym = log[j]

all_symbols_stats[log[i]].no_of_hits =
all_symbols_stats[log[i]].no_of_hits + 1
all_symbols_stats[sym].no_of_hits = all_symbols_stats[sym].no_of_hits + 1

if is_spam then
all_symbols_stats[log[i]].spam_hits =
all_symbols_stats[log[i]].spam_hits + 1
all_symbols_stats[sym].spam_hits = all_symbols_stats[sym].spam_hits + 1
else
all_symbols_stats[log[i]].ham_hits =
all_symbols_stats[log[i]].ham_hits + 1
all_symbols_stats[sym].ham_hits = all_symbols_stats[sym].ham_hits + 1
end

-- Find slowest message
if ((tonumber(log[#log-1]) or 0) > file_stats.slowest) then
file_stats.slowest = tonumber(log[#log-1])
file_stats.slowest_file = log[#log]
if ((tonumber(log[#log]) or 0) > file_stats.slowest) then
file_stats.slowest = tonumber(log[#log])
file_stats.slowest_file = message
end
end
end

+ 33
- 21
lualib/rspamadm/rescore.lua Просмотреть файл

@@ -188,17 +188,18 @@ local function init_weights(all_symbols, original_symbol_scores)
return weights
end

local function shuffle(logs)
local function shuffle(logs, messages)

local size = #logs
for i = size, 1, -1 do
local rand = math.random(size)
logs[i], logs[rand] = logs[rand], logs[i]
messages[i], messages[rand] = messages[rand], messages[i]
end

end

local function split_logs(logs, split_percent)
local function split_logs(logs, messages, split_percent)

if not split_percent then
split_percent = 60
@@ -208,16 +209,20 @@ local function split_logs(logs, split_percent)

local test_logs = {}
local train_logs = {}
local test_messages = {}
local train_messages = {}

for i=1,split_index do
train_logs[#train_logs + 1] = logs[i]
table.insert(train_logs, logs[i])
table.insert(train_messages, messages[i])
end

for i=split_index + 1, #logs do
test_logs[#test_logs + 1] = logs[i]
table.insert(test_logs, logs[i])
table.insert(test_messages, messages[i])
end

return train_logs, test_logs
return {train_logs,train_messages}, {test_logs,test_messages}
end

local function stitch_new_scores(all_symbols, new_scores)
@@ -291,7 +296,10 @@ local function print_score_diff(new_symbol_scores, original_symbol_scores)

end

local function calculate_fscore_from_weights(logs, all_symbols, weights, threshold)
local function calculate_fscore_from_weights(logs, messages,
all_symbols,
weights,
threshold)

local new_symbol_scores = weights:clone()

@@ -300,14 +308,15 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho
logs = update_logs(logs, new_symbol_scores)

local file_stats, _, all_fps, all_fns =
rescore_utility.generate_statistics_from_logs(logs, threshold)
rescore_utility.generate_statistics_from_logs(logs, messages, threshold)

return file_stats.fscore, all_fps, all_fns
end

local function print_stats(logs, threshold)
local function print_stats(logs, messages, threshold)

local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs,
messages, threshold)

local file_stat_format = [[
F-score: %.2f
@@ -519,7 +528,7 @@ local function handler(args)
end

local threshold,reject_score = get_threshold()
local logs = rescore_utility.get_all_logs(opts['log'])
local logs,messages = rescore_utility.get_all_logs(opts['log'])

if opts['ignore-symbol'] then
local function add_ignore(s)
@@ -574,7 +583,9 @@ local function handler(args)

-- Display hit frequencies
if opts['freq'] then
local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold)
local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs,
messages,
threshold)
local t = {}
for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end

@@ -607,7 +618,7 @@ local function handler(args)
end

-- Print file statistics
print_stats(logs, threshold)
print_stats(logs, messages, threshold)

-- Work out how many symbols weren't seen in the corpus
local symbols_no_hits = {}
@@ -635,13 +646,13 @@ local function handler(args)
return
end

shuffle(logs)
shuffle(logs, messages)
torch.setdefaulttensortype('torch.FloatTensor')

local train_logs, validation_logs = split_logs(logs, 70)
local cv_logs, test_logs = split_logs(validation_logs, 50)
local train_logs, validation_logs = split_logs(logs, messages,70)
local cv_logs, test_logs = split_logs(validation_logs[1], validation_logs[2], 50)

local dataset = make_dataset_from_logs(train_logs, all_symbols, reject_score)
local dataset = make_dataset_from_logs(train_logs[1], all_symbols, reject_score)

-- Start of perceptron training
local input_size = #all_symbols
@@ -675,7 +686,8 @@ local function handler(args)
initial_weights)
end

local fscore, fps, fns = calculate_fscore_from_weights(cv_logs,
local fscore, fps, fns = calculate_fscore_from_weights(cv_logs[1],
cv_logs[2],
all_symbols,
linear_module.weight[1],
threshold)
@@ -710,13 +722,13 @@ local function handler(args)

-- Pre-rescore test stats
logger.message("\n\nPre-rescore test stats\n")
test_logs = update_logs(test_logs, original_symbol_scores)
print_stats(test_logs, threshold)
test_logs[1] = update_logs(test_logs[1], original_symbol_scores)
print_stats(test_logs[1], test_logs[2], threshold)

-- Post-rescore test stats
test_logs = update_logs(test_logs, new_symbol_scores)
test_logs[1] = update_logs(test_logs[1], new_symbol_scores)
logger.message("\n\nPost-rescore test stats\n")
print_stats(test_logs, threshold)
print_stats(test_logs[1], test_logs[2], threshold)

logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s',
best_fscore, best_learning_rate, best_weight_decay)

Загрузка…
Отмена
Сохранить