- Move neural functions to library - Parameterise spawn_train - neural plugin: Fix store_pool_only when autotrain is true - neural plugin: Use cache_set instead of mempool - Add testtags/2.7
@@ -0,0 +1,779 @@ | |||
--[[ | |||
Copyright (c) 2020, Vsevolod Stakhov <vsevolod@highsecure.ru> | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
]]-- | |||
local fun = require "fun" | |||
local lua_redis = require "lua_redis" | |||
local lua_settings = require "lua_settings" | |||
local lua_util = require "lua_util" | |||
local meta_functions = require "lua_meta" | |||
local rspamd_kann = require "rspamd_kann" | |||
local rspamd_logger = require "rspamd_logger" | |||
local rspamd_tensor = require "rspamd_tensor" | |||
local rspamd_util = require "rspamd_util" | |||
local N = 'neural' | |||
-- Used in prefix to avoid wrong ANN to be loaded | |||
local plugin_ver = '2' | |||
-- Module vars | |||
local default_options = { | |||
train = { | |||
max_trains = 1000, | |||
max_epoch = 1000, | |||
max_usages = 10, | |||
max_iterations = 25, -- Torch style | |||
mse = 0.001, | |||
autotrain = true, | |||
train_prob = 1.0, | |||
learn_threads = 1, | |||
learn_mode = 'balanced', -- Possible values: balanced, proportional | |||
learning_rate = 0.01, | |||
classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias) | |||
spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1) | |||
ham_skip_prob = 0.0, -- proportional mode: ham skip probability | |||
store_pool_only = false, -- store tokens in cache only (disables autotrain); | |||
-- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest | |||
}, | |||
watch_interval = 60.0, | |||
lock_expire = 600, | |||
learning_spawned = false, | |||
ann_expire = 60 * 60 * 24 * 2, -- 2 days | |||
hidden_layer_mult = 1.5, -- number of neurons in the hidden layer | |||
symbol_spam = 'NEURAL_SPAM', | |||
symbol_ham = 'NEURAL_HAM', | |||
max_inputs = nil, -- when PCA is used | |||
blacklisted_symbols = {}, -- list of symbols skipped in neural processing | |||
} | |||
-- Rule structure: | |||
-- * static config fields (see `default_options`) | |||
-- * prefix - name or defined prefix | |||
-- * settings - table of settings indexed by settings id, -1 is used when no settings defined | |||
-- Rule settings element defines elements for specific settings id: | |||
-- * symbols - static symbols profile (defined by config or extracted from symcache) | |||
-- * name - name of settings id | |||
-- * digest - digest of all symbols | |||
-- * ann - dynamic ANN configuration loaded from Redis | |||
-- * train - train data for ANN (e.g. the currently trained ANN) | |||
-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN | |||
-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically | |||
-- * version - version of ANN loaded from redis | |||
-- * redis_key - name of ANN key in Redis | |||
-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols) | |||
-- * distance - distance between set.symbols and set.ann.symbols | |||
-- * ann - kann object | |||
local settings = { | |||
rules = {}, | |||
prefix = 'rn', -- Neural network default prefix | |||
max_profiles = 3, -- Maximum number of NN profiles stored | |||
} | |||
-- Get module & Redis configuration | |||
local module_config = rspamd_config:get_all_opt(N) | |||
settings = lua_util.override_defaults(settings, module_config) | |||
local redis_params = lua_redis.parse_redis_server('neural') | |||
-- Lua script that checks if we can store a new training vector | |||
-- Uses the following keys: | |||
-- key1 - ann key | |||
-- returns nspam,nham (or nil if locked) | |||
local redis_lua_script_vectors_len = [[ | |||
local prefix = KEYS[1] | |||
local locked = redis.call('HGET', prefix, 'lock') | |||
if locked then | |||
local host = redis.call('HGET', prefix, 'hostname') or 'unknown' | |||
return string.format('%s:%s', host, locked) | |||
end | |||
local nspam = 0 | |||
local nham = 0 | |||
local ret = redis.call('LLEN', prefix .. '_spam') | |||
if ret then nspam = tonumber(ret) end | |||
ret = redis.call('LLEN', prefix .. '_ham') | |||
if ret then nham = tonumber(ret) end | |||
return {nspam,nham} | |||
]] | |||
-- Lua script to invalidate ANNs by rank | |||
-- Uses the following keys | |||
-- key1 - prefix for keys | |||
-- key2 - number of elements to leave | |||
local redis_lua_script_maybe_invalidate = [[ | |||
local card = redis.call('ZCARD', KEYS[1]) | |||
local lim = tonumber(KEYS[2]) | |||
if card > lim then | |||
local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1) | |||
for _,k in ipairs(to_delete) do | |||
local tb = cjson.decode(k) | |||
redis.call('DEL', tb.redis_key) | |||
-- Also train vectors | |||
redis.call('DEL', tb.redis_key .. '_spam') | |||
redis.call('DEL', tb.redis_key .. '_ham') | |||
end | |||
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1) | |||
return to_delete | |||
else | |||
return {} | |||
end | |||
]] | |||
-- Lua script to invalidate ANN from redis | |||
-- Uses the following keys | |||
-- key1 - prefix for keys | |||
-- key2 - current time | |||
-- key3 - key expire | |||
-- key4 - hostname | |||
local redis_lua_script_maybe_lock = [[ | |||
local locked = redis.call('HGET', KEYS[1], 'lock') | |||
local now = tonumber(KEYS[2]) | |||
if locked then | |||
locked = tonumber(locked) | |||
local expire = tonumber(KEYS[3]) | |||
if now > locked and (now - locked) < expire then | |||
return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'} | |||
end | |||
end | |||
redis.call('HSET', KEYS[1], 'lock', tostring(now)) | |||
redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) | |||
return 1 | |||
]] | |||
-- Lua script to save and unlock ANN in redis | |||
-- Uses the following keys | |||
-- key1 - prefix for ANN | |||
-- key2 - prefix for profile | |||
-- key3 - compressed ANN | |||
-- key4 - profile as JSON | |||
-- key5 - expire in seconds | |||
-- key6 - current time | |||
-- key7 - old key | |||
-- key8 - optional PCA | |||
local redis_lua_script_save_unlock = [[ | |||
local now = tonumber(KEYS[6]) | |||
redis.call('ZADD', KEYS[2], now, KEYS[4]) | |||
redis.call('HSET', KEYS[1], 'ann', KEYS[3]) | |||
redis.call('DEL', KEYS[1] .. '_spam') | |||
redis.call('DEL', KEYS[1] .. '_ham') | |||
redis.call('HDEL', KEYS[1], 'lock') | |||
redis.call('HDEL', KEYS[7], 'lock') | |||
redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) | |||
if KEYS[8] then | |||
redis.call('HSET', KEYS[1], 'pca', KEYS[8]) | |||
end | |||
return 1 | |||
]] | |||
local redis_script_id = {} | |||
local function load_scripts() | |||
redis_script_id.vectors_len = lua_redis.add_redis_script(redis_lua_script_vectors_len, | |||
redis_params) | |||
redis_script_id.maybe_invalidate = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate, | |||
redis_params) | |||
redis_script_id.maybe_lock = lua_redis.add_redis_script(redis_lua_script_maybe_lock, | |||
redis_params) | |||
redis_script_id.save_unlock = lua_redis.add_redis_script(redis_lua_script_save_unlock, | |||
redis_params) | |||
end | |||
local function create_ann(n, nlayers, rule) | |||
-- We ignore number of layers so far when using kann | |||
local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0) | |||
local t = rspamd_kann.layer.input(n) | |||
t = rspamd_kann.transform.relu(t) | |||
t = rspamd_kann.layer.dense(t, nhidden); | |||
t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg) | |||
return rspamd_kann.new.kann(t) | |||
end | |||
-- Fills ANN data for a specific settings element | |||
local function fill_set_ann(set, ann_key) | |||
if not set.ann then | |||
set.ann = { | |||
symbols = set.symbols, | |||
distance = 0, | |||
digest = set.digest, | |||
redis_key = ann_key, | |||
version = 0, | |||
} | |||
end | |||
end | |||
-- This function takes all inputs, applies PCA transformation and returns the final | |||
-- PCA matrix as rspamd_tensor | |||
local function learn_pca(inputs, max_inputs) | |||
local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs)) | |||
local eigenvals = scatter_matrix:eigen() | |||
-- scatter matrix is not filled with eigenvectors | |||
lua_util.debugm(N, 'eigenvalues: %s', eigenvals) | |||
local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1]) | |||
for i=1,max_inputs do | |||
w[i] = scatter_matrix[#scatter_matrix - i + 1] | |||
end | |||
lua_util.debugm(N, 'pca matrix: %s', w) | |||
return w | |||
end | |||
-- This function is intended to extend lock for ANN during training | |||
-- It registers periodic that increases locked key each 30 seconds unless | |||
-- `set.learning_spawned` is set to `true` | |||
local function register_lock_extender(rule, set, ev_base, ann_key) | |||
rspamd_config:add_periodic(ev_base, 30.0, | |||
function() | |||
local function redis_lock_extend_cb(_err, _) | |||
if _err then | |||
rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', | |||
ann_key, _err) | |||
else | |||
rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', | |||
ann_key) | |||
end | |||
end | |||
if set.learning_spawned then | |||
lua_redis.redis_make_request_taskless(ev_base, | |||
rspamd_config, | |||
rule.redis, | |||
nil, | |||
true, -- is write | |||
redis_lock_extend_cb, --callback | |||
'HINCRBY', -- command | |||
{ann_key, 'lock', '30'} | |||
) | |||
else | |||
lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") | |||
return false -- do not plan any more updates | |||
end | |||
return true | |||
end | |||
) | |||
end | |||
local function can_push_train_vector(rule, task, learn_type, nspam, nham) | |||
local train_opts = rule.train | |||
local coin = math.random() | |||
if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then | |||
rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) | |||
return false | |||
end | |||
if train_opts.learn_mode == 'balanced' then | |||
-- Keep balanced training set based on number of spam and ham samples | |||
if learn_type == 'spam' then | |||
if nspam <= train_opts.max_trains then | |||
if nspam > nham then | |||
-- Apply sampling | |||
local skip_rate = 1.0 - nham / (nspam + 1) | |||
if coin < skip_rate - train_opts.classes_bias then | |||
rspamd_logger.infox(task, | |||
'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', | |||
learn_type, | |||
skip_rate - train_opts.classes_bias, | |||
nspam, nham) | |||
return false | |||
end | |||
end | |||
return true | |||
else -- Enough learns | |||
rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', | |||
learn_type, | |||
nspam) | |||
end | |||
else | |||
if nham <= train_opts.max_trains then | |||
if nham > nspam then | |||
-- Apply sampling | |||
local skip_rate = 1.0 - nspam / (nham + 1) | |||
if coin < skip_rate - train_opts.classes_bias then | |||
rspamd_logger.infox(task, | |||
'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', | |||
learn_type, | |||
skip_rate - train_opts.classes_bias, | |||
nspam, nham) | |||
return false | |||
end | |||
end | |||
return true | |||
else | |||
rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type, | |||
nham) | |||
end | |||
end | |||
else | |||
-- Probabilistic learn mode, we just skip learn if we already have enough samples or | |||
-- if our coin drop is less than desired probability | |||
if learn_type == 'spam' then | |||
if nspam <= train_opts.max_trains then | |||
if train_opts.spam_skip_prob then | |||
if coin <= train_opts.spam_skip_prob then | |||
rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type, | |||
coin, train_opts.spam_skip_prob) | |||
return false | |||
end | |||
return true | |||
end | |||
else | |||
rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type, | |||
nspam, train_opts.max_trains) | |||
end | |||
else | |||
if nham <= train_opts.max_trains then | |||
if train_opts.ham_skip_prob then | |||
if coin <= train_opts.ham_skip_prob then | |||
rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type, | |||
coin, train_opts.ham_skip_prob) | |||
return false | |||
end | |||
return true | |||
end | |||
else | |||
rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type, | |||
nham, train_opts.max_trains) | |||
end | |||
end | |||
end | |||
return false | |||
end | |||
-- Closure generator for unlock function | |||
local function gen_unlock_cb(rule, set, ann_key) | |||
return function (err) | |||
if err then | |||
rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s', | |||
rule.prefix, set.name, ann_key, err) | |||
else | |||
lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s', | |||
rule.prefix, set.name, ann_key) | |||
end | |||
end | |||
end | |||
-- Used to generate new ANN key for specific profile | |||
local function new_ann_key(rule, set, version) | |||
local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix, | |||
rule.prefix, set.name, set.digest:sub(1, 8), tostring(version)) | |||
return ann_key | |||
end | |||
local function redis_ann_prefix(rule, settings_name) | |||
-- We also need to count metatokens: | |||
local n = meta_functions.version | |||
return string.format('%s%d_%s_%d_%s', | |||
settings.prefix, plugin_ver, rule.prefix, n, settings_name) | |||
end | |||
-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis | |||
local function spawn_train(params) | |||
-- Check training data sanity | |||
-- Now we need to join inputs and create the appropriate test vectors | |||
local n = #params.set.symbols + | |||
meta_functions.rspamd_count_metatokens() | |||
-- Now we can train ann | |||
local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule) | |||
if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then | |||
-- Invalidate ANN as it is definitely invalid | |||
-- TODO: add invalidation | |||
assert(false) | |||
else | |||
local inputs, outputs = {}, {} | |||
-- Used to show sparsed vectors in a convenient format (for debugging only) | |||
local function debug_vec(t) | |||
local ret = {} | |||
for i,v in ipairs(t) do | |||
if v ~= 0 then | |||
ret[#ret + 1] = string.format('%d=%.2f', i, v) | |||
end | |||
end | |||
return ret | |||
end | |||
-- Make training set by joining vectors | |||
-- KANN automatically shuffles those samples | |||
-- 1.0 is used for spam and -1.0 is used for ham | |||
-- It implies that output layer can express that (e.g. tanh output) | |||
for _,e in ipairs(params.spam_vec) do | |||
inputs[#inputs + 1] = e | |||
outputs[#outputs + 1] = {1.0} | |||
--rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e)) | |||
end | |||
for _,e in ipairs(params.ham_vec) do | |||
inputs[#inputs + 1] = e | |||
outputs[#outputs + 1] = {-1.0} | |||
--rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e)) | |||
end | |||
-- Called in child process | |||
local function train() | |||
local log_thresh = params.rule.train.max_iterations / 10 | |||
local seen_nan = false | |||
local function train_cb(iter, train_cost, value_cost) | |||
if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then | |||
if train_cost ~= train_cost and not seen_nan then | |||
-- We have nan :( try to log lot's of stuff to dig into a problem | |||
seen_nan = true | |||
rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s', | |||
params.rule.prefix, params.set.name, | |||
value_cost) | |||
for i,e in ipairs(inputs) do | |||
lua_util.debugm(N, rspamd_config, 'train vector %s -> %s', | |||
debug_vec(e), outputs[i][1]) | |||
end | |||
end | |||
rspamd_logger.infox(rspamd_config, | |||
"ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s", | |||
params.rule.prefix, params.set.name, | |||
params.ann_key, | |||
iter, | |||
train_cost, | |||
value_cost) | |||
end | |||
end | |||
lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started", | |||
params.rule.prefix, params.set.name) | |||
local ret,err = pcall(train_ann.train1, train_ann, | |||
inputs, outputs, { | |||
lr = params.rule.train.learning_rate, | |||
max_epoch = params.rule.train.max_iterations, | |||
cb = train_cb, | |||
pca = (params.set.ann or {}).pca | |||
}) | |||
if not ret then | |||
rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s", | |||
params.rule.prefix, params.set.name, err) | |||
return nil | |||
end | |||
if not seen_nan then | |||
local out = train_ann:save() | |||
return out | |||
else | |||
return nil | |||
end | |||
end | |||
params.set.learning_spawned = true | |||
local function redis_save_cb(err) | |||
if err then | |||
rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s', | |||
params.rule.prefix, params.set.name, params.ann_key, err) | |||
lua_redis.redis_make_request_taskless(params.ev_base, | |||
rspamd_config, | |||
params.rule.redis, | |||
nil, | |||
false, -- is write | |||
gen_unlock_cb(params.rule, params.set, params.ann_key), --callback | |||
'HDEL', -- command | |||
{params.ann_key, 'lock'} | |||
) | |||
else | |||
rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', | |||
params.rule.prefix, params.set.name, params.set.ann.redis_key) | |||
end | |||
end | |||
local function ann_trained(err, data) | |||
params.set.learning_spawned = false | |||
if err then | |||
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', | |||
params.rule.prefix, params.set.name, err) | |||
lua_redis.redis_make_request_taskless(params.ev_base, | |||
rspamd_config, | |||
params.rule.redis, | |||
nil, | |||
true, -- is write | |||
gen_unlock_cb(params.rule, params.set, params.ann_key), --callback | |||
'HDEL', -- command | |||
{params.ann_key, 'lock'} | |||
) | |||
else | |||
local ann_data = rspamd_util.zstd_compress(data) | |||
local pca_data | |||
fill_set_ann(params.set, params.ann_key) | |||
if params.set.ann.pca then | |||
pca_data = rspamd_util.zstd_compress(params.set.ann.pca:save()) | |||
end | |||
-- Deserialise ANN from the child process | |||
ann_trained = rspamd_kann.load(data) | |||
local version = (params.set.ann.version or 0) + 1 | |||
params.set.ann.version = version | |||
params.set.ann.ann = ann_trained | |||
params.set.ann.symbols = params.set.symbols | |||
params.set.ann.redis_key = new_ann_key(params.rule, params.set, version) | |||
local profile = { | |||
symbols = params.set.symbols, | |||
digest = params.set.digest, | |||
redis_key = params.set.ann.redis_key, | |||
version = version | |||
} | |||
local ucl = require "ucl" | |||
local profile_serialized = ucl.to_format(profile, 'json-compact', true) | |||
rspamd_logger.infox(rspamd_config, | |||
'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)', | |||
params.rule.prefix, params.set.name, | |||
#data, #ann_data, | |||
#(params.set.ann.pca or {}), #(pca_data or {}), | |||
params.set.ann.redis_key, params.ann_key) | |||
lua_redis.exec_redis_script(redis_script_id.save_unlock, | |||
{ev_base = params.ev_base, is_write = true}, | |||
redis_save_cb, | |||
{profile.redis_key, | |||
redis_ann_prefix(params.rule, params.set.name), | |||
ann_data, | |||
profile_serialized, | |||
tostring(params.rule.ann_expire), | |||
tostring(os.time()), | |||
params.ann_key, -- old key to unlock... | |||
pca_data | |||
}) | |||
end | |||
end | |||
if params.rule.max_inputs then | |||
fill_set_ann(params.set, params.ann_key) | |||
-- Train PCA in the main process, presumably it is not that long | |||
params.set.ann.pca = learn_pca(inputs, params.rule.max_inputs) | |||
end | |||
params.worker:spawn_process{ | |||
func = train, | |||
on_complete = ann_trained, | |||
proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name), | |||
} | |||
-- Spawn learn and register lock extension | |||
params.set.learning_spawned = true | |||
register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key) | |||
return | |||
end | |||
end | |||
-- This function is used to adjust profiles and allowed setting ids for each rule | |||
-- It must be called when all settings are already registered (e.g. at post-init for config) | |||
local function process_rules_settings() | |||
local function process_settings_elt(rule, selt) | |||
local profile = rule.profile[selt.name] | |||
if profile then | |||
-- Use static user defined profile | |||
-- Ensure that we have an array... | |||
lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s", | |||
rule.prefix, selt.name, profile) | |||
if not profile[1] then profile = lua_util.keys(profile) end | |||
selt.symbols = profile | |||
else | |||
lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)", | |||
rule.prefix, selt.name) | |||
end | |||
local function filter_symbols_predicate(sname) | |||
if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then | |||
return false | |||
end | |||
local fl = rspamd_config:get_symbol_flags(sname) | |||
if fl then | |||
fl = lua_util.list_to_hash(fl) | |||
return not (fl.nostat or fl.idempotent or fl.skip or fl.composite) | |||
end | |||
return false | |||
end | |||
-- Generic stuff | |||
table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols))) | |||
selt.digest = lua_util.table_digest(selt.symbols) | |||
selt.prefix = redis_ann_prefix(rule, selt.name) | |||
lua_redis.register_prefix(selt.prefix, N, | |||
string.format('NN prefix for rule "%s"; settings id "%s"', | |||
rule.prefix, selt.name), { | |||
persistent = true, | |||
type = 'zlist', | |||
}) | |||
-- Versions | |||
lua_redis.register_prefix(selt.prefix .. '_\\d+', N, | |||
string.format('NN storage for rule "%s"; settings id "%s"', | |||
rule.prefix, selt.name), { | |||
persistent = true, | |||
type = 'hash', | |||
}) | |||
lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N, | |||
string.format('NN learning set (spam) for rule "%s"; settings id "%s"', | |||
rule.prefix, selt.name), { | |||
persistent = true, | |||
type = 'list', | |||
}) | |||
lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N, | |||
string.format('NN learning set (spam) for rule "%s"; settings id "%s"', | |||
rule.prefix, selt.name), { | |||
persistent = true, | |||
type = 'list', | |||
}) | |||
end | |||
for k,rule in pairs(settings.rules) do | |||
if not rule.allowed_settings then | |||
rule.allowed_settings = {} | |||
elseif rule.allowed_settings == 'all' then | |||
-- Extract all settings ids | |||
rule.allowed_settings = lua_util.keys(lua_settings.all_settings()) | |||
end | |||
-- Convert to a map <setting_id> -> true | |||
rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings) | |||
-- Check if we can work without settings | |||
if k == 'default' or type(rule.default) ~= 'boolean' then | |||
rule.default = true | |||
end | |||
rule.settings = {} | |||
if rule.default then | |||
local default_settings = { | |||
symbols = lua_settings.default_symbols(), | |||
name = 'default' | |||
} | |||
process_settings_elt(rule, default_settings) | |||
rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32 | |||
end | |||
-- Now, for each allowed settings, we store sorted symbols + digest | |||
-- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest } | |||
for s,_ in pairs(rule.allowed_settings) do | |||
-- Here, we have a name, set of symbols and | |||
local settings_id = s | |||
if type(settings_id) ~= 'number' then | |||
settings_id = lua_settings.numeric_settings_id(s) | |||
end | |||
local selt = lua_settings.settings_by_id(settings_id) | |||
local nelt = { | |||
symbols = selt.symbols, -- Already sorted | |||
name = selt.name | |||
} | |||
process_settings_elt(rule, nelt) | |||
for id,ex in pairs(rule.settings) do | |||
if type(ex) == 'table' then | |||
if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then | |||
-- Equal symbols, add reference | |||
lua_util.debugm(N, rspamd_config, | |||
'added reference from settings id %s to %s; same symbols', | |||
nelt.name, ex.name) | |||
rule.settings[settings_id] = id | |||
nelt = nil | |||
end | |||
end | |||
end | |||
if nelt then | |||
rule.settings[settings_id] = nelt | |||
lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s', | |||
nelt.name, settings_id, rule.prefix) | |||
end | |||
end | |||
end | |||
end | |||
-- Extract settings element for a specific settings id | |||
local function get_rule_settings(task, rule) | |||
local sid = task:get_settings_id() or -1 | |||
local set = rule.settings[sid] | |||
if not set then return nil end | |||
while type(set) == 'number' do | |||
-- Reference to another settings! | |||
set = rule.settings[set] | |||
end | |||
return set | |||
end | |||
local function result_to_vector(task, profile) | |||
if not profile.zeros then | |||
-- Fill zeros vector | |||
local zeros = {} | |||
for i=1,meta_functions.count_metatokens() do | |||
zeros[i] = 0.0 | |||
end | |||
for _,_ in ipairs(profile.symbols) do | |||
zeros[#zeros + 1] = 0.0 | |||
end | |||
profile.zeros = zeros | |||
end | |||
local vec = lua_util.shallowcopy(profile.zeros) | |||
local mt = meta_functions.rspamd_gen_metatokens(task) | |||
for i,v in ipairs(mt) do | |||
vec[i] = v | |||
end | |||
task:process_ann_tokens(profile.symbols, vec, #mt, 0.1) | |||
return vec | |||
end | |||
return { | |||
can_push_train_vector = can_push_train_vector, | |||
create_ann = create_ann, | |||
default_options = default_options, | |||
gen_unlock_cb = gen_unlock_cb, | |||
get_rule_settings = get_rule_settings, | |||
load_scripts = load_scripts, | |||
module_config = module_config, | |||
new_ann_key = new_ann_key, | |||
plugin_ver = plugin_ver, | |||
process_rules_settings = process_rules_settings, | |||
redis_ann_prefix = redis_ann_prefix, | |||
redis_params = redis_params, | |||
redis_script_id = redis_script_id, | |||
result_to_vector = result_to_vector, | |||
settings = settings, | |||
spawn_train = spawn_train, | |||
} |
@@ -25,8 +25,9 @@ local rspamd_logger = require "rspamd_logger" | |||
-- Define default controller paths, could be overridden in local.d/controller.lua | |||
local controller_plugin_paths = { | |||
maps = dofile(local_rules .. "/controller/maps.lua"), | |||
neural = dofile(local_rules .. "/controller/neural.lua"), | |||
selectors = dofile(local_rules .. "/controller/selectors.lua"), | |||
maps = dofile(local_rules .. "/controller/maps.lua") | |||
} | |||
if rspamd_util.file_exists(local_conf .. '/controller.lua') then | |||
@@ -62,4 +63,4 @@ for plug,paths in pairs(controller_plugin_paths) do | |||
plug, path, type(attrs)) | |||
end | |||
end | |||
end | |||
end |
@@ -0,0 +1,72 @@ | |||
--[[ | |||
Copyright (c) 2020, Vsevolod Stakhov <vsevolod@highsecure.ru> | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
]]-- | |||
local neural_common = require "plugins/neural" | |||
local ts = require("tableshape").types | |||
local ucl = require "ucl" | |||
local E = {} | |||
-- Controller neural plugin | |||
local learn_request_schema = ts.shape{ | |||
ham_vec = ts.array_of(ts.array_of(ts.number)), | |||
rule = ts.string:is_optional(), | |||
spam_vec = ts.array_of(ts.array_of(ts.number)), | |||
} | |||
local function handle_learn(task, conn) | |||
local parser = ucl.parser() | |||
local ok, err = parser:parse_text(task:get_rawbody()) | |||
if not ok then | |||
conn:send_error(400, err) | |||
return | |||
end | |||
local req_params = parser:get_object() | |||
ok, err = learn_request_schema:transform(req_params) | |||
if not ok then | |||
conn:send_error(400, err) | |||
return | |||
end | |||
local rule_name = req_params.rule or 'default' | |||
local rule = neural_common.settings.rules[rule_name] | |||
local set = neural_common.get_rule_settings(task, rule) | |||
local version = ((set.ann or E).version or 0) + 1 | |||
neural_common.spawn_train{ | |||
ev_base = task:get_ev_base(), | |||
ann_key = neural_common.new_ann_key(rule, set, version), | |||
set = set, | |||
rule = rule, | |||
ham_vec = req_params.ham_vec, | |||
spam_vec = req_params.spam_vec, | |||
worker = task:get_worker(), | |||
} | |||
conn:send_string('{"success" : true}') | |||
end | |||
rspamd_config:add_post_init(neural_common.process_rules_settings) | |||
return { | |||
learn = { | |||
handler = handle_learn, | |||
enable = true, | |||
need_task = true, | |||
}, | |||
} |
@@ -19,22 +19,21 @@ if confighelp then | |||
return | |||
end | |||
local rspamd_logger = require "rspamd_logger" | |||
local rspamd_util = require "rspamd_util" | |||
local rspamd_kann = require "rspamd_kann" | |||
local rspamd_text = require "rspamd_text" | |||
local fun = require "fun" | |||
local lua_redis = require "lua_redis" | |||
local lua_util = require "lua_util" | |||
local lua_verdict = require "lua_verdict" | |||
local neural_common = require "plugins/neural" | |||
local rspamd_kann = require "rspamd_kann" | |||
local rspamd_logger = require "rspamd_logger" | |||
local rspamd_tensor = require "rspamd_tensor" | |||
local fun = require "fun" | |||
local lua_settings = require "lua_settings" | |||
local meta_functions = require "lua_meta" | |||
local rspamd_text = require "rspamd_text" | |||
local rspamd_util = require "rspamd_util" | |||
local ts = require("tableshape").types | |||
local lua_verdict = require "lua_verdict" | |||
local N = "neural" | |||
local plugin_ver = '2' | |||
local settings = neural_common.settings | |||
-- Module vars | |||
local default_options = { | |||
@@ -52,7 +51,7 @@ local default_options = { | |||
classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias) | |||
spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1) | |||
ham_skip_prob = 0.0, -- proportional mode: ham skip probability | |||
store_pool_only = false, -- store tokens in mempool variable only (disables autotrain); | |||
store_pool_only = false, -- store tokens in cache only (disables autotrain); | |||
-- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest | |||
}, | |||
watch_interval = 60.0, | |||
@@ -77,207 +76,9 @@ local redis_profile_schema = ts.shape{ | |||
local has_blas = rspamd_tensor.has_blas() | |||
local text_cookie = rspamd_text.cookie | |||
local settings = { | |||
rules = {}, | |||
prefix = 'rn', -- Neural network default prefix | |||
max_profiles = 3, -- Maximum number of NN profiles stored | |||
} | |||
local module_config = rspamd_config:get_all_opt("neural") | |||
if not module_config then | |||
-- Legacy | |||
module_config = rspamd_config:get_all_opt("fann_redis") | |||
end | |||
local redis_lua_script_vectors_len = [[ | |||
local prefix = KEYS[1] | |||
local locked = redis.call('HGET', prefix, 'lock') | |||
if locked then | |||
local host = redis.call('HGET', prefix, 'hostname') or 'unknown' | |||
return string.format('%s:%s', host, locked) | |||
end | |||
local nspam = 0 | |||
local nham = 0 | |||
local ret = redis.call('LLEN', prefix .. '_spam') | |||
if ret then nspam = tonumber(ret) end | |||
ret = redis.call('LLEN', prefix .. '_ham') | |||
if ret then nham = tonumber(ret) end | |||
return {nspam,nham} | |||
]] | |||
local redis_lua_script_vectors_len_id = nil | |||
local redis_lua_script_maybe_invalidate = [[ | |||
local card = redis.call('ZCARD', KEYS[1]) | |||
local lim = tonumber(KEYS[2]) | |||
if card > lim then | |||
local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1) | |||
for _,k in ipairs(to_delete) do | |||
local tb = cjson.decode(k) | |||
redis.call('DEL', tb.redis_key) | |||
-- Also train vectors | |||
redis.call('DEL', tb.redis_key .. '_spam') | |||
redis.call('DEL', tb.redis_key .. '_ham') | |||
end | |||
redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1) | |||
return to_delete | |||
else | |||
return {} | |||
end | |||
]] | |||
local redis_maybe_invalidate_id = nil | |||
local redis_lua_script_maybe_lock = [[ | |||
local locked = redis.call('HGET', KEYS[1], 'lock') | |||
local now = tonumber(KEYS[2]) | |||
if locked then | |||
locked = tonumber(locked) | |||
local expire = tonumber(KEYS[3]) | |||
if now > locked and (now - locked) < expire then | |||
return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'} | |||
end | |||
end | |||
redis.call('HSET', KEYS[1], 'lock', tostring(now)) | |||
redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) | |||
return 1 | |||
]] | |||
local redis_maybe_lock_id = nil | |||
local redis_lua_script_save_unlock = [[ | |||
local now = tonumber(KEYS[6]) | |||
redis.call('ZADD', KEYS[2], now, KEYS[4]) | |||
redis.call('HSET', KEYS[1], 'ann', KEYS[3]) | |||
redis.call('DEL', KEYS[1] .. '_spam') | |||
redis.call('DEL', KEYS[1] .. '_ham') | |||
redis.call('HDEL', KEYS[1], 'lock') | |||
redis.call('HDEL', KEYS[7], 'lock') | |||
redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) | |||
if KEYS[8] then | |||
redis.call('HSET', KEYS[1], 'pca', KEYS[8]) | |||
end | |||
return 1 | |||
]] | |||
local redis_save_unlock_id = nil | |||
local redis_params | |||
local function load_scripts(params) | |||
redis_lua_script_vectors_len_id = lua_redis.add_redis_script(redis_lua_script_vectors_len, | |||
params) | |||
redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate, | |||
params) | |||
redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock, | |||
params) | |||
redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock, | |||
params) | |||
end | |||
local function result_to_vector(task, profile) | |||
if not profile.zeros then | |||
-- Fill zeros vector | |||
local zeros = {} | |||
for i=1,meta_functions.count_metatokens() do | |||
zeros[i] = 0.0 | |||
end | |||
for _,_ in ipairs(profile.symbols) do | |||
zeros[#zeros + 1] = 0.0 | |||
end | |||
profile.zeros = zeros | |||
end | |||
local vec = lua_util.shallowcopy(profile.zeros) | |||
local mt = meta_functions.rspamd_gen_metatokens(task) | |||
for i,v in ipairs(mt) do | |||
vec[i] = v | |||
end | |||
task:process_ann_tokens(profile.symbols, vec, #mt, 0.1) | |||
return vec | |||
end | |||
local function new_ann_key(rule, set, version) | |||
local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix, | |||
rule.prefix, set.name, set.digest:sub(1, 8), tostring(version)) | |||
return ann_key | |||
end | |||
local function get_rule_settings(task, rule) | |||
local sid = task:get_settings_id() or -1 | |||
local set = rule.settings[sid] | |||
if not set then return nil end | |||
while type(set) == 'number' do | |||
-- Reference to another settings! | |||
set = rule.settings[set] | |||
end | |||
return set | |||
end | |||
local function redis_ann_prefix(rule, settings_name) | |||
-- We also need to count metatokens: | |||
local n = meta_functions.version | |||
return string.format('%s%d_%s_%d_%s', | |||
settings.prefix, plugin_ver, rule.prefix, n, settings_name) | |||
end | |||
-- Creates and stores ANN profile in Redis | |||
local function new_ann_profile(task, rule, set, version) | |||
local ann_key = new_ann_key(rule, set, version) | |||
local ann_key = neural_common.new_ann_key(rule, set, version, settings) | |||
local profile = { | |||
symbols = set.symbols, | |||
@@ -321,7 +122,7 @@ local function ann_scores_filter(task) | |||
local ann | |||
local profile | |||
local set = get_rule_settings(task, rule) | |||
local set = neural_common.get_rule_settings(task, rule) | |||
if set then | |||
if set.ann then | |||
ann = set.ann.ann | |||
@@ -336,7 +137,7 @@ local function ann_scores_filter(task) | |||
end | |||
if ann then | |||
local vec = result_to_vector(task, profile) | |||
local vec = neural_common.result_to_vector(task, profile) | |||
local score | |||
local out = ann:apply1(vec, set.ann.pca) | |||
@@ -357,112 +158,12 @@ local function ann_scores_filter(task) | |||
end | |||
end | |||
local function create_ann(n, nlayers, rule) | |||
-- We ignore number of layers so far when using kann | |||
local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0) | |||
local t = rspamd_kann.layer.input(n) | |||
t = rspamd_kann.transform.relu(t) | |||
t = rspamd_kann.layer.dense(t, nhidden); | |||
t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg) | |||
return rspamd_kann.new.kann(t) | |||
end | |||
local function can_push_train_vector(rule, task, learn_type, nspam, nham) | |||
local train_opts = rule.train | |||
local coin = math.random() | |||
if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then | |||
rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) | |||
return false | |||
end | |||
if train_opts.learn_mode == 'balanced' then | |||
-- Keep balanced training set based on number of spam and ham samples | |||
if learn_type == 'spam' then | |||
if nspam <= train_opts.max_trains then | |||
if nspam > nham then | |||
-- Apply sampling | |||
local skip_rate = 1.0 - nham / (nspam + 1) | |||
if coin < skip_rate - train_opts.classes_bias then | |||
rspamd_logger.infox(task, | |||
'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', | |||
learn_type, | |||
skip_rate - train_opts.classes_bias, | |||
nspam, nham) | |||
return false | |||
end | |||
end | |||
return true | |||
else -- Enough learns | |||
rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', | |||
learn_type, | |||
nspam) | |||
end | |||
else | |||
if nham <= train_opts.max_trains then | |||
if nham > nspam then | |||
-- Apply sampling | |||
local skip_rate = 1.0 - nspam / (nham + 1) | |||
if coin < skip_rate - train_opts.classes_bias then | |||
rspamd_logger.infox(task, | |||
'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', | |||
learn_type, | |||
skip_rate - train_opts.classes_bias, | |||
nspam, nham) | |||
return false | |||
end | |||
end | |||
return true | |||
else | |||
rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type, | |||
nham) | |||
end | |||
end | |||
else | |||
-- Probabilistic learn mode, we just skip learn if we already have enough samples or | |||
-- if our coin drop is less than desired probability | |||
if learn_type == 'spam' then | |||
if nspam <= train_opts.max_trains then | |||
if train_opts.spam_skip_prob then | |||
if coin <= train_opts.spam_skip_prob then | |||
rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type, | |||
coin, train_opts.spam_skip_prob) | |||
return false | |||
end | |||
return true | |||
end | |||
else | |||
rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type, | |||
nspam, train_opts.max_trains) | |||
end | |||
else | |||
if nham <= train_opts.max_trains then | |||
if train_opts.ham_skip_prob then | |||
if coin <= train_opts.ham_skip_prob then | |||
rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type, | |||
coin, train_opts.ham_skip_prob) | |||
return false | |||
end | |||
return true | |||
end | |||
else | |||
rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type, | |||
nham, train_opts.max_trains) | |||
end | |||
end | |||
end | |||
return false | |||
end | |||
local function ann_push_task_result(rule, task, verdict, score, set) | |||
local train_opts = rule.train | |||
local learn_spam, learn_ham | |||
local skip_reason = 'unknown' | |||
if train_opts.autotrain then | |||
if not train_opts.store_pool_only and train_opts.autotrain then | |||
if train_opts.spam_score then | |||
learn_spam = score >= train_opts.spam_score | |||
@@ -510,10 +211,10 @@ local function ann_push_task_result(rule, task, verdict, score, set) | |||
learn_ham = false | |||
learn_spam = false | |||
-- Explicitly store tokens in a mempool variable | |||
local vec = result_to_vector(task, set) | |||
task:get_mempool():set_variable('neural_vec_mpack', ucl.to_format(vec, 'msgpack')) | |||
task:get_mempool():set_variable('neural_profile_digest', set.digest) | |||
-- Explicitly store tokens in cache | |||
local vec = neural_common.result_to_vector(task, set) | |||
task:cache_set('neural_vec_mpack', ucl.to_format(vec, 'msgpack')) | |||
task:cache_set('neural_profile_digest', set.digest) | |||
skip_reason = 'store_pool_only has been set' | |||
end | |||
end | |||
@@ -527,8 +228,8 @@ local function ann_push_task_result(rule, task, verdict, score, set) | |||
if not err and type(data) == 'table' then | |||
local nspam,nham = data[1],data[2] | |||
if can_push_train_vector(rule, task, learn_type, nspam, nham) then | |||
local vec = result_to_vector(task, set) | |||
if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then | |||
local vec = neural_common.result_to_vector(task, set) | |||
local str = rspamd_util.zstd_compress(table.concat(vec, ';')) | |||
local target_key = set.ann.redis_key .. '_' .. learn_type | |||
@@ -585,7 +286,7 @@ local function ann_push_task_result(rule, task, verdict, score, set) | |||
set.name) | |||
end | |||
lua_redis.exec_redis_script(redis_lua_script_vectors_len_id, | |||
lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len, | |||
{task = task, is_write = false}, | |||
vectors_len_cb, | |||
{ | |||
@@ -605,284 +306,6 @@ end | |||
--- Offline training logic | |||
local function gen_unlock_cb(rule, set, ann_key) | |||
return function (err) | |||
if err then | |||
rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s', | |||
rule.prefix, set.name, ann_key, err) | |||
else | |||
lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s', | |||
rule.prefix, set.name, ann_key) | |||
end | |||
end | |||
end | |||
local function register_lock_extender(rule, set, ev_base, ann_key) | |||
rspamd_config:add_periodic(ev_base, 30.0, | |||
function() | |||
local function redis_lock_extend_cb(_err, _) | |||
if _err then | |||
rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', | |||
ann_key, _err) | |||
else | |||
rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', | |||
ann_key) | |||
end | |||
end | |||
if set.learning_spawned then | |||
lua_redis.redis_make_request_taskless(ev_base, | |||
rspamd_config, | |||
rule.redis, | |||
nil, | |||
true, -- is write | |||
redis_lock_extend_cb, --callback | |||
'HINCRBY', -- command | |||
{ann_key, 'lock', '30'} | |||
) | |||
else | |||
lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") | |||
return false -- do not plan any more updates | |||
end | |||
return true | |||
end | |||
) | |||
end | |||
local function learn_pca(inputs, max_inputs) | |||
local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs)) | |||
local eigenvals = scatter_matrix:eigen() | |||
-- scatter matrix is not filled with eigenvectors | |||
lua_util.debugm(N, 'eigenvalues: %s', eigenvals) | |||
local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1]) | |||
for i=1,max_inputs do | |||
w[i] = scatter_matrix[#scatter_matrix - i + 1] | |||
end | |||
lua_util.debugm(N, 'pca matrix: %s', w) | |||
return w | |||
end | |||
local function fill_set_ann(set, ann_key) | |||
if not set.ann then | |||
set.ann = { | |||
symbols = set.symbols, | |||
distance = 0, | |||
digest = set.digest, | |||
redis_key = ann_key, | |||
version = 0, | |||
} | |||
end | |||
end | |||
local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec) | |||
-- Check training data sanity | |||
-- Now we need to join inputs and create the appropriate test vectors | |||
local n = #set.symbols + | |||
meta_functions.rspamd_count_metatokens() | |||
-- Now we can train ann | |||
local train_ann = create_ann(rule.max_inputs or n, 3, rule) | |||
if #ham_vec + #spam_vec < rule.train.max_trains / 2 then | |||
-- Invalidate ANN as it is definitely invalid | |||
-- TODO: add invalidation | |||
assert(false) | |||
else | |||
local inputs, outputs = {}, {} | |||
-- Used to show sparsed vectors in a convenient format (for debugging only) | |||
local function debug_vec(t) | |||
local ret = {} | |||
for i,v in ipairs(t) do | |||
if v ~= 0 then | |||
ret[#ret + 1] = string.format('%d=%.2f', i, v) | |||
end | |||
end | |||
return ret | |||
end | |||
-- Make training set by joining vectors | |||
-- KANN automatically shuffles those samples | |||
-- 1.0 is used for spam and -1.0 is used for ham | |||
-- It implies that output layer can express that (e.g. tanh output) | |||
for _,e in ipairs(spam_vec) do | |||
inputs[#inputs + 1] = e | |||
outputs[#outputs + 1] = {1.0} | |||
--rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e)) | |||
end | |||
for _,e in ipairs(ham_vec) do | |||
inputs[#inputs + 1] = e | |||
outputs[#outputs + 1] = {-1.0} | |||
--rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e)) | |||
end | |||
-- Called in child process | |||
local function train() | |||
local log_thresh = rule.train.max_iterations / 10 | |||
local seen_nan = false | |||
local function train_cb(iter, train_cost, value_cost) | |||
if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then | |||
if train_cost ~= train_cost and not seen_nan then | |||
-- We have nan :( try to log lot's of stuff to dig into a problem | |||
seen_nan = true | |||
rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s', | |||
rule.prefix, set.name, | |||
value_cost) | |||
for i,e in ipairs(inputs) do | |||
lua_util.debugm(N, rspamd_config, 'train vector %s -> %s', | |||
debug_vec(e), outputs[i][1]) | |||
end | |||
end | |||
rspamd_logger.infox(rspamd_config, | |||
"ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s", | |||
rule.prefix, set.name, | |||
ann_key, | |||
iter, | |||
train_cost, | |||
value_cost) | |||
end | |||
end | |||
lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started", | |||
rule.prefix, set.name) | |||
local ret,err = pcall(train_ann.train1, train_ann, | |||
inputs, outputs, { | |||
lr = rule.train.learning_rate, | |||
max_epoch = rule.train.max_iterations, | |||
cb = train_cb, | |||
pca = (set.ann or {}).pca | |||
}) | |||
if not ret then | |||
rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s", | |||
rule.prefix, set.name, err) | |||
return nil | |||
end | |||
if not seen_nan then | |||
local out = train_ann:save() | |||
return out | |||
else | |||
return nil | |||
end | |||
end | |||
set.learning_spawned = true | |||
local function redis_save_cb(err) | |||
if err then | |||
rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s', | |||
rule.prefix, set.name, ann_key, err) | |||
lua_redis.redis_make_request_taskless(ev_base, | |||
rspamd_config, | |||
rule.redis, | |||
nil, | |||
false, -- is write | |||
gen_unlock_cb(rule, set, ann_key), --callback | |||
'HDEL', -- command | |||
{ann_key, 'lock'} | |||
) | |||
else | |||
rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', | |||
rule.prefix, set.name, set.ann.redis_key) | |||
end | |||
end | |||
local function ann_trained(err, data) | |||
set.learning_spawned = false | |||
if err then | |||
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', | |||
rule.prefix, set.name, err) | |||
lua_redis.redis_make_request_taskless(ev_base, | |||
rspamd_config, | |||
rule.redis, | |||
nil, | |||
true, -- is write | |||
gen_unlock_cb(rule, set, ann_key), --callback | |||
'HDEL', -- command | |||
{ann_key, 'lock'} | |||
) | |||
else | |||
local ann_data = rspamd_util.zstd_compress(data) | |||
local pca_data | |||
fill_set_ann(set, ann_key) | |||
if set.ann.pca then | |||
pca_data = rspamd_util.zstd_compress(set.ann.pca:save()) | |||
end | |||
-- Deserialise ANN from the child process | |||
ann_trained = rspamd_kann.load(data) | |||
local version = (set.ann.version or 0) + 1 | |||
set.ann.version = version | |||
set.ann.ann = ann_trained | |||
set.ann.symbols = set.symbols | |||
set.ann.redis_key = new_ann_key(rule, set, version) | |||
local profile = { | |||
symbols = set.symbols, | |||
digest = set.digest, | |||
redis_key = set.ann.redis_key, | |||
version = version | |||
} | |||
local ucl = require "ucl" | |||
local profile_serialized = ucl.to_format(profile, 'json-compact', true) | |||
rspamd_logger.infox(rspamd_config, | |||
'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)', | |||
rule.prefix, set.name, | |||
#data, #ann_data, | |||
#(set.ann.pca or {}), #(pca_data or {}), | |||
set.ann.redis_key, ann_key) | |||
lua_redis.exec_redis_script(redis_save_unlock_id, | |||
{ev_base = ev_base, is_write = true}, | |||
redis_save_cb, | |||
{profile.redis_key, | |||
redis_ann_prefix(rule, set.name), | |||
ann_data, | |||
profile_serialized, | |||
tostring(rule.ann_expire), | |||
tostring(os.time()), | |||
ann_key, -- old key to unlock... | |||
pca_data | |||
}) | |||
end | |||
end | |||
if rule.max_inputs then | |||
fill_set_ann(set, ann_key) | |||
-- Train PCA in the main process, presumably it is not that long | |||
set.ann.pca = learn_pca(inputs, rule.max_inputs) | |||
end | |||
worker:spawn_process{ | |||
func = train, | |||
on_complete = ann_trained, | |||
proctitle = string.format("ANN train for %s/%s", rule.prefix, set.name), | |||
} | |||
end | |||
-- Spawn learn and register lock extension | |||
set.learning_spawned = true | |||
register_lock_extender(rule, set, ev_base, ann_key) | |||
end | |||
-- Utility to extract and split saved training vectors to a table of tables | |||
local function process_training_vectors(data) | |||
return fun.totable(fun.map(function(tok) | |||
@@ -909,14 +332,16 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) | |||
rule.redis, | |||
nil, | |||
true, -- is write | |||
gen_unlock_cb(rule, set, ann_key), --callback | |||
neural_common.gen_unlock_cb(rule, set, ann_key), --callback | |||
'HDEL', -- command | |||
{ann_key, 'lock'} | |||
) | |||
else | |||
-- Decompress and convert to numbers each training vector | |||
ham_elts = process_training_vectors(data) | |||
spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts) | |||
neural_common.spawn_train({worker = worker, ev_base = ev_base, | |||
rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts, | |||
spam_vec = spam_elts}) | |||
end | |||
end | |||
@@ -931,7 +356,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) | |||
rule.redis, | |||
nil, | |||
true, -- is write | |||
gen_unlock_cb(rule, set, ann_key), --callback | |||
neural_common.gen_unlock_cb(rule, set, ann_key), --callback | |||
'HDEL', -- command | |||
{ann_key, 'lock'} | |||
) | |||
@@ -987,7 +412,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) | |||
-- Call Redis script that tries to acquire a lock | |||
-- This script returns either a boolean or a pair {'lock_time', 'hostname'} when | |||
-- ANN is locked by another host (or a process, meh) | |||
lua_redis.exec_redis_script(redis_maybe_lock_id, | |||
lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock, | |||
{ev_base = ev_base, is_write = true}, | |||
redis_lock_cb, | |||
{ | |||
@@ -1376,7 +801,7 @@ local function cleanup_anns(rule, cfg, ev_base) | |||
end | |||
if type(set) == 'table' then | |||
lua_redis.exec_redis_script(redis_maybe_invalidate_id, | |||
lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate, | |||
{ev_base = ev_base, is_write = true}, | |||
invalidate_cb, | |||
{set.prefix, tostring(settings.max_profiles)}) | |||
@@ -1411,7 +836,7 @@ local function ann_push_vector(task) | |||
end | |||
for _,rule in pairs(settings.rules) do | |||
local set = get_rule_settings(task, rule) | |||
local set = neural_common.get_rule_settings(task, rule) | |||
if set then | |||
ann_push_task_result(rule, task, verdict, score, set) | |||
@@ -1423,155 +848,20 @@ local function ann_push_vector(task) | |||
end | |||
local function process_rules_settings() | |||
local function process_settings_elt(rule, selt) | |||
local profile = rule.profile[selt.name] | |||
if profile then | |||
-- Use static user defined profile | |||
-- Ensure that we have an array... | |||
lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s", | |||
rule.prefix, selt.name, profile) | |||
if not profile[1] then profile = lua_util.keys(profile) end | |||
selt.symbols = profile | |||
else | |||
lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)", | |||
rule.prefix, selt.name) | |||
end | |||
local function filter_symbols_predicate(sname) | |||
if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then | |||
return false | |||
end | |||
local fl = rspamd_config:get_symbol_flags(sname) | |||
if fl then | |||
fl = lua_util.list_to_hash(fl) | |||
return not (fl.nostat or fl.idempotent or fl.skip or fl.composite) | |||
end | |||
return false | |||
end | |||
-- Generic stuff | |||
table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols))) | |||
selt.digest = lua_util.table_digest(selt.symbols) | |||
selt.prefix = redis_ann_prefix(rule, selt.name) | |||
lua_redis.register_prefix(selt.prefix, N, | |||
string.format('NN prefix for rule "%s"; settings id "%s"', | |||
rule.prefix, selt.name), { | |||
persistent = true, | |||
type = 'zlist', | |||
}) | |||
-- Versions | |||
lua_redis.register_prefix(selt.prefix .. '_\\d+', N, | |||
string.format('NN storage for rule "%s"; settings id "%s"', | |||
rule.prefix, selt.name), { | |||
persistent = true, | |||
type = 'hash', | |||
}) | |||
lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N, | |||
string.format('NN learning set (spam) for rule "%s"; settings id "%s"', | |||
rule.prefix, selt.name), { | |||
persistent = true, | |||
type = 'list', | |||
}) | |||
lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N, | |||
string.format('NN learning set (spam) for rule "%s"; settings id "%s"', | |||
rule.prefix, selt.name), { | |||
persistent = true, | |||
type = 'list', | |||
}) | |||
end | |||
for k,rule in pairs(settings.rules) do | |||
if not rule.allowed_settings then | |||
rule.allowed_settings = {} | |||
elseif rule.allowed_settings == 'all' then | |||
-- Extract all settings ids | |||
rule.allowed_settings = lua_util.keys(lua_settings.all_settings()) | |||
end | |||
-- Convert to a map <setting_id> -> true | |||
rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings) | |||
-- Check if we can work without settings | |||
if k == 'default' or type(rule.default) ~= 'boolean' then | |||
rule.default = true | |||
end | |||
rule.settings = {} | |||
if rule.default then | |||
local default_settings = { | |||
symbols = lua_settings.default_symbols(), | |||
name = 'default' | |||
} | |||
process_settings_elt(rule, default_settings) | |||
rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32 | |||
end | |||
-- Now, for each allowed settings, we store sorted symbols + digest | |||
-- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest } | |||
for s,_ in pairs(rule.allowed_settings) do | |||
-- Here, we have a name, set of symbols and | |||
local settings_id = s | |||
if type(settings_id) ~= 'number' then | |||
settings_id = lua_settings.numeric_settings_id(s) | |||
end | |||
local selt = lua_settings.settings_by_id(settings_id) | |||
local nelt = { | |||
symbols = selt.symbols, -- Already sorted | |||
name = selt.name | |||
} | |||
process_settings_elt(rule, nelt) | |||
for id,ex in pairs(rule.settings) do | |||
if type(ex) == 'table' then | |||
if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then | |||
-- Equal symbols, add reference | |||
lua_util.debugm(N, rspamd_config, | |||
'added reference from settings id %s to %s; same symbols', | |||
nelt.name, ex.name) | |||
rule.settings[settings_id] = id | |||
nelt = nil | |||
end | |||
end | |||
end | |||
if nelt then | |||
rule.settings[settings_id] = nelt | |||
lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s', | |||
nelt.name, settings_id, rule.prefix) | |||
end | |||
end | |||
end | |||
end | |||
redis_params = lua_redis.parse_redis_server('neural') | |||
if not redis_params then | |||
redis_params = lua_redis.parse_redis_server('fann_redis') | |||
end | |||
-- Initialization part | |||
if not (module_config and type(module_config) == 'table') or not redis_params then | |||
if not (neural_common.module_config and type(neural_common.module_config) == 'table') | |||
or not neural_common.redis_params then | |||
rspamd_logger.infox(rspamd_config, 'Module is unconfigured') | |||
lua_util.disable_module(N, "redis") | |||
return | |||
end | |||
local rules = module_config['rules'] | |||
local rules = neural_common.module_config['rules'] | |||
if not rules then | |||
-- Use legacy configuration | |||
rules = {} | |||
rules['default'] = module_config | |||
rules['default'] = neural_common.module_config | |||
end | |||
local id = rspamd_config:register_symbol({ | |||
@@ -1582,8 +872,7 @@ local id = rspamd_config:register_symbol({ | |||
callback = ann_scores_filter | |||
}) | |||
settings = lua_util.override_defaults(settings, module_config) | |||
settings.rules = {} -- Reset unless validated further in the cycle | |||
neural_common.settings.rules = {} -- Reset unless validated further in the cycle | |||
if settings.blacklisted_symbols and settings.blacklisted_symbols[1] then | |||
-- Transform to hash for simplicity | |||
@@ -1593,7 +882,7 @@ end | |||
-- Check all rules | |||
for k,r in pairs(rules) do | |||
local rule_elt = lua_util.override_defaults(default_options, r) | |||
rule_elt['redis'] = redis_params | |||
rule_elt['redis'] = neural_common.redis_params | |||
rule_elt['anns'] = {} -- Store ANNs here | |||
if not rule_elt.prefix then | |||
@@ -1651,11 +940,12 @@ rspamd_config:register_symbol({ | |||
callback = ann_push_vector | |||
}) | |||
-- We also need to deal with settings | |||
rspamd_config:add_post_init(neural_common.process_rules_settings) | |||
-- Add training scripts | |||
for _,rule in pairs(settings.rules) do | |||
load_scripts(rule.redis) | |||
-- We also need to deal with settings | |||
rspamd_config:add_post_init(process_rules_settings) | |||
neural_common.load_scripts(rule.redis) | |||
-- This function will check ANNs in Redis when a worker is loaded | |||
rspamd_config:add_on_load(function(cfg, ev_base, worker) | |||
if worker:is_scanner() then |
@@ -0,0 +1,75 @@ | |||
*** Settings *** | |||
Suite Setup Neural Setup | |||
Suite Teardown Neural Teardown | |||
Library Process | |||
Library ${TESTDIR}/lib/rspamd.py | |||
Resource ${TESTDIR}/lib/rspamd.robot | |||
Variables ${TESTDIR}/lib/vars.py | |||
*** Variables *** | |||
${URL_TLD} ${TESTDIR}/../lua/unit/test_tld.dat | |||
${CONFIG} ${TESTDIR}/configs/neural_noauto.conf | |||
${MESSAGE} ${TESTDIR}/messages/spam_message.eml | |||
${REDIS_SCOPE} Suite | |||
${RSPAMD_SCOPE} Suite | |||
*** Test Cases *** | |||
Collect training vectors & train manually | |||
Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL","SAVE_NN_ROW"]} | |||
Expect Symbol SPAM_SYMBOL | |||
# Save neural inputs for later | |||
${SPAM_ROW} = Get File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0] | |||
Remove File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0] | |||
Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL","SAVE_NN_ROW"]} | |||
Expect Symbol HAM_SYMBOL | |||
# Save neural inputs for later | |||
${HAM_ROW} = Get File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0] | |||
Remove File ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0] | |||
${HAM_ROW} = Run ${RSPAMADM} lua -a ${HAM_ROW} ${TESTDIR}/util/nn_unpack.lua | |||
${HAM_ROW} = Evaluate json.loads("${HAM_ROW}") | |||
${SPAM_ROW} = Run ${RSPAMADM} lua -a ${SPAM_ROW} ${TESTDIR}/util/nn_unpack.lua | |||
${SPAM_ROW} = Evaluate json.loads("${SPAM_ROW}") | |||
${HAM_VEC} = Evaluate [${HAM_ROW}] * 10 | |||
${SPAM_VEC} = Evaluate [${SPAM_ROW}] * 10 | |||
${json1} = Evaluate json.dumps({"spam_vec": ${SPAM_VEC}, "ham_vec": ${HAM_VEC}, "rule": "SHORT"}) | |||
# Save variables for use in inverse training | |||
Set Suite Variable ${HAM_VEC} | |||
Set Suite Variable ${SPAM_VEC} | |||
HTTP POST ${LOCAL_ADDR} ${PORT_CONTROLLER} /plugins/neural/learn ${json1} | |||
Sleep 2s Wait for neural to be loaded | |||
Check Neural HAM | |||
Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} | |||
Do Not Expect Symbol NEURAL_SPAM_SHORT | |||
Expect Symbol NEURAL_HAM_SHORT | |||
Check Neural SPAM | |||
Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} | |||
Do Not Expect Symbol NEURAL_HAM_SHORT | |||
Expect Symbol NEURAL_SPAM_SHORT | |||
Train inverse | |||
${json2} = Evaluate json.dumps({"spam_vec": ${HAM_VEC}, "ham_vec": ${SPAM_VEC}, "rule": "SHORT"}) | |||
HTTP POST ${LOCAL_ADDR} ${PORT_CONTROLLER} /plugins/neural/learn ${json2} | |||
Sleep 2s Wait for neural to be loaded | |||
Check Neural HAM - inverse | |||
Scan File ${MESSAGE} Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} | |||
Do Not Expect Symbol NEURAL_HAM_SHORT | |||
Expect Symbol NEURAL_SPAM_SHORT | |||
Check Neural SPAM - inverse | |||
Scan File ${MESSAGE} Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]} | |||
Do Not Expect Symbol NEURAL_SPAM_SHORT | |||
Expect Symbol NEURAL_HAM_SHORT | |||
*** Keywords *** | |||
Neural Setup | |||
${TMPDIR} = Make Temporary Directory | |||
Set Suite Variable ${TMPDIR} | |||
Run Redis | |||
Generic Setup | |||
Neural Teardown | |||
Shutdown Process With Children ${REDIS_PID} | |||
Normal Teardown |
@@ -0,0 +1,85 @@ | |||
options = { | |||
url_tld = "${URL_TLD}" | |||
pidfile = "${TMPDIR}/rspamd.pid" | |||
lua_path = "${INSTALLROOT}/share/rspamd/lib/?.lua" | |||
filters = []; | |||
explicit_modules = ["settings"]; | |||
} | |||
logging = { | |||
type = "file", | |||
level = "debug" | |||
filename = "${TMPDIR}/rspamd.log" | |||
log_usec = true; | |||
} | |||
metric = { | |||
name = "default", | |||
actions = { | |||
reject = 100500, | |||
add_header = 50500, | |||
} | |||
unknown_weight = 1 | |||
} | |||
worker { | |||
type = normal | |||
bind_socket = ${LOCAL_ADDR}:${PORT_NORMAL} | |||
count = 1 | |||
task_timeout = 10s; | |||
} | |||
worker { | |||
type = controller | |||
bind_socket = ${LOCAL_ADDR}:${PORT_CONTROLLER} | |||
count = 1 | |||
secure_ip = ["127.0.0.1", "::1"]; | |||
stats_path = "${TMPDIR}/stats.ucl" | |||
} | |||
modules { | |||
path = "${TESTDIR}/../../src/plugins/lua/" | |||
} | |||
lua = "${TESTDIR}/lua/test_coverage.lua"; | |||
neural { | |||
rules { | |||
SHORT { | |||
train { | |||
learning_rate = 0.001; | |||
max_usages = 2; | |||
spam_score = 1; | |||
ham_score = -1; | |||
max_trains = 10; | |||
max_iterations = 250; | |||
store_pool_only = true; | |||
} | |||
symbol_spam = "NEURAL_SPAM_SHORT"; | |||
symbol_ham = "NEURAL_HAM_SHORT"; | |||
ann_expire = 86400; | |||
watch_interval = 0.5; | |||
} | |||
SHORT_PCA { | |||
train { | |||
learning_rate = 0.001; | |||
max_usages = 2; | |||
spam_score = 1; | |||
ham_score = -1; | |||
max_trains = 10; | |||
max_iterations = 250; | |||
store_pool_only = true; | |||
} | |||
symbol_spam = "NEURAL_SPAM_SHORT_PCA"; | |||
symbol_ham = "NEURAL_HAM_SHORT_PCA"; | |||
ann_expire = 86400; | |||
watch_interval = 0.5; | |||
max_inputs = 2; | |||
} | |||
} | |||
allow_local = true; | |||
} | |||
redis { | |||
servers = "${REDIS_ADDR}:${REDIS_PORT}"; | |||
expand_keys = true; | |||
} | |||
lua = "${TESTDIR}/lua/neural.lua"; |
@@ -209,6 +209,7 @@ Run Rspamd | |||
... ELSE Make Temporary Directory | |||
Set Directory Ownership ${tmpdir} ${RSPAMD_USER} ${RSPAMD_GROUP} | |||
${template} = Get File ${CONFIG} | |||
# TODO: stop using this; we have Lupa now | |||
FOR ${i} IN @{vargs} | |||
${newvalue} = Replace Variables ${${i}} | |||
Set To Dictionary ${d} ${i}=${newvalue} | |||
@@ -218,7 +219,8 @@ Run Rspamd | |||
Log ${config} | |||
Create File ${tmpdir}/rspamd.conf ${config} | |||
${result} = Run Process ${RSPAMD} -u ${RSPAMD_USER} -g ${RSPAMD_GROUP} | |||
... -c ${tmpdir}/rspamd.conf env:TMPDIR=${tmpdir} env:DBDIR=${tmpdir} env:LD_LIBRARY_PATH=${TESTDIR}/../../contrib/aho-corasick stdout=DEVNULL stderr=DEVNULL | |||
... -c ${tmpdir}/rspamd.conf env:TMPDIR=${tmpdir} env:DBDIR=${tmpdir} env:LD_LIBRARY_PATH=${TESTDIR}/../../contrib/aho-corasick | |||
... env:RSPAMD_INSTALLROOT=${INSTALLROOT} stdout=DEVNULL stderr=DEVNULL | |||
Run Keyword If ${result.rc} != 0 Log ${result.stderr} | |||
Should Be Equal As Integers ${result.rc} 0 | |||
Wait Until Keyword Succeeds 10x 1 sec Check Pidfile ${tmpdir}/rspamd.pid timeout=0.5s |
@@ -1,3 +1,5 @@ | |||
local logger = require "rspamd_logger" | |||
rspamd_config:register_symbol({ | |||
name = 'SPAM_SYMBOL', | |||
score = 5.0, | |||
@@ -21,4 +23,39 @@ rspamd_config:register_symbol({ | |||
callback = function() | |||
return true, 'Fires always' | |||
end | |||
}) | |||
}) | |||
rspamd_config.SAVE_NN_ROW = { | |||
callback = function(task) | |||
local fname = os.tmpname() | |||
task:cache_set('nn_row_tmpfile', fname) | |||
return true, 1.0, fname | |||
end | |||
} | |||
rspamd_config.SAVE_NN_ROW_IDEMPOTENT = { | |||
callback = function(task) | |||
local function tohex(str) | |||
return (str:gsub('.', function (c) | |||
return string.format('%02X', string.byte(c)) | |||
end)) | |||
end | |||
local fname = task:cache_get('nn_row_tmpfile') | |||
if not fname then | |||
return | |||
end | |||
local f, err = io.open(fname, 'w') | |||
if not f then | |||
logger.errx(task, err) | |||
return | |||
end | |||
f:write(tohex(task:cache_get('neural_vec_mpack') or '')) | |||
f:close() | |||
return | |||
end, | |||
type = 'idempotent', | |||
flags = 'explicit_disable', | |||
priority = 10, | |||
} | |||
dofile(rspamd_env.INSTALLROOT .. "/share/rspamd/rules/controller/init.lua") |
@@ -0,0 +1,16 @@ | |||
local ucl = require "ucl" | |||
local function unhex(str) | |||
return (str:gsub('..', function (cc) | |||
return string.char(tonumber(cc, 16)) | |||
end)) | |||
end | |||
local parser = ucl.parser() | |||
local ok, err = parser:parse_string(unhex(arg[1]), 'msgpack') | |||
if not ok then | |||
io.stderr:write(err) | |||
os.exit(1) | |||
end | |||
print(ucl.to_format(parser:get_object(), 'json-compact')) |