diff options
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r-- | lualib/rspamadm/rescore.lua | 31 |
1 files changed, 10 insertions, 21 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index dfa73f2d5..9b4d3a4ce 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -14,10 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. ]]-- -if not rspamd_config:has_torch() then - return -end - local lua_util = require "lua_util" local ucl = require "ucl" local logger = require "rspamd_logger" @@ -26,9 +22,6 @@ local rspamd_util = require "rspamd_util" local argparse = require "argparse" local rescore_utility = require "rescore_utility" --- Load these lazily -local torch -local nn local opts local ignore_symbols = { @@ -137,17 +130,17 @@ parser:option "--l2" local function make_dataset_from_logs(logs, all_symbols, spam_score) -- Returns a list of {input, output} for torch SGD train - local dataset = {} + local inputs = {} + local outputs = {} for _, log in pairs(logs) do - local input = torch.Tensor(#all_symbols) - local output = torch.Tensor(1) + log = lua_util.rspamd_str_split(log, " ") if log[1] == "SPAM" then - output[1] = 1 + outputs[#outputs+1] = 1 else - output[1] = 0 + outputs[#outputs+1] = 0 end local symbols_set = {} @@ -158,23 +151,19 @@ local function make_dataset_from_logs(logs, all_symbols, spam_score) end end + local input_vec = {} for index, symbol in pairs(all_symbols) do if symbols_set[symbol] then - input[index] = 1 + input_vec[index] = 1 else - input[index] = 0 + input_vec[index] = 0 end end - dataset[#dataset + 1] = {input, output} - - end - - function dataset:size() - return #dataset + inputs[#inputs + 1] = input_vec end - return dataset + return inputs,outputs end local function init_weights(all_symbols, original_symbol_scores) |