From 703bd13d5bedc30ed9bbeb7180d3cd083fc0e1f4 Mon Sep 17 00:00:00 2001 From: Pragadeesh C Date: Thu, 1 Jun 2017 16:07:28 -0700 Subject: [PATCH] added corpus_test, rescore commands --- lualib/rspamadm/corpus_test.lua | 126 ++++++++++++ lualib/rspamadm/rescore.lua | 298 ++++++++++++++++++++++++++++ lualib/rspamadm/rescore_utility.lua | 214 ++++++++++++++++++++ src/rspamadm/CMakeLists.txt | 2 + src/rspamadm/commands.c | 4 + src/rspamadm/corpus_test.c | 121 +++++++++++ src/rspamadm/rescore.c | 141 +++++++++++++ 7 files changed, 906 insertions(+) create mode 100644 lualib/rspamadm/corpus_test.lua create mode 100644 lualib/rspamadm/rescore.lua create mode 100644 lualib/rspamadm/rescore_utility.lua create mode 100644 src/rspamadm/corpus_test.c create mode 100644 src/rspamadm/rescore.c diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua new file mode 100644 index 000000000..b29fa5602 --- /dev/null +++ b/lualib/rspamadm/corpus_test.lua @@ -0,0 +1,126 @@ +local ucl = require "ucl" +local lua_util = require "lua_util" + +local HAM = "HAM" +local SPAM = "SPAM" + +local function scan_email(n_parellel, path) + + local rspamc_command = string.format("rspamc -j --compact -n %s %s", n_parellel, path) + local result = assert(io.popen(rspamc_command)) + result = result:read("*all") + return result +end + +local function write_results(results, file) + + local f = io.open(file, 'w') + + for _, result in pairs(results) do + local log_line = string.format("%s %.2f %s", result.type, result.score, result.action) + + for _, sym in pairs(result.symbols) do + log_line = log_line .. " " .. sym + end + + log_line = log_line .. "\r\n" + + f:write(log_line) + end + + f:close() +end + +local function encoded_json_to_log(result) + -- Returns table containing score, action, list of symbols + + local filtered_result = {} + local parser = ucl.parser() + + local is_good, err = parser:parse_string(result) + + if not is_good then + print(err) + os.exit() + end + + result = parser:get_object() + + filtered_result.score = result.score + local action = result.action:gsub("%s+", "_") + filtered_result.action = action + + filtered_result.symbols = {} + + for sym, _ in pairs(result.symbols) do + table.insert(filtered_result.symbols, sym) + end + + return filtered_result +end + +local function scan_results_to_logs(results, actual_email_type) + + local logs = {} + + results = lua_util.rspamd_str_split(results, "\n") + + if results[#results] == "" then + results[#results] = nil + end + + for _, result in pairs(results) do + result = encoded_json_to_log(result) + result['type'] = actual_email_type + table.insert(logs, result) + end + + return logs +end + +return function (_, res) + + local ham_directory = res['ham_directory'] + local spam_directory = res['spam_directory'] + local connections = res["connections"] + local output = res["output_location"] + + local results = {} + + local start_time = os.time() + local no_of_ham = 0 + local no_of_spam = 0 + + if ham_directory then + io.write("Scanning ham corpus...\n") + local ham_results = scan_email(connections, ham_directory) + ham_results = scan_results_to_logs(ham_results, HAM) + + no_of_ham = #ham_results + + for _, result in pairs(ham_results) do + table.insert(results, result) + end + end + + if spam_directory then + io.write("Scanning spam corpus...\n") + local spam_results = scan_email(connections, spam_directory) + spam_results = scan_results_to_logs(spam_results, SPAM) + + no_of_spam = #spam_results + + for _, result in pairs(spam_results) do + table.insert(results, result) + end + end + + io.write(string.format("Writing results to %s\n", output)) + write_results(results, output) + + io.write("\nStats: \n") + io.write(string.format("Elapsed time: %ds\n", os.time() - start_time)) + io.write(string.format("No of ham: %d\n", no_of_ham)) + io.write(string.format("No of spam: %d\n", no_of_spam)) + +end \ No newline at end of file diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua new file mode 100644 index 000000000..538122f68 --- /dev/null +++ b/lualib/rspamadm/rescore.lua @@ -0,0 +1,298 @@ +local torch = require "torch" +local nn = require "nn" +local lua_util = require "lua_util" +local ucl = require "ucl" + +local rescore_utility = require "rspamadm/rescore_utility" + +local function make_dataset_from_logs(logs, all_symbols) + -- Returns a list of {input, output} for torch SGD train + + local dataset = {} + + for _, log in pairs(logs) do + local input = torch.Tensor(#all_symbols) + local output = torch.Tensor(1) + log = lua_util.rspamd_str_split(log, " ") + + if log[1] == "SPAM" then + output[1] = 1 + else + output[1] = 0 + end + + local symbols_set = {} + + for i=4,#log do + symbols_set[log[i]] = true + end + + for index, symbol in pairs(all_symbols) do + if symbols_set[symbol] then + input[index] = 1 + else + input[index] = 0 + end + end + + dataset[#dataset + 1] = {input, output} + + end + + function dataset:size() + return #dataset + end + + return dataset +end + +local function init_weights(all_symbols, original_symbol_scores) + + local weights = torch.Tensor(#all_symbols) + + local mean = 0 + + for i, symbol in pairs(all_symbols) do + local score = original_symbol_scores[symbol] + if not score then score = 0 end + weights[i] = score + mean = mean + score + end + + return weights +end + +local function shuffle(logs) + + local size = #logs + for i = size, 1, -1 do + local rand = math.random(size) + logs[i], logs[rand] = logs[rand], logs[i] + end + +end + +local function split_logs(logs, split_percent) + + if not split_percent then + split_percent = 60 + end + + local split_index = math.floor(#logs * split_percent / 100) + + local test_logs = {} + local train_logs = {} + + for i=1,split_index do + train_logs[#train_logs + 1] = logs[i] + end + + for i=split_index + 1, #logs do + test_logs[#test_logs + 1] = logs[i] + end + + return train_logs, test_logs +end + +local function stitch_new_scores(all_symbols, new_scores) + + local new_symbol_scores = {} + + for idx, symbol in pairs(all_symbols) do + new_symbol_scores[symbol] = new_scores[idx] + end + + return new_symbol_scores +end + + +local function update_logs(logs, symbol_scores) + + for i, log in ipairs(logs) do + + log = lua_util.rspamd_str_split(log, " ") + + local score = 0 + + for j=4,#log do + log[j] = log[j]:gsub("%s+", "") + score = score + (symbol_scores[log[j ]] or 0) + end + + log[2] = rescore_utility.round(score, 2) + + logs[i] = table.concat(log, " ") + end + + return logs +end + +local function write_scores(new_symbol_scores, file_path) + + local file = assert(io.open(file_path, "w")) + + local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl") + + file:write(new_scores_ucl) + + file:close() +end + +local function print_score_diff(new_symbol_scores, original_symbol_scores) + + print(string.format("%-35s %-10s %-10s", "SYMBOL", "OLD_SCORE", "NEW_SCORE")) + + for symbol, new_score in pairs(new_symbol_scores) do + print(string.format("%-35s %-10s %-10s", + symbol, + original_symbol_scores[symbol] or 0, + rescore_utility.round(new_score, 2))) + end + + print "\nClass changes \n" + for symbol, new_score in pairs(new_symbol_scores) do + if original_symbol_scores[symbol] ~= nil then + if (original_symbol_scores[symbol] > 0 and new_score < 0) or + (original_symbol_scores[symbol] < 0 and new_score > 0) then + print(string.format("%-35s %-10s %-10s", + symbol, + original_symbol_scores[symbol] or 0, + rescore_utility.round(new_score, 2))) + end + end + end + +end + +local function calculate_fscore_from_weights(logs, all_symbols, weights, bias, threshold) + + local new_symbol_scores = weights:clone() + + new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) + + logs = update_logs(logs, new_symbol_scores) + + local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold) + + return file_stats.fscore +end + +local function print_stats(logs, threshold) + + local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold) + + local file_stat_format = [[ +F-score: %.2f +False positive rate: %.2f %% +False negative rate: %.2f %% +Overall accuracy: %.2f %% +]] + + io.write("\nStatistics at threshold: " .. threshold .. "\n") + + io.write(string.format(file_stat_format, + file_stats.fscore, + file_stats.false_positive_rate, + file_stats.false_negative_rate, + file_stats.overall_accuracy)) + +end + +return function (_, res) + + local logs = rescore_utility.get_all_logs(res["logdir"]) + local all_symbols = rescore_utility.get_all_symbols(logs) + local original_symbol_scores = rescore_utility.get_all_symbol_scores() + + shuffle(logs) + + local train_logs, validation_logs = split_logs(logs, 70) + local cv_logs, test_logs = split_logs(validation_logs, 50) + + local dataset = make_dataset_from_logs(train_logs, all_symbols) + + local learning_rates = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10} + local penalty_weights = {0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100} + + -- Start of perceptron training + + local input_size = #all_symbols + local linear_module = nn.Linear(input_size, 1) + + local perceptron = nn.Sequential() + perceptron:add(linear_module) + + local activation = nn.Sigmoid() + + perceptron:add(activation) + + local criterion = nn.MSECriterion() + criterion.sizeAverage = false + + local best_fscore = -math.huge + local best_weights = linear_module.weight[1]:clone() + + local trainer = nn.StochasticGradient(perceptron, criterion) + trainer.maxIteration = res["iters"] + trainer.verbose = false + + trainer.hookIteration = function(self, iteration, error) + + if iteration == trainer.maxIteration then + + local fscore = calculate_fscore_from_weights(cv_logs, + all_symbols, + linear_module.weight[1], + linear_module.bias[1], + res["threshold"]) + + print("Cross-validation fscore: " .. fscore) + + if best_fscore < fscore then + best_fscore = fscore + best_weights = linear_module.weight[1]:clone() + end + end + end + + for _, learning_rate in pairs(learning_rates) do + for _, weight in pairs(penalty_weights) do + + trainer.weightDecay = weight + print("Learning with learning_rate: " .. learning_rate + .. " | l2_weight: " .. weight) + + linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores) + + trainer.learningRate = learning_rate + trainer:train(dataset) + + print() + end + end + + -- End perceptron training + + local new_symbol_scores = best_weights + + new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) + + if res["output"] then + write_scores(new_symbol_scores, res["output"]) + end + + if res["diff"] then + print_score_diff(new_symbol_scores, original_symbol_scores) + end + + + -- Pre-rescore test stats + print("\n\nPre-rescore test stats\n") + test_logs = update_logs(test_logs, original_symbol_scores) + print_stats(test_logs, res['threshold']) + + -- Post-rescore test stats + test_logs = update_logs(test_logs, new_symbol_scores) + print("\n\nPost-rescore test stats\n") + print_stats(test_logs, res['threshold']) +end \ No newline at end of file diff --git a/lualib/rspamadm/rescore_utility.lua b/lualib/rspamadm/rescore_utility.lua new file mode 100644 index 000000000..4c6504e76 --- /dev/null +++ b/lualib/rspamadm/rescore_utility.lua @@ -0,0 +1,214 @@ +local ucl = require "ucl" +local lua_util = require "lua_util" +local rspamd_util = require "rspamd_util" + +local utility = {} + +function utility.round(num, places) + return string.format("%." .. (places or 0) .. "f", num) +end + +function utility.get_all_symbols(logs) + -- Returns a list of all symbols + + local symbols_set = {} + + for _, line in pairs(logs) do + line = lua_util.rspamd_str_split(line, " ") + for i=4,#line do + line[i] = line[i]:gsub("%s+", "") + if not symbols_set[line[i]] then + symbols_set[line[i]] = true + end + end + end + + local all_symbols = {} + + for symbol, _ in pairs(symbols_set) do + all_symbols[#all_symbols + 1] = symbol + end + + table.sort(all_symbols) + + return all_symbols +end + +function utility.read_log_file(file) + + local lines = {} + + file = assert(io.open(file, "r")) + + for line in file:lines() do + lines[#lines + 1] = line + end + + io.close(file) + + return lines +end + +function utility.get_all_logs(dir_path) + -- Reads all log files in the directory and returns a list of logs. + + if dir_path:sub(#dir_path, #dir_path) == "/" then + dir_path = dir_path:sub(1, #dir_path -1) + end + + local files = rspamd_util.glob(dir_path .. "/*") + local all_logs = {} + + for _, file in pairs(files) do + local logs = utility.read_log_file(file) + for _, log_line in pairs(logs) do + all_logs[#all_logs + 1] = log_line + end + end + + return all_logs +end + +function utility.get_all_symbol_scores() + + local output = assert(io.popen("rspamc counters -j --compact")) + output = output:read("*all") + + local parser = ucl.parser() + local result, err = parser:parse_string(output) + + if not result then + print(err) + os.exit() + end + + output = parser:get_object() + + local symbol_scores = {} + + for _, symbol_info in pairs(output) do + symbol_scores[symbol_info.symbol] = symbol_info.weight + end + + return symbol_scores +end + +function utility.generate_statistics_from_logs(logs, threshold) + + -- Returns file_stats table and list of symbol_stats table. + + local file_stats = { + no_of_emails = 0, + no_of_spam = 0, + no_of_ham = 0, + spam_percent = 0, + ham_percent = 0, + true_positives = 0, + true_negatives = 0, + false_negative_rate = 0, + false_positive_rate = 0, + overall_accuracy = 0, + fscore = 0 + } + + local all_symbols_stats = {} + + local false_positives = 0 + local false_negatives = 0 + local true_positives = 0 + local true_negatives = 0 + local no_of_emails = 0 + local no_of_spam = 0 + local no_of_ham = 0 + + for _, log in pairs(logs) do + log = lua_util.rspamd_str_trim(log) + log = lua_util.rspamd_str_split(log, " ") + + local is_spam = (log[1] == "SPAM") + local score = tonumber(log[2]) + + no_of_emails = no_of_emails + 1 + + if is_spam then + no_of_spam = no_of_spam + 1 + else + no_of_ham = no_of_ham + 1 + end + + if is_spam and (score >= threshold) then + true_positives = true_positives + 1 + elseif is_spam and (score < threshold) then + false_negatives = false_negatives + 1 + elseif not is_spam and (score >= threshold) then + false_positives = false_positives + 1 + else + true_negatives = true_negatives + 1 + end + + for i=4, #log do + if all_symbols_stats[log[i]] == nil then + all_symbols_stats[log[i]] = { + name = log[i], + no_of_hits = 0, + spam_hits = 0, + ham_hits = 0, + spam_overall = 0 + } + end + + all_symbols_stats[log[i]].no_of_hits = + all_symbols_stats[log[i]].no_of_hits + 1 + + if is_spam then + all_symbols_stats[log[i]].spam_hits = + all_symbols_stats[log[i]].spam_hits + 1 + else + all_symbols_stats[log[i]].ham_hits = + all_symbols_stats[log[i]].ham_hits + 1 + end + end + end + + -- Calculating file stats + + file_stats.no_of_ham = no_of_ham + file_stats.no_of_spam = no_of_spam + file_stats.no_of_emails = no_of_emails + file_stats.true_positives = true_positives + file_stats.true_negatives = true_negatives + + if no_of_emails > 0 then + file_stats.spam_percent = no_of_spam * 100 / no_of_emails + file_stats.ham_percent = no_of_ham * 100 / no_of_emails + file_stats.overall_accuracy = (true_positives + true_negatives) * 100 / + no_of_emails + end + + if no_of_ham > 0 then + file_stats.false_positive_rate = false_positives * 100 / no_of_ham + end + + if no_of_spam > 0 then + file_stats.false_negative_rate = false_negatives * 100 / no_of_spam + end + + file_stats.fscore = 2 * true_positives / (2 + * true_positives + + false_positives + + false_negatives) + + -- Calculating symbol stats + + for _, symbol_stats in pairs(all_symbols_stats) do + symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam + symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham + symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails + symbol_stats.spam_overall = symbol_stats.spam_percent / + (symbol_stats.spam_percent + symbol_stats.ham_percent) + end + + return file_stats, all_symbols_stats +end + +return utility diff --git a/src/rspamadm/CMakeLists.txt b/src/rspamadm/CMakeLists.txt index 7dfaad691..fb3f25229 100644 --- a/src/rspamadm/CMakeLists.txt +++ b/src/rspamadm/CMakeLists.txt @@ -10,10 +10,12 @@ SET(RSPAMADMSRC rspamadm.c control.c confighelp.c configwizard.c + corpus_test.c stat_convert.c signtool.c lua_repl.c dkim_keygen.c + rescore.c ${CMAKE_BINARY_DIR}/src/workers.c ${CMAKE_BINARY_DIR}/src/modules.c ${CMAKE_SOURCE_DIR}/src/controller.c diff --git a/src/rspamadm/commands.c b/src/rspamadm/commands.c index 1eaa45c81..410306fe3 100644 --- a/src/rspamadm/commands.c +++ b/src/rspamadm/commands.c @@ -29,6 +29,8 @@ extern struct rspamadm_command signtool_command; extern struct rspamadm_command lua_command; extern struct rspamadm_command dkim_keygen_command; extern struct rspamadm_command configwizard_command; +extern struct rspamadm_command corpus_test_command; +extern struct rspamadm_command rescore_command; const struct rspamadm_command *commands[] = { &help_command, @@ -46,6 +48,8 @@ const struct rspamadm_command *commands[] = { &lua_command, &dkim_keygen_command, &configwizard_command, + &corpus_test_command, + &rescore_command, NULL }; diff --git a/src/rspamadm/corpus_test.c b/src/rspamadm/corpus_test.c new file mode 100644 index 000000000..62aecb148 --- /dev/null +++ b/src/rspamadm/corpus_test.c @@ -0,0 +1,121 @@ +/*- + * Copyright 2017 Pragadeesh C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rspamadm.h" +#include "config.h" +#include "lua/lua_common.h" + +static gchar *ham_directory = NULL; +static gchar *spam_directory = NULL; +static gchar *output_location = "results.log"; +static gint connections = 10; + +static void rspamadm_corpus_test (gint argc, gchar **argv); +static const char *rspamadm_corpus_test_help (gboolean full_help); + +struct rspamadm_command corpus_test_command = { + .name = "corpus_test", + .flags = 0, + .help = rspamadm_corpus_test_help, + .run = rspamadm_corpus_test +}; + +// TODO add -nparellel and -o options +static GOptionEntry entries[] = { + {"ham", 'a', 0, G_OPTION_ARG_FILENAME, &ham_directory, + "Ham directory", NULL}, + {"spam", 's', 0, G_OPTION_ARG_FILENAME, &spam_directory, + "Spam directory", NULL}, + {"output", 'o', 0, G_OPTION_ARG_FILENAME, &output_location, + "Log output location", NULL}, + {"connections", 'n', 0, G_OPTION_ARG_INT, &connections, + "Number of parellel connections [Default: 10]", NULL}, + {NULL, 0, 0, G_OPTION_ARG_NONE, NULL, NULL, NULL} +}; + +static const char * +rspamadm_corpus_test_help (gboolean full_help) +{ + const char *help_str; + + if (full_help) { + help_str = "Create logs files from email corpus\n\n" + "Usage: rspamadm corpus_test [-a ]" + " [-s ]\n" + "Where option are:\n\n" + "-a: path to ham directory\n" + "-s: path to spam directory\n" + "-n: maximum parellel connections\n" + "-o: log output file\n"; + + } + + else { + help_str = "Create logs files from email corpus"; + } + + return help_str; +} + +static void +rspamadm_corpus_test (gint argc, gchar **argv) +{ + GOptionContext *context; + GError *error = NULL; + lua_State *L; + ucl_object_t *obj; + + context = g_option_context_new ( + "corpus_test - Create logs files from email corpus"); + + g_option_context_set_summary (context, + "Summary:\n Rspamd administration utility version " + RVERSION + "\n Release id: " + RID); + + g_option_context_add_main_entries (context, entries, NULL); + g_option_context_set_ignore_unknown_options (context, TRUE); + + if (!g_option_context_parse (context, &argc, &argv, &error)) { + rspamd_fprintf (stderr, "option parsing failed: %s\n", error->message); + g_error_free (error); + exit(1); + } + + L = rspamd_lua_init (); + rspamd_lua_set_path(L, NULL, NULL); + + + obj = ucl_object_typed_new (UCL_OBJECT); + ucl_object_insert_key (obj, ucl_object_fromstring (ham_directory), + "ham_directory", 0, false); + ucl_object_insert_key (obj, ucl_object_fromstring (spam_directory), + "spam_directory", 0, false); + ucl_object_insert_key (obj, ucl_object_fromstring (output_location), + "output_location", 0, false); + ucl_object_insert_key (obj, ucl_object_fromint (connections), + "connections", 0, false); + + rspamadm_execute_lua_ucl_subr (L, + argc, + argv, + obj, + "corpus_test"); + + lua_close (L); + ucl_object_unref (obj); +} diff --git a/src/rspamadm/rescore.c b/src/rspamadm/rescore.c new file mode 100644 index 000000000..de5ace272 --- /dev/null +++ b/src/rspamadm/rescore.c @@ -0,0 +1,141 @@ +/*- + * Copyright 2017 Pragadeesh C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rspamadm.h" +#include "config.h" +#include "lua/lua_common.h" + +#if !defined(WITH_TORCH) || !defined(WITH_LUAJIT) + #define HAS_TORCH false +#else + #define HAS_TORCH true +#endif + +static gchar *logdir = NULL; +static gchar *output = "new.scores"; +static gdouble threshold = 15; // Spam threshold +static gboolean score_diff = false; // Print score diff flag +static gint64 iters = 500; // Perceptron max iterations + +static void rspamadm_rescore (gint argc, gchar **argv); +static const char *rspamadm_rescore_help (gboolean full_help); + +struct rspamadm_command rescore_command = { + .name = "rescore", + .flags = 0, + .help = rspamadm_rescore_help, + .run = rspamadm_rescore +}; + +static GOptionEntry entries[] = { + {"logdir", 'l', 0, G_OPTION_ARG_FILENAME, &logdir, + "Logs directory", NULL}, + {"output", 'o', 0, G_OPTION_ARG_FILENAME, &output, + "Scores output locaiton", NULL}, + {"diff", 'd', 0, G_OPTION_ARG_NONE, &score_diff, + "Print score diff", NULL}, + {"iters", 'i', 0, G_OPTION_ARG_INT64, &iters, + "Max iterations for perceptron [Default: 500]", NULL}, + {NULL, 0, 0, G_OPTION_ARG_NONE, NULL, NULL, NULL} +}; + +static const char * +rspamadm_rescore_help (gboolean full_help) +{ + + const char *help_str; + + if (full_help) { + help_str = "Estimate optimal symbol weights from log files\n\n" + "Usage: rspamadm rescore -l \n" + "Where options are:\n\n" + "-l: path to logs directory\n" + "-o: Scores output file location\n" + "-d: Print scores diff\n" + "-i: Max iterations for perceptron\n"; + } + + else { + help_str = "Estimate optimal symbol weights from log files"; + } + + return help_str; +} + +static void +rspamadm_rescore (gint argc, gchar **argv) +{ + + GOptionContext *context; + GError *error = NULL; + lua_State *L; + ucl_object_t *obj; + + context = g_option_context_new ( + "rescore - Estimate optimal symbol weights from log files"); + + g_option_context_set_summary (context, + "Summary:\n Rspamd administration utility version " + RVERSION + "\n Release id: " + RID); + + g_option_context_add_main_entries (context, entries, NULL); + g_option_context_set_ignore_unknown_options (context, TRUE); + + if (!g_option_context_parse (context, &argc, &argv, &error)) { + rspamd_fprintf (stderr, "option parsing failed: %s\n", error->message); + g_error_free (error); + exit(1); + } + + if (!HAS_TORCH) { + rspamd_fprintf (stderr, "Torch is not enabled. " + "Use -DENABLE_TORCH=ON option while running cmake.\n"); + exit (1); + } + + if (logdir == NULL) { + rspamd_fprintf (stderr, "Please specify log directory.\n"); + exit (1); + } + + L = rspamd_lua_init (); + + rspamd_lua_set_path(L, NULL, NULL); + + obj = ucl_object_typed_new (UCL_OBJECT); + + ucl_object_insert_key (obj, ucl_object_fromstring (logdir), + "logdir", 0, false); + ucl_object_insert_key (obj, ucl_object_fromstring (output), + "output", 0, false); + ucl_object_insert_key (obj, ucl_object_fromdouble (threshold), + "threshold", 0, false); + ucl_object_insert_key (obj, ucl_object_fromint (iters), + "iters", 0, false); + ucl_object_insert_key (obj, ucl_object_frombool (score_diff), + "diff", 0, false); + + rspamadm_execute_lua_ucl_subr (L, + argc, + argv, + obj, + "rescore"); + + lua_close (L); + ucl_object_unref (obj); +} \ No newline at end of file -- 2.39.5