From: Andrew Lewis Date: Thu, 17 Dec 2020 09:28:09 +0000 (+0200) Subject: [Feature] Add controller endpoint for training neural X-Git-Tag: 2.7~65^2 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=refs%2Fpull%2F3570%2Fhead;p=rspamd.git [Feature] Add controller endpoint for training neural - Move neural functions to library - Parameterise spawn_train - neural plugin: Fix store_pool_only when autotrain is true - neural plugin: Use cache_set instead of mempool - Add test --- diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua new file mode 100644 index 000000000..4d4c44b5d --- /dev/null +++ b/lualib/plugins/neural.lua @@ -0,0 +1,779 @@ +--[[ +Copyright (c) 2020, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local fun = require "fun" +local lua_redis = require "lua_redis" +local lua_settings = require "lua_settings" +local lua_util = require "lua_util" +local meta_functions = require "lua_meta" +local rspamd_kann = require "rspamd_kann" +local rspamd_logger = require "rspamd_logger" +local rspamd_tensor = require "rspamd_tensor" +local rspamd_util = require "rspamd_util" + +local N = 'neural' + +-- Used in prefix to avoid wrong ANN to be loaded +local plugin_ver = '2' + +-- Module vars +local default_options = { + train = { + max_trains = 1000, + max_epoch = 1000, + max_usages = 10, + max_iterations = 25, -- Torch style + mse = 0.001, + autotrain = true, + train_prob = 1.0, + learn_threads = 1, + learn_mode = 'balanced', -- Possible values: balanced, proportional + learning_rate = 0.01, + classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias) + spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1) + ham_skip_prob = 0.0, -- proportional mode: ham skip probability + store_pool_only = false, -- store tokens in cache only (disables autotrain); + -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest + }, + watch_interval = 60.0, + lock_expire = 600, + learning_spawned = false, + ann_expire = 60 * 60 * 24 * 2, -- 2 days + hidden_layer_mult = 1.5, -- number of neurons in the hidden layer + symbol_spam = 'NEURAL_SPAM', + symbol_ham = 'NEURAL_HAM', + max_inputs = nil, -- when PCA is used + blacklisted_symbols = {}, -- list of symbols skipped in neural processing +} + +-- Rule structure: +-- * static config fields (see `default_options`) +-- * prefix - name or defined prefix +-- * settings - table of settings indexed by settings id, -1 is used when no settings defined + +-- Rule settings element defines elements for specific settings id: +-- * symbols - static symbols profile (defined by config or extracted from symcache) +-- * name - name of settings id +-- * digest - digest of all symbols +-- * ann - dynamic ANN configuration loaded from Redis +-- * train - train data for ANN (e.g. the currently trained ANN) + +-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN +-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically +-- * version - version of ANN loaded from redis +-- * redis_key - name of ANN key in Redis +-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols) +-- * distance - distance between set.symbols and set.ann.symbols +-- * ann - kann object + +local settings = { + rules = {}, + prefix = 'rn', -- Neural network default prefix + max_profiles = 3, -- Maximum number of NN profiles stored +} + +-- Get module & Redis configuration +local module_config = rspamd_config:get_all_opt(N) +settings = lua_util.override_defaults(settings, module_config) +local redis_params = lua_redis.parse_redis_server('neural') + +-- Lua script that checks if we can store a new training vector +-- Uses the following keys: +-- key1 - ann key +-- returns nspam,nham (or nil if locked) +local redis_lua_script_vectors_len = [[ + local prefix = KEYS[1] + local locked = redis.call('HGET', prefix, 'lock') + if locked then + local host = redis.call('HGET', prefix, 'hostname') or 'unknown' + return string.format('%s:%s', host, locked) + end + local nspam = 0 + local nham = 0 + + local ret = redis.call('LLEN', prefix .. '_spam') + if ret then nspam = tonumber(ret) end + ret = redis.call('LLEN', prefix .. '_ham') + if ret then nham = tonumber(ret) end + + return {nspam,nham} +]] + +-- Lua script to invalidate ANNs by rank +-- Uses the following keys +-- key1 - prefix for keys +-- key2 - number of elements to leave +local redis_lua_script_maybe_invalidate = [[ + local card = redis.call('ZCARD', KEYS[1]) + local lim = tonumber(KEYS[2]) + if card > lim then + local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1) + for _,k in ipairs(to_delete) do + local tb = cjson.decode(k) + redis.call('DEL', tb.redis_key) + -- Also train vectors + redis.call('DEL', tb.redis_key .. '_spam') + redis.call('DEL', tb.redis_key .. '_ham') + end + redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1) + return to_delete + else + return {} + end +]] + +-- Lua script to invalidate ANN from redis +-- Uses the following keys +-- key1 - prefix for keys +-- key2 - current time +-- key3 - key expire +-- key4 - hostname +local redis_lua_script_maybe_lock = [[ + local locked = redis.call('HGET', KEYS[1], 'lock') + local now = tonumber(KEYS[2]) + if locked then + locked = tonumber(locked) + local expire = tonumber(KEYS[3]) + if now > locked and (now - locked) < expire then + return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'} + end + end + redis.call('HSET', KEYS[1], 'lock', tostring(now)) + redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) + return 1 +]] + +-- Lua script to save and unlock ANN in redis +-- Uses the following keys +-- key1 - prefix for ANN +-- key2 - prefix for profile +-- key3 - compressed ANN +-- key4 - profile as JSON +-- key5 - expire in seconds +-- key6 - current time +-- key7 - old key +-- key8 - optional PCA +local redis_lua_script_save_unlock = [[ + local now = tonumber(KEYS[6]) + redis.call('ZADD', KEYS[2], now, KEYS[4]) + redis.call('HSET', KEYS[1], 'ann', KEYS[3]) + redis.call('DEL', KEYS[1] .. '_spam') + redis.call('DEL', KEYS[1] .. '_ham') + redis.call('HDEL', KEYS[1], 'lock') + redis.call('HDEL', KEYS[7], 'lock') + redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) + if KEYS[8] then + redis.call('HSET', KEYS[1], 'pca', KEYS[8]) + end + return 1 +]] + +local redis_script_id = {} + +local function load_scripts() + redis_script_id.vectors_len = lua_redis.add_redis_script(redis_lua_script_vectors_len, + redis_params) + redis_script_id.maybe_invalidate = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate, + redis_params) + redis_script_id.maybe_lock = lua_redis.add_redis_script(redis_lua_script_maybe_lock, + redis_params) + redis_script_id.save_unlock = lua_redis.add_redis_script(redis_lua_script_save_unlock, + redis_params) +end + +local function create_ann(n, nlayers, rule) + -- We ignore number of layers so far when using kann + local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0) + local t = rspamd_kann.layer.input(n) + t = rspamd_kann.transform.relu(t) + t = rspamd_kann.layer.dense(t, nhidden); + t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg) + return rspamd_kann.new.kann(t) +end + +-- Fills ANN data for a specific settings element +local function fill_set_ann(set, ann_key) + if not set.ann then + set.ann = { + symbols = set.symbols, + distance = 0, + digest = set.digest, + redis_key = ann_key, + version = 0, + } + end +end + +-- This function takes all inputs, applies PCA transformation and returns the final +-- PCA matrix as rspamd_tensor +local function learn_pca(inputs, max_inputs) + local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs)) + local eigenvals = scatter_matrix:eigen() + -- scatter matrix is not filled with eigenvectors + lua_util.debugm(N, 'eigenvalues: %s', eigenvals) + local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1]) + for i=1,max_inputs do + w[i] = scatter_matrix[#scatter_matrix - i + 1] + end + + lua_util.debugm(N, 'pca matrix: %s', w) + + return w +end + +-- This function is intended to extend lock for ANN during training +-- It registers periodic that increases locked key each 30 seconds unless +-- `set.learning_spawned` is set to `true` +local function register_lock_extender(rule, set, ev_base, ann_key) + rspamd_config:add_periodic(ev_base, 30.0, + function() + local function redis_lock_extend_cb(_err, _) + if _err then + rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', + ann_key, _err) + else + rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', + ann_key) + end + end + + if set.learning_spawned then + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + redis_lock_extend_cb, --callback + 'HINCRBY', -- command + {ann_key, 'lock', '30'} + ) + else + lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") + return false -- do not plan any more updates + end + + return true + end + ) +end + +local function can_push_train_vector(rule, task, learn_type, nspam, nham) + local train_opts = rule.train + local coin = math.random() + + if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then + rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) + return false + end + + if train_opts.learn_mode == 'balanced' then + -- Keep balanced training set based on number of spam and ham samples + if learn_type == 'spam' then + if nspam <= train_opts.max_trains then + if nspam > nham then + -- Apply sampling + local skip_rate = 1.0 - nham / (nspam + 1) + if coin < skip_rate - train_opts.classes_bias then + rspamd_logger.infox(task, + 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', + learn_type, + skip_rate - train_opts.classes_bias, + nspam, nham) + return false + end + end + return true + else -- Enough learns + rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', + learn_type, + nspam) + end + else + if nham <= train_opts.max_trains then + if nham > nspam then + -- Apply sampling + local skip_rate = 1.0 - nspam / (nham + 1) + if coin < skip_rate - train_opts.classes_bias then + rspamd_logger.infox(task, + 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', + learn_type, + skip_rate - train_opts.classes_bias, + nspam, nham) + return false + end + end + return true + else + rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type, + nham) + end + end + else + -- Probabilistic learn mode, we just skip learn if we already have enough samples or + -- if our coin drop is less than desired probability + if learn_type == 'spam' then + if nspam <= train_opts.max_trains then + if train_opts.spam_skip_prob then + if coin <= train_opts.spam_skip_prob then + rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type, + coin, train_opts.spam_skip_prob) + return false + end + + return true + end + else + rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type, + nspam, train_opts.max_trains) + end + else + if nham <= train_opts.max_trains then + if train_opts.ham_skip_prob then + if coin <= train_opts.ham_skip_prob then + rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type, + coin, train_opts.ham_skip_prob) + return false + end + + return true + end + else + rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type, + nham, train_opts.max_trains) + end + end + end + + return false +end + +-- Closure generator for unlock function +local function gen_unlock_cb(rule, set, ann_key) + return function (err) + if err then + rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s', + rule.prefix, set.name, ann_key, err) + else + lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s', + rule.prefix, set.name, ann_key) + end + end +end + +-- Used to generate new ANN key for specific profile +local function new_ann_key(rule, set, version) + local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix, + rule.prefix, set.name, set.digest:sub(1, 8), tostring(version)) + + return ann_key +end + +local function redis_ann_prefix(rule, settings_name) + -- We also need to count metatokens: + local n = meta_functions.version + return string.format('%s%d_%s_%d_%s', + settings.prefix, plugin_ver, rule.prefix, n, settings_name) +end + +-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis +local function spawn_train(params) + -- Check training data sanity + -- Now we need to join inputs and create the appropriate test vectors + local n = #params.set.symbols + + meta_functions.rspamd_count_metatokens() + + -- Now we can train ann + local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule) + + if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then + -- Invalidate ANN as it is definitely invalid + -- TODO: add invalidation + assert(false) + else + local inputs, outputs = {}, {} + + -- Used to show sparsed vectors in a convenient format (for debugging only) + local function debug_vec(t) + local ret = {} + for i,v in ipairs(t) do + if v ~= 0 then + ret[#ret + 1] = string.format('%d=%.2f', i, v) + end + end + + return ret + end + + -- Make training set by joining vectors + -- KANN automatically shuffles those samples + -- 1.0 is used for spam and -1.0 is used for ham + -- It implies that output layer can express that (e.g. tanh output) + for _,e in ipairs(params.spam_vec) do + inputs[#inputs + 1] = e + outputs[#outputs + 1] = {1.0} + --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e)) + end + for _,e in ipairs(params.ham_vec) do + inputs[#inputs + 1] = e + outputs[#outputs + 1] = {-1.0} + --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e)) + end + + -- Called in child process + local function train() + local log_thresh = params.rule.train.max_iterations / 10 + local seen_nan = false + + local function train_cb(iter, train_cost, value_cost) + if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then + if train_cost ~= train_cost and not seen_nan then + -- We have nan :( try to log lot's of stuff to dig into a problem + seen_nan = true + rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s', + params.rule.prefix, params.set.name, + value_cost) + for i,e in ipairs(inputs) do + lua_util.debugm(N, rspamd_config, 'train vector %s -> %s', + debug_vec(e), outputs[i][1]) + end + end + + rspamd_logger.infox(rspamd_config, + "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s", + params.rule.prefix, params.set.name, + params.ann_key, + iter, + train_cost, + value_cost) + end + end + + lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started", + params.rule.prefix, params.set.name) + + local ret,err = pcall(train_ann.train1, train_ann, + inputs, outputs, { + lr = params.rule.train.learning_rate, + max_epoch = params.rule.train.max_iterations, + cb = train_cb, + pca = (params.set.ann or {}).pca + }) + + if not ret then + rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s", + params.rule.prefix, params.set.name, err) + + return nil + end + + if not seen_nan then + local out = train_ann:save() + return out + else + return nil + end + end + + params.set.learning_spawned = true + + local function redis_save_cb(err) + if err then + rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s', + params.rule.prefix, params.set.name, params.ann_key, err) + lua_redis.redis_make_request_taskless(params.ev_base, + rspamd_config, + params.rule.redis, + nil, + false, -- is write + gen_unlock_cb(params.rule, params.set, params.ann_key), --callback + 'HDEL', -- command + {params.ann_key, 'lock'} + ) + else + rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', + params.rule.prefix, params.set.name, params.set.ann.redis_key) + end + end + + local function ann_trained(err, data) + params.set.learning_spawned = false + if err then + rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', + params.rule.prefix, params.set.name, err) + lua_redis.redis_make_request_taskless(params.ev_base, + rspamd_config, + params.rule.redis, + nil, + true, -- is write + gen_unlock_cb(params.rule, params.set, params.ann_key), --callback + 'HDEL', -- command + {params.ann_key, 'lock'} + ) + else + local ann_data = rspamd_util.zstd_compress(data) + local pca_data + + fill_set_ann(params.set, params.ann_key) + if params.set.ann.pca then + pca_data = rspamd_util.zstd_compress(params.set.ann.pca:save()) + end + -- Deserialise ANN from the child process + ann_trained = rspamd_kann.load(data) + local version = (params.set.ann.version or 0) + 1 + params.set.ann.version = version + params.set.ann.ann = ann_trained + params.set.ann.symbols = params.set.symbols + params.set.ann.redis_key = new_ann_key(params.rule, params.set, version) + + local profile = { + symbols = params.set.symbols, + digest = params.set.digest, + redis_key = params.set.ann.redis_key, + version = version + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + + rspamd_logger.infox(rspamd_config, + 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)', + params.rule.prefix, params.set.name, + #data, #ann_data, + #(params.set.ann.pca or {}), #(pca_data or {}), + params.set.ann.redis_key, params.ann_key) + + lua_redis.exec_redis_script(redis_script_id.save_unlock, + {ev_base = params.ev_base, is_write = true}, + redis_save_cb, + {profile.redis_key, + redis_ann_prefix(params.rule, params.set.name), + ann_data, + profile_serialized, + tostring(params.rule.ann_expire), + tostring(os.time()), + params.ann_key, -- old key to unlock... + pca_data + }) + end + end + + if params.rule.max_inputs then + fill_set_ann(params.set, params.ann_key) + -- Train PCA in the main process, presumably it is not that long + params.set.ann.pca = learn_pca(inputs, params.rule.max_inputs) + end + + params.worker:spawn_process{ + func = train, + on_complete = ann_trained, + proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name), + } + -- Spawn learn and register lock extension + params.set.learning_spawned = true + register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key) + return + + end +end + +-- This function is used to adjust profiles and allowed setting ids for each rule +-- It must be called when all settings are already registered (e.g. at post-init for config) +local function process_rules_settings() + local function process_settings_elt(rule, selt) + local profile = rule.profile[selt.name] + if profile then + -- Use static user defined profile + -- Ensure that we have an array... + lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s", + rule.prefix, selt.name, profile) + if not profile[1] then profile = lua_util.keys(profile) end + selt.symbols = profile + else + lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)", + rule.prefix, selt.name) + end + + local function filter_symbols_predicate(sname) + if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then + return false + end + local fl = rspamd_config:get_symbol_flags(sname) + if fl then + fl = lua_util.list_to_hash(fl) + + return not (fl.nostat or fl.idempotent or fl.skip or fl.composite) + end + + return false + end + + -- Generic stuff + table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols))) + + selt.digest = lua_util.table_digest(selt.symbols) + selt.prefix = redis_ann_prefix(rule, selt.name) + + lua_redis.register_prefix(selt.prefix, N, + string.format('NN prefix for rule "%s"; settings id "%s"', + rule.prefix, selt.name), { + persistent = true, + type = 'zlist', + }) + -- Versions + lua_redis.register_prefix(selt.prefix .. '_\\d+', N, + string.format('NN storage for rule "%s"; settings id "%s"', + rule.prefix, selt.name), { + persistent = true, + type = 'hash', + }) + lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N, + string.format('NN learning set (spam) for rule "%s"; settings id "%s"', + rule.prefix, selt.name), { + persistent = true, + type = 'list', + }) + lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N, + string.format('NN learning set (spam) for rule "%s"; settings id "%s"', + rule.prefix, selt.name), { + persistent = true, + type = 'list', + }) + end + + for k,rule in pairs(settings.rules) do + if not rule.allowed_settings then + rule.allowed_settings = {} + elseif rule.allowed_settings == 'all' then + -- Extract all settings ids + rule.allowed_settings = lua_util.keys(lua_settings.all_settings()) + end + + -- Convert to a map -> true + rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings) + + -- Check if we can work without settings + if k == 'default' or type(rule.default) ~= 'boolean' then + rule.default = true + end + + rule.settings = {} + + if rule.default then + local default_settings = { + symbols = lua_settings.default_symbols(), + name = 'default' + } + + process_settings_elt(rule, default_settings) + rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32 + end + + -- Now, for each allowed settings, we store sorted symbols + digest + -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest } + for s,_ in pairs(rule.allowed_settings) do + -- Here, we have a name, set of symbols and + local settings_id = s + if type(settings_id) ~= 'number' then + settings_id = lua_settings.numeric_settings_id(s) + end + local selt = lua_settings.settings_by_id(settings_id) + + local nelt = { + symbols = selt.symbols, -- Already sorted + name = selt.name + } + + process_settings_elt(rule, nelt) + for id,ex in pairs(rule.settings) do + if type(ex) == 'table' then + if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then + -- Equal symbols, add reference + lua_util.debugm(N, rspamd_config, + 'added reference from settings id %s to %s; same symbols', + nelt.name, ex.name) + rule.settings[settings_id] = id + nelt = nil + end + end + end + + if nelt then + rule.settings[settings_id] = nelt + lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s', + nelt.name, settings_id, rule.prefix) + end + end + end +end + +-- Extract settings element for a specific settings id +local function get_rule_settings(task, rule) + local sid = task:get_settings_id() or -1 + local set = rule.settings[sid] + + if not set then return nil end + + while type(set) == 'number' do + -- Reference to another settings! + set = rule.settings[set] + end + + return set +end + +local function result_to_vector(task, profile) + if not profile.zeros then + -- Fill zeros vector + local zeros = {} + for i=1,meta_functions.count_metatokens() do + zeros[i] = 0.0 + end + for _,_ in ipairs(profile.symbols) do + zeros[#zeros + 1] = 0.0 + end + profile.zeros = zeros + end + + local vec = lua_util.shallowcopy(profile.zeros) + local mt = meta_functions.rspamd_gen_metatokens(task) + + for i,v in ipairs(mt) do + vec[i] = v + end + + task:process_ann_tokens(profile.symbols, vec, #mt, 0.1) + + return vec +end + +return { + can_push_train_vector = can_push_train_vector, + create_ann = create_ann, + default_options = default_options, + gen_unlock_cb = gen_unlock_cb, + get_rule_settings = get_rule_settings, + load_scripts = load_scripts, + module_config = module_config, + new_ann_key = new_ann_key, + plugin_ver = plugin_ver, + process_rules_settings = process_rules_settings, + redis_ann_prefix = redis_ann_prefix, + redis_params = redis_params, + redis_script_id = redis_script_id, + result_to_vector = result_to_vector, + settings = settings, + spawn_train = spawn_train, +} diff --git a/rules/controller/init.lua b/rules/controller/init.lua index e5204da63..136081ddc 100644 --- a/rules/controller/init.lua +++ b/rules/controller/init.lua @@ -25,8 +25,9 @@ local rspamd_logger = require "rspamd_logger" -- Define default controller paths, could be overridden in local.d/controller.lua local controller_plugin_paths = { + maps = dofile(local_rules .. "/controller/maps.lua"), + neural = dofile(local_rules .. "/controller/neural.lua"), selectors = dofile(local_rules .. "/controller/selectors.lua"), - maps = dofile(local_rules .. "/controller/maps.lua") } if rspamd_util.file_exists(local_conf .. '/controller.lua') then @@ -62,4 +63,4 @@ for plug,paths in pairs(controller_plugin_paths) do plug, path, type(attrs)) end end -end \ No newline at end of file +end diff --git a/rules/controller/neural.lua b/rules/controller/neural.lua new file mode 100644 index 000000000..3207e008c --- /dev/null +++ b/rules/controller/neural.lua @@ -0,0 +1,72 @@ +--[[ +Copyright (c) 2020, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local neural_common = require "plugins/neural" +local ts = require("tableshape").types +local ucl = require "ucl" + +local E = {} + +-- Controller neural plugin + +local learn_request_schema = ts.shape{ + ham_vec = ts.array_of(ts.array_of(ts.number)), + rule = ts.string:is_optional(), + spam_vec = ts.array_of(ts.array_of(ts.number)), +} + +local function handle_learn(task, conn) + local parser = ucl.parser() + local ok, err = parser:parse_text(task:get_rawbody()) + if not ok then + conn:send_error(400, err) + return + end + local req_params = parser:get_object() + + ok, err = learn_request_schema:transform(req_params) + if not ok then + conn:send_error(400, err) + return + end + + local rule_name = req_params.rule or 'default' + local rule = neural_common.settings.rules[rule_name] + local set = neural_common.get_rule_settings(task, rule) + local version = ((set.ann or E).version or 0) + 1 + + neural_common.spawn_train{ + ev_base = task:get_ev_base(), + ann_key = neural_common.new_ann_key(rule, set, version), + set = set, + rule = rule, + ham_vec = req_params.ham_vec, + spam_vec = req_params.spam_vec, + worker = task:get_worker(), + } + + conn:send_string('{"success" : true}') +end + +rspamd_config:add_post_init(neural_common.process_rules_settings) + +return { + learn = { + handler = handle_learn, + enable = true, + need_task = true, + }, +} diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 5eab75d76..3d1c387a5 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -19,22 +19,21 @@ if confighelp then return end -local rspamd_logger = require "rspamd_logger" -local rspamd_util = require "rspamd_util" -local rspamd_kann = require "rspamd_kann" -local rspamd_text = require "rspamd_text" +local fun = require "fun" local lua_redis = require "lua_redis" local lua_util = require "lua_util" +local lua_verdict = require "lua_verdict" +local neural_common = require "plugins/neural" +local rspamd_kann = require "rspamd_kann" +local rspamd_logger = require "rspamd_logger" local rspamd_tensor = require "rspamd_tensor" -local fun = require "fun" -local lua_settings = require "lua_settings" -local meta_functions = require "lua_meta" +local rspamd_text = require "rspamd_text" +local rspamd_util = require "rspamd_util" local ts = require("tableshape").types -local lua_verdict = require "lua_verdict" + local N = "neural" --- Used in prefix to avoid wrong ANN to be loaded -local plugin_ver = '2' +local settings = neural_common.settings -- Module vars local default_options = { @@ -52,7 +51,7 @@ local default_options = { classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias) spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1) ham_skip_prob = 0.0, -- proportional mode: ham skip probability - store_pool_only = false, -- store tokens in mempool variable only (disables autotrain); + store_pool_only = false, -- store tokens in cache only (disables autotrain); -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest }, watch_interval = 60.0, @@ -77,207 +76,9 @@ local redis_profile_schema = ts.shape{ local has_blas = rspamd_tensor.has_blas() local text_cookie = rspamd_text.cookie --- Rule structure: --- * static config fields (see `default_options`) --- * prefix - name or defined prefix --- * settings - table of settings indexed by settings id, -1 is used when no settings defined - --- Rule settings element defines elements for specific settings id: --- * symbols - static symbols profile (defined by config or extracted from symcache) --- * name - name of settings id --- * digest - digest of all symbols --- * ann - dynamic ANN configuration loaded from Redis --- * train - train data for ANN (e.g. the currently trained ANN) - --- Settings ANN table is loaded from Redis and represents dynamic profile for ANN --- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically --- * version - version of ANN loaded from redis --- * redis_key - name of ANN key in Redis --- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols) --- * distance - distance between set.symbols and set.ann.symbols --- * ann - kann object - -local settings = { - rules = {}, - prefix = 'rn', -- Neural network default prefix - max_profiles = 3, -- Maximum number of NN profiles stored -} - -local module_config = rspamd_config:get_all_opt("neural") -if not module_config then - -- Legacy - module_config = rspamd_config:get_all_opt("fann_redis") -end - - --- Lua script that checks if we can store a new training vector --- Uses the following keys: --- key1 - ann key --- returns nspam,nham (or nil if locked) -local redis_lua_script_vectors_len = [[ - local prefix = KEYS[1] - local locked = redis.call('HGET', prefix, 'lock') - if locked then - local host = redis.call('HGET', prefix, 'hostname') or 'unknown' - return string.format('%s:%s', host, locked) - end - local nspam = 0 - local nham = 0 - - local ret = redis.call('LLEN', prefix .. '_spam') - if ret then nspam = tonumber(ret) end - ret = redis.call('LLEN', prefix .. '_ham') - if ret then nham = tonumber(ret) end - - return {nspam,nham} -]] -local redis_lua_script_vectors_len_id = nil - --- Lua script to invalidate ANNs by rank --- Uses the following keys --- key1 - prefix for keys --- key2 - number of elements to leave -local redis_lua_script_maybe_invalidate = [[ - local card = redis.call('ZCARD', KEYS[1]) - local lim = tonumber(KEYS[2]) - if card > lim then - local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1) - for _,k in ipairs(to_delete) do - local tb = cjson.decode(k) - redis.call('DEL', tb.redis_key) - -- Also train vectors - redis.call('DEL', tb.redis_key .. '_spam') - redis.call('DEL', tb.redis_key .. '_ham') - end - redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1) - return to_delete - else - return {} - end -]] -local redis_maybe_invalidate_id = nil - --- Lua script to invalidate ANN from redis --- Uses the following keys --- key1 - prefix for keys --- key2 - current time --- key3 - key expire --- key4 - hostname -local redis_lua_script_maybe_lock = [[ - local locked = redis.call('HGET', KEYS[1], 'lock') - local now = tonumber(KEYS[2]) - if locked then - locked = tonumber(locked) - local expire = tonumber(KEYS[3]) - if now > locked and (now - locked) < expire then - return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'} - end - end - redis.call('HSET', KEYS[1], 'lock', tostring(now)) - redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) - return 1 -]] -local redis_maybe_lock_id = nil - --- Lua script to save and unlock ANN in redis --- Uses the following keys --- key1 - prefix for ANN --- key2 - prefix for profile --- key3 - compressed ANN --- key4 - profile as JSON --- key5 - expire in seconds --- key6 - current time --- key7 - old key --- key8 - optional PCA -local redis_lua_script_save_unlock = [[ - local now = tonumber(KEYS[6]) - redis.call('ZADD', KEYS[2], now, KEYS[4]) - redis.call('HSET', KEYS[1], 'ann', KEYS[3]) - redis.call('DEL', KEYS[1] .. '_spam') - redis.call('DEL', KEYS[1] .. '_ham') - redis.call('HDEL', KEYS[1], 'lock') - redis.call('HDEL', KEYS[7], 'lock') - redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) - if KEYS[8] then - redis.call('HSET', KEYS[1], 'pca', KEYS[8]) - end - return 1 -]] -local redis_save_unlock_id = nil - -local redis_params - -local function load_scripts(params) - redis_lua_script_vectors_len_id = lua_redis.add_redis_script(redis_lua_script_vectors_len, - params) - redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate, - params) - redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock, - params) - redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock, - params) -end - -local function result_to_vector(task, profile) - if not profile.zeros then - -- Fill zeros vector - local zeros = {} - for i=1,meta_functions.count_metatokens() do - zeros[i] = 0.0 - end - for _,_ in ipairs(profile.symbols) do - zeros[#zeros + 1] = 0.0 - end - profile.zeros = zeros - end - - local vec = lua_util.shallowcopy(profile.zeros) - local mt = meta_functions.rspamd_gen_metatokens(task) - - for i,v in ipairs(mt) do - vec[i] = v - end - - task:process_ann_tokens(profile.symbols, vec, #mt, 0.1) - - return vec -end - --- Used to generate new ANN key for specific profile -local function new_ann_key(rule, set, version) - local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix, - rule.prefix, set.name, set.digest:sub(1, 8), tostring(version)) - - return ann_key -end - --- Extract settings element for a specific settings id -local function get_rule_settings(task, rule) - local sid = task:get_settings_id() or -1 - - local set = rule.settings[sid] - - if not set then return nil end - - while type(set) == 'number' do - -- Reference to another settings! - set = rule.settings[set] - end - - return set -end - --- Generate redis prefix for specific rule and specific settings -local function redis_ann_prefix(rule, settings_name) - -- We also need to count metatokens: - local n = meta_functions.version - return string.format('%s%d_%s_%d_%s', - settings.prefix, plugin_ver, rule.prefix, n, settings_name) -end - -- Creates and stores ANN profile in Redis local function new_ann_profile(task, rule, set, version) - local ann_key = new_ann_key(rule, set, version) + local ann_key = neural_common.new_ann_key(rule, set, version, settings) local profile = { symbols = set.symbols, @@ -321,7 +122,7 @@ local function ann_scores_filter(task) local ann local profile - local set = get_rule_settings(task, rule) + local set = neural_common.get_rule_settings(task, rule) if set then if set.ann then ann = set.ann.ann @@ -336,7 +137,7 @@ local function ann_scores_filter(task) end if ann then - local vec = result_to_vector(task, profile) + local vec = neural_common.result_to_vector(task, profile) local score local out = ann:apply1(vec, set.ann.pca) @@ -357,112 +158,12 @@ local function ann_scores_filter(task) end end -local function create_ann(n, nlayers, rule) - -- We ignore number of layers so far when using kann - local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0) - local t = rspamd_kann.layer.input(n) - t = rspamd_kann.transform.relu(t) - t = rspamd_kann.layer.dense(t, nhidden); - t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg) - return rspamd_kann.new.kann(t) -end - -local function can_push_train_vector(rule, task, learn_type, nspam, nham) - local train_opts = rule.train - local coin = math.random() - - if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then - rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) - return false - end - - if train_opts.learn_mode == 'balanced' then - -- Keep balanced training set based on number of spam and ham samples - if learn_type == 'spam' then - if nspam <= train_opts.max_trains then - if nspam > nham then - -- Apply sampling - local skip_rate = 1.0 - nham / (nspam + 1) - if coin < skip_rate - train_opts.classes_bias then - rspamd_logger.infox(task, - 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', - learn_type, - skip_rate - train_opts.classes_bias, - nspam, nham) - return false - end - end - return true - else -- Enough learns - rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', - learn_type, - nspam) - end - else - if nham <= train_opts.max_trains then - if nham > nspam then - -- Apply sampling - local skip_rate = 1.0 - nspam / (nham + 1) - if coin < skip_rate - train_opts.classes_bias then - rspamd_logger.infox(task, - 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', - learn_type, - skip_rate - train_opts.classes_bias, - nspam, nham) - return false - end - end - return true - else - rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type, - nham) - end - end - else - -- Probabilistic learn mode, we just skip learn if we already have enough samples or - -- if our coin drop is less than desired probability - if learn_type == 'spam' then - if nspam <= train_opts.max_trains then - if train_opts.spam_skip_prob then - if coin <= train_opts.spam_skip_prob then - rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type, - coin, train_opts.spam_skip_prob) - return false - end - - return true - end - else - rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type, - nspam, train_opts.max_trains) - end - else - if nham <= train_opts.max_trains then - if train_opts.ham_skip_prob then - if coin <= train_opts.ham_skip_prob then - rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type, - coin, train_opts.ham_skip_prob) - return false - end - - return true - end - else - rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type, - nham, train_opts.max_trains) - end - end - end - - return false -end - local function ann_push_task_result(rule, task, verdict, score, set) local train_opts = rule.train local learn_spam, learn_ham local skip_reason = 'unknown' - if train_opts.autotrain then + if not train_opts.store_pool_only and train_opts.autotrain then if train_opts.spam_score then learn_spam = score >= train_opts.spam_score @@ -510,10 +211,10 @@ local function ann_push_task_result(rule, task, verdict, score, set) learn_ham = false learn_spam = false - -- Explicitly store tokens in a mempool variable - local vec = result_to_vector(task, set) - task:get_mempool():set_variable('neural_vec_mpack', ucl.to_format(vec, 'msgpack')) - task:get_mempool():set_variable('neural_profile_digest', set.digest) + -- Explicitly store tokens in cache + local vec = neural_common.result_to_vector(task, set) + task:cache_set('neural_vec_mpack', ucl.to_format(vec, 'msgpack')) + task:cache_set('neural_profile_digest', set.digest) skip_reason = 'store_pool_only has been set' end end @@ -527,8 +228,8 @@ local function ann_push_task_result(rule, task, verdict, score, set) if not err and type(data) == 'table' then local nspam,nham = data[1],data[2] - if can_push_train_vector(rule, task, learn_type, nspam, nham) then - local vec = result_to_vector(task, set) + if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then + local vec = neural_common.result_to_vector(task, set) local str = rspamd_util.zstd_compress(table.concat(vec, ';')) local target_key = set.ann.redis_key .. '_' .. learn_type @@ -585,7 +286,7 @@ local function ann_push_task_result(rule, task, verdict, score, set) set.name) end - lua_redis.exec_redis_script(redis_lua_script_vectors_len_id, + lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len, {task = task, is_write = false}, vectors_len_cb, { @@ -605,284 +306,6 @@ end --- Offline training logic --- Closure generator for unlock function -local function gen_unlock_cb(rule, set, ann_key) - return function (err) - if err then - rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s', - rule.prefix, set.name, ann_key, err) - else - lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s', - rule.prefix, set.name, ann_key) - end - end -end - --- This function is intended to extend lock for ANN during training --- It registers periodic that increases locked key each 30 seconds unless --- `set.learning_spawned` is set to `true` -local function register_lock_extender(rule, set, ev_base, ann_key) - rspamd_config:add_periodic(ev_base, 30.0, - function() - local function redis_lock_extend_cb(_err, _) - if _err then - rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', - ann_key, _err) - else - rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', - ann_key) - end - end - - if set.learning_spawned then - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - redis_lock_extend_cb, --callback - 'HINCRBY', -- command - {ann_key, 'lock', '30'} - ) - else - lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") - return false -- do not plan any more updates - end - - return true - end - ) -end - --- This function takes all inputs, applies PCA transformation and returns the final --- PCA matrix as rspamd_tensor -local function learn_pca(inputs, max_inputs) - local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs)) - local eigenvals = scatter_matrix:eigen() - -- scatter matrix is not filled with eigenvectors - lua_util.debugm(N, 'eigenvalues: %s', eigenvals) - local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1]) - for i=1,max_inputs do - w[i] = scatter_matrix[#scatter_matrix - i + 1] - end - - lua_util.debugm(N, 'pca matrix: %s', w) - - return w -end - --- Fills ANN data for a specific settings element -local function fill_set_ann(set, ann_key) - if not set.ann then - set.ann = { - symbols = set.symbols, - distance = 0, - digest = set.digest, - redis_key = ann_key, - version = 0, - } - end -end - --- This function receives training vectors, checks them, spawn learning and saves ANN in Redis -local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec) - -- Check training data sanity - -- Now we need to join inputs and create the appropriate test vectors - local n = #set.symbols + - meta_functions.rspamd_count_metatokens() - - -- Now we can train ann - local train_ann = create_ann(rule.max_inputs or n, 3, rule) - - if #ham_vec + #spam_vec < rule.train.max_trains / 2 then - -- Invalidate ANN as it is definitely invalid - -- TODO: add invalidation - assert(false) - else - local inputs, outputs = {}, {} - - -- Used to show sparsed vectors in a convenient format (for debugging only) - local function debug_vec(t) - local ret = {} - for i,v in ipairs(t) do - if v ~= 0 then - ret[#ret + 1] = string.format('%d=%.2f', i, v) - end - end - - return ret - end - - -- Make training set by joining vectors - -- KANN automatically shuffles those samples - -- 1.0 is used for spam and -1.0 is used for ham - -- It implies that output layer can express that (e.g. tanh output) - for _,e in ipairs(spam_vec) do - inputs[#inputs + 1] = e - outputs[#outputs + 1] = {1.0} - --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e)) - end - for _,e in ipairs(ham_vec) do - inputs[#inputs + 1] = e - outputs[#outputs + 1] = {-1.0} - --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e)) - end - - -- Called in child process - local function train() - local log_thresh = rule.train.max_iterations / 10 - local seen_nan = false - - local function train_cb(iter, train_cost, value_cost) - if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then - if train_cost ~= train_cost and not seen_nan then - -- We have nan :( try to log lot's of stuff to dig into a problem - seen_nan = true - rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s', - rule.prefix, set.name, - value_cost) - for i,e in ipairs(inputs) do - lua_util.debugm(N, rspamd_config, 'train vector %s -> %s', - debug_vec(e), outputs[i][1]) - end - end - - rspamd_logger.infox(rspamd_config, - "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s", - rule.prefix, set.name, - ann_key, - iter, - train_cost, - value_cost) - end - end - - lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started", - rule.prefix, set.name) - - local ret,err = pcall(train_ann.train1, train_ann, - inputs, outputs, { - lr = rule.train.learning_rate, - max_epoch = rule.train.max_iterations, - cb = train_cb, - pca = (set.ann or {}).pca - }) - - if not ret then - rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s", - rule.prefix, set.name, err) - - return nil - end - - if not seen_nan then - local out = train_ann:save() - return out - else - return nil - end - end - - set.learning_spawned = true - - local function redis_save_cb(err) - if err then - rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s', - rule.prefix, set.name, ann_key, err) - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - gen_unlock_cb(rule, set, ann_key), --callback - 'HDEL', -- command - {ann_key, 'lock'} - ) - else - rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', - rule.prefix, set.name, set.ann.redis_key) - end - end - - local function ann_trained(err, data) - set.learning_spawned = false - if err then - rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', - rule.prefix, set.name, err) - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - gen_unlock_cb(rule, set, ann_key), --callback - 'HDEL', -- command - {ann_key, 'lock'} - ) - else - local ann_data = rspamd_util.zstd_compress(data) - local pca_data - - fill_set_ann(set, ann_key) - if set.ann.pca then - pca_data = rspamd_util.zstd_compress(set.ann.pca:save()) - end - -- Deserialise ANN from the child process - ann_trained = rspamd_kann.load(data) - local version = (set.ann.version or 0) + 1 - set.ann.version = version - set.ann.ann = ann_trained - set.ann.symbols = set.symbols - set.ann.redis_key = new_ann_key(rule, set, version) - - local profile = { - symbols = set.symbols, - digest = set.digest, - redis_key = set.ann.redis_key, - version = version - } - - local ucl = require "ucl" - local profile_serialized = ucl.to_format(profile, 'json-compact', true) - - rspamd_logger.infox(rspamd_config, - 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)', - rule.prefix, set.name, - #data, #ann_data, - #(set.ann.pca or {}), #(pca_data or {}), - set.ann.redis_key, ann_key) - - lua_redis.exec_redis_script(redis_save_unlock_id, - {ev_base = ev_base, is_write = true}, - redis_save_cb, - {profile.redis_key, - redis_ann_prefix(rule, set.name), - ann_data, - profile_serialized, - tostring(rule.ann_expire), - tostring(os.time()), - ann_key, -- old key to unlock... - pca_data - }) - end - end - - if rule.max_inputs then - fill_set_ann(set, ann_key) - -- Train PCA in the main process, presumably it is not that long - set.ann.pca = learn_pca(inputs, rule.max_inputs) - end - - worker:spawn_process{ - func = train, - on_complete = ann_trained, - proctitle = string.format("ANN train for %s/%s", rule.prefix, set.name), - } - end - -- Spawn learn and register lock extension - set.learning_spawned = true - register_lock_extender(rule, set, ev_base, ann_key) -end - -- Utility to extract and split saved training vectors to a table of tables local function process_training_vectors(data) return fun.totable(fun.map(function(tok) @@ -909,14 +332,16 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) rule.redis, nil, true, -- is write - gen_unlock_cb(rule, set, ann_key), --callback + neural_common.gen_unlock_cb(rule, set, ann_key), --callback 'HDEL', -- command {ann_key, 'lock'} ) else -- Decompress and convert to numbers each training vector ham_elts = process_training_vectors(data) - spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts) + neural_common.spawn_train({worker = worker, ev_base = ev_base, + rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts, + spam_vec = spam_elts}) end end @@ -931,7 +356,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) rule.redis, nil, true, -- is write - gen_unlock_cb(rule, set, ann_key), --callback + neural_common.gen_unlock_cb(rule, set, ann_key), --callback 'HDEL', -- command {ann_key, 'lock'} ) @@ -987,7 +412,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) -- Call Redis script that tries to acquire a lock -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when -- ANN is locked by another host (or a process, meh) - lua_redis.exec_redis_script(redis_maybe_lock_id, + lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock, {ev_base = ev_base, is_write = true}, redis_lock_cb, { @@ -1376,7 +801,7 @@ local function cleanup_anns(rule, cfg, ev_base) end if type(set) == 'table' then - lua_redis.exec_redis_script(redis_maybe_invalidate_id, + lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate, {ev_base = ev_base, is_write = true}, invalidate_cb, {set.prefix, tostring(settings.max_profiles)}) @@ -1411,7 +836,7 @@ local function ann_push_vector(task) end for _,rule in pairs(settings.rules) do - local set = get_rule_settings(task, rule) + local set = neural_common.get_rule_settings(task, rule) if set then ann_push_task_result(rule, task, verdict, score, set) @@ -1423,155 +848,20 @@ local function ann_push_vector(task) end --- This function is used to adjust profiles and allowed setting ids for each rule --- It must be called when all settings are already registered (e.g. at post-init for config) -local function process_rules_settings() - local function process_settings_elt(rule, selt) - local profile = rule.profile[selt.name] - if profile then - -- Use static user defined profile - -- Ensure that we have an array... - lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s", - rule.prefix, selt.name, profile) - if not profile[1] then profile = lua_util.keys(profile) end - selt.symbols = profile - else - lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)", - rule.prefix, selt.name) - end - - local function filter_symbols_predicate(sname) - if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then - return false - end - local fl = rspamd_config:get_symbol_flags(sname) - if fl then - fl = lua_util.list_to_hash(fl) - - return not (fl.nostat or fl.idempotent or fl.skip or fl.composite) - end - - return false - end - - -- Generic stuff - table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols))) - - selt.digest = lua_util.table_digest(selt.symbols) - selt.prefix = redis_ann_prefix(rule, selt.name) - - lua_redis.register_prefix(selt.prefix, N, - string.format('NN prefix for rule "%s"; settings id "%s"', - rule.prefix, selt.name), { - persistent = true, - type = 'zlist', - }) - -- Versions - lua_redis.register_prefix(selt.prefix .. '_\\d+', N, - string.format('NN storage for rule "%s"; settings id "%s"', - rule.prefix, selt.name), { - persistent = true, - type = 'hash', - }) - lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N, - string.format('NN learning set (spam) for rule "%s"; settings id "%s"', - rule.prefix, selt.name), { - persistent = true, - type = 'list', - }) - lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N, - string.format('NN learning set (spam) for rule "%s"; settings id "%s"', - rule.prefix, selt.name), { - persistent = true, - type = 'list', - }) - end - - for k,rule in pairs(settings.rules) do - if not rule.allowed_settings then - rule.allowed_settings = {} - elseif rule.allowed_settings == 'all' then - -- Extract all settings ids - rule.allowed_settings = lua_util.keys(lua_settings.all_settings()) - end - - -- Convert to a map -> true - rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings) - - -- Check if we can work without settings - if k == 'default' or type(rule.default) ~= 'boolean' then - rule.default = true - end - - rule.settings = {} - - if rule.default then - local default_settings = { - symbols = lua_settings.default_symbols(), - name = 'default' - } - - process_settings_elt(rule, default_settings) - rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32 - end - - -- Now, for each allowed settings, we store sorted symbols + digest - -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest } - for s,_ in pairs(rule.allowed_settings) do - -- Here, we have a name, set of symbols and - local settings_id = s - if type(settings_id) ~= 'number' then - settings_id = lua_settings.numeric_settings_id(s) - end - local selt = lua_settings.settings_by_id(settings_id) - - local nelt = { - symbols = selt.symbols, -- Already sorted - name = selt.name - } - - process_settings_elt(rule, nelt) - for id,ex in pairs(rule.settings) do - if type(ex) == 'table' then - if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then - -- Equal symbols, add reference - lua_util.debugm(N, rspamd_config, - 'added reference from settings id %s to %s; same symbols', - nelt.name, ex.name) - rule.settings[settings_id] = id - nelt = nil - end - end - end - - if nelt then - rule.settings[settings_id] = nelt - lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s', - nelt.name, settings_id, rule.prefix) - end - end - end -end - -redis_params = lua_redis.parse_redis_server('neural') - -if not redis_params then - redis_params = lua_redis.parse_redis_server('fann_redis') -end - -- Initialization part -if not (module_config and type(module_config) == 'table') or not redis_params then +if not (neural_common.module_config and type(neural_common.module_config) == 'table') + or not neural_common.redis_params then rspamd_logger.infox(rspamd_config, 'Module is unconfigured') lua_util.disable_module(N, "redis") return end -local rules = module_config['rules'] +local rules = neural_common.module_config['rules'] if not rules then -- Use legacy configuration rules = {} - rules['default'] = module_config + rules['default'] = neural_common.module_config end local id = rspamd_config:register_symbol({ @@ -1582,8 +872,7 @@ local id = rspamd_config:register_symbol({ callback = ann_scores_filter }) -settings = lua_util.override_defaults(settings, module_config) -settings.rules = {} -- Reset unless validated further in the cycle +neural_common.settings.rules = {} -- Reset unless validated further in the cycle if settings.blacklisted_symbols and settings.blacklisted_symbols[1] then -- Transform to hash for simplicity @@ -1593,7 +882,7 @@ end -- Check all rules for k,r in pairs(rules) do local rule_elt = lua_util.override_defaults(default_options, r) - rule_elt['redis'] = redis_params + rule_elt['redis'] = neural_common.redis_params rule_elt['anns'] = {} -- Store ANNs here if not rule_elt.prefix then @@ -1651,11 +940,12 @@ rspamd_config:register_symbol({ callback = ann_push_vector }) +-- We also need to deal with settings +rspamd_config:add_post_init(neural_common.process_rules_settings) + -- Add training scripts for _,rule in pairs(settings.rules) do - load_scripts(rule.redis) - -- We also need to deal with settings - rspamd_config:add_post_init(process_rules_settings) + neural_common.load_scripts(rule.redis) -- This function will check ANNs in Redis when a worker is loaded rspamd_config:add_on_load(function(cfg, ev_base, worker) if worker:is_scanner() then diff --git a/test/functional/cases/330_neural.robot b/test/functional/cases/330_neural.robot deleted file mode 100644 index 8ce342838..000000000 --- a/test/functional/cases/330_neural.robot +++ /dev/null @@ -1,74 +0,0 @@ -*** Settings *** -Suite Setup Neural Setup -Suite Teardown Neural Teardown -Library Process -Library ${TESTDIR}/lib/rspamd.py -Resource ${TESTDIR}/lib/rspamd.robot -Variables ${TESTDIR}/lib/vars.py - -*** Variables *** -${URL_TLD} ${TESTDIR}/../lua/unit/test_tld.dat -${CONFIG} ${TESTDIR}/configs/neural.conf -${MESSAGE} ${TESTDIR}/messages/spam_message.eml -${REDIS_SCOPE} Suite -${RSPAMD_SCOPE} Suite - -*** Test Cases *** -Train - Sleep 2s Wait for redis mess - FOR ${INDEX} IN RANGE 0 10 - Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"]} - Expect Symbol SPAM_SYMBOL - Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"]} - Expect Symbol HAM_SYMBOL - END - -Check Neural HAM - Sleep 2s Wait for neural to be loaded - Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} - Expect Symbol NEURAL_HAM_SHORT - Do Not Expect Symbol NEURAL_SPAM_SHORT - Expect Symbol NEURAL_HAM_SHORT_PCA - Do Not Expect Symbol NEURAL_SPAM_SHORT_PCA - -Check Neural SPAM - Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} - Expect Symbol NEURAL_SPAM_SHORT - Do Not Expect Symbol NEURAL_HAM_SHORT - Expect Symbol NEURAL_SPAM_SHORT_PCA - Do Not Expect Symbol NEURAL_HAM_SHORT_PCA - - -Train INVERSE - FOR ${INDEX} IN RANGE 0 10 - Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"]; SPAM_SYMBOL = -5;} - Expect Symbol SPAM_SYMBOL - Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"]; HAM_SYMBOL = 5;} - Expect Symbol HAM_SYMBOL - END - -Check Neural HAM INVERSE - Sleep 2s Wait for neural to be loaded - Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"]} - Expect Symbol NEURAL_SPAM_SHORT - Expect Symbol NEURAL_SPAM_SHORT_PCA - Do Not Expect Symbol NEURAL_HAM_SHORT - Do Not Expect Symbol NEURAL_HAM_SHORT_PCA - -Check Neural SPAM INVERSE - Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"]} - Expect Symbol NEURAL_HAM_SHORT - Expect Symbol NEURAL_HAM_SHORT_PCA - Do Not Expect Symbol NEURAL_SPAM_SHORT - Do Not Expect Symbol NEURAL_SPAM_SHORT_PCA - -*** Keywords *** -Neural Setup - ${TMPDIR} = Make Temporary Directory - Set Suite Variable ${TMPDIR} - Run Redis - Generic Setup - -Neural Teardown - Shutdown Process With Children ${REDIS_PID} - Normal Teardown diff --git a/test/functional/cases/330_neural/001_autotrain.robot b/test/functional/cases/330_neural/001_autotrain.robot new file mode 100644 index 000000000..8ce342838 --- /dev/null +++ b/test/functional/cases/330_neural/001_autotrain.robot @@ -0,0 +1,74 @@ +*** Settings *** +Suite Setup Neural Setup +Suite Teardown Neural Teardown +Library Process +Library ${TESTDIR}/lib/rspamd.py +Resource ${TESTDIR}/lib/rspamd.robot +Variables ${TESTDIR}/lib/vars.py + +*** Variables *** +${URL_TLD} ${TESTDIR}/../lua/unit/test_tld.dat +${CONFIG} ${TESTDIR}/configs/neural.conf +${MESSAGE} ${TESTDIR}/messages/spam_message.eml +${REDIS_SCOPE} Suite +${RSPAMD_SCOPE} Suite + +*** Test Cases *** +Train + Sleep 2s Wait for redis mess + FOR ${INDEX} IN RANGE 0 10 + Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"]} + Expect Symbol SPAM_SYMBOL + Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"]} + Expect Symbol HAM_SYMBOL + END + +Check Neural HAM + Sleep 2s Wait for neural to be loaded + Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} + Expect Symbol NEURAL_HAM_SHORT + Do Not Expect Symbol NEURAL_SPAM_SHORT + Expect Symbol NEURAL_HAM_SHORT_PCA + Do Not Expect Symbol NEURAL_SPAM_SHORT_PCA + +Check Neural SPAM + Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} + Expect Symbol NEURAL_SPAM_SHORT + Do Not Expect Symbol NEURAL_HAM_SHORT + Expect Symbol NEURAL_SPAM_SHORT_PCA + Do Not Expect Symbol NEURAL_HAM_SHORT_PCA + + +Train INVERSE + FOR ${INDEX} IN RANGE 0 10 + Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"]; SPAM_SYMBOL = -5;} + Expect Symbol SPAM_SYMBOL + Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"]; HAM_SYMBOL = 5;} + Expect Symbol HAM_SYMBOL + END + +Check Neural HAM INVERSE + Sleep 2s Wait for neural to be loaded + Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"]} + Expect Symbol NEURAL_SPAM_SHORT + Expect Symbol NEURAL_SPAM_SHORT_PCA + Do Not Expect Symbol NEURAL_HAM_SHORT + Do Not Expect Symbol NEURAL_HAM_SHORT_PCA + +Check Neural SPAM INVERSE + Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"]} + Expect Symbol NEURAL_HAM_SHORT + Expect Symbol NEURAL_HAM_SHORT_PCA + Do Not Expect Symbol NEURAL_SPAM_SHORT + Do Not Expect Symbol NEURAL_SPAM_SHORT_PCA + +*** Keywords *** +Neural Setup + ${TMPDIR} = Make Temporary Directory + Set Suite Variable ${TMPDIR} + Run Redis + Generic Setup + +Neural Teardown + Shutdown Process With Children ${REDIS_PID} + Normal Teardown diff --git a/test/functional/cases/330_neural/002_manualtrain.robot b/test/functional/cases/330_neural/002_manualtrain.robot new file mode 100644 index 000000000..22a42120e --- /dev/null +++ b/test/functional/cases/330_neural/002_manualtrain.robot @@ -0,0 +1,75 @@ +*** Settings *** +Suite Setup Neural Setup +Suite Teardown Neural Teardown +Library Process +Library ${TESTDIR}/lib/rspamd.py +Resource ${TESTDIR}/lib/rspamd.robot +Variables ${TESTDIR}/lib/vars.py + +*** Variables *** +${URL_TLD} ${TESTDIR}/../lua/unit/test_tld.dat +${CONFIG} ${TESTDIR}/configs/neural_noauto.conf +${MESSAGE} ${TESTDIR}/messages/spam_message.eml +${REDIS_SCOPE} Suite +${RSPAMD_SCOPE} Suite + +*** Test Cases *** +Collect training vectors & train manually + Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL","SAVE_NN_ROW"]} + Expect Symbol SPAM_SYMBOL + # Save neural inputs for later + ${SPAM_ROW} = Get File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0] + Remove File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0] + Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL","SAVE_NN_ROW"]} + Expect Symbol HAM_SYMBOL + # Save neural inputs for later + ${HAM_ROW} = Get File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0] + Remove File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0] + ${HAM_ROW} = Run ${RSPAMADM} lua -a ${HAM_ROW} ${TESTDIR}/util/nn_unpack.lua + ${HAM_ROW} = Evaluate json.loads("${HAM_ROW}") + ${SPAM_ROW} = Run ${RSPAMADM} lua -a ${SPAM_ROW} ${TESTDIR}/util/nn_unpack.lua + ${SPAM_ROW} = Evaluate json.loads("${SPAM_ROW}") + ${HAM_VEC} = Evaluate [${HAM_ROW}] * 10 + ${SPAM_VEC} = Evaluate [${SPAM_ROW}] * 10 + ${json1} = Evaluate json.dumps({"spam_vec": ${SPAM_VEC}, "ham_vec": ${HAM_VEC}, "rule": "SHORT"}) + # Save variables for use in inverse training + Set Suite Variable ${HAM_VEC} + Set Suite Variable ${SPAM_VEC} + HTTP POST ${LOCAL_ADDR} ${PORT_CONTROLLER} /plugins/neural/learn ${json1} + Sleep 2s Wait for neural to be loaded + +Check Neural HAM + Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} + Do Not Expect Symbol NEURAL_SPAM_SHORT + Expect Symbol NEURAL_HAM_SHORT + +Check Neural SPAM + Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} + Do Not Expect Symbol NEURAL_HAM_SHORT + Expect Symbol NEURAL_SPAM_SHORT + +Train inverse + ${json2} = Evaluate json.dumps({"spam_vec": ${HAM_VEC}, "ham_vec": ${SPAM_VEC}, "rule": "SHORT"}) + HTTP POST ${LOCAL_ADDR} ${PORT_CONTROLLER} /plugins/neural/learn ${json2} + Sleep 2s Wait for neural to be loaded + +Check Neural HAM - inverse + Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} + Do Not Expect Symbol NEURAL_HAM_SHORT + Expect Symbol NEURAL_SPAM_SHORT + +Check Neural SPAM - inverse + Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} + Do Not Expect Symbol NEURAL_SPAM_SHORT + Expect Symbol NEURAL_HAM_SHORT + +*** Keywords *** +Neural Setup + ${TMPDIR} = Make Temporary Directory + Set Suite Variable ${TMPDIR} + Run Redis + Generic Setup + +Neural Teardown + Shutdown Process With Children ${REDIS_PID} + Normal Teardown diff --git a/test/functional/configs/neural_noauto.conf b/test/functional/configs/neural_noauto.conf new file mode 100644 index 000000000..55f0a4283 --- /dev/null +++ b/test/functional/configs/neural_noauto.conf @@ -0,0 +1,85 @@ +options = { + url_tld = "${URL_TLD}" + pidfile = "${TMPDIR}/rspamd.pid" + lua_path = "${INSTALLROOT}/share/rspamd/lib/?.lua" + filters = []; + explicit_modules = ["settings"]; +} + +logging = { + type = "file", + level = "debug" + filename = "${TMPDIR}/rspamd.log" + log_usec = true; +} +metric = { + name = "default", + actions = { + reject = 100500, + add_header = 50500, + } + unknown_weight = 1 +} +worker { + type = normal + bind_socket = ${LOCAL_ADDR}:${PORT_NORMAL} + count = 1 + task_timeout = 10s; +} +worker { + type = controller + bind_socket = ${LOCAL_ADDR}:${PORT_CONTROLLER} + count = 1 + secure_ip = ["127.0.0.1", "::1"]; + stats_path = "${TMPDIR}/stats.ucl" +} + +modules { + path = "${TESTDIR}/../../src/plugins/lua/" +} + +lua = "${TESTDIR}/lua/test_coverage.lua"; + +neural { + rules { + SHORT { + train { + learning_rate = 0.001; + max_usages = 2; + spam_score = 1; + ham_score = -1; + max_trains = 10; + max_iterations = 250; + store_pool_only = true; + } + symbol_spam = "NEURAL_SPAM_SHORT"; + symbol_ham = "NEURAL_HAM_SHORT"; + ann_expire = 86400; + watch_interval = 0.5; + } + SHORT_PCA { + train { + learning_rate = 0.001; + max_usages = 2; + spam_score = 1; + ham_score = -1; + max_trains = 10; + max_iterations = 250; + store_pool_only = true; + } + symbol_spam = "NEURAL_SPAM_SHORT_PCA"; + symbol_ham = "NEURAL_HAM_SHORT_PCA"; + ann_expire = 86400; + watch_interval = 0.5; + max_inputs = 2; + } + } + allow_local = true; + +} +redis { + servers = "${REDIS_ADDR}:${REDIS_PORT}"; + expand_keys = true; +} + +lua = "${TESTDIR}/lua/neural.lua"; diff --git a/test/functional/lib/rspamd.robot b/test/functional/lib/rspamd.robot index 53d4e70f9..0b6cc6f38 100644 --- a/test/functional/lib/rspamd.robot +++ b/test/functional/lib/rspamd.robot @@ -209,6 +209,7 @@ Run Rspamd ... ELSE Make Temporary Directory Set Directory Ownership ${tmpdir} ${RSPAMD_USER} ${RSPAMD_GROUP} ${template} = Get File ${CONFIG} + # TODO: stop using this; we have Lupa now FOR ${i} IN @{vargs} ${newvalue} = Replace Variables ${${i}} Set To Dictionary ${d} ${i}=${newvalue} @@ -218,7 +219,8 @@ Run Rspamd Log ${config} Create File ${tmpdir}/rspamd.conf ${config} ${result} = Run Process ${RSPAMD} -u ${RSPAMD_USER} -g ${RSPAMD_GROUP} - ... -c ${tmpdir}/rspamd.conf env:TMPDIR=${tmpdir} env:DBDIR=${tmpdir} env:LD_LIBRARY_PATH=${TESTDIR}/../../contrib/aho-corasick stdout=DEVNULL stderr=DEVNULL + ... -c ${tmpdir}/rspamd.conf env:TMPDIR=${tmpdir} env:DBDIR=${tmpdir} env:LD_LIBRARY_PATH=${TESTDIR}/../../contrib/aho-corasick + ... env:RSPAMD_INSTALLROOT=${INSTALLROOT} stdout=DEVNULL stderr=DEVNULL Run Keyword If ${result.rc} != 0 Log ${result.stderr} Should Be Equal As Integers ${result.rc} 0 Wait Until Keyword Succeeds 10x 1 sec Check Pidfile ${tmpdir}/rspamd.pid timeout=0.5s diff --git a/test/functional/lua/neural.lua b/test/functional/lua/neural.lua index 70857d429..ccdad1b68 100644 --- a/test/functional/lua/neural.lua +++ b/test/functional/lua/neural.lua @@ -1,3 +1,5 @@ +local logger = require "rspamd_logger" + rspamd_config:register_symbol({ name = 'SPAM_SYMBOL', score = 5.0, @@ -21,4 +23,39 @@ rspamd_config:register_symbol({ callback = function() return true, 'Fires always' end -}) \ No newline at end of file +}) + +rspamd_config.SAVE_NN_ROW = { + callback = function(task) + local fname = os.tmpname() + task:cache_set('nn_row_tmpfile', fname) + return true, 1.0, fname + end +} + +rspamd_config.SAVE_NN_ROW_IDEMPOTENT = { + callback = function(task) + local function tohex(str) + return (str:gsub('.', function (c) + return string.format('%02X', string.byte(c)) + end)) + end + local fname = task:cache_get('nn_row_tmpfile') + if not fname then + return + end + local f, err = io.open(fname, 'w') + if not f then + logger.errx(task, err) + return + end + f:write(tohex(task:cache_get('neural_vec_mpack') or '')) + f:close() + return + end, + type = 'idempotent', + flags = 'explicit_disable', + priority = 10, +} + +dofile(rspamd_env.INSTALLROOT .. "/share/rspamd/rules/controller/init.lua") diff --git a/test/functional/util/nn_unpack.lua b/test/functional/util/nn_unpack.lua new file mode 100644 index 000000000..fee98d5a0 --- /dev/null +++ b/test/functional/util/nn_unpack.lua @@ -0,0 +1,16 @@ +local ucl = require "ucl" + +local function unhex(str) + return (str:gsub('..', function (cc) + return string.char(tonumber(cc, 16)) + end)) +end + +local parser = ucl.parser() +local ok, err = parser:parse_string(unhex(arg[1]), 'msgpack') +if not ok then + io.stderr:write(err) + os.exit(1) +end + +print(ucl.to_format(parser:get_object(), 'json-compact'))