aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/plugins/lua/neural.lua363
1 files changed, 116 insertions, 247 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 032859d18..193b07614 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -26,9 +26,6 @@ local lua_redis = require "lua_redis"
local lua_util = require "lua_util"
local fun = require "fun"
local meta_functions = require "lua_meta"
-local use_torch = false
-local torch
-local nn
local N = "neural"
-- Module vars
@@ -216,10 +213,7 @@ local function gen_ann_prefix(rule, id)
local cksum = rspamd_config:get_symbols_cksum():hex()
-- We also need to count metatokens:
local n = meta_functions.rspamd_count_metatokens()
- local tprefix = ''
- if use_torch then
- tprefix = 't';
- end
+ local tprefix = 'k'
if id then
return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id
else
@@ -229,27 +223,7 @@ end
local function is_ann_valid(rule, prefix, ann)
if ann then
- local n = rspamd_config:get_symbols_count() +
- meta_functions.rspamd_count_metatokens()
-
- if use_torch then
- return true
- else
- if n ~= ann:get_inputs() then
- rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache', prefix, ann:get_inputs(), n)
- return false
- end
- local layers = ann:get_layers()
-
- if not layers or #layers ~= rule.nlayers then
- rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
- prefix, #layers)
- return false
- end
-
- return true
- end
+ return true
end
end
@@ -275,28 +249,17 @@ local function ann_scores_filter(task)
fun.each(function(e) table.insert(ann_data, e) end, mt)
local score
- if use_torch then
- local out = rule.anns[id].ann:forward(torch.Tensor(ann_data))
- score = out[1]
- else
- local out = rule.anns[id].ann:test(ann_data)
- score = out[1]
- end
+ local out = rule.anns[id].ann:apply1(ann_data)
+ score = out[1]
local symscore = string.format('%.3f', score)
rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore)
if score > 0 then
local result = score
- if not use_torch then
- result = rspamd_util.normalize_prob(score / 2.0, 0)
- end
task:insert_result(rule.symbol_spam, result, symscore, id)
else
local result = -(score)
- if not use_torch then
- result = rspamd_util.normalize_prob(-(score) / 2.0, 0)
- end
task:insert_result(rule.symbol_ham, result, symscore, id)
end
end
@@ -304,20 +267,13 @@ local function ann_scores_filter(task)
end
local function create_ann(n, nlayers)
- if use_torch then
- -- We ignore number of layers so far when using torch
- local ann = nn.Sequential()
- local nhidden = math.floor((n + 1) / 2)
- ann:add(nn.NaN(nn.Identity()))
- ann:add(nn.Linear(n, nhidden))
- ann:add(nn.PReLU())
- ann:add(nn.Linear(nhidden, 1))
- ann:add(nn.Tanh())
-
- return ann
- else
- assert(false)
- end
+ -- We ignore number of layers so far when using kann
+ local nhidden = math.floor((n + 1) / 2)
+ local t = rspamd_kann.layer.input(n)
+ t = rspamd_kann.transform.relu(t)
+ t = rspamd_kann.transform.tanh(rspamd_kann.layer.dense(t, nhidden));
+ t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.mse)
+ return rspamd_kann.new.kann(t)
end
local function create_train_ann(rule, n, id)
@@ -364,11 +320,7 @@ local function load_or_invalidate_ann(rule, data, id, ev_base)
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
return
else
- if use_torch then
- ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject()
- else
- assert(false)
- end
+ ann = rspamd_kann.load(ann_data)
end
if is_ann_valid(rule, prefix, ann) then
@@ -533,47 +485,9 @@ local function train_ann(rule, _, ev_base, elt, worker)
)
else
rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
- prefix, train_mse)
- local ann_data
- if use_torch then
- local f = torch.MemoryFile()
- f:writeObject(rule.anns[elt].ann_train)
- ann_data = rspamd_util.zstd_compress(f:storage():string())
- else
- ann_data = rspamd_util.zstd_compress(rule.anns[elt].ann_train:data())
- end
-
- rule.anns[elt].version = rule.anns[elt].version + 1
- rule.anns[elt].ann = rule.anns[elt].ann_train
- rule.anns[elt].ann_train = nil
- lua_redis.exec_redis_script(redis_save_unlock_id,
- {ev_base = ev_base, is_write = true},
- redis_save_cb,
- {prefix, tostring(ann_data), tostring(rule.ann_expire)})
- end
- end
-
- local function ann_trained_torch(err, data)
- rule.learning_spawned = false
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
- prefix, err)
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- redis_unlock_cb, --callback
- 'DEL', -- command
- {prefix .. '_locked'}
- )
- else
- rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes',
- prefix, #data)
- local ann_data
- local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
- ann_data = rspamd_util.zstd_compress(f:storage():string())
- rule.anns[elt].ann_train = f:readObject()
+ prefix, train_mse)
+ local f = rule.anns[elt].ann_train:save()
+ local ann_data = rspamd_util.zstd_compress(f)
rule.anns[elt].version = rule.anns[elt].version + 1
rule.anns[elt].ann = rule.anns[elt].ann_train
@@ -608,12 +522,6 @@ local function train_ann(rule, _, ev_base, elt, worker)
-- Now we need to join inputs and create the appropriate test vectors
local n = rspamd_config:get_symbols_count() +
meta_functions.rspamd_count_metatokens()
- local filt = function(elts)
- -- Basic sanity checks: vector has good length + there are no
- -- 'bad' values such as NaNs or infinities in its elements
- return #elts == n and
- not fun.any(function(e) return e ~= e or e == math.huge or e == -math.huge end, elts)
- end
-- Now we can train ann
if not rule.anns[elt] or not rule.anns[elt].ann_train then
@@ -638,67 +546,44 @@ local function train_ann(rule, _, ev_base, elt, worker)
redis_invalidate_cb,
{prefix})
else
- if use_torch then
- -- For torch we do not need to mix samples as they would be flushed
- local dataset = {}
- fun.each(function(s)
- table.insert(dataset, {torch.Tensor(s), torch.Tensor({1.0})})
- end, fun.filter(filt, spam_elts))
- fun.each(function(s)
- table.insert(dataset, {torch.Tensor(s), torch.Tensor({-1.0})})
- end, fun.filter(filt, ham_elts))
- -- Needed for torch
- dataset.size = function() return #dataset end
-
- local function train_torch()
- if rule.train.learn_threads then
- torch.setnumthreads(rule.train.learn_threads)
- end
- local criterion = nn.MSECriterion()
- local trainer = nn.StochasticGradient(rule.anns[elt].ann_train,
- criterion)
- trainer.learning_rate = rule.train.learning_rate
- trainer.verbose = false
- trainer.maxIteration = rule.train.max_iterations
- trainer.hookIteration = function(_, iteration, currentError)
- rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
- iteration, currentError)
- end
- trainer.logger = function(s)
- rspamd_logger.infox(rspamd_config, 'training: %s', s)
- end
- trainer:train(dataset)
- local out = torch.MemoryFile()
- out:writeObject(rule.anns[elt].ann_train)
- local st = out:storage():string()
- return st
+ local inputs, outputs = {}, {}
+
+ for _,e in ipairs(spam_elts) do
+ if e == e then
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = 1.0
+ end
+ end
+ for _,e in ipairs(ham_elts) do
+ if e == e then
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = 0.0
end
+ end
- rule.learning_spawned = true
- worker:spawn_process{
- func = train_torch,
- on_complete = ann_trained_torch,
- }
- else
- local inputs = {}
- local outputs = {}
-
- fun.each(function(spam_sample, ham_sample)
- table.insert(inputs, spam_sample)
- table.insert(outputs, {1.0})
- table.insert(inputs, ham_sample)
- table.insert(outputs, {-1.0})
- end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
- rule.learning_spawned = true
- rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
- rule.anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
- ev_base, {
- max_epochs = rule.train.max_epoch,
- desired_mse = rule.train.mse
- })
+ local function train()
+ rule.anns[elt].ann_train:train1(inputs, outputs, {
+ lr = rule.train.learning_rate,
+ max_epoch = rule.train.max_iterations,
+ cb = function(iter, train_cost, _)
+ if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then
+ rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
+ iter, train_cost)
+ end
+ end
+ })
+
+ local out = rule.anns[elt].ann_train:save()
+ return tostring(out)
end
+ rule.learning_spawned = true
+
+ worker:spawn_process{
+ func = train,
+ on_complete = ann_trained,
+ }
end
end
end
@@ -929,99 +814,83 @@ if not (opts and type(opts) == 'table') or not redis_params then
return
end
-if not use_torch then
- rspamd_logger.errx(rspamd_config, 'neural networks support is not compiled in rspamd, this ' ..
- 'module is eventually disabled')
- lua_util.disable_module(N, "fail")
- return
-else
- local rules = opts['rules']
-
- if not rules then
- -- Use legacy configuration
- rules = {}
- rules['RFANN'] = opts
- end
+local rules = opts['rules']
- if opts.disable_torch then
- use_torch = false
- else
- torch = require "torch"
- nn = require "nn"
+if not rules then
+ -- Use legacy configuration
+ rules = {}
+ rules['RFANN'] = opts
+end
- torch.setnumthreads(1)
+local id = rspamd_config:register_symbol({
+ name = 'NEURAL_CHECK',
+ type = 'postfilter,nostat',
+ priority = 6,
+ callback = ann_scores_filter
+})
+for k,r in pairs(rules) do
+ local def_rules = lua_util.override_defaults(default_options, r)
+ def_rules['redis'] = redis_params
+ def_rules['anns'] = {} -- Store ANNs here
+
+ if not def_rules.prefix then
+ def_rules.prefix = k
end
-
- local id = rspamd_config:register_symbol({
- name = 'NEURAL_CHECK',
- type = 'postfilter,nostat',
- priority = 6,
- callback = ann_scores_filter
- })
- for k,r in pairs(rules) do
- local def_rules = lua_util.override_defaults(default_options, r)
- def_rules['redis'] = redis_params
- def_rules['anns'] = {} -- Store ANNs here
-
- if not def_rules.prefix then
- def_rules.prefix = k
- end
- if not def_rules.name then
- def_rules.name = k
- end
- if def_rules.train.max_train then
- def_rules.train.max_trains = def_rules.train.max_train
- end
- rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
- settings.rules[k] = def_rules
- rspamd_config:set_metric_symbol({
- name = def_rules.symbol_spam,
- score = 0.0,
- description = 'Neural network SPAM',
- group = 'neural'
- })
- rspamd_config:register_symbol({
- name = def_rules.symbol_spam,
- type = 'virtual,nostat',
- parent = id
- })
-
- rspamd_config:set_metric_symbol({
- name = def_rules.symbol_ham,
- score = -0.0,
- description = 'Neural network HAM',
- group = 'neural'
- })
- rspamd_config:register_symbol({
- name = def_rules.symbol_ham,
- type = 'virtual,nostat',
- parent = id
- })
+ if not def_rules.name then
+ def_rules.name = k
+ end
+ if def_rules.train.max_train then
+ def_rules.train.max_trains = def_rules.train.max_train
end
+ rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
+ settings.rules[k] = def_rules
+ rspamd_config:set_metric_symbol({
+ name = def_rules.symbol_spam,
+ score = 0.0,
+ description = 'Neural network SPAM',
+ group = 'neural'
+ })
+ rspamd_config:register_symbol({
+ name = def_rules.symbol_spam,
+ type = 'virtual,nostat',
+ parent = id
+ })
+ rspamd_config:set_metric_symbol({
+ name = def_rules.symbol_ham,
+ score = -0.0,
+ description = 'Neural network HAM',
+ group = 'neural'
+ })
rspamd_config:register_symbol({
- name = 'NEURAL_LEARN',
- type = 'idempotent,nostat',
- priority = 5,
- callback = ann_push_vector
+ name = def_rules.symbol_ham,
+ type = 'virtual,nostat',
+ parent = id
})
+end
- -- Add training scripts
- for _,rule in pairs(settings.rules) do
- load_scripts(rule.redis)
- rspamd_config:add_on_load(function(cfg, ev_base, worker)
+rspamd_config:register_symbol({
+ name = 'NEURAL_LEARN',
+ type = 'idempotent,nostat',
+ priority = 5,
+ callback = ann_push_vector
+})
+
+-- Add training scripts
+for _,rule in pairs(settings.rules) do
+ load_scripts(rule.redis)
+ rspamd_config:add_on_load(function(cfg, ev_base, worker)
+ rspamd_config:add_periodic(ev_base, 0.0,
+ function(_, _)
+ return check_anns(rule, cfg, ev_base)
+ end)
+
+ if worker:is_primary_controller() then
+ -- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
- return check_anns(rule, cfg, ev_base)
+ return maybe_train_anns(rule, cfg, ev_base, worker)
end)
-
- if worker:is_primary_controller() then
- -- We also want to train neural nets when they have enough data
- rspamd_config:add_periodic(ev_base, 0.0,
- function(_, _)
- return maybe_train_anns(rule, cfg, ev_base, worker)
- end)
- end
- end)
- end
+ end
+ end)
end