summaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm/corpus_test.lua
blob: eb93d586c681635c15a289c03d5057cfe9222b89 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
local rspamd_logger = require "rspamd_logger"
local ucl = require "ucl"
local lua_util = require "lua_util"

local HAM = "HAM"
local SPAM = "SPAM"

local function scan_email(n_parellel, path, timeout)

    local rspamc_command = string.format("rspamc -j --compact -n %s -t %.3f %s",
        n_parellel, timeout, path)
    local result = assert(io.popen(rspamc_command))
    result = result:read("*all")
    return result
end   

local function write_results(results, file)

    local f = io.open(file, 'w')

    for _, result in pairs(results) do
        local log_line = string.format("%s %.2f %s", result.type, result.score, result.action)

        for _, sym in pairs(result.symbols) do
            log_line = log_line .. " " .. sym
        end

        log_line = log_line .. "\r\n"

        f:write(log_line)
    end

    f:close()
end

local function encoded_json_to_log(result)
   -- Returns table containing score, action, list of symbols

    local filtered_result = {}
    local parser = ucl.parser()

    local is_good, err = parser:parse_string(result)

    if not is_good then
      io.stderr:write(rspamd_logger.slog("Parser error: %1\n", err))
      return nil
    end

    result = parser:get_object()

    filtered_result.score = result.score
    if not result.action then
      io.stderr:write(rspamd_logger.slog("Bad JSON: %1\n", result))
      return nil
    end
    local action = result.action:gsub("%s+", "_")
    filtered_result.action = action

    filtered_result.symbols = {}

    for sym, _ in pairs(result.symbols) do
        table.insert(filtered_result.symbols, sym)
    end

    return filtered_result   
end

local function scan_results_to_logs(results, actual_email_type)

    local logs = {}

    results = lua_util.rspamd_str_split(results, "\n")

    if results[#results] == "" then
        results[#results] = nil
    end

    for _, result in pairs(results) do      
        result = encoded_json_to_log(result)
        if result then
          result['type'] = actual_email_type
          table.insert(logs, result)
        end
    end

    return logs
end

return function (_, res)

    local ham_directory = res['ham_directory']
    local spam_directory = res['spam_directory']
    local connections = res["connections"]
    local output = res["output_location"]

    local results = {}

    local start_time = os.time()
    local no_of_ham = 0
    local no_of_spam = 0

    if ham_directory then
        io.write("Scanning ham corpus...\n")
        local ham_results = scan_email(connections, ham_directory, res["timeout"])
        ham_results = scan_results_to_logs(ham_results, HAM)

        no_of_ham = #ham_results

        for _, result in pairs(ham_results) do
            table.insert(results, result)
        end
    end

    if spam_directory then
        io.write("Scanning spam corpus...\n")
        local spam_results = scan_email(connections, spam_directory, res.timeout)
        spam_results = scan_results_to_logs(spam_results, SPAM)

        no_of_spam = #spam_results

        for _, result in pairs(spam_results) do
            table.insert(results, result)
        end
    end

    io.write(string.format("Writing results to %s\n", output))
    write_results(results, output)

    io.write("\nStats: \n")
    io.write(string.format("Elapsed time: %ds\n", os.time() - start_time))
    io.write(string.format("No of ham: %d\n", no_of_ham))
    io.write(string.format("No of spam: %d\n", no_of_spam))

end