summaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r--lualib/rspamadm/rescore.lua31
1 files changed, 10 insertions, 21 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
index dfa73f2d5..9b4d3a4ce 100644
--- a/lualib/rspamadm/rescore.lua
+++ b/lualib/rspamadm/rescore.lua
@@ -14,10 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
]]--
-if not rspamd_config:has_torch() then
- return
-end
-
local lua_util = require "lua_util"
local ucl = require "ucl"
local logger = require "rspamd_logger"
@@ -26,9 +22,6 @@ local rspamd_util = require "rspamd_util"
local argparse = require "argparse"
local rescore_utility = require "rescore_utility"
--- Load these lazily
-local torch
-local nn
local opts
local ignore_symbols = {
@@ -137,17 +130,17 @@ parser:option "--l2"
local function make_dataset_from_logs(logs, all_symbols, spam_score)
-- Returns a list of {input, output} for torch SGD train
- local dataset = {}
+ local inputs = {}
+ local outputs = {}
for _, log in pairs(logs) do
- local input = torch.Tensor(#all_symbols)
- local output = torch.Tensor(1)
+
log = lua_util.rspamd_str_split(log, " ")
if log[1] == "SPAM" then
- output[1] = 1
+ outputs[#outputs+1] = 1
else
- output[1] = 0
+ outputs[#outputs+1] = 0
end
local symbols_set = {}
@@ -158,23 +151,19 @@ local function make_dataset_from_logs(logs, all_symbols, spam_score)
end
end
+ local input_vec = {}
for index, symbol in pairs(all_symbols) do
if symbols_set[symbol] then
- input[index] = 1
+ input_vec[index] = 1
else
- input[index] = 0
+ input_vec[index] = 0
end
end
- dataset[#dataset + 1] = {input, output}
-
- end
-
- function dataset:size()
- return #dataset
+ inputs[#inputs + 1] = input_vec
end
- return dataset
+ return inputs,outputs
end
local function init_weights(all_symbols, original_symbol_scores)