From d7f6e6a5166a683f347e14b48255f8dd5ceee3d9 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 16 Dec 2017 15:40:37 +0000 Subject: [Feature] Add framework to manage Redis scripts --- lualib/lua_redis.lua | 157 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) (limited to 'lualib/lua_redis.lua') diff --git a/lualib/lua_redis.lua b/lualib/lua_redis.lua index 25f0078ba..d88d17489 100644 --- a/lualib/lua_redis.lua +++ b/lualib/lua_redis.lua @@ -606,4 +606,161 @@ end exports.rspamd_redis_make_request_taskless = redis_make_request_taskless exports.redis_make_request_taskless = redis_make_request_taskless +local redis_scripts = { +} + +local function load_redis_script(script, cfg, ev_base, _) + local function merge_tables(t1, t2) + for k,v in pairs(t2) do t1[k] = v end + end + + local function set_loaded() + if script.sha then + script.loaded = true + end + + local wait_table = {} + for _,s in ipairs(script.waitq) do + table.insert(wait_table, s) + end + + script.waitq = {} + + for _,s in ipairs(wait_table) do + s(script) + end + end + local servers = {} + + if script.redis_params.read_servers then + merge_tables(servers, script.redis_params.read_servers:all_upstreams()) + end + if script.redis_params.write_servers then + merge_tables(servers, script.redis_params.write_servers:all_upstreams()) + end + + -- Call load script on each server, set loaded flag + script.in_flight = #servers + for _,s in ipairs(servers) do + local function script_cb(err, data) + if err then + s:fail() + else + s:ok() + script.sha = data -- We assume that sha is the same on all servers + end + script.in_flight = script.in_flight - 1 + + if script.in_flight == 0 then + set_loaded(script) + end + end + + local rspamd_redis = require "rspamd_redis" + + local options = { + ev_base = ev_base, + config = cfg, + callback = script_cb, + host = s:get_addr(), + timeout = script.redis_params['timeout'], + cmd = 'SCRIPT', + args = {'LOAD', script.script} + } + + if script.redis_params['password'] then + options['password'] = script.redis_params['password'] + end + + if script.redis_params['db'] then + options['dbname'] = script.redis_params['db'] + end + + local ret = rspamd_redis.make_request(options) + if not ret then + logger.errx('cannot execute redis request to load script') + script.in_flight = script.in_flight - 1 + end + end + + if script.in_flight == 0 then + set_loaded(script) + end +end + +local function add_redis_script(script, redis_params) + local new_script = { + loaded = false, + redis_params = redis_params, + script = script, + waitq = {}, -- callbacks pending for script being loaded + id = #redis_scripts + 1 + } + + -- Register on load function + rspamd_config:add_on_load(function(cfg, ev_base, worker) + load_redis_script(new_script, cfg, ev_base, worker) + end) + + table.insert(redis_scripts, new_script) + + return #redis_scripts +end +exports.add_redis_script = add_redis_script + +local function exec_redis_script(id, params, callback, args) + if not redis_scripts[id] then + return false + end + + local script = redis_scripts[id] + + local function do_call() + local function redis_cb(err, data) + if not err then + callback(err, data) + elseif err == 'NOSCRIPT' then + -- Schedule restart + table.insert(script.waitq, do_call) + if script.in_flight ~= 0 then + -- Reload scripts if this has not been initiated yet + if params.task then + load_redis_script(script, rspamd_config, + params.task:get_ev_base(), nil) + else + load_redis_script(script, rspamd_config, + params.ev_base, nil) + end + end + else + callback(err, data) + end + end + + if params.task then + if not rspamd_redis_make_request(params.task, script.redis_params, + params.key, params.is_write, redis_cb, 'EVALSHA', args) then + callback('Cannot make redis request', nil) + end + else + if not redis_make_request_taskless(params.ev_base, rspamd_config, + script.redis_params, + params.key, params.is_write, redis_cb, 'EVALSHA', args) then + callback('Cannot make redis request', nil) + end + end + end + + if not script.loaded then + do_call() + else + -- Delayed until scripts are loaded + table.insert(script.waitq, do_call) + end + + return true +end + +exports.exec_redis_script = exec_redis_script + return exports -- cgit v1.2.3