aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/lua_redis.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/lua_redis.lua')
-rw-r--r--lualib/lua_redis.lua224
1 files changed, 223 insertions, 1 deletions
diff --git a/lualib/lua_redis.lua b/lualib/lua_redis.lua
index 22ab675fa..e77f0069f 100644
--- a/lualib/lua_redis.lua
+++ b/lualib/lua_redis.lua
@@ -32,6 +32,9 @@ local common_schema = ts.shape {
prefix = ts.string:is_optional(),
password = ts.string:is_optional(),
expand_keys = ts.boolean:is_optional(),
+ sentinels = (ts.string + ts.array_of(ts.string)):is_optional(),
+ sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
+ sentinel_masters_pattern = ts.string:is_optional(),
}
local config_schema =
@@ -48,11 +51,203 @@ local config_schema =
exports.config_schema = config_schema
+
+local function redis_query_sentinel(ev_base, params, initialised)
+ local function flatten_redis_table(tbl)
+ local res = {}
+ for i=1,#tbl,2 do
+ res[tbl[i]] = tbl[i + 1]
+ end
+
+ return res
+ end
+ -- Coroutines syntax
+ local rspamd_redis = require "rspamd_redis"
+ local addr = params.sentinels:get_upstream_round_robin()
+
+ local is_ok, connection = rspamd_redis.connect_sync({
+ host = addr:get_addr(),
+ timeout = params.timeout,
+ config = rspamd_config,
+ ev_base = ev_base,
+ })
+
+ if not is_ok then
+ logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
+ tostring(addr:get_addr()))
+ addr:fail()
+
+ return
+ end
+
+ -- Get masters list
+ connection:add_cmd('SENTINEL', {'masters'})
+
+ local ok,result = connection:exec()
+
+ if ok and result and type(result) == 'table' then
+ local masters = {}
+ for _,m in ipairs(result) do
+ local master = flatten_redis_table(m)
+
+ if params.sentinel_masters_pattern then
+ if master.name:match(params.sentinel_masters_pattern) then
+ lutil.debugm(N, 'found master %s with ip %s and port %s',
+ master.name, master.ip, master.port)
+ masters[master.name] = master
+ else
+ lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
+ master.name, master.ip, master.port, params.sentinel_masters_pattern)
+ end
+ else
+ lutil.debugm(N, 'found master %s with ip %s and port %s',
+ master.name, master.ip, master.port)
+ masters[master.name] = master
+ end
+ end
+
+ -- For each master we need to get a list of slaves
+ for k,v in pairs(masters) do
+ v.slaves = {}
+ local slave_result
+
+ connection:add_cmd('SENTINEL', {'slaves', k})
+ ok,slave_result = connection:exec()
+
+ if ok then
+ for _,s in ipairs(slave_result) do
+ local slave = flatten_redis_table(s)
+ lutil.debugm(N, rspamd_config,
+ 'found slave form master %s with ip %s and port %s',
+ v.name, slave.ip, slave.port)
+ v.slaves[#v.slaves + 1] = slave
+ end
+ end
+ end
+
+ -- We now form new strings for masters and slaves
+ local read_servers_tbl, write_servers_tbl = {}, {}
+
+ for _,master in pairs(masters) do
+ write_servers_tbl[#write_servers_tbl + 1] = string.format(
+ '%s:%s', master.ip, master.port
+ )
+ read_servers_tbl[#read_servers_tbl + 1] = string.format(
+ '%s:%s', master.ip, master.port
+ )
+
+ for _,slave in ipairs(master.slaves) do
+ read_servers_tbl[#read_servers_tbl + 1] = string.format(
+ '%s:%s', slave.ip, slave.port
+ )
+ end
+ end
+
+ local read_servers_str = table.concat(read_servers_tbl, ',')
+ local write_servers_str = table.concat(read_servers_tbl, ',')
+
+ lutil.debugm(N, rspamd_config,
+ 'new servers list: %s read; %s write', read_servers_str, write_servers_str)
+
+ if read_servers_str ~= params.read_servers_str then
+ local upstream_list = require "rspamd_upstream_list"
+
+ local read_upstreams = upstream_list.create(rspamd_config,
+ read_servers_str, 6379)
+
+ if read_upstreams then
+ logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
+ addr:get_addr():to_string(true), read_servers_str)
+ params.read_servers = read_upstreams
+ params.read_servers_str = read_servers_str
+ end
+ end
+
+ if write_servers_str ~= params.write_servers_str then
+ local upstream_list = require "rspamd_upstream_list"
+
+ local write_upstreams = upstream_list.create(rspamd_config,
+ write_servers_str, 6379)
+
+ if write_upstreams then
+ logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
+ addr:get_addr():to_string(true), write_servers_str)
+ params.write_servers = write_upstreams
+ params.write_servers_str = write_servers_str
+ end
+ end
+
+ addr:ok()
+ else
+ logger.errx('cannot get data from Redis Sentinel %s: %s',
+ addr:get_addr():to_string(true), result)
+ addr:fail()
+ end
+
+end
+
+local function add_redis_sentinels(params)
+ local upstream_list = require "rspamd_upstream_list"
+
+ local upstreams_sentinels = upstream_list.create(rspamd_config,
+ params.sentinels, 5000)
+
+ if not upstreams_sentinels then
+ logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
+ params.sentinels)
+
+ return
+ end
+
+ params.sentinels = upstreams_sentinels
+
+ if not params.sentinel_watch_time then
+ params.sentinel_watch_time = 60 -- Each minute
+ end
+
+ rspamd_config:add_on_load(function(cfg, ev_base, worker)
+ local initialised = false
+ if worker:is_scanner() then
+ rspamd_config:add_periodic(ev_base, 0.0, function()
+ redis_query_sentinel(ev_base, params, initialised)
+ initialised = true
+
+ return params.sentinel_watch_time
+ end, false)
+ end
+ end)
+end
+
+local cached_results = {}
+
+local function calculate_redis_hash(params)
+ local cr = require "rspamd_cryptobox_hash"
+
+ local h = cr.create()
+
+ local function rec_hash(k, v)
+ if type(v) == 'string' then
+ h:update(k)
+ h:update(v)
+ elseif type(v) == 'number' then
+ h:update(k)
+ h:update(tostring(v))
+ elseif type(v) == 'table' then
+ for kk,vv in pairs(v) do
+ rec_hash(kk, vv)
+ end
+ end
+ end
+
+ rec_hash(params)
+
+ return h:base32()
+end
+
--[[[
-- @module lua_redis
-- This module contains helper functions for working with Redis
--]]
-
local function try_load_redis_servers(options, rspamd_config, result)
local default_port = 6379
local default_timeout = 1.0
@@ -71,6 +266,8 @@ local function try_load_redis_servers(options, rspamd_config, result)
upstreams_read = upstream_list.create(options['read_servers'],
default_port)
end
+
+ result.read_servers_str = options['read_servers']
elseif options['servers'] then
if rspamd_config then
upstreams_read = upstream_list.create(rspamd_config,
@@ -78,6 +275,8 @@ local function try_load_redis_servers(options, rspamd_config, result)
else
upstreams_read = upstream_list.create(options['servers'], default_port)
end
+
+ result.read_servers_str = options['servers']
read_only = false
elseif options['server'] then
if rspamd_config then
@@ -86,6 +285,8 @@ local function try_load_redis_servers(options, rspamd_config, result)
else
upstreams_read = upstream_list.create(options['server'], default_port)
end
+
+ result.read_servers_str = options['server']
read_only = false
end
@@ -98,9 +299,11 @@ local function try_load_redis_servers(options, rspamd_config, result)
upstreams_write = upstream_list.create(options['write_servers'],
default_port)
end
+ result.write_servers_str = options['write_servers']
read_only = false
elseif not read_only then
upstreams_write = upstreams_read
+ result.write_servers_str = result.read_servers_str
end
end
@@ -144,10 +347,29 @@ local function try_load_redis_servers(options, rspamd_config, result)
if upstreams_read then
result.read_servers = upstreams_read
+
if upstreams_write then
result.write_servers = upstreams_write
end
+ local h = calculate_redis_hash(result)
+
+ if cached_results[h] then
+ for k,v in pairs(cached_results[h]) do
+ result[k] = v
+ end
+ lutil.debugm(N, 'reused redis server: %s', result)
+ return true
+ end
+
+ result.hash = h
+ cached_results[h] = result
+
+ if not result.read_only and options.sentinels then
+ result.sentinels = options.sentinels
+ add_redis_sentinels(result)
+ end
+
lutil.debugm(N, 'loaded redis server: %s', result)
return true
end