aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-12-07 20:18:49 +0000
committerGitHub <noreply@github.com>2017-12-07 20:18:49 +0000
commit09f3015da643b82e24b054f1704aa6783bfc79e8 (patch)
tree96c522272ea9373433938c2dbc97ec25e3bf703f
parent1e929b744952674120545cbbd1643b6fd1910aab (diff)
parent703bd13d5bedc30ed9bbeb7180d3cd083fc0e1f4 (diff)
downloadrspamd-09f3015da643b82e24b054f1704aa6783bfc79e8.tar.gz
rspamd-09f3015da643b82e24b054f1704aa6783bfc79e8.zip
Merge pull request #1946 from cpragadeesh/rescore-filter
[Feature] added corpus_test, rescore commands
-rw-r--r--lualib/rspamadm/corpus_test.lua126
-rw-r--r--lualib/rspamadm/rescore.lua298
-rw-r--r--lualib/rspamadm/rescore_utility.lua214
-rw-r--r--src/rspamadm/CMakeLists.txt2
-rw-r--r--src/rspamadm/commands.c4
-rw-r--r--src/rspamadm/corpus_test.c121
-rw-r--r--src/rspamadm/rescore.c141
7 files changed, 906 insertions, 0 deletions
diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua
new file mode 100644
index 000000000..b29fa5602
--- /dev/null
+++ b/lualib/rspamadm/corpus_test.lua
@@ -0,0 +1,126 @@
+local ucl = require "ucl"
+local lua_util = require "lua_util"
+
+local HAM = "HAM"
+local SPAM = "SPAM"
+
+local function scan_email(n_parellel, path)
+
+ local rspamc_command = string.format("rspamc -j --compact -n %s %s", n_parellel, path)
+ local result = assert(io.popen(rspamc_command))
+ result = result:read("*all")
+ return result
+end
+
+local function write_results(results, file)
+
+ local f = io.open(file, 'w')
+
+ for _, result in pairs(results) do
+ local log_line = string.format("%s %.2f %s", result.type, result.score, result.action)
+
+ for _, sym in pairs(result.symbols) do
+ log_line = log_line .. " " .. sym
+ end
+
+ log_line = log_line .. "\r\n"
+
+ f:write(log_line)
+ end
+
+ f:close()
+end
+
+local function encoded_json_to_log(result)
+ -- Returns table containing score, action, list of symbols
+
+ local filtered_result = {}
+ local parser = ucl.parser()
+
+ local is_good, err = parser:parse_string(result)
+
+ if not is_good then
+ print(err)
+ os.exit()
+ end
+
+ result = parser:get_object()
+
+ filtered_result.score = result.score
+ local action = result.action:gsub("%s+", "_")
+ filtered_result.action = action
+
+ filtered_result.symbols = {}
+
+ for sym, _ in pairs(result.symbols) do
+ table.insert(filtered_result.symbols, sym)
+ end
+
+ return filtered_result
+end
+
+local function scan_results_to_logs(results, actual_email_type)
+
+ local logs = {}
+
+ results = lua_util.rspamd_str_split(results, "\n")
+
+ if results[#results] == "" then
+ results[#results] = nil
+ end
+
+ for _, result in pairs(results) do
+ result = encoded_json_to_log(result)
+ result['type'] = actual_email_type
+ table.insert(logs, result)
+ end
+
+ return logs
+end
+
+return function (_, res)
+
+ local ham_directory = res['ham_directory']
+ local spam_directory = res['spam_directory']
+ local connections = res["connections"]
+ local output = res["output_location"]
+
+ local results = {}
+
+ local start_time = os.time()
+ local no_of_ham = 0
+ local no_of_spam = 0
+
+ if ham_directory then
+ io.write("Scanning ham corpus...\n")
+ local ham_results = scan_email(connections, ham_directory)
+ ham_results = scan_results_to_logs(ham_results, HAM)
+
+ no_of_ham = #ham_results
+
+ for _, result in pairs(ham_results) do
+ table.insert(results, result)
+ end
+ end
+
+ if spam_directory then
+ io.write("Scanning spam corpus...\n")
+ local spam_results = scan_email(connections, spam_directory)
+ spam_results = scan_results_to_logs(spam_results, SPAM)
+
+ no_of_spam = #spam_results
+
+ for _, result in pairs(spam_results) do
+ table.insert(results, result)
+ end
+ end
+
+ io.write(string.format("Writing results to %s\n", output))
+ write_results(results, output)
+
+ io.write("\nStats: \n")
+ io.write(string.format("Elapsed time: %ds\n", os.time() - start_time))
+ io.write(string.format("No of ham: %d\n", no_of_ham))
+ io.write(string.format("No of spam: %d\n", no_of_spam))
+
+end \ No newline at end of file
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
new file mode 100644
index 000000000..538122f68
--- /dev/null
+++ b/lualib/rspamadm/rescore.lua
@@ -0,0 +1,298 @@
+local torch = require "torch"
+local nn = require "nn"
+local lua_util = require "lua_util"
+local ucl = require "ucl"
+
+local rescore_utility = require "rspamadm/rescore_utility"
+
+local function make_dataset_from_logs(logs, all_symbols)
+ -- Returns a list of {input, output} for torch SGD train
+
+ local dataset = {}
+
+ for _, log in pairs(logs) do
+ local input = torch.Tensor(#all_symbols)
+ local output = torch.Tensor(1)
+ log = lua_util.rspamd_str_split(log, " ")
+
+ if log[1] == "SPAM" then
+ output[1] = 1
+ else
+ output[1] = 0
+ end
+
+ local symbols_set = {}
+
+ for i=4,#log do
+ symbols_set[log[i]] = true
+ end
+
+ for index, symbol in pairs(all_symbols) do
+ if symbols_set[symbol] then
+ input[index] = 1
+ else
+ input[index] = 0
+ end
+ end
+
+ dataset[#dataset + 1] = {input, output}
+
+ end
+
+ function dataset:size()
+ return #dataset
+ end
+
+ return dataset
+end
+
+local function init_weights(all_symbols, original_symbol_scores)
+
+ local weights = torch.Tensor(#all_symbols)
+
+ local mean = 0
+
+ for i, symbol in pairs(all_symbols) do
+ local score = original_symbol_scores[symbol]
+ if not score then score = 0 end
+ weights[i] = score
+ mean = mean + score
+ end
+
+ return weights
+end
+
+local function shuffle(logs)
+
+ local size = #logs
+ for i = size, 1, -1 do
+ local rand = math.random(size)
+ logs[i], logs[rand] = logs[rand], logs[i]
+ end
+
+end
+
+local function split_logs(logs, split_percent)
+
+ if not split_percent then
+ split_percent = 60
+ end
+
+ local split_index = math.floor(#logs * split_percent / 100)
+
+ local test_logs = {}
+ local train_logs = {}
+
+ for i=1,split_index do
+ train_logs[#train_logs + 1] = logs[i]
+ end
+
+ for i=split_index + 1, #logs do
+ test_logs[#test_logs + 1] = logs[i]
+ end
+
+ return train_logs, test_logs
+end
+
+local function stitch_new_scores(all_symbols, new_scores)
+
+ local new_symbol_scores = {}
+
+ for idx, symbol in pairs(all_symbols) do
+ new_symbol_scores[symbol] = new_scores[idx]
+ end
+
+ return new_symbol_scores
+end
+
+
+local function update_logs(logs, symbol_scores)
+
+ for i, log in ipairs(logs) do
+
+ log = lua_util.rspamd_str_split(log, " ")
+
+ local score = 0
+
+ for j=4,#log do
+ log[j] = log[j]:gsub("%s+", "")
+ score = score + (symbol_scores[log[j ]] or 0)
+ end
+
+ log[2] = rescore_utility.round(score, 2)
+
+ logs[i] = table.concat(log, " ")
+ end
+
+ return logs
+end
+
+local function write_scores(new_symbol_scores, file_path)
+
+ local file = assert(io.open(file_path, "w"))
+
+ local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl")
+
+ file:write(new_scores_ucl)
+
+ file:close()
+end
+
+local function print_score_diff(new_symbol_scores, original_symbol_scores)
+
+ print(string.format("%-35s %-10s %-10s", "SYMBOL", "OLD_SCORE", "NEW_SCORE"))
+
+ for symbol, new_score in pairs(new_symbol_scores) do
+ print(string.format("%-35s %-10s %-10s",
+ symbol,
+ original_symbol_scores[symbol] or 0,
+ rescore_utility.round(new_score, 2)))
+ end
+
+ print "\nClass changes \n"
+ for symbol, new_score in pairs(new_symbol_scores) do
+ if original_symbol_scores[symbol] ~= nil then
+ if (original_symbol_scores[symbol] > 0 and new_score < 0) or
+ (original_symbol_scores[symbol] < 0 and new_score > 0) then
+ print(string.format("%-35s %-10s %-10s",
+ symbol,
+ original_symbol_scores[symbol] or 0,
+ rescore_utility.round(new_score, 2)))
+ end
+ end
+ end
+
+end
+
+local function calculate_fscore_from_weights(logs, all_symbols, weights, bias, threshold)
+
+ local new_symbol_scores = weights:clone()
+
+ new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
+
+ logs = update_logs(logs, new_symbol_scores)
+
+ local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+
+ return file_stats.fscore
+end
+
+local function print_stats(logs, threshold)
+
+ local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+
+ local file_stat_format = [[
+F-score: %.2f
+False positive rate: %.2f %%
+False negative rate: %.2f %%
+Overall accuracy: %.2f %%
+]]
+
+ io.write("\nStatistics at threshold: " .. threshold .. "\n")
+
+ io.write(string.format(file_stat_format,
+ file_stats.fscore,
+ file_stats.false_positive_rate,
+ file_stats.false_negative_rate,
+ file_stats.overall_accuracy))
+
+end
+
+return function (_, res)
+
+ local logs = rescore_utility.get_all_logs(res["logdir"])
+ local all_symbols = rescore_utility.get_all_symbols(logs)
+ local original_symbol_scores = rescore_utility.get_all_symbol_scores()
+
+ shuffle(logs)
+
+ local train_logs, validation_logs = split_logs(logs, 70)
+ local cv_logs, test_logs = split_logs(validation_logs, 50)
+
+ local dataset = make_dataset_from_logs(train_logs, all_symbols)
+
+ local learning_rates = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10}
+ local penalty_weights = {0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100}
+
+ -- Start of perceptron training
+
+ local input_size = #all_symbols
+ local linear_module = nn.Linear(input_size, 1)
+
+ local perceptron = nn.Sequential()
+ perceptron:add(linear_module)
+
+ local activation = nn.Sigmoid()
+
+ perceptron:add(activation)
+
+ local criterion = nn.MSECriterion()
+ criterion.sizeAverage = false
+
+ local best_fscore = -math.huge
+ local best_weights = linear_module.weight[1]:clone()
+
+ local trainer = nn.StochasticGradient(perceptron, criterion)
+ trainer.maxIteration = res["iters"]
+ trainer.verbose = false
+
+ trainer.hookIteration = function(self, iteration, error)
+
+ if iteration == trainer.maxIteration then
+
+ local fscore = calculate_fscore_from_weights(cv_logs,
+ all_symbols,
+ linear_module.weight[1],
+ linear_module.bias[1],
+ res["threshold"])
+
+ print("Cross-validation fscore: " .. fscore)
+
+ if best_fscore < fscore then
+ best_fscore = fscore
+ best_weights = linear_module.weight[1]:clone()
+ end
+ end
+ end
+
+ for _, learning_rate in pairs(learning_rates) do
+ for _, weight in pairs(penalty_weights) do
+
+ trainer.weightDecay = weight
+ print("Learning with learning_rate: " .. learning_rate
+ .. " | l2_weight: " .. weight)
+
+ linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
+
+ trainer.learningRate = learning_rate
+ trainer:train(dataset)
+
+ print()
+ end
+ end
+
+ -- End perceptron training
+
+ local new_symbol_scores = best_weights
+
+ new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
+
+ if res["output"] then
+ write_scores(new_symbol_scores, res["output"])
+ end
+
+ if res["diff"] then
+ print_score_diff(new_symbol_scores, original_symbol_scores)
+ end
+
+
+ -- Pre-rescore test stats
+ print("\n\nPre-rescore test stats\n")
+ test_logs = update_logs(test_logs, original_symbol_scores)
+ print_stats(test_logs, res['threshold'])
+
+ -- Post-rescore test stats
+ test_logs = update_logs(test_logs, new_symbol_scores)
+ print("\n\nPost-rescore test stats\n")
+ print_stats(test_logs, res['threshold'])
+end \ No newline at end of file
diff --git a/lualib/rspamadm/rescore_utility.lua b/lualib/rspamadm/rescore_utility.lua
new file mode 100644
index 000000000..4c6504e76
--- /dev/null
+++ b/lualib/rspamadm/rescore_utility.lua
@@ -0,0 +1,214 @@
+local ucl = require "ucl"
+local lua_util = require "lua_util"
+local rspamd_util = require "rspamd_util"
+
+local utility = {}
+
+function utility.round(num, places)
+ return string.format("%." .. (places or 0) .. "f", num)
+end
+
+function utility.get_all_symbols(logs)
+ -- Returns a list of all symbols
+
+ local symbols_set = {}
+
+ for _, line in pairs(logs) do
+ line = lua_util.rspamd_str_split(line, " ")
+ for i=4,#line do
+ line[i] = line[i]:gsub("%s+", "")
+ if not symbols_set[line[i]] then
+ symbols_set[line[i]] = true
+ end
+ end
+ end
+
+ local all_symbols = {}
+
+ for symbol, _ in pairs(symbols_set) do
+ all_symbols[#all_symbols + 1] = symbol
+ end
+
+ table.sort(all_symbols)
+
+ return all_symbols
+end
+
+function utility.read_log_file(file)
+
+ local lines = {}
+
+ file = assert(io.open(file, "r"))
+
+ for line in file:lines() do
+ lines[#lines + 1] = line
+ end
+
+ io.close(file)
+
+ return lines
+end
+
+function utility.get_all_logs(dir_path)
+ -- Reads all log files in the directory and returns a list of logs.
+
+ if dir_path:sub(#dir_path, #dir_path) == "/" then
+ dir_path = dir_path:sub(1, #dir_path -1)
+ end
+
+ local files = rspamd_util.glob(dir_path .. "/*")
+ local all_logs = {}
+
+ for _, file in pairs(files) do
+ local logs = utility.read_log_file(file)
+ for _, log_line in pairs(logs) do
+ all_logs[#all_logs + 1] = log_line
+ end
+ end
+
+ return all_logs
+end
+
+function utility.get_all_symbol_scores()
+
+ local output = assert(io.popen("rspamc counters -j --compact"))
+ output = output:read("*all")
+
+ local parser = ucl.parser()
+ local result, err = parser:parse_string(output)
+
+ if not result then
+ print(err)
+ os.exit()
+ end
+
+ output = parser:get_object()
+
+ local symbol_scores = {}
+
+ for _, symbol_info in pairs(output) do
+ symbol_scores[symbol_info.symbol] = symbol_info.weight
+ end
+
+ return symbol_scores
+end
+
+function utility.generate_statistics_from_logs(logs, threshold)
+
+ -- Returns file_stats table and list of symbol_stats table.
+
+ local file_stats = {
+ no_of_emails = 0,
+ no_of_spam = 0,
+ no_of_ham = 0,
+ spam_percent = 0,
+ ham_percent = 0,
+ true_positives = 0,
+ true_negatives = 0,
+ false_negative_rate = 0,
+ false_positive_rate = 0,
+ overall_accuracy = 0,
+ fscore = 0
+ }
+
+ local all_symbols_stats = {}
+
+ local false_positives = 0
+ local false_negatives = 0
+ local true_positives = 0
+ local true_negatives = 0
+ local no_of_emails = 0
+ local no_of_spam = 0
+ local no_of_ham = 0
+
+ for _, log in pairs(logs) do
+ log = lua_util.rspamd_str_trim(log)
+ log = lua_util.rspamd_str_split(log, " ")
+
+ local is_spam = (log[1] == "SPAM")
+ local score = tonumber(log[2])
+
+ no_of_emails = no_of_emails + 1
+
+ if is_spam then
+ no_of_spam = no_of_spam + 1
+ else
+ no_of_ham = no_of_ham + 1
+ end
+
+ if is_spam and (score >= threshold) then
+ true_positives = true_positives + 1
+ elseif is_spam and (score < threshold) then
+ false_negatives = false_negatives + 1
+ elseif not is_spam and (score >= threshold) then
+ false_positives = false_positives + 1
+ else
+ true_negatives = true_negatives + 1
+ end
+
+ for i=4, #log do
+ if all_symbols_stats[log[i]] == nil then
+ all_symbols_stats[log[i]] = {
+ name = log[i],
+ no_of_hits = 0,
+ spam_hits = 0,
+ ham_hits = 0,
+ spam_overall = 0
+ }
+ end
+
+ all_symbols_stats[log[i]].no_of_hits =
+ all_symbols_stats[log[i]].no_of_hits + 1
+
+ if is_spam then
+ all_symbols_stats[log[i]].spam_hits =
+ all_symbols_stats[log[i]].spam_hits + 1
+ else
+ all_symbols_stats[log[i]].ham_hits =
+ all_symbols_stats[log[i]].ham_hits + 1
+ end
+ end
+ end
+
+ -- Calculating file stats
+
+ file_stats.no_of_ham = no_of_ham
+ file_stats.no_of_spam = no_of_spam
+ file_stats.no_of_emails = no_of_emails
+ file_stats.true_positives = true_positives
+ file_stats.true_negatives = true_negatives
+
+ if no_of_emails > 0 then
+ file_stats.spam_percent = no_of_spam * 100 / no_of_emails
+ file_stats.ham_percent = no_of_ham * 100 / no_of_emails
+ file_stats.overall_accuracy = (true_positives + true_negatives) * 100 /
+ no_of_emails
+ end
+
+ if no_of_ham > 0 then
+ file_stats.false_positive_rate = false_positives * 100 / no_of_ham
+ end
+
+ if no_of_spam > 0 then
+ file_stats.false_negative_rate = false_negatives * 100 / no_of_spam
+ end
+
+ file_stats.fscore = 2 * true_positives / (2
+ * true_positives
+ + false_positives
+ + false_negatives)
+
+ -- Calculating symbol stats
+
+ for _, symbol_stats in pairs(all_symbols_stats) do
+ symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam
+ symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham
+ symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails
+ symbol_stats.spam_overall = symbol_stats.spam_percent /
+ (symbol_stats.spam_percent + symbol_stats.ham_percent)
+ end
+
+ return file_stats, all_symbols_stats
+end
+
+return utility
diff --git a/src/rspamadm/CMakeLists.txt b/src/rspamadm/CMakeLists.txt
index 7dfaad691..fb3f25229 100644
--- a/src/rspamadm/CMakeLists.txt
+++ b/src/rspamadm/CMakeLists.txt
@@ -10,10 +10,12 @@ SET(RSPAMADMSRC rspamadm.c
control.c
confighelp.c
configwizard.c
+ corpus_test.c
stat_convert.c
signtool.c
lua_repl.c
dkim_keygen.c
+ rescore.c
${CMAKE_BINARY_DIR}/src/workers.c
${CMAKE_BINARY_DIR}/src/modules.c
${CMAKE_SOURCE_DIR}/src/controller.c
diff --git a/src/rspamadm/commands.c b/src/rspamadm/commands.c
index 1eaa45c81..410306fe3 100644
--- a/src/rspamadm/commands.c
+++ b/src/rspamadm/commands.c
@@ -29,6 +29,8 @@ extern struct rspamadm_command signtool_command;
extern struct rspamadm_command lua_command;
extern struct rspamadm_command dkim_keygen_command;
extern struct rspamadm_command configwizard_command;
+extern struct rspamadm_command corpus_test_command;
+extern struct rspamadm_command rescore_command;
const struct rspamadm_command *commands[] = {
&help_command,
@@ -46,6 +48,8 @@ const struct rspamadm_command *commands[] = {
&lua_command,
&dkim_keygen_command,
&configwizard_command,
+ &corpus_test_command,
+ &rescore_command,
NULL
};
diff --git a/src/rspamadm/corpus_test.c b/src/rspamadm/corpus_test.c
new file mode 100644
index 000000000..62aecb148
--- /dev/null
+++ b/src/rspamadm/corpus_test.c
@@ -0,0 +1,121 @@
+/*-
+ * Copyright 2017 Pragadeesh C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "rspamadm.h"
+#include "config.h"
+#include "lua/lua_common.h"
+
+static gchar *ham_directory = NULL;
+static gchar *spam_directory = NULL;
+static gchar *output_location = "results.log";
+static gint connections = 10;
+
+static void rspamadm_corpus_test (gint argc, gchar **argv);
+static const char *rspamadm_corpus_test_help (gboolean full_help);
+
+struct rspamadm_command corpus_test_command = {
+ .name = "corpus_test",
+ .flags = 0,
+ .help = rspamadm_corpus_test_help,
+ .run = rspamadm_corpus_test
+};
+
+// TODO add -nparellel and -o options
+static GOptionEntry entries[] = {
+ {"ham", 'a', 0, G_OPTION_ARG_FILENAME, &ham_directory,
+ "Ham directory", NULL},
+ {"spam", 's', 0, G_OPTION_ARG_FILENAME, &spam_directory,
+ "Spam directory", NULL},
+ {"output", 'o', 0, G_OPTION_ARG_FILENAME, &output_location,
+ "Log output location", NULL},
+ {"connections", 'n', 0, G_OPTION_ARG_INT, &connections,
+ "Number of parellel connections [Default: 10]", NULL},
+ {NULL, 0, 0, G_OPTION_ARG_NONE, NULL, NULL, NULL}
+};
+
+static const char *
+rspamadm_corpus_test_help (gboolean full_help)
+{
+ const char *help_str;
+
+ if (full_help) {
+ help_str = "Create logs files from email corpus\n\n"
+ "Usage: rspamadm corpus_test [-a <ham_directory>]"
+ " [-s <spam_directory>]\n"
+ "Where option are:\n\n"
+ "-a: path to ham directory\n"
+ "-s: path to spam directory\n"
+ "-n: maximum parellel connections\n"
+ "-o: log output file\n";
+
+ }
+
+ else {
+ help_str = "Create logs files from email corpus";
+ }
+
+ return help_str;
+}
+
+static void
+rspamadm_corpus_test (gint argc, gchar **argv)
+{
+ GOptionContext *context;
+ GError *error = NULL;
+ lua_State *L;
+ ucl_object_t *obj;
+
+ context = g_option_context_new (
+ "corpus_test - Create logs files from email corpus");
+
+ g_option_context_set_summary (context,
+ "Summary:\n Rspamd administration utility version "
+ RVERSION
+ "\n Release id: "
+ RID);
+
+ g_option_context_add_main_entries (context, entries, NULL);
+ g_option_context_set_ignore_unknown_options (context, TRUE);
+
+ if (!g_option_context_parse (context, &argc, &argv, &error)) {
+ rspamd_fprintf (stderr, "option parsing failed: %s\n", error->message);
+ g_error_free (error);
+ exit(1);
+ }
+
+ L = rspamd_lua_init ();
+ rspamd_lua_set_path(L, NULL, NULL);
+
+
+ obj = ucl_object_typed_new (UCL_OBJECT);
+ ucl_object_insert_key (obj, ucl_object_fromstring (ham_directory),
+ "ham_directory", 0, false);
+ ucl_object_insert_key (obj, ucl_object_fromstring (spam_directory),
+ "spam_directory", 0, false);
+ ucl_object_insert_key (obj, ucl_object_fromstring (output_location),
+ "output_location", 0, false);
+ ucl_object_insert_key (obj, ucl_object_fromint (connections),
+ "connections", 0, false);
+
+ rspamadm_execute_lua_ucl_subr (L,
+ argc,
+ argv,
+ obj,
+ "corpus_test");
+
+ lua_close (L);
+ ucl_object_unref (obj);
+}
diff --git a/src/rspamadm/rescore.c b/src/rspamadm/rescore.c
new file mode 100644
index 000000000..de5ace272
--- /dev/null
+++ b/src/rspamadm/rescore.c
@@ -0,0 +1,141 @@
+/*-
+ * Copyright 2017 Pragadeesh C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "rspamadm.h"
+#include "config.h"
+#include "lua/lua_common.h"
+
+#if !defined(WITH_TORCH) || !defined(WITH_LUAJIT)
+ #define HAS_TORCH false
+#else
+ #define HAS_TORCH true
+#endif
+
+static gchar *logdir = NULL;
+static gchar *output = "new.scores";
+static gdouble threshold = 15; // Spam threshold
+static gboolean score_diff = false; // Print score diff flag
+static gint64 iters = 500; // Perceptron max iterations
+
+static void rspamadm_rescore (gint argc, gchar **argv);
+static const char *rspamadm_rescore_help (gboolean full_help);
+
+struct rspamadm_command rescore_command = {
+ .name = "rescore",
+ .flags = 0,
+ .help = rspamadm_rescore_help,
+ .run = rspamadm_rescore
+};
+
+static GOptionEntry entries[] = {
+ {"logdir", 'l', 0, G_OPTION_ARG_FILENAME, &logdir,
+ "Logs directory", NULL},
+ {"output", 'o', 0, G_OPTION_ARG_FILENAME, &output,
+ "Scores output locaiton", NULL},
+ {"diff", 'd', 0, G_OPTION_ARG_NONE, &score_diff,
+ "Print score diff", NULL},
+ {"iters", 'i', 0, G_OPTION_ARG_INT64, &iters,
+ "Max iterations for perceptron [Default: 500]", NULL},
+ {NULL, 0, 0, G_OPTION_ARG_NONE, NULL, NULL, NULL}
+};
+
+static const char *
+rspamadm_rescore_help (gboolean full_help)
+{
+
+ const char *help_str;
+
+ if (full_help) {
+ help_str = "Estimate optimal symbol weights from log files\n\n"
+ "Usage: rspamadm rescore -l <log_directory>\n"
+ "Where options are:\n\n"
+ "-l: path to logs directory\n"
+ "-o: Scores output file location\n"
+ "-d: Print scores diff\n"
+ "-i: Max iterations for perceptron\n";
+ }
+
+ else {
+ help_str = "Estimate optimal symbol weights from log files";
+ }
+
+ return help_str;
+}
+
+static void
+rspamadm_rescore (gint argc, gchar **argv)
+{
+
+ GOptionContext *context;
+ GError *error = NULL;
+ lua_State *L;
+ ucl_object_t *obj;
+
+ context = g_option_context_new (
+ "rescore - Estimate optimal symbol weights from log files");
+
+ g_option_context_set_summary (context,
+ "Summary:\n Rspamd administration utility version "
+ RVERSION
+ "\n Release id: "
+ RID);
+
+ g_option_context_add_main_entries (context, entries, NULL);
+ g_option_context_set_ignore_unknown_options (context, TRUE);
+
+ if (!g_option_context_parse (context, &argc, &argv, &error)) {
+ rspamd_fprintf (stderr, "option parsing failed: %s\n", error->message);
+ g_error_free (error);
+ exit(1);
+ }
+
+ if (!HAS_TORCH) {
+ rspamd_fprintf (stderr, "Torch is not enabled. "
+ "Use -DENABLE_TORCH=ON option while running cmake.\n");
+ exit (1);
+ }
+
+ if (logdir == NULL) {
+ rspamd_fprintf (stderr, "Please specify log directory.\n");
+ exit (1);
+ }
+
+ L = rspamd_lua_init ();
+
+ rspamd_lua_set_path(L, NULL, NULL);
+
+ obj = ucl_object_typed_new (UCL_OBJECT);
+
+ ucl_object_insert_key (obj, ucl_object_fromstring (logdir),
+ "logdir", 0, false);
+ ucl_object_insert_key (obj, ucl_object_fromstring (output),
+ "output", 0, false);
+ ucl_object_insert_key (obj, ucl_object_fromdouble (threshold),
+ "threshold", 0, false);
+ ucl_object_insert_key (obj, ucl_object_fromint (iters),
+ "iters", 0, false);
+ ucl_object_insert_key (obj, ucl_object_frombool (score_diff),
+ "diff", 0, false);
+
+ rspamadm_execute_lua_ucl_subr (L,
+ argc,
+ argv,
+ obj,
+ "rescore");
+
+ lua_close (L);
+ ucl_object_unref (obj);
+} \ No newline at end of file