summaryrefslogtreecommitdiffstats
path: root/lualib/lua_redis.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/lua_redis.lua')
-rw-r--r--lualib/lua_redis.lua174
1 files changed, 123 insertions, 51 deletions
diff --git a/lualib/lua_redis.lua b/lualib/lua_redis.lua
index d88d17489..eace60fa8 100644
--- a/lualib/lua_redis.lua
+++ b/lualib/lua_redis.lua
@@ -609,28 +609,30 @@ exports.redis_make_request_taskless = redis_make_request_taskless
local redis_scripts = {
}
-local function load_redis_script(script, cfg, ev_base, _)
- local function merge_tables(t1, t2)
- for k,v in pairs(t2) do t1[k] = v end
+local function script_set_loaded(script)
+ if script.sha then
+ script.loaded = true
end
- local function set_loaded()
- if script.sha then
- script.loaded = true
- end
+ local wait_table = {}
+ for _,s in ipairs(script.waitq) do
+ table.insert(wait_table, s)
+ end
- local wait_table = {}
- for _,s in ipairs(script.waitq) do
- table.insert(wait_table, s)
- end
+ script.waitq = {}
- script.waitq = {}
+ for _,s in ipairs(wait_table) do
+ s(script.loaded)
+ end
+end
- for _,s in ipairs(wait_table) do
- s(script)
- end
+local function prepare_redis_call(script)
+ local function merge_tables(t1, t2)
+ for k,v in pairs(t2) do t1[k] = v end
end
+
local servers = {}
+ local options = {}
if script.redis_params.read_servers then
merge_tables(servers, script.redis_params.read_servers:all_upstreams())
@@ -642,52 +644,104 @@ local function load_redis_script(script, cfg, ev_base, _)
-- Call load script on each server, set loaded flag
script.in_flight = #servers
for _,s in ipairs(servers) do
- local function script_cb(err, data)
+ local cur_opts = {
+ host = s:get_addr(),
+ timeout = script.redis_params['timeout'],
+ cmd = 'SCRIPT',
+ args = {'LOAD', script.script },
+ upstream = s
+ }
+
+ if script.redis_params['password'] then
+ cur_opts['password'] = script.redis_params['password']
+ end
+
+ if script.redis_params['db'] then
+ cur_opts['dbname'] = script.redis_params['db']
+ end
+
+ table.insert(options, cur_opts)
+ end
+
+ return options
+end
+
+local function load_script_task(script, task)
+ local rspamd_redis = require "rspamd_redis"
+ local opts = prepare_redis_call(script)
+
+ for _,opt in ipairs(opts) do
+ opt.task = task
+ opt.callback = function(err, data)
if err then
- s:fail()
+ opt.upstream:fail()
else
- s:ok()
+ opt.upstream:ok()
+ logger.infox(task,
+ "loaded redis script with id %s, sha: %s", script.id, data)
script.sha = data -- We assume that sha is the same on all servers
end
script.in_flight = script.in_flight - 1
if script.in_flight == 0 then
- set_loaded(script)
+ script_set_loaded(script)
end
end
- local rspamd_redis = require "rspamd_redis"
+ local ret = rspamd_redis.make_request(opt)
- local options = {
- ev_base = ev_base,
- config = cfg,
- callback = script_cb,
- host = s:get_addr(),
- timeout = script.redis_params['timeout'],
- cmd = 'SCRIPT',
- args = {'LOAD', script.script}
- }
+ if not ret then
+ logger.errx('cannot execute redis request to load script')
+ script.in_flight = script.in_flight - 1
+ opt.upstream:fail()
+ end
- if script.redis_params['password'] then
- options['password'] = script.redis_params['password']
+ if script.in_flight == 0 then
+ script_set_loaded(script)
end
+ end
+end
- if script.redis_params['db'] then
- options['dbname'] = script.redis_params['db']
+local function load_script_taskless(script, cfg, ev_base)
+ local rspamd_redis = require "rspamd_redis"
+ local opts = prepare_redis_call(script)
+
+ for _,opt in ipairs(opts) do
+ opt.config = cfg
+ opt.ev_base = ev_base
+ opt.callback = function(err, data)
+ if err then
+ opt.upstream:fail()
+ else
+ opt.upstream:ok()
+ logger.infox(cfg,
+ "loaded redis script with id %s, sha: %s", script.id, data)
+ script.sha = data -- We assume that sha is the same on all servers
+ end
+ script.in_flight = script.in_flight - 1
+
+ if script.in_flight == 0 then
+ script_set_loaded(script)
+ end
end
+ local ret = rspamd_redis.make_request(opt)
- local ret = rspamd_redis.make_request(options)
if not ret then
logger.errx('cannot execute redis request to load script')
script.in_flight = script.in_flight - 1
+ opt.upstream:fail()
end
- end
- if script.in_flight == 0 then
- set_loaded(script)
+ if script.in_flight == 0 then
+ script_set_loaded(script)
+ end
end
end
+local function load_redis_script(script, cfg, ev_base, _)
+ load_script_taskless(script, cfg, ev_base)
+end
+
local function add_redis_script(script, redis_params)
local new_script = {
loaded = false,
@@ -709,34 +763,47 @@ end
exports.add_redis_script = add_redis_script
local function exec_redis_script(id, params, callback, args)
+ local logger = require "rspamd_logger"
+ local args_modified = false
+
if not redis_scripts[id] then
+ logger.errx("cannot find registered script with id %s", id)
return false
end
local script = redis_scripts[id]
- local function do_call()
+ local function do_call(can_reload)
local function redis_cb(err, data)
if not err then
callback(err, data)
- elseif err == 'NOSCRIPT' then
+ elseif string.match(err, 'NOSCRIPT') then
-- Schedule restart
- table.insert(script.waitq, do_call)
- if script.in_flight ~= 0 then
- -- Reload scripts if this has not been initiated yet
- if params.task then
- load_redis_script(script, rspamd_config,
- params.task:get_ev_base(), nil)
- else
- load_redis_script(script, rspamd_config,
- params.ev_base, nil)
+ script.sha = nil
+ if can_reload then
+ table.insert(script.waitq, do_call)
+ if script.in_flight == 0 then
+ -- Reload scripts if this has not been initiated yet
+ if params.task then
+ load_script_task(script, params.task)
+ else
+ load_script_taskless(script, rspamd_config, params.ev_base)
+ end
end
+ else
+ callback(err, data)
end
else
callback(err, data)
end
end
+ if not args_modified then
+ table.insert(args, 1, tostring(#args))
+ table.insert(args, 1, script.sha)
+ args_modified = true
+ end
+
if params.task then
if not rspamd_redis_make_request(params.task, script.redis_params,
params.key, params.is_write, redis_cb, 'EVALSHA', args) then
@@ -751,11 +818,16 @@ local function exec_redis_script(id, params, callback, args)
end
end
- if not script.loaded then
- do_call()
+ if script.loaded then
+ do_call(true)
else
-- Delayed until scripts are loaded
- table.insert(script.waitq, do_call)
+ if not params.task then
+ table.insert(script.waitq, do_call)
+ else
+ -- TODO: fix taskfull requests
+ callback('NOSCRIPT', nil)
+ end
end
return true