]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Restore old fann_scores, move common parts
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 5 Nov 2016 12:03:20 +0000 (12:03 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 5 Nov 2016 12:03:20 +0000 (12:03 +0000)
src/lua/global_functions.lua
src/plugins/lua/fann_classifier.lua
src/plugins/lua/fann_redis.lua [new file with mode: 0644]
src/plugins/lua/fann_scores.lua

index 0eb461496d929242c7880f046a9657c207dc32c9..b8b840043c6d7936838f3fe0e5f83a2c37674b28 100644 (file)
@@ -164,3 +164,220 @@ function rspamd_str_split(s, sep)
   local p = lpeg.Ct(elem * (sep * elem)^0)   -- make a table capture
   return lpeg.match(p, s)
 end
+
+-- Metafunctions
+local function meta_size_function(task)
+  local sizes = {
+    100,
+    200,
+    500,
+    1000,
+    2000,
+    4000,
+    10000,
+    20000,
+    30000,
+    100000,
+    200000,
+    400000,
+    800000,
+    1000000,
+    2000000,
+    8000000,
+  }
+
+  local size = task:get_size()
+  for i = 1,#sizes do
+    if sizes[i] >= size then
+      return {i / #sizes}
+    end
+  end
+
+  return {0}
+end
+
+local function meta_images_function(task)
+  local images = task:get_images()
+  local ntotal = 0
+  local njpg = 0
+  local npng = 0
+  local nlarge = 0
+  local nsmall = 0
+
+  if images then
+    for _,img in ipairs(images) do
+      if img:get_type() == 'png' then
+        npng = npng + 1
+      elseif img:get_type() == 'jpeg' then
+        njpg = njpg + 1
+      end
+
+      local w = img:get_width()
+      local h = img:get_height()
+
+      if w > 0 and h > 0 then
+        if w + h > 256 then
+          nlarge = nlarge + 1
+        else
+          nsmall = nsmall + 1
+        end
+      end
+
+      ntotal = ntotal + 1
+    end
+  end
+  if ntotal > 0 then
+    njpg = njpg / ntotal
+    npng = npng / ntotal
+    nlarge = nlarge / ntotal
+    nsmall = nsmall / ntotal
+  end
+  return {ntotal,njpg,npng,nlarge,nsmall}
+end
+
+local function meta_nparts_function(task)
+  local nattachments = 0
+  local ntextparts = 0
+  local totalparts = 1
+
+  local tp = task:get_text_parts()
+  if tp then
+    ntextparts = #tp
+  end
+
+  local parts = task:get_parts()
+
+  if parts then
+    for _,p in ipairs(parts) do
+      if p:get_filename() then
+        nattachments = nattachments + 1
+      end
+      totalparts = totalparts + 1
+    end
+  end
+
+  return {ntextparts/totalparts, nattachments/totalparts}
+end
+
+local function meta_encoding_function(task)
+  local nutf = 0
+  local nother = 0
+
+  local tp = task:get_text_parts()
+  if tp then
+    for _,p in ipairs(tp) do
+      if p:is_utf() then
+        nutf = nutf + 1
+      else
+        nother = nother + 1
+      end
+    end
+  end
+
+  return {nutf, nother}
+end
+
+local function meta_recipients_function(task)
+  local nmime = 0
+  local nsmtp = 0
+
+  if task:has_recipients('mime') then
+    nmime = #(task:get_recipients('mime'))
+  end
+  if task:has_recipients('smtp') then
+    nsmtp = #(task:get_recipients('smtp'))
+  end
+
+  if nmime > 0 then nmime = 1.0 / nmime end
+  if nsmtp > 0 then nsmtp = 1.0 / nsmtp end
+
+  return {nmime,nsmtp}
+end
+
+local function meta_received_function(task)
+  local ret = 0
+  local rh = task:get_received_headers()
+
+  if rh and #rh > 0 then
+    ret = 1 / #rh
+  end
+
+  return {ret}
+end
+
+local function meta_urls_function(task)
+  if task:has_urls() then
+    return {1.0 / #(task:get_urls())}
+  end
+
+  return {0}
+end
+
+local function meta_attachments_function(task)
+end
+
+local metafunctions = {
+  {
+    cb = meta_size_function,
+    ninputs = 1,
+  },
+  {
+    cb = meta_images_function,
+    ninputs = 5,
+    -- 1 - number of images,
+    -- 2 - number of png images,
+    -- 3 - number of jpeg images
+    -- 4 - number of large images (> 128 x 128)
+    -- 5 - number of small images (< 128 x 128)
+  },
+  {
+    cb = meta_nparts_function,
+    ninputs = 2,
+    -- 1 - number of text parts
+    -- 2 - number of attachments
+  },
+  {
+    cb = meta_encoding_function,
+    ninputs = 2,
+    -- 1 - number of utf parts
+    -- 2 - number of non-utf parts
+  },
+  {
+    cb = meta_recipients_function,
+    ninputs = 2,
+    -- 1 - number of mime rcpt
+    -- 2 - number of smtp rcpt
+  },
+  {
+    cb = meta_received_function,
+    ninputs = 1,
+  },
+  {
+    cb = meta_urls_function,
+    ninputs = 1,
+  },
+}
+
+function rspamd_gen_metatokens(task)
+  local ipairs = ipairs
+  local metatokens = {}
+  for _,mt in ipairs(metafunctions) do
+    local ct = mt.cb(task)
+
+    for _,tok in ipairs(ct) do
+      table.insert(metatokens, tok)
+    end
+  end
+
+  return metatokens
+end
+
+function rspamd_count_metatokens()
+  local ipairs = ipairs
+  local total = 0
+  for _,mt in ipairs(metafunctions) do
+    total = total + mt.ninputs
+  end
+
+  return total
+end
\ No newline at end of file
index af7acece819027f7e5f676dc7e47f0a7fb8eb6f2..9c35d0bfa24de92798edfdc661dfe826d2385fc2 100644 (file)
@@ -149,7 +149,7 @@ local function tokens_to_vector(tokens)
 end
 
 local function add_metatokens(task, vec)
-    local mt = gen_metatokens(task)
+    local mt = rspamd_gen_metatokens(task)
     for _,tok in ipairs(mt) do
       table.insert(vec, tok)
     end
@@ -157,7 +157,7 @@ end
 
 local function create_fann()
   local layers = {}
-  local mt_size = count_metatokens()
+  local mt_size = rspamd_count_metatokens()
   local neurons = classifier_config.neurons + mt_size
 
   for i = 1,classifier_config.layers - 1 do
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
new file mode 100644 (file)
index 0000000..e81af47
--- /dev/null
@@ -0,0 +1,586 @@
+--[[
+Copyright (c) 2016, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+-- This plugin is a concept of FANN scores adjustment
+-- NOT FOR PRODUCTION USE so far
+
+local rspamd_logger = require "rspamd_logger"
+local rspamd_fann = require "rspamd_fann"
+local rspamd_util = require "rspamd_util"
+local fann_symbol_spam = 'FANN_SPAM'
+local fann_symbol_ham = 'FANN_HAM'
+require "fun" ()
+local ucl = require "ucl"
+
+local module_log_id = 0x100
+-- Module vars
+-- ANNs indexed by settings id
+local data = {
+  ['0'] = {
+    fann_mtime = 0,
+    ntrains = 0,
+    epoch = 0,
+  }
+}
+
+
+-- Lua script to train a row
+-- Uses the following keys:
+-- key1 - prefix for keys
+-- key2 - max count of learns
+-- key3 - spam or ham
+-- returns 1 or 0: 1 - allow learn, 0 - not allow learn
+local redis_lua_script_can_train = [[
+  local locked = redis.call('GET', KEYS[1] .. '_locked')
+  if locked then return 0 end
+  local nspam = 0
+  local nham = 0
+
+  local ret = redis.call('LLEN', KEYS[1] .. '_spam')
+  if ret then nspam = tonumber(ret) end
+  ret = redis.call('LLEN', KEYS[1] .. '_ham')
+  if ret then nham = tonumber(ret) end
+
+  if KEYS[3] == 'spam' then
+    if nham + 1 >= nspam then return tostring(nspam + 1) end
+  else
+    if nspam + 1 >= nham then return tostring(nham + 1) end
+  end
+
+  return tostring(0)
+]]
+local redis_can_train_sha = nil
+
+-- Lua script to load ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - local version
+-- returns nil or bulk string if new ANN can be loaded
+local redis_lua_script_maybe_load = [[
+  local locked = redis.call('GET', KEYS[1] .. '_locked')
+  if locked then return false end
+
+  local ver = 0
+  local ret = redis.call('GET', KEYS[1] .. '_version')
+  if ret then ver = tonumber(ret) end
+  if ver > KEYS[2] then return redis.call('GET', KEYS[1] .. '_ann') end
+
+  return false
+]]
+local redis_maybe_load_sha = nil
+
+-- Lua script to invalidate ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+local redis_lua_script_maybe_invalidate = [[
+  local locked = redis.call('GET', KEYS[1] .. '_locked')
+  if locked then return false end
+  redis.call('SET', KEYS[1] .. '_locked', '1')
+  redis.call('SET', KEYS[1] .. '_version', '0')
+  redis.call('DEL', KEYS[1] .. '_spam')
+  redis.call('DEL', KEYS[1] .. '_ham')
+  redis.call('DEL', KEYS[1] .. '_data')
+  redis.call('DEL', KEYS[1] .. '_locked')
+  return 1
+]]
+local redis_maybe_invalidate_sha = nil
+
+local redis_params
+redis_params = rspamd_parse_redis_server('fann_redis')
+
+local fann_prefix = 'RFANN'
+local max_trains = 1000
+local max_epoch = 100
+local use_settings = false
+local watch_interval = 60.0
+
+local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args)
+  if not ev_base or not redis_params or not callback or not command then
+    return false,nil,nil
+  end
+
+  local addr
+  local rspamd_redis = require "rspamd_redis"
+
+  if key then
+    if is_write then
+      addr = redis_params['write_servers']:get_upstream_by_hash(key)
+    else
+      addr = redis_params['read_servers']:get_upstream_by_hash(key)
+    end
+  else
+    if is_write then
+      addr = redis_params['write_servers']:get_upstream_master_slave(key)
+    else
+      addr = redis_params['read_servers']:get_upstream_round_robin(key)
+    end
+  end
+
+  if not addr then
+    logger.errx(task, 'cannot select server to make redis request')
+  end
+
+  local options = {
+    ev_base = ev_base,
+    config = cfg,
+    callback = callback,
+    host = addr:get_addr(),
+    timeout = redis_params['timeout'],
+    cmd = command,
+    args = args
+  }
+
+  if redis_params['password'] then
+    options['password'] = redis_params['password']
+  end
+
+  if redis_params['db'] then
+    options['dbname'] = redis_params['db']
+  end
+
+  local ret,conn = rspamd_redis.make_request(options)
+  if not ret then
+    rspamd_logger.errx('cannot execute redis request')
+  end
+  return ret,conn,addr
+end
+
+local function symbols_to_fann_vector(syms, scores)
+  local learn_data = {}
+  local matched_symbols = {}
+  local n = rspamd_config:get_symbols_count()
+
+  each(function(s, score)
+     matched_symbols[s + 1] = rspamd_util.tanh(score)
+  end, zip(syms, scores))
+
+  for i=1,n do
+    if matched_symbols[i] then
+      learn_data[i] = matched_symbols[i]
+    else
+      learn_data[i] = 0
+    end
+  end
+
+  return learn_data
+end
+
+local function gen_fann_prefix(id)
+  if use_settings then
+    return fann_prefix .. id
+  else
+    return fann_prefix
+  end
+end
+
+local function is_fann_valid(ann)
+  if ann then
+    local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
+
+    if n ~= ann:get_inputs() then
+      rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
+      ' is found in the cache', ann:get_inputs(), n)
+      return false
+    end
+    local layers = ann:get_layers()
+
+    if not layers or #layers ~= 5 then
+      rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s',
+        #layers)
+      return false
+    end
+
+    return true
+  end
+end
+
+local function fann_scores_filter(task)
+  local id = '0'
+  if use_settings then
+   local sid = task:get_settings_id()
+   if sid then
+    id = tostring(sid)
+   end
+  end
+
+  if data[id].fann then
+    local symbols,scores = task:get_symbols_numeric()
+    local fann_data = symbols_to_fann_vector(symbols, scores)
+    local mt = rspamd_gen_metatokens(task)
+
+    for _,tok in ipairs(mt) do
+      table.insert(fann_data, tok)
+    end
+
+    local out = data[id].fann:test(fann_data)
+    local symscore = string.format('%.3f', out[1])
+    rspamd_logger.infox(task, 'fann score: %s', symscore)
+
+    if out[1] > 0 then
+      local result = rspamd_util.normalize_prob(out[1] / 2.0, 0)
+      task:insert_result(fann_symbol_spam, result, symscore, id)
+    else
+      local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
+      task:insert_result(fann_symbol_ham, result, symscore, id)
+    end
+  end
+end
+
+local function create_train_fann(n, id)
+  data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1)
+  data[id].ntrains = 0
+  data[id].epoch = 0
+end
+
+local function load_or_invalidate_fann(data, id, ev_base)
+  local err,ann_data = rspamd_util.zstd_decompress(data)
+  local ann
+
+  if err or not ann_data then
+    rspamd_logger.errx('cannot decompress ann: %s', err)
+  else
+    ann = rspamd_fann.load_data(ann_data)
+  end
+
+  if is_fann_valid(ann) then
+    data[id].fann = ann
+  else
+    local function redis_invalidate_cb(err, data)
+      if err then
+        rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, err)
+      elseif type(data) == 'string' then
+        rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, err)
+      end
+    end
+    -- Invalidate ANN
+    rspamd_logger.infox('invalidate ANN %s')
+    redis_make_request(ev_base,
+      rspamd_config,
+      nil,
+      true, -- is write
+      redis_invalidate_cb, --callback
+      'EVALSHA', -- command
+      {redis_maybe_invalidate_sha, 1, fann_prefix .. id}
+    )
+  end
+end
+
+local function fann_train_callback(score, required_score, results, cf, id, opts, extra, ev_base)
+  local fname = gen_fann_prefix(id)
+
+  local learn_spam, learn_ham = false, false
+
+  if opts['spam_score'] then
+    learn_spam = score >= opts['spam_score']
+  else
+    learn_spam = score >= required_score
+  end
+  if opts['ham_score'] then
+    learn_ham = score <= opts['ham_score']
+  else
+    learn_ham = score < 0
+  end
+
+  if learn_spam or learn_ham then
+    local k
+    if learn_spam then k = 'spam' else k = 'ham' end
+
+    local function learn_vec_cb(err, data)
+      if err then
+        rspamd_logger.errx(rspamd_config, 'cannot store train vector: %s', err)
+      end
+    end
+
+    local function can_train_cb(err, data)
+      if not err and tonumber(data) > 0 then
+        local learn_data = symbols_to_fann_vector(
+          map(function(r) return r[1] end, results),
+          map(function(r) return r[2] end, results)
+        )
+        -- Add filtered meta tokens
+        each(function(e) table.insert(learn_data, e) end, extra)
+        local str = rspamd_util.zstd_compress(table.concat(learn_data, ';'))
+
+        redis_make_request(ev_base,
+          rspamd_config,
+          nil,
+          true, -- is write
+          learn_vec_cb, --callback
+          'LPUSH', -- command
+          {fname .. '_' .. k, str} -- arguments
+        )
+      else
+        if err then
+          rspamd_logger.errx(rspamd_config, 'cannot check if we can train: %s', err)
+        end
+      end
+    end
+
+    redis_make_request(ev_base,
+      rspamd_config,
+      nil,
+      false, -- is write
+      can_train_cb, --callback
+      'EVALSHA', -- command
+      {redis_can_train_sha, '3', fname, tostring(max_trains), k} -- arguments
+    )
+  end
+end
+
+local function train_fann(cfg, ev_base, elt)
+
+end
+
+local function maybe_train_fanns(cfg, ev_base)
+  local function members_cb(err, data)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
+    elseif type(data) == 'table' then
+      each(function(i, elt)
+        local redis_len_cb = function(err, data)
+          if err then
+            rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err)
+          elseif data and type(data) == 'number' or type(data) == 'string' then
+            if tonumber(data) and tonumber(data) > max_trains then
+              train_fann(cfg, ev_base, elt)
+            end
+          end
+        end
+
+        local local_ver = 0
+        local numelt = tonumber(elt)
+        if data[numelt] then
+          if data[numelt].version then
+            local_ver = data[numelt].version
+          end
+        end
+        redis_make_request(ev_base,
+          rspamd_config,
+          nil,
+          false, -- is write
+          redis_len_cb, --callback
+          'LLEN', -- command
+          {fann_prefix .. elt .. '_spam'}
+        )
+      end,
+      data)
+    end
+  end
+
+  if not redis_maybe_load_sha then
+    -- Plan new event early
+    return 1.0
+  end
+  -- First we need to get all fanns stored in our Redis
+  redis_make_request(ev_base,
+    rspamd_config,
+    nil,
+    false, -- is write
+    members_cb, --callback
+    'SMEMBERS', -- command
+    {fann_prefix} -- arguments
+  )
+
+  return watch_interval
+end
+
+local function check_fanns(cfg, ev_base)
+  local function members_cb(err, data)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
+    elseif type(data) == 'table' then
+      each(function(i, elt)
+        local redis_update_cb = function(err, data)
+          if err then
+            rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err)
+          elseif data and type(data) == 'string' then
+            load_or_invalidate_fann(data, elt, ev_base)
+          end
+        end
+
+        local local_ver = 0
+        local numelt = tonumber(elt)
+        if data[numelt] then
+          if data[numelt].version then
+            local_ver = data[numelt].version
+          end
+        end
+        redis_make_request(ev_base,
+          rspamd_config,
+          nil,
+          false, -- is write
+          redis_update_cb, --callback
+          'EVALSHA', -- command
+          {redis_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)}
+        )
+      end,
+      data)
+    end
+  end
+
+  if not redis_maybe_load_sha then
+    -- Plan new event early
+    return 1.0
+  end
+  -- First we need to get all fanns stored in our Redis
+  redis_make_request(ev_base,
+    rspamd_config,
+    nil,
+    false, -- is write
+    members_cb, --callback
+    'SMEMBERS', -- command
+    {fann_prefix} -- arguments
+  )
+
+  return watch_interval
+end
+
+-- Initialization part
+
+local opts = rspamd_config:get_all_opt("fann_redis")
+if not (opts and type(opts) == 'table') or not redis_params then
+  rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
+  return
+end
+
+if not rspamd_fann.is_enabled() then
+  rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
+    'module is eventually disabled')
+  return
+else
+  use_settings = opts['use_settings']
+  rspamd_config:set_metric_symbol({
+    name = fann_symbol_spam,
+    score = 3.0,
+    description = 'Neural network SPAM',
+    group = 'fann'
+  })
+  local id = rspamd_config:register_symbol({
+    name = fann_symbol_spam,
+    type = 'postfilter',
+    priority = 5,
+    callback = fann_scores_filter
+  })
+  rspamd_config:set_metric_symbol({
+    name = fann_symbol_ham,
+    score = -2.0,
+    description = 'Neural network HAM',
+    group = 'fann'
+  })
+  rspamd_config:register_symbol({
+    name = fann_symbol_ham,
+    type = 'virtual',
+    parent = id
+  })
+  if opts['train'] then
+    rspamd_config:add_on_load(function(cfg)
+      if opts['train']['max_train'] then
+        max_trains = opts['train']['max_train']
+      end
+      if opts['train']['max_epoch'] then
+        max_epoch = opts['train']['max_epoch']
+      end
+      local ret = cfg:register_worker_script("log_helper",
+        function(score, req_score, results, cf, id, extra, ev_base)
+          -- map (snd x) (filter (fst x == module_id) extra)
+          local extra_fann = map(function(e) return e[2] end,
+            filter(function(e) return e[1] == module_log_id end, extra))
+          if use_settings then
+            fann_train_callback(score, req_score, results, cf,
+              tostring(id), opts['train'], extra_fann, ev_base)
+          else
+            fann_train_callback(score, req_score, results, cf, '0',
+              opts['train'], extra_fann, ev_base)
+          end
+        end)
+
+      if not ret then
+        rspamd_logger.errx(cfg, 'cannot find worker "log_helper"')
+      end
+    end)
+    -- This is needed to pass extra tokens from worker to log_helper
+    rspamd_plugins["fann_score"] = {
+      log_callback = function(task)
+        return totable(map(
+          function(tok) return {module_log_id, tok} end,
+          rspamd_gen_metatokens(task)))
+      end
+    }
+  end
+  -- Add training scripts
+  rspamd_config:add_on_load(function(cfg, ev_base, worker)
+    local function can_train_sha_cb(err, data)
+      if err or not data or type(data) ~= 'string' then
+        rspamd_logger.errx(cfg, 'cannot save redis train script: %s', err)
+      else
+        redis_can_train_sha = tostring(data)
+      end
+    end
+    redis_make_request(ev_base,
+      rspamd_config,
+      nil,
+      true, -- is write
+      can_train_sha_cb, --callback
+      'SCRIPT', -- command
+      {'LOAD', redis_lua_script_can_train} -- arguments
+    )
+
+    local function maybe_load_sha_cb(err, data)
+      if err or not data or type(data) ~= 'string' then
+        rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err)
+      else
+        redis_maybe_load_sha = tostring(data)
+
+        rspamd_config:add_periodic(ev_base, 0.0,
+          function(cfg, ev_base)
+            return check_fanns(cfg, ev_base)
+          end)
+      end
+    end
+    redis_make_request(ev_base,
+      rspamd_config,
+      nil,
+      true, -- is write
+      maybe_load_sha_cb, --callback
+      'SCRIPT', -- command
+      {'LOAD', redis_lua_script_maybe_load} -- arguments
+    )
+
+    local function maybe_invalidate_sha_cb(err, data)
+      if err or not data or type(data) ~= 'string' then
+        rspamd_logger.errx(cfg, 'cannot save redis invalidate script: %s', err)
+      else
+        redis_maybe_invalidate_sha = tostring(data)
+      end
+    end
+    redis_make_request(ev_base,
+      rspamd_config,
+      nil,
+      true, -- is write
+      maybe_invalidate_sha_cb, --callback
+      'SCRIPT', -- command
+      {'LOAD', redis_lua_script_maybe_invalidate} -- arguments
+    )
+
+    if worker:get_name() == 'normal' then
+      -- We also want to train neural nets when they have enough data
+      rspamd_config:add_periodic(ev_base, 0.0,
+        function(cfg, ev_base)
+          return maybe_train_fanns(cfg, ev_base)
+        end)
+    end
+  end)
+end
index 7bc55117dcd4ad849c5e81dd24b48b0092581226..f533a34d892f9f514fba9c73adf07be9b969fb6a 100644 (file)
@@ -36,389 +36,123 @@ local data = {
   }
 }
 
-
--- Lua script to train a row
--- Uses the following keys:
--- key1 - prefix for keys
--- key2 - max count of learns
--- key3 - spam or ham
--- returns 1 or 0: 1 - allow learn, 0 - not allow learn
-local redis_lua_script_can_train = [[
-  local locked = redis.call('GET', KEYS[1] .. '_locked')
-  if locked then return 0 end
-  local nspam = 0
-  local nham = 0
-
-  local ret = redis.call('LLEN', KEYS[1] .. '_spam')
-  if ret then nspam = tonumber(ret) end
-  ret = redis.call('LLEN', KEYS[1] .. '_ham')
-  if ret then nham = tonumber(ret) end
-
-  if KEYS[3] == 'spam' then
-    if nham + 1 >= nspam then return tostring(nspam + 1) end
-  else
-    if nspam + 1 >= nham then return tostring(nham + 1) end
-  end
-
-  return tostring(0)
-]]
-local redis_can_train_sha = nil
-
--- Lua script to load ANN from redis
--- Uses the following keys
--- key1 - prefix for keys
--- key2 - local version
--- returns nil or bulk string if new ANN can be loaded
-local redis_lua_script_maybe_load = [[
-  local locked = redis.call('GET', KEYS[1] .. '_locked')
-  if locked then return false end
-
-  local ver = 0
-  local ret = redis.call('GET', KEYS[1] .. '_version')
-  if ret then ver = tonumber(ret) end
-  if ver > KEYS[2] then return redis.call('GET', KEYS[1] .. '_ann') end
-
-  return false
-]]
-local redis_maybe_load_sha = nil
-
--- Lua script to invalidate ANN from redis
--- Uses the following keys
--- key1 - prefix for keys
-local redis_lua_script_maybe_invalidate = [[
-  local locked = redis.call('GET', KEYS[1] .. '_locked')
-  if locked then return false end
-  redis.call('SET', KEYS[1] .. '_locked', '1')
-  redis.call('SET', KEYS[1] .. '_version', '0')
-  redis.call('DEL', KEYS[1] .. '_spam')
-  redis.call('DEL', KEYS[1] .. '_ham')
-  redis.call('DEL', KEYS[1] .. '_data')
-  redis.call('DEL', KEYS[1] .. '_locked')
-  return 1
-]]
-local redis_maybe_invalidate_sha = nil
-
-local redis_params
-redis_params = rspamd_parse_redis_server('fann_scores')
-
-local fann_prefix = 'RFANN'
+local fann_file
 local max_trains = 1000
 local max_epoch = 100
 local use_settings = false
-local watch_interval = 60.0
 
-local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args)
-  if not ev_base or not redis_params or not callback or not command then
-    return false,nil,nil
-  end
+local function symbols_to_fann_vector(syms, scores)
+  local learn_data = {}
+  local matched_symbols = {}
+  local n = rspamd_config:get_symbols_count()
 
-  local addr
-  local rspamd_redis = require "rspamd_redis"
+  each(function(s, score)
+     matched_symbols[s + 1] = rspamd_util.tanh(score)
+  end, zip(syms, scores))
 
-  if key then
-    if is_write then
-      addr = redis_params['write_servers']:get_upstream_by_hash(key)
-    else
-      addr = redis_params['read_servers']:get_upstream_by_hash(key)
-    end
-  else
-    if is_write then
-      addr = redis_params['write_servers']:get_upstream_master_slave(key)
+  for i=1,n do
+    if matched_symbols[i] then
+      learn_data[i] = matched_symbols[i]
     else
-      addr = redis_params['read_servers']:get_upstream_round_robin(key)
+      learn_data[i] = 0
     end
   end
 
-  if not addr then
-    logger.errx(task, 'cannot select server to make redis request')
-  end
-
-  local options = {
-    ev_base = ev_base,
-    config = cfg,
-    callback = callback,
-    host = addr:get_addr(),
-    timeout = redis_params['timeout'],
-    cmd = command,
-    args = args
-  }
-
-  if redis_params['password'] then
-    options['password'] = redis_params['password']
-  end
-
-  if redis_params['db'] then
-    options['dbname'] = redis_params['db']
-  end
-
-  local ret,conn = rspamd_redis.make_request(options)
-  if not ret then
-    rspamd_logger.errx('cannot execute redis request')
-  end
-  return ret,conn,addr
+  return learn_data
 end
 
--- Metafunctions
-local function fann_size_function(task)
-  local sizes = {
-    100,
-    200,
-    500,
-    1000,
-    2000,
-    4000,
-    10000,
-    20000,
-    30000,
-    100000,
-    200000,
-    400000,
-    800000,
-    1000000,
-    2000000,
-    8000000,
-  }
-
-  local size = task:get_size()
-  for i = 1,#sizes do
-    if sizes[i] >= size then
-      return {i / #sizes}
-    end
+local function gen_fann_file(id)
+  if use_settings then
+    return fann_file .. id
+  else
+    return fann_file
   end
-
-  return {0}
 end
 
-local function fann_images_function(task)
-  local images = task:get_images()
-  local ntotal = 0
-  local njpg = 0
-  local npng = 0
-  local nlarge = 0
-  local nsmall = 0
-
-  if images then
-    for _,img in ipairs(images) do
-      if img:get_type() == 'png' then
-        npng = npng + 1
-      elseif img:get_type() == 'jpeg' then
-        njpg = njpg + 1
-      end
-
-      local w = img:get_width()
-      local h = img:get_height()
-
-      if w > 0 and h > 0 then
-        if w + h > 256 then
-          nlarge = nlarge + 1
-        else
-          nsmall = nsmall + 1
-        end
-      end
+local function load_fann(id)
+  local fname = gen_fann_file(id)
+  local err,st = rspamd_util.stat(fname)
 
-      ntotal = ntotal + 1
-    end
+  if err then
+    return false
   end
-  if ntotal > 0 then
-    njpg = njpg / ntotal
-    npng = npng / ntotal
-    nlarge = nlarge / ntotal
-    nsmall = nsmall / ntotal
-  end
-  return {ntotal,njpg,npng,nlarge,nsmall}
-end
 
-local function fann_nparts_function(task)
-  local nattachments = 0
-  local ntextparts = 0
-  local totalparts = 1
+  local fd = rspamd_util.lock_file(fname)
+  data[id].fann = rspamd_fann.load(fname)
+  rspamd_util.unlock_file(fd) -- closes fd
 
-  local tp = task:get_text_parts()
-  if tp then
-    ntextparts = #tp
-  end
+  if data[id].fann then
+    local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
 
-  local parts = task:get_parts()
+    if n ~= data[id].fann:get_inputs() then
+      rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
+      ' is found in the cache; removing', data[id].fann:get_inputs(), n)
+      data[id].fann = nil
 
-  if parts then
-    for _,p in ipairs(parts) do
-      if p:get_filename() then
-        nattachments = nattachments + 1
+      local ret,err = rspamd_util.unlink(fname)
+      if not ret then
+        rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
+          fname, err)
       end
-      totalparts = totalparts + 1
-    end
-  end
-
-  return {ntextparts/totalparts, nattachments/totalparts}
-end
-
-local function fann_encoding_function(task)
-  local nutf = 0
-  local nother = 0
-
-  local tp = task:get_text_parts()
-  if tp then
-    for _,p in ipairs(tp) do
-      if p:is_utf() then
-        nutf = nutf + 1
+    else
+      local layers = data[id].fann:get_layers()
+
+      if not layers or #layers ~= 5 then
+        rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s, removing',
+          #layers)
+        data[id].fann = nil
+        local ret,err = rspamd_util.unlink(fname)
+        if not ret then
+          rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
+            fname, err)
+        end
       else
-        nother = nother + 1
+        rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fname)
+        return true
       end
     end
-  end
-
-  return {nutf, nother}
-end
-
-local function fann_recipients_function(task)
-  local nmime = 0
-  local nsmtp = 0
-
-  if task:has_recipients('mime') then
-    nmime = #(task:get_recipients('mime'))
-  end
-  if task:has_recipients('smtp') then
-    nsmtp = #(task:get_recipients('smtp'))
-  end
-
-  if nmime > 0 then nmime = 1.0 / nmime end
-  if nsmtp > 0 then nsmtp = 1.0 / nsmtp end
-
-  return {nmime,nsmtp}
-end
-
-local function fann_received_function(task)
-  local ret = 0
-  local rh = task:get_received_headers()
-
-  if rh and #rh > 0 then
-    ret = 1 / #rh
-  end
-
-  return {ret}
-end
-
-local function fann_urls_function(task)
-  if task:has_urls() then
-    return {1.0 / #(task:get_urls())}
-  end
-
-  return {0}
-end
-
-local function fann_attachments_function(task)
-end
-
-local metafunctions = {
-  {
-    cb = fann_size_function,
-    ninputs = 1,
-  },
-  {
-    cb = fann_images_function,
-    ninputs = 5,
-    -- 1 - number of images,
-    -- 2 - number of png images,
-    -- 3 - number of jpeg images
-    -- 4 - number of large images (> 128 x 128)
-    -- 5 - number of small images (< 128 x 128)
-  },
-  {
-    cb = fann_nparts_function,
-    ninputs = 2,
-    -- 1 - number of text parts
-    -- 2 - number of attachments
-  },
-  {
-    cb = fann_encoding_function,
-    ninputs = 2,
-    -- 1 - number of utf parts
-    -- 2 - number of non-utf parts
-  },
-  {
-    cb = fann_recipients_function,
-    ninputs = 2,
-    -- 1 - number of mime rcpt
-    -- 2 - number of smtp rcpt
-  },
-  {
-    cb = fann_received_function,
-    ninputs = 1,
-  },
-  {
-    cb = fann_urls_function,
-    ninputs = 1,
-  },
-}
-
-local function gen_metatokens(task)
-  local metatokens = {}
-  for _,mt in ipairs(metafunctions) do
-    local ct = mt.cb(task)
-
-    for _,tok in ipairs(ct) do
-      table.insert(metatokens, tok)
-    end
-  end
-
-  return metatokens
-end
-
-local function count_metatokens()
-  local total = 0
-  for _,mt in ipairs(metafunctions) do
-    total = total + mt.ninputs
-  end
-
-  return total
-end
-
-local function symbols_to_fann_vector(syms, scores)
-  local learn_data = {}
-  local matched_symbols = {}
-  local n = rspamd_config:get_symbols_count()
-
-  each(function(s, score)
-     matched_symbols[s + 1] = rspamd_util.tanh(score)
-  end, zip(syms, scores))
-
-  for i=1,n do
-    if matched_symbols[i] then
-      learn_data[i] = matched_symbols[i]
-    else
-      learn_data[i] = 0
+  else
+    rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fname)
+    local ret,err = rspamd_util.unlink(fname)
+    if not ret then
+      rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
+        fname, err)
     end
   end
 
-  return learn_data
-end
-
-local function gen_fann_prefix(id)
-  if use_settings then
-    return fann_prefix .. id
-  else
-    return fann_prefix
-  end
+  return false
 end
 
-local function is_fann_valid(ann)
-  if ann then
-    local n = rspamd_config:get_symbols_count() + count_metatokens()
+local function check_fann(id)
+  if data[id].fann then
+    local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
 
-    if n ~= ann:get_inputs() then
+    if n ~= data[id].fann:get_inputs() then
       rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
-      ' is found in the cache', ann:get_inputs(), n)
-      return false
+      ' is found in the cache', data[id].fann:get_inputs(), n)
+      data[id].fann = nil
     end
-    local layers = ann:get_layers()
+    local layers = data[id].fann:get_layers()
 
     if not layers or #layers ~= 5 then
       rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s',
         #layers)
-      return false
+      data[id].fann = nil
     end
+  end
+
+  local fname = gen_fann_file(id)
+  local err,st = rspamd_util.stat(fname)
+
+  if not err then
+    local mtime = st['mtime']
 
-    return true
+    if mtime > data[id].fann_mtime then
+      rspamd_logger.infox(rspamd_config, 'have more fresh version of fann ' ..
+        'file: %s -> %s, need to reload %s', data[id].fann_mtime, mtime, fname)
+      data[id].fann_mtime = mtime
+      data[id].fann = nil
+    end
   end
 end
 
@@ -431,10 +165,12 @@ local function fann_scores_filter(task)
    end
   end
 
+  check_fann(id)
+
   if data[id].fann then
     local symbols,scores = task:get_symbols_numeric()
     local fann_data = symbols_to_fann_vector(symbols, scores)
-    local mt = gen_metatokens(task)
+    local mt = rspamd_gen_metatokens(task)
 
     for _,tok in ipairs(mt) do
       table.insert(fann_data, tok)
@@ -451,6 +187,10 @@ local function fann_scores_filter(task)
       local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
       task:insert_result(fann_symbol_ham, result, symscore, id)
     end
+  else
+    if load_fann(id) then
+      fann_scores_filter(task)
+    end
   end
 end
 
@@ -460,44 +200,69 @@ local function create_train_fann(n, id)
   data[id].epoch = 0
 end
 
-local function load_or_invalidate_fann(data, id, ev_base)
-  local err,ann_data = rspamd_util.zstd_decompress(data)
-  local ann
+local function fann_train_callback(score, required_score, results, cf, id, opts, extra)
+  local n = cf:get_symbols_count() + rspamd_count_metatokens()
+  local fname = gen_fann_file(id)
 
-  if err or not ann_data then
-    rspamd_logger.errx('cannot decompress ann: %s', err)
-  else
-    ann = rspamd_fann.load_data(ann_data)
+  if not data[id].fann_train then
+    create_train_fann(n, id)
+  end
+
+  if data[id].fann_train:get_inputs() ~= n then
+    rspamd_logger.infox(cf, 'fann has incorrect number of inputs: %s, %s symbols' ..
+      ' is found in the cache', data[id].fann_train:get_inputs(), n)
+    create_train_fann(n, id)
   end
 
-  if is_fann_valid(ann) then
-    data[id].fann = ann
+  if data[id].ntrains > max_trains then
+    -- Store fann on disk
+    local res = false
+
+    local err,st = rspamd_util.stat(fname)
+    if err then
+      local fd,err = rspamd_util.create_file(fname)
+      if not fd then
+        rspamd_logger.errx(cf, 'cannot save fann in %s: %s', fname, err)
+      else
+        rspamd_util.lock_file(fname, fd)
+        res = data[id].fann_train:save(fname)
+        rspamd_util.unlock_file(fd) -- Closes fd as well
+      end
+    else
+      local fd = rspamd_util.lock_file(fname)
+      res = data[id].fann_train:save(fname)
+      rspamd_util.unlock_file(fd) -- Closes fd as well
+    end
+
+    if not res then
+      rspamd_logger.errx(cf, 'cannot save fann in %s', fname)
+    else
+      data[id].exist = true
+      data[id].ntrains = 0
+      data[id].epoch = data[id].epoch + 1
+    end
   else
-    local function redis_invalidate_cb(err, data)
+    if not data[id].checked then
+      data[id].checked = true
+      local err,st = rspamd_util.stat(fname)
       if err then
-        rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, err)
-      elseif type(data) == 'string' then
-        rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, err)
+        data[id].exist = false
       end
     end
-    -- Invalidate ANN
-    rspamd_logger.infox('invalidate ANN %s')
-    redis_make_request(ev_base,
-      rspamd_config,
-      nil,
-      true, -- is write
-      redis_invalidate_cb, --callback
-      'EVALSHA', -- command
-      {redis_maybe_invalidate_sha, 1, fann_prefix .. id}
-    )
+    if not data[id].exist then
+      rspamd_logger.infox(cf, 'not enough trains for fann %s, %s left', fname,
+        max_trains - data[id].ntrains)
+    end
   end
-end
 
-local function fann_train_callback(score, required_score, results, cf, id, opts, extra, ev_base)
-  local fname = gen_fann_prefix(id)
+  if data[id].epoch > max_epoch then
+    -- Re-create fann
+    rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fname,
+      max_epoch)
+    create_train_fann(n, id)
+  end
 
   local learn_spam, learn_ham = false, false
-
   if opts['spam_score'] then
     learn_spam = score >= opts['spam_score']
   else
@@ -510,157 +275,21 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
   end
 
   if learn_spam or learn_ham then
-    local k
-    if learn_spam then k = 'spam' else k = 'ham' end
-
-    local function learn_vec_cb(err, data)
-      if err then
-        rspamd_logger.errx(rspamd_config, 'cannot store train vector: %s', err)
-      end
-    end
-
-    local function can_train_cb(err, data)
-      if not err and tonumber(data) > 0 then
-        local learn_data = symbols_to_fann_vector(
-          map(function(r) return r[1] end, results),
-          map(function(r) return r[2] end, results)
-        )
-        -- Add filtered meta tokens
-        each(function(e) table.insert(learn_data, e) end, extra)
-        local str = rspamd_util.zstd_compress(table.concat(learn_data, ';'))
-
-        redis_make_request(ev_base,
-          rspamd_config,
-          nil,
-          true, -- is write
-          learn_vec_cb, --callback
-          'LPUSH', -- command
-          {fname .. '_' .. k, str} -- arguments
-        )
-      else
-        if err then
-          rspamd_logger.errx(rspamd_config, 'cannot check if we can train: %s', err)
-        end
-      end
-    end
-
-    redis_make_request(ev_base,
-      rspamd_config,
-      nil,
-      false, -- is write
-      can_train_cb, --callback
-      'EVALSHA', -- command
-      {redis_can_train_sha, '3', fname, tostring(max_trains), k} -- arguments
+    local learn_data = symbols_to_fann_vector(
+      map(function(r) return r[1] end, results),
+      map(function(r) return r[2] end, results)
     )
-  end
-end
-
-local function train_fann(cfg, ev_base, elt)
+    -- Add filtered meta tokens
+    each(function(e) table.insert(learn_data, e) end, extra)
 
-end
-
-local function maybe_train_fanns(cfg, ev_base)
-  local function members_cb(err, data)
-    if err then
-      rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
-    elseif type(data) == 'table' then
-      each(function(i, elt)
-        local redis_len_cb = function(err, data)
-          if err then
-            rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err)
-          elseif data and type(data) == 'number' or type(data) == 'string' then
-            if tonumber(data) and tonumber(data) > max_trains then
-              train_fann(cfg, ev_base, elt)
-            end
-          end
-        end
-
-        local local_ver = 0
-        local numelt = tonumber(elt)
-        if data[numelt] then
-          if data[numelt].version then
-            local_ver = data[numelt].version
-          end
-        end
-        redis_make_request(ev_base,
-          rspamd_config,
-          nil,
-          false, -- is write
-          redis_len_cb, --callback
-          'LLEN', -- command
-          {fann_prefix .. elt .. '_spam'}
-        )
-      end,
-      data)
-    end
-  end
-
-  if not redis_maybe_load_sha then
-    -- Plan new event early
-    return 1.0
-  end
-  -- First we need to get all fanns stored in our Redis
-  redis_make_request(ev_base,
-    rspamd_config,
-    nil,
-    false, -- is write
-    members_cb, --callback
-    'SMEMBERS', -- command
-    {fann_prefix} -- arguments
-  )
-
-  return watch_interval
-end
-
-local function check_fanns(cfg, ev_base)
-  local function members_cb(err, data)
-    if err then
-      rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
-    elseif type(data) == 'table' then
-      each(function(i, elt)
-        local redis_update_cb = function(err, data)
-          if err then
-            rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err)
-          elseif data and type(data) == 'string' then
-            load_or_invalidate_fann(data, elt, ev_base)
-          end
-        end
-
-        local local_ver = 0
-        local numelt = tonumber(elt)
-        if data[numelt] then
-          if data[numelt].version then
-            local_ver = data[numelt].version
-          end
-        end
-        redis_make_request(ev_base,
-          rspamd_config,
-          nil,
-          false, -- is write
-          redis_update_cb, --callback
-          'EVALSHA', -- command
-          {redis_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)}
-        )
-      end,
-      data)
+    if learn_spam then
+      data[id].fann_train:train(learn_data, {1.0})
+    else
+      data[id].fann_train:train(learn_data, {-1.0})
     end
-  end
 
-  if not redis_maybe_load_sha then
-    -- Plan new event early
-    return 1.0
+    data[id].ntrains = data[id].ntrains + 1
   end
-  -- First we need to get all fanns stored in our Redis
-  redis_make_request(ev_base,
-    rspamd_config,
-    nil,
-    false, -- is write
-    members_cb, --callback
-    'SMEMBERS', -- command
-    {fann_prefix} -- arguments
-  )
-
-  return watch_interval
 end
 
 -- Initialization part
@@ -674,128 +303,71 @@ end
 if not rspamd_fann.is_enabled() then
   rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
     'module is eventually disabled')
+
   return
 else
-  use_settings = opts['use_settings']
-  rspamd_config:set_metric_symbol({
-    name = fann_symbol_spam,
-    score = 3.0,
-    description = 'Neural network SPAM',
-    group = 'fann'
-  })
-  local id = rspamd_config:register_symbol({
-    name = fann_symbol_spam,
-    type = 'postfilter',
-    priority = 5,
-    callback = fann_scores_filter
-  })
-  rspamd_config:set_metric_symbol({
-    name = fann_symbol_ham,
-    score = -2.0,
-    description = 'Neural network HAM',
-    group = 'fann'
-  })
-  rspamd_config:register_symbol({
-    name = fann_symbol_ham,
-    type = 'virtual',
-    parent = id
-  })
-  if opts['train'] then
-    rspamd_config:add_on_load(function(cfg)
-      if opts['train']['max_train'] then
-        max_trains = opts['train']['max_train']
-      end
-      if opts['train']['max_epoch'] then
-        max_epoch = opts['train']['max_epoch']
-      end
-      local ret = cfg:register_worker_script("log_helper",
-        function(score, req_score, results, cf, id, extra, ev_base)
-          -- map (snd x) (filter (fst x == module_id) extra)
-          local extra_fann = map(function(e) return e[2] end,
-            filter(function(e) return e[1] == module_log_id end, extra))
-          if use_settings then
-            fann_train_callback(score, req_score, results, cf,
-              tostring(id), opts['train'], extra_fann, ev_base)
-          else
-            fann_train_callback(score, req_score, results, cf, '0',
-              opts['train'], extra_fann, ev_base)
-          end
+  if not opts['fann_file'] then
+    rspamd_logger.warnx(rspamd_config, 'fann_scores module requires ' ..
+      '`fann_file` to be specified')
+  else
+    fann_file = opts['fann_file']
+    use_settings = opts['use_settings']
+    rspamd_config:set_metric_symbol({
+      name = fann_symbol_spam,
+      score = 3.0,
+      description = 'Neural network SPAM',
+      group = 'fann'
+    })
+    local id = rspamd_config:register_symbol({
+      name = fann_symbol_spam,
+      type = 'postfilter',
+      priority = 5,
+      callback = fann_scores_filter
+    })
+    rspamd_config:set_metric_symbol({
+      name = fann_symbol_ham,
+      score = -2.0,
+      description = 'Neural network HAM',
+      group = 'fann'
+    })
+    rspamd_config:register_symbol({
+      name = fann_symbol_ham,
+      type = 'virtual',
+      parent = id
+    })
+    if opts['train'] then
+      rspamd_config:add_on_load(function(cfg)
+        if opts['train']['max_train'] then
+          max_trains = opts['train']['max_train']
+        end
+        if opts['train']['max_epoch'] then
+          max_epoch = opts['train']['max_epoch']
+        end
+        local ret = cfg:register_worker_script("log_helper",
+          function(score, req_score, results, cf, id, extra)
+            -- map (snd x) (filter (fst x == module_id) extra)
+            local extra_fann = map(function(e) return e[2] end,
+              filter(function(e) return e[1] == module_log_id end, extra))
+            if use_settings then
+              fann_train_callback(score, req_score, results, cf,
+                tostring(id), opts['train'], extra_fann)
+            else
+              fann_train_callback(score, req_score, results, cf, '0',
+                opts['train'], extra_fann)
+            end
         end)
 
-      if not ret then
-        rspamd_logger.errx(cfg, 'cannot find worker "log_helper"')
-      end
-    end)
-    -- This is needed to pass extra tokens from worker to log_helper
-    rspamd_plugins["fann_score"] = {
-      log_callback = function(task)
-        return totable(map(
-          function(tok) return {module_log_id, tok} end,
-          gen_metatokens(task)))
-      end
-    }
-  end
-  -- Add training scripts
-  rspamd_config:add_on_load(function(cfg, ev_base, worker)
-    local function can_train_sha_cb(err, data)
-      if err or not data or type(data) ~= 'string' then
-        rspamd_logger.errx(cfg, 'cannot save redis train script: %s', err)
-      else
-        redis_can_train_sha = tostring(data)
-      end
-    end
-    redis_make_request(ev_base,
-      rspamd_config,
-      nil,
-      true, -- is write
-      can_train_sha_cb, --callback
-      'SCRIPT', -- command
-      {'LOAD', redis_lua_script_can_train} -- arguments
-    )
-
-    local function maybe_load_sha_cb(err, data)
-      if err or not data or type(data) ~= 'string' then
-        rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err)
-      else
-        redis_maybe_load_sha = tostring(data)
-
-        rspamd_config:add_periodic(ev_base, 0.0,
-          function(cfg, ev_base)
-            return check_fanns(cfg, ev_base)
-          end)
-      end
-    end
-    redis_make_request(ev_base,
-      rspamd_config,
-      nil,
-      true, -- is write
-      maybe_load_sha_cb, --callback
-      'SCRIPT', -- command
-      {'LOAD', redis_lua_script_maybe_load} -- arguments
-    )
-
-    local function maybe_invalidate_sha_cb(err, data)
-      if err or not data or type(data) ~= 'string' then
-        rspamd_logger.errx(cfg, 'cannot save redis invalidate script: %s', err)
-      else
-        redis_maybe_invalidate_sha = tostring(data)
-      end
-    end
-    redis_make_request(ev_base,
-      rspamd_config,
-      nil,
-      true, -- is write
-      maybe_invalidate_sha_cb, --callback
-      'SCRIPT', -- command
-      {'LOAD', redis_lua_script_maybe_invalidate} -- arguments
-    )
-
-    if worker:get_name() == 'normal' then
-      -- We also want to train neural nets when they have enough data
-      rspamd_config:add_periodic(ev_base, 0.0,
-        function(cfg, ev_base)
-          return maybe_train_fanns(cfg, ev_base)
-        end)
+        if not ret then
+          rspamd_logger.errx(cfg, 'cannot find worker "log_helper"')
+        end
+      end)
+      rspamd_plugins["fann_score"] = {
+        log_callback = function(task)
+          return totable(map(
+            function(tok) return {module_log_id, tok} end,
+            rspamd_gen_metatokens(task)))
+        end
+      }
     end
-  end)
+  end
 end