]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] New bayes expiry plugin
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 24 Feb 2018 19:09:20 +0000 (19:09 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 24 Feb 2018 19:09:20 +0000 (19:09 +0000)
src/plugins/lua/bayes_expiry.lua

index d922f3f55839003872cb8d7eb3a03651e7273cea..af955465dfed92a1869284b49f9d98073e0ecc37 100644 (file)
@@ -20,280 +20,232 @@ if confighelp then
 end
 
 local N = 'bayes_expiry'
+local E = {}
 local logger = require "rspamd_logger"
-local mempool = require "rspamd_mempool"
-local util = require "rspamd_util"
 local lutil = require "lua_util"
 local lredis = require "lua_redis"
 
-local pool = mempool.create()
 local settings = {
-  interval = 604800,
-  statefile = string.format('%s/%s', rspamd_paths['DBDIR'], 'bayes_expired'),
-  variables = {
-    ot_bayes_ttl = 31536000, -- one year
-    ot_min_age = 7776000, -- 90 days
-    ot_min_count = 5,
-  },
-  symbols = {},
-  timeout = 60,
+  interval = 60, -- one iteration step per minute
+  count = 1000, -- check up to 1000 keys on each iteration
+  threshold = 10, -- require at least 10 occurrences to increase expire
+  epsilon_common = 0.01, -- eliminate common if spam to ham rate is equal to this epsilon
+  common_ttl_divisor = 100, -- how should we discriminate common elements
+  significant_factor = 3.0 / 4.0, -- which tokens should we update
+  classifiers = {},
 }
 
-local VAR_NAME = 'bayes_expired'
-local EXPIRE_SCRIPT_TMPL = [[local result = {}
-local OT_BAYES_TTL = ${ot_bayes_ttl}
-local OT_MIN_AGE = ${ot_min_age}
-local OT_MIN_COUNT = ${ot_min_count}
-local symbol = ARGV[1]
-local prefixes = redis.call('SMEMBERS', symbol .. '_keys')
-for _, pfx in ipairs(prefixes) do
-  local res = redis.call('SCAN', '0', 'MATCH', pfx .. '_*')
-  local cursor, data = res[1], res[2]
-  while data do
-    local key_name = table.remove(data)
-    if key_name then
-      local h, s = redis.call('HMGET', key_name, 'H', 'S')
-      if (h or s) then
-        if not s then s = 0 else s = tonumber(s) end
-        if not h then h = 0 else h = tonumber(h) end
-        if s < OT_MIN_COUNT and h < OT_MIN_COUNT then
-          local ttl = redis.call('TTL', key_name)
-          if ttl > 0 then
-            local age = OT_BAYES_TTL - ttl
-            if age > OT_MIN_AGE then
-              table.insert(result, key_name)
-            end
-          end
+local template = {
+
+}
+
+local function check_redis_classifier(cls, cfg)
+  -- Skip old classifiers
+  if cls.new_schema then
+    local symbol_spam, symbol_ham
+    local expiry = (cls.expiry or cls.expire)
+    -- Load symbols from statfiles
+    local statfiles = cls.statfile
+    for _,stf in ipairs(statfiles) do
+      local symbol = stf.symbol or 'undefined'
+
+      local spam
+      if stf.spam then
+        spam = stf.spam
+      else
+        if string.match(symbol:upper(), 'SPAM') then
+          spam = true
+        else
+          spam = false
         end
       end
-    else
-      if cursor == "0" then
-        data = nil
+
+      if spam then
+        symbol_spam = symbol
       else
-        local res = redis.call('SCAN', tostring(cursor), 'MATCH', pfx .. '_*')
-        cursor, data = res[1], res[2]
+        symbol_ham = symbol
+      end
+    end
+
+    if not symbol_spam or not symbol_ham or not expiry then
+      return
+    end
+    -- Now try to load redis_params if needed
+
+    local redis_params = {}
+    if not lredis.try_load_redis_servers(cls, rspamd_config, redis_params) then
+      if not lredis.try_load_redis_servers(cfg[N] or E, rspamd_config, redis_params) then
+        if not lredis.try_load_redis_servers(cfg['redis'] or E, rspamd_config, redis_params) then
+          return false
+        end
       end
     end
+
+    table.insert(settings.classifiers, {
+      symbol_spam = symbol_spam,
+      symbol_ham = symbol_ham,
+      redis_params = redis_params,
+      expiry = expiry
+    })
   end
 end
-return table.concat(result, string.char(31))]]
 
-local function configure_bayes_expiry()
-  local opts = rspamd_config:get_all_opt(N)
-  if not type(opts) == 'table' then return false end
-  for k, v in pairs(opts) do
-    settings[k] = v
+-- Check classifiers and try find the appropriate ones
+local obj = rspamd_config:get_ucl()
+
+local classifier = obj.classifier
+
+if classifier then
+  if classifier[1] then
+    for _,cls in ipairs(classifier) do
+      if cls.bayes then cls = cls.bayes end
+      if cls.backend and cls.backend == 'redis' then
+        check_redis_classifier(cls, obj)
+      end
+    end
+  else
+    if classifier.bayes then
+
+      classifier = classifier.bayes
+      if classifier[1] then
+        for _,cls in ipairs(classifier) do
+          if cls.backend and cls.backend == 'redis' then
+            check_redis_classifier(cls, obj)
+          end
+        end
+      else
+        if classifier.backend and classifier.backend == 'redis' then
+          check_redis_classifier(classifier, obj)
+        end
+      end
+    end
   end
-  if not settings.symbols[1] then
-    logger.warn('No symbols configured, not enabling expiry')
-    return false
+end
+
+
+local opts = rspamd_config:get_all_opt(N)
+
+if opts then
+  for k,v in pairs(opts) do
+    settings[k] = v
   end
-  return true
 end
 
-if not configure_bayes_expiry() then
-  lutil.disable_module(N, 'config')
-  return
+-- Fill template
+template.count = settings.count
+template.threshold = settings.threshold
+template.common_ttl_divisor = settings.common_ttl_divisor
+template.epsilon_common = settings.epsilon_common
+template.significant_factor = settings.significant_factor
+
+for k,v in pairs(template) do
+  template[k] = tostring(v)
 end
 
-local function get_redis_params(ev_base, symbol)
-  local redis_params
-  local copts = rspamd_config:get_all_opt('classifier')
-  if not type(copts) == 'table' then
-    logger.errx(ev_base, "Couldn't get classifier configuration")
-    return
-  end
-  if type(copts.backend) == 'table' then
-    redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true)
+-- Arguments:
+-- [1] = symbol pattern
+-- [2] = expire value
+-- [3] = cursor
+-- returns new cursor
+local expiry_script = [[
+  local ret = redis.call('SCAN', KEYS[3], 'MATCH', KEYS[1], 'COUNT', '${count}')
+  local next = ret[1]
+  local keys = ret[2]
+  local nelts = 0
+  local extended = 0
+  local discriminated = 0
+
+  for _,key in ipairs(keys) do
+    local values = redis.call('HMGET', key, 'H', 'S')
+    local ham = tonumber(values[1]) or 0
+    local spam = tonumber(values[2]) or 0
+
+    if ham > ${threshold} or spam > ${threshold} then
+      local total = ham + spam
+
+      if total > 0 then
+        if ham / total > ${significant_factor} or spam / total > ${significant_factor} then
+          redis.replicate_commands()
+          redis.call('EXPIRE', key, KEYS[2])
+          extended = extended + 1
+        elseif math.abs(ham - spam) <= total * ${epsilon_common} then
+          local ttl = redis.call('TTL', key)
+          redis.replicate_commands()
+          redis.call('EXPIRE', key, tonumber(ttl) / ${common_ttl_divisor})
+          discriminated = discriminated + 1
+        end
+      end
+    end
+    nelts = nelts + 1
   end
-  if redis_params then return redis_params end
-  if type(copts.statfile) == 'table' then
-    for _, stf in ipairs(copts.statfile) do
-      if stf.name == symbol then
-        redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true)
+
+  return {next, nelts, extended, discriminated}
+]]
+
+local cur = 0
+
+local function expire_step(cls, ev_base, worker)
+
+  local function redis_step_cb(err, data)
+    if err then
+      logger.errx(rspamd_config, 'cannot perform expiry step: %s', err)
+    elseif type(data) == 'table' then
+      local next,nelts,extended,discriminated = tonumber(data[1]), tonumber(data[2]),
+        tonumber(data[3]),tonumber(data[4])
+
+      if next ~= 0 then
+        logger.infox(rspamd_config, 'executed expiry step for bayes: %s items checked, %s extended, %s discriminated',
+            nelts, extended, discriminated)
+      else
+        logger.infox(rspamd_config, 'executed final expiry step for bayes: %s items checked, %s extended, %s discriminated',
+            nelts, extended, discriminated)
       end
+
+      cur = next
     end
   end
-  if redis_params then return redis_params end
-  redis_params = lredis.rspamd_parse_redis_server(nil, copts, false)
-  redis_params.timeout = settings.timeout
-  return redis_params
+  lredis.exec_redis_script(cls.script,
+      {ev_base = ev_base, is_write = true},
+      redis_step_cb,
+      {'RS*_*', cls.expiry, cur}
+  )
 end
 
 rspamd_config:add_on_load(function (_, ev_base, worker)
-  local processed_symbols, expire_script_sha
   -- Exit unless we're the first 'controller' worker
   if not (worker:get_name() == 'controller' and worker:get_index() == 0) then return end
-  -- Persist mempool variable to statefile on shutdown
-  rspamd_config:register_finish_script(function ()
-    local stamp = pool:get_variable(VAR_NAME, 'double')
-    if not stamp then
-      logger.warnx(ev_base, 'No last bayes expiry to persist to disk')
-      return
-    end
-    local f, err = io.open(settings['statefile'], 'w')
-    if err then
-      logger.errx(ev_base, 'Unable to write statefile to disk: %s', err)
-      return
-    end
-    if f then
-      f:write(pool:get_variable(VAR_NAME, 'double'))
-      f:close()
-    end
-  end)
-  local expire_symbol
-  local function load_scripts(redis_params, cont, p1, p2)
-    local function load_script_cb(err, data)
-      if err then
-        logger.errx(ev_base, 'Error loading script: %s', err)
-      else
-        if type(data) == 'string' then
-          expire_script_sha = data
-          logger.debugm(N, ev_base, 'expire_script_sha: %s', expire_script_sha)
-          if type(cont) == 'function' then
-            cont(p1, p2)
-          end
-        end
+
+  local unique_redis_params = {}
+  -- Push redis script to all unique redis servers
+  for _,cls in ipairs(settings.classifiers) do
+    local seen = false
+    for _,rp in ipairs(unique_redis_params) do
+      if lutil.table_cmp(rp, cls.redis_params) then
+        seen = true
       end
     end
-    local scripttxt = lutil.template(EXPIRE_SCRIPT_TMPL, settings.variables)
-    local ret = lredis.redis_make_request_taskless(ev_base,
-      rspamd_config,
-      redis_params,
-      nil,
-      true, -- is write
-      load_script_cb, --callback
-      'SCRIPT', -- command
-      {'LOAD', scripttxt}
-    )
-    if not ret then
-      logger.errx(ev_base, 'Error loading script')
-    end
-  end
-  local function continue_expire()
-    for _, symbol in ipairs(settings.symbols) do
-      if not processed_symbols[symbol] then
-        local redis_params = get_redis_params(ev_base, symbol)
-        if not redis_params then
-          processed_symbols[symbol] = true
-          logger.errx(ev_base, "Couldn't get redis params")
-        else
-          load_scripts(redis_params, expire_symbol, redis_params, symbol)
-          break
-        end
-      end
+
+    if not seen then
+      table.insert(unique_redis_params, cls.redis_params)
     end
   end
-  expire_symbol = function(redis_params, symbol)
-    local function del_keys_cb(err, data)
-      if err then
-        logger.errx(ev_base, 'Redis request failed: %s', err)
-      end
-      processed_symbols[symbol] = true
-      continue_expire()
-    end
-    local function get_keys_cb(err, data)
-      if err then
-        logger.errx(ev_base, 'Redis request failed: %s', err)
-        processed_symbols[symbol] = true
-        continue_expire()
-      else
-        if type(data) == 'string' then
-          if data == "" then
-            data = {}
-          else
-            data = lutil.rspamd_str_split(data, string.char(31))
-          end
-        end
-        if type(data) == 'table' then
-          if not data[1] then
-            logger.warnx(ev_base, 'No keys to delete: %s', symbol)
-            processed_symbols[symbol] = true
-            continue_expire()
-          else
-            local ret = lredis.redis_make_request_taskless(ev_base,
-              rspamd_config,
-              redis_params,
-              nil,
-              true, -- is write
-              del_keys_cb, --callback
-              'DEL', -- command
-              data
-            )
-            if not ret then
-              logger.errx(ev_base, 'Redis request failed')
-              processed_symbols[symbol] = true
-              continue_expire()
-            end
-          end
-        else
-          logger.warnx(ev_base, 'No keys to delete: %s', symbol)
-          processed_symbols[symbol] = true
-          continue_expire()
-        end
+
+  for _,rp in ipairs(unique_redis_params) do
+    local script_id = lredis.add_redis_script(lutil.template(expiry_script,
+        template), rp)
+
+    for _,cls in ipairs(settings.classifiers) do
+      if lutil.table_cmp(rp, cls.redis_params) then
+        cls.script = script_id
       end
     end
-    local ret = lredis.redis_make_request_taskless(ev_base,
-      rspamd_config,
-      redis_params,
-      nil,
-      false, -- is write
-      get_keys_cb, --callback
-      'EVALSHA', -- command
-      {expire_script_sha, 0, symbol}
-    )
-    if not ret then
-      logger.errx(ev_base, 'Redis request failed')
-      processed_symbols[symbol] = true
-      continue_expire()
-    end
-  end
-  local function begin_expire(time)
-    local stamp = time or util.get_time()
-    pool:set_variable(VAR_NAME, stamp)
-    processed_symbols = {}
-    continue_expire()
   end
+
   -- Expire tokens at regular intervals
-  local function schedule_regular_expiry()
+  for _,cls in ipairs(settings.classifiers) do
     rspamd_config:add_periodic(ev_base, settings['interval'], function ()
-      begin_expire()
+      expire_step(cls, ev_base, worker)
       return true
     end)
   end
-  -- Expire tokens and reschedule expiry
-  local function schedule_intermediate_expiry(when)
-    rspamd_config:add_periodic(ev_base, when, function ()
-      begin_expire()
-      schedule_regular_expiry()
-      return false
-    end)
-  end
-  -- Try read statefile on startup
-  local stamp
-  local f, err = io.open(settings['statefile'], 'r')
-  if err then
-    logger.warnx(ev_base, 'Failed to open statefile: %s', err)
-  end
-  if f then
-    io.input(f)
-    stamp = tonumber(io.read())
-    pool:set_variable(VAR_NAME, stamp)
-  end
-  local time = util.get_time()
-  if not stamp then
-    logger.debugm(N, ev_base, 'No state found - expiring stats immediately')
-    begin_expire(time)
-    schedule_regular_expiry()
-    return
-  end
-  local delta = stamp - time + settings['interval']
-  if delta <= 0 then
-    logger.debugm(N, ev_base, 'Last expiry is too old - expiring stats immediately')
-    begin_expire(time)
-    schedule_regular_expiry()
-    return
-  end
-  logger.debugm(N, ev_base, 'Scheduling next expiry in %s seconds', delta)
-  schedule_intermediate_expiry(delta)
 end)