summaryrefslogtreecommitdiffstats
path: root/lualib/lua_redis.lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-12-16 15:40:37 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-12-16 15:40:37 +0000
commitd7f6e6a5166a683f347e14b48255f8dd5ceee3d9 (patch)
treeaf81dc482ce71e4ec2f6cca227af9bb8eb0b3e20 /lualib/lua_redis.lua
parent9d99bef1243389742c3071801192e16fb90d8995 (diff)
downloadrspamd-d7f6e6a5166a683f347e14b48255f8dd5ceee3d9.tar.gz
rspamd-d7f6e6a5166a683f347e14b48255f8dd5ceee3d9.zip
[Feature] Add framework to manage Redis scripts
Diffstat (limited to 'lualib/lua_redis.lua')
-rw-r--r--lualib/lua_redis.lua157
1 files changed, 157 insertions, 0 deletions
diff --git a/lualib/lua_redis.lua b/lualib/lua_redis.lua
index 25f0078ba..d88d17489 100644
--- a/lualib/lua_redis.lua
+++ b/lualib/lua_redis.lua
@@ -606,4 +606,161 @@ end
exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
exports.redis_make_request_taskless = redis_make_request_taskless
+local redis_scripts = {
+}
+
+local function load_redis_script(script, cfg, ev_base, _)
+ local function merge_tables(t1, t2)
+ for k,v in pairs(t2) do t1[k] = v end
+ end
+
+ local function set_loaded()
+ if script.sha then
+ script.loaded = true
+ end
+
+ local wait_table = {}
+ for _,s in ipairs(script.waitq) do
+ table.insert(wait_table, s)
+ end
+
+ script.waitq = {}
+
+ for _,s in ipairs(wait_table) do
+ s(script)
+ end
+ end
+ local servers = {}
+
+ if script.redis_params.read_servers then
+ merge_tables(servers, script.redis_params.read_servers:all_upstreams())
+ end
+ if script.redis_params.write_servers then
+ merge_tables(servers, script.redis_params.write_servers:all_upstreams())
+ end
+
+ -- Call load script on each server, set loaded flag
+ script.in_flight = #servers
+ for _,s in ipairs(servers) do
+ local function script_cb(err, data)
+ if err then
+ s:fail()
+ else
+ s:ok()
+ script.sha = data -- We assume that sha is the same on all servers
+ end
+ script.in_flight = script.in_flight - 1
+
+ if script.in_flight == 0 then
+ set_loaded(script)
+ end
+ end
+
+ local rspamd_redis = require "rspamd_redis"
+
+ local options = {
+ ev_base = ev_base,
+ config = cfg,
+ callback = script_cb,
+ host = s:get_addr(),
+ timeout = script.redis_params['timeout'],
+ cmd = 'SCRIPT',
+ args = {'LOAD', script.script}
+ }
+
+ if script.redis_params['password'] then
+ options['password'] = script.redis_params['password']
+ end
+
+ if script.redis_params['db'] then
+ options['dbname'] = script.redis_params['db']
+ end
+
+ local ret = rspamd_redis.make_request(options)
+ if not ret then
+ logger.errx('cannot execute redis request to load script')
+ script.in_flight = script.in_flight - 1
+ end
+ end
+
+ if script.in_flight == 0 then
+ set_loaded(script)
+ end
+end
+
+local function add_redis_script(script, redis_params)
+ local new_script = {
+ loaded = false,
+ redis_params = redis_params,
+ script = script,
+ waitq = {}, -- callbacks pending for script being loaded
+ id = #redis_scripts + 1
+ }
+
+ -- Register on load function
+ rspamd_config:add_on_load(function(cfg, ev_base, worker)
+ load_redis_script(new_script, cfg, ev_base, worker)
+ end)
+
+ table.insert(redis_scripts, new_script)
+
+ return #redis_scripts
+end
+exports.add_redis_script = add_redis_script
+
+local function exec_redis_script(id, params, callback, args)
+ if not redis_scripts[id] then
+ return false
+ end
+
+ local script = redis_scripts[id]
+
+ local function do_call()
+ local function redis_cb(err, data)
+ if not err then
+ callback(err, data)
+ elseif err == 'NOSCRIPT' then
+ -- Schedule restart
+ table.insert(script.waitq, do_call)
+ if script.in_flight ~= 0 then
+ -- Reload scripts if this has not been initiated yet
+ if params.task then
+ load_redis_script(script, rspamd_config,
+ params.task:get_ev_base(), nil)
+ else
+ load_redis_script(script, rspamd_config,
+ params.ev_base, nil)
+ end
+ end
+ else
+ callback(err, data)
+ end
+ end
+
+ if params.task then
+ if not rspamd_redis_make_request(params.task, script.redis_params,
+ params.key, params.is_write, redis_cb, 'EVALSHA', args) then
+ callback('Cannot make redis request', nil)
+ end
+ else
+ if not redis_make_request_taskless(params.ev_base, rspamd_config,
+ script.redis_params,
+ params.key, params.is_write, redis_cb, 'EVALSHA', args) then
+ callback('Cannot make redis request', nil)
+ end
+ end
+ end
+
+ if not script.loaded then
+ do_call()
+ else
+ -- Delayed until scripts are loaded
+ table.insert(script.waitq, do_call)
+ end
+
+ return true
+end
+
+exports.exec_redis_script = exec_redis_script
+
return exports