]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Apply library functions in plugins
authorAndrew Lewis <nerf@judo.za.org>
Thu, 8 Mar 2018 13:16:24 +0000 (15:16 +0200)
committerAndrew Lewis <nerf@judo.za.org>
Thu, 8 Mar 2018 13:16:44 +0000 (15:16 +0200)
lualib/lua_util.lua
src/plugins/lua/force_actions.lua
src/plugins/lua/ratelimit.lua
src/plugins/lua/url_reputation.lua
src/plugins/lua/url_tags.lua

index 88925aeecdffd14668cc38c27a8d36eff7d5a5cc..d41d79fea4b490639e2eedae610923b265b4aff9 100644 (file)
@@ -291,6 +291,35 @@ end
 
 exports.check_experimental = check_experimental
 
+--[[[
+-- @function lua_util.list_to_hash(list)
+-- Converts numerically-indexed table to table indexed by values
+-- @param {table} list numerically-indexed table or string, which is treated as a one-element list
+-- @return {table} table indexed by values
+-- @example
+-- local h = lua_util.list_to_hash({"a", "b"})
+-- -- h contains {a = true, b = true}
+--]]
+local function list_to_hash(list)
+  if type(list) == 'table' then
+    if list[1] then
+      local h = {}
+      for _, e in ipairs(list) do
+        h[e] = true
+      end
+      return h
+    else
+      return list
+    end
+  elseif type(list) == 'string' then
+    local h = {}
+    h[list] = true
+    return h
+  end
+end
+
+exports.list_to_hash = list_to_hash
+
 --[[[
 -- @function lua_util.parse_time_interval(str)
 -- Parses human readable time interval
index a733a14257586166838c27f8e105893c8ce6b67c..1d99ce52bb07df53dafd23908c2b4532c509a572 100644 (file)
@@ -25,6 +25,7 @@ local E = {}
 local N = 'force_actions'
 
 local fun = require "fun"
+local lua_util = require "lua_util"
 local rspamd_cryptobox_hash = require "rspamd_cryptobox_hash"
 local rspamd_expression = require "rspamd_expression"
 local rspamd_logger = require "rspamd_logger"
@@ -89,24 +90,6 @@ local function gen_cb(expr, act, pool, message, subject, raction, honor, limit)
 
 end
 
-local function list_to_hash(list)
-  if type(list) == 'table' then
-    if list[1] then
-      local h = {}
-      for _, e in ipairs(list) do
-        h[e] = true
-      end
-      return h
-    else
-      return list
-    end
-  elseif type(list) == 'string' then
-    local h = {}
-    h[list] = true
-    return h
-  end
-end
-
 local function configure_module()
   local opts = rspamd_config:get_all_opt(N)
   if not opts then
@@ -153,8 +136,8 @@ local function configure_module()
         local subject = sett.subject
         local message = sett.message
         local lim = sett.limit or 0
-        local raction = list_to_hash(sett.require_action)
-        local honor = list_to_hash(sett.honor_action)
+        local raction = lua_util.list_to_hash(sett.require_action)
+        local honor = lua_util.list_to_hash(sett.honor_action)
         local cb, atoms = gen_cb(expr, action, rspamd_config:get_mempool(),
           message, subject, raction, honor, lim)
         if cb and atoms then
index d18b79bfe861be7133e023aafe4d473e20d24438..324454f4d5cc13af3a5b04aa197e4301d7281f61 100644 (file)
@@ -55,7 +55,7 @@ local lua_util = require "lua_util"
 
 local user_keywords = {'user'}
 
-local redis_script_sha
+local redis_script_id
 local redis_script = [[local bucket
 local limited = false
 local buckets = {}
@@ -160,29 +160,13 @@ end
 return results]]
 
 local function load_scripts(cfg, ev_base)
-  local function rl_script_cb(err, data)
-    if err then
-      rspamd_logger.errx(cfg, 'Script loading failed: ' .. err)
-    elseif type(data) == 'string' then
-      redis_script_sha = data
-    end
-  end
   local script
   if ratelimit_symbol then
     script = redis_script_symbol
   else
     script = redis_script
   end
-  lua_redis.redis_make_request_taskless(
-    ev_base,
-    cfg,
-    redis_params,
-    nil, -- key
-    true, -- is write
-    rl_script_cb, --callback
-    'SCRIPT', -- command
-    {'LOAD', script}
-  )
+  redis_script_id = lua_redis.add_redis_script(script, redis_params)
 end
 
 local limit_parser
@@ -410,9 +394,9 @@ local function process_buckets(task, buckets)
   end
   local redis_cb = rl_redis_cb
   if ratelimit_symbol then redis_cb = rl_symbol_redis_cb end
-  local args = {redis_script_sha, #buckets}
+  local kwargs, args = {}, {}
   for _, bucket in ipairs(buckets) do
-    table.insert(args, bucket[2])
+    table.insert(kwargs, bucket[2])
   end
   for _, bucket in ipairs(buckets) do
     if use_ip_score then
@@ -449,14 +433,7 @@ local function process_buckets(task, buckets)
   end
   table.insert(args, rspamd_util.get_time())
   table.insert(args, task:get_queue_id() or task:get_uid())
-  local ret = rspamd_redis_make_request(task,
-    redis_params, -- connect params
-    nil, -- hash key
-    true, -- is write
-    redis_cb, --callback
-    'evalsha', -- command
-    args -- arguments
-  )
+  local ret = lua_redis.exec_redis_script(redis_script_id, {task = task, is_write = true}, redis_cb, kwargs, args)
   if not ret then
     rspamd_logger.errx(task, 'got error connecting to redis')
   end
index c3856f3b66125ffb4706ab48fe096daf00e17d7a..e7d35697dc92929ea14e87ad56c720693c6fd0a6 100644 (file)
@@ -24,7 +24,7 @@ end
 local E = {}
 local N = 'url_reputation'
 
-local whitelist, redis_params, redis_incr_script_sha
+local whitelist, redis_params, redis_incr_script_id
 local settings = {
   expire = 86400, -- 1 day
   key_prefix = 'Ur.',
@@ -74,21 +74,7 @@ end
 
 -- Function to load the script
 local function load_scripts(cfg, ev_base)
-  local function redis_incr_script_cb(err, data)
-    if err then
-      rspamd_logger.errx(cfg, 'Increment script loading failed: ' .. err)
-    else
-      redis_incr_script_sha = tostring(data)
-    end
-  end
-  rspamd_redis.redis_make_request_taskless(ev_base,
-    rspamd_config,
-    nil,
-    true, -- is write
-    redis_incr_script_cb, --callback
-    'SCRIPT', -- command
-    {'LOAD', redis_incr_script}
-  )
+  redis_incr_script_id = rspamd_redis.add_redis_script(redis_incr_script, redis_params)
 end
 
 -- Calculates URL reputation
@@ -175,8 +161,6 @@ local function url_reputation_check(task)
       if which then
         -- Update reputation for guilty domain only
         rk = {
-          redis_incr_script_sha,
-          2,
           settings.key_prefix .. which .. '_total',
           settings.key_prefix .. which .. '_' .. scale[reputation],
         }
@@ -248,7 +232,7 @@ local function url_reputation_check(task)
           end
         end
 
-        rk = {redis_incr_script_sha, 0}
+        rk = {}
         local added = 0
         if most_relevant then
           tlds = {most_relevant}
@@ -264,16 +248,11 @@ local function url_reputation_check(task)
           added = added + 1
         end
       end
-      if rk[3] then
-        rk[2] = (#rk - 2)
-        local ret = rspamd_redis_make_request(task,
-          redis_params,
-          rk[3],
-          true, -- is write
-          redis_incr_cb, --callback
-          'EVALSHA', -- command
-          rk
-        )
+      if rk[2] then
+        local ret = rspamd_redis.exec_redis_script(redis_incr_script_id,
+          {task = task, is_write = true},
+          redis_incr_cb,
+          rk)
         if not ret then
           rspamd_logger.errx(task, 'couldnt schedule increment')
         end
index e64aa926f31db50d3199c3744d5bfdbdbc588e84..c0f7ffa744ae7a06d20d55dc7d7ef89a9fc502bb 100644 (file)
@@ -23,7 +23,7 @@ end
 
 local N = 'url_tags'
 
-local redis_params, redis_set_script_sha
+local redis_params, redis_set_script_id
 local settings = {
   -- lifetime for tags
   expire = 3600, -- 1 hour
@@ -36,60 +36,9 @@ local settings = {
 local rspamd_logger = require "rspamd_logger"
 local rspamd_util = require "rspamd_util"
 local lua_util = require "lua_util"
+local lua_redis = require "lua_redis"
 local ucl = require "ucl"
 
--- This function is used for taskless redis requests (to load scripts)
-local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args)
-  if not ev_base or not redis_params or not callback or not command then
-    return false,nil,nil
-  end
-
-  local addr
-  local rspamd_redis = require "rspamd_redis"
-
-  if key then
-    if is_write then
-      addr = redis_params['write_servers']:get_upstream_by_hash(key)
-    else
-      addr = redis_params['read_servers']:get_upstream_by_hash(key)
-    end
-  else
-    if is_write then
-      addr = redis_params['write_servers']:get_upstream_master_slave(key)
-    else
-      addr = redis_params['read_servers']:get_upstream_round_robin(key)
-    end
-  end
-
-  if not addr then
-    rspamd_logger.errx(cfg, 'cannot select server to make redis request')
-  end
-
-  local options = {
-    ev_base = ev_base,
-    config = cfg,
-    callback = callback,
-    host = addr:get_addr(),
-    timeout = redis_params['timeout'],
-    cmd = command,
-    args = args
-  }
-
-  if redis_params['password'] then
-    options['password'] = redis_params['password']
-  end
-
-  if redis_params['db'] then
-    options['dbname'] = redis_params['db']
-  end
-
-  local ret,conn = rspamd_redis.make_request(options)
-  if not ret then
-    rspamd_logger.errx('cannot execute redis request')
-  end
-  return ret,conn,addr
-end
-
 -- Tags are stored in format: [timestamp]|[tag1],[timestamp]|[tag2]
 local redis_set_script_head = 'local expiry = '
 local redis_set_script_tail = [[
@@ -136,41 +85,17 @@ end
 
 -- Function to load the script
 local function load_scripts(cfg, ev_base)
-  local function redis_set_script_cb(err, data)
-    if err then
-      rspamd_logger.errx(cfg, 'Set script loading failed: ' .. err)
-    else
-      redis_set_script_sha = tostring(data)
-    end
-  end
   local set_script =
     redis_set_script_head ..
     settings.expire ..
     '\n' ..
     redis_set_script_tail
-  redis_make_request(ev_base,
-    rspamd_config,
-    nil,
-    true, -- is write
-    redis_set_script_cb, --callback
-    'SCRIPT', -- command
-    {'LOAD', set_script}
-  )
+  redis_set_script_id = lua_redis.add_redis_script(set_script, redis_params)
 end
 
 -- Saves tags to redis
 local function tags_save(task)
 
-  -- Handle errors (reloads script if necessary)
-  local function redis_set_cb(err)
-    if err then
-      rspamd_logger.errx(task, 'Redis error: %s', err)
-      if string.match(err, 'NOSCRIPT') then
-        load_scripts(rspamd_config, task:get_ev_base())
-      end
-    end
-  end
-
   local tags = {}
   -- Figure out what tags are present for each TLD
   for _, url in ipairs(task:get_urls(false)) do
@@ -251,26 +176,13 @@ local function tags_save(task)
     end
     table.insert(redis_args, table.concat(tmp4, '/'))
   end
-
-  local redis_final = {redis_set_script_sha}
-  table.insert(redis_final, #redis_keys)
-  for _, k in ipairs(redis_keys) do
-    table.insert(redis_final, k)
-  end
-  for _, a in ipairs(redis_args) do
-    table.insert(redis_final, a)
-  end
-  table.insert(redis_final, rspamd_util.get_time())
+  table.insert(redis_args, rspamd_util.get_time())
 
   -- Send query to redis
-  rspamd_redis_make_request(task,
-    redis_params,
-    nil,
-    true, -- is write
-    redis_set_cb, --callback
-    'EVALSHA', -- command
-    redis_final
-  )
+  lua_redis.exec_redis_script(
+    redis_set_script_id,
+    {task = task, is_write = true},
+    function() end, redis_keys, redis_args)
 end
 
 local function tags_restore(task)
@@ -362,26 +274,7 @@ end
 for k, v in pairs(opts) do
   settings[k] = v
 end
-local function list_to_hash(list)
-  if type(list) == 'table' then
-    if list[1] then
-      local h = {}
-      for _, e in ipairs(list) do
-        h[e] = true
-      end
-      return h
-    else
-      return list
-    end
-  elseif type(list) == 'string' then
-    local h = {}
-    h[list] = true
-    return h
-  else
-    return {}
-  end
-end
-settings.ignore_tags = list_to_hash(settings.ignore_tags)
+settings.ignore_tags = lua_util.list_to_hash(settings.ignore_tags)
 
 rspamd_config:add_on_load(function(cfg, ev_base, worker)
   load_scripts(cfg, ev_base)