]> source.dussan.org Git - rspamd.git/commitdiff
added corpus_test, rescore commands 1946/head
authorPragadeesh C <cpragadeesh@gmail.com>
Thu, 1 Jun 2017 23:07:28 +0000 (16:07 -0700)
committerPragadeesh C <cpragadeesh@gmail.com>
Thu, 7 Dec 2017 16:22:58 +0000 (21:52 +0530)
lualib/rspamadm/corpus_test.lua [new file with mode: 0644]
lualib/rspamadm/rescore.lua [new file with mode: 0644]
lualib/rspamadm/rescore_utility.lua [new file with mode: 0644]
src/rspamadm/CMakeLists.txt
src/rspamadm/commands.c
src/rspamadm/corpus_test.c [new file with mode: 0644]
src/rspamadm/rescore.c [new file with mode: 0644]

diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua
new file mode 100644 (file)
index 0000000..b29fa56
--- /dev/null
@@ -0,0 +1,126 @@
+local ucl = require "ucl"
+local lua_util = require "lua_util"
+
+local HAM = "HAM"
+local SPAM = "SPAM"
+
+local function scan_email(n_parellel, path)
+
+    local rspamc_command = string.format("rspamc -j --compact -n %s %s", n_parellel, path)
+    local result = assert(io.popen(rspamc_command))
+    result = result:read("*all")
+    return result
+end   
+
+local function write_results(results, file)
+
+    local f = io.open(file, 'w')
+
+    for _, result in pairs(results) do
+        local log_line = string.format("%s %.2f %s", result.type, result.score, result.action)
+
+        for _, sym in pairs(result.symbols) do
+            log_line = log_line .. " " .. sym
+        end
+
+        log_line = log_line .. "\r\n"
+
+        f:write(log_line)
+    end
+
+    f:close()
+end
+
+local function encoded_json_to_log(result)
+   -- Returns table containing score, action, list of symbols
+
+    local filtered_result = {}
+    local parser = ucl.parser()
+
+    local is_good, err = parser:parse_string(result)
+
+    if not is_good then
+        print(err)
+        os.exit()
+    end
+
+    result = parser:get_object()
+
+    filtered_result.score = result.score
+    local action = result.action:gsub("%s+", "_")
+    filtered_result.action = action
+
+    filtered_result.symbols = {}
+
+    for sym, _ in pairs(result.symbols) do
+        table.insert(filtered_result.symbols, sym)
+    end
+
+    return filtered_result   
+end
+
+local function scan_results_to_logs(results, actual_email_type)
+
+    local logs = {}
+
+    results = lua_util.rspamd_str_split(results, "\n")
+
+    if results[#results] == "" then
+        results[#results] = nil
+    end
+
+    for _, result in pairs(results) do      
+        result = encoded_json_to_log(result)
+        result['type'] = actual_email_type
+        table.insert(logs, result)
+    end
+
+    return logs
+end
+
+return function (_, res)
+
+    local ham_directory = res['ham_directory']
+    local spam_directory = res['spam_directory']
+    local connections = res["connections"]
+    local output = res["output_location"]
+
+    local results = {}
+
+    local start_time = os.time()
+    local no_of_ham = 0
+    local no_of_spam = 0
+
+    if ham_directory then
+        io.write("Scanning ham corpus...\n")
+        local ham_results = scan_email(connections, ham_directory)
+        ham_results = scan_results_to_logs(ham_results, HAM)
+
+        no_of_ham = #ham_results
+
+        for _, result in pairs(ham_results) do
+            table.insert(results, result)
+        end
+    end
+
+    if spam_directory then
+        io.write("Scanning spam corpus...\n")
+        local spam_results = scan_email(connections, spam_directory)
+        spam_results = scan_results_to_logs(spam_results, SPAM)
+
+        no_of_spam = #spam_results
+
+        for _, result in pairs(spam_results) do
+            table.insert(results, result)
+        end
+    end
+
+    io.write(string.format("Writing results to %s\n", output))
+    write_results(results, output)
+
+    io.write("\nStats: \n")
+    io.write(string.format("Elapsed time: %ds\n", os.time() - start_time))
+    io.write(string.format("No of ham: %d\n", no_of_ham))
+    io.write(string.format("No of spam: %d\n", no_of_spam))
+
+end
\ No newline at end of file
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
new file mode 100644 (file)
index 0000000..538122f
--- /dev/null
@@ -0,0 +1,298 @@
+local torch = require "torch"
+local nn = require "nn"
+local lua_util = require "lua_util"
+local ucl = require "ucl"
+
+local rescore_utility = require "rspamadm/rescore_utility"
+
+local function make_dataset_from_logs(logs, all_symbols)
+       -- Returns a list of {input, output} for torch SGD train
+
+       local dataset = {}
+
+       for _, log in pairs(logs) do
+               local input = torch.Tensor(#all_symbols)
+               local output = torch.Tensor(1)
+               log = lua_util.rspamd_str_split(log, " ")
+
+               if log[1] == "SPAM" then
+                       output[1] = 1
+               else
+                       output[1] = 0
+               end
+
+               local symbols_set = {}
+
+               for i=4,#log do
+                       symbols_set[log[i]] = true
+               end
+
+               for index, symbol in pairs(all_symbols) do
+                       if symbols_set[symbol] then
+                               input[index] = 1
+                       else
+                               input[index] = 0
+                       end     
+               end
+
+               dataset[#dataset + 1] = {input, output}
+
+       end
+
+       function dataset:size()
+               return #dataset
+       end
+
+       return dataset
+end
+
+local function init_weights(all_symbols, original_symbol_scores)
+
+       local weights = torch.Tensor(#all_symbols)
+
+       local mean = 0
+
+       for i, symbol in pairs(all_symbols) do
+               local score = original_symbol_scores[symbol]
+               if not score then score = 0 end
+               weights[i] = score
+               mean = mean + score
+       end
+
+       return weights   
+end
+
+local function shuffle(logs)
+
+       local size = #logs
+       for i = size, 1, -1 do
+               local rand = math.random(size)
+               logs[i], logs[rand] = logs[rand], logs[i]
+       end
+
+end
+
+local function split_logs(logs, split_percent)
+
+       if not split_percent then
+               split_percent = 60
+       end
+
+       local split_index = math.floor(#logs * split_percent / 100)
+
+       local test_logs = {}
+       local train_logs = {}
+
+       for i=1,split_index do
+               train_logs[#train_logs + 1] = logs[i]
+       end   
+
+       for i=split_index + 1, #logs do
+               test_logs[#test_logs + 1] = logs[i]
+       end
+
+       return train_logs, test_logs
+end
+
+local function stitch_new_scores(all_symbols, new_scores)
+
+       local new_symbol_scores = {}
+
+       for idx, symbol in pairs(all_symbols) do
+               new_symbol_scores[symbol] = new_scores[idx]
+       end
+
+       return new_symbol_scores
+end
+
+
+local function update_logs(logs, symbol_scores)
+
+       for i, log in ipairs(logs) do
+
+               log = lua_util.rspamd_str_split(log, " ")
+
+               local score = 0
+
+               for j=4,#log do
+                       log[j] = log[j]:gsub("%s+", "")
+                       score = score + (symbol_scores[log[j    ]] or 0)
+               end
+
+               log[2] = rescore_utility.round(score, 2)
+
+               logs[i] = table.concat(log, " ")
+       end
+
+       return logs
+end
+
+local function write_scores(new_symbol_scores, file_path)
+
+       local file = assert(io.open(file_path, "w"))
+
+       local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl")
+
+       file:write(new_scores_ucl)
+
+       file:close()
+end
+
+local function print_score_diff(new_symbol_scores, original_symbol_scores)
+
+       print(string.format("%-35s %-10s %-10s", "SYMBOL", "OLD_SCORE", "NEW_SCORE"))
+
+       for symbol, new_score in pairs(new_symbol_scores) do
+       print(string.format("%-35s %-10s %-10s",
+                 symbol,
+                 original_symbol_scores[symbol] or 0,
+                 rescore_utility.round(new_score, 2)))
+       end
+
+       print "\nClass changes \n"
+       for symbol, new_score in pairs(new_symbol_scores) do
+               if original_symbol_scores[symbol] ~= nil then
+                       if (original_symbol_scores[symbol] > 0 and new_score < 0) or
+                               (original_symbol_scores[symbol] < 0 and new_score > 0) then
+                               print(string.format("%-35s %-10s %-10s",
+                                               symbol,
+                                               original_symbol_scores[symbol] or 0,
+                                               rescore_utility.round(new_score, 2)))
+                       end
+               end
+       end
+
+end
+
+local function calculate_fscore_from_weights(logs, all_symbols, weights, bias, threshold)
+
+       local new_symbol_scores = weights:clone()
+
+       new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
+
+       logs = update_logs(logs, new_symbol_scores)
+
+       local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+
+       return file_stats.fscore
+end
+
+local function print_stats(logs, threshold)
+
+       local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+
+       local file_stat_format = [[
+F-score: %.2f
+False positive rate: %.2f %%
+False negative rate: %.2f %%
+Overall accuracy: %.2f %%
+]]
+
+       io.write("\nStatistics at threshold: " .. threshold .. "\n")
+
+       io.write(string.format(file_stat_format,
+                       file_stats.fscore,
+                       file_stats.false_positive_rate,
+                       file_stats.false_negative_rate,
+                       file_stats.overall_accuracy))
+
+end
+
+return function (_, res)
+
+       local logs = rescore_utility.get_all_logs(res["logdir"])
+       local all_symbols = rescore_utility.get_all_symbols(logs)
+       local original_symbol_scores = rescore_utility.get_all_symbol_scores()
+
+       shuffle(logs)
+
+       local train_logs, validation_logs = split_logs(logs, 70)
+       local cv_logs, test_logs = split_logs(validation_logs, 50)
+
+       local dataset = make_dataset_from_logs(train_logs, all_symbols)
+
+       local learning_rates = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10}
+       local penalty_weights = {0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100}
+
+       -- Start of perceptron training
+
+       local input_size = #all_symbols
+       local linear_module = nn.Linear(input_size, 1)
+
+       local perceptron = nn.Sequential()
+       perceptron:add(linear_module)
+
+       local activation = nn.Sigmoid()
+
+       perceptron:add(activation)
+
+       local criterion = nn.MSECriterion()
+       criterion.sizeAverage = false
+
+       local best_fscore = -math.huge
+       local best_weights = linear_module.weight[1]:clone()
+
+       local trainer = nn.StochasticGradient(perceptron, criterion)
+       trainer.maxIteration = res["iters"]
+       trainer.verbose = false
+
+       trainer.hookIteration = function(self, iteration, error)
+
+               if iteration == trainer.maxIteration then
+
+                       local fscore = calculate_fscore_from_weights(cv_logs,
+                                       all_symbols,
+                                       linear_module.weight[1],
+                                       linear_module.bias[1],
+                                       res["threshold"])
+
+                       print("Cross-validation fscore: " .. fscore)
+
+                       if best_fscore < fscore then
+                               best_fscore = fscore
+                               best_weights = linear_module.weight[1]:clone()
+                       end
+               end
+       end
+
+       for _, learning_rate in pairs(learning_rates) do
+               for _, weight in pairs(penalty_weights) do
+
+                       trainer.weightDecay = weight
+                       print("Learning with learning_rate: " .. learning_rate 
+                               .. " | l2_weight: " .. weight)
+
+                       linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
+
+                       trainer.learningRate = learning_rate
+                       trainer:train(dataset)
+
+                       print()
+               end   
+       end
+
+       -- End perceptron training
+
+       local new_symbol_scores = best_weights
+
+       new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
+
+       if res["output"] then
+               write_scores(new_symbol_scores, res["output"])
+       end
+
+       if res["diff"] then
+               print_score_diff(new_symbol_scores, original_symbol_scores)
+       end
+
+
+       -- Pre-rescore test stats
+       print("\n\nPre-rescore test stats\n")
+       test_logs = update_logs(test_logs, original_symbol_scores)
+       print_stats(test_logs, res['threshold'])
+
+       -- Post-rescore test stats
+       test_logs = update_logs(test_logs, new_symbol_scores)
+       print("\n\nPost-rescore test stats\n")
+       print_stats(test_logs, res['threshold'])
+end
\ No newline at end of file
diff --git a/lualib/rspamadm/rescore_utility.lua b/lualib/rspamadm/rescore_utility.lua
new file mode 100644 (file)
index 0000000..4c6504e
--- /dev/null
@@ -0,0 +1,214 @@
+local ucl = require "ucl"
+local lua_util = require "lua_util"
+local rspamd_util = require "rspamd_util"
+
+local utility = {}
+
+function utility.round(num, places)
+   return string.format("%." .. (places or 0) .. "f", num)
+end
+
+function utility.get_all_symbols(logs)
+   -- Returns a list of all symbols
+
+       local symbols_set = {}
+
+       for _, line in pairs(logs) do
+               line = lua_util.rspamd_str_split(line, " ")
+               for i=4,#line do
+                       line[i] = line[i]:gsub("%s+", "")
+                       if not symbols_set[line[i]] then
+                               symbols_set[line[i]] = true
+                       end
+               end
+       end
+
+       local all_symbols = {}
+
+       for symbol, _ in pairs(symbols_set) do
+               all_symbols[#all_symbols + 1] = symbol
+       end
+
+       table.sort(all_symbols)
+
+       return all_symbols
+end
+
+function utility.read_log_file(file)
+
+       local lines = {}
+
+       file = assert(io.open(file, "r"))
+
+       for line in file:lines() do
+               lines[#lines + 1] = line
+       end
+
+       io.close(file)
+
+       return lines
+end
+
+function utility.get_all_logs(dir_path)
+   -- Reads all log files in the directory and returns a list of logs.
+
+       if dir_path:sub(#dir_path, #dir_path) == "/" then
+               dir_path = dir_path:sub(1, #dir_path -1)
+       end
+
+       local files = rspamd_util.glob(dir_path .. "/*")
+       local all_logs = {}
+
+       for _, file in pairs(files) do
+               local logs = utility.read_log_file(file)
+               for _, log_line in pairs(logs) do
+                       all_logs[#all_logs + 1] = log_line
+               end      
+       end
+
+       return all_logs
+end
+
+function utility.get_all_symbol_scores()
+
+       local output = assert(io.popen("rspamc counters -j --compact"))
+       output = output:read("*all")
+
+       local parser = ucl.parser()
+       local result, err = parser:parse_string(output)
+
+       if not result then
+               print(err)
+               os.exit()
+       end
+
+       output = parser:get_object()
+
+       local symbol_scores = {}
+
+       for _, symbol_info in pairs(output) do
+               symbol_scores[symbol_info.symbol] = symbol_info.weight
+       end
+
+       return symbol_scores
+end
+
+function utility.generate_statistics_from_logs(logs, threshold)
+
+   -- Returns file_stats table and list of symbol_stats table.
+
+       local file_stats = {
+               no_of_emails = 0,
+               no_of_spam = 0,
+               no_of_ham = 0,
+               spam_percent = 0,
+               ham_percent = 0,
+               true_positives = 0,
+               true_negatives = 0,
+               false_negative_rate = 0,
+               false_positive_rate = 0,
+               overall_accuracy = 0,
+               fscore = 0
+       }
+
+       local all_symbols_stats = {}
+
+       local false_positives = 0
+       local false_negatives = 0
+       local true_positives = 0
+       local true_negatives = 0
+       local no_of_emails = 0
+       local no_of_spam = 0
+       local no_of_ham = 0
+
+       for _, log in pairs(logs) do
+               log = lua_util.rspamd_str_trim(log)
+               log = lua_util.rspamd_str_split(log, " ")
+
+               local is_spam = (log[1] == "SPAM")
+               local score = tonumber(log[2])
+
+               no_of_emails = no_of_emails + 1
+
+               if is_spam then
+                       no_of_spam = no_of_spam + 1
+               else
+                       no_of_ham = no_of_ham + 1       
+               end
+
+               if is_spam and (score >= threshold) then
+                       true_positives = true_positives + 1
+               elseif is_spam and (score < threshold) then
+                       false_negatives = false_negatives + 1
+               elseif not is_spam and (score >= threshold) then
+                       false_positives = false_positives + 1
+               else
+                       true_negatives = true_negatives + 1
+               end
+
+               for i=4, #log do   
+                       if all_symbols_stats[log[i]] == nil then
+                               all_symbols_stats[log[i]] = {
+                                       name = log[i],
+                                       no_of_hits = 0,
+                                       spam_hits = 0,
+                                       ham_hits = 0,
+                                       spam_overall = 0
+                               }
+                       end
+
+                       all_symbols_stats[log[i]].no_of_hits =
+                       all_symbols_stats[log[i]].no_of_hits + 1
+
+                       if is_spam then
+                               all_symbols_stats[log[i]].spam_hits =
+                               all_symbols_stats[log[i]].spam_hits + 1
+                       else
+                               all_symbols_stats[log[i]].ham_hits =
+                               all_symbols_stats[log[i]].ham_hits + 1
+                       end
+               end
+       end
+
+       -- Calculating file stats
+
+       file_stats.no_of_ham = no_of_ham
+       file_stats.no_of_spam = no_of_spam
+       file_stats.no_of_emails = no_of_emails
+       file_stats.true_positives = true_positives
+       file_stats.true_negatives = true_negatives
+
+       if no_of_emails > 0 then
+               file_stats.spam_percent = no_of_spam * 100 / no_of_emails
+               file_stats.ham_percent = no_of_ham * 100 / no_of_emails
+               file_stats.overall_accuracy = (true_positives + true_negatives) * 100 /
+               no_of_emails
+       end
+
+       if no_of_ham > 0 then
+               file_stats.false_positive_rate = false_positives * 100 / no_of_ham
+       end
+
+       if no_of_spam > 0 then
+               file_stats.false_negative_rate = false_negatives * 100 / no_of_spam
+       end
+
+       file_stats.fscore = 2 * true_positives / (2
+                                               * true_positives
+                                               + false_positives
+                                               + false_negatives)
+
+       -- Calculating symbol stats
+
+       for _, symbol_stats in pairs(all_symbols_stats) do
+               symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam
+               symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham
+               symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails
+               symbol_stats.spam_overall = symbol_stats.spam_percent /
+               (symbol_stats.spam_percent + symbol_stats.ham_percent)
+       end
+
+       return file_stats, all_symbols_stats
+end
+
+return utility
index 7dfaad69167c6c0ac9346d52392791a16d82d076..fb3f252290df54cf09b079d67b5187ed06377ff2 100644 (file)
@@ -10,10 +10,12 @@ SET(RSPAMADMSRC rspamadm.c
         control.c
         confighelp.c
         configwizard.c
+        corpus_test.c
         stat_convert.c
         signtool.c
         lua_repl.c
         dkim_keygen.c
+        rescore.c
         ${CMAKE_BINARY_DIR}/src/workers.c
         ${CMAKE_BINARY_DIR}/src/modules.c
         ${CMAKE_SOURCE_DIR}/src/controller.c
index 1eaa45c81ecee06a434b831c71653ac085eee467..410306fe3f360861db2f8e50114f2373cc8dd6dd 100644 (file)
@@ -29,6 +29,8 @@ extern struct rspamadm_command signtool_command;
 extern struct rspamadm_command lua_command;
 extern struct rspamadm_command dkim_keygen_command;
 extern struct rspamadm_command configwizard_command;
+extern struct rspamadm_command corpus_test_command;
+extern struct rspamadm_command rescore_command;
 
 const struct rspamadm_command *commands[] = {
        &help_command,
@@ -46,6 +48,8 @@ const struct rspamadm_command *commands[] = {
        &lua_command,
        &dkim_keygen_command,
        &configwizard_command,
+       &corpus_test_command,
+       &rescore_command,
        NULL
 };
 
diff --git a/src/rspamadm/corpus_test.c b/src/rspamadm/corpus_test.c
new file mode 100644 (file)
index 0000000..62aecb1
--- /dev/null
@@ -0,0 +1,121 @@
+/*-
+ * Copyright 2017 Pragadeesh C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "rspamadm.h"
+#include "config.h"
+#include "lua/lua_common.h"
+
+static gchar *ham_directory = NULL;
+static gchar *spam_directory = NULL;
+static gchar *output_location = "results.log";
+static gint connections = 10;
+
+static void rspamadm_corpus_test (gint argc, gchar **argv);
+static const char *rspamadm_corpus_test_help (gboolean full_help);
+
+struct rspamadm_command corpus_test_command = {
+       .name = "corpus_test",
+       .flags = 0,
+       .help = rspamadm_corpus_test_help,
+       .run = rspamadm_corpus_test
+};
+
+// TODO add -nparellel and -o options
+static GOptionEntry entries[] = {
+                       {"ham", 'a', 0, G_OPTION_ARG_FILENAME, &ham_directory,
+                                                                                               "Ham directory", NULL},
+                       {"spam", 's', 0, G_OPTION_ARG_FILENAME, &spam_directory,
+                                                                                               "Spam directory", NULL},
+                       {"output", 'o', 0, G_OPTION_ARG_FILENAME, &output_location,
+                                                                                               "Log output location", NULL},
+                       {"connections", 'n', 0, G_OPTION_ARG_INT, &connections,
+                                               "Number of parellel connections [Default: 10]", NULL},
+                       {NULL,  0,      0, G_OPTION_ARG_NONE, NULL, NULL, NULL}
+};
+
+static const char *
+rspamadm_corpus_test_help (gboolean full_help) 
+{
+       const char *help_str;
+
+       if (full_help) {
+               help_str = "Create logs files from email corpus\n\n"
+                                       "Usage: rspamadm corpus_test [-a <ham_directory>]"
+                                                                                       " [-s <spam_directory>]\n"
+                                       "Where option are:\n\n"
+                                       "-a: path to ham directory\n"
+                                       "-s: path to spam directory\n"
+                                       "-n: maximum parellel connections\n"
+                                       "-o: log output file\n";
+
+       }
+
+       else {
+               help_str = "Create logs files from email corpus";
+       }
+
+       return help_str;
+}
+
+static void
+rspamadm_corpus_test (gint argc, gchar **argv) 
+{
+       GOptionContext *context;
+       GError *error = NULL;
+       lua_State *L;
+       ucl_object_t *obj;
+
+       context = g_option_context_new (
+                               "corpus_test - Create logs files from email corpus");
+
+       g_option_context_set_summary (context, 
+                       "Summary:\n Rspamd administration utility version "
+                                               RVERSION
+                                               "\n Release id: "
+                                               RID);
+       
+       g_option_context_add_main_entries (context, entries, NULL);
+       g_option_context_set_ignore_unknown_options (context, TRUE);
+       
+       if (!g_option_context_parse (context, &argc, &argv, &error)) {
+               rspamd_fprintf (stderr, "option parsing failed: %s\n", error->message);
+               g_error_free (error);
+               exit(1);
+       }
+
+       L = rspamd_lua_init ();
+       rspamd_lua_set_path(L, NULL, NULL);
+
+
+       obj = ucl_object_typed_new (UCL_OBJECT);
+       ucl_object_insert_key (obj, ucl_object_fromstring (ham_directory),
+                                                                                       "ham_directory", 0, false);
+       ucl_object_insert_key (obj, ucl_object_fromstring (spam_directory),
+                                                                                       "spam_directory", 0, false);
+       ucl_object_insert_key (obj, ucl_object_fromstring (output_location),
+                                                                                       "output_location", 0, false);
+       ucl_object_insert_key (obj, ucl_object_fromint (connections),
+                                                                                       "connections", 0, false);
+
+       rspamadm_execute_lua_ucl_subr (L,
+                                               argc,
+                                               argv,
+                                               obj,
+                                               "corpus_test");
+
+       lua_close (L);
+       ucl_object_unref (obj);
+}
diff --git a/src/rspamadm/rescore.c b/src/rspamadm/rescore.c
new file mode 100644 (file)
index 0000000..de5ace2
--- /dev/null
@@ -0,0 +1,141 @@
+/*-
+ * Copyright 2017 Pragadeesh C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "rspamadm.h"
+#include "config.h"
+#include "lua/lua_common.h"
+
+#if !defined(WITH_TORCH) || !defined(WITH_LUAJIT)
+    #define HAS_TORCH false
+#else
+    #define HAS_TORCH true
+#endif
+
+static gchar *logdir = NULL; 
+static gchar *output = "new.scores";
+static gdouble threshold = 15; // Spam threshold
+static gboolean score_diff = false;  // Print score diff flag
+static gint64 iters = 500; // Perceptron max iterations
+
+static void rspamadm_rescore (gint argc, gchar **argv);
+static const char *rspamadm_rescore_help (gboolean full_help);
+
+struct rspamadm_command rescore_command = {
+    .name = "rescore",
+    .flags = 0,
+    .help = rspamadm_rescore_help,
+    .run = rspamadm_rescore
+};
+
+static GOptionEntry entries[] = {
+            {"logdir", 'l', 0, G_OPTION_ARG_FILENAME, &logdir,
+                                        "Logs directory", NULL},
+            {"output", 'o', 0, G_OPTION_ARG_FILENAME, &output,
+                                    "Scores output locaiton", NULL},
+            {"diff", 'd', 0, G_OPTION_ARG_NONE, &score_diff,
+                                    "Print score diff", NULL},
+            {"iters", 'i', 0, G_OPTION_ARG_INT64, &iters,
+                "Max iterations for perceptron [Default: 500]", NULL},  
+            {NULL,  0,  0, G_OPTION_ARG_NONE, NULL, NULL, NULL}
+};
+
+static const char *
+rspamadm_rescore_help (gboolean full_help)
+{
+
+    const char *help_str;
+
+    if (full_help) {
+        help_str = "Estimate optimal symbol weights from log files\n\n"
+                    "Usage: rspamadm rescore -l <log_directory>\n"
+                    "Where options are:\n\n"
+                    "-l: path to logs directory\n"
+                    "-o: Scores output file location\n"
+                    "-d: Print scores diff\n"
+                    "-i: Max iterations for perceptron\n";
+    }
+
+    else {
+        help_str = "Estimate optimal symbol weights from log files";
+    }
+
+    return help_str;
+}
+
+static void 
+rspamadm_rescore (gint argc, gchar **argv)
+{
+
+    GOptionContext *context;
+    GError *error = NULL;
+    lua_State *L;
+    ucl_object_t *obj;
+
+    context = g_option_context_new (
+        "rescore - Estimate optimal symbol weights from log files");
+
+    g_option_context_set_summary (context,
+        "Summary:\n Rspamd administration utility version "
+                        RVERSION
+                        "\n Release id: "
+                        RID);
+
+    g_option_context_add_main_entries (context, entries, NULL);
+    g_option_context_set_ignore_unknown_options (context, TRUE);
+
+    if (!g_option_context_parse (context, &argc, &argv, &error)) {
+        rspamd_fprintf (stderr, "option parsing failed: %s\n", error->message);
+        g_error_free (error);
+        exit(1);
+    }
+
+    if (!HAS_TORCH) {
+        rspamd_fprintf (stderr, "Torch is not enabled. "
+            "Use -DENABLE_TORCH=ON option while running cmake.\n");
+        exit (1);
+    }
+
+    if (logdir == NULL) {
+        rspamd_fprintf (stderr, "Please specify log directory.\n");
+        exit (1);
+    }
+
+    L = rspamd_lua_init ();
+
+    rspamd_lua_set_path(L, NULL, NULL);
+
+    obj = ucl_object_typed_new (UCL_OBJECT);
+
+    ucl_object_insert_key (obj, ucl_object_fromstring (logdir),
+                                            "logdir", 0, false);
+    ucl_object_insert_key (obj, ucl_object_fromstring (output),
+                                            "output", 0, false);
+    ucl_object_insert_key (obj, ucl_object_fromdouble (threshold),
+                                            "threshold", 0, false);
+    ucl_object_insert_key (obj, ucl_object_fromint (iters),
+                                            "iters", 0, false);
+    ucl_object_insert_key (obj, ucl_object_frombool (score_diff),
+                                            "diff", 0, false);
+
+    rspamadm_execute_lua_ucl_subr (L,
+                        argc,
+                        argv,
+                        obj,
+                        "rescore");
+
+    lua_close (L);
+    ucl_object_unref (obj);
+}
\ No newline at end of file