|
|
@@ -188,17 +188,18 @@ local function init_weights(all_symbols, original_symbol_scores) |
|
|
|
return weights |
|
|
|
end |
|
|
|
|
|
|
|
local function shuffle(logs) |
|
|
|
local function shuffle(logs, messages) |
|
|
|
|
|
|
|
local size = #logs |
|
|
|
for i = size, 1, -1 do |
|
|
|
local rand = math.random(size) |
|
|
|
logs[i], logs[rand] = logs[rand], logs[i] |
|
|
|
messages[i], messages[rand] = messages[rand], messages[i] |
|
|
|
end |
|
|
|
|
|
|
|
end |
|
|
|
|
|
|
|
local function split_logs(logs, split_percent) |
|
|
|
local function split_logs(logs, messages, split_percent) |
|
|
|
|
|
|
|
if not split_percent then |
|
|
|
split_percent = 60 |
|
|
@@ -208,16 +209,20 @@ local function split_logs(logs, split_percent) |
|
|
|
|
|
|
|
local test_logs = {} |
|
|
|
local train_logs = {} |
|
|
|
local test_messages = {} |
|
|
|
local train_messages = {} |
|
|
|
|
|
|
|
for i=1,split_index do |
|
|
|
train_logs[#train_logs + 1] = logs[i] |
|
|
|
table.insert(train_logs, logs[i]) |
|
|
|
table.insert(train_messages, messages[i]) |
|
|
|
end |
|
|
|
|
|
|
|
for i=split_index + 1, #logs do |
|
|
|
test_logs[#test_logs + 1] = logs[i] |
|
|
|
table.insert(test_logs, logs[i]) |
|
|
|
table.insert(test_messages, messages[i]) |
|
|
|
end |
|
|
|
|
|
|
|
return train_logs, test_logs |
|
|
|
return {train_logs,train_messages}, {test_logs,test_messages} |
|
|
|
end |
|
|
|
|
|
|
|
local function stitch_new_scores(all_symbols, new_scores) |
|
|
@@ -291,7 +296,10 @@ local function print_score_diff(new_symbol_scores, original_symbol_scores) |
|
|
|
|
|
|
|
end |
|
|
|
|
|
|
|
local function calculate_fscore_from_weights(logs, all_symbols, weights, threshold) |
|
|
|
local function calculate_fscore_from_weights(logs, messages, |
|
|
|
all_symbols, |
|
|
|
weights, |
|
|
|
threshold) |
|
|
|
|
|
|
|
local new_symbol_scores = weights:clone() |
|
|
|
|
|
|
@@ -300,14 +308,15 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho |
|
|
|
logs = update_logs(logs, new_symbol_scores) |
|
|
|
|
|
|
|
local file_stats, _, all_fps, all_fns = |
|
|
|
rescore_utility.generate_statistics_from_logs(logs, threshold) |
|
|
|
rescore_utility.generate_statistics_from_logs(logs, messages, threshold) |
|
|
|
|
|
|
|
return file_stats.fscore, all_fps, all_fns |
|
|
|
end |
|
|
|
|
|
|
|
local function print_stats(logs, threshold) |
|
|
|
local function print_stats(logs, messages, threshold) |
|
|
|
|
|
|
|
local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold) |
|
|
|
local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, |
|
|
|
messages, threshold) |
|
|
|
|
|
|
|
local file_stat_format = [[ |
|
|
|
F-score: %.2f |
|
|
@@ -519,7 +528,7 @@ local function handler(args) |
|
|
|
end |
|
|
|
|
|
|
|
local threshold,reject_score = get_threshold() |
|
|
|
local logs = rescore_utility.get_all_logs(opts['log']) |
|
|
|
local logs,messages = rescore_utility.get_all_logs(opts['log']) |
|
|
|
|
|
|
|
if opts['ignore-symbol'] then |
|
|
|
local function add_ignore(s) |
|
|
@@ -574,7 +583,9 @@ local function handler(args) |
|
|
|
|
|
|
|
-- Display hit frequencies |
|
|
|
if opts['freq'] then |
|
|
|
local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold) |
|
|
|
local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, |
|
|
|
messages, |
|
|
|
threshold) |
|
|
|
local t = {} |
|
|
|
for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end |
|
|
|
|
|
|
@@ -607,7 +618,7 @@ local function handler(args) |
|
|
|
end |
|
|
|
|
|
|
|
-- Print file statistics |
|
|
|
print_stats(logs, threshold) |
|
|
|
print_stats(logs, messages, threshold) |
|
|
|
|
|
|
|
-- Work out how many symbols weren't seen in the corpus |
|
|
|
local symbols_no_hits = {} |
|
|
@@ -635,13 +646,13 @@ local function handler(args) |
|
|
|
return |
|
|
|
end |
|
|
|
|
|
|
|
shuffle(logs) |
|
|
|
shuffle(logs, messages) |
|
|
|
torch.setdefaulttensortype('torch.FloatTensor') |
|
|
|
|
|
|
|
local train_logs, validation_logs = split_logs(logs, 70) |
|
|
|
local cv_logs, test_logs = split_logs(validation_logs, 50) |
|
|
|
local train_logs, validation_logs = split_logs(logs, messages,70) |
|
|
|
local cv_logs, test_logs = split_logs(validation_logs[1], validation_logs[2], 50) |
|
|
|
|
|
|
|
local dataset = make_dataset_from_logs(train_logs, all_symbols, reject_score) |
|
|
|
local dataset = make_dataset_from_logs(train_logs[1], all_symbols, reject_score) |
|
|
|
|
|
|
|
-- Start of perceptron training |
|
|
|
local input_size = #all_symbols |
|
|
@@ -675,7 +686,8 @@ local function handler(args) |
|
|
|
initial_weights) |
|
|
|
end |
|
|
|
|
|
|
|
local fscore, fps, fns = calculate_fscore_from_weights(cv_logs, |
|
|
|
local fscore, fps, fns = calculate_fscore_from_weights(cv_logs[1], |
|
|
|
cv_logs[2], |
|
|
|
all_symbols, |
|
|
|
linear_module.weight[1], |
|
|
|
threshold) |
|
|
@@ -710,13 +722,13 @@ local function handler(args) |
|
|
|
|
|
|
|
-- Pre-rescore test stats |
|
|
|
logger.message("\n\nPre-rescore test stats\n") |
|
|
|
test_logs = update_logs(test_logs, original_symbol_scores) |
|
|
|
print_stats(test_logs, threshold) |
|
|
|
test_logs[1] = update_logs(test_logs[1], original_symbol_scores) |
|
|
|
print_stats(test_logs[1], test_logs[2], threshold) |
|
|
|
|
|
|
|
-- Post-rescore test stats |
|
|
|
test_logs = update_logs(test_logs, new_symbol_scores) |
|
|
|
test_logs[1] = update_logs(test_logs[1], new_symbol_scores) |
|
|
|
logger.message("\n\nPost-rescore test stats\n") |
|
|
|
print_stats(test_logs, threshold) |
|
|
|
print_stats(test_logs[1], test_logs[2], threshold) |
|
|
|
|
|
|
|
logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s', |
|
|
|
best_fscore, best_learning_rate, best_weight_decay) |