diff options
-rw-r--r-- | lualib/lua_util.lua | 29 | ||||
-rw-r--r-- | src/plugins/lua/force_actions.lua | 23 | ||||
-rw-r--r-- | src/plugins/lua/ratelimit.lua | 33 | ||||
-rw-r--r-- | src/plugins/lua/url_reputation.lua | 37 | ||||
-rw-r--r-- | src/plugins/lua/url_tags.lua | 125 |
5 files changed, 54 insertions, 193 deletions
diff --git a/lualib/lua_util.lua b/lualib/lua_util.lua index 88925aeec..d41d79fea 100644 --- a/lualib/lua_util.lua +++ b/lualib/lua_util.lua @@ -292,6 +292,35 @@ end exports.check_experimental = check_experimental --[[[ +-- @function lua_util.list_to_hash(list) +-- Converts numerically-indexed table to table indexed by values +-- @param {table} list numerically-indexed table or string, which is treated as a one-element list +-- @return {table} table indexed by values +-- @example +-- local h = lua_util.list_to_hash({"a", "b"}) +-- -- h contains {a = true, b = true} +--]] +local function list_to_hash(list) + if type(list) == 'table' then + if list[1] then + local h = {} + for _, e in ipairs(list) do + h[e] = true + end + return h + else + return list + end + elseif type(list) == 'string' then + local h = {} + h[list] = true + return h + end +end + +exports.list_to_hash = list_to_hash + +--[[[ -- @function lua_util.parse_time_interval(str) -- Parses human readable time interval -- Accepts 's' for seconds, 'm' for minutes, 'h' for hours, 'd' for days, diff --git a/src/plugins/lua/force_actions.lua b/src/plugins/lua/force_actions.lua index a733a1425..1d99ce52b 100644 --- a/src/plugins/lua/force_actions.lua +++ b/src/plugins/lua/force_actions.lua @@ -25,6 +25,7 @@ local E = {} local N = 'force_actions' local fun = require "fun" +local lua_util = require "lua_util" local rspamd_cryptobox_hash = require "rspamd_cryptobox_hash" local rspamd_expression = require "rspamd_expression" local rspamd_logger = require "rspamd_logger" @@ -89,24 +90,6 @@ local function gen_cb(expr, act, pool, message, subject, raction, honor, limit) end -local function list_to_hash(list) - if type(list) == 'table' then - if list[1] then - local h = {} - for _, e in ipairs(list) do - h[e] = true - end - return h - else - return list - end - elseif type(list) == 'string' then - local h = {} - h[list] = true - return h - end -end - local function configure_module() local opts = rspamd_config:get_all_opt(N) if not opts then @@ -153,8 +136,8 @@ local function configure_module() local subject = sett.subject local message = sett.message local lim = sett.limit or 0 - local raction = list_to_hash(sett.require_action) - local honor = list_to_hash(sett.honor_action) + local raction = lua_util.list_to_hash(sett.require_action) + local honor = lua_util.list_to_hash(sett.honor_action) local cb, atoms = gen_cb(expr, action, rspamd_config:get_mempool(), message, subject, raction, honor, lim) if cb and atoms then diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua index d18b79bfe..324454f4d 100644 --- a/src/plugins/lua/ratelimit.lua +++ b/src/plugins/lua/ratelimit.lua @@ -55,7 +55,7 @@ local lua_util = require "lua_util" local user_keywords = {'user'} -local redis_script_sha +local redis_script_id local redis_script = [[local bucket local limited = false local buckets = {} @@ -160,29 +160,13 @@ end return results]] local function load_scripts(cfg, ev_base) - local function rl_script_cb(err, data) - if err then - rspamd_logger.errx(cfg, 'Script loading failed: ' .. err) - elseif type(data) == 'string' then - redis_script_sha = data - end - end local script if ratelimit_symbol then script = redis_script_symbol else script = redis_script end - lua_redis.redis_make_request_taskless( - ev_base, - cfg, - redis_params, - nil, -- key - true, -- is write - rl_script_cb, --callback - 'SCRIPT', -- command - {'LOAD', script} - ) + redis_script_id = lua_redis.add_redis_script(script, redis_params) end local limit_parser @@ -410,9 +394,9 @@ local function process_buckets(task, buckets) end local redis_cb = rl_redis_cb if ratelimit_symbol then redis_cb = rl_symbol_redis_cb end - local args = {redis_script_sha, #buckets} + local kwargs, args = {}, {} for _, bucket in ipairs(buckets) do - table.insert(args, bucket[2]) + table.insert(kwargs, bucket[2]) end for _, bucket in ipairs(buckets) do if use_ip_score then @@ -449,14 +433,7 @@ local function process_buckets(task, buckets) end table.insert(args, rspamd_util.get_time()) table.insert(args, task:get_queue_id() or task:get_uid()) - local ret = rspamd_redis_make_request(task, - redis_params, -- connect params - nil, -- hash key - true, -- is write - redis_cb, --callback - 'evalsha', -- command - args -- arguments - ) + local ret = lua_redis.exec_redis_script(redis_script_id, {task = task, is_write = true}, redis_cb, kwargs, args) if not ret then rspamd_logger.errx(task, 'got error connecting to redis') end diff --git a/src/plugins/lua/url_reputation.lua b/src/plugins/lua/url_reputation.lua index c3856f3b6..e7d35697d 100644 --- a/src/plugins/lua/url_reputation.lua +++ b/src/plugins/lua/url_reputation.lua @@ -24,7 +24,7 @@ end local E = {} local N = 'url_reputation' -local whitelist, redis_params, redis_incr_script_sha +local whitelist, redis_params, redis_incr_script_id local settings = { expire = 86400, -- 1 day key_prefix = 'Ur.', @@ -74,21 +74,7 @@ end -- Function to load the script local function load_scripts(cfg, ev_base) - local function redis_incr_script_cb(err, data) - if err then - rspamd_logger.errx(cfg, 'Increment script loading failed: ' .. err) - else - redis_incr_script_sha = tostring(data) - end - end - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - nil, - true, -- is write - redis_incr_script_cb, --callback - 'SCRIPT', -- command - {'LOAD', redis_incr_script} - ) + redis_incr_script_id = rspamd_redis.add_redis_script(redis_incr_script, redis_params) end -- Calculates URL reputation @@ -175,8 +161,6 @@ local function url_reputation_check(task) if which then -- Update reputation for guilty domain only rk = { - redis_incr_script_sha, - 2, settings.key_prefix .. which .. '_total', settings.key_prefix .. which .. '_' .. scale[reputation], } @@ -248,7 +232,7 @@ local function url_reputation_check(task) end end - rk = {redis_incr_script_sha, 0} + rk = {} local added = 0 if most_relevant then tlds = {most_relevant} @@ -264,16 +248,11 @@ local function url_reputation_check(task) added = added + 1 end end - if rk[3] then - rk[2] = (#rk - 2) - local ret = rspamd_redis_make_request(task, - redis_params, - rk[3], - true, -- is write - redis_incr_cb, --callback - 'EVALSHA', -- command - rk - ) + if rk[2] then + local ret = rspamd_redis.exec_redis_script(redis_incr_script_id, + {task = task, is_write = true}, + redis_incr_cb, + rk) if not ret then rspamd_logger.errx(task, 'couldnt schedule increment') end diff --git a/src/plugins/lua/url_tags.lua b/src/plugins/lua/url_tags.lua index e64aa926f..c0f7ffa74 100644 --- a/src/plugins/lua/url_tags.lua +++ b/src/plugins/lua/url_tags.lua @@ -23,7 +23,7 @@ end local N = 'url_tags' -local redis_params, redis_set_script_sha +local redis_params, redis_set_script_id local settings = { -- lifetime for tags expire = 3600, -- 1 hour @@ -36,60 +36,9 @@ local settings = { local rspamd_logger = require "rspamd_logger" local rspamd_util = require "rspamd_util" local lua_util = require "lua_util" +local lua_redis = require "lua_redis" local ucl = require "ucl" --- This function is used for taskless redis requests (to load scripts) -local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args) - if not ev_base or not redis_params or not callback or not command then - return false,nil,nil - end - - local addr - local rspamd_redis = require "rspamd_redis" - - if key then - if is_write then - addr = redis_params['write_servers']:get_upstream_by_hash(key) - else - addr = redis_params['read_servers']:get_upstream_by_hash(key) - end - else - if is_write then - addr = redis_params['write_servers']:get_upstream_master_slave(key) - else - addr = redis_params['read_servers']:get_upstream_round_robin(key) - end - end - - if not addr then - rspamd_logger.errx(cfg, 'cannot select server to make redis request') - end - - local options = { - ev_base = ev_base, - config = cfg, - callback = callback, - host = addr:get_addr(), - timeout = redis_params['timeout'], - cmd = command, - args = args - } - - if redis_params['password'] then - options['password'] = redis_params['password'] - end - - if redis_params['db'] then - options['dbname'] = redis_params['db'] - end - - local ret,conn = rspamd_redis.make_request(options) - if not ret then - rspamd_logger.errx('cannot execute redis request') - end - return ret,conn,addr -end - -- Tags are stored in format: [timestamp]|[tag1],[timestamp]|[tag2] local redis_set_script_head = 'local expiry = ' local redis_set_script_tail = [[ @@ -136,41 +85,17 @@ end -- Function to load the script local function load_scripts(cfg, ev_base) - local function redis_set_script_cb(err, data) - if err then - rspamd_logger.errx(cfg, 'Set script loading failed: ' .. err) - else - redis_set_script_sha = tostring(data) - end - end local set_script = redis_set_script_head .. settings.expire .. '\n' .. redis_set_script_tail - redis_make_request(ev_base, - rspamd_config, - nil, - true, -- is write - redis_set_script_cb, --callback - 'SCRIPT', -- command - {'LOAD', set_script} - ) + redis_set_script_id = lua_redis.add_redis_script(set_script, redis_params) end -- Saves tags to redis local function tags_save(task) - -- Handle errors (reloads script if necessary) - local function redis_set_cb(err) - if err then - rspamd_logger.errx(task, 'Redis error: %s', err) - if string.match(err, 'NOSCRIPT') then - load_scripts(rspamd_config, task:get_ev_base()) - end - end - end - local tags = {} -- Figure out what tags are present for each TLD for _, url in ipairs(task:get_urls(false)) do @@ -251,26 +176,13 @@ local function tags_save(task) end table.insert(redis_args, table.concat(tmp4, '/')) end - - local redis_final = {redis_set_script_sha} - table.insert(redis_final, #redis_keys) - for _, k in ipairs(redis_keys) do - table.insert(redis_final, k) - end - for _, a in ipairs(redis_args) do - table.insert(redis_final, a) - end - table.insert(redis_final, rspamd_util.get_time()) + table.insert(redis_args, rspamd_util.get_time()) -- Send query to redis - rspamd_redis_make_request(task, - redis_params, - nil, - true, -- is write - redis_set_cb, --callback - 'EVALSHA', -- command - redis_final - ) + lua_redis.exec_redis_script( + redis_set_script_id, + {task = task, is_write = true}, + function() end, redis_keys, redis_args) end local function tags_restore(task) @@ -362,26 +274,7 @@ end for k, v in pairs(opts) do settings[k] = v end -local function list_to_hash(list) - if type(list) == 'table' then - if list[1] then - local h = {} - for _, e in ipairs(list) do - h[e] = true - end - return h - else - return list - end - elseif type(list) == 'string' then - local h = {} - h[list] = true - return h - else - return {} - end -end -settings.ignore_tags = list_to_hash(settings.ignore_tags) +settings.ignore_tags = lua_util.list_to_hash(settings.ignore_tags) rspamd_config:add_on_load(function(cfg, ev_base, worker) load_scripts(cfg, ev_base) |