--- /dev/null
+--[[
+Copyright (c) 2020, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local fun = require "fun"
+local lua_redis = require "lua_redis"
+local lua_settings = require "lua_settings"
+local lua_util = require "lua_util"
+local meta_functions = require "lua_meta"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_logger = require "rspamd_logger"
+local rspamd_tensor = require "rspamd_tensor"
+local rspamd_util = require "rspamd_util"
+
+local N = 'neural'
+
+-- Used in prefix to avoid wrong ANN to be loaded
+local plugin_ver = '2'
+
+-- Module vars
+local default_options = {
+ train = {
+ max_trains = 1000,
+ max_epoch = 1000,
+ max_usages = 10,
+ max_iterations = 25, -- Torch style
+ mse = 0.001,
+ autotrain = true,
+ train_prob = 1.0,
+ learn_threads = 1,
+ learn_mode = 'balanced', -- Possible values: balanced, proportional
+ learning_rate = 0.01,
+ classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
+ spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
+ ham_skip_prob = 0.0, -- proportional mode: ham skip probability
+ store_pool_only = false, -- store tokens in cache only (disables autotrain);
+ -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
+ },
+ watch_interval = 60.0,
+ lock_expire = 600,
+ learning_spawned = false,
+ ann_expire = 60 * 60 * 24 * 2, -- 2 days
+ hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
+ symbol_spam = 'NEURAL_SPAM',
+ symbol_ham = 'NEURAL_HAM',
+ max_inputs = nil, -- when PCA is used
+ blacklisted_symbols = {}, -- list of symbols skipped in neural processing
+}
+
+-- Rule structure:
+-- * static config fields (see `default_options`)
+-- * prefix - name or defined prefix
+-- * settings - table of settings indexed by settings id, -1 is used when no settings defined
+
+-- Rule settings element defines elements for specific settings id:
+-- * symbols - static symbols profile (defined by config or extracted from symcache)
+-- * name - name of settings id
+-- * digest - digest of all symbols
+-- * ann - dynamic ANN configuration loaded from Redis
+-- * train - train data for ANN (e.g. the currently trained ANN)
+
+-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
+-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
+-- * version - version of ANN loaded from redis
+-- * redis_key - name of ANN key in Redis
+-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
+-- * distance - distance between set.symbols and set.ann.symbols
+-- * ann - kann object
+
+local settings = {
+ rules = {},
+ prefix = 'rn', -- Neural network default prefix
+ max_profiles = 3, -- Maximum number of NN profiles stored
+}
+
+-- Get module & Redis configuration
+local module_config = rspamd_config:get_all_opt(N)
+settings = lua_util.override_defaults(settings, module_config)
+local redis_params = lua_redis.parse_redis_server('neural')
+
+-- Lua script that checks if we can store a new training vector
+-- Uses the following keys:
+-- key1 - ann key
+-- returns nspam,nham (or nil if locked)
+local redis_lua_script_vectors_len = [[
+ local prefix = KEYS[1]
+ local locked = redis.call('HGET', prefix, 'lock')
+ if locked then
+ local host = redis.call('HGET', prefix, 'hostname') or 'unknown'
+ return string.format('%s:%s', host, locked)
+ end
+ local nspam = 0
+ local nham = 0
+
+ local ret = redis.call('LLEN', prefix .. '_spam')
+ if ret then nspam = tonumber(ret) end
+ ret = redis.call('LLEN', prefix .. '_ham')
+ if ret then nham = tonumber(ret) end
+
+ return {nspam,nham}
+]]
+
+-- Lua script to invalidate ANNs by rank
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - number of elements to leave
+local redis_lua_script_maybe_invalidate = [[
+ local card = redis.call('ZCARD', KEYS[1])
+ local lim = tonumber(KEYS[2])
+ if card > lim then
+ local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
+ for _,k in ipairs(to_delete) do
+ local tb = cjson.decode(k)
+ redis.call('DEL', tb.redis_key)
+ -- Also train vectors
+ redis.call('DEL', tb.redis_key .. '_spam')
+ redis.call('DEL', tb.redis_key .. '_ham')
+ end
+ redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
+ return to_delete
+ else
+ return {}
+ end
+]]
+
+-- Lua script to invalidate ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - current time
+-- key3 - key expire
+-- key4 - hostname
+local redis_lua_script_maybe_lock = [[
+ local locked = redis.call('HGET', KEYS[1], 'lock')
+ local now = tonumber(KEYS[2])
+ if locked then
+ locked = tonumber(locked)
+ local expire = tonumber(KEYS[3])
+ if now > locked and (now - locked) < expire then
+ return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'}
+ end
+ end
+ redis.call('HSET', KEYS[1], 'lock', tostring(now))
+ redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
+ return 1
+]]
+
+-- Lua script to save and unlock ANN in redis
+-- Uses the following keys
+-- key1 - prefix for ANN
+-- key2 - prefix for profile
+-- key3 - compressed ANN
+-- key4 - profile as JSON
+-- key5 - expire in seconds
+-- key6 - current time
+-- key7 - old key
+-- key8 - optional PCA
+local redis_lua_script_save_unlock = [[
+ local now = tonumber(KEYS[6])
+ redis.call('ZADD', KEYS[2], now, KEYS[4])
+ redis.call('HSET', KEYS[1], 'ann', KEYS[3])
+ redis.call('DEL', KEYS[1] .. '_spam')
+ redis.call('DEL', KEYS[1] .. '_ham')
+ redis.call('HDEL', KEYS[1], 'lock')
+ redis.call('HDEL', KEYS[7], 'lock')
+ redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
+ if KEYS[8] then
+ redis.call('HSET', KEYS[1], 'pca', KEYS[8])
+ end
+ return 1
+]]
+
+local redis_script_id = {}
+
+local function load_scripts()
+ redis_script_id.vectors_len = lua_redis.add_redis_script(redis_lua_script_vectors_len,
+ redis_params)
+ redis_script_id.maybe_invalidate = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
+ redis_params)
+ redis_script_id.maybe_lock = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
+ redis_params)
+ redis_script_id.save_unlock = lua_redis.add_redis_script(redis_lua_script_save_unlock,
+ redis_params)
+end
+
+local function create_ann(n, nlayers, rule)
+ -- We ignore number of layers so far when using kann
+ local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
+ local t = rspamd_kann.layer.input(n)
+ t = rspamd_kann.transform.relu(t)
+ t = rspamd_kann.layer.dense(t, nhidden);
+ t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
+ return rspamd_kann.new.kann(t)
+end
+
+-- Fills ANN data for a specific settings element
+local function fill_set_ann(set, ann_key)
+ if not set.ann then
+ set.ann = {
+ symbols = set.symbols,
+ distance = 0,
+ digest = set.digest,
+ redis_key = ann_key,
+ version = 0,
+ }
+ end
+end
+
+-- This function takes all inputs, applies PCA transformation and returns the final
+-- PCA matrix as rspamd_tensor
+local function learn_pca(inputs, max_inputs)
+ local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
+ local eigenvals = scatter_matrix:eigen()
+ -- scatter matrix is not filled with eigenvectors
+ lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
+ local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
+ for i=1,max_inputs do
+ w[i] = scatter_matrix[#scatter_matrix - i + 1]
+ end
+
+ lua_util.debugm(N, 'pca matrix: %s', w)
+
+ return w
+end
+
+-- This function is intended to extend lock for ANN during training
+-- It registers periodic that increases locked key each 30 seconds unless
+-- `set.learning_spawned` is set to `true`
+local function register_lock_extender(rule, set, ev_base, ann_key)
+ rspamd_config:add_periodic(ev_base, 30.0,
+ function()
+ local function redis_lock_extend_cb(_err, _)
+ if _err then
+ rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+ ann_key, _err)
+ else
+ rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
+ ann_key)
+ end
+ end
+
+ if set.learning_spawned then
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ redis_lock_extend_cb, --callback
+ 'HINCRBY', -- command
+ {ann_key, 'lock', '30'}
+ )
+ else
+ lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
+ return false -- do not plan any more updates
+ end
+
+ return true
+ end
+ )
+end
+
+local function can_push_train_vector(rule, task, learn_type, nspam, nham)
+ local train_opts = rule.train
+ local coin = math.random()
+
+ if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
+ rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+ return false
+ end
+
+ if train_opts.learn_mode == 'balanced' then
+ -- Keep balanced training set based on number of spam and ham samples
+ if learn_type == 'spam' then
+ if nspam <= train_opts.max_trains then
+ if nspam > nham then
+ -- Apply sampling
+ local skip_rate = 1.0 - nham / (nspam + 1)
+ if coin < skip_rate - train_opts.classes_bias then
+ rspamd_logger.infox(task,
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
+ return false
+ end
+ end
+ return true
+ else -- Enough learns
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
+ learn_type,
+ nspam)
+ end
+ else
+ if nham <= train_opts.max_trains then
+ if nham > nspam then
+ -- Apply sampling
+ local skip_rate = 1.0 - nspam / (nham + 1)
+ if coin < skip_rate - train_opts.classes_bias then
+ rspamd_logger.infox(task,
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
+ return false
+ end
+ end
+ return true
+ else
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
+ nham)
+ end
+ end
+ else
+ -- Probabilistic learn mode, we just skip learn if we already have enough samples or
+ -- if our coin drop is less than desired probability
+ if learn_type == 'spam' then
+ if nspam <= train_opts.max_trains then
+ if train_opts.spam_skip_prob then
+ if coin <= train_opts.spam_skip_prob then
+ rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
+ coin, train_opts.spam_skip_prob)
+ return false
+ end
+
+ return true
+ end
+ else
+ rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
+ nspam, train_opts.max_trains)
+ end
+ else
+ if nham <= train_opts.max_trains then
+ if train_opts.ham_skip_prob then
+ if coin <= train_opts.ham_skip_prob then
+ rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
+ coin, train_opts.ham_skip_prob)
+ return false
+ end
+
+ return true
+ end
+ else
+ rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
+ nham, train_opts.max_trains)
+ end
+ end
+ end
+
+ return false
+end
+
+-- Closure generator for unlock function
+local function gen_unlock_cb(rule, set, ann_key)
+ return function (err)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
+ rule.prefix, set.name, ann_key, err)
+ else
+ lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
+ rule.prefix, set.name, ann_key)
+ end
+ end
+end
+
+-- Used to generate new ANN key for specific profile
+local function new_ann_key(rule, set, version)
+ local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
+ rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
+
+ return ann_key
+end
+
+local function redis_ann_prefix(rule, settings_name)
+ -- We also need to count metatokens:
+ local n = meta_functions.version
+ return string.format('%s%d_%s_%d_%s',
+ settings.prefix, plugin_ver, rule.prefix, n, settings_name)
+end
+
+-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
+local function spawn_train(params)
+ -- Check training data sanity
+ -- Now we need to join inputs and create the appropriate test vectors
+ local n = #params.set.symbols +
+ meta_functions.rspamd_count_metatokens()
+
+ -- Now we can train ann
+ local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule)
+
+ if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then
+ -- Invalidate ANN as it is definitely invalid
+ -- TODO: add invalidation
+ assert(false)
+ else
+ local inputs, outputs = {}, {}
+
+ -- Used to show sparsed vectors in a convenient format (for debugging only)
+ local function debug_vec(t)
+ local ret = {}
+ for i,v in ipairs(t) do
+ if v ~= 0 then
+ ret[#ret + 1] = string.format('%d=%.2f', i, v)
+ end
+ end
+
+ return ret
+ end
+
+ -- Make training set by joining vectors
+ -- KANN automatically shuffles those samples
+ -- 1.0 is used for spam and -1.0 is used for ham
+ -- It implies that output layer can express that (e.g. tanh output)
+ for _,e in ipairs(params.spam_vec) do
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = {1.0}
+ --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
+ end
+ for _,e in ipairs(params.ham_vec) do
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = {-1.0}
+ --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
+ end
+
+ -- Called in child process
+ local function train()
+ local log_thresh = params.rule.train.max_iterations / 10
+ local seen_nan = false
+
+ local function train_cb(iter, train_cost, value_cost)
+ if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then
+ if train_cost ~= train_cost and not seen_nan then
+ -- We have nan :( try to log lot's of stuff to dig into a problem
+ seen_nan = true
+ rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
+ params.rule.prefix, params.set.name,
+ value_cost)
+ for i,e in ipairs(inputs) do
+ lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
+ debug_vec(e), outputs[i][1])
+ end
+ end
+
+ rspamd_logger.infox(rspamd_config,
+ "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
+ params.rule.prefix, params.set.name,
+ params.ann_key,
+ iter,
+ train_cost,
+ value_cost)
+ end
+ end
+
+ lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
+ params.rule.prefix, params.set.name)
+
+ local ret,err = pcall(train_ann.train1, train_ann,
+ inputs, outputs, {
+ lr = params.rule.train.learning_rate,
+ max_epoch = params.rule.train.max_iterations,
+ cb = train_cb,
+ pca = (params.set.ann or {}).pca
+ })
+
+ if not ret then
+ rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
+ params.rule.prefix, params.set.name, err)
+
+ return nil
+ end
+
+ if not seen_nan then
+ local out = train_ann:save()
+ return out
+ else
+ return nil
+ end
+ end
+
+ params.set.learning_spawned = true
+
+ local function redis_save_cb(err)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
+ params.rule.prefix, params.set.name, params.ann_key, err)
+ lua_redis.redis_make_request_taskless(params.ev_base,
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ false, -- is write
+ gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+ 'HDEL', -- command
+ {params.ann_key, 'lock'}
+ )
+ else
+ rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
+ params.rule.prefix, params.set.name, params.set.ann.redis_key)
+ end
+ end
+
+ local function ann_trained(err, data)
+ params.set.learning_spawned = false
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
+ params.rule.prefix, params.set.name, err)
+ lua_redis.redis_make_request_taskless(params.ev_base,
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ true, -- is write
+ gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+ 'HDEL', -- command
+ {params.ann_key, 'lock'}
+ )
+ else
+ local ann_data = rspamd_util.zstd_compress(data)
+ local pca_data
+
+ fill_set_ann(params.set, params.ann_key)
+ if params.set.ann.pca then
+ pca_data = rspamd_util.zstd_compress(params.set.ann.pca:save())
+ end
+ -- Deserialise ANN from the child process
+ ann_trained = rspamd_kann.load(data)
+ local version = (params.set.ann.version or 0) + 1
+ params.set.ann.version = version
+ params.set.ann.ann = ann_trained
+ params.set.ann.symbols = params.set.symbols
+ params.set.ann.redis_key = new_ann_key(params.rule, params.set, version)
+
+ local profile = {
+ symbols = params.set.symbols,
+ digest = params.set.digest,
+ redis_key = params.set.ann.redis_key,
+ version = version
+ }
+
+ local ucl = require "ucl"
+ local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+
+ rspamd_logger.infox(rspamd_config,
+ 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
+ params.rule.prefix, params.set.name,
+ #data, #ann_data,
+ #(params.set.ann.pca or {}), #(pca_data or {}),
+ params.set.ann.redis_key, params.ann_key)
+
+ lua_redis.exec_redis_script(redis_script_id.save_unlock,
+ {ev_base = params.ev_base, is_write = true},
+ redis_save_cb,
+ {profile.redis_key,
+ redis_ann_prefix(params.rule, params.set.name),
+ ann_data,
+ profile_serialized,
+ tostring(params.rule.ann_expire),
+ tostring(os.time()),
+ params.ann_key, -- old key to unlock...
+ pca_data
+ })
+ end
+ end
+
+ if params.rule.max_inputs then
+ fill_set_ann(params.set, params.ann_key)
+ -- Train PCA in the main process, presumably it is not that long
+ params.set.ann.pca = learn_pca(inputs, params.rule.max_inputs)
+ end
+
+ params.worker:spawn_process{
+ func = train,
+ on_complete = ann_trained,
+ proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name),
+ }
+ -- Spawn learn and register lock extension
+ params.set.learning_spawned = true
+ register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key)
+ return
+
+ end
+end
+
+-- This function is used to adjust profiles and allowed setting ids for each rule
+-- It must be called when all settings are already registered (e.g. at post-init for config)
+local function process_rules_settings()
+ local function process_settings_elt(rule, selt)
+ local profile = rule.profile[selt.name]
+ if profile then
+ -- Use static user defined profile
+ -- Ensure that we have an array...
+ lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
+ rule.prefix, selt.name, profile)
+ if not profile[1] then profile = lua_util.keys(profile) end
+ selt.symbols = profile
+ else
+ lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
+ rule.prefix, selt.name)
+ end
+
+ local function filter_symbols_predicate(sname)
+ if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then
+ return false
+ end
+ local fl = rspamd_config:get_symbol_flags(sname)
+ if fl then
+ fl = lua_util.list_to_hash(fl)
+
+ return not (fl.nostat or fl.idempotent or fl.skip or fl.composite)
+ end
+
+ return false
+ end
+
+ -- Generic stuff
+ table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))
+
+ selt.digest = lua_util.table_digest(selt.symbols)
+ selt.prefix = redis_ann_prefix(rule, selt.name)
+
+ lua_redis.register_prefix(selt.prefix, N,
+ string.format('NN prefix for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name), {
+ persistent = true,
+ type = 'zlist',
+ })
+ -- Versions
+ lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
+ string.format('NN storage for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name), {
+ persistent = true,
+ type = 'hash',
+ })
+ lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N,
+ string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name), {
+ persistent = true,
+ type = 'list',
+ })
+ lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N,
+ string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name), {
+ persistent = true,
+ type = 'list',
+ })
+ end
+
+ for k,rule in pairs(settings.rules) do
+ if not rule.allowed_settings then
+ rule.allowed_settings = {}
+ elseif rule.allowed_settings == 'all' then
+ -- Extract all settings ids
+ rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
+ end
+
+ -- Convert to a map <setting_id> -> true
+ rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
+
+ -- Check if we can work without settings
+ if k == 'default' or type(rule.default) ~= 'boolean' then
+ rule.default = true
+ end
+
+ rule.settings = {}
+
+ if rule.default then
+ local default_settings = {
+ symbols = lua_settings.default_symbols(),
+ name = 'default'
+ }
+
+ process_settings_elt(rule, default_settings)
+ rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
+ end
+
+ -- Now, for each allowed settings, we store sorted symbols + digest
+ -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
+ for s,_ in pairs(rule.allowed_settings) do
+ -- Here, we have a name, set of symbols and
+ local settings_id = s
+ if type(settings_id) ~= 'number' then
+ settings_id = lua_settings.numeric_settings_id(s)
+ end
+ local selt = lua_settings.settings_by_id(settings_id)
+
+ local nelt = {
+ symbols = selt.symbols, -- Already sorted
+ name = selt.name
+ }
+
+ process_settings_elt(rule, nelt)
+ for id,ex in pairs(rule.settings) do
+ if type(ex) == 'table' then
+ if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
+ -- Equal symbols, add reference
+ lua_util.debugm(N, rspamd_config,
+ 'added reference from settings id %s to %s; same symbols',
+ nelt.name, ex.name)
+ rule.settings[settings_id] = id
+ nelt = nil
+ end
+ end
+ end
+
+ if nelt then
+ rule.settings[settings_id] = nelt
+ lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
+ nelt.name, settings_id, rule.prefix)
+ end
+ end
+ end
+end
+
+-- Extract settings element for a specific settings id
+local function get_rule_settings(task, rule)
+ local sid = task:get_settings_id() or -1
+ local set = rule.settings[sid]
+
+ if not set then return nil end
+
+ while type(set) == 'number' do
+ -- Reference to another settings!
+ set = rule.settings[set]
+ end
+
+ return set
+end
+
+local function result_to_vector(task, profile)
+ if not profile.zeros then
+ -- Fill zeros vector
+ local zeros = {}
+ for i=1,meta_functions.count_metatokens() do
+ zeros[i] = 0.0
+ end
+ for _,_ in ipairs(profile.symbols) do
+ zeros[#zeros + 1] = 0.0
+ end
+ profile.zeros = zeros
+ end
+
+ local vec = lua_util.shallowcopy(profile.zeros)
+ local mt = meta_functions.rspamd_gen_metatokens(task)
+
+ for i,v in ipairs(mt) do
+ vec[i] = v
+ end
+
+ task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
+
+ return vec
+end
+
+return {
+ can_push_train_vector = can_push_train_vector,
+ create_ann = create_ann,
+ default_options = default_options,
+ gen_unlock_cb = gen_unlock_cb,
+ get_rule_settings = get_rule_settings,
+ load_scripts = load_scripts,
+ module_config = module_config,
+ new_ann_key = new_ann_key,
+ plugin_ver = plugin_ver,
+ process_rules_settings = process_rules_settings,
+ redis_ann_prefix = redis_ann_prefix,
+ redis_params = redis_params,
+ redis_script_id = redis_script_id,
+ result_to_vector = result_to_vector,
+ settings = settings,
+ spawn_train = spawn_train,
+}
-- Define default controller paths, could be overridden in local.d/controller.lua
local controller_plugin_paths = {
+ maps = dofile(local_rules .. "/controller/maps.lua"),
+ neural = dofile(local_rules .. "/controller/neural.lua"),
selectors = dofile(local_rules .. "/controller/selectors.lua"),
- maps = dofile(local_rules .. "/controller/maps.lua")
}
if rspamd_util.file_exists(local_conf .. '/controller.lua') then
plug, path, type(attrs))
end
end
-end
\ No newline at end of file
+end
--- /dev/null
+--[[
+Copyright (c) 2020, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local neural_common = require "plugins/neural"
+local ts = require("tableshape").types
+local ucl = require "ucl"
+
+local E = {}
+
+-- Controller neural plugin
+
+local learn_request_schema = ts.shape{
+ ham_vec = ts.array_of(ts.array_of(ts.number)),
+ rule = ts.string:is_optional(),
+ spam_vec = ts.array_of(ts.array_of(ts.number)),
+}
+
+local function handle_learn(task, conn)
+ local parser = ucl.parser()
+ local ok, err = parser:parse_text(task:get_rawbody())
+ if not ok then
+ conn:send_error(400, err)
+ return
+ end
+ local req_params = parser:get_object()
+
+ ok, err = learn_request_schema:transform(req_params)
+ if not ok then
+ conn:send_error(400, err)
+ return
+ end
+
+ local rule_name = req_params.rule or 'default'
+ local rule = neural_common.settings.rules[rule_name]
+ local set = neural_common.get_rule_settings(task, rule)
+ local version = ((set.ann or E).version or 0) + 1
+
+ neural_common.spawn_train{
+ ev_base = task:get_ev_base(),
+ ann_key = neural_common.new_ann_key(rule, set, version),
+ set = set,
+ rule = rule,
+ ham_vec = req_params.ham_vec,
+ spam_vec = req_params.spam_vec,
+ worker = task:get_worker(),
+ }
+
+ conn:send_string('{"success" : true}')
+end
+
+rspamd_config:add_post_init(neural_common.process_rules_settings)
+
+return {
+ learn = {
+ handler = handle_learn,
+ enable = true,
+ need_task = true,
+ },
+}
return
end
-local rspamd_logger = require "rspamd_logger"
-local rspamd_util = require "rspamd_util"
-local rspamd_kann = require "rspamd_kann"
-local rspamd_text = require "rspamd_text"
+local fun = require "fun"
local lua_redis = require "lua_redis"
local lua_util = require "lua_util"
+local lua_verdict = require "lua_verdict"
+local neural_common = require "plugins/neural"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_logger = require "rspamd_logger"
local rspamd_tensor = require "rspamd_tensor"
-local fun = require "fun"
-local lua_settings = require "lua_settings"
-local meta_functions = require "lua_meta"
+local rspamd_text = require "rspamd_text"
+local rspamd_util = require "rspamd_util"
local ts = require("tableshape").types
-local lua_verdict = require "lua_verdict"
+
local N = "neural"
--- Used in prefix to avoid wrong ANN to be loaded
-local plugin_ver = '2'
+local settings = neural_common.settings
-- Module vars
local default_options = {
classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
ham_skip_prob = 0.0, -- proportional mode: ham skip probability
- store_pool_only = false, -- store tokens in mempool variable only (disables autotrain);
+ store_pool_only = false, -- store tokens in cache only (disables autotrain);
-- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
},
watch_interval = 60.0,
local has_blas = rspamd_tensor.has_blas()
local text_cookie = rspamd_text.cookie
--- Rule structure:
--- * static config fields (see `default_options`)
--- * prefix - name or defined prefix
--- * settings - table of settings indexed by settings id, -1 is used when no settings defined
-
--- Rule settings element defines elements for specific settings id:
--- * symbols - static symbols profile (defined by config or extracted from symcache)
--- * name - name of settings id
--- * digest - digest of all symbols
--- * ann - dynamic ANN configuration loaded from Redis
--- * train - train data for ANN (e.g. the currently trained ANN)
-
--- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
--- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
--- * version - version of ANN loaded from redis
--- * redis_key - name of ANN key in Redis
--- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
--- * distance - distance between set.symbols and set.ann.symbols
--- * ann - kann object
-
-local settings = {
- rules = {},
- prefix = 'rn', -- Neural network default prefix
- max_profiles = 3, -- Maximum number of NN profiles stored
-}
-
-local module_config = rspamd_config:get_all_opt("neural")
-if not module_config then
- -- Legacy
- module_config = rspamd_config:get_all_opt("fann_redis")
-end
-
-
--- Lua script that checks if we can store a new training vector
--- Uses the following keys:
--- key1 - ann key
--- returns nspam,nham (or nil if locked)
-local redis_lua_script_vectors_len = [[
- local prefix = KEYS[1]
- local locked = redis.call('HGET', prefix, 'lock')
- if locked then
- local host = redis.call('HGET', prefix, 'hostname') or 'unknown'
- return string.format('%s:%s', host, locked)
- end
- local nspam = 0
- local nham = 0
-
- local ret = redis.call('LLEN', prefix .. '_spam')
- if ret then nspam = tonumber(ret) end
- ret = redis.call('LLEN', prefix .. '_ham')
- if ret then nham = tonumber(ret) end
-
- return {nspam,nham}
-]]
-local redis_lua_script_vectors_len_id = nil
-
--- Lua script to invalidate ANNs by rank
--- Uses the following keys
--- key1 - prefix for keys
--- key2 - number of elements to leave
-local redis_lua_script_maybe_invalidate = [[
- local card = redis.call('ZCARD', KEYS[1])
- local lim = tonumber(KEYS[2])
- if card > lim then
- local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
- for _,k in ipairs(to_delete) do
- local tb = cjson.decode(k)
- redis.call('DEL', tb.redis_key)
- -- Also train vectors
- redis.call('DEL', tb.redis_key .. '_spam')
- redis.call('DEL', tb.redis_key .. '_ham')
- end
- redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
- return to_delete
- else
- return {}
- end
-]]
-local redis_maybe_invalidate_id = nil
-
--- Lua script to invalidate ANN from redis
--- Uses the following keys
--- key1 - prefix for keys
--- key2 - current time
--- key3 - key expire
--- key4 - hostname
-local redis_lua_script_maybe_lock = [[
- local locked = redis.call('HGET', KEYS[1], 'lock')
- local now = tonumber(KEYS[2])
- if locked then
- locked = tonumber(locked)
- local expire = tonumber(KEYS[3])
- if now > locked and (now - locked) < expire then
- return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'}
- end
- end
- redis.call('HSET', KEYS[1], 'lock', tostring(now))
- redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
- return 1
-]]
-local redis_maybe_lock_id = nil
-
--- Lua script to save and unlock ANN in redis
--- Uses the following keys
--- key1 - prefix for ANN
--- key2 - prefix for profile
--- key3 - compressed ANN
--- key4 - profile as JSON
--- key5 - expire in seconds
--- key6 - current time
--- key7 - old key
--- key8 - optional PCA
-local redis_lua_script_save_unlock = [[
- local now = tonumber(KEYS[6])
- redis.call('ZADD', KEYS[2], now, KEYS[4])
- redis.call('HSET', KEYS[1], 'ann', KEYS[3])
- redis.call('DEL', KEYS[1] .. '_spam')
- redis.call('DEL', KEYS[1] .. '_ham')
- redis.call('HDEL', KEYS[1], 'lock')
- redis.call('HDEL', KEYS[7], 'lock')
- redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
- if KEYS[8] then
- redis.call('HSET', KEYS[1], 'pca', KEYS[8])
- end
- return 1
-]]
-local redis_save_unlock_id = nil
-
-local redis_params
-
-local function load_scripts(params)
- redis_lua_script_vectors_len_id = lua_redis.add_redis_script(redis_lua_script_vectors_len,
- params)
- redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
- params)
- redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
- params)
- redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock,
- params)
-end
-
-local function result_to_vector(task, profile)
- if not profile.zeros then
- -- Fill zeros vector
- local zeros = {}
- for i=1,meta_functions.count_metatokens() do
- zeros[i] = 0.0
- end
- for _,_ in ipairs(profile.symbols) do
- zeros[#zeros + 1] = 0.0
- end
- profile.zeros = zeros
- end
-
- local vec = lua_util.shallowcopy(profile.zeros)
- local mt = meta_functions.rspamd_gen_metatokens(task)
-
- for i,v in ipairs(mt) do
- vec[i] = v
- end
-
- task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
-
- return vec
-end
-
--- Used to generate new ANN key for specific profile
-local function new_ann_key(rule, set, version)
- local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
- rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
-
- return ann_key
-end
-
--- Extract settings element for a specific settings id
-local function get_rule_settings(task, rule)
- local sid = task:get_settings_id() or -1
-
- local set = rule.settings[sid]
-
- if not set then return nil end
-
- while type(set) == 'number' do
- -- Reference to another settings!
- set = rule.settings[set]
- end
-
- return set
-end
-
--- Generate redis prefix for specific rule and specific settings
-local function redis_ann_prefix(rule, settings_name)
- -- We also need to count metatokens:
- local n = meta_functions.version
- return string.format('%s%d_%s_%d_%s',
- settings.prefix, plugin_ver, rule.prefix, n, settings_name)
-end
-
-- Creates and stores ANN profile in Redis
local function new_ann_profile(task, rule, set, version)
- local ann_key = new_ann_key(rule, set, version)
+ local ann_key = neural_common.new_ann_key(rule, set, version, settings)
local profile = {
symbols = set.symbols,
local ann
local profile
- local set = get_rule_settings(task, rule)
+ local set = neural_common.get_rule_settings(task, rule)
if set then
if set.ann then
ann = set.ann.ann
end
if ann then
- local vec = result_to_vector(task, profile)
+ local vec = neural_common.result_to_vector(task, profile)
local score
local out = ann:apply1(vec, set.ann.pca)
end
end
-local function create_ann(n, nlayers, rule)
- -- We ignore number of layers so far when using kann
- local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
- local t = rspamd_kann.layer.input(n)
- t = rspamd_kann.transform.relu(t)
- t = rspamd_kann.layer.dense(t, nhidden);
- t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
- return rspamd_kann.new.kann(t)
-end
-
-local function can_push_train_vector(rule, task, learn_type, nspam, nham)
- local train_opts = rule.train
- local coin = math.random()
-
- if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
- rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
- return false
- end
-
- if train_opts.learn_mode == 'balanced' then
- -- Keep balanced training set based on number of spam and ham samples
- if learn_type == 'spam' then
- if nspam <= train_opts.max_trains then
- if nspam > nham then
- -- Apply sampling
- local skip_rate = 1.0 - nham / (nspam + 1)
- if coin < skip_rate - train_opts.classes_bias then
- rspamd_logger.infox(task,
- 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
- learn_type,
- skip_rate - train_opts.classes_bias,
- nspam, nham)
- return false
- end
- end
- return true
- else -- Enough learns
- rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
- learn_type,
- nspam)
- end
- else
- if nham <= train_opts.max_trains then
- if nham > nspam then
- -- Apply sampling
- local skip_rate = 1.0 - nspam / (nham + 1)
- if coin < skip_rate - train_opts.classes_bias then
- rspamd_logger.infox(task,
- 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
- learn_type,
- skip_rate - train_opts.classes_bias,
- nspam, nham)
- return false
- end
- end
- return true
- else
- rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
- nham)
- end
- end
- else
- -- Probabilistic learn mode, we just skip learn if we already have enough samples or
- -- if our coin drop is less than desired probability
- if learn_type == 'spam' then
- if nspam <= train_opts.max_trains then
- if train_opts.spam_skip_prob then
- if coin <= train_opts.spam_skip_prob then
- rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
- coin, train_opts.spam_skip_prob)
- return false
- end
-
- return true
- end
- else
- rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
- nspam, train_opts.max_trains)
- end
- else
- if nham <= train_opts.max_trains then
- if train_opts.ham_skip_prob then
- if coin <= train_opts.ham_skip_prob then
- rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
- coin, train_opts.ham_skip_prob)
- return false
- end
-
- return true
- end
- else
- rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
- nham, train_opts.max_trains)
- end
- end
- end
-
- return false
-end
-
local function ann_push_task_result(rule, task, verdict, score, set)
local train_opts = rule.train
local learn_spam, learn_ham
local skip_reason = 'unknown'
- if train_opts.autotrain then
+ if not train_opts.store_pool_only and train_opts.autotrain then
if train_opts.spam_score then
learn_spam = score >= train_opts.spam_score
learn_ham = false
learn_spam = false
- -- Explicitly store tokens in a mempool variable
- local vec = result_to_vector(task, set)
- task:get_mempool():set_variable('neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
- task:get_mempool():set_variable('neural_profile_digest', set.digest)
+ -- Explicitly store tokens in cache
+ local vec = neural_common.result_to_vector(task, set)
+ task:cache_set('neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
+ task:cache_set('neural_profile_digest', set.digest)
skip_reason = 'store_pool_only has been set'
end
end
if not err and type(data) == 'table' then
local nspam,nham = data[1],data[2]
- if can_push_train_vector(rule, task, learn_type, nspam, nham) then
- local vec = result_to_vector(task, set)
+ if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
+ local vec = neural_common.result_to_vector(task, set)
local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
local target_key = set.ann.redis_key .. '_' .. learn_type
set.name)
end
- lua_redis.exec_redis_script(redis_lua_script_vectors_len_id,
+ lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
{task = task, is_write = false},
vectors_len_cb,
{
--- Offline training logic
--- Closure generator for unlock function
-local function gen_unlock_cb(rule, set, ann_key)
- return function (err)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
- rule.prefix, set.name, ann_key, err)
- else
- lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
- rule.prefix, set.name, ann_key)
- end
- end
-end
-
--- This function is intended to extend lock for ANN during training
--- It registers periodic that increases locked key each 30 seconds unless
--- `set.learning_spawned` is set to `true`
-local function register_lock_extender(rule, set, ev_base, ann_key)
- rspamd_config:add_periodic(ev_base, 30.0,
- function()
- local function redis_lock_extend_cb(_err, _)
- if _err then
- rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
- ann_key, _err)
- else
- rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
- ann_key)
- end
- end
-
- if set.learning_spawned then
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- redis_lock_extend_cb, --callback
- 'HINCRBY', -- command
- {ann_key, 'lock', '30'}
- )
- else
- lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
- return false -- do not plan any more updates
- end
-
- return true
- end
- )
-end
-
--- This function takes all inputs, applies PCA transformation and returns the final
--- PCA matrix as rspamd_tensor
-local function learn_pca(inputs, max_inputs)
- local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
- local eigenvals = scatter_matrix:eigen()
- -- scatter matrix is not filled with eigenvectors
- lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
- local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
- for i=1,max_inputs do
- w[i] = scatter_matrix[#scatter_matrix - i + 1]
- end
-
- lua_util.debugm(N, 'pca matrix: %s', w)
-
- return w
-end
-
--- Fills ANN data for a specific settings element
-local function fill_set_ann(set, ann_key)
- if not set.ann then
- set.ann = {
- symbols = set.symbols,
- distance = 0,
- digest = set.digest,
- redis_key = ann_key,
- version = 0,
- }
- end
-end
-
--- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
-local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
- -- Check training data sanity
- -- Now we need to join inputs and create the appropriate test vectors
- local n = #set.symbols +
- meta_functions.rspamd_count_metatokens()
-
- -- Now we can train ann
- local train_ann = create_ann(rule.max_inputs or n, 3, rule)
-
- if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
- -- Invalidate ANN as it is definitely invalid
- -- TODO: add invalidation
- assert(false)
- else
- local inputs, outputs = {}, {}
-
- -- Used to show sparsed vectors in a convenient format (for debugging only)
- local function debug_vec(t)
- local ret = {}
- for i,v in ipairs(t) do
- if v ~= 0 then
- ret[#ret + 1] = string.format('%d=%.2f', i, v)
- end
- end
-
- return ret
- end
-
- -- Make training set by joining vectors
- -- KANN automatically shuffles those samples
- -- 1.0 is used for spam and -1.0 is used for ham
- -- It implies that output layer can express that (e.g. tanh output)
- for _,e in ipairs(spam_vec) do
- inputs[#inputs + 1] = e
- outputs[#outputs + 1] = {1.0}
- --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
- end
- for _,e in ipairs(ham_vec) do
- inputs[#inputs + 1] = e
- outputs[#outputs + 1] = {-1.0}
- --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
- end
-
- -- Called in child process
- local function train()
- local log_thresh = rule.train.max_iterations / 10
- local seen_nan = false
-
- local function train_cb(iter, train_cost, value_cost)
- if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
- if train_cost ~= train_cost and not seen_nan then
- -- We have nan :( try to log lot's of stuff to dig into a problem
- seen_nan = true
- rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
- rule.prefix, set.name,
- value_cost)
- for i,e in ipairs(inputs) do
- lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
- debug_vec(e), outputs[i][1])
- end
- end
-
- rspamd_logger.infox(rspamd_config,
- "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
- rule.prefix, set.name,
- ann_key,
- iter,
- train_cost,
- value_cost)
- end
- end
-
- lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
- rule.prefix, set.name)
-
- local ret,err = pcall(train_ann.train1, train_ann,
- inputs, outputs, {
- lr = rule.train.learning_rate,
- max_epoch = rule.train.max_iterations,
- cb = train_cb,
- pca = (set.ann or {}).pca
- })
-
- if not ret then
- rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
- rule.prefix, set.name, err)
-
- return nil
- end
-
- if not seen_nan then
- local out = train_ann:save()
- return out
- else
- return nil
- end
- end
-
- set.learning_spawned = true
-
- local function redis_save_cb(err)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
- rule.prefix, set.name, ann_key, err)
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- {ann_key, 'lock'}
- )
- else
- rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
- rule.prefix, set.name, set.ann.redis_key)
- end
- end
-
- local function ann_trained(err, data)
- set.learning_spawned = false
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
- rule.prefix, set.name, err)
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- {ann_key, 'lock'}
- )
- else
- local ann_data = rspamd_util.zstd_compress(data)
- local pca_data
-
- fill_set_ann(set, ann_key)
- if set.ann.pca then
- pca_data = rspamd_util.zstd_compress(set.ann.pca:save())
- end
- -- Deserialise ANN from the child process
- ann_trained = rspamd_kann.load(data)
- local version = (set.ann.version or 0) + 1
- set.ann.version = version
- set.ann.ann = ann_trained
- set.ann.symbols = set.symbols
- set.ann.redis_key = new_ann_key(rule, set, version)
-
- local profile = {
- symbols = set.symbols,
- digest = set.digest,
- redis_key = set.ann.redis_key,
- version = version
- }
-
- local ucl = require "ucl"
- local profile_serialized = ucl.to_format(profile, 'json-compact', true)
-
- rspamd_logger.infox(rspamd_config,
- 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
- rule.prefix, set.name,
- #data, #ann_data,
- #(set.ann.pca or {}), #(pca_data or {}),
- set.ann.redis_key, ann_key)
-
- lua_redis.exec_redis_script(redis_save_unlock_id,
- {ev_base = ev_base, is_write = true},
- redis_save_cb,
- {profile.redis_key,
- redis_ann_prefix(rule, set.name),
- ann_data,
- profile_serialized,
- tostring(rule.ann_expire),
- tostring(os.time()),
- ann_key, -- old key to unlock...
- pca_data
- })
- end
- end
-
- if rule.max_inputs then
- fill_set_ann(set, ann_key)
- -- Train PCA in the main process, presumably it is not that long
- set.ann.pca = learn_pca(inputs, rule.max_inputs)
- end
-
- worker:spawn_process{
- func = train,
- on_complete = ann_trained,
- proctitle = string.format("ANN train for %s/%s", rule.prefix, set.name),
- }
- end
- -- Spawn learn and register lock extension
- set.learning_spawned = true
- register_lock_extender(rule, set, ev_base, ann_key)
-end
-
-- Utility to extract and split saved training vectors to a table of tables
local function process_training_vectors(data)
return fun.totable(fun.map(function(tok)
rule.redis,
nil,
true, -- is write
- gen_unlock_cb(rule, set, ann_key), --callback
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
'HDEL', -- command
{ann_key, 'lock'}
)
else
-- Decompress and convert to numbers each training vector
ham_elts = process_training_vectors(data)
- spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts)
+ neural_common.spawn_train({worker = worker, ev_base = ev_base,
+ rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts,
+ spam_vec = spam_elts})
end
end
rule.redis,
nil,
true, -- is write
- gen_unlock_cb(rule, set, ann_key), --callback
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
'HDEL', -- command
{ann_key, 'lock'}
)
-- Call Redis script that tries to acquire a lock
-- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
-- ANN is locked by another host (or a process, meh)
- lua_redis.exec_redis_script(redis_maybe_lock_id,
+ lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
{ev_base = ev_base, is_write = true},
redis_lock_cb,
{
end
if type(set) == 'table' then
- lua_redis.exec_redis_script(redis_maybe_invalidate_id,
+ lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
{ev_base = ev_base, is_write = true},
invalidate_cb,
{set.prefix, tostring(settings.max_profiles)})
end
for _,rule in pairs(settings.rules) do
- local set = get_rule_settings(task, rule)
+ local set = neural_common.get_rule_settings(task, rule)
if set then
ann_push_task_result(rule, task, verdict, score, set)
end
--- This function is used to adjust profiles and allowed setting ids for each rule
--- It must be called when all settings are already registered (e.g. at post-init for config)
-local function process_rules_settings()
- local function process_settings_elt(rule, selt)
- local profile = rule.profile[selt.name]
- if profile then
- -- Use static user defined profile
- -- Ensure that we have an array...
- lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
- rule.prefix, selt.name, profile)
- if not profile[1] then profile = lua_util.keys(profile) end
- selt.symbols = profile
- else
- lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
- rule.prefix, selt.name)
- end
-
- local function filter_symbols_predicate(sname)
- if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then
- return false
- end
- local fl = rspamd_config:get_symbol_flags(sname)
- if fl then
- fl = lua_util.list_to_hash(fl)
-
- return not (fl.nostat or fl.idempotent or fl.skip or fl.composite)
- end
-
- return false
- end
-
- -- Generic stuff
- table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))
-
- selt.digest = lua_util.table_digest(selt.symbols)
- selt.prefix = redis_ann_prefix(rule, selt.name)
-
- lua_redis.register_prefix(selt.prefix, N,
- string.format('NN prefix for rule "%s"; settings id "%s"',
- rule.prefix, selt.name), {
- persistent = true,
- type = 'zlist',
- })
- -- Versions
- lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
- string.format('NN storage for rule "%s"; settings id "%s"',
- rule.prefix, selt.name), {
- persistent = true,
- type = 'hash',
- })
- lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N,
- string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
- rule.prefix, selt.name), {
- persistent = true,
- type = 'list',
- })
- lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N,
- string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
- rule.prefix, selt.name), {
- persistent = true,
- type = 'list',
- })
- end
-
- for k,rule in pairs(settings.rules) do
- if not rule.allowed_settings then
- rule.allowed_settings = {}
- elseif rule.allowed_settings == 'all' then
- -- Extract all settings ids
- rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
- end
-
- -- Convert to a map <setting_id> -> true
- rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
-
- -- Check if we can work without settings
- if k == 'default' or type(rule.default) ~= 'boolean' then
- rule.default = true
- end
-
- rule.settings = {}
-
- if rule.default then
- local default_settings = {
- symbols = lua_settings.default_symbols(),
- name = 'default'
- }
-
- process_settings_elt(rule, default_settings)
- rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
- end
-
- -- Now, for each allowed settings, we store sorted symbols + digest
- -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
- for s,_ in pairs(rule.allowed_settings) do
- -- Here, we have a name, set of symbols and
- local settings_id = s
- if type(settings_id) ~= 'number' then
- settings_id = lua_settings.numeric_settings_id(s)
- end
- local selt = lua_settings.settings_by_id(settings_id)
-
- local nelt = {
- symbols = selt.symbols, -- Already sorted
- name = selt.name
- }
-
- process_settings_elt(rule, nelt)
- for id,ex in pairs(rule.settings) do
- if type(ex) == 'table' then
- if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
- -- Equal symbols, add reference
- lua_util.debugm(N, rspamd_config,
- 'added reference from settings id %s to %s; same symbols',
- nelt.name, ex.name)
- rule.settings[settings_id] = id
- nelt = nil
- end
- end
- end
-
- if nelt then
- rule.settings[settings_id] = nelt
- lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
- nelt.name, settings_id, rule.prefix)
- end
- end
- end
-end
-
-redis_params = lua_redis.parse_redis_server('neural')
-
-if not redis_params then
- redis_params = lua_redis.parse_redis_server('fann_redis')
-end
-
-- Initialization part
-if not (module_config and type(module_config) == 'table') or not redis_params then
+if not (neural_common.module_config and type(neural_common.module_config) == 'table')
+ or not neural_common.redis_params then
rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
lua_util.disable_module(N, "redis")
return
end
-local rules = module_config['rules']
+local rules = neural_common.module_config['rules']
if not rules then
-- Use legacy configuration
rules = {}
- rules['default'] = module_config
+ rules['default'] = neural_common.module_config
end
local id = rspamd_config:register_symbol({
callback = ann_scores_filter
})
-settings = lua_util.override_defaults(settings, module_config)
-settings.rules = {} -- Reset unless validated further in the cycle
+neural_common.settings.rules = {} -- Reset unless validated further in the cycle
if settings.blacklisted_symbols and settings.blacklisted_symbols[1] then
-- Transform to hash for simplicity
-- Check all rules
for k,r in pairs(rules) do
local rule_elt = lua_util.override_defaults(default_options, r)
- rule_elt['redis'] = redis_params
+ rule_elt['redis'] = neural_common.redis_params
rule_elt['anns'] = {} -- Store ANNs here
if not rule_elt.prefix then
callback = ann_push_vector
})
+-- We also need to deal with settings
+rspamd_config:add_post_init(neural_common.process_rules_settings)
+
-- Add training scripts
for _,rule in pairs(settings.rules) do
- load_scripts(rule.redis)
- -- We also need to deal with settings
- rspamd_config:add_post_init(process_rules_settings)
+ neural_common.load_scripts(rule.redis)
-- This function will check ANNs in Redis when a worker is loaded
rspamd_config:add_on_load(function(cfg, ev_base, worker)
if worker:is_scanner() then
+++ /dev/null
-*** Settings ***
-Suite Setup Neural Setup
-Suite Teardown Neural Teardown
-Library Process
-Library ${TESTDIR}/lib/rspamd.py
-Resource ${TESTDIR}/lib/rspamd.robot
-Variables ${TESTDIR}/lib/vars.py
-
-*** Variables ***
-${URL_TLD} ${TESTDIR}/../lua/unit/test_tld.dat
-${CONFIG} ${TESTDIR}/configs/neural.conf
-${MESSAGE} ${TESTDIR}/messages/spam_message.eml
-${REDIS_SCOPE} Suite
-${RSPAMD_SCOPE} Suite
-
-*** Test Cases ***
-Train
- Sleep 2s Wait for redis mess
- FOR ${INDEX} IN RANGE 0 10
- Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"]}
- Expect Symbol SPAM_SYMBOL
- Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"]}
- Expect Symbol HAM_SYMBOL
- END
-
-Check Neural HAM
- Sleep 2s Wait for neural to be loaded
- Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
- Expect Symbol NEURAL_HAM_SHORT
- Do Not Expect Symbol NEURAL_SPAM_SHORT
- Expect Symbol NEURAL_HAM_SHORT_PCA
- Do Not Expect Symbol NEURAL_SPAM_SHORT_PCA
-
-Check Neural SPAM
- Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
- Expect Symbol NEURAL_SPAM_SHORT
- Do Not Expect Symbol NEURAL_HAM_SHORT
- Expect Symbol NEURAL_SPAM_SHORT_PCA
- Do Not Expect Symbol NEURAL_HAM_SHORT_PCA
-
-
-Train INVERSE
- FOR ${INDEX} IN RANGE 0 10
- Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"]; SPAM_SYMBOL = -5;}
- Expect Symbol SPAM_SYMBOL
- Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"]; HAM_SYMBOL = 5;}
- Expect Symbol HAM_SYMBOL
- END
-
-Check Neural HAM INVERSE
- Sleep 2s Wait for neural to be loaded
- Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"]}
- Expect Symbol NEURAL_SPAM_SHORT
- Expect Symbol NEURAL_SPAM_SHORT_PCA
- Do Not Expect Symbol NEURAL_HAM_SHORT
- Do Not Expect Symbol NEURAL_HAM_SHORT_PCA
-
-Check Neural SPAM INVERSE
- Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"]}
- Expect Symbol NEURAL_HAM_SHORT
- Expect Symbol NEURAL_HAM_SHORT_PCA
- Do Not Expect Symbol NEURAL_SPAM_SHORT
- Do Not Expect Symbol NEURAL_SPAM_SHORT_PCA
-
-*** Keywords ***
-Neural Setup
- ${TMPDIR} = Make Temporary Directory
- Set Suite Variable ${TMPDIR}
- Run Redis
- Generic Setup
-
-Neural Teardown
- Shutdown Process With Children ${REDIS_PID}
- Normal Teardown
--- /dev/null
+*** Settings ***
+Suite Setup Neural Setup
+Suite Teardown Neural Teardown
+Library Process
+Library ${TESTDIR}/lib/rspamd.py
+Resource ${TESTDIR}/lib/rspamd.robot
+Variables ${TESTDIR}/lib/vars.py
+
+*** Variables ***
+${URL_TLD} ${TESTDIR}/../lua/unit/test_tld.dat
+${CONFIG} ${TESTDIR}/configs/neural.conf
+${MESSAGE} ${TESTDIR}/messages/spam_message.eml
+${REDIS_SCOPE} Suite
+${RSPAMD_SCOPE} Suite
+
+*** Test Cases ***
+Train
+ Sleep 2s Wait for redis mess
+ FOR ${INDEX} IN RANGE 0 10
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"]}
+ Expect Symbol SPAM_SYMBOL
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"]}
+ Expect Symbol HAM_SYMBOL
+ END
+
+Check Neural HAM
+ Sleep 2s Wait for neural to be loaded
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+ Expect Symbol NEURAL_HAM_SHORT
+ Do Not Expect Symbol NEURAL_SPAM_SHORT
+ Expect Symbol NEURAL_HAM_SHORT_PCA
+ Do Not Expect Symbol NEURAL_SPAM_SHORT_PCA
+
+Check Neural SPAM
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+ Expect Symbol NEURAL_SPAM_SHORT
+ Do Not Expect Symbol NEURAL_HAM_SHORT
+ Expect Symbol NEURAL_SPAM_SHORT_PCA
+ Do Not Expect Symbol NEURAL_HAM_SHORT_PCA
+
+
+Train INVERSE
+ FOR ${INDEX} IN RANGE 0 10
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"]; SPAM_SYMBOL = -5;}
+ Expect Symbol SPAM_SYMBOL
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"]; HAM_SYMBOL = 5;}
+ Expect Symbol HAM_SYMBOL
+ END
+
+Check Neural HAM INVERSE
+ Sleep 2s Wait for neural to be loaded
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"]}
+ Expect Symbol NEURAL_SPAM_SHORT
+ Expect Symbol NEURAL_SPAM_SHORT_PCA
+ Do Not Expect Symbol NEURAL_HAM_SHORT
+ Do Not Expect Symbol NEURAL_HAM_SHORT_PCA
+
+Check Neural SPAM INVERSE
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"]}
+ Expect Symbol NEURAL_HAM_SHORT
+ Expect Symbol NEURAL_HAM_SHORT_PCA
+ Do Not Expect Symbol NEURAL_SPAM_SHORT
+ Do Not Expect Symbol NEURAL_SPAM_SHORT_PCA
+
+*** Keywords ***
+Neural Setup
+ ${TMPDIR} = Make Temporary Directory
+ Set Suite Variable ${TMPDIR}
+ Run Redis
+ Generic Setup
+
+Neural Teardown
+ Shutdown Process With Children ${REDIS_PID}
+ Normal Teardown
--- /dev/null
+*** Settings ***
+Suite Setup Neural Setup
+Suite Teardown Neural Teardown
+Library Process
+Library ${TESTDIR}/lib/rspamd.py
+Resource ${TESTDIR}/lib/rspamd.robot
+Variables ${TESTDIR}/lib/vars.py
+
+*** Variables ***
+${URL_TLD} ${TESTDIR}/../lua/unit/test_tld.dat
+${CONFIG} ${TESTDIR}/configs/neural_noauto.conf
+${MESSAGE} ${TESTDIR}/messages/spam_message.eml
+${REDIS_SCOPE} Suite
+${RSPAMD_SCOPE} Suite
+
+*** Test Cases ***
+Collect training vectors & train manually
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL","SAVE_NN_ROW"]}
+ Expect Symbol SPAM_SYMBOL
+ # Save neural inputs for later
+ ${SPAM_ROW} = Get File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0]
+ Remove File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0]
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL","SAVE_NN_ROW"]}
+ Expect Symbol HAM_SYMBOL
+ # Save neural inputs for later
+ ${HAM_ROW} = Get File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0]
+ Remove File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0]
+ ${HAM_ROW} = Run ${RSPAMADM} lua -a ${HAM_ROW} ${TESTDIR}/util/nn_unpack.lua
+ ${HAM_ROW} = Evaluate json.loads("${HAM_ROW}")
+ ${SPAM_ROW} = Run ${RSPAMADM} lua -a ${SPAM_ROW} ${TESTDIR}/util/nn_unpack.lua
+ ${SPAM_ROW} = Evaluate json.loads("${SPAM_ROW}")
+ ${HAM_VEC} = Evaluate [${HAM_ROW}] * 10
+ ${SPAM_VEC} = Evaluate [${SPAM_ROW}] * 10
+ ${json1} = Evaluate json.dumps({"spam_vec": ${SPAM_VEC}, "ham_vec": ${HAM_VEC}, "rule": "SHORT"})
+ # Save variables for use in inverse training
+ Set Suite Variable ${HAM_VEC}
+ Set Suite Variable ${SPAM_VEC}
+ HTTP POST ${LOCAL_ADDR} ${PORT_CONTROLLER} /plugins/neural/learn ${json1}
+ Sleep 2s Wait for neural to be loaded
+
+Check Neural HAM
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+ Do Not Expect Symbol NEURAL_SPAM_SHORT
+ Expect Symbol NEURAL_HAM_SHORT
+
+Check Neural SPAM
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+ Do Not Expect Symbol NEURAL_HAM_SHORT
+ Expect Symbol NEURAL_SPAM_SHORT
+
+Train inverse
+ ${json2} = Evaluate json.dumps({"spam_vec": ${HAM_VEC}, "ham_vec": ${SPAM_VEC}, "rule": "SHORT"})
+ HTTP POST ${LOCAL_ADDR} ${PORT_CONTROLLER} /plugins/neural/learn ${json2}
+ Sleep 2s Wait for neural to be loaded
+
+Check Neural HAM - inverse
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+ Do Not Expect Symbol NEURAL_HAM_SHORT
+ Expect Symbol NEURAL_SPAM_SHORT
+
+Check Neural SPAM - inverse
+ Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+ Do Not Expect Symbol NEURAL_SPAM_SHORT
+ Expect Symbol NEURAL_HAM_SHORT
+
+*** Keywords ***
+Neural Setup
+ ${TMPDIR} = Make Temporary Directory
+ Set Suite Variable ${TMPDIR}
+ Run Redis
+ Generic Setup
+
+Neural Teardown
+ Shutdown Process With Children ${REDIS_PID}
+ Normal Teardown
--- /dev/null
+options = {
+ url_tld = "${URL_TLD}"
+ pidfile = "${TMPDIR}/rspamd.pid"
+ lua_path = "${INSTALLROOT}/share/rspamd/lib/?.lua"
+ filters = [];
+ explicit_modules = ["settings"];
+}
+
+logging = {
+ type = "file",
+ level = "debug"
+ filename = "${TMPDIR}/rspamd.log"
+ log_usec = true;
+}
+metric = {
+ name = "default",
+ actions = {
+ reject = 100500,
+ add_header = 50500,
+ }
+ unknown_weight = 1
+}
+worker {
+ type = normal
+ bind_socket = ${LOCAL_ADDR}:${PORT_NORMAL}
+ count = 1
+ task_timeout = 10s;
+}
+worker {
+ type = controller
+ bind_socket = ${LOCAL_ADDR}:${PORT_CONTROLLER}
+ count = 1
+ secure_ip = ["127.0.0.1", "::1"];
+ stats_path = "${TMPDIR}/stats.ucl"
+}
+
+modules {
+ path = "${TESTDIR}/../../src/plugins/lua/"
+}
+
+lua = "${TESTDIR}/lua/test_coverage.lua";
+
+neural {
+ rules {
+ SHORT {
+ train {
+ learning_rate = 0.001;
+ max_usages = 2;
+ spam_score = 1;
+ ham_score = -1;
+ max_trains = 10;
+ max_iterations = 250;
+ store_pool_only = true;
+ }
+ symbol_spam = "NEURAL_SPAM_SHORT";
+ symbol_ham = "NEURAL_HAM_SHORT";
+ ann_expire = 86400;
+ watch_interval = 0.5;
+ }
+ SHORT_PCA {
+ train {
+ learning_rate = 0.001;
+ max_usages = 2;
+ spam_score = 1;
+ ham_score = -1;
+ max_trains = 10;
+ max_iterations = 250;
+ store_pool_only = true;
+ }
+ symbol_spam = "NEURAL_SPAM_SHORT_PCA";
+ symbol_ham = "NEURAL_HAM_SHORT_PCA";
+ ann_expire = 86400;
+ watch_interval = 0.5;
+ max_inputs = 2;
+ }
+ }
+ allow_local = true;
+
+}
+redis {
+ servers = "${REDIS_ADDR}:${REDIS_PORT}";
+ expand_keys = true;
+}
+
+lua = "${TESTDIR}/lua/neural.lua";
... ELSE Make Temporary Directory
Set Directory Ownership ${tmpdir} ${RSPAMD_USER} ${RSPAMD_GROUP}
${template} = Get File ${CONFIG}
+ # TODO: stop using this; we have Lupa now
FOR ${i} IN @{vargs}
${newvalue} = Replace Variables ${${i}}
Set To Dictionary ${d} ${i}=${newvalue}
Log ${config}
Create File ${tmpdir}/rspamd.conf ${config}
${result} = Run Process ${RSPAMD} -u ${RSPAMD_USER} -g ${RSPAMD_GROUP}
- ... -c ${tmpdir}/rspamd.conf env:TMPDIR=${tmpdir} env:DBDIR=${tmpdir} env:LD_LIBRARY_PATH=${TESTDIR}/../../contrib/aho-corasick stdout=DEVNULL stderr=DEVNULL
+ ... -c ${tmpdir}/rspamd.conf env:TMPDIR=${tmpdir} env:DBDIR=${tmpdir} env:LD_LIBRARY_PATH=${TESTDIR}/../../contrib/aho-corasick
+ ... env:RSPAMD_INSTALLROOT=${INSTALLROOT} stdout=DEVNULL stderr=DEVNULL
Run Keyword If ${result.rc} != 0 Log ${result.stderr}
Should Be Equal As Integers ${result.rc} 0
Wait Until Keyword Succeeds 10x 1 sec Check Pidfile ${tmpdir}/rspamd.pid timeout=0.5s
+local logger = require "rspamd_logger"
+
rspamd_config:register_symbol({
name = 'SPAM_SYMBOL',
score = 5.0,
callback = function()
return true, 'Fires always'
end
-})
\ No newline at end of file
+})
+
+rspamd_config.SAVE_NN_ROW = {
+ callback = function(task)
+ local fname = os.tmpname()
+ task:cache_set('nn_row_tmpfile', fname)
+ return true, 1.0, fname
+ end
+}
+
+rspamd_config.SAVE_NN_ROW_IDEMPOTENT = {
+ callback = function(task)
+ local function tohex(str)
+ return (str:gsub('.', function (c)
+ return string.format('%02X', string.byte(c))
+ end))
+ end
+ local fname = task:cache_get('nn_row_tmpfile')
+ if not fname then
+ return
+ end
+ local f, err = io.open(fname, 'w')
+ if not f then
+ logger.errx(task, err)
+ return
+ end
+ f:write(tohex(task:cache_get('neural_vec_mpack') or ''))
+ f:close()
+ return
+ end,
+ type = 'idempotent',
+ flags = 'explicit_disable',
+ priority = 10,
+}
+
+dofile(rspamd_env.INSTALLROOT .. "/share/rspamd/rules/controller/init.lua")
--- /dev/null
+local ucl = require "ucl"
+
+local function unhex(str)
+ return (str:gsub('..', function (cc)
+ return string.char(tonumber(cc, 16))
+ end))
+end
+
+local parser = ucl.parser()
+local ok, err = parser:parse_string(unhex(arg[1]), 'msgpack')
+if not ok then
+ io.stderr:write(err)
+ os.exit(1)
+end
+
+print(ucl.to_format(parser:get_object(), 'json-compact'))