summaryrefslogtreecommitdiffstats
path: root/lualib/plugins/neural.lua
diff options
context:
space:
mode:
authorAndrew Lewis <nerf@judo.za.org>2020-12-17 11:28:09 +0200
committerAndrew Lewis <nerf@judo.za.org>2020-12-17 11:28:09 +0200
commit960b608d352e8c820b0725d898d78959ca59ee7d (patch)
tree9d9f192e1c3161a804e94e1aed1c0a63b77929c0 /lualib/plugins/neural.lua
parent5ce6a2d97ff655651e4bba7737b834d866b94c94 (diff)
downloadrspamd-960b608d352e8c820b0725d898d78959ca59ee7d.tar.gz
rspamd-960b608d352e8c820b0725d898d78959ca59ee7d.zip
[Feature] Add controller endpoint for training neural
- Move neural functions to library - Parameterise spawn_train - neural plugin: Fix store_pool_only when autotrain is true - neural plugin: Use cache_set instead of mempool - Add test
Diffstat (limited to 'lualib/plugins/neural.lua')
-rw-r--r--lualib/plugins/neural.lua779
1 files changed, 779 insertions, 0 deletions
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
new file mode 100644
index 000000000..4d4c44b5d
--- /dev/null
+++ b/lualib/plugins/neural.lua
@@ -0,0 +1,779 @@
+--[[
+Copyright (c) 2020, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local fun = require "fun"
+local lua_redis = require "lua_redis"
+local lua_settings = require "lua_settings"
+local lua_util = require "lua_util"
+local meta_functions = require "lua_meta"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_logger = require "rspamd_logger"
+local rspamd_tensor = require "rspamd_tensor"
+local rspamd_util = require "rspamd_util"
+
+local N = 'neural'
+
+-- Used in prefix to avoid wrong ANN to be loaded
+local plugin_ver = '2'
+
+-- Module vars
+local default_options = {
+ train = {
+ max_trains = 1000,
+ max_epoch = 1000,
+ max_usages = 10,
+ max_iterations = 25, -- Torch style
+ mse = 0.001,
+ autotrain = true,
+ train_prob = 1.0,
+ learn_threads = 1,
+ learn_mode = 'balanced', -- Possible values: balanced, proportional
+ learning_rate = 0.01,
+ classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
+ spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
+ ham_skip_prob = 0.0, -- proportional mode: ham skip probability
+ store_pool_only = false, -- store tokens in cache only (disables autotrain);
+ -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
+ },
+ watch_interval = 60.0,
+ lock_expire = 600,
+ learning_spawned = false,
+ ann_expire = 60 * 60 * 24 * 2, -- 2 days
+ hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
+ symbol_spam = 'NEURAL_SPAM',
+ symbol_ham = 'NEURAL_HAM',
+ max_inputs = nil, -- when PCA is used
+ blacklisted_symbols = {}, -- list of symbols skipped in neural processing
+}
+
+-- Rule structure:
+-- * static config fields (see `default_options`)
+-- * prefix - name or defined prefix
+-- * settings - table of settings indexed by settings id, -1 is used when no settings defined
+
+-- Rule settings element defines elements for specific settings id:
+-- * symbols - static symbols profile (defined by config or extracted from symcache)
+-- * name - name of settings id
+-- * digest - digest of all symbols
+-- * ann - dynamic ANN configuration loaded from Redis
+-- * train - train data for ANN (e.g. the currently trained ANN)
+
+-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
+-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
+-- * version - version of ANN loaded from redis
+-- * redis_key - name of ANN key in Redis
+-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
+-- * distance - distance between set.symbols and set.ann.symbols
+-- * ann - kann object
+
+local settings = {
+ rules = {},
+ prefix = 'rn', -- Neural network default prefix
+ max_profiles = 3, -- Maximum number of NN profiles stored
+}
+
+-- Get module & Redis configuration
+local module_config = rspamd_config:get_all_opt(N)
+settings = lua_util.override_defaults(settings, module_config)
+local redis_params = lua_redis.parse_redis_server('neural')
+
+-- Lua script that checks if we can store a new training vector
+-- Uses the following keys:
+-- key1 - ann key
+-- returns nspam,nham (or nil if locked)
+local redis_lua_script_vectors_len = [[
+ local prefix = KEYS[1]
+ local locked = redis.call('HGET', prefix, 'lock')
+ if locked then
+ local host = redis.call('HGET', prefix, 'hostname') or 'unknown'
+ return string.format('%s:%s', host, locked)
+ end
+ local nspam = 0
+ local nham = 0
+
+ local ret = redis.call('LLEN', prefix .. '_spam')
+ if ret then nspam = tonumber(ret) end
+ ret = redis.call('LLEN', prefix .. '_ham')
+ if ret then nham = tonumber(ret) end
+
+ return {nspam,nham}
+]]
+
+-- Lua script to invalidate ANNs by rank
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - number of elements to leave
+local redis_lua_script_maybe_invalidate = [[
+ local card = redis.call('ZCARD', KEYS[1])
+ local lim = tonumber(KEYS[2])
+ if card > lim then
+ local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
+ for _,k in ipairs(to_delete) do
+ local tb = cjson.decode(k)
+ redis.call('DEL', tb.redis_key)
+ -- Also train vectors
+ redis.call('DEL', tb.redis_key .. '_spam')
+ redis.call('DEL', tb.redis_key .. '_ham')
+ end
+ redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
+ return to_delete
+ else
+ return {}
+ end
+]]
+
+-- Lua script to invalidate ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - current time
+-- key3 - key expire
+-- key4 - hostname
+local redis_lua_script_maybe_lock = [[
+ local locked = redis.call('HGET', KEYS[1], 'lock')
+ local now = tonumber(KEYS[2])
+ if locked then
+ locked = tonumber(locked)
+ local expire = tonumber(KEYS[3])
+ if now > locked and (now - locked) < expire then
+ return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'}
+ end
+ end
+ redis.call('HSET', KEYS[1], 'lock', tostring(now))
+ redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
+ return 1
+]]
+
+-- Lua script to save and unlock ANN in redis
+-- Uses the following keys
+-- key1 - prefix for ANN
+-- key2 - prefix for profile
+-- key3 - compressed ANN
+-- key4 - profile as JSON
+-- key5 - expire in seconds
+-- key6 - current time
+-- key7 - old key
+-- key8 - optional PCA
+local redis_lua_script_save_unlock = [[
+ local now = tonumber(KEYS[6])
+ redis.call('ZADD', KEYS[2], now, KEYS[4])
+ redis.call('HSET', KEYS[1], 'ann', KEYS[3])
+ redis.call('DEL', KEYS[1] .. '_spam')
+ redis.call('DEL', KEYS[1] .. '_ham')
+ redis.call('HDEL', KEYS[1], 'lock')
+ redis.call('HDEL', KEYS[7], 'lock')
+ redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
+ if KEYS[8] then
+ redis.call('HSET', KEYS[1], 'pca', KEYS[8])
+ end
+ return 1
+]]
+
+local redis_script_id = {}
+
+local function load_scripts()
+ redis_script_id.vectors_len = lua_redis.add_redis_script(redis_lua_script_vectors_len,
+ redis_params)
+ redis_script_id.maybe_invalidate = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
+ redis_params)
+ redis_script_id.maybe_lock = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
+ redis_params)
+ redis_script_id.save_unlock = lua_redis.add_redis_script(redis_lua_script_save_unlock,
+ redis_params)
+end
+
+local function create_ann(n, nlayers, rule)
+ -- We ignore number of layers so far when using kann
+ local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
+ local t = rspamd_kann.layer.input(n)
+ t = rspamd_kann.transform.relu(t)
+ t = rspamd_kann.layer.dense(t, nhidden);
+ t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
+ return rspamd_kann.new.kann(t)
+end
+
+-- Fills ANN data for a specific settings element
+local function fill_set_ann(set, ann_key)
+ if not set.ann then
+ set.ann = {
+ symbols = set.symbols,
+ distance = 0,
+ digest = set.digest,
+ redis_key = ann_key,
+ version = 0,
+ }
+ end
+end
+
+-- This function takes all inputs, applies PCA transformation and returns the final
+-- PCA matrix as rspamd_tensor
+local function learn_pca(inputs, max_inputs)
+ local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
+ local eigenvals = scatter_matrix:eigen()
+ -- scatter matrix is not filled with eigenvectors
+ lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
+ local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
+ for i=1,max_inputs do
+ w[i] = scatter_matrix[#scatter_matrix - i + 1]
+ end
+
+ lua_util.debugm(N, 'pca matrix: %s', w)
+
+ return w
+end
+
+-- This function is intended to extend lock for ANN during training
+-- It registers periodic that increases locked key each 30 seconds unless
+-- `set.learning_spawned` is set to `true`
+local function register_lock_extender(rule, set, ev_base, ann_key)
+ rspamd_config:add_periodic(ev_base, 30.0,
+ function()
+ local function redis_lock_extend_cb(_err, _)
+ if _err then
+ rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+ ann_key, _err)
+ else
+ rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
+ ann_key)
+ end
+ end
+
+ if set.learning_spawned then
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ redis_lock_extend_cb, --callback
+ 'HINCRBY', -- command
+ {ann_key, 'lock', '30'}
+ )
+ else
+ lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
+ return false -- do not plan any more updates
+ end
+
+ return true
+ end
+ )
+end
+
+local function can_push_train_vector(rule, task, learn_type, nspam, nham)
+ local train_opts = rule.train
+ local coin = math.random()
+
+ if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
+ rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+ return false
+ end
+
+ if train_opts.learn_mode == 'balanced' then
+ -- Keep balanced training set based on number of spam and ham samples
+ if learn_type == 'spam' then
+ if nspam <= train_opts.max_trains then
+ if nspam > nham then
+ -- Apply sampling
+ local skip_rate = 1.0 - nham / (nspam + 1)
+ if coin < skip_rate - train_opts.classes_bias then
+ rspamd_logger.infox(task,
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
+ return false
+ end
+ end
+ return true
+ else -- Enough learns
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
+ learn_type,
+ nspam)
+ end
+ else
+ if nham <= train_opts.max_trains then
+ if nham > nspam then
+ -- Apply sampling
+ local skip_rate = 1.0 - nspam / (nham + 1)
+ if coin < skip_rate - train_opts.classes_bias then
+ rspamd_logger.infox(task,
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
+ return false
+ end
+ end
+ return true
+ else
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
+ nham)
+ end
+ end
+ else
+ -- Probabilistic learn mode, we just skip learn if we already have enough samples or
+ -- if our coin drop is less than desired probability
+ if learn_type == 'spam' then
+ if nspam <= train_opts.max_trains then
+ if train_opts.spam_skip_prob then
+ if coin <= train_opts.spam_skip_prob then
+ rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
+ coin, train_opts.spam_skip_prob)
+ return false
+ end
+
+ return true
+ end
+ else
+ rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
+ nspam, train_opts.max_trains)
+ end
+ else
+ if nham <= train_opts.max_trains then
+ if train_opts.ham_skip_prob then
+ if coin <= train_opts.ham_skip_prob then
+ rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
+ coin, train_opts.ham_skip_prob)
+ return false
+ end
+
+ return true
+ end
+ else
+ rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
+ nham, train_opts.max_trains)
+ end
+ end
+ end
+
+ return false
+end
+
+-- Closure generator for unlock function
+local function gen_unlock_cb(rule, set, ann_key)
+ return function (err)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
+ rule.prefix, set.name, ann_key, err)
+ else
+ lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
+ rule.prefix, set.name, ann_key)
+ end
+ end
+end
+
+-- Used to generate new ANN key for specific profile
+local function new_ann_key(rule, set, version)
+ local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
+ rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
+
+ return ann_key
+end
+
+local function redis_ann_prefix(rule, settings_name)
+ -- We also need to count metatokens:
+ local n = meta_functions.version
+ return string.format('%s%d_%s_%d_%s',
+ settings.prefix, plugin_ver, rule.prefix, n, settings_name)
+end
+
+-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
+local function spawn_train(params)
+ -- Check training data sanity
+ -- Now we need to join inputs and create the appropriate test vectors
+ local n = #params.set.symbols +
+ meta_functions.rspamd_count_metatokens()
+
+ -- Now we can train ann
+ local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule)
+
+ if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then
+ -- Invalidate ANN as it is definitely invalid
+ -- TODO: add invalidation
+ assert(false)
+ else
+ local inputs, outputs = {}, {}
+
+ -- Used to show sparsed vectors in a convenient format (for debugging only)
+ local function debug_vec(t)
+ local ret = {}
+ for i,v in ipairs(t) do
+ if v ~= 0 then
+ ret[#ret + 1] = string.format('%d=%.2f', i, v)
+ end
+ end
+
+ return ret
+ end
+
+ -- Make training set by joining vectors
+ -- KANN automatically shuffles those samples
+ -- 1.0 is used for spam and -1.0 is used for ham
+ -- It implies that output layer can express that (e.g. tanh output)
+ for _,e in ipairs(params.spam_vec) do
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = {1.0}
+ --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
+ end
+ for _,e in ipairs(params.ham_vec) do
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = {-1.0}
+ --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
+ end
+
+ -- Called in child process
+ local function train()
+ local log_thresh = params.rule.train.max_iterations / 10
+ local seen_nan = false
+
+ local function train_cb(iter, train_cost, value_cost)
+ if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then
+ if train_cost ~= train_cost and not seen_nan then
+ -- We have nan :( try to log lot's of stuff to dig into a problem
+ seen_nan = true
+ rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
+ params.rule.prefix, params.set.name,
+ value_cost)
+ for i,e in ipairs(inputs) do
+ lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
+ debug_vec(e), outputs[i][1])
+ end
+ end
+
+ rspamd_logger.infox(rspamd_config,
+ "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
+ params.rule.prefix, params.set.name,
+ params.ann_key,
+ iter,
+ train_cost,
+ value_cost)
+ end
+ end
+
+ lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
+ params.rule.prefix, params.set.name)
+
+ local ret,err = pcall(train_ann.train1, train_ann,
+ inputs, outputs, {
+ lr = params.rule.train.learning_rate,
+ max_epoch = params.rule.train.max_iterations,
+ cb = train_cb,
+ pca = (params.set.ann or {}).pca
+ })
+
+ if not ret then
+ rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
+ params.rule.prefix, params.set.name, err)
+
+ return nil
+ end
+
+ if not seen_nan then
+ local out = train_ann:save()
+ return out
+ else
+ return nil
+ end
+ end
+
+ params.set.learning_spawned = true
+
+ local function redis_save_cb(err)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
+ params.rule.prefix, params.set.name, params.ann_key, err)
+ lua_redis.redis_make_request_taskless(params.ev_base,
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ false, -- is write
+ gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+ 'HDEL', -- command
+ {params.ann_key, 'lock'}
+ )
+ else
+ rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
+ params.rule.prefix, params.set.name, params.set.ann.redis_key)
+ end
+ end
+
+ local function ann_trained(err, data)
+ params.set.learning_spawned = false
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
+ params.rule.prefix, params.set.name, err)
+ lua_redis.redis_make_request_taskless(params.ev_base,
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ true, -- is write
+ gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+ 'HDEL', -- command
+ {params.ann_key, 'lock'}
+ )
+ else
+ local ann_data = rspamd_util.zstd_compress(data)
+ local pca_data
+
+ fill_set_ann(params.set, params.ann_key)
+ if params.set.ann.pca then
+ pca_data = rspamd_util.zstd_compress(params.set.ann.pca:save())
+ end
+ -- Deserialise ANN from the child process
+ ann_trained = rspamd_kann.load(data)
+ local version = (params.set.ann.version or 0) + 1
+ params.set.ann.version = version
+ params.set.ann.ann = ann_trained
+ params.set.ann.symbols = params.set.symbols
+ params.set.ann.redis_key = new_ann_key(params.rule, params.set, version)
+
+ local profile = {
+ symbols = params.set.symbols,
+ digest = params.set.digest,
+ redis_key = params.set.ann.redis_key,
+ version = version
+ }
+
+ local ucl = require "ucl"
+ local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+
+ rspamd_logger.infox(rspamd_config,
+ 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
+ params.rule.prefix, params.set.name,
+ #data, #ann_data,
+ #(params.set.ann.pca or {}), #(pca_data or {}),
+ params.set.ann.redis_key, params.ann_key)
+
+ lua_redis.exec_redis_script(redis_script_id.save_unlock,
+ {ev_base = params.ev_base, is_write = true},
+ redis_save_cb,
+ {profile.redis_key,
+ redis_ann_prefix(params.rule, params.set.name),
+ ann_data,
+ profile_serialized,
+ tostring(params.rule.ann_expire),
+ tostring(os.time()),
+ params.ann_key, -- old key to unlock...
+ pca_data
+ })
+ end
+ end
+
+ if params.rule.max_inputs then
+ fill_set_ann(params.set, params.ann_key)
+ -- Train PCA in the main process, presumably it is not that long
+ params.set.ann.pca = learn_pca(inputs, params.rule.max_inputs)
+ end
+
+ params.worker:spawn_process{
+ func = train,
+ on_complete = ann_trained,
+ proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name),
+ }
+ -- Spawn learn and register lock extension
+ params.set.learning_spawned = true
+ register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key)
+ return
+
+ end
+end
+
+-- This function is used to adjust profiles and allowed setting ids for each rule
+-- It must be called when all settings are already registered (e.g. at post-init for config)
+local function process_rules_settings()
+ local function process_settings_elt(rule, selt)
+ local profile = rule.profile[selt.name]
+ if profile then
+ -- Use static user defined profile
+ -- Ensure that we have an array...
+ lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
+ rule.prefix, selt.name, profile)
+ if not profile[1] then profile = lua_util.keys(profile) end
+ selt.symbols = profile
+ else
+ lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
+ rule.prefix, selt.name)
+ end
+
+ local function filter_symbols_predicate(sname)
+ if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then
+ return false
+ end
+ local fl = rspamd_config:get_symbol_flags(sname)
+ if fl then
+ fl = lua_util.list_to_hash(fl)
+
+ return not (fl.nostat or fl.idempotent or fl.skip or fl.composite)
+ end
+
+ return false
+ end
+
+ -- Generic stuff
+ table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))
+
+ selt.digest = lua_util.table_digest(selt.symbols)
+ selt.prefix = redis_ann_prefix(rule, selt.name)
+
+ lua_redis.register_prefix(selt.prefix, N,
+ string.format('NN prefix for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name), {
+ persistent = true,
+ type = 'zlist',
+ })
+ -- Versions
+ lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
+ string.format('NN storage for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name), {
+ persistent = true,
+ type = 'hash',
+ })
+ lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N,
+ string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name), {
+ persistent = true,
+ type = 'list',
+ })
+ lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N,
+ string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name), {
+ persistent = true,
+ type = 'list',
+ })
+ end
+
+ for k,rule in pairs(settings.rules) do
+ if not rule.allowed_settings then
+ rule.allowed_settings = {}
+ elseif rule.allowed_settings == 'all' then
+ -- Extract all settings ids
+ rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
+ end
+
+ -- Convert to a map <setting_id> -> true
+ rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
+
+ -- Check if we can work without settings
+ if k == 'default' or type(rule.default) ~= 'boolean' then
+ rule.default = true
+ end
+
+ rule.settings = {}
+
+ if rule.default then
+ local default_settings = {
+ symbols = lua_settings.default_symbols(),
+ name = 'default'
+ }
+
+ process_settings_elt(rule, default_settings)
+ rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
+ end
+
+ -- Now, for each allowed settings, we store sorted symbols + digest
+ -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
+ for s,_ in pairs(rule.allowed_settings) do
+ -- Here, we have a name, set of symbols and
+ local settings_id = s
+ if type(settings_id) ~= 'number' then
+ settings_id = lua_settings.numeric_settings_id(s)
+ end
+ local selt = lua_settings.settings_by_id(settings_id)
+
+ local nelt = {
+ symbols = selt.symbols, -- Already sorted
+ name = selt.name
+ }
+
+ process_settings_elt(rule, nelt)
+ for id,ex in pairs(rule.settings) do
+ if type(ex) == 'table' then
+ if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
+ -- Equal symbols, add reference
+ lua_util.debugm(N, rspamd_config,
+ 'added reference from settings id %s to %s; same symbols',
+ nelt.name, ex.name)
+ rule.settings[settings_id] = id
+ nelt = nil
+ end
+ end
+ end
+
+ if nelt then
+ rule.settings[settings_id] = nelt
+ lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
+ nelt.name, settings_id, rule.prefix)
+ end
+ end
+ end
+end
+
+-- Extract settings element for a specific settings id
+local function get_rule_settings(task, rule)
+ local sid = task:get_settings_id() or -1
+ local set = rule.settings[sid]
+
+ if not set then return nil end
+
+ while type(set) == 'number' do
+ -- Reference to another settings!
+ set = rule.settings[set]
+ end
+
+ return set
+end
+
+local function result_to_vector(task, profile)
+ if not profile.zeros then
+ -- Fill zeros vector
+ local zeros = {}
+ for i=1,meta_functions.count_metatokens() do
+ zeros[i] = 0.0
+ end
+ for _,_ in ipairs(profile.symbols) do
+ zeros[#zeros + 1] = 0.0
+ end
+ profile.zeros = zeros
+ end
+
+ local vec = lua_util.shallowcopy(profile.zeros)
+ local mt = meta_functions.rspamd_gen_metatokens(task)
+
+ for i,v in ipairs(mt) do
+ vec[i] = v
+ end
+
+ task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
+
+ return vec
+end
+
+return {
+ can_push_train_vector = can_push_train_vector,
+ create_ann = create_ann,
+ default_options = default_options,
+ gen_unlock_cb = gen_unlock_cb,
+ get_rule_settings = get_rule_settings,
+ load_scripts = load_scripts,
+ module_config = module_config,
+ new_ann_key = new_ann_key,
+ plugin_ver = plugin_ver,
+ process_rules_settings = process_rules_settings,
+ redis_ann_prefix = redis_ann_prefix,
+ redis_params = redis_params,
+ redis_script_id = redis_script_id,
+ result_to_vector = result_to_vector,
+ settings = settings,
+ spawn_train = spawn_train,
+}