]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Add controller endpoint for training neural 3570/head
authorAndrew Lewis <nerf@judo.za.org>
Thu, 17 Dec 2020 09:28:09 +0000 (11:28 +0200)
committerAndrew Lewis <nerf@judo.za.org>
Thu, 17 Dec 2020 09:28:09 +0000 (11:28 +0200)
 - Move neural functions to library
 - Parameterise spawn_train
 - neural plugin: Fix store_pool_only when autotrain is true
 - neural plugin: Use cache_set instead of mempool
 - Add test

lualib/plugins/neural.lua [new file with mode: 0644]
rules/controller/init.lua
rules/controller/neural.lua [new file with mode: 0644]
src/plugins/lua/neural.lua
test/functional/cases/330_neural.robot [deleted file]
test/functional/cases/330_neural/001_autotrain.robot [new file with mode: 0644]
test/functional/cases/330_neural/002_manualtrain.robot [new file with mode: 0644]
test/functional/configs/neural_noauto.conf [new file with mode: 0644]
test/functional/lib/rspamd.robot
test/functional/lua/neural.lua
test/functional/util/nn_unpack.lua [new file with mode: 0644]

diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
new file mode 100644 (file)
index 0000000..4d4c44b
--- /dev/null
@@ -0,0 +1,779 @@
+--[[
+Copyright (c) 2020, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local fun = require "fun"
+local lua_redis = require "lua_redis"
+local lua_settings = require "lua_settings"
+local lua_util = require "lua_util"
+local meta_functions = require "lua_meta"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_logger = require "rspamd_logger"
+local rspamd_tensor = require "rspamd_tensor"
+local rspamd_util = require "rspamd_util"
+
+local N = 'neural'
+
+-- Used in prefix to avoid wrong ANN to be loaded
+local plugin_ver = '2'
+
+-- Module vars
+local default_options = {
+  train = {
+    max_trains = 1000,
+    max_epoch = 1000,
+    max_usages = 10,
+    max_iterations = 25, -- Torch style
+    mse = 0.001,
+    autotrain = true,
+    train_prob = 1.0,
+    learn_threads = 1,
+    learn_mode = 'balanced', -- Possible values: balanced, proportional
+    learning_rate = 0.01,
+    classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
+    spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
+    ham_skip_prob = 0.0, -- proportional mode: ham skip probability
+    store_pool_only = false, -- store tokens in cache only (disables autotrain);
+    -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
+  },
+  watch_interval = 60.0,
+  lock_expire = 600,
+  learning_spawned = false,
+  ann_expire = 60 * 60 * 24 * 2, -- 2 days
+  hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
+  symbol_spam = 'NEURAL_SPAM',
+  symbol_ham = 'NEURAL_HAM',
+  max_inputs = nil, -- when PCA is used
+  blacklisted_symbols = {}, -- list of symbols skipped in neural processing
+}
+
+-- Rule structure:
+-- * static config fields (see `default_options`)
+-- * prefix - name or defined prefix
+-- * settings - table of settings indexed by settings id, -1 is used when no settings defined
+
+-- Rule settings element defines elements for specific settings id:
+-- * symbols - static symbols profile (defined by config or extracted from symcache)
+-- * name - name of settings id
+-- * digest - digest of all symbols
+-- * ann - dynamic ANN configuration loaded from Redis
+-- * train - train data for ANN (e.g. the currently trained ANN)
+
+-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
+-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
+-- * version - version of ANN loaded from redis
+-- * redis_key - name of ANN key in Redis
+-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
+-- * distance - distance between set.symbols and set.ann.symbols
+-- * ann - kann object
+
+local settings = {
+  rules = {},
+  prefix = 'rn', -- Neural network default prefix
+  max_profiles = 3, -- Maximum number of NN profiles stored
+}
+
+-- Get module & Redis configuration
+local module_config = rspamd_config:get_all_opt(N)
+settings = lua_util.override_defaults(settings, module_config)
+local redis_params = lua_redis.parse_redis_server('neural')
+
+-- Lua script that checks if we can store a new training vector
+-- Uses the following keys:
+-- key1 - ann key
+-- returns nspam,nham (or nil if locked)
+local redis_lua_script_vectors_len = [[
+  local prefix = KEYS[1]
+  local locked = redis.call('HGET', prefix, 'lock')
+  if locked then
+    local host = redis.call('HGET', prefix, 'hostname') or 'unknown'
+    return string.format('%s:%s', host, locked)
+  end
+  local nspam = 0
+  local nham = 0
+
+  local ret = redis.call('LLEN', prefix .. '_spam')
+  if ret then nspam = tonumber(ret) end
+  ret = redis.call('LLEN', prefix .. '_ham')
+  if ret then nham = tonumber(ret) end
+
+  return {nspam,nham}
+]]
+
+-- Lua script to invalidate ANNs by rank
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - number of elements to leave
+local redis_lua_script_maybe_invalidate = [[
+  local card = redis.call('ZCARD', KEYS[1])
+  local lim = tonumber(KEYS[2])
+  if card > lim then
+    local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
+    for _,k in ipairs(to_delete) do
+      local tb = cjson.decode(k)
+      redis.call('DEL', tb.redis_key)
+      -- Also train vectors
+      redis.call('DEL', tb.redis_key .. '_spam')
+      redis.call('DEL', tb.redis_key .. '_ham')
+    end
+    redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
+    return to_delete
+  else
+    return {}
+  end
+]]
+
+-- Lua script to invalidate ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - current time
+-- key3 - key expire
+-- key4 - hostname
+local redis_lua_script_maybe_lock = [[
+  local locked = redis.call('HGET', KEYS[1], 'lock')
+  local now = tonumber(KEYS[2])
+  if locked then
+    locked = tonumber(locked)
+    local expire = tonumber(KEYS[3])
+    if now > locked and (now - locked) < expire then
+      return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'}
+    end
+  end
+  redis.call('HSET', KEYS[1], 'lock', tostring(now))
+  redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
+  return 1
+]]
+
+-- Lua script to save and unlock ANN in redis
+-- Uses the following keys
+-- key1 - prefix for ANN
+-- key2 - prefix for profile
+-- key3 - compressed ANN
+-- key4 - profile as JSON
+-- key5 - expire in seconds
+-- key6 - current time
+-- key7 - old key
+-- key8 - optional PCA
+local redis_lua_script_save_unlock = [[
+  local now = tonumber(KEYS[6])
+  redis.call('ZADD', KEYS[2], now, KEYS[4])
+  redis.call('HSET', KEYS[1], 'ann', KEYS[3])
+  redis.call('DEL', KEYS[1] .. '_spam')
+  redis.call('DEL', KEYS[1] .. '_ham')
+  redis.call('HDEL', KEYS[1], 'lock')
+  redis.call('HDEL', KEYS[7], 'lock')
+  redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
+  if KEYS[8] then
+    redis.call('HSET', KEYS[1], 'pca', KEYS[8])
+  end
+  return 1
+]]
+
+local redis_script_id = {}
+
+local function load_scripts()
+  redis_script_id.vectors_len = lua_redis.add_redis_script(redis_lua_script_vectors_len,
+    redis_params)
+  redis_script_id.maybe_invalidate = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
+    redis_params)
+  redis_script_id.maybe_lock = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
+    redis_params)
+  redis_script_id.save_unlock = lua_redis.add_redis_script(redis_lua_script_save_unlock,
+    redis_params)
+end
+
+local function create_ann(n, nlayers, rule)
+    -- We ignore number of layers so far when using kann
+  local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
+  local t = rspamd_kann.layer.input(n)
+  t = rspamd_kann.transform.relu(t)
+  t = rspamd_kann.layer.dense(t, nhidden);
+  t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
+  return rspamd_kann.new.kann(t)
+end
+
+-- Fills ANN data for a specific settings element
+local function fill_set_ann(set, ann_key)
+  if not set.ann then
+    set.ann = {
+      symbols = set.symbols,
+      distance = 0,
+      digest = set.digest,
+      redis_key = ann_key,
+      version = 0,
+    }
+  end
+end
+
+-- This function takes all inputs, applies PCA transformation and returns the final
+-- PCA matrix as rspamd_tensor
+local function learn_pca(inputs, max_inputs)
+  local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
+  local eigenvals = scatter_matrix:eigen()
+  -- scatter matrix is not filled with eigenvectors
+  lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
+  local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
+  for i=1,max_inputs do
+    w[i] = scatter_matrix[#scatter_matrix - i + 1]
+  end
+
+  lua_util.debugm(N, 'pca matrix: %s', w)
+
+  return w
+end
+
+-- This function is intended to extend lock for ANN during training
+-- It registers periodic that increases locked key each 30 seconds unless
+-- `set.learning_spawned` is set to `true`
+local function register_lock_extender(rule, set, ev_base, ann_key)
+  rspamd_config:add_periodic(ev_base, 30.0,
+      function()
+        local function redis_lock_extend_cb(_err, _)
+          if _err then
+            rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+                ann_key, _err)
+          else
+            rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
+                ann_key)
+          end
+        end
+
+        if set.learning_spawned then
+          lua_redis.redis_make_request_taskless(ev_base,
+              rspamd_config,
+              rule.redis,
+              nil,
+              true, -- is write
+              redis_lock_extend_cb, --callback
+              'HINCRBY', -- command
+              {ann_key, 'lock', '30'}
+          )
+        else
+          lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
+          return false -- do not plan any more updates
+        end
+
+        return true
+      end
+  )
+end
+
+local function can_push_train_vector(rule, task, learn_type, nspam, nham)
+  local train_opts = rule.train
+  local coin = math.random()
+
+  if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
+    rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+    return false
+  end
+
+  if train_opts.learn_mode == 'balanced' then
+    -- Keep balanced training set based on number of spam and ham samples
+    if learn_type == 'spam' then
+      if nspam <= train_opts.max_trains then
+        if nspam > nham then
+          -- Apply sampling
+          local skip_rate = 1.0 - nham / (nspam + 1)
+          if coin < skip_rate - train_opts.classes_bias then
+            rspamd_logger.infox(task,
+                'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+                learn_type,
+                skip_rate - train_opts.classes_bias,
+                nspam, nham)
+            return false
+          end
+        end
+        return true
+      else -- Enough learns
+        rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
+            learn_type,
+            nspam)
+      end
+    else
+      if nham <= train_opts.max_trains then
+        if nham > nspam then
+          -- Apply sampling
+          local skip_rate = 1.0 - nspam / (nham + 1)
+          if coin < skip_rate - train_opts.classes_bias then
+            rspamd_logger.infox(task,
+                'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+                learn_type,
+                skip_rate - train_opts.classes_bias,
+                nspam, nham)
+            return false
+          end
+        end
+        return true
+      else
+        rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
+            nham)
+      end
+    end
+  else
+    -- Probabilistic learn mode, we just skip learn if we already have enough samples or
+    -- if our coin drop is less than desired probability
+    if learn_type == 'spam' then
+      if nspam <= train_opts.max_trains then
+        if train_opts.spam_skip_prob then
+          if coin <= train_opts.spam_skip_prob then
+            rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
+                coin, train_opts.spam_skip_prob)
+            return false
+          end
+
+          return true
+        end
+      else
+        rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
+            nspam, train_opts.max_trains)
+      end
+    else
+      if nham <= train_opts.max_trains then
+        if train_opts.ham_skip_prob then
+          if coin <= train_opts.ham_skip_prob then
+            rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
+                coin, train_opts.ham_skip_prob)
+            return false
+          end
+
+          return true
+        end
+      else
+        rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
+            nham, train_opts.max_trains)
+      end
+    end
+  end
+
+  return false
+end
+
+-- Closure generator for unlock function
+local function gen_unlock_cb(rule, set, ann_key)
+  return function (err)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
+          rule.prefix, set.name, ann_key, err)
+    else
+      lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
+          rule.prefix, set.name, ann_key)
+    end
+  end
+end
+
+-- Used to generate new ANN key for specific profile
+local function new_ann_key(rule, set, version)
+  local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
+      rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
+
+  return ann_key
+end
+
+local function redis_ann_prefix(rule, settings_name)
+  -- We also need to count metatokens:
+  local n = meta_functions.version
+  return string.format('%s%d_%s_%d_%s',
+    settings.prefix, plugin_ver, rule.prefix, n, settings_name)
+end
+
+-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
+local function spawn_train(params)
+  -- Check training data sanity
+  -- Now we need to join inputs and create the appropriate test vectors
+  local n = #params.set.symbols +
+      meta_functions.rspamd_count_metatokens()
+
+  -- Now we can train ann
+  local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule)
+
+  if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then
+    -- Invalidate ANN as it is definitely invalid
+    -- TODO: add invalidation
+    assert(false)
+  else
+    local inputs, outputs = {}, {}
+
+    -- Used to show sparsed vectors in a convenient format (for debugging only)
+    local function debug_vec(t)
+      local ret = {}
+      for i,v in ipairs(t) do
+        if v ~= 0 then
+          ret[#ret + 1] = string.format('%d=%.2f', i, v)
+        end
+      end
+
+      return ret
+    end
+
+    -- Make training set by joining vectors
+    -- KANN automatically shuffles those samples
+    -- 1.0 is used for spam and -1.0 is used for ham
+    -- It implies that output layer can express that (e.g. tanh output)
+    for _,e in ipairs(params.spam_vec) do
+      inputs[#inputs + 1] = e
+      outputs[#outputs + 1] = {1.0}
+      --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
+    end
+    for _,e in ipairs(params.ham_vec) do
+      inputs[#inputs + 1] = e
+      outputs[#outputs + 1] = {-1.0}
+      --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
+    end
+
+    -- Called in child process
+    local function train()
+      local log_thresh = params.rule.train.max_iterations / 10
+      local seen_nan = false
+
+      local function train_cb(iter, train_cost, value_cost)
+        if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then
+          if train_cost ~= train_cost and not seen_nan then
+            -- We have nan :( try to log lot's of stuff to dig into a problem
+            seen_nan = true
+            rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
+                params.rule.prefix, params.set.name,
+                value_cost)
+            for i,e in ipairs(inputs) do
+              lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
+                  debug_vec(e), outputs[i][1])
+            end
+          end
+
+          rspamd_logger.infox(rspamd_config,
+              "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
+              params.rule.prefix, params.set.name,
+              params.ann_key,
+              iter,
+              train_cost,
+              value_cost)
+        end
+      end
+
+      lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
+          params.rule.prefix, params.set.name)
+
+      local ret,err = pcall(train_ann.train1, train_ann,
+          inputs, outputs, {
+            lr = params.rule.train.learning_rate,
+            max_epoch = params.rule.train.max_iterations,
+            cb = train_cb,
+            pca = (params.set.ann or {}).pca
+          })
+
+      if not ret then
+        rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
+            params.rule.prefix, params.set.name, err)
+
+        return nil
+      end
+
+      if not seen_nan then
+        local out = train_ann:save()
+        return out
+      else
+        return nil
+      end
+    end
+
+    params.set.learning_spawned = true
+
+    local function redis_save_cb(err)
+      if err then
+        rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
+            params.rule.prefix, params.set.name, params.ann_key, err)
+        lua_redis.redis_make_request_taskless(params.ev_base,
+            rspamd_config,
+            params.rule.redis,
+            nil,
+            false, -- is write
+            gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+            'HDEL', -- command
+            {params.ann_key, 'lock'}
+        )
+      else
+        rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
+            params.rule.prefix, params.set.name, params.set.ann.redis_key)
+      end
+    end
+
+    local function ann_trained(err, data)
+      params.set.learning_spawned = false
+      if err then
+        rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
+            params.rule.prefix, params.set.name, err)
+        lua_redis.redis_make_request_taskless(params.ev_base,
+            rspamd_config,
+            params.rule.redis,
+            nil,
+            true, -- is write
+            gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+            'HDEL', -- command
+            {params.ann_key, 'lock'}
+        )
+      else
+        local ann_data = rspamd_util.zstd_compress(data)
+        local pca_data
+
+        fill_set_ann(params.set, params.ann_key)
+        if params.set.ann.pca then
+          pca_data = rspamd_util.zstd_compress(params.set.ann.pca:save())
+        end
+        -- Deserialise ANN from the child process
+        ann_trained = rspamd_kann.load(data)
+        local version = (params.set.ann.version or 0) + 1
+        params.set.ann.version = version
+        params.set.ann.ann = ann_trained
+        params.set.ann.symbols = params.set.symbols
+        params.set.ann.redis_key = new_ann_key(params.rule, params.set, version)
+
+        local profile = {
+          symbols = params.set.symbols,
+          digest = params.set.digest,
+          redis_key = params.set.ann.redis_key,
+          version = version
+        }
+
+        local ucl = require "ucl"
+        local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+
+        rspamd_logger.infox(rspamd_config,
+            'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
+            params.rule.prefix, params.set.name,
+            #data, #ann_data,
+            #(params.set.ann.pca or {}), #(pca_data or {}),
+            params.set.ann.redis_key, params.ann_key)
+
+        lua_redis.exec_redis_script(redis_script_id.save_unlock,
+            {ev_base = params.ev_base, is_write = true},
+            redis_save_cb,
+            {profile.redis_key,
+             redis_ann_prefix(params.rule, params.set.name),
+             ann_data,
+             profile_serialized,
+             tostring(params.rule.ann_expire),
+             tostring(os.time()),
+             params.ann_key, -- old key to unlock...
+             pca_data
+            })
+      end
+    end
+
+    if params.rule.max_inputs then
+      fill_set_ann(params.set, params.ann_key)
+      -- Train PCA in the main process, presumably it is not that long
+      params.set.ann.pca = learn_pca(inputs, params.rule.max_inputs)
+    end
+
+    params.worker:spawn_process{
+      func = train,
+      on_complete = ann_trained,
+      proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name),
+    }
+    -- Spawn learn and register lock extension
+    params.set.learning_spawned = true
+    register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key)
+    return
+
+  end
+end
+
+-- This function is used to adjust profiles and allowed setting ids for each rule
+-- It must be called when all settings are already registered (e.g. at post-init for config)
+local function process_rules_settings()
+  local function process_settings_elt(rule, selt)
+    local profile = rule.profile[selt.name]
+    if profile then
+      -- Use static user defined profile
+      -- Ensure that we have an array...
+      lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
+          rule.prefix, selt.name, profile)
+      if not profile[1] then profile = lua_util.keys(profile) end
+      selt.symbols = profile
+    else
+      lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
+          rule.prefix, selt.name)
+    end
+
+    local function filter_symbols_predicate(sname)
+      if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then
+        return false
+      end
+      local fl = rspamd_config:get_symbol_flags(sname)
+      if fl then
+        fl = lua_util.list_to_hash(fl)
+
+        return not (fl.nostat or fl.idempotent or fl.skip or fl.composite)
+      end
+
+      return false
+    end
+
+    -- Generic stuff
+    table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))
+
+    selt.digest = lua_util.table_digest(selt.symbols)
+    selt.prefix = redis_ann_prefix(rule, selt.name)
+
+    lua_redis.register_prefix(selt.prefix, N,
+        string.format('NN prefix for rule "%s"; settings id "%s"',
+            rule.prefix, selt.name), {
+          persistent = true,
+          type = 'zlist',
+        })
+    -- Versions
+    lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
+        string.format('NN storage for rule "%s"; settings id "%s"',
+            rule.prefix, selt.name), {
+          persistent = true,
+          type = 'hash',
+        })
+    lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N,
+        string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+            rule.prefix, selt.name), {
+          persistent = true,
+          type = 'list',
+        })
+    lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N,
+        string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+            rule.prefix, selt.name), {
+          persistent = true,
+          type = 'list',
+        })
+  end
+
+  for k,rule in pairs(settings.rules) do
+    if not rule.allowed_settings then
+      rule.allowed_settings = {}
+    elseif rule.allowed_settings == 'all' then
+      -- Extract all settings ids
+      rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
+    end
+
+    -- Convert to a map <setting_id> -> true
+    rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
+
+    -- Check if we can work without settings
+    if k == 'default' or type(rule.default) ~= 'boolean' then
+      rule.default = true
+    end
+
+    rule.settings = {}
+
+    if rule.default then
+      local default_settings = {
+        symbols = lua_settings.default_symbols(),
+        name = 'default'
+      }
+
+      process_settings_elt(rule, default_settings)
+      rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
+    end
+
+    -- Now, for each allowed settings, we store sorted symbols + digest
+    -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
+    for s,_ in pairs(rule.allowed_settings) do
+      -- Here, we have a name, set of symbols and
+      local settings_id = s
+      if type(settings_id) ~= 'number' then
+        settings_id = lua_settings.numeric_settings_id(s)
+      end
+      local selt = lua_settings.settings_by_id(settings_id)
+
+      local nelt = {
+        symbols = selt.symbols, -- Already sorted
+        name = selt.name
+      }
+
+      process_settings_elt(rule, nelt)
+      for id,ex in pairs(rule.settings) do
+        if type(ex) == 'table' then
+          if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
+            -- Equal symbols, add reference
+            lua_util.debugm(N, rspamd_config,
+                'added reference from settings id %s to %s; same symbols',
+                nelt.name, ex.name)
+            rule.settings[settings_id] = id
+            nelt = nil
+          end
+        end
+      end
+
+      if nelt then
+        rule.settings[settings_id] = nelt
+        lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
+            nelt.name, settings_id, rule.prefix)
+      end
+    end
+  end
+end
+
+-- Extract settings element for a specific settings id
+local function get_rule_settings(task, rule)
+  local sid = task:get_settings_id() or -1
+  local set = rule.settings[sid]
+
+  if not set then return nil end
+
+  while type(set) == 'number' do
+    -- Reference to another settings!
+    set = rule.settings[set]
+  end
+
+  return set
+end
+
+local function result_to_vector(task, profile)
+  if not profile.zeros then
+    -- Fill zeros vector
+    local zeros = {}
+    for i=1,meta_functions.count_metatokens() do
+      zeros[i] = 0.0
+    end
+    for _,_ in ipairs(profile.symbols) do
+      zeros[#zeros + 1] = 0.0
+    end
+    profile.zeros = zeros
+  end
+
+  local vec = lua_util.shallowcopy(profile.zeros)
+  local mt = meta_functions.rspamd_gen_metatokens(task)
+
+  for i,v in ipairs(mt) do
+    vec[i] = v
+  end
+
+  task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
+
+  return vec
+end
+
+return {
+  can_push_train_vector = can_push_train_vector,
+  create_ann = create_ann,
+  default_options = default_options,
+  gen_unlock_cb = gen_unlock_cb,
+  get_rule_settings = get_rule_settings,
+  load_scripts = load_scripts,
+  module_config = module_config,
+  new_ann_key = new_ann_key,
+  plugin_ver = plugin_ver,
+  process_rules_settings = process_rules_settings,
+  redis_ann_prefix = redis_ann_prefix,
+  redis_params = redis_params,
+  redis_script_id = redis_script_id,
+  result_to_vector = result_to_vector,
+  settings = settings,
+  spawn_train = spawn_train,
+}
index e5204da633c8c90c7d20331fec4e1779f42ffe6d..136081ddcc5cbb5dd2ce787499b33519f9dd6cd3 100644 (file)
@@ -25,8 +25,9 @@ local rspamd_logger = require "rspamd_logger"
 -- Define default controller paths, could be overridden in local.d/controller.lua
 
 local controller_plugin_paths = {
+  maps = dofile(local_rules .. "/controller/maps.lua"),
+  neural = dofile(local_rules .. "/controller/neural.lua"),
   selectors = dofile(local_rules .. "/controller/selectors.lua"),
-  maps = dofile(local_rules .. "/controller/maps.lua")
 }
 
 if rspamd_util.file_exists(local_conf .. '/controller.lua') then
@@ -62,4 +63,4 @@ for plug,paths in pairs(controller_plugin_paths) do
         plug, path, type(attrs))
     end
   end
-end
\ No newline at end of file
+end
diff --git a/rules/controller/neural.lua b/rules/controller/neural.lua
new file mode 100644 (file)
index 0000000..3207e00
--- /dev/null
@@ -0,0 +1,72 @@
+--[[
+Copyright (c) 2020, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local neural_common = require "plugins/neural"
+local ts = require("tableshape").types
+local ucl = require "ucl"
+
+local E = {}
+
+-- Controller neural plugin
+
+local learn_request_schema = ts.shape{
+  ham_vec = ts.array_of(ts.array_of(ts.number)),
+  rule = ts.string:is_optional(),
+  spam_vec = ts.array_of(ts.array_of(ts.number)),
+}
+
+local function handle_learn(task, conn)
+  local parser = ucl.parser()
+  local ok, err = parser:parse_text(task:get_rawbody())
+  if not ok then
+    conn:send_error(400, err)
+    return
+  end
+  local req_params = parser:get_object()
+
+  ok, err = learn_request_schema:transform(req_params)
+  if not ok then
+    conn:send_error(400, err)
+    return
+  end
+
+  local rule_name = req_params.rule or 'default'
+  local rule = neural_common.settings.rules[rule_name]
+  local set = neural_common.get_rule_settings(task, rule)
+  local version = ((set.ann or E).version or 0) + 1
+
+  neural_common.spawn_train{
+    ev_base = task:get_ev_base(),
+    ann_key = neural_common.new_ann_key(rule, set, version),
+    set = set,
+    rule = rule,
+    ham_vec = req_params.ham_vec,
+    spam_vec = req_params.spam_vec,
+    worker = task:get_worker(),
+  }
+
+  conn:send_string('{"success" : true}')
+end
+
+rspamd_config:add_post_init(neural_common.process_rules_settings)
+
+return {
+  learn = {
+    handler = handle_learn,
+    enable = true,
+    need_task = true,
+  },
+}
index 5eab75d7677b17ca6317005a390fcbcd089ef3e4..3d1c387a5045db8d45c7bd0f98ef7dbaa45554b9 100644 (file)
@@ -19,22 +19,21 @@ if confighelp then
   return
 end
 
-local rspamd_logger = require "rspamd_logger"
-local rspamd_util = require "rspamd_util"
-local rspamd_kann = require "rspamd_kann"
-local rspamd_text = require "rspamd_text"
+local fun = require "fun"
 local lua_redis = require "lua_redis"
 local lua_util = require "lua_util"
+local lua_verdict = require "lua_verdict"
+local neural_common = require "plugins/neural"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_logger = require "rspamd_logger"
 local rspamd_tensor = require "rspamd_tensor"
-local fun = require "fun"
-local lua_settings = require "lua_settings"
-local meta_functions = require "lua_meta"
+local rspamd_text = require "rspamd_text"
+local rspamd_util = require "rspamd_util"
 local ts = require("tableshape").types
-local lua_verdict = require "lua_verdict"
+
 local N = "neural"
 
--- Used in prefix to avoid wrong ANN to be loaded
-local plugin_ver = '2'
+local settings = neural_common.settings
 
 -- Module vars
 local default_options = {
@@ -52,7 +51,7 @@ local default_options = {
     classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
     spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
     ham_skip_prob = 0.0, -- proportional mode: ham skip probability
-    store_pool_only = false, -- store tokens in mempool variable only (disables autotrain);
+    store_pool_only = false, -- store tokens in cache only (disables autotrain);
     -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
   },
   watch_interval = 60.0,
@@ -77,207 +76,9 @@ local redis_profile_schema = ts.shape{
 local has_blas = rspamd_tensor.has_blas()
 local text_cookie = rspamd_text.cookie
 
--- Rule structure:
--- * static config fields (see `default_options`)
--- * prefix - name or defined prefix
--- * settings - table of settings indexed by settings id, -1 is used when no settings defined
-
--- Rule settings element defines elements for specific settings id:
--- * symbols - static symbols profile (defined by config or extracted from symcache)
--- * name - name of settings id
--- * digest - digest of all symbols
--- * ann - dynamic ANN configuration loaded from Redis
--- * train - train data for ANN (e.g. the currently trained ANN)
-
--- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
--- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
--- * version - version of ANN loaded from redis
--- * redis_key - name of ANN key in Redis
--- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
--- * distance - distance between set.symbols and set.ann.symbols
--- * ann - kann object
-
-local settings = {
-  rules = {},
-  prefix = 'rn', -- Neural network default prefix
-  max_profiles = 3, -- Maximum number of NN profiles stored
-}
-
-local module_config = rspamd_config:get_all_opt("neural")
-if not module_config then
-  -- Legacy
-  module_config = rspamd_config:get_all_opt("fann_redis")
-end
-
-
--- Lua script that checks if we can store a new training vector
--- Uses the following keys:
--- key1 - ann key
--- returns nspam,nham (or nil if locked)
-local redis_lua_script_vectors_len = [[
-  local prefix = KEYS[1]
-  local locked = redis.call('HGET', prefix, 'lock')
-  if locked then
-    local host = redis.call('HGET', prefix, 'hostname') or 'unknown'
-    return string.format('%s:%s', host, locked)
-  end
-  local nspam = 0
-  local nham = 0
-
-  local ret = redis.call('LLEN', prefix .. '_spam')
-  if ret then nspam = tonumber(ret) end
-  ret = redis.call('LLEN', prefix .. '_ham')
-  if ret then nham = tonumber(ret) end
-
-  return {nspam,nham}
-]]
-local redis_lua_script_vectors_len_id = nil
-
--- Lua script to invalidate ANNs by rank
--- Uses the following keys
--- key1 - prefix for keys
--- key2 - number of elements to leave
-local redis_lua_script_maybe_invalidate = [[
-  local card = redis.call('ZCARD', KEYS[1])
-  local lim = tonumber(KEYS[2])
-  if card > lim then
-    local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
-    for _,k in ipairs(to_delete) do
-      local tb = cjson.decode(k)
-      redis.call('DEL', tb.redis_key)
-      -- Also train vectors
-      redis.call('DEL', tb.redis_key .. '_spam')
-      redis.call('DEL', tb.redis_key .. '_ham')
-    end
-    redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
-    return to_delete
-  else
-    return {}
-  end
-]]
-local redis_maybe_invalidate_id = nil
-
--- Lua script to invalidate ANN from redis
--- Uses the following keys
--- key1 - prefix for keys
--- key2 - current time
--- key3 - key expire
--- key4 - hostname
-local redis_lua_script_maybe_lock = [[
-  local locked = redis.call('HGET', KEYS[1], 'lock')
-  local now = tonumber(KEYS[2])
-  if locked then
-    locked = tonumber(locked)
-    local expire = tonumber(KEYS[3])
-    if now > locked and (now - locked) < expire then
-      return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'}
-    end
-  end
-  redis.call('HSET', KEYS[1], 'lock', tostring(now))
-  redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
-  return 1
-]]
-local redis_maybe_lock_id = nil
-
--- Lua script to save and unlock ANN in redis
--- Uses the following keys
--- key1 - prefix for ANN
--- key2 - prefix for profile
--- key3 - compressed ANN
--- key4 - profile as JSON
--- key5 - expire in seconds
--- key6 - current time
--- key7 - old key
--- key8 - optional PCA
-local redis_lua_script_save_unlock = [[
-  local now = tonumber(KEYS[6])
-  redis.call('ZADD', KEYS[2], now, KEYS[4])
-  redis.call('HSET', KEYS[1], 'ann', KEYS[3])
-  redis.call('DEL', KEYS[1] .. '_spam')
-  redis.call('DEL', KEYS[1] .. '_ham')
-  redis.call('HDEL', KEYS[1], 'lock')
-  redis.call('HDEL', KEYS[7], 'lock')
-  redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
-  if KEYS[8] then
-    redis.call('HSET', KEYS[1], 'pca', KEYS[8])
-  end
-  return 1
-]]
-local redis_save_unlock_id = nil
-
-local redis_params
-
-local function load_scripts(params)
-  redis_lua_script_vectors_len_id = lua_redis.add_redis_script(redis_lua_script_vectors_len,
-    params)
-  redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
-    params)
-  redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
-    params)
-  redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock,
-    params)
-end
-
-local function result_to_vector(task, profile)
-  if not profile.zeros then
-    -- Fill zeros vector
-    local zeros = {}
-    for i=1,meta_functions.count_metatokens() do
-      zeros[i] = 0.0
-    end
-    for _,_ in ipairs(profile.symbols) do
-      zeros[#zeros + 1] = 0.0
-    end
-    profile.zeros = zeros
-  end
-
-  local vec = lua_util.shallowcopy(profile.zeros)
-  local mt = meta_functions.rspamd_gen_metatokens(task)
-
-  for i,v in ipairs(mt) do
-    vec[i] = v
-  end
-
-  task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
-
-  return vec
-end
-
--- Used to generate new ANN key for specific profile
-local function new_ann_key(rule, set, version)
-  local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
-      rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
-
-  return ann_key
-end
-
--- Extract settings element for a specific settings id
-local function get_rule_settings(task, rule)
-  local sid = task:get_settings_id() or -1
-
-  local set = rule.settings[sid]
-
-  if not set then return nil end
-
-  while type(set) == 'number' do
-    -- Reference to another settings!
-    set = rule.settings[set]
-  end
-
-  return set
-end
-
--- Generate redis prefix for specific rule and specific settings
-local function redis_ann_prefix(rule, settings_name)
-  -- We also need to count metatokens:
-  local n = meta_functions.version
-  return string.format('%s%d_%s_%d_%s',
-    settings.prefix, plugin_ver, rule.prefix, n, settings_name)
-end
-
 -- Creates and stores ANN profile in Redis
 local function new_ann_profile(task, rule, set, version)
-  local ann_key = new_ann_key(rule, set, version)
+  local ann_key = neural_common.new_ann_key(rule, set, version, settings)
 
   local profile = {
     symbols = set.symbols,
@@ -321,7 +122,7 @@ local function ann_scores_filter(task)
     local ann
     local profile
 
-    local set = get_rule_settings(task, rule)
+    local set = neural_common.get_rule_settings(task, rule)
     if set then
       if set.ann then
         ann = set.ann.ann
@@ -336,7 +137,7 @@ local function ann_scores_filter(task)
     end
 
     if ann then
-      local vec = result_to_vector(task, profile)
+      local vec = neural_common.result_to_vector(task, profile)
 
       local score
       local out = ann:apply1(vec, set.ann.pca)
@@ -357,112 +158,12 @@ local function ann_scores_filter(task)
   end
 end
 
-local function create_ann(n, nlayers, rule)
-    -- We ignore number of layers so far when using kann
-  local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
-  local t = rspamd_kann.layer.input(n)
-  t = rspamd_kann.transform.relu(t)
-  t = rspamd_kann.layer.dense(t, nhidden);
-  t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
-  return rspamd_kann.new.kann(t)
-end
-
-local function can_push_train_vector(rule, task, learn_type, nspam, nham)
-  local train_opts = rule.train
-  local coin = math.random()
-
-  if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
-    rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
-    return false
-  end
-
-  if train_opts.learn_mode == 'balanced' then
-    -- Keep balanced training set based on number of spam and ham samples
-    if learn_type == 'spam' then
-      if nspam <= train_opts.max_trains then
-        if nspam > nham then
-          -- Apply sampling
-          local skip_rate = 1.0 - nham / (nspam + 1)
-          if coin < skip_rate - train_opts.classes_bias then
-            rspamd_logger.infox(task,
-                'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
-                learn_type,
-                skip_rate - train_opts.classes_bias,
-                nspam, nham)
-            return false
-          end
-        end
-        return true
-      else -- Enough learns
-        rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
-            learn_type,
-            nspam)
-      end
-    else
-      if nham <= train_opts.max_trains then
-        if nham > nspam then
-          -- Apply sampling
-          local skip_rate = 1.0 - nspam / (nham + 1)
-          if coin < skip_rate - train_opts.classes_bias then
-            rspamd_logger.infox(task,
-                'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
-                learn_type,
-                skip_rate - train_opts.classes_bias,
-                nspam, nham)
-            return false
-          end
-        end
-        return true
-      else
-        rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
-            nham)
-      end
-    end
-  else
-    -- Probabilistic learn mode, we just skip learn if we already have enough samples or
-    -- if our coin drop is less than desired probability
-    if learn_type == 'spam' then
-      if nspam <= train_opts.max_trains then
-        if train_opts.spam_skip_prob then
-          if coin <= train_opts.spam_skip_prob then
-            rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
-                coin, train_opts.spam_skip_prob)
-            return false
-          end
-
-          return true
-        end
-      else
-        rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
-            nspam, train_opts.max_trains)
-      end
-    else
-      if nham <= train_opts.max_trains then
-        if train_opts.ham_skip_prob then
-          if coin <= train_opts.ham_skip_prob then
-            rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
-                coin, train_opts.ham_skip_prob)
-            return false
-          end
-
-          return true
-        end
-      else
-        rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
-            nham, train_opts.max_trains)
-      end
-    end
-  end
-
-  return false
-end
-
 local function ann_push_task_result(rule, task, verdict, score, set)
   local train_opts = rule.train
   local learn_spam, learn_ham
   local skip_reason = 'unknown'
 
-  if train_opts.autotrain then
+  if not train_opts.store_pool_only and train_opts.autotrain then
     if train_opts.spam_score then
       learn_spam = score >= train_opts.spam_score
 
@@ -510,10 +211,10 @@ local function ann_push_task_result(rule, task, verdict, score, set)
       learn_ham = false
       learn_spam = false
 
-      -- Explicitly store tokens in a mempool variable
-      local vec = result_to_vector(task, set)
-      task:get_mempool():set_variable('neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
-      task:get_mempool():set_variable('neural_profile_digest', set.digest)
+      -- Explicitly store tokens in cache
+      local vec = neural_common.result_to_vector(task, set)
+      task:cache_set('neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
+      task:cache_set('neural_profile_digest', set.digest)
       skip_reason = 'store_pool_only has been set'
     end
   end
@@ -527,8 +228,8 @@ local function ann_push_task_result(rule, task, verdict, score, set)
       if not err and type(data) == 'table' then
         local nspam,nham = data[1],data[2]
 
-        if can_push_train_vector(rule, task, learn_type, nspam, nham) then
-          local vec = result_to_vector(task, set)
+        if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
+          local vec = neural_common.result_to_vector(task, set)
 
           local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
           local target_key = set.ann.redis_key .. '_' .. learn_type
@@ -585,7 +286,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
             set.name)
       end
 
-      lua_redis.exec_redis_script(redis_lua_script_vectors_len_id,
+      lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
           {task = task, is_write = false},
           vectors_len_cb,
           {
@@ -605,284 +306,6 @@ end
 
 --- Offline training logic
 
--- Closure generator for unlock function
-local function gen_unlock_cb(rule, set, ann_key)
-  return function (err)
-    if err then
-      rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
-          rule.prefix, set.name, ann_key, err)
-    else
-      lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
-          rule.prefix, set.name, ann_key)
-    end
-  end
-end
-
--- This function is intended to extend lock for ANN during training
--- It registers periodic that increases locked key each 30 seconds unless
--- `set.learning_spawned` is set to `true`
-local function register_lock_extender(rule, set, ev_base, ann_key)
-  rspamd_config:add_periodic(ev_base, 30.0,
-      function()
-        local function redis_lock_extend_cb(_err, _)
-          if _err then
-            rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
-                ann_key, _err)
-          else
-            rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
-                ann_key)
-          end
-        end
-
-        if set.learning_spawned then
-          lua_redis.redis_make_request_taskless(ev_base,
-              rspamd_config,
-              rule.redis,
-              nil,
-              true, -- is write
-              redis_lock_extend_cb, --callback
-              'HINCRBY', -- command
-              {ann_key, 'lock', '30'}
-          )
-        else
-          lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
-          return false -- do not plan any more updates
-        end
-
-        return true
-      end
-  )
-end
-
--- This function takes all inputs, applies PCA transformation and returns the final
--- PCA matrix as rspamd_tensor
-local function learn_pca(inputs, max_inputs)
-  local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
-  local eigenvals = scatter_matrix:eigen()
-  -- scatter matrix is not filled with eigenvectors
-  lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
-  local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
-  for i=1,max_inputs do
-    w[i] = scatter_matrix[#scatter_matrix - i + 1]
-  end
-
-  lua_util.debugm(N, 'pca matrix: %s', w)
-
-  return w
-end
-
--- Fills ANN data for a specific settings element
-local function fill_set_ann(set, ann_key)
-  if not set.ann then
-    set.ann = {
-      symbols = set.symbols,
-      distance = 0,
-      digest = set.digest,
-      redis_key = ann_key,
-      version = 0,
-    }
-  end
-end
-
--- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
-local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
-  -- Check training data sanity
-  -- Now we need to join inputs and create the appropriate test vectors
-  local n = #set.symbols +
-      meta_functions.rspamd_count_metatokens()
-
-  -- Now we can train ann
-  local train_ann = create_ann(rule.max_inputs or n, 3, rule)
-
-  if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
-    -- Invalidate ANN as it is definitely invalid
-    -- TODO: add invalidation
-    assert(false)
-  else
-    local inputs, outputs = {}, {}
-
-    -- Used to show sparsed vectors in a convenient format (for debugging only)
-    local function debug_vec(t)
-      local ret = {}
-      for i,v in ipairs(t) do
-        if v ~= 0 then
-          ret[#ret + 1] = string.format('%d=%.2f', i, v)
-        end
-      end
-
-      return ret
-    end
-
-    -- Make training set by joining vectors
-    -- KANN automatically shuffles those samples
-    -- 1.0 is used for spam and -1.0 is used for ham
-    -- It implies that output layer can express that (e.g. tanh output)
-    for _,e in ipairs(spam_vec) do
-      inputs[#inputs + 1] = e
-      outputs[#outputs + 1] = {1.0}
-      --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
-    end
-    for _,e in ipairs(ham_vec) do
-      inputs[#inputs + 1] = e
-      outputs[#outputs + 1] = {-1.0}
-      --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
-    end
-
-    -- Called in child process
-    local function train()
-      local log_thresh = rule.train.max_iterations / 10
-      local seen_nan = false
-
-      local function train_cb(iter, train_cost, value_cost)
-        if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
-          if train_cost ~= train_cost and not seen_nan then
-            -- We have nan :( try to log lot's of stuff to dig into a problem
-            seen_nan = true
-            rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
-                rule.prefix, set.name,
-                value_cost)
-            for i,e in ipairs(inputs) do
-              lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
-                  debug_vec(e), outputs[i][1])
-            end
-          end
-
-          rspamd_logger.infox(rspamd_config,
-              "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
-              rule.prefix, set.name,
-              ann_key,
-              iter,
-              train_cost,
-              value_cost)
-        end
-      end
-
-      lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
-          rule.prefix, set.name)
-
-      local ret,err = pcall(train_ann.train1, train_ann,
-          inputs, outputs, {
-            lr = rule.train.learning_rate,
-            max_epoch = rule.train.max_iterations,
-            cb = train_cb,
-            pca = (set.ann or {}).pca
-          })
-
-      if not ret then
-        rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
-            rule.prefix, set.name, err)
-
-        return nil
-      end
-
-      if not seen_nan then
-        local out = train_ann:save()
-        return out
-      else
-        return nil
-      end
-    end
-
-    set.learning_spawned = true
-
-    local function redis_save_cb(err)
-      if err then
-        rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
-            rule.prefix, set.name, ann_key, err)
-        lua_redis.redis_make_request_taskless(ev_base,
-            rspamd_config,
-            rule.redis,
-            nil,
-            false, -- is write
-            gen_unlock_cb(rule, set, ann_key), --callback
-            'HDEL', -- command
-            {ann_key, 'lock'}
-        )
-      else
-        rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
-            rule.prefix, set.name, set.ann.redis_key)
-      end
-    end
-
-    local function ann_trained(err, data)
-      set.learning_spawned = false
-      if err then
-        rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
-            rule.prefix, set.name, err)
-        lua_redis.redis_make_request_taskless(ev_base,
-            rspamd_config,
-            rule.redis,
-            nil,
-            true, -- is write
-            gen_unlock_cb(rule, set, ann_key), --callback
-            'HDEL', -- command
-            {ann_key, 'lock'}
-        )
-      else
-        local ann_data = rspamd_util.zstd_compress(data)
-        local pca_data
-
-        fill_set_ann(set, ann_key)
-        if set.ann.pca then
-          pca_data = rspamd_util.zstd_compress(set.ann.pca:save())
-        end
-        -- Deserialise ANN from the child process
-        ann_trained = rspamd_kann.load(data)
-        local version = (set.ann.version or 0) + 1
-        set.ann.version = version
-        set.ann.ann = ann_trained
-        set.ann.symbols = set.symbols
-        set.ann.redis_key = new_ann_key(rule, set, version)
-
-        local profile = {
-          symbols = set.symbols,
-          digest = set.digest,
-          redis_key = set.ann.redis_key,
-          version = version
-        }
-
-        local ucl = require "ucl"
-        local profile_serialized = ucl.to_format(profile, 'json-compact', true)
-
-        rspamd_logger.infox(rspamd_config,
-            'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
-            rule.prefix, set.name,
-            #data, #ann_data,
-            #(set.ann.pca or {}), #(pca_data or {}),
-            set.ann.redis_key, ann_key)
-
-        lua_redis.exec_redis_script(redis_save_unlock_id,
-            {ev_base = ev_base, is_write = true},
-            redis_save_cb,
-            {profile.redis_key,
-             redis_ann_prefix(rule, set.name),
-             ann_data,
-             profile_serialized,
-             tostring(rule.ann_expire),
-             tostring(os.time()),
-             ann_key, -- old key to unlock...
-             pca_data
-            })
-      end
-    end
-
-    if rule.max_inputs then
-      fill_set_ann(set, ann_key)
-      -- Train PCA in the main process, presumably it is not that long
-      set.ann.pca = learn_pca(inputs, rule.max_inputs)
-    end
-
-    worker:spawn_process{
-      func = train,
-      on_complete = ann_trained,
-      proctitle = string.format("ANN train for %s/%s", rule.prefix, set.name),
-    }
-  end
-  -- Spawn learn and register lock extension
-  set.learning_spawned = true
-  register_lock_extender(rule, set, ev_base, ann_key)
-end
-
 -- Utility to extract and split saved training vectors to a table of tables
 local function process_training_vectors(data)
   return fun.totable(fun.map(function(tok)
@@ -909,14 +332,16 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
         rule.redis,
         nil,
         true, -- is write
-          gen_unlock_cb(rule, set, ann_key), --callback
+          neural_common.gen_unlock_cb(rule, set, ann_key), --callback
         'HDEL', -- command
         {ann_key, 'lock'}
       )
     else
       -- Decompress and convert to numbers each training vector
       ham_elts = process_training_vectors(data)
-      spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts)
+      neural_common.spawn_train({worker = worker, ev_base = ev_base,
+          rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts,
+          spam_vec = spam_elts})
     end
   end
 
@@ -931,7 +356,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
         rule.redis,
         nil,
         true, -- is write
-          gen_unlock_cb(rule, set, ann_key), --callback
+          neural_common.gen_unlock_cb(rule, set, ann_key), --callback
         'HDEL', -- command
         {ann_key, 'lock'}
       )
@@ -987,7 +412,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   -- Call Redis script that tries to acquire a lock
   -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
   -- ANN is locked by another host (or a process, meh)
-  lua_redis.exec_redis_script(redis_maybe_lock_id,
+  lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
     {ev_base = ev_base, is_write = true},
     redis_lock_cb,
       {
@@ -1376,7 +801,7 @@ local function cleanup_anns(rule, cfg, ev_base)
     end
 
     if type(set) == 'table' then
-      lua_redis.exec_redis_script(redis_maybe_invalidate_id,
+      lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
           {ev_base = ev_base, is_write = true},
           invalidate_cb,
           {set.prefix, tostring(settings.max_profiles)})
@@ -1411,7 +836,7 @@ local function ann_push_vector(task)
   end
 
   for _,rule in pairs(settings.rules) do
-    local set = get_rule_settings(task, rule)
+    local set = neural_common.get_rule_settings(task, rule)
 
     if set then
       ann_push_task_result(rule, task, verdict, score, set)
@@ -1423,155 +848,20 @@ local function ann_push_vector(task)
 end
 
 
--- This function is used to adjust profiles and allowed setting ids for each rule
--- It must be called when all settings are already registered (e.g. at post-init for config)
-local function process_rules_settings()
-  local function process_settings_elt(rule, selt)
-    local profile = rule.profile[selt.name]
-    if profile then
-      -- Use static user defined profile
-      -- Ensure that we have an array...
-      lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
-          rule.prefix, selt.name, profile)
-      if not profile[1] then profile = lua_util.keys(profile) end
-      selt.symbols = profile
-    else
-      lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
-          rule.prefix, selt.name)
-    end
-
-    local function filter_symbols_predicate(sname)
-      if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then
-        return false
-      end
-      local fl = rspamd_config:get_symbol_flags(sname)
-      if fl then
-        fl = lua_util.list_to_hash(fl)
-
-        return not (fl.nostat or fl.idempotent or fl.skip or fl.composite)
-      end
-
-      return false
-    end
-
-    -- Generic stuff
-    table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))
-
-    selt.digest = lua_util.table_digest(selt.symbols)
-    selt.prefix = redis_ann_prefix(rule, selt.name)
-
-    lua_redis.register_prefix(selt.prefix, N,
-        string.format('NN prefix for rule "%s"; settings id "%s"',
-            rule.prefix, selt.name), {
-          persistent = true,
-          type = 'zlist',
-        })
-    -- Versions
-    lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
-        string.format('NN storage for rule "%s"; settings id "%s"',
-            rule.prefix, selt.name), {
-          persistent = true,
-          type = 'hash',
-        })
-    lua_redis.register_prefix(selt.prefix .. '_\\d+_spam', N,
-        string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
-            rule.prefix, selt.name), {
-          persistent = true,
-          type = 'list',
-        })
-    lua_redis.register_prefix(selt.prefix .. '_\\d+_ham', N,
-        string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
-            rule.prefix, selt.name), {
-          persistent = true,
-          type = 'list',
-        })
-  end
-
-  for k,rule in pairs(settings.rules) do
-    if not rule.allowed_settings then
-      rule.allowed_settings = {}
-    elseif rule.allowed_settings == 'all' then
-      -- Extract all settings ids
-      rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
-    end
-
-    -- Convert to a map <setting_id> -> true
-    rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
-
-    -- Check if we can work without settings
-    if k == 'default' or type(rule.default) ~= 'boolean' then
-      rule.default = true
-    end
-
-    rule.settings = {}
-
-    if rule.default then
-      local default_settings = {
-        symbols = lua_settings.default_symbols(),
-        name = 'default'
-      }
-
-      process_settings_elt(rule, default_settings)
-      rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
-    end
-
-    -- Now, for each allowed settings, we store sorted symbols + digest
-    -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
-    for s,_ in pairs(rule.allowed_settings) do
-      -- Here, we have a name, set of symbols and
-      local settings_id = s
-      if type(settings_id) ~= 'number' then
-        settings_id = lua_settings.numeric_settings_id(s)
-      end
-      local selt = lua_settings.settings_by_id(settings_id)
-
-      local nelt = {
-        symbols = selt.symbols, -- Already sorted
-        name = selt.name
-      }
-
-      process_settings_elt(rule, nelt)
-      for id,ex in pairs(rule.settings) do
-        if type(ex) == 'table' then
-          if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
-            -- Equal symbols, add reference
-            lua_util.debugm(N, rspamd_config,
-                'added reference from settings id %s to %s; same symbols',
-                nelt.name, ex.name)
-            rule.settings[settings_id] = id
-            nelt = nil
-          end
-        end
-      end
-
-      if nelt then
-        rule.settings[settings_id] = nelt
-        lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
-            nelt.name, settings_id, rule.prefix)
-      end
-    end
-  end
-end
-
-redis_params = lua_redis.parse_redis_server('neural')
-
-if not redis_params then
-  redis_params = lua_redis.parse_redis_server('fann_redis')
-end
-
 -- Initialization part
-if not (module_config and type(module_config) == 'table') or not redis_params then
+if not (neural_common.module_config and type(neural_common.module_config) == 'table')
+    or not neural_common.redis_params then
   rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
   lua_util.disable_module(N, "redis")
   return
 end
 
-local rules = module_config['rules']
+local rules = neural_common.module_config['rules']
 
 if not rules then
   -- Use legacy configuration
   rules = {}
-  rules['default'] = module_config
+  rules['default'] = neural_common.module_config
 end
 
 local id = rspamd_config:register_symbol({
@@ -1582,8 +872,7 @@ local id = rspamd_config:register_symbol({
   callback = ann_scores_filter
 })
 
-settings = lua_util.override_defaults(settings, module_config)
-settings.rules = {} -- Reset unless validated further in the cycle
+neural_common.settings.rules = {} -- Reset unless validated further in the cycle
 
 if settings.blacklisted_symbols and settings.blacklisted_symbols[1] then
   -- Transform to hash for simplicity
@@ -1593,7 +882,7 @@ end
 -- Check all rules
 for k,r in pairs(rules) do
   local rule_elt = lua_util.override_defaults(default_options, r)
-  rule_elt['redis'] = redis_params
+  rule_elt['redis'] = neural_common.redis_params
   rule_elt['anns'] = {} -- Store ANNs here
 
   if not rule_elt.prefix then
@@ -1651,11 +940,12 @@ rspamd_config:register_symbol({
   callback = ann_push_vector
 })
 
+-- We also need to deal with settings
+rspamd_config:add_post_init(neural_common.process_rules_settings)
+
 -- Add training scripts
 for _,rule in pairs(settings.rules) do
-  load_scripts(rule.redis)
-  -- We also need to deal with settings
-  rspamd_config:add_post_init(process_rules_settings)
+  neural_common.load_scripts(rule.redis)
   -- This function will check ANNs in Redis when a worker is loaded
   rspamd_config:add_on_load(function(cfg, ev_base, worker)
     if worker:is_scanner() then
diff --git a/test/functional/cases/330_neural.robot b/test/functional/cases/330_neural.robot
deleted file mode 100644 (file)
index 8ce3428..0000000
+++ /dev/null
@@ -1,74 +0,0 @@
-*** Settings ***
-Suite Setup      Neural Setup
-Suite Teardown   Neural Teardown
-Library         Process
-Library         ${TESTDIR}/lib/rspamd.py
-Resource        ${TESTDIR}/lib/rspamd.robot
-Variables       ${TESTDIR}/lib/vars.py
-
-*** Variables ***
-${URL_TLD}      ${TESTDIR}/../lua/unit/test_tld.dat
-${CONFIG}       ${TESTDIR}/configs/neural.conf
-${MESSAGE}      ${TESTDIR}/messages/spam_message.eml
-${REDIS_SCOPE}  Suite
-${RSPAMD_SCOPE}  Suite
-
-*** Test Cases ***
-Train
-  Sleep  2s  Wait for redis mess
-  FOR    ${INDEX}    IN RANGE    0    10
-    Scan File  ${MESSAGE}  Settings={symbols_enabled = ["SPAM_SYMBOL"]}
-    Expect Symbol  SPAM_SYMBOL
-    Scan File  ${MESSAGE}  Settings={symbols_enabled = ["HAM_SYMBOL"]}
-    Expect Symbol  HAM_SYMBOL
-  END
-
-Check Neural HAM
-  Sleep  2s  Wait for neural to be loaded
-  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
-  Expect Symbol  NEURAL_HAM_SHORT
-  Do Not Expect Symbol  NEURAL_SPAM_SHORT
-  Expect Symbol  NEURAL_HAM_SHORT_PCA
-  Do Not Expect Symbol  NEURAL_SPAM_SHORT_PCA
-
-Check Neural SPAM
-  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
-  Expect Symbol  NEURAL_SPAM_SHORT
-  Do Not Expect Symbol  NEURAL_HAM_SHORT
-  Expect Symbol  NEURAL_SPAM_SHORT_PCA
-  Do Not Expect Symbol  NEURAL_HAM_SHORT_PCA
-
-
-Train INVERSE
-  FOR    ${INDEX}    IN RANGE    0    10
-    Scan File  ${MESSAGE}  Settings={symbols_enabled = ["SPAM_SYMBOL"]; SPAM_SYMBOL = -5;}
-    Expect Symbol  SPAM_SYMBOL
-    Scan File  ${MESSAGE}  Settings={symbols_enabled = ["HAM_SYMBOL"]; HAM_SYMBOL = 5;}
-    Expect Symbol  HAM_SYMBOL
-  END
-
-Check Neural HAM INVERSE
-  Sleep  2s  Wait for neural to be loaded
-  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"]}
-  Expect Symbol  NEURAL_SPAM_SHORT
-  Expect Symbol  NEURAL_SPAM_SHORT_PCA
-  Do Not Expect Symbol  NEURAL_HAM_SHORT
-  Do Not Expect Symbol  NEURAL_HAM_SHORT_PCA
-
-Check Neural SPAM INVERSE
-  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"]}
-  Expect Symbol  NEURAL_HAM_SHORT
-  Expect Symbol  NEURAL_HAM_SHORT_PCA
-  Do Not Expect Symbol  NEURAL_SPAM_SHORT
-  Do Not Expect Symbol  NEURAL_SPAM_SHORT_PCA
-
-*** Keywords ***
-Neural Setup
-  ${TMPDIR} =    Make Temporary Directory
-  Set Suite Variable        ${TMPDIR}
-  Run Redis
-  Generic Setup
-
-Neural Teardown
-  Shutdown Process With Children  ${REDIS_PID}
-  Normal Teardown
diff --git a/test/functional/cases/330_neural/001_autotrain.robot b/test/functional/cases/330_neural/001_autotrain.robot
new file mode 100644 (file)
index 0000000..8ce3428
--- /dev/null
@@ -0,0 +1,74 @@
+*** Settings ***
+Suite Setup      Neural Setup
+Suite Teardown   Neural Teardown
+Library         Process
+Library         ${TESTDIR}/lib/rspamd.py
+Resource        ${TESTDIR}/lib/rspamd.robot
+Variables       ${TESTDIR}/lib/vars.py
+
+*** Variables ***
+${URL_TLD}      ${TESTDIR}/../lua/unit/test_tld.dat
+${CONFIG}       ${TESTDIR}/configs/neural.conf
+${MESSAGE}      ${TESTDIR}/messages/spam_message.eml
+${REDIS_SCOPE}  Suite
+${RSPAMD_SCOPE}  Suite
+
+*** Test Cases ***
+Train
+  Sleep  2s  Wait for redis mess
+  FOR    ${INDEX}    IN RANGE    0    10
+    Scan File  ${MESSAGE}  Settings={symbols_enabled = ["SPAM_SYMBOL"]}
+    Expect Symbol  SPAM_SYMBOL
+    Scan File  ${MESSAGE}  Settings={symbols_enabled = ["HAM_SYMBOL"]}
+    Expect Symbol  HAM_SYMBOL
+  END
+
+Check Neural HAM
+  Sleep  2s  Wait for neural to be loaded
+  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+  Expect Symbol  NEURAL_HAM_SHORT
+  Do Not Expect Symbol  NEURAL_SPAM_SHORT
+  Expect Symbol  NEURAL_HAM_SHORT_PCA
+  Do Not Expect Symbol  NEURAL_SPAM_SHORT_PCA
+
+Check Neural SPAM
+  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+  Expect Symbol  NEURAL_SPAM_SHORT
+  Do Not Expect Symbol  NEURAL_HAM_SHORT
+  Expect Symbol  NEURAL_SPAM_SHORT_PCA
+  Do Not Expect Symbol  NEURAL_HAM_SHORT_PCA
+
+
+Train INVERSE
+  FOR    ${INDEX}    IN RANGE    0    10
+    Scan File  ${MESSAGE}  Settings={symbols_enabled = ["SPAM_SYMBOL"]; SPAM_SYMBOL = -5;}
+    Expect Symbol  SPAM_SYMBOL
+    Scan File  ${MESSAGE}  Settings={symbols_enabled = ["HAM_SYMBOL"]; HAM_SYMBOL = 5;}
+    Expect Symbol  HAM_SYMBOL
+  END
+
+Check Neural HAM INVERSE
+  Sleep  2s  Wait for neural to be loaded
+  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"]}
+  Expect Symbol  NEURAL_SPAM_SHORT
+  Expect Symbol  NEURAL_SPAM_SHORT_PCA
+  Do Not Expect Symbol  NEURAL_HAM_SHORT
+  Do Not Expect Symbol  NEURAL_HAM_SHORT_PCA
+
+Check Neural SPAM INVERSE
+  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"]}
+  Expect Symbol  NEURAL_HAM_SHORT
+  Expect Symbol  NEURAL_HAM_SHORT_PCA
+  Do Not Expect Symbol  NEURAL_SPAM_SHORT
+  Do Not Expect Symbol  NEURAL_SPAM_SHORT_PCA
+
+*** Keywords ***
+Neural Setup
+  ${TMPDIR} =    Make Temporary Directory
+  Set Suite Variable        ${TMPDIR}
+  Run Redis
+  Generic Setup
+
+Neural Teardown
+  Shutdown Process With Children  ${REDIS_PID}
+  Normal Teardown
diff --git a/test/functional/cases/330_neural/002_manualtrain.robot b/test/functional/cases/330_neural/002_manualtrain.robot
new file mode 100644 (file)
index 0000000..22a4212
--- /dev/null
@@ -0,0 +1,75 @@
+*** Settings ***
+Suite Setup      Neural Setup
+Suite Teardown   Neural Teardown
+Library         Process
+Library         ${TESTDIR}/lib/rspamd.py
+Resource        ${TESTDIR}/lib/rspamd.robot
+Variables       ${TESTDIR}/lib/vars.py
+
+*** Variables ***
+${URL_TLD}      ${TESTDIR}/../lua/unit/test_tld.dat
+${CONFIG}       ${TESTDIR}/configs/neural_noauto.conf
+${MESSAGE}      ${TESTDIR}/messages/spam_message.eml
+${REDIS_SCOPE}  Suite
+${RSPAMD_SCOPE}  Suite
+
+*** Test Cases ***
+Collect training vectors & train manually
+  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["SPAM_SYMBOL","SAVE_NN_ROW"]}
+  Expect Symbol  SPAM_SYMBOL
+  # Save neural inputs for later
+  ${SPAM_ROW} =  Get File  ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0]
+  Remove File  ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0]
+  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["HAM_SYMBOL","SAVE_NN_ROW"]}
+  Expect Symbol  HAM_SYMBOL
+  # Save neural inputs for later
+  ${HAM_ROW} =  Get File  ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0]
+  Remove File  ${SCAN_RESULT}[symbols][SAVE_NN_ROW][options][0]
+  ${HAM_ROW} =  Run  ${RSPAMADM} lua -a ${HAM_ROW} ${TESTDIR}/util/nn_unpack.lua
+  ${HAM_ROW} =  Evaluate  json.loads("${HAM_ROW}")
+  ${SPAM_ROW} =  Run  ${RSPAMADM} lua -a ${SPAM_ROW} ${TESTDIR}/util/nn_unpack.lua
+  ${SPAM_ROW} =  Evaluate  json.loads("${SPAM_ROW}")
+  ${HAM_VEC} =  Evaluate  [${HAM_ROW}] * 10
+  ${SPAM_VEC} =  Evaluate  [${SPAM_ROW}] * 10
+  ${json1} =  Evaluate  json.dumps({"spam_vec": ${SPAM_VEC}, "ham_vec": ${HAM_VEC}, "rule": "SHORT"})
+  # Save variables for use in inverse training
+  Set Suite Variable  ${HAM_VEC}
+  Set Suite Variable  ${SPAM_VEC}
+  HTTP  POST  ${LOCAL_ADDR}  ${PORT_CONTROLLER}  /plugins/neural/learn  ${json1}
+  Sleep  2s  Wait for neural to be loaded
+
+Check Neural HAM
+  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+  Do Not Expect Symbol  NEURAL_SPAM_SHORT
+  Expect Symbol  NEURAL_HAM_SHORT
+
+Check Neural SPAM
+  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+  Do Not Expect Symbol  NEURAL_HAM_SHORT
+  Expect Symbol  NEURAL_SPAM_SHORT
+
+Train inverse
+  ${json2} =  Evaluate  json.dumps({"spam_vec": ${HAM_VEC}, "ham_vec": ${SPAM_VEC}, "rule": "SHORT"})
+  HTTP  POST  ${LOCAL_ADDR}  ${PORT_CONTROLLER}  /plugins/neural/learn  ${json2}
+  Sleep  2s  Wait for neural to be loaded
+
+Check Neural HAM - inverse
+  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["HAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+  Do Not Expect Symbol  NEURAL_HAM_SHORT
+  Expect Symbol  NEURAL_SPAM_SHORT
+
+Check Neural SPAM - inverse
+  Scan File  ${MESSAGE}  Settings={symbols_enabled = ["SPAM_SYMBOL"];groups_enabled=["neural"];symbols_disabled = ["NEURAL_LEARN"]}
+  Do Not Expect Symbol  NEURAL_SPAM_SHORT
+  Expect Symbol  NEURAL_HAM_SHORT
+
+*** Keywords ***
+Neural Setup
+  ${TMPDIR} =    Make Temporary Directory
+  Set Suite Variable        ${TMPDIR}
+  Run Redis
+  Generic Setup
+
+Neural Teardown
+  Shutdown Process With Children  ${REDIS_PID}
+  Normal Teardown
diff --git a/test/functional/configs/neural_noauto.conf b/test/functional/configs/neural_noauto.conf
new file mode 100644 (file)
index 0000000..55f0a42
--- /dev/null
@@ -0,0 +1,85 @@
+options = {
+  url_tld = "${URL_TLD}"
+  pidfile = "${TMPDIR}/rspamd.pid"
+  lua_path = "${INSTALLROOT}/share/rspamd/lib/?.lua"
+  filters = [];
+  explicit_modules = ["settings"];
+}
+
+logging = {
+  type = "file",
+  level = "debug"
+  filename = "${TMPDIR}/rspamd.log"
+  log_usec = true;
+}
+metric = {
+  name = "default",
+  actions = {
+    reject = 100500,
+    add_header = 50500,
+  }
+  unknown_weight = 1
+}
+worker {
+  type = normal
+  bind_socket = ${LOCAL_ADDR}:${PORT_NORMAL}
+  count = 1
+  task_timeout = 10s;
+}
+worker {
+  type = controller
+  bind_socket = ${LOCAL_ADDR}:${PORT_CONTROLLER}
+  count = 1
+  secure_ip = ["127.0.0.1", "::1"];
+  stats_path = "${TMPDIR}/stats.ucl"
+}
+
+modules {
+  path = "${TESTDIR}/../../src/plugins/lua/"
+}
+
+lua = "${TESTDIR}/lua/test_coverage.lua";
+
+neural {
+  rules {
+      SHORT {
+          train {
+              learning_rate = 0.001;
+              max_usages = 2;
+              spam_score = 1;
+              ham_score = -1;
+              max_trains = 10;
+              max_iterations = 250;
+              store_pool_only = true;
+          }
+          symbol_spam = "NEURAL_SPAM_SHORT";
+          symbol_ham = "NEURAL_HAM_SHORT";
+          ann_expire = 86400;
+          watch_interval = 0.5;
+      }
+      SHORT_PCA {
+          train {
+              learning_rate = 0.001;
+              max_usages = 2;
+              spam_score = 1;
+              ham_score = -1;
+              max_trains = 10;
+              max_iterations = 250;
+              store_pool_only = true;
+          }
+          symbol_spam = "NEURAL_SPAM_SHORT_PCA";
+          symbol_ham = "NEURAL_HAM_SHORT_PCA";
+          ann_expire = 86400;
+          watch_interval = 0.5;
+          max_inputs = 2;
+      }
+  }
+  allow_local = true;
+
+}
+redis {
+  servers = "${REDIS_ADDR}:${REDIS_PORT}";
+  expand_keys = true;
+}
+
+lua = "${TESTDIR}/lua/neural.lua";
index 53d4e70f93a22a34228cd15dec5520fd6cb6ef5c..0b6cc6f3853db525ce5ad23656bbbf11124deb49 100644 (file)
@@ -209,6 +209,7 @@ Run Rspamd
   ...  ELSE  Make Temporary Directory
   Set Directory Ownership  ${tmpdir}  ${RSPAMD_USER}  ${RSPAMD_GROUP}
   ${template} =  Get File  ${CONFIG}
+  # TODO: stop using this; we have Lupa now
   FOR  ${i}  IN  @{vargs}
     ${newvalue} =  Replace Variables  ${${i}}
     Set To Dictionary  ${d}  ${i}=${newvalue}
@@ -218,7 +219,8 @@ Run Rspamd
   Log  ${config}
   Create File  ${tmpdir}/rspamd.conf  ${config}
   ${result} =  Run Process  ${RSPAMD}  -u  ${RSPAMD_USER}  -g  ${RSPAMD_GROUP}
-  ...  -c  ${tmpdir}/rspamd.conf  env:TMPDIR=${tmpdir}  env:DBDIR=${tmpdir}  env:LD_LIBRARY_PATH=${TESTDIR}/../../contrib/aho-corasick  stdout=DEVNULL  stderr=DEVNULL
+  ...  -c  ${tmpdir}/rspamd.conf  env:TMPDIR=${tmpdir}  env:DBDIR=${tmpdir}  env:LD_LIBRARY_PATH=${TESTDIR}/../../contrib/aho-corasick
+  ...  env:RSPAMD_INSTALLROOT=${INSTALLROOT}  stdout=DEVNULL  stderr=DEVNULL
   Run Keyword If  ${result.rc} != 0  Log  ${result.stderr}
   Should Be Equal As Integers  ${result.rc}  0
   Wait Until Keyword Succeeds  10x  1 sec  Check Pidfile  ${tmpdir}/rspamd.pid  timeout=0.5s
index 70857d429a8725896e736434600463a14c569988..ccdad1b68a3df727901f2690ee1d8b659d0e531d 100644 (file)
@@ -1,3 +1,5 @@
+local logger = require "rspamd_logger"
+
 rspamd_config:register_symbol({
   name = 'SPAM_SYMBOL',
   score = 5.0,
@@ -21,4 +23,39 @@ rspamd_config:register_symbol({
   callback = function()
     return true, 'Fires always'
   end
-})
\ No newline at end of file
+})
+
+rspamd_config.SAVE_NN_ROW = {
+  callback = function(task)
+    local fname = os.tmpname()
+    task:cache_set('nn_row_tmpfile', fname)
+    return true, 1.0, fname
+  end
+}
+
+rspamd_config.SAVE_NN_ROW_IDEMPOTENT = {
+  callback = function(task)
+    local function tohex(str)
+      return (str:gsub('.', function (c)
+        return string.format('%02X', string.byte(c))
+      end))
+    end
+    local fname = task:cache_get('nn_row_tmpfile')
+    if not fname then
+      return
+    end
+    local f, err = io.open(fname, 'w')
+    if not f then
+      logger.errx(task, err)
+      return
+    end
+    f:write(tohex(task:cache_get('neural_vec_mpack') or ''))
+    f:close()
+    return
+  end,
+  type = 'idempotent',
+  flags = 'explicit_disable',
+  priority = 10,
+}
+
+dofile(rspamd_env.INSTALLROOT .. "/share/rspamd/rules/controller/init.lua")
diff --git a/test/functional/util/nn_unpack.lua b/test/functional/util/nn_unpack.lua
new file mode 100644 (file)
index 0000000..fee98d5
--- /dev/null
@@ -0,0 +1,16 @@
+local ucl = require "ucl"
+
+local function unhex(str)
+  return (str:gsub('..', function (cc)
+    return string.char(tonumber(cc, 16))
+  end))
+end
+
+local parser = ucl.parser()
+local ok, err = parser:parse_string(unhex(arg[1]), 'msgpack')
+if not ok then
+  io.stderr:write(err)
+  os.exit(1)
+end
+
+print(ucl.to_format(parser:get_object(), 'json-compact'))