diff options
Diffstat (limited to 'lualib/lua_redis.lua')
-rw-r--r-- | lualib/lua_redis.lua | 224 |
1 files changed, 223 insertions, 1 deletions
diff --git a/lualib/lua_redis.lua b/lualib/lua_redis.lua index 22ab675fa..e77f0069f 100644 --- a/lualib/lua_redis.lua +++ b/lualib/lua_redis.lua @@ -32,6 +32,9 @@ local common_schema = ts.shape { prefix = ts.string:is_optional(), password = ts.string:is_optional(), expand_keys = ts.boolean:is_optional(), + sentinels = (ts.string + ts.array_of(ts.string)):is_optional(), + sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional(), + sentinel_masters_pattern = ts.string:is_optional(), } local config_schema = @@ -48,11 +51,203 @@ local config_schema = exports.config_schema = config_schema + +local function redis_query_sentinel(ev_base, params, initialised) + local function flatten_redis_table(tbl) + local res = {} + for i=1,#tbl,2 do + res[tbl[i]] = tbl[i + 1] + end + + return res + end + -- Coroutines syntax + local rspamd_redis = require "rspamd_redis" + local addr = params.sentinels:get_upstream_round_robin() + + local is_ok, connection = rspamd_redis.connect_sync({ + host = addr:get_addr(), + timeout = params.timeout, + config = rspamd_config, + ev_base = ev_base, + }) + + if not is_ok then + logger.errx(rspamd_config, 'cannot connect sentinel at address: %s', + tostring(addr:get_addr())) + addr:fail() + + return + end + + -- Get masters list + connection:add_cmd('SENTINEL', {'masters'}) + + local ok,result = connection:exec() + + if ok and result and type(result) == 'table' then + local masters = {} + for _,m in ipairs(result) do + local master = flatten_redis_table(m) + + if params.sentinel_masters_pattern then + if master.name:match(params.sentinel_masters_pattern) then + lutil.debugm(N, 'found master %s with ip %s and port %s', + master.name, master.ip, master.port) + masters[master.name] = master + else + lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s', + master.name, master.ip, master.port, params.sentinel_masters_pattern) + end + else + lutil.debugm(N, 'found master %s with ip %s and port %s', + master.name, master.ip, master.port) + masters[master.name] = master + end + end + + -- For each master we need to get a list of slaves + for k,v in pairs(masters) do + v.slaves = {} + local slave_result + + connection:add_cmd('SENTINEL', {'slaves', k}) + ok,slave_result = connection:exec() + + if ok then + for _,s in ipairs(slave_result) do + local slave = flatten_redis_table(s) + lutil.debugm(N, rspamd_config, + 'found slave form master %s with ip %s and port %s', + v.name, slave.ip, slave.port) + v.slaves[#v.slaves + 1] = slave + end + end + end + + -- We now form new strings for masters and slaves + local read_servers_tbl, write_servers_tbl = {}, {} + + for _,master in pairs(masters) do + write_servers_tbl[#write_servers_tbl + 1] = string.format( + '%s:%s', master.ip, master.port + ) + read_servers_tbl[#read_servers_tbl + 1] = string.format( + '%s:%s', master.ip, master.port + ) + + for _,slave in ipairs(master.slaves) do + read_servers_tbl[#read_servers_tbl + 1] = string.format( + '%s:%s', slave.ip, slave.port + ) + end + end + + local read_servers_str = table.concat(read_servers_tbl, ',') + local write_servers_str = table.concat(read_servers_tbl, ',') + + lutil.debugm(N, rspamd_config, + 'new servers list: %s read; %s write', read_servers_str, write_servers_str) + + if read_servers_str ~= params.read_servers_str then + local upstream_list = require "rspamd_upstream_list" + + local read_upstreams = upstream_list.create(rspamd_config, + read_servers_str, 6379) + + if read_upstreams then + logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s', + addr:get_addr():to_string(true), read_servers_str) + params.read_servers = read_upstreams + params.read_servers_str = read_servers_str + end + end + + if write_servers_str ~= params.write_servers_str then + local upstream_list = require "rspamd_upstream_list" + + local write_upstreams = upstream_list.create(rspamd_config, + write_servers_str, 6379) + + if write_upstreams then + logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s', + addr:get_addr():to_string(true), write_servers_str) + params.write_servers = write_upstreams + params.write_servers_str = write_servers_str + end + end + + addr:ok() + else + logger.errx('cannot get data from Redis Sentinel %s: %s', + addr:get_addr():to_string(true), result) + addr:fail() + end + +end + +local function add_redis_sentinels(params) + local upstream_list = require "rspamd_upstream_list" + + local upstreams_sentinels = upstream_list.create(rspamd_config, + params.sentinels, 5000) + + if not upstreams_sentinels then + logger.errx(rspamd_config, 'cannot load redis sentinels string: %s', + params.sentinels) + + return + end + + params.sentinels = upstreams_sentinels + + if not params.sentinel_watch_time then + params.sentinel_watch_time = 60 -- Each minute + end + + rspamd_config:add_on_load(function(cfg, ev_base, worker) + local initialised = false + if worker:is_scanner() then + rspamd_config:add_periodic(ev_base, 0.0, function() + redis_query_sentinel(ev_base, params, initialised) + initialised = true + + return params.sentinel_watch_time + end, false) + end + end) +end + +local cached_results = {} + +local function calculate_redis_hash(params) + local cr = require "rspamd_cryptobox_hash" + + local h = cr.create() + + local function rec_hash(k, v) + if type(v) == 'string' then + h:update(k) + h:update(v) + elseif type(v) == 'number' then + h:update(k) + h:update(tostring(v)) + elseif type(v) == 'table' then + for kk,vv in pairs(v) do + rec_hash(kk, vv) + end + end + end + + rec_hash(params) + + return h:base32() +end + --[[[ -- @module lua_redis -- This module contains helper functions for working with Redis --]] - local function try_load_redis_servers(options, rspamd_config, result) local default_port = 6379 local default_timeout = 1.0 @@ -71,6 +266,8 @@ local function try_load_redis_servers(options, rspamd_config, result) upstreams_read = upstream_list.create(options['read_servers'], default_port) end + + result.read_servers_str = options['read_servers'] elseif options['servers'] then if rspamd_config then upstreams_read = upstream_list.create(rspamd_config, @@ -78,6 +275,8 @@ local function try_load_redis_servers(options, rspamd_config, result) else upstreams_read = upstream_list.create(options['servers'], default_port) end + + result.read_servers_str = options['servers'] read_only = false elseif options['server'] then if rspamd_config then @@ -86,6 +285,8 @@ local function try_load_redis_servers(options, rspamd_config, result) else upstreams_read = upstream_list.create(options['server'], default_port) end + + result.read_servers_str = options['server'] read_only = false end @@ -98,9 +299,11 @@ local function try_load_redis_servers(options, rspamd_config, result) upstreams_write = upstream_list.create(options['write_servers'], default_port) end + result.write_servers_str = options['write_servers'] read_only = false elseif not read_only then upstreams_write = upstreams_read + result.write_servers_str = result.read_servers_str end end @@ -144,10 +347,29 @@ local function try_load_redis_servers(options, rspamd_config, result) if upstreams_read then result.read_servers = upstreams_read + if upstreams_write then result.write_servers = upstreams_write end + local h = calculate_redis_hash(result) + + if cached_results[h] then + for k,v in pairs(cached_results[h]) do + result[k] = v + end + lutil.debugm(N, 'reused redis server: %s', result) + return true + end + + result.hash = h + cached_results[h] = result + + if not result.read_only and options.sentinels then + result.sentinels = options.sentinels + add_redis_sentinels(result) + end + lutil.debugm(N, 'loaded redis server: %s', result) return true end |