]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Try to improve tokens expiration logic
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 12 Mar 2018 13:13:02 +0000 (13:13 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 12 Mar 2018 13:13:02 +0000 (13:13 +0000)
src/plugins/lua/bayes_expiry.lua

index 725a163c760f613e823af791abb2d5824f058a20..c8ff4d0ba8acf5bbf4470aa8061fa54e1bcc3ff2 100644 (file)
@@ -28,7 +28,6 @@ local lredis = require "lua_redis"
 local settings = {
   interval = 60, -- one iteration step per minute
   count = 1000, -- check up to 1000 keys on each iteration
-  threshold = 10, -- require at least 10 occurrences to increase expire
   epsilon_common = 0.01, -- eliminate common if spam to ham rate is equal to this epsilon
   common_ttl_divisor = 10, -- how should we discriminate common elements
   significant_factor = 3.0 / 4.0, -- which tokens should we update
@@ -154,48 +153,69 @@ local expiry_script = [[
   local nelts = 0
   local extended = 0
   local discriminated = 0
+  local tokens = {}
+  local sum, sum_squares = 0, 0
 
   for _,key in ipairs(keys) do
     local values = redis.call('HMGET', key, 'H', 'S')
     local ham = tonumber(values[1]) or 0
     local spam = tonumber(values[2]) or 0
+    local ttl = redis.call('TTL', key)
+    tokens[key] = {
+      ham,
+      spam,
+      ttl
+    }
+    local total = spam + ham
+    sum = sum + total
+    sum_squares = sum_squares + total * total
+    nelts = nelts + 1
+  end
+  redis.replicate_commands()
+
+  local mean = sum / nelts
+  local stddev = math.sqrt(sum_squares / nelts - mean * mean)
+
+  for key,token in pairs(tokens) do
+    local ham, spam, ttl = token[1], token[2], token[3]
+    local threshold = mean
+    local total = spam + ham
 
-    if ham > ${threshold} or spam > ${threshold} then
-      local total = ham + spam
+    if total >= threshold and total > 0 then
       if ham / total > ${significant_factor} or spam / total > ${significant_factor} then
-        redis.replicate_commands()
         redis.call('EXPIRE', key, math.floor(KEYS[2]))
         extended = extended + 1
-      elseif math.abs(ham - spam) <= total * ${epsilon_common} then
-        local ttl = redis.call('TTL', key)
-        redis.replicate_commands()
-        redis.call('EXPIRE', key, math.floor(tonumber(ttl) / ${common_ttl_divisor}))
-        discriminated = discriminated + 1
       end
     end
-    nelts = nelts + 1
+    if total == 0 or math.abs(ham - spam) <= total * ${epsilon_common} then
+      discriminated = discriminated + 1
+      redis.call('EXPIRE', key, math.floor(tonumber(ttl) / ${common_ttl_divisor}))
+    end
   end
 
-  return {next, nelts, extended, discriminated}
+  return {next, nelts, extended, discriminated, mean, stddev}
 ]]
 
 local cur = 0
 
 local function expire_step(cls, ev_base, worker)
-
   local function redis_step_cb(err, data)
     if err then
       logger.errx(rspamd_config, 'cannot perform expiry step: %s', err)
     elseif type(data) == 'table' then
-      local next,nelts,extended,discriminated = tonumber(data[1]), tonumber(data[2]),
-        tonumber(data[3]),tonumber(data[4])
+      local next,nelts,extended,discriminated,mean,stddev = tonumber(data[1]),
+        tonumber(data[2]),
+        tonumber(data[3]),
+        tonumber(data[4]),
+        tonumber(data[5]),
+        tonumber(data[6])
 
       if next ~= 0 then
-        logger.infox(rspamd_config, 'executed expiry step for bayes: %s items checked, %s extended, %s discriminated',
-            nelts, extended, discriminated)
+        logger.infox(rspamd_config, 'executed expiry step for bayes: %s items checked, %s extended, %s discriminated, %s mean, %s std',
+            nelts, extended, discriminated, mean, stddev)
       else
-        logger.infox(rspamd_config, 'executed final expiry step for bayes: %s items checked, %s extended, %s discriminated',
-            nelts, extended, discriminated)
+        logger.infox(rspamd_config, 'executed final expiry step for bayes: %s items checked, %s extended, %s discriminated, %s mean, %s std',
+            nelts, extended, discriminated, mean, stddev)
       end
 
       cur = next