diff options
-rw-r--r-- | src/plugins/lua/neural.lua | 363 |
1 files changed, 116 insertions, 247 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 032859d18..193b07614 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -26,9 +26,6 @@ local lua_redis = require "lua_redis" local lua_util = require "lua_util" local fun = require "fun" local meta_functions = require "lua_meta" -local use_torch = false -local torch -local nn local N = "neural" -- Module vars @@ -216,10 +213,7 @@ local function gen_ann_prefix(rule, id) local cksum = rspamd_config:get_symbols_cksum():hex() -- We also need to count metatokens: local n = meta_functions.rspamd_count_metatokens() - local tprefix = '' - if use_torch then - tprefix = 't'; - end + local tprefix = 'k' if id then return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id else @@ -229,27 +223,7 @@ end local function is_ann_valid(rule, prefix, ann) if ann then - local n = rspamd_config:get_symbols_count() + - meta_functions.rspamd_count_metatokens() - - if use_torch then - return true - else - if n ~= ann:get_inputs() then - rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache', prefix, ann:get_inputs(), n) - return false - end - local layers = ann:get_layers() - - if not layers or #layers ~= rule.nlayers then - rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s', - prefix, #layers) - return false - end - - return true - end + return true end end @@ -275,28 +249,17 @@ local function ann_scores_filter(task) fun.each(function(e) table.insert(ann_data, e) end, mt) local score - if use_torch then - local out = rule.anns[id].ann:forward(torch.Tensor(ann_data)) - score = out[1] - else - local out = rule.anns[id].ann:test(ann_data) - score = out[1] - end + local out = rule.anns[id].ann:apply1(ann_data) + score = out[1] local symscore = string.format('%.3f', score) rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore) if score > 0 then local result = score - if not use_torch then - result = rspamd_util.normalize_prob(score / 2.0, 0) - end task:insert_result(rule.symbol_spam, result, symscore, id) else local result = -(score) - if not use_torch then - result = rspamd_util.normalize_prob(-(score) / 2.0, 0) - end task:insert_result(rule.symbol_ham, result, symscore, id) end end @@ -304,20 +267,13 @@ local function ann_scores_filter(task) end local function create_ann(n, nlayers) - if use_torch then - -- We ignore number of layers so far when using torch - local ann = nn.Sequential() - local nhidden = math.floor((n + 1) / 2) - ann:add(nn.NaN(nn.Identity())) - ann:add(nn.Linear(n, nhidden)) - ann:add(nn.PReLU()) - ann:add(nn.Linear(nhidden, 1)) - ann:add(nn.Tanh()) - - return ann - else - assert(false) - end + -- We ignore number of layers so far when using kann + local nhidden = math.floor((n + 1) / 2) + local t = rspamd_kann.layer.input(n) + t = rspamd_kann.transform.relu(t) + t = rspamd_kann.transform.tanh(rspamd_kann.layer.dense(t, nhidden)); + t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.mse) + return rspamd_kann.new.kann(t) end local function create_train_ann(rule, n, id) @@ -364,11 +320,7 @@ local function load_or_invalidate_ann(rule, data, id, ev_base) rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err) return else - if use_torch then - ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject() - else - assert(false) - end + ann = rspamd_kann.load(ann_data) end if is_ann_valid(rule, prefix, ann) then @@ -533,47 +485,9 @@ local function train_ann(rule, _, ev_base, elt, worker) ) else rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s', - prefix, train_mse) - local ann_data - if use_torch then - local f = torch.MemoryFile() - f:writeObject(rule.anns[elt].ann_train) - ann_data = rspamd_util.zstd_compress(f:storage():string()) - else - ann_data = rspamd_util.zstd_compress(rule.anns[elt].ann_train:data()) - end - - rule.anns[elt].version = rule.anns[elt].version + 1 - rule.anns[elt].ann = rule.anns[elt].ann_train - rule.anns[elt].ann_train = nil - lua_redis.exec_redis_script(redis_save_unlock_id, - {ev_base = ev_base, is_write = true}, - redis_save_cb, - {prefix, tostring(ann_data), tostring(rule.ann_expire)}) - end - end - - local function ann_trained_torch(err, data) - rule.learning_spawned = false - if err then - rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s', - prefix, err) - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - redis_unlock_cb, --callback - 'DEL', -- command - {prefix .. '_locked'} - ) - else - rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes', - prefix, #data) - local ann_data - local f = torch.MemoryFile(torch.CharStorage():string(tostring(data))) - ann_data = rspamd_util.zstd_compress(f:storage():string()) - rule.anns[elt].ann_train = f:readObject() + prefix, train_mse) + local f = rule.anns[elt].ann_train:save() + local ann_data = rspamd_util.zstd_compress(f) rule.anns[elt].version = rule.anns[elt].version + 1 rule.anns[elt].ann = rule.anns[elt].ann_train @@ -608,12 +522,6 @@ local function train_ann(rule, _, ev_base, elt, worker) -- Now we need to join inputs and create the appropriate test vectors local n = rspamd_config:get_symbols_count() + meta_functions.rspamd_count_metatokens() - local filt = function(elts) - -- Basic sanity checks: vector has good length + there are no - -- 'bad' values such as NaNs or infinities in its elements - return #elts == n and - not fun.any(function(e) return e ~= e or e == math.huge or e == -math.huge end, elts) - end -- Now we can train ann if not rule.anns[elt] or not rule.anns[elt].ann_train then @@ -638,67 +546,44 @@ local function train_ann(rule, _, ev_base, elt, worker) redis_invalidate_cb, {prefix}) else - if use_torch then - -- For torch we do not need to mix samples as they would be flushed - local dataset = {} - fun.each(function(s) - table.insert(dataset, {torch.Tensor(s), torch.Tensor({1.0})}) - end, fun.filter(filt, spam_elts)) - fun.each(function(s) - table.insert(dataset, {torch.Tensor(s), torch.Tensor({-1.0})}) - end, fun.filter(filt, ham_elts)) - -- Needed for torch - dataset.size = function() return #dataset end - - local function train_torch() - if rule.train.learn_threads then - torch.setnumthreads(rule.train.learn_threads) - end - local criterion = nn.MSECriterion() - local trainer = nn.StochasticGradient(rule.anns[elt].ann_train, - criterion) - trainer.learning_rate = rule.train.learning_rate - trainer.verbose = false - trainer.maxIteration = rule.train.max_iterations - trainer.hookIteration = function(_, iteration, currentError) - rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s", - iteration, currentError) - end - trainer.logger = function(s) - rspamd_logger.infox(rspamd_config, 'training: %s', s) - end - trainer:train(dataset) - local out = torch.MemoryFile() - out:writeObject(rule.anns[elt].ann_train) - local st = out:storage():string() - return st + local inputs, outputs = {}, {} + + for _,e in ipairs(spam_elts) do + if e == e then + inputs[#inputs + 1] = e + outputs[#outputs + 1] = 1.0 + end + end + for _,e in ipairs(ham_elts) do + if e == e then + inputs[#inputs + 1] = e + outputs[#outputs + 1] = 0.0 end + end - rule.learning_spawned = true - worker:spawn_process{ - func = train_torch, - on_complete = ann_trained_torch, - } - else - local inputs = {} - local outputs = {} - - fun.each(function(spam_sample, ham_sample) - table.insert(inputs, spam_sample) - table.insert(outputs, {1.0}) - table.insert(inputs, ham_sample) - table.insert(outputs, {-1.0}) - end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts))) - rule.learning_spawned = true - rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix) - rule.anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained, - ev_base, { - max_epochs = rule.train.max_epoch, - desired_mse = rule.train.mse - }) + local function train() + rule.anns[elt].ann_train:train1(inputs, outputs, { + lr = rule.train.learning_rate, + max_epoch = rule.train.max_iterations, + cb = function(iter, train_cost, _) + if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then + rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s", + iter, train_cost) + end + end + }) + + local out = rule.anns[elt].ann_train:save() + return tostring(out) end + rule.learning_spawned = true + + worker:spawn_process{ + func = train, + on_complete = ann_trained, + } end end end @@ -929,99 +814,83 @@ if not (opts and type(opts) == 'table') or not redis_params then return end -if not use_torch then - rspamd_logger.errx(rspamd_config, 'neural networks support is not compiled in rspamd, this ' .. - 'module is eventually disabled') - lua_util.disable_module(N, "fail") - return -else - local rules = opts['rules'] - - if not rules then - -- Use legacy configuration - rules = {} - rules['RFANN'] = opts - end +local rules = opts['rules'] - if opts.disable_torch then - use_torch = false - else - torch = require "torch" - nn = require "nn" +if not rules then + -- Use legacy configuration + rules = {} + rules['RFANN'] = opts +end - torch.setnumthreads(1) +local id = rspamd_config:register_symbol({ + name = 'NEURAL_CHECK', + type = 'postfilter,nostat', + priority = 6, + callback = ann_scores_filter +}) +for k,r in pairs(rules) do + local def_rules = lua_util.override_defaults(default_options, r) + def_rules['redis'] = redis_params + def_rules['anns'] = {} -- Store ANNs here + + if not def_rules.prefix then + def_rules.prefix = k end - - local id = rspamd_config:register_symbol({ - name = 'NEURAL_CHECK', - type = 'postfilter,nostat', - priority = 6, - callback = ann_scores_filter - }) - for k,r in pairs(rules) do - local def_rules = lua_util.override_defaults(default_options, r) - def_rules['redis'] = redis_params - def_rules['anns'] = {} -- Store ANNs here - - if not def_rules.prefix then - def_rules.prefix = k - end - if not def_rules.name then - def_rules.name = k - end - if def_rules.train.max_train then - def_rules.train.max_trains = def_rules.train.max_train - end - rspamd_logger.infox(rspamd_config, "register ann rule %s", k) - settings.rules[k] = def_rules - rspamd_config:set_metric_symbol({ - name = def_rules.symbol_spam, - score = 0.0, - description = 'Neural network SPAM', - group = 'neural' - }) - rspamd_config:register_symbol({ - name = def_rules.symbol_spam, - type = 'virtual,nostat', - parent = id - }) - - rspamd_config:set_metric_symbol({ - name = def_rules.symbol_ham, - score = -0.0, - description = 'Neural network HAM', - group = 'neural' - }) - rspamd_config:register_symbol({ - name = def_rules.symbol_ham, - type = 'virtual,nostat', - parent = id - }) + if not def_rules.name then + def_rules.name = k + end + if def_rules.train.max_train then + def_rules.train.max_trains = def_rules.train.max_train end + rspamd_logger.infox(rspamd_config, "register ann rule %s", k) + settings.rules[k] = def_rules + rspamd_config:set_metric_symbol({ + name = def_rules.symbol_spam, + score = 0.0, + description = 'Neural network SPAM', + group = 'neural' + }) + rspamd_config:register_symbol({ + name = def_rules.symbol_spam, + type = 'virtual,nostat', + parent = id + }) + rspamd_config:set_metric_symbol({ + name = def_rules.symbol_ham, + score = -0.0, + description = 'Neural network HAM', + group = 'neural' + }) rspamd_config:register_symbol({ - name = 'NEURAL_LEARN', - type = 'idempotent,nostat', - priority = 5, - callback = ann_push_vector + name = def_rules.symbol_ham, + type = 'virtual,nostat', + parent = id }) +end - -- Add training scripts - for _,rule in pairs(settings.rules) do - load_scripts(rule.redis) - rspamd_config:add_on_load(function(cfg, ev_base, worker) +rspamd_config:register_symbol({ + name = 'NEURAL_LEARN', + type = 'idempotent,nostat', + priority = 5, + callback = ann_push_vector +}) + +-- Add training scripts +for _,rule in pairs(settings.rules) do + load_scripts(rule.redis) + rspamd_config:add_on_load(function(cfg, ev_base, worker) + rspamd_config:add_periodic(ev_base, 0.0, + function(_, _) + return check_anns(rule, cfg, ev_base) + end) + + if worker:is_primary_controller() then + -- We also want to train neural nets when they have enough data rspamd_config:add_periodic(ev_base, 0.0, function(_, _) - return check_anns(rule, cfg, ev_base) + return maybe_train_anns(rule, cfg, ev_base, worker) end) - - if worker:is_primary_controller() then - -- We also want to train neural nets when they have enough data - rspamd_config:add_periodic(ev_base, 0.0, - function(_, _) - return maybe_train_anns(rule, cfg, ev_base, worker) - end) - end - end) - end + end + end) end |