]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Rework rescore utility
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 1 Mar 2018 16:00:39 +0000 (16:00 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 1 Mar 2018 16:00:39 +0000 (16:00 +0000)
lualib/rspamadm/rescore.lua
lualib/rspamadm/rescore_utility.lua
src/rspamadm/configwizard.c
src/rspamadm/rescore.c

index 4f6cc5075194886eedcba5f9ab85ab34cf4e467d..3129af5638b148251134b0b46dfab51759d5ed16 100644 (file)
@@ -2,297 +2,338 @@ local torch = require "torch"
 local nn = require "nn"
 local lua_util = require "lua_util"
 local ucl = require "ucl"
+local logger = require "rspamd_logger"
+local getopt = require "rspamadm/getopt"
 
 local rescore_utility = require "rspamadm/rescore_utility"
 
+local opts
+
 local function make_dataset_from_logs(logs, all_symbols)
-       -- Returns a list of {input, output} for torch SGD train
+  -- Returns a list of {input, output} for torch SGD train
 
-       local dataset = {}
+  local dataset = {}
 
-       for _, log in pairs(logs) do
-               local input = torch.Tensor(#all_symbols)
-               local output = torch.Tensor(1)
-               log = lua_util.rspamd_str_split(log, " ")
+  for _, log in pairs(logs) do
+    local input = torch.Tensor(#all_symbols)
+    local output = torch.Tensor(1)
+    log = lua_util.rspamd_str_split(log, " ")
 
-               if log[1] == "SPAM" then
-                       output[1] = 1
-               else
-                       output[1] = 0
-               end
+    if log[1] == "SPAM" then
+      output[1] = 1
+    else
+      output[1] = 0
+    end
 
-               local symbols_set = {}
+    local symbols_set = {}
 
-               for i=4,#log do
-                       symbols_set[log[i]] = true
-               end
+    for i=4,#log do
+      symbols_set[log[i]] = true
+    end
 
-               for index, symbol in pairs(all_symbols) do
-                       if symbols_set[symbol] then
-                               input[index] = 1
-                       else
-                               input[index] = 0
-                       end     
-               end
+    for index, symbol in pairs(all_symbols) do
+      if symbols_set[symbol] then
+        input[index] = 1
+      else
+        input[index] = 0
+      end
+    end
 
-               dataset[#dataset + 1] = {input, output}
+    dataset[#dataset + 1] = {input, output}
 
-       end
+  end
 
-       function dataset:size()
-               return #dataset
-       end
+  function dataset:size()
+    return #dataset
+  end
 
-       return dataset
+  return dataset
 end
 
 local function init_weights(all_symbols, original_symbol_scores)
 
-       local weights = torch.Tensor(#all_symbols)
+  local weights = torch.Tensor(#all_symbols)
 
-       local mean = 0
+  local mean = 0
 
-       for i, symbol in pairs(all_symbols) do
-               local score = original_symbol_scores[symbol]
-               if not score then score = 0 end
-               weights[i] = score
-               mean = mean + score
-       end
+  for i, symbol in pairs(all_symbols) do
+    local score = original_symbol_scores[symbol]
+    if not score then score = 0 end
+    weights[i] = score
+    mean = mean + score
+  end
 
-       return weights   
+  return weights
 end
 
 local function shuffle(logs)
 
-       local size = #logs
-       for i = size, 1, -1 do
-               local rand = math.random(size)
-               logs[i], logs[rand] = logs[rand], logs[i]
-       end
+  local size = #logs
+  for i = size, 1, -1 do
+    local rand = math.random(size)
+    logs[i], logs[rand] = logs[rand], logs[i]
+  end
 
 end
 
 local function split_logs(logs, split_percent)
 
-       if not split_percent then
-               split_percent = 60
-       end
+  if not split_percent then
+    split_percent = 60
+  end
 
-       local split_index = math.floor(#logs * split_percent / 100)
+  local split_index = math.floor(#logs * split_percent / 100)
 
-       local test_logs = {}
-       local train_logs = {}
+  local test_logs = {}
+  local train_logs = {}
 
-       for i=1,split_index do
-               train_logs[#train_logs + 1] = logs[i]
-       end   
+  for i=1,split_index do
+    train_logs[#train_logs + 1] = logs[i]
+  end
 
-       for i=split_index + 1, #logs do
-               test_logs[#test_logs + 1] = logs[i]
-       end
+  for i=split_index + 1, #logs do
+    test_logs[#test_logs + 1] = logs[i]
+  end
 
-       return train_logs, test_logs
+  return train_logs, test_logs
 end
 
 local function stitch_new_scores(all_symbols, new_scores)
 
-       local new_symbol_scores = {}
+  local new_symbol_scores = {}
 
-       for idx, symbol in pairs(all_symbols) do
-               new_symbol_scores[symbol] = new_scores[idx]
-       end
+  for idx, symbol in pairs(all_symbols) do
+    new_symbol_scores[symbol] = new_scores[idx]
+  end
 
-       return new_symbol_scores
+  return new_symbol_scores
 end
 
 
 local function update_logs(logs, symbol_scores)
 
-       for i, log in ipairs(logs) do
+  for i, log in ipairs(logs) do
 
-               log = lua_util.rspamd_str_split(log, " ")
+    log = lua_util.rspamd_str_split(log, " ")
 
-               local score = 0
+    local score = 0
 
-               for j=4,#log do
-                       log[j] = log[j]:gsub("%s+", "")
-                       score = score + (symbol_scores[log[j    ]] or 0)
-               end
+    for j=4,#log do
+      log[j] = log[j]:gsub("%s+", "")
+      score = score + (symbol_scores[log[j     ]] or 0)
+    end
 
-               log[2] = rescore_utility.round(score, 2)
+    log[2] = lua_util.round(score, 2)
 
-               logs[i] = table.concat(log, " ")
-       end
+    logs[i] = table.concat(log, " ")
+  end
 
-       return logs
+  return logs
 end
 
 local function write_scores(new_symbol_scores, file_path)
 
-       local file = assert(io.open(file_path, "w"))
+  local file = assert(io.open(file_path, "w"))
 
-       local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl")
+  local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl")
 
-       file:write(new_scores_ucl)
+  file:write(new_scores_ucl)
 
-       file:close()
+  file:close()
 end
 
 local function print_score_diff(new_symbol_scores, original_symbol_scores)
 
-       print(string.format("%-35s %-10s %-10s", "SYMBOL", "OLD_SCORE", "NEW_SCORE"))
-
-       for symbol, new_score in pairs(new_symbol_scores) do
-       print(string.format("%-35s %-10s %-10s",
-                 symbol,
-                 original_symbol_scores[symbol] or 0,
-                 rescore_utility.round(new_score, 2)))
-       end
-
-       print "\nClass changes \n"
-       for symbol, new_score in pairs(new_symbol_scores) do
-               if original_symbol_scores[symbol] ~= nil then
-                       if (original_symbol_scores[symbol] > 0 and new_score < 0) or
-                               (original_symbol_scores[symbol] < 0 and new_score > 0) then
-                               print(string.format("%-35s %-10s %-10s",
-                                               symbol,
-                                               original_symbol_scores[symbol] or 0,
-                                               rescore_utility.round(new_score, 2)))
-                       end
-               end
-       end
+  print(string.format("%-35s %-10s %-10s", "SYMBOL", "OLD_SCORE", "NEW_SCORE"))
+
+  for symbol, new_score in pairs(new_symbol_scores) do
+    print(string.format("%-35s %-10s %-10s",
+        symbol,
+        original_symbol_scores[symbol] or 0,
+        rescore_utility.round(new_score, 2)))
+  end
+
+  print "\nClass changes \n"
+  for symbol, new_score in pairs(new_symbol_scores) do
+    if original_symbol_scores[symbol] ~= nil then
+      if (original_symbol_scores[symbol] > 0 and new_score < 0) or
+          (original_symbol_scores[symbol] < 0 and new_score > 0) then
+        print(string.format("%-35s %-10s %-10s",
+            symbol,
+            original_symbol_scores[symbol] or 0,
+            rescore_utility.round(new_score, 2)))
+      end
+    end
+  end
 
 end
 
 local function calculate_fscore_from_weights(logs, all_symbols, weights, bias, threshold)
 
-       local new_symbol_scores = weights:clone()
+  local new_symbol_scores = weights:clone()
 
-       new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
+  new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
 
-       logs = update_logs(logs, new_symbol_scores)
+  logs = update_logs(logs, new_symbol_scores)
 
-       local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+  local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
 
-       return file_stats.fscore
+  return file_stats.fscore
 end
 
 local function print_stats(logs, threshold)
 
-       local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+  local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
 
-       local file_stat_format = [[
+  local file_stat_format = [[
 F-score: %.2f
 False positive rate: %.2f %%
 False negative rate: %.2f %%
 Overall accuracy: %.2f %%
 ]]
 
-       io.write("\nStatistics at threshold: " .. threshold .. "\n")
+  io.write("\nStatistics at threshold: " .. threshold .. "\n")
+
+  io.write(string.format(file_stat_format,
+      file_stats.fscore,
+      file_stats.false_positive_rate,
+      file_stats.false_negative_rate,
+      file_stats.overall_accuracy))
 
-       io.write(string.format(file_stat_format,
-                       file_stats.fscore,
-                       file_stats.false_positive_rate,
-                       file_stats.false_negative_rate,
-                       file_stats.overall_accuracy))
+end
 
+local default_opts = {
+  verbose = true,
+  iters = 10,
+  threads = 1,
+}
+
+local function override_defaults(def, override)
+  for k,v in pairs(override) do
+    if def[k] then
+      if type(v) == 'table' then
+        override_defaults(def[k], v)
+      else
+        def[k] = v
+      end
+    else
+      def[k] = v
+    end
+  end
 end
 
-return function (_, res)
+local function get_threshold(opts)
+  local actions = rspamd_config:get_all_actions()
 
-       local logs = rescore_utility.get_all_logs(res["logdir"])
-       local all_symbols = rescore_utility.get_all_symbols(logs)
-       local original_symbol_scores = rescore_utility.get_all_symbol_scores(res["timeout"])
+  if opts['spam-action'] then
+    return actions[opts['spam-action']] or 0
+  else
+    return actions['add header'] or actions['rewrite subject'] or actions['reject']
+  end
+end
 
-       shuffle(logs)
+return function (args, cfg)
+  opts = default_opts
+  override_defaults(opts, getopt.getopt(args, ''))
+  local threshold = get_threshold(opts)
+  local logs = rescore_utility.get_all_logs(cfg["logdir"])
+  local all_symbols = rescore_utility.get_all_symbols(logs)
+  local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config)
 
-       local train_logs, validation_logs = split_logs(logs, 70)
-       local cv_logs, test_logs = split_logs(validation_logs, 50)
+  shuffle(logs)
+  torch.setdefaulttensortype('torch.FloatTensor')
 
-       local dataset = make_dataset_from_logs(train_logs, all_symbols)
+  local train_logs, validation_logs = split_logs(logs, 70)
+  local cv_logs, test_logs = split_logs(validation_logs, 50)
 
-       local learning_rates = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10}
-       local penalty_weights = {0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100}
+  local dataset = make_dataset_from_logs(train_logs, all_symbols)
 
-       -- Start of perceptron training
+  local learning_rates = {
+    0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10
+  }
+  local penalty_weights = {
+    0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100
+  }
 
-       local input_size = #all_symbols
-       local linear_module = nn.Linear(input_size, 1)
+  -- Start of perceptron training
 
-       local perceptron = nn.Sequential()
-       perceptron:add(linear_module)
+  local input_size = #all_symbols
+  torch.setnumthreads(opts['threads'])
+  local linear_module = nn.Linear(input_size, 1)
 
-       local activation = nn.Sigmoid()
+  local perceptron = nn.Sequential()
+  perceptron:add(linear_module)
 
-       perceptron:add(activation)
+  local activation = nn.Sigmoid()
 
-       local criterion = nn.MSECriterion()
-       criterion.sizeAverage = false
+  perceptron:add(activation)
 
-       local best_fscore = -math.huge
-       local best_weights = linear_module.weight[1]:clone()
+  local criterion = nn.MSECriterion()
+  criterion.sizeAverage = false
 
-       local trainer = nn.StochasticGradient(perceptron, criterion)
-       trainer.maxIteration = res["iters"]
-       trainer.verbose = false
+  local best_fscore = -math.huge
+  local best_weights = linear_module.weight[1]:clone()
 
-       trainer.hookIteration = function(self, iteration, error)
+  local trainer = nn.StochasticGradient(perceptron, criterion)
+  trainer.maxIteration = opts["iters"]
+  trainer.verbose = opts['verbose']
+  trainer.hookIteration = function(self, iteration, error)
 
-               if iteration == trainer.maxIteration then
+    if iteration == trainer.maxIteration then
 
-                       local fscore = calculate_fscore_from_weights(cv_logs,
-                                       all_symbols,
-                                       linear_module.weight[1],
-                                       linear_module.bias[1],
-                                       res["threshold"])
+      local fscore = calculate_fscore_from_weights(cv_logs,
+          all_symbols,
+          linear_module.weight[1],
+          linear_module.bias[1],
+          threshold)
 
-                       print("Cross-validation fscore: " .. fscore)
+      print("Cross-validation fscore: " .. fscore)
 
-                       if best_fscore < fscore then
-                               best_fscore = fscore
-                               best_weights = linear_module.weight[1]:clone()
-                       end
-               end
-       end
+      if best_fscore < fscore then
+        best_fscore = fscore
+        best_weights = linear_module.weight[1]:clone()
+      end
+    end
+  end
 
-       for _, learning_rate in pairs(learning_rates) do
-               for _, weight in pairs(penalty_weights) do
+  for _, learning_rate in ipairs(learning_rates) do
+    for _, weight in ipairs(penalty_weights) do
 
-                       trainer.weightDecay = weight
-                       print("Learning with learning_rate: " .. learning_rate 
-                               .. " | l2_weight: " .. weight)
+      trainer.weightDecay = weight
+      print("Learning with learning_rate: " .. learning_rate
+          .. " | l2_weight: " .. weight)
 
-                       linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
+      linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
 
-                       trainer.learningRate = learning_rate
-                       trainer:train(dataset)
+      trainer.learningRate = learning_rate
+      trainer:train(dataset)
 
-                       print()
-               end   
-       end
+      print()
+    end
+  end
 
-       -- End perceptron training
+  -- End perceptron training
 
-       local new_symbol_scores = best_weights
+  local new_symbol_scores = best_weights
 
-       new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
+  new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
 
-       if res["output"] then
-               write_scores(new_symbol_scores, res["output"])
-       end
+  if res["output"] then
+    write_scores(new_symbol_scores, res["output"])
+  end
 
-       if res["diff"] then
-               print_score_diff(new_symbol_scores, original_symbol_scores)
-       end
+  if opts["diff"] then
+    print_score_diff(new_symbol_scores, original_symbol_scores)
+  end
 
 
-       -- Pre-rescore test stats
-       print("\n\nPre-rescore test stats\n")
-       test_logs = update_logs(test_logs, original_symbol_scores)
-       print_stats(test_logs, res['threshold'])
+  -- Pre-rescore test stats
+  io.write("\n\nPre-rescore test stats\n\n")
+  test_logs = update_logs(test_logs, original_symbol_scores)
+  print_stats(test_logs, threshold)
 
-       -- Post-rescore test stats
-       test_logs = update_logs(test_logs, new_symbol_scores)
-       print("\n\nPost-rescore test stats\n")
-       print_stats(test_logs, res['threshold'])
+  -- Post-rescore test stats
+  test_logs = update_logs(test_logs, new_symbol_scores)
+  io.write("\n\nPost-rescore test stats\n\n")
+  print_stats(test_logs, threshold)
 end
\ No newline at end of file
index 2390fc5656a7d81cb142d7d7d59635876bafade9..db79bbf7b79d9e0d622bfe2d469658686d030e06 100644 (file)
-local ucl = require "ucl"
 local lua_util = require "lua_util"
 local rspamd_util = require "rspamd_util"
+local fun = require "fun"
 
 local utility = {}
 
-function utility.round(num, places)
-   return string.format("%." .. (places or 0) .. "f", num)
-end
-
 function utility.get_all_symbols(logs)
-   -- Returns a list of all symbols
+  -- Returns a list of all symbols
 
-       local symbols_set = {}
+  local symbols_set = {}
 
-       for _, line in pairs(logs) do
-               line = lua_util.rspamd_str_split(line, " ")
-               for i=4,#line do
-                       line[i] = line[i]:gsub("%s+", "")
-                       if not symbols_set[line[i]] then
-                               symbols_set[line[i]] = true
-                       end
-               end
-       end
+  for _, line in pairs(logs) do
+    line = lua_util.rspamd_str_split(line, " ")
+    for i=4,#line do
+      line[i] = line[i]:gsub("%s+", "")
+      if not symbols_set[line[i]] then
+        symbols_set[line[i]] = true
+      end
+    end
+  end
 
-       local all_symbols = {}
+  local all_symbols = {}
 
-       for symbol, _ in pairs(symbols_set) do
-               all_symbols[#all_symbols + 1] = symbol
-       end
+  for symbol, _ in pairs(symbols_set) do
+    all_symbols[#all_symbols + 1] = symbol
+  end
 
-       table.sort(all_symbols)
+  table.sort(all_symbols)
 
-       return all_symbols
+  return all_symbols
 end
 
 function utility.read_log_file(file)
 
-       local lines = {}
+  local lines = {}
 
-       file = assert(io.open(file, "r"))
+  file = assert(io.open(file, "r"))
 
-       for line in file:lines() do
-               lines[#lines + 1] = line
-       end
+  for line in file:lines() do
+    lines[#lines + 1] = line
+  end
 
-       io.close(file)
+  io.close(file)
 
-       return lines
+  return lines
 end
 
 function utility.get_all_logs(dir_path)
-   -- Reads all log files in the directory and returns a list of logs.
+  -- Reads all log files in the directory and returns a list of logs.
 
-       if dir_path:sub(#dir_path, #dir_path) == "/" then
-               dir_path = dir_path:sub(1, #dir_path -1)
-       end
+  if dir_path:sub(#dir_path, #dir_path) == "/" then
+    dir_path = dir_path:sub(1, #dir_path -1)
+  end
 
-       local files = rspamd_util.glob(dir_path .. "/*")
-       local all_logs = {}
+  local files = rspamd_util.glob(dir_path .. "/*")
+  local all_logs = {}
 
-       for _, file in pairs(files) do
-               local logs = utility.read_log_file(file)
-               for _, log_line in pairs(logs) do
-                       all_logs[#all_logs + 1] = log_line
-               end      
-       end
+  for _, file in pairs(files) do
+    local logs = utility.read_log_file(file)
+    for _, log_line in pairs(logs) do
+      all_logs[#all_logs + 1] = log_line
+    end
+  end
 
-       return all_logs
+  return all_logs
 end
 
-function utility.get_all_symbol_scores(timeout)
-
-       local output = assert(io.popen("rspamc counters -j --compact -t " .. tostring(timeout)))
-       output = output:read("*all")
-
-       local parser = ucl.parser()
-       local result, err = parser:parse_string(output)
-
-       if not result then
-               print(err)
-               os.exit()
-       end
-
-       output = parser:get_object()
-
-       local symbol_scores = {}
-
-       for _, symbol_info in pairs(output) do
-               symbol_scores[symbol_info.symbol] = symbol_info.weight
-       end
+function utility.get_all_symbol_scores(conf)
+  local counters = conf:get_symbols_counters()
 
-       return symbol_scores
+  return fun.tomap(fun.map(function(elt)
+    return elt['symbol'],elt['weight']
+  end, counters))
 end
 
 function utility.generate_statistics_from_logs(logs, threshold)
 
-   -- Returns file_stats table and list of symbol_stats table.
-
-       local file_stats = {
-               no_of_emails = 0,
-               no_of_spam = 0,
-               no_of_ham = 0,
-               spam_percent = 0,
-               ham_percent = 0,
-               true_positives = 0,
-               true_negatives = 0,
-               false_negative_rate = 0,
-               false_positive_rate = 0,
-               overall_accuracy = 0,
-               fscore = 0
-       }
-
-       local all_symbols_stats = {}
-
-       local false_positives = 0
-       local false_negatives = 0
-       local true_positives = 0
-       local true_negatives = 0
-       local no_of_emails = 0
-       local no_of_spam = 0
-       local no_of_ham = 0
-
-       for _, log in pairs(logs) do
-               log = lua_util.rspamd_str_trim(log)
-               log = lua_util.rspamd_str_split(log, " ")
-
-               local is_spam = (log[1] == "SPAM")
-               local score = tonumber(log[2])
-
-               no_of_emails = no_of_emails + 1
-
-               if is_spam then
-                       no_of_spam = no_of_spam + 1
-               else
-                       no_of_ham = no_of_ham + 1       
-               end
-
-               if is_spam and (score >= threshold) then
-                       true_positives = true_positives + 1
-               elseif is_spam and (score < threshold) then
-                       false_negatives = false_negatives + 1
-               elseif not is_spam and (score >= threshold) then
-                       false_positives = false_positives + 1
-               else
-                       true_negatives = true_negatives + 1
-               end
-
-               for i=4, #log do   
-                       if all_symbols_stats[log[i]] == nil then
-                               all_symbols_stats[log[i]] = {
-                                       name = log[i],
-                                       no_of_hits = 0,
-                                       spam_hits = 0,
-                                       ham_hits = 0,
-                                       spam_overall = 0
-                               }
-                       end
-
-                       all_symbols_stats[log[i]].no_of_hits =
-                       all_symbols_stats[log[i]].no_of_hits + 1
-
-                       if is_spam then
-                               all_symbols_stats[log[i]].spam_hits =
-                               all_symbols_stats[log[i]].spam_hits + 1
-                       else
-                               all_symbols_stats[log[i]].ham_hits =
-                               all_symbols_stats[log[i]].ham_hits + 1
-                       end
-               end
-       end
-
-       -- Calculating file stats
-
-       file_stats.no_of_ham = no_of_ham
-       file_stats.no_of_spam = no_of_spam
-       file_stats.no_of_emails = no_of_emails
-       file_stats.true_positives = true_positives
-       file_stats.true_negatives = true_negatives
-
-       if no_of_emails > 0 then
-               file_stats.spam_percent = no_of_spam * 100 / no_of_emails
-               file_stats.ham_percent = no_of_ham * 100 / no_of_emails
-               file_stats.overall_accuracy = (true_positives + true_negatives) * 100 /
-               no_of_emails
-       end
-
-       if no_of_ham > 0 then
-               file_stats.false_positive_rate = false_positives * 100 / no_of_ham
-       end
-
-       if no_of_spam > 0 then
-               file_stats.false_negative_rate = false_negatives * 100 / no_of_spam
-       end
-
-       file_stats.fscore = 2 * true_positives / (2
-                                               * true_positives
-                                               + false_positives
-                                               + false_negatives)
-
-       -- Calculating symbol stats
-
-       for _, symbol_stats in pairs(all_symbols_stats) do
-               symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam
-               symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham
-               symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails
-               symbol_stats.spam_overall = symbol_stats.spam_percent /
-               (symbol_stats.spam_percent + symbol_stats.ham_percent)
-       end
-
-       return file_stats, all_symbols_stats
+  -- Returns file_stats table and list of symbol_stats table.
+
+  local file_stats = {
+    no_of_emails = 0,
+    no_of_spam = 0,
+    no_of_ham = 0,
+    spam_percent = 0,
+    ham_percent = 0,
+    true_positives = 0,
+    true_negatives = 0,
+    false_negative_rate = 0,
+    false_positive_rate = 0,
+    overall_accuracy = 0,
+    fscore = 0
+  }
+
+  local all_symbols_stats = {}
+
+  local false_positives = 0
+  local false_negatives = 0
+  local true_positives = 0
+  local true_negatives = 0
+  local no_of_emails = 0
+  local no_of_spam = 0
+  local no_of_ham = 0
+
+  for _, log in pairs(logs) do
+    log = lua_util.rspamd_str_trim(log)
+    log = lua_util.rspamd_str_split(log, " ")
+
+    local is_spam = (log[1] == "SPAM")
+    local score = tonumber(log[2])
+
+    no_of_emails = no_of_emails + 1
+
+    if is_spam then
+      no_of_spam = no_of_spam + 1
+    else
+      no_of_ham = no_of_ham + 1
+    end
+
+    if is_spam and (score >= threshold) then
+      true_positives = true_positives + 1
+    elseif is_spam and (score < threshold) then
+      false_negatives = false_negatives + 1
+    elseif not is_spam and (score >= threshold) then
+      false_positives = false_positives + 1
+    else
+      true_negatives = true_negatives + 1
+    end
+
+    for i=4, #log do
+      if all_symbols_stats[log[i]] == nil then
+        all_symbols_stats[log[i]] = {
+          name = log[i],
+          no_of_hits = 0,
+          spam_hits = 0,
+          ham_hits = 0,
+          spam_overall = 0
+        }
+      end
+
+      all_symbols_stats[log[i]].no_of_hits =
+      all_symbols_stats[log[i]].no_of_hits + 1
+
+      if is_spam then
+        all_symbols_stats[log[i]].spam_hits =
+        all_symbols_stats[log[i]].spam_hits + 1
+      else
+        all_symbols_stats[log[i]].ham_hits =
+        all_symbols_stats[log[i]].ham_hits + 1
+      end
+    end
+  end
+
+  -- Calculating file stats
+
+  file_stats.no_of_ham = no_of_ham
+  file_stats.no_of_spam = no_of_spam
+  file_stats.no_of_emails = no_of_emails
+  file_stats.true_positives = true_positives
+  file_stats.true_negatives = true_negatives
+
+  if no_of_emails > 0 then
+    file_stats.spam_percent = no_of_spam * 100 / no_of_emails
+    file_stats.ham_percent = no_of_ham * 100 / no_of_emails
+    file_stats.overall_accuracy = (true_positives + true_negatives) * 100 /
+        no_of_emails
+  end
+
+  if no_of_ham > 0 then
+    file_stats.false_positive_rate = false_positives * 100 / no_of_ham
+  end
+
+  if no_of_spam > 0 then
+    file_stats.false_negative_rate = false_negatives * 100 / no_of_spam
+  end
+
+  file_stats.fscore = 2 * true_positives / (2
+      * true_positives
+      + false_positives
+      + false_negatives)
+
+  -- Calculating symbol stats
+
+  for _, symbol_stats in pairs(all_symbols_stats) do
+    symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam
+    symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham
+    symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails
+    symbol_stats.spam_overall = symbol_stats.spam_percent /
+        (symbol_stats.spam_percent + symbol_stats.ham_percent)
+  end
+
+  return file_stats, all_symbols_stats
 end
 
 return utility
index bd8ccefbdbcebda2663883365680c23a8911626e..83ba980e02fd8a77b9a906009ddaebe3c9214a4e 100644 (file)
@@ -137,7 +137,7 @@ rspamadm_configwizard (gint argc, gchar **argv)
                /* Do post-load actions */
                rspamd_lua_post_load_config (cfg);
 
-               if (!rspamd_init_filters (rspamd_main->cfg, FALSE)) {
+               if (!rspamd_init_filters (cfg, FALSE)) {
                        ret = FALSE;
                }
 
index e0c69f13b2e10b6ca27bdd23e31407801b21ca0c..c88ca0350cebbe40a3230ad06623c5a2fc3e0df6 100644 (file)
 
 static gchar *logdir = NULL;
 static gchar *output = "new.scores";
-static gdouble threshold = 15; /* Spam threshold */
 static gboolean score_diff = false;  /* Print score diff flag */
-static gint64 iters = 500; /* Perceptron max iterations */
-gdouble timeout = 60.0;
-
-/* TODO: think about adding the config file reading */
+static gchar *config = NULL;
+extern struct rspamd_main *rspamd_main;
+/* Defined in modules.c */
+extern module_t *modules[];
+extern worker_t *workers[];
 
 static void rspamadm_rescore (gint argc, gchar **argv);
 
@@ -51,13 +51,29 @@ static GOptionEntry entries[] = {
                                "Scores output locaiton",                       NULL},
                {"diff",   'd', 0, G_OPTION_ARG_NONE,     &score_diff,
                                "Print score diff",                             NULL},
-               {"iters",  'i', 0, G_OPTION_ARG_INT64,    &iters,
-                               "Max iterations for perceptron [Default: 500]", NULL},
-               {"timeout", 't', 0, G_OPTION_ARG_DOUBLE, &timeout,
-                               "Timeout for connections [Default: 60]", NULL},
+               {"config", 'c', 0, G_OPTION_ARG_STRING, &config,
+                               "Config file to use",     NULL},
                {NULL,     0,   0, G_OPTION_ARG_NONE, NULL, NULL,       NULL}
 };
 
+static void
+config_logger (rspamd_mempool_t *pool, gpointer ud)
+{
+       struct rspamd_main *rm = ud;
+
+       rm->cfg->log_type = RSPAMD_LOG_CONSOLE;
+       rm->cfg->log_level = G_LOG_LEVEL_CRITICAL;
+
+       rspamd_set_logger (rm->cfg, g_quark_try_string ("main"), &rm->logger,
+                       rm->server_pool);
+
+       if (rspamd_log_open_priv (rm->logger, rm->workers_uid, rm->workers_gid) ==
+                       -1) {
+               fprintf (stderr, "Fatal error, cannot open logfile, exiting\n");
+               exit (EXIT_FAILURE);
+       }
+}
+
 static const char *
 rspamadm_rescore_help (gboolean full_help) {
 
@@ -70,8 +86,7 @@ rspamadm_rescore_help (gboolean full_help) {
                                "-l: path to logs directory\n"
                                "-o: scores output file location\n"
                                "-d: print scores diff\n"
-                               "-i: max iterations for perceptron\n"
-                               "-t: timeout for rspamc operations (default: 60)\n";
+                               "-i: max iterations for perceptron\n";
        } else {
                help_str = "Estimate optimal symbol weights from log files";
        }
@@ -85,7 +100,10 @@ rspamadm_rescore (gint argc, gchar **argv) {
        GOptionContext *context;
        GError *error = NULL;
        lua_State *L;
-       ucl_object_t *obj;
+       struct rspamd_config *cfg = rspamd_main->cfg, **pcfg;
+       gboolean ret = TRUE;
+       worker_t **pworker;
+       const gchar *confdir;
 
        context = g_option_context_new (
                        "rescore - estimate optimal symbol weights from log files");
@@ -116,30 +134,66 @@ rspamadm_rescore (gint argc, gchar **argv) {
                exit (EXIT_FAILURE);
        }
 
-       L = rspamd_lua_init ();
-       rspamd_lua_set_path (L, NULL, ucl_vars);
-
-       obj = ucl_object_typed_new (UCL_OBJECT);
-
-       ucl_object_insert_key (obj, ucl_object_fromstring (logdir),
-                       "logdir", 0, false);
-       ucl_object_insert_key (obj, ucl_object_fromstring (output),
-                       "output", 0, false);
-       ucl_object_insert_key (obj, ucl_object_fromdouble (threshold),
-                       "threshold", 0, false);
-       ucl_object_insert_key (obj, ucl_object_fromint (iters),
-                       "iters", 0, false);
-       ucl_object_insert_key (obj, ucl_object_frombool (score_diff),
-                       "diff", 0, false);
-       ucl_object_insert_key (obj, ucl_object_fromdouble (timeout),
-                       "timeout", 0, false);
-
-       rspamadm_execute_lua_ucl_subr (L,
-                       argc,
-                       argv,
-                       obj,
-                       "rescore");
-
-       lua_close (L);
-       ucl_object_unref (obj);
+       if (config == NULL) {
+               if ((confdir = g_hash_table_lookup (ucl_vars, "CONFDIR")) == NULL) {
+                       confdir = RSPAMD_CONFDIR;
+               }
+
+               config = g_strdup_printf ("%s%c%s", confdir, G_DIR_SEPARATOR,
+                               "rspamd.conf");
+       }
+
+       pworker = &workers[0];
+       while (*pworker) {
+               /* Init string quarks */
+               (void) g_quark_from_static_string ((*pworker)->name);
+               pworker++;
+       }
+
+       cfg->cache = rspamd_symbols_cache_new (cfg);
+       cfg->compiled_modules = modules;
+       cfg->compiled_workers = workers;
+       cfg->cfg_name = config;
+
+       if (!rspamd_config_read (cfg, cfg->cfg_name, NULL,
+                       config_logger, rspamd_main, ucl_vars)) {
+               ret = FALSE;
+       }
+       else {
+               /* Do post-load actions */
+               rspamd_lua_post_load_config (cfg);
+
+               if (!rspamd_init_filters (cfg, FALSE)) {
+                       ret = FALSE;
+               }
+
+               if (ret) {
+                       ret = rspamd_config_post_load (cfg, RSPAMD_CONFIG_INIT_SYMCACHE);
+                       rspamd_symbols_cache_validate (cfg->cache,
+                                       cfg,
+                                       FALSE);
+               }
+       }
+
+       if (ret) {
+               L = cfg->lua_state;
+               rspamd_lua_set_path (L, cfg->rcl_obj, ucl_vars);
+               ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (cfg->cfg_name),
+                               "config_path", 0, false);
+               ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (logdir),
+                               "logdir", 0, false);
+               ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (output),
+                               "output", 0, false);
+               ucl_object_insert_key (cfg->rcl_obj, ucl_object_frombool (score_diff),
+                               "diff", 0, false);
+               pcfg = lua_newuserdata (L, sizeof (struct rspamd_config *));
+               rspamd_lua_setclass (L, "rspamd{config}", -1);
+               *pcfg = cfg;
+               lua_setglobal (L, "rspamd_config");
+               rspamadm_execute_lua_ucl_subr (L,
+                               argc,
+                               argv,
+                               cfg->rcl_obj,
+                               "rescore");
+       }
 }
\ No newline at end of file