aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-12-17 13:15:02 +0000
committerGitHub <noreply@github.com>2020-12-17 13:15:02 +0000
commitc3bbc67337285414516173f778f8e5ab0841b1f6 (patch)
treebcc38704958dc6df79bc6207c93e9f12ea011ba9 /src
parent1710451544a6e4e37d7865c088782f99d8082360 (diff)
parent960b608d352e8c820b0725d898d78959ca59ee7d (diff)
downloadrspamd-c3bbc67337285414516173f778f8e5ab0841b1f6.tar.gz
rspamd-c3bbc67337285414516173f778f8e5ab0841b1f6.zip
Merge pull request #3570 from fatalbanana/nn_training
[Feature] Add controller endpoint for training neural
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/neural.lua788
1 files changed, 39 insertions, 749 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 5eab75d76..3d1c387a5 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -19,22 +19,21 @@ if confighelp then
return
end
-local rspamd_logger = require "rspamd_logger"
-local rspamd_util = require "rspamd_util"
-local rspamd_kann = require "rspamd_kann"
-local rspamd_text = require "rspamd_text"
+local fun = require "fun"
local lua_redis = require "lua_redis"
local lua_util = require "lua_util"
+local lua_verdict = require "lua_verdict"
+local neural_common = require "plugins/neural"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_logger = require "rspamd_logger"
local rspamd_tensor = require "rspamd_tensor"
-local fun = require "fun"
-local lua_settings = require "lua_settings"
-local meta_functions = require "lua_meta"
+local rspamd_text = require "rspamd_text"
+local rspamd_util = require "rspamd_util"
local ts = require("tableshape").types
-local lua_verdict = require "lua_verdict"
+
local N = "neural"
--- Used in prefix to avoid wrong ANN to be loaded
-local plugin_ver = '2'
+local settings = neural_common.settings
-- Module vars
local default_options = {
@@ -52,7 +51,7 @@ local default_options = {
classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
ham_skip_prob = 0.0, -- proportional mode: ham skip probability
- store_pool_only = false, -- store tokens in mempool variable only (disables autotrain);
+ store_pool_only = false, -- store tokens in cache only (disables autotrain);
-- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
},
watch_interval = 60.0,
@@ -77,207 +76,9 @@ local redis_profile_schema = ts.shape{
local has_blas = rspamd_tensor.has_blas()
local text_cookie = rspamd_text.cookie
--- Rule structure:
--- * static config fields (see `default_options`)
--- * prefix - name or defined prefix
--- * settings - table of settings indexed by settings id, -1 is used when no settings defined
-
--- Rule settings element defines elements for specific settings id:
--- * symbols - static symbols profile (defined by config or extracted from symcache)
--- * name - name of settings id
--- * digest - digest of all symbols
--- * ann - dynamic ANN configuration loaded from Redis
--- * train - train data for ANN (e.g. the currently trained ANN)
-
--- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
--- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
--- * version - version of ANN loaded from redis
--- * redis_key - name of ANN key in Redis
--- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
--- * distance - distance between set.symbols and set.ann.symbols
--- * ann - kann object
-
-local settings = {
- rules = {},
- prefix = 'rn', -- Neural network default prefix
- max_profiles = 3, -- Maximum number of NN profiles stored
-}
-
-local module_config = rspamd_config:get_all_opt("neural")
-if not module_config then
- -- Legacy
- module_config = rspamd_config:get_all_opt("fann_redis")
-end
-
-
--- Lua script that checks if we can store a new training vector
--- Uses the following keys:
--- key1 - ann key
--- returns nspam,nham (or nil if locked)
-local redis_lua_script_vectors_len = [[
- local prefix = KEYS[1]
- local locked = redis.call('HGET', prefix, 'lock')
- if locked then
- local host = redis.call('HGET', prefix, 'hostname') or 'unknown'
- return string.format('%s:%s', host, locked)
- end
- local nspam = 0
- local nham = 0
-
- local ret = redis.call('LLEN', prefix .. '_spam')
- if ret then nspam = tonumber(ret) end
- ret = redis.call('LLEN', prefix .. '_ham')
- if ret then nham = tonumber(ret) end
-
- return {nspam,nham}
-]]
-local redis_lua_script_vectors_len_id = nil
-
--- Lua script to invalidate ANNs by rank
--- Uses the following keys
--- key1 - prefix for keys
--- key2 - number of elements to leave
-local redis_lua_script_maybe_invalidate = [[
- local card = redis.call('ZCARD', KEYS[1])
- local lim = tonumber(KEYS[2])
- if card > lim then
- local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
- for _,k in ipairs(to_delete) do
- local tb = cjson.decode(k)
- redis.call('DEL', tb.redis_key)
- -- Also train vectors
- redis.call('DEL', tb.redis_key .. '_spam')
- redis.call('DEL', tb.redis_key .. '_ham')
- end
- redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
- return to_delete
- else
- return {}
- end
-]]
-local redis_maybe_invalidate_id = nil
-
--- Lua script to invalidate ANN from redis
--- Uses the following keys
--- key1 - prefix for keys
--- key2 - current time
--- key3 - key expire
--- key4 - hostname
-local redis_lua_script_maybe_lock = [[
- local locked = redis.call('HGET', KEYS[1], 'lock')
- local now = tonumber(KEYS[2])
- if locked then
- locked = tonumber(locked)
- local expire = tonumber(KEYS[3])
- if now > locked and (now - locked) < expire then
- return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'}
- end
- end
- redis.call('HSET', KEYS[1], 'lock', tostring(now))
- redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
- return 1
-]]
-local redis_maybe_lock_id = nil
-
--- Lua script to save and unlock ANN in redis
--- Uses the following keys
--- key1 - prefix for ANN
--- key2 - prefix for profile
--- key3 - compressed ANN
--- key4 - profile as JSON
--- key5 - expire in seconds
--- key6 - current time
--- key7 - old key
--- key8 - optional PCA
-local redis_lua_script_save_unlock = [[
- local now = tonumber(KEYS[6])
- redis.call('ZADD', KEYS[2], now, KEYS[4])
- redis.call('HSET', KEYS[1], 'ann', KEYS[3])
- redis.call('DEL', KEYS[1] .. '_spam')
- redis.call('DEL', KEYS[1] .. '_ham')
- redis.call('HDEL', KEYS[1], 'lock')
- redis.call('HDEL', KEYS[7], 'lock')
- redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
- if KEYS[8] then
- redis.call('HSET', KEYS[1], 'pca', KEYS[8])
- end
- return 1
-]]
-local redis_save_unlock_id = nil
-
-local redis_params
-
-local function load_scripts(params)
- redis_lua_script_vectors_len_id = lua_redis.add_redis_script(redis_lua_script_vectors_len,
- params)
- redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
- params)
- redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
- params)
- redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock,
- params)
-end
-
-local function result_to_vector(task, profile)
- if not profile.zeros then
- -- Fill zeros vector
- local zeros = {}
- for i=1,meta_functions.count_metatokens() do
- zeros[i] = 0.0
- end
- for _,_ in ipairs(profile.symbols) do
- zeros[#zeros + 1] = 0.0
- end
- profile.zeros = zeros
- end
-
- local vec = lua_util.shallowcopy(profile.zeros)
- local mt = meta_functions.rspamd_gen_metatokens(task)
-
- for i,v in ipairs(mt) do
- vec[i] = v
- end
-
- task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
-
- return vec
-end
-
--- Used to generate new ANN key for specific profile
-local function new_ann_key(rule, set, version)
- local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
- rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
-
- return ann_key
-end
-
--- Extract settings element for a specific settings id
-local function get_rule_settings(task, rule)
- local sid = task:get_settings_id() or -1
-
- local set = rule.settings[sid]
-
- if not set then return nil end
-
- while type(set) == 'number' do
- -- Reference to another settings!
- set = rule.settings[set]
- end
-
- return set
-end
-
--- Generate redis prefix for specific rule and specific settings
-local function redis_ann_prefix(rule, settings_name)
- -- We also need to count metatokens:
- local n = meta_functions.version
- return string.format('%s%d_%s_%d_%s',
- settings.prefix, plugin_ver, rule.prefix, n, settings_name)
-end
-
-- Creates and stores ANN profile in Redis
local function new_ann_profile(task, rule, set, version)
- local ann_key = new_ann_key(rule, set, version)
+ local ann_key = neural_common.new_ann_key(rule, set, version, settings)
local profile = {
symbols = set.symbols,
@@ -321,7 +122,7 @@ local function ann_scores_filter(task)
local ann
local profile
- local set = get_rule_settings(task, rule)
+ local set = neural_common.get_rule_settings(task, rule)
if set then
if set.ann then
ann = set.ann.ann
@@ -336,7 +137,7 @@ local function ann_scores_filter(task)
end
if ann then
- local vec = result_to_vector(task, profile)
+ local vec = neural_common.result_to_vector(task, profile)
local score
local out = ann:apply1(vec, set.ann.pca)
@@ -357,112 +158,12 @@ local function ann_scores_filter(task)
end
end
-local function create_ann(n, nlayers, rule)
- -- We ignore number of layers so far when using kann
- local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
- local t = rspamd_kann.layer.input(n)
- t = rspamd_kann.transform.relu(t)
- t = rspamd_kann.layer.dense(t, nhidden);
- t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
- return rspamd_kann.new.kann(t)
-end
-
-local function can_push_train_vector(rule, task, learn_type, nspam, nham)
- local train_opts = rule.train
- local coin = math.random()
-
- if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
- rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
- return false
- end
-
- if train_opts.learn_mode == 'balanced' then
- -- Keep balanced training set based on number of spam and ham samples
- if learn_type == 'spam' then
- if nspam <= train_opts.max_trains then
- if nspam > nham then
- -- Apply sampling
- local skip_rate = 1.0 - nham / (nspam + 1)
- if coin < skip_rate - train_opts.classes_bias then
- rspamd_logger.infox(task,
- 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
- learn_type,
- skip_rate - train_opts.classes_bias,
- nspam, nham)
- return false
- end
- end
- return true
- else -- Enough learns
- rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
- learn_type,
- nspam)
- end
- else
- if nham <= train_opts.max_trains then
- if nham > nspam then
- -- Apply sampling
- local skip_rate = 1.0 - nspam / (nham + 1)
- if coin < skip_rate - train_opts.classes_bias then
- rspamd_logger.infox(task,
- 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
- learn_type,
- skip_rate - train_opts.classes_bias,
- nspam, nham)
- return false
- end
- end
- return true
- else
- rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
- nham)
- end
- end
- else
- -- Probabilistic learn mode, we just skip learn if we already have enough samples or
- -- if our coin drop is less than desired probability
- if learn_type == 'spam' then
- if nspam <= train_opts.max_trains then
- if train_opts.spam_skip_prob then
- if coin <= train_opts.spam_skip_prob then
- rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
- coin, train_opts.spam_skip_prob)
- return false
- end
-
- return true
- end
- else
- rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
- nspam, train_opts.max_trains)
- end
- else
- if nham <= train_opts.max_trains then
- if train_opts.ham_skip_prob then
- if coin <= train_opts.ham_skip_prob then
- rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
- coin, train_opts.ham_skip_prob)
- return false
- end
-
- return true
- end
- else
- rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
- nham, train_opts.max_trains)
- end
- end
- end
-
- return false
-end
-
local function ann_push_task_result(rule, task, verdict, score, set)
local train_opts = rule.train
local learn_spam, learn_ham
local skip_reason = 'unknown'
- if train_opts.autotrain then
+ if not train_opts.store_pool_only and train_opts.autotrain then
if train_opts.spam_score then
learn_spam = score >= train_opts.spam_score
@@ -510,10 +211,10 @@ local function ann_push_task_result(rule, task, verdict, score, set)
learn_ham = false
learn_spam = false
- -- Explicitly store tokens in a mempool variable
- local vec = result_to_vector(task, set)
- task:get_mempool():set_variable('neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
- task:get_mempool():set_variable('neural_profile_digest', set.digest)
+ -- Explicitly store tokens in cache
+ local vec = neural_common.result_to_vector(task, set)
+ task:cache_set('neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
+ task:cache_set('neural_profile_digest', set.digest)
skip_reason = 'store_pool_only has been set'
end
end
@@ -527,8 +228,8 @@ local function ann_push_task_result(rule, task, verdict, score, set)
if not err and type(data) == 'table' then
local nspam,nham = data[1],data[2]
- if can_push_train_vector(rule, task, learn_type, nspam, nham) then
- local vec = result_to_vector(task, set)
+ if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
+ local vec = neural_common.result_to_vector(task, set)
local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
local target_key = set.ann.redis_key .. '_' .. learn_type
@@ -585,7 +286,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
set.name)
end
- lua_redis.exec_redis_script(redis_lua_script_vectors_len_id,
+ lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
{task = task, is_write = false},
vectors_len_cb,
{
@@ -605,284 +306,6 @@ end
--- Offline training logic
--- Closure generator for unlock function
-local function gen_unlock_cb(rule, set, ann_key)
- return function (err)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
- rule.prefix, set.name, ann_key, err)
- else
- lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
- rule.prefix, set.name, ann_key)
- end
- end
-end
-
--- This function is intended to extend lock for ANN during training
--- It registers periodic that increases locked key each 30 seconds unless
--- `set.learning_spawned` is set to `true`
-local function register_lock_extender(rule, set, ev_base, ann_key)
- rspamd_config:add_periodic(ev_base, 30.0,
- function()
- local function redis_lock_extend_cb(_err, _)
- if _err then
- rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
- ann_key, _err)
- else
- rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
- ann_key)
- end
- end
-
- if set.learning_spawned then
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- redis_lock_extend_cb, --callback
- 'HINCRBY', -- command
- {ann_key, 'lock', '30'}
- )
- else
- lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
- return false -- do not plan any more updates
- end
-
- return true
- end
- )
-end
-
--- This function takes all inputs, applies PCA transformation and returns the final
--- PCA matrix as rspamd_tensor
-local function learn_pca(inputs, max_inputs)
- local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
- local eigenvals = scatter_matrix:eigen()
- -- scatter matrix is not filled with eigenvectors
- lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
- local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
- for i=1,max_inputs do
- w[i] = scatter_matrix[#scatter_matrix - i + 1]
- end
-
- lua_util.debugm(N, 'pca matrix: %s', w)
-
- return w
-end
-
--- Fills ANN data for a specific settings element
-local function fill_set_ann(set, ann_key)
- if not set.ann then
- set.ann = {
- symbols = set.symbols,
- distance = 0,
- digest = set.digest,
- redis_key = ann_key,
- version = 0,
- }
- end
-end
-
--- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
-local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
- -- Check training data sanity
- -- Now we need to join inputs and create the appropriate test vectors
- local n = #set.symbols +
- meta_functions.rspamd_count_metatokens()
-
- -- Now we can train ann
- local train_ann = create_ann(rule.max_inputs or n, 3, rule)
-
- if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
- -- Invalidate ANN as it is definitely invalid
- -- TODO: add invalidation
- assert(false)
- else
- local inputs, outputs = {}, {}
-
- -- Used to show sparsed vectors in a convenient format (for debugging only)
- local function debug_vec(t)
- local ret = {}
- for i,v in ipairs(t) do
- if v ~= 0 then
- ret[#ret + 1] = string.format('%d=%.2f', i, v)
- end
- end
-
- return ret
- end
-
- -- Make training set by joining vectors
- -- KANN automatically shuffles those samples
- -- 1.0 is used for spam and -1.0 is used for ham
- -- It implies that output layer can express that (e.g. tanh output)
- for _,e in ipairs(spam_vec) do
- inputs[#inputs + 1] = e
- outputs[#outputs + 1] = {1.0}
- --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
- end
- for _,e in ipairs(ham_vec) do
- inputs[#inputs + 1] = e
- outputs[#outputs + 1] = {-1.0}
- --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
- end
-
- -- Called in child process
- local function train()
- local log_thresh = rule.train.max_iterations / 10
- local seen_nan = false
-
- local function train_cb(iter, train_cost, value_cost)
- if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
- if train_cost ~= train_cost and not seen_nan then
- -- We have nan :( try to log lot's of stuff to dig into a problem
- seen_nan = true
- rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
- rule.prefix, set.name,
- value_cost)
- for i,e in ipairs(inputs) do
- lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
- debug_vec(e), outputs[i][1])
- end
- end
-
- rspamd_logger.infox(rspamd_config,
- "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
- rule.prefix, set.name,
- ann_key,
- iter,
- train_cost,
- value_cost)
- end
- end
-
- lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
- rule.prefix, set.name)
-
- local ret,err = pcall(train_ann.train1, train_ann,
- inputs, outputs, {
- lr = rule.train.learning_rate,
- max_epoch = rule.train.max_iterations,
- cb = train_cb,
- pca = (set.ann or {}).pca
- })
-
- if not ret then
- rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
- rule.prefix, set.name, err)
-
- return nil
- end
-
- if not seen_nan then
- local out = train_ann:save()
- return out
- else
- return nil
- end
- end
-
- set.learning_spawned = true
-
- local function redis_save_cb(err)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
- rule.prefix, set.name, ann_key, err)
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- {ann_key, 'lock'}
- )
- else
- rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
- rule.prefix, set.name, set.ann.redis_key)
- end
- end
-
- local function ann_trained(err, data)
- set.learning_spawned = false
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
- rule.prefix, set.name, err)
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- {ann_key, 'lock'}
- )
- else
- local ann_data = rspamd_util.zstd_compress(data)
- local pca_data
-
- fill_set_ann(set, ann_key)
- if set.ann.pca then
- pca_data = rspamd_util.zstd_compress(set.ann.pca:save())
- end
- -- Deserialise ANN from the child process
- ann_trained = rspamd_kann.load(data)
- local version = (set.ann.version or 0) + 1
- set.ann.version = version
- set.ann.ann = ann_trained
- set.ann.symbols = set.symbols
- set.ann.redis_key = new_ann_key(rule, set, version)
-
- local profile = {
- symbols = set.symbols,
- digest = set.digest,
- redis_key = set.ann.redis_key,
- version = version
- }
-
- local ucl = require "ucl"
- local profile_serialized = ucl.to_format(profile, 'json-compact', true)
-
- rspamd_logger.infox(rspamd_config,
- 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
- rule.prefix, set.name,
- #data, #ann_data,
- #(set.ann.pca or {}), #(pca_data or {}),
- set.ann.redis_key, ann_key)
-
- lua_redis.exec_redis_script(redis_save_unlock_id,
- {ev_base = ev_base, is_write = true},
- redis_save_cb,
- {profile.redis_key,
- redis_ann_prefix(rule, set.name),
- ann_data,
- profile_serialized,
- tostring(rule.ann_expire),
- tostring(os.time()),
- ann_key, -- old key to unlock...
- pca_data
- })
- end
- end
-
- if rule.max_inputs then
- fill_set_ann(set, ann_key)
- -- Train PCA in the main process, presumably it is not that long
- set.ann.pca = learn_pca(inputs, rule.max_inputs)
- end
-
- worker:spawn_process{
- func = train,
- on_complete = ann_trained,
- proctitle = string.format("ANN train for %s/%s", rule.prefix, set.name),
- }
- end
- -- Spawn learn and register lock extension
- set.learning_spawned = true
- register_lock_extender(rule, set, ev_base, ann_key)
-end
-
-- Utility to extract and split saved training vectors to a table of tables
local function process_training_vectors(data)
return fun.totable(fun.map(function(tok)
@@ -909,14 +332,16 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
rule.redis,
nil,
true, -- is write
- gen_unlock_cb(rule, set, ann_key), --callback
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
'HDEL', -- command
{ann_key, 'lock'}
)
else
-- Decompress and convert to numbers each training vector
ham_elts = process_training_vectors(data)
- spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts)
+ neural_common.spawn_train({worker = worker, ev_base = ev_base,
+ rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts,
+ spam_vec = spam_elts})
end
end
@@ -931,7 +356,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
rule.redis,
nil,
true, -- is write
- gen_unlock_cb(rule, set, ann_key), --callback
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
'HDEL', -- command
{ann_key, 'lock'}
)
@@ -987,7 +412,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
-- Call Redis script that tries to acquire a lock
-- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
-- ANN is locked by another host (or a process, meh)
- lua_redis.exec_redis_script(redis_maybe_lock_id,
+ lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
{ev_base = ev_base, is_write = true},
redis_lock_cb,
{
@@ -1376,7 +801,7 @@ local function cleanup_anns(rule, cfg, ev_base)
end
if type(set) == 'table' then
- lua_redis.exec_redis_script(redis_maybe_invalidate_id,
+ lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
{ev_base = ev_base, is_write = true},
invalidate_cb,
{set.prefix, tostring(settings.max_profiles)})
@@ -1411,7 +836,7 @@ local function ann_push_vector(task)
end
for _,rule in pairs(settings.rules) do
- local set = get_rule_settings(task, rule)
+ local set = neural_common.get_rule_settings(task, rule)
if set then
ann_push_task_result(rule, task, verdict, score, set)
@@ -1423,155 +848,20 @@ local function ann_push_vector(task)
end
--- This function is used to adjust profiles and allowed setting ids for each rule
--- It must be called when all settings are already registered (e.g. at post-init for config)
-local function process_rules_settings()
- local function process_settings_elt(rule, selt)
- local profile = rule.profile[selt.name]
- if profile then
- -- Use static user defined profile
- -- Ensure that we have an array...
- lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
- rule.prefix, selt.name, profile)
- if not profile[1] then profile = lua_util.keys(profile) end
- selt.symbols = profile
- else
- lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
- rule.prefix, selt.name)
- end
-
- local function filter_symbols_predicate(sname)
- if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then
- return false
- end
- local fl = rspamd_config:get_symbol_flags(sname)
- if fl then
- fl = lua_util.list_to_hash(fl)
-
- return not (fl.nostat or fl.idempotent or fl.skip or fl.composite)
- end
-
- return false
- end
-
- -- Generic stuff
- table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))
-
- selt.digest = lua_util.table_digest(selt.symbols)
- selt.prefix = redis_ann_prefix(rule, selt.name)
-
- lua_redis.register_prefix(selt.prefix, N,
- string.format('NN prefix for rule "%s"; settings id "%s"',
- rule.prefix, selt.name), {
- persistent = true,
- type = 'zlist',
- })
- -- Versions
- lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
- string.format('NN storage for rule "%s"; settings id "%s"',
- rule.prefix, selt.name), {
- persistent = true,
- type = 'hash',
- })
- lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N,
- string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
- rule.prefix, selt.name), {
- persistent = true,
- type = 'list',
- })
- lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N,
- string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
- rule.prefix, selt.name), {
- persistent = true,
- type = 'list',
- })
- end
-
- for k,rule in pairs(settings.rules) do
- if not rule.allowed_settings then
- rule.allowed_settings = {}
- elseif rule.allowed_settings == 'all' then
- -- Extract all settings ids
- rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
- end
-
- -- Convert to a map <setting_id> -> true
- rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
-
- -- Check if we can work without settings
- if k == 'default' or type(rule.default) ~= 'boolean' then
- rule.default = true
- end
-
- rule.settings = {}
-
- if rule.default then
- local default_settings = {
- symbols = lua_settings.default_symbols(),
- name = 'default'
- }
-
- process_settings_elt(rule, default_settings)
- rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
- end
-
- -- Now, for each allowed settings, we store sorted symbols + digest
- -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
- for s,_ in pairs(rule.allowed_settings) do
- -- Here, we have a name, set of symbols and
- local settings_id = s
- if type(settings_id) ~= 'number' then
- settings_id = lua_settings.numeric_settings_id(s)
- end
- local selt = lua_settings.settings_by_id(settings_id)
-
- local nelt = {
- symbols = selt.symbols, -- Already sorted
- name = selt.name
- }
-
- process_settings_elt(rule, nelt)
- for id,ex in pairs(rule.settings) do
- if type(ex) == 'table' then
- if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
- -- Equal symbols, add reference
- lua_util.debugm(N, rspamd_config,
- 'added reference from settings id %s to %s; same symbols',
- nelt.name, ex.name)
- rule.settings[settings_id] = id
- nelt = nil
- end
- end
- end
-
- if nelt then
- rule.settings[settings_id] = nelt
- lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
- nelt.name, settings_id, rule.prefix)
- end
- end
- end
-end
-
-redis_params = lua_redis.parse_redis_server('neural')
-
-if not redis_params then
- redis_params = lua_redis.parse_redis_server('fann_redis')
-end
-
-- Initialization part
-if not (module_config and type(module_config) == 'table') or not redis_params then
+if not (neural_common.module_config and type(neural_common.module_config) == 'table')
+ or not neural_common.redis_params then
rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
lua_util.disable_module(N, "redis")
return
end
-local rules = module_config['rules']
+local rules = neural_common.module_config['rules']
if not rules then
-- Use legacy configuration
rules = {}
- rules['default'] = module_config
+ rules['default'] = neural_common.module_config
end
local id = rspamd_config:register_symbol({
@@ -1582,8 +872,7 @@ local id = rspamd_config:register_symbol({
callback = ann_scores_filter
})
-settings = lua_util.override_defaults(settings, module_config)
-settings.rules = {} -- Reset unless validated further in the cycle
+neural_common.settings.rules = {} -- Reset unless validated further in the cycle
if settings.blacklisted_symbols and settings.blacklisted_symbols[1] then
-- Transform to hash for simplicity
@@ -1593,7 +882,7 @@ end
-- Check all rules
for k,r in pairs(rules) do
local rule_elt = lua_util.override_defaults(default_options, r)
- rule_elt['redis'] = redis_params
+ rule_elt['redis'] = neural_common.redis_params
rule_elt['anns'] = {} -- Store ANNs here
if not rule_elt.prefix then
@@ -1651,11 +940,12 @@ rspamd_config:register_symbol({
callback = ann_push_vector
})
+-- We also need to deal with settings
+rspamd_config:add_post_init(neural_common.process_rules_settings)
+
-- Add training scripts
for _,rule in pairs(settings.rules) do
- load_scripts(rule.redis)
- -- We also need to deal with settings
- rspamd_config:add_post_init(process_rules_settings)
+ neural_common.load_scripts(rule.redis)
-- This function will check ANNs in Redis when a worker is loaded
rspamd_config:add_on_load(function(cfg, ev_base, worker)
if worker:is_scanner() then