]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Bayes expiry plugin 1670/head
authorAndrew Lewis <nerf@judo.za.org>
Sat, 10 Jun 2017 13:45:29 +0000 (15:45 +0200)
committerAndrew Lewis <nerf@judo.za.org>
Sat, 10 Jun 2017 13:45:29 +0000 (15:45 +0200)
lualib/lua_redis.lua
lualib/lua_util.lua
src/plugins/lua/arc.lua
src/plugins/lua/bayes_expiry.lua [new file with mode: 0644]
src/plugins/lua/dkim_signing.lua
src/plugins/lua/metadata_exporter.lua

index 42a0aacefc6aef76687075951f9908158cd48a5a..0dc5872fe4b90897e7f4107fb170cf0ef0cd43b2 100644 (file)
@@ -5,7 +5,7 @@ local exports = {}
 -- This function parses redis server definition using either
 -- specific server string for this module or global
 -- redis section
-local function rspamd_parse_redis_server(module_name)
+local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
 
   local result = {}
   local default_port = 6379
@@ -71,7 +71,12 @@ local function rspamd_parse_redis_server(module_name)
   end
 
   -- Try local options
-  local opts = rspamd_config:get_all_opt(module_name)
+  local opts
+  if not module_opts then
+    opts = rspamd_config:get_all_opt(module_name)
+  else
+    opts = module_opts
+  end
   local ret = false
 
   if opts then
@@ -82,6 +87,8 @@ local function rspamd_parse_redis_server(module_name)
     return result
   end
 
+  if no_fallback then return nil end
+
   -- Try global options
   opts = rspamd_config:get_all_opt('redis')
 
index 0a824dca14942aa23e48e11715bf738f090b429c..1f53d51edf4148e4b8e815460de2b7d051c3038c 100644 (file)
@@ -30,4 +30,14 @@ exports.round = function(num, numDecimalPlaces)
   return math.floor(num * mult) / mult
 end
 
+exports.template = function(tmpl, keys)
+  local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
+  local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) }
+  local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") }
+
+  local template_grammar = lpeg.Cs((var + var_braced + 1)^0)
+
+  return lpeg.match(template_grammar, tmpl)
+end
+
 return exports
index fc1c65769059a5cad7e1df3e508b9d411e755b3c..a29417d748fd94be4e8cc3b4a1a8cdfc740d9526 100644 (file)
@@ -335,19 +335,6 @@ rspamd_config:register_symbol({
 rspamd_config:register_dependency(id, symbols['spf_allow_symbol'])
 rspamd_config:register_dependency(id, symbols['dkim_allow_symbol'])
 
--- Signatures part
-local function simple_template(tmpl, keys)
-  local lpeg = require "lpeg"
-
-  local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
-  local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) }
-  local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") }
-
-  local template_grammar = lpeg.Cs((var + var_braced + 1)^0)
-
-  return lpeg.match(template_grammar, tmpl)
-end
-
 local function arc_sign_seal(task, params, header)
   local arc_sigs = task:cache_get('arc-sigs')
   local arc_seals = task:cache_get('arc-seals')
@@ -514,7 +501,7 @@ local function arc_signing_cb(task)
     end
   else
     if (p.key and p.selector) then
-      p.key = simple_template(p.key, {domain = p.domain, selector = p.selector})
+      p.key = lua_util.template(p.key, {domain = p.domain, selector = p.selector})
       local dret, hdr = dkim_sign(task, p)
       if dret then
         return arc_sign_seal(task, p, hdr)
diff --git a/src/plugins/lua/bayes_expiry.lua b/src/plugins/lua/bayes_expiry.lua
new file mode 100644 (file)
index 0000000..d7df264
--- /dev/null
@@ -0,0 +1,296 @@
+--[[
+Copyright (c) 2017, Andrew Lewis <nerf@judo.za.org>
+Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]] --
+
+if confighelp then
+  return
+end
+
+local N = 'bayes_expiry'
+local logger = require "rspamd_logger"
+local mempool = require "rspamd_mempool"
+local util = require "rspamd_util"
+local lutil = require "lua_util"
+local lredis = require "lua_redis"
+
+local pool = mempool.create()
+local settings = {
+  interval = 604800,
+  statefile = string.format('%s/%s', rspamd_paths['DBDIR'], 'bayes_expired'),
+  variables = {
+    ot_bayes_ttl = 31536000, -- one year
+    ot_min_age = 7776000, -- 90 days
+    ot_min_count = 5,
+  },
+  symbols = {},
+  timeout = 60,
+}
+
+local VAR_NAME = 'bayes_expired'
+local EXPIRE_SCRIPT_TMPL = [[local result = {}
+local OT_BAYES_TTL = ${ot_bayes_ttl}
+local OT_MIN_AGE = ${ot_min_age}
+local OT_MIN_COUNT = ${ot_min_count}
+local symbol = ARGV[1]
+local prefixes = redis.call('SMEMBERS', symbol .. '_keys')
+for _, pfx in ipairs(prefixes) do
+  local res = redis.call('SCAN', '0', 'MATCH', pfx .. '_*')
+  local cursor, data = res[1], res[2]
+  while data do
+    local key_name = table.remove(data)
+    if key_name then
+      local h, s = redis.call('HMGET', key_name, 'H', 'S')
+      if (h or s) then
+        if not s then s = 0 else s = tonumber(s) end
+        if not h then h = 0 else h = tonumber(h) end
+        if s < OT_MIN_COUNT and h < OT_MIN_COUNT then
+          local ttl = redis.call('TTL', key_name)
+          if ttl > 0 then
+            local age = OT_BAYES_TTL - ttl
+            if age > OT_MIN_AGE then
+              table.insert(result, key_name)
+            end
+          end
+        end
+      end
+    else
+      if cursor == "0" then
+        data = nil
+      else
+        local res = redis.call('SCAN', tostring(cursor), 'MATCH', pfx .. '_*')
+        cursor, data = res[1], res[2]
+      end
+    end
+  end
+end
+return table.concat(result, string.char(31))]]
+
+local function configure_bayes_expiry()
+  local opts = rspamd_config:get_all_opt(N)
+  if not type(opts) == 'table' then return false end
+  for k, v in pairs(opts) do
+    settings[k] = v
+  end
+  if not settings.symbols[1] then
+    logger.warn('No symbols configured, not enabling expiry')
+    return false
+  end
+  return true
+end
+
+if not configure_bayes_expiry() then return end
+
+local function get_redis_params(ev_base, symbol)
+  local redis_params
+  local copts = rspamd_config:get_all_opt('classifier')
+  if not type(copts) == 'table' then
+    logger.errx(ev_base, "Couldn't get classifier configuration")
+    return
+  end
+  if type(copts.backend) == 'table' then
+    redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true)
+  end
+  if redis_params then return redis_params end
+  if type(copts.statfile) == 'table' then
+    for _, stf in ipairs(copts.statfile) do
+      if stf.name == symbol then
+        redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true)
+      end
+    end
+  end
+  if redis_params then return redis_params end
+  redis_params = lredis.rspamd_parse_redis_server(nil, copts, false)
+  redis_params.timeout = settings.timeout
+  return redis_params
+end
+
+rspamd_config:add_on_load(function (_, ev_base, worker)
+  local processed_symbols, expire_script_sha
+  -- Exit unless we're the first 'normal' worker
+  if not (worker:get_name() == 'normal' and worker:get_index() == 0) then return end
+  -- Persist mempool variable to statefile on shutdown
+  rspamd_config:register_finish_script(function ()
+    local stamp = pool:get_variable(VAR_NAME, 'double')
+    if not stamp then
+      logger.warnx(ev_base, 'No last bayes expiry to persist to disk')
+      return
+    end
+    local f, err = io.open(settings['statefile'], 'w')
+    if err then
+      logger.errx(ev_base, 'Unable to write statefile to disk: %s', err)
+      return
+    end
+    if f then
+      f:write(pool:get_variable(VAR_NAME, 'double'))
+      f:close()
+    end
+  end)
+  local expire_symbol
+  local function load_scripts(redis_params, cont, p1, p2)
+    local function load_script_cb(err, data)
+      if err then
+        logger.errx(ev_base, 'Error loading script: %s', err)
+      else
+        if type(data) == 'string' then
+          expire_script_sha = data
+          logger.debugm(N, ev_base, 'expire_script_sha: %s', expire_script_sha)
+          if type(cont) == 'function' then
+            cont(p1, p2)
+          end
+        end
+      end
+    end
+    local scripttxt = lutil.template(EXPIRE_SCRIPT_TMPL, settings.variables)
+    local ret = lredis.redis_make_request_taskless(ev_base,
+      rspamd_config,
+      redis_params,
+      nil,
+      true, -- is write
+      load_script_cb, --callback
+      'SCRIPT', -- command
+      {'LOAD', scripttxt}
+    )
+    if not ret then
+      logger.errx(ev_base, 'Error loading script')
+    end
+  end
+  local function continue_expire()
+    for _, symbol in ipairs(settings.symbols) do
+      if not processed_symbols[symbol] then
+        local redis_params = get_redis_params(ev_base, symbol)
+        if not redis_params then
+          processed_symbols[symbol] = true
+          logger.errx(ev_base, "Couldn't get redis params")
+        else
+          load_scripts(redis_params, expire_symbol, redis_params, symbol)
+          break
+        end
+      end
+    end
+  end
+  expire_symbol = function(redis_params, symbol)
+    local function del_keys_cb(err, data)
+      if err then
+        logger.errx(ev_base, 'Redis request failed: %s', err)
+      end
+      processed_symbols[symbol] = true
+      continue_expire()
+    end
+    local function get_keys_cb(err, data)
+      if err then
+        logger.errx(ev_base, 'Redis request failed: %s', err)
+        processed_symbols[symbol] = true
+        continue_expire()
+      else
+        if type(data) == 'string' then
+          if data == "" then
+            data = {}
+          else
+            data = lutil.rspamd_str_split(data, string.char(31))
+          end
+        end
+        if type(data) == 'table' then
+          if not data[1] then
+            logger.warnx(ev_base, 'No keys to delete: %s', symbol)
+            processed_symbols[symbol] = true
+            continue_expire()
+          else
+            local ret = lredis.redis_make_request_taskless(ev_base,
+              rspamd_config,
+              redis_params,
+              nil,
+              true, -- is write
+              del_keys_cb, --callback
+              'DEL', -- command
+              data
+            )
+            if not ret then
+              logger.errx(ev_base, 'Redis request failed')
+              processed_symbols[symbol] = true
+              continue_expire()
+            end
+          end
+        else
+          logger.warnx(ev_base, 'No keys to delete: %s', symbol)
+          processed_symbols[symbol] = true
+          continue_expire()
+        end
+      end
+    end
+    local ret = lredis.redis_make_request_taskless(ev_base,
+      rspamd_config,
+      redis_params,
+      nil,
+      false, -- is write
+      get_keys_cb, --callback
+      'EVALSHA', -- command
+      {expire_script_sha, 0, symbol}
+    )
+    if not ret then
+      logger.errx(ev_base, 'Redis request failed')
+      processed_symbols[symbol] = true
+      continue_expire()
+    end
+  end
+  local function begin_expire(time)
+    local stamp = time or util.get_time()
+    pool:set_variable(VAR_NAME, stamp)
+    processed_symbols = {}
+    continue_expire()
+  end
+  -- Expire tokens at regular intervals
+  local function schedule_regular_expiry()
+    rspamd_config:add_periodic(ev_base, settings['interval'], function ()
+      begin_expire()
+      return true
+    end)
+  end
+  -- Expire tokens and reschedule expiry
+  local function schedule_intermediate_expiry(when)
+    rspamd_config:add_periodic(ev_base, when, function ()
+      begin_expire()
+      schedule_regular_expiry()
+      return false
+    end)
+  end
+  -- Try read statefile on startup
+  local stamp
+  local f, err = io.open(settings['statefile'], 'r')
+  if err then
+    logger.warnx(ev_base, 'Failed to open statefile: %s', err)
+  end
+  if f then
+    io.input(f)
+    stamp = tonumber(io.read())
+    pool:set_variable(VAR_NAME, stamp)
+  end
+  local time = util.get_time()
+  if not stamp then
+    logger.debugm(N, ev_base, 'No state found - expiring stats immediately')
+    begin_expire(time)
+    schedule_regular_expiry()
+    return
+  end
+  local delta = stamp - time + settings['interval']
+  if delta <= 0 then
+    logger.debugm(N, ev_base, 'Last expiry is too old - expiring stats immediately')
+    begin_expire(time)
+    schedule_regular_expiry()
+    return
+  end
+  logger.debugm(N, ev_base, 'Scheduling next expiry in %s seconds', delta)
+  schedule_intermediate_expiry(delta)
+end)
index 78505bf927a36bb25dc0fcfa5d85d623edc385ec..edb4db2f624937ee329a47cb0556880d091d5c63 100644 (file)
@@ -15,6 +15,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ]]--
 
+local lutil = require "lua_util"
 local rspamd_logger = require "rspamd_logger"
 local dkim_sign_tools = require "dkim_sign_tools"
 
@@ -46,18 +47,6 @@ local N = 'dkim_signing'
 local redis_params
 local sign_func = rspamd_plugins.dkim.sign
 
-local function simple_template(tmpl, keys)
-  local lpeg = require "lpeg"
-
-  local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
-  local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) }
-  local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") }
-
-  local template_grammar = lpeg.Cs((var + var_braced + 1)^0)
-
-  return lpeg.match(template_grammar, tmpl)
-end
-
 local function dkim_signing_cb(task)
   local ret,p = dkim_sign_tools.prepare_dkim_signing(N, task, settings)
 
@@ -123,7 +112,7 @@ local function dkim_signing_cb(task)
     end
   else
     if (p.key and p.selector) then
-      p.key = simple_template(p.key, {domain = p.domain, selector = p.selector})
+      p.key = lutil.template(p.key, {domain = p.domain, selector = p.selector})
       local sret, _ = sign_func(task, p)
       return sret
     else
index f0b3f175d77254e3b00a75236837f07fd181d939..2268c86f5bd8f3c39aedafa1f3b43b4687bb6540 100644 (file)
@@ -22,6 +22,7 @@ end
 -- A plugin that pushes metadata (or whole messages) to external services
 
 local redis_params
+local lutil = require "lua_util"
 local rspamd_http = require "rspamd_http"
 local rspamd_tcp = require "rspamd_tcp"
 local rspamd_util = require "rspamd_util"
@@ -136,18 +137,6 @@ local function get_general_metadata(task, flatten, no_content)
   return r
 end
 
-local function simple_template(tmpl, keys)
-  local lpeg = require "lpeg"
-
-  local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
-  local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) }
-  local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") }
-
-  local template_grammar = lpeg.Cs((var + var_braced + 1)^0)
-
-  return lpeg.match(template_grammar, tmpl)
-end
-
 local formatters = {
   default = function(task)
     return task:get_content()
@@ -158,7 +147,7 @@ local formatters = {
     meta.mail_to = rule.mail_to or settings.mail_to
     meta.our_message_id = rspamd_util.random_hex(12) .. '@rspamd'
     meta.date = rspamd_util.time_to_string(rspamd_util.get_time())
-    return simple_template(rule.email_template or settings.email_template, meta)
+    return lutil.template(rule.email_template or settings.email_template, meta)
   end,
   json = function(task)
     return ucl.to_format(get_general_metadata(task), 'json-compact')