local rspamd_redis = require "lua_redis"
local fun = require "fun"
local meta_functions = require "meta_functions"
+local use_torch = false
+local torch
+local nn
+
+if rspamd_config:has_torch() then
+ use_torch = true
+ torch = require "torch"
+ nn = require "nn"
+end
-- Module vars
local default_options = {
local cksum = rspamd_config:get_symbols_cksum():hex()
-- We also need to count metatokens:
local n = meta_functions.rspamd_count_metatokens()
+ local tprefix = ''
+ if use_torch then
+ tprefix = 't';
+ end
if id then
- return string.format('%s%s%d%s', rule.prefix, cksum, n, id),
+ return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id),
rule.prefix .. id
else
- return string.format('%s%s%d', rule.prefix, cksum, n), nil
+ return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil
end
end
local n = rspamd_config:get_symbols_count() +
meta_functions.rspamd_count_metatokens()
- if n ~= ann:get_inputs() then
- rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache', prefix, ann:get_inputs(), n)
- return false
- end
- local layers = ann:get_layers()
+ if torch then
+ local nlayers = #ann
+ if nlayers ~= rule.nlayers then
+ rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
+ prefix, nlayers)
+ return false
+ end
- if not layers or #layers ~= rule.nlayers then
- rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
- prefix, #layers)
- return false
- end
+ local inp = ann:get(1):nElement()
+ if n ~= inp then
+ rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
+ ' is found in the cache', prefix, inp, n)
+ return false
+ end
+ else
+ if n ~= ann:get_inputs() then
+ rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
+ ' is found in the cache', prefix, ann:get_inputs(), n)
+ return false
+ end
+ local layers = ann:get_layers()
- return true
+ if not layers or #layers ~= rule.nlayers then
+ rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
+ prefix, #layers)
+ return false
+ end
+
+ return true
+ end
end
end
-- Add filtered meta tokens
fun.each(function(e) table.insert(fann_data, e) end, mt)
- local out = fanns[id].fann:test(fann_data)
- local symscore = string.format('%.3f', out[1])
+ local score
+ if torch then
+ local out = fanns[id].fann:forward(torch.Tensor(fann_data))
+ score = out[1]
+ else
+ local out = fanns[id].fann:test(fann_data)
+ score = out[1]
+ end
+
+ local symscore = string.format('%.3f', score)
rspamd_logger.infox(task, 'fann score: %s', symscore)
- if out[1] > 0 then
- local result = rspamd_util.normalize_prob(out[1] / 2.0, 0)
+ if score > 0 then
+ local result = rspamd_util.normalize_prob(score / 2.0, 0)
task:insert_result(rule.symbol_spam, result, symscore, id)
else
- local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
+ local result = rspamd_util.normalize_prob((-score) / 2.0, 0)
task:insert_result(rule.symbol_ham, result, symscore, id)
end
end
end
local function create_fann(n, nlayers)
- local layers = {}
- local div = 1.0
- for _ = 1, nlayers - 1 do
- table.insert(layers, math.floor(n / div))
- div = div * 2
- end
- table.insert(layers, 1)
- return rspamd_fann.create(nlayers, layers)
+ if torch then
+ -- We ignore number of layers so far when using torch
+ local ann = nn.Sequential()
+ local nhidden = math.floor((n + 1) / 2)
+ ann:add(nn.Linear(n, nhidden))
+ ann:add(nn.PReLU())
+ ann:add(nn.Linear(nhidden, 1))
+
+ return ann
+ else
+ local layers = {}
+ local div = 1.0
+ for _ = 1, nlayers - 1 do
+ table.insert(layers, math.floor(n / div))
+ div = div * 2
+ end
+ table.insert(layers, 1)
+ return rspamd_fann.create(nlayers, layers)
+ end
end
local function create_train_fann(rule, n, id)
end
-- Fix that for flexibe layers number
if fanns[id].fann then
- if n ~= fanns[id].fann:get_inputs() or --
- (fanns[id].fann_train and n ~= fanns[id].fann_train:get_inputs()) then
- rspamd_logger.infox(rspamd_config,
- 'recreate ANN %s as it has a wrong number of inputs, version %s',
- prefix,
- fanns[id].version)
-
+ if not is_fann_valid(rule, prefix, fanns[id].fann) then
fanns[id].fann_train = create_fann(n, rule.nlayers)
fanns[id].fann = nil
elseif fanns[id].version % rule.max_usages == 0 then
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
return
else
- ann = rspamd_fann.load_data(ann_data)
+ if torch then
+ ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject()
+ else
+ ann = rspamd_fann.load_data(ann_data)
+ end
end
if is_fann_valid(rule, prefix, ann) then
else
rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
prefix, train_mse)
- local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
+ local ann_data
+ if torch then
+ local f = torch.MemoryFile()
+ f:writeObject(fanns[elt].fann_train)
+ ann_data = rspamd_util.zstd_compress(f:storage():string())
+ else
+ ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
+ end
+
fanns[elt].version = fanns[elt].version + 1
fanns[elt].fann = fanns[elt].fann_train
fanns[elt].fann_train = nil
end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
-- Now we can train fann
- if not fanns[elt] or not fanns[elt].fann_train
- or n ~= fanns[elt].fann_train:get_inputs() then
+ if not fanns[elt] or not fanns[elt].fann_train then
-- Create fann if it does not exist
create_train_fann(rule, n, elt)
end