You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_redis.lua 44KB


  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local ts = require("tableshape").types
  17. local exports = {}
  18. local E = {}
  19. local N = "lua_redis"
  20. local common_schema = ts.shape {
  21. timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  22. db = ts.string:is_optional(),
  23. database = ts.string:is_optional(),
  24. dbname = ts.string:is_optional(),
  25. prefix = ts.string:is_optional(),
  26. password = ts.string:is_optional(),
  27. expand_keys = ts.boolean:is_optional(),
  28. sentinels = (ts.string + ts.array_of(ts.string)):is_optional(),
  29. sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  30. sentinel_masters_pattern = ts.string:is_optional(),
  31. sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional(),
  32. }
  33. local config_schema =
  34. ts.shape({
  35. read_servers = ts.string + ts.array_of(ts.string),
  36. write_servers = ts.string + ts.array_of(ts.string),
  37. }, {extra_opts = common_schema}) +
  38. ts.shape({
  39. servers = ts.string + ts.array_of(ts.string),
  40. }, {extra_opts = common_schema}) +
  41. ts.shape({
  42. server = ts.string + ts.array_of(ts.string),
  43. }, {extra_opts = common_schema})
  44. exports.config_schema = config_schema
  45. local function redis_query_sentinel(ev_base, params, initialised)
  46. local function flatten_redis_table(tbl)
  47. local res = {}
  48. for i=1,#tbl,2 do
  49. res[tbl[i]] = tbl[i + 1]
  50. end
  51. return res
  52. end
  53. -- Coroutines syntax
  54. local rspamd_redis = require "rspamd_redis"
  55. local addr = params.sentinels:get_upstream_round_robin()
  56. local is_ok, connection = rspamd_redis.connect_sync({
  57. host = addr:get_addr(),
  58. timeout = params.timeout,
  59. config = rspamd_config,
  60. ev_base = ev_base,
  61. })
  62. if not is_ok then
  63. logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
  64. tostring(addr:get_addr()))
  65. addr:fail()
  66. return
  67. end
  68. -- Get masters list
  69. connection:add_cmd('SENTINEL', {'masters'})
  70. local ok,result = connection:exec()
  71. if ok and result and type(result) == 'table' then
  72. local masters = {}
  73. for _,m in ipairs(result) do
  74. local master = flatten_redis_table(m)
  75. if params.sentinel_masters_pattern then
  76. if master.name:match(params.sentinel_masters_pattern) then
  77. lutil.debugm(N, 'found master %s with ip %s and port %s',
  78. master.name, master.ip, master.port)
  79. masters[master.name] = master
  80. else
  81. lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
  82. master.name, master.ip, master.port, params.sentinel_masters_pattern)
  83. end
  84. else
  85. lutil.debugm(N, 'found master %s with ip %s and port %s',
  86. master.name, master.ip, master.port)
  87. masters[master.name] = master
  88. end
  89. end
  90. -- For each master we need to get a list of slaves
  91. for k,v in pairs(masters) do
  92. v.slaves = {}
  93. local slave_result
  94. connection:add_cmd('SENTINEL', {'slaves', k})
  95. ok,slave_result = connection:exec()
  96. if ok then
  97. for _,s in ipairs(slave_result) do
  98. local slave = flatten_redis_table(s)
  99. lutil.debugm(N, rspamd_config,
  100. 'found slave form master %s with ip %s and port %s',
  101. v.name, slave.ip, slave.port)
  102. v.slaves[#v.slaves + 1] = slave
  103. end
  104. end
  105. end
  106. -- We now form new strings for masters and slaves
  107. local read_servers_tbl, write_servers_tbl = {}, {}
  108. for _,master in pairs(masters) do
  109. write_servers_tbl[#write_servers_tbl + 1] = string.format(
  110. '%s:%s', master.ip, master.port
  111. )
  112. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  113. '%s:%s', master.ip, master.port
  114. )
  115. for _,slave in ipairs(master.slaves) do
  116. if slave['master-link-status'] == 'ok' then
  117. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  118. '%s:%s', slave.ip, slave.port
  119. )
  120. end
  121. end
  122. end
  123. table.sort(read_servers_tbl)
  124. table.sort(write_servers_tbl)
  125. local read_servers_str = table.concat(read_servers_tbl, ',')
  126. local write_servers_str = table.concat(write_servers_tbl, ',')
  127. lutil.debugm(N, rspamd_config,
  128. 'new servers list: %s read; %s write',
  129. read_servers_str,
  130. write_servers_str)
  131. if read_servers_str ~= params.read_servers_str then
  132. local upstream_list = require "rspamd_upstream_list"
  133. local read_upstreams = upstream_list.create(rspamd_config,
  134. read_servers_str, 6379)
  135. if read_upstreams then
  136. logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
  137. addr:get_addr():to_string(true), read_servers_str)
  138. params.read_servers = read_upstreams
  139. params.read_servers_str = read_servers_str
  140. end
  141. end
  142. if write_servers_str ~= params.write_servers_str then
  143. local upstream_list = require "rspamd_upstream_list"
  144. local write_upstreams = upstream_list.create(rspamd_config,
  145. write_servers_str, 6379)
  146. if write_upstreams then
  147. logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
  148. addr:get_addr():to_string(true), write_servers_str)
  149. params.write_servers = write_upstreams
  150. params.write_servers_str = write_servers_str
  151. local queried = false
  152. local function monitor_failures(up, _, count)
  153. if count > params.sentinel_master_maxerrors and not queried then
  154. logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel',
  155. up:get_addr():to_string(true), count)
  156. queried = true -- Avoid multiple checks caused by this monitor
  157. redis_query_sentinel(ev_base, params, true)
  158. end
  159. end
  160. write_upstreams:add_watcher('failure', monitor_failures)
  161. end
  162. end
  163. addr:ok()
  164. else
  165. logger.errx('cannot get data from Redis Sentinel %s: %s',
  166. addr:get_addr():to_string(true), result)
  167. addr:fail()
  168. end
  169. end
  170. local function add_redis_sentinels(params)
  171. local upstream_list = require "rspamd_upstream_list"
  172. local upstreams_sentinels = upstream_list.create(rspamd_config,
  173. params.sentinels, 5000)
  174. if not upstreams_sentinels then
  175. logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
  176. params.sentinels)
  177. return
  178. end
  179. params.sentinels = upstreams_sentinels
  180. if not params.sentinel_watch_time then
  181. params.sentinel_watch_time = 60 -- Each minute
  182. end
  183. if not params.sentinel_master_maxerrors then
  184. params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking
  185. end
  186. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  187. local initialised = false
  188. if worker:is_scanner() then
  189. rspamd_config:add_periodic(ev_base, 0.0, function()
  190. redis_query_sentinel(ev_base, params, initialised)
  191. initialised = true
  192. return params.sentinel_watch_time
  193. end, false)
  194. end
  195. end)
  196. end
  197. local cached_results = {}
  198. local function calculate_redis_hash(params)
  199. local cr = require "rspamd_cryptobox_hash"
  200. local h = cr.create()
  201. local function rec_hash(k, v)
  202. if type(v) == 'string' then
  203. h:update(k)
  204. h:update(v)
  205. elseif type(v) == 'number' then
  206. h:update(k)
  207. h:update(tostring(v))
  208. elseif type(v) == 'table' then
  209. for kk,vv in pairs(v) do
  210. rec_hash(kk, vv)
  211. end
  212. end
  213. end
  214. rec_hash('top', params)
  215. return h:base32()
  216. end
  217. local function process_redis_opts(options, redis_params)
  218. local default_timeout = 1.0
  219. local default_expand_keys = false
  220. if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then
  221. if options['timeout'] then
  222. redis_params['timeout'] = tonumber(options['timeout'])
  223. else
  224. redis_params['timeout'] = default_timeout
  225. end
  226. end
  227. if options['prefix'] and not redis_params['prefix'] then
  228. redis_params['prefix'] = options['prefix']
  229. end
  230. if type(options['expand_keys']) == 'boolean' then
  231. redis_params['expand_keys'] = options['expand_keys']
  232. else
  233. redis_params['expand_keys'] = default_expand_keys
  234. end
  235. if not redis_params['db'] then
  236. if options['db'] then
  237. redis_params['db'] = tostring(options['db'])
  238. elseif options['dbname'] then
  239. redis_params['db'] = tostring(options['dbname'])
  240. elseif options['database'] then
  241. redis_params['db'] = tostring(options['database'])
  242. end
  243. end
  244. if options['password'] and not redis_params['password'] then
  245. redis_params['password'] = options['password']
  246. end
  247. if not redis_params.sentinel and options.sentinel then
  248. redis_params.sentinel = options.sentinel
  249. end
  250. end
  251. local function enrich_defaults(rspamd_config, module, redis_params)
  252. if rspamd_config then
  253. local opts = rspamd_config:get_all_opt('redis')
  254. if opts then
  255. if module then
  256. if opts[module] then
  257. process_redis_opts(opts[module], redis_params)
  258. end
  259. end
  260. process_redis_opts(opts, redis_params)
  261. end
  262. end
  263. end
  264. local function maybe_return_cached(redis_params)
  265. local h = calculate_redis_hash(redis_params)
  266. if cached_results[h] then
  267. lutil.debugm(N, 'reused redis server: %s', redis_params)
  268. return cached_results[h]
  269. end
  270. redis_params.hash = h
  271. cached_results[h] = redis_params
  272. if not redis_params.read_only and redis_params.sentinels then
  273. add_redis_sentinels(redis_params)
  274. end
  275. lutil.debugm(N, 'loaded new redis server: %s', redis_params)
  276. return redis_params
  277. end
  278. --[[[
  279. -- @module lua_redis
  280. -- This module contains helper functions for working with Redis
  281. --]]
  282. local function process_redis_options(options, rspamd_config, result)
  283. local default_port = 6379
  284. local upstream_list = require "rspamd_upstream_list"
  285. local read_only = true
  286. -- Try to get read servers:
  287. local upstreams_read, upstreams_write
  288. if options['read_servers'] then
  289. if rspamd_config then
  290. upstreams_read = upstream_list.create(rspamd_config,
  291. options['read_servers'], default_port)
  292. else
  293. upstreams_read = upstream_list.create(options['read_servers'],
  294. default_port)
  295. end
  296. result.read_servers_str = options['read_servers']
  297. elseif options['servers'] then
  298. if rspamd_config then
  299. upstreams_read = upstream_list.create(rspamd_config,
  300. options['servers'], default_port)
  301. else
  302. upstreams_read = upstream_list.create(options['servers'], default_port)
  303. end
  304. result.read_servers_str = options['servers']
  305. read_only = false
  306. elseif options['server'] then
  307. if rspamd_config then
  308. upstreams_read = upstream_list.create(rspamd_config,
  309. options['server'], default_port)
  310. else
  311. upstreams_read = upstream_list.create(options['server'], default_port)
  312. end
  313. result.read_servers_str = options['server']
  314. read_only = false
  315. end
  316. if upstreams_read then
  317. if options['write_servers'] then
  318. if rspamd_config then
  319. upstreams_write = upstream_list.create(rspamd_config,
  320. options['write_servers'], default_port)
  321. else
  322. upstreams_write = upstream_list.create(options['write_servers'],
  323. default_port)
  324. end
  325. result.write_servers_str = options['write_servers']
  326. read_only = false
  327. elseif not read_only then
  328. upstreams_write = upstreams_read
  329. result.write_servers_str = result.read_servers_str
  330. end
  331. end
  332. -- Store options
  333. process_redis_opts(options, result)
  334. if read_only and not upstreams_write then
  335. result.read_only = true
  336. elseif upstreams_write then
  337. result.read_only = false
  338. end
  339. if upstreams_read then
  340. result.read_servers = upstreams_read
  341. if upstreams_write then
  342. result.write_servers = upstreams_write
  343. end
  344. return true
  345. end
  346. lutil.debugm(N, rspamd_config,
  347. 'cannot load redis server from obj: %s, processed to %s',
  348. options, result)
  349. return false
  350. end
  351. --[[[
  352. @function try_load_redis_servers(options, rspamd_config, no_fallback)
  353. Tries to load redis servers from the specified `options` object.
  354. Returns `redis_params` table or nil in case of failure
  355. --]]
  356. exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name)
  357. local result = {}
  358. if process_redis_options(options, rspamd_config, result) then
  359. if not no_fallback then
  360. enrich_defaults(rspamd_config, module_name, result)
  361. end
  362. return maybe_return_cached(result)
  363. end
  364. end
  365. -- This function parses redis server definition using either
  366. -- specific server string for this module or global
  367. -- redis section
  368. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  369. local result = {}
  370. -- Try local options
  371. local opts
  372. lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name)
  373. if not module_opts then
  374. opts = rspamd_config:get_all_opt(module_name)
  375. else
  376. opts = module_opts
  377. end
  378. if opts then
  379. local ret
  380. if opts.redis then
  381. ret = process_redis_options(opts.redis, rspamd_config, result)
  382. if ret then
  383. if not no_fallback then
  384. enrich_defaults(rspamd_config, module_name, result)
  385. end
  386. return maybe_return_cached(result)
  387. end
  388. end
  389. ret = process_redis_options(opts, rspamd_config, result)
  390. if ret then
  391. if not no_fallback then
  392. enrich_defaults(rspamd_config, module_name, result)
  393. end
  394. return maybe_return_cached(result)
  395. end
  396. end
  397. if no_fallback then
  398. logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled",
  399. module_name)
  400. return nil
  401. end
  402. -- Try global options
  403. opts = rspamd_config:get_all_opt('redis')
  404. if opts then
  405. local ret
  406. if opts[module_name] then
  407. ret = process_redis_options(opts[module_name], rspamd_config, result)
  408. if ret then
  409. return maybe_return_cached(result)
  410. end
  411. else
  412. ret = process_redis_options(opts, rspamd_config, result)
  413. -- Exclude disabled
  414. if opts['disabled_modules'] then
  415. for _,v in ipairs(opts['disabled_modules']) do
  416. if v == module_name then
  417. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  418. module_name)
  419. return nil
  420. end
  421. end
  422. end
  423. if ret then
  424. logger.infox(rspamd_config, "use default Redis settings for %s",
  425. module_name)
  426. return maybe_return_cached(result)
  427. end
  428. end
  429. end
  430. if result.read_servers then
  431. return maybe_return_cached(result)
  432. end
  433. return nil
  434. end
  435. --[[[
  436. -- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
  437. -- Extracts Redis server settings from configuration
  438. -- @param {string} module_name name of module to get settings for
  439. -- @param {table} module_opts settings for module or `nil` to fetch them from configuration
  440. -- @param {boolean} no_fallback should be `true` if global settings must not be used
  441. -- @return {table} redis server settings
  442. -- @example
  443. -- local rconfig = lua_redis.parse_redis_server('my_module')
  444. -- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
  445. -- -- ['timeout'] contains timeout in seconds
  446. -- -- ['expand_keys'] if true tells that redis key expansion is enabled
  447. --]]
  448. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  449. exports.parse_redis_server = rspamd_parse_redis_server
  450. local process_cmd = {
  451. bitop = function(args)
  452. local idx_l = {}
  453. for i = 2, #args do
  454. table.insert(idx_l, i)
  455. end
  456. return idx_l
  457. end,
  458. blpop = function(args)
  459. local idx_l = {}
  460. for i = 1, #args -1 do
  461. table.insert(idx_l, i)
  462. end
  463. return idx_l
  464. end,
  465. eval = function(args)
  466. local idx_l = {}
  467. local numkeys = args[2]
  468. if numkeys and tonumber(numkeys) >= 1 then
  469. for i = 3, numkeys + 2 do
  470. table.insert(idx_l, i)
  471. end
  472. end
  473. return idx_l
  474. end,
  475. set = function(args)
  476. return {1}
  477. end,
  478. mget = function(args)
  479. local idx_l = {}
  480. for i = 1, #args do
  481. table.insert(idx_l, i)
  482. end
  483. return idx_l
  484. end,
  485. mset = function(args)
  486. local idx_l = {}
  487. for i = 1, #args, 2 do
  488. table.insert(idx_l, i)
  489. end
  490. return idx_l
  491. end,
  492. sdiffstore = function(args)
  493. local idx_l = {}
  494. for i = 2, #args do
  495. table.insert(idx_l, i)
  496. end
  497. return idx_l
  498. end,
  499. smove = function(args)
  500. return {1, 2}
  501. end,
  502. script = function() end
  503. }
  504. process_cmd.append = process_cmd.set
  505. process_cmd.auth = process_cmd.script
  506. process_cmd.bgrewriteaof = process_cmd.script
  507. process_cmd.bgsave = process_cmd.script
  508. process_cmd.bitcount = process_cmd.set
  509. process_cmd.bitfield = process_cmd.set
  510. process_cmd.bitpos = process_cmd.set
  511. process_cmd.brpop = process_cmd.blpop
  512. process_cmd.brpoplpush = process_cmd.blpop
  513. process_cmd.client = process_cmd.script
  514. process_cmd.cluster = process_cmd.script
  515. process_cmd.command = process_cmd.script
  516. process_cmd.config = process_cmd.script
  517. process_cmd.dbsize = process_cmd.script
  518. process_cmd.debug = process_cmd.script
  519. process_cmd.decr = process_cmd.set
  520. process_cmd.decrby = process_cmd.set
  521. process_cmd.del = process_cmd.mget
  522. process_cmd.discard = process_cmd.script
  523. process_cmd.dump = process_cmd.set
  524. process_cmd.echo = process_cmd.script
  525. process_cmd.evalsha = process_cmd.eval
  526. process_cmd.exec = process_cmd.script
  527. process_cmd.exists = process_cmd.mget
  528. process_cmd.expire = process_cmd.set
  529. process_cmd.expireat = process_cmd.set
  530. process_cmd.flushall = process_cmd.script
  531. process_cmd.flushdb = process_cmd.script
  532. process_cmd.geoadd = process_cmd.set
  533. process_cmd.geohash = process_cmd.set
  534. process_cmd.geopos = process_cmd.set
  535. process_cmd.geodist = process_cmd.set
  536. process_cmd.georadius = process_cmd.set
  537. process_cmd.georadiusbymember = process_cmd.set
  538. process_cmd.get = process_cmd.set
  539. process_cmd.getbit = process_cmd.set
  540. process_cmd.getrange = process_cmd.set
  541. process_cmd.getset = process_cmd.set
  542. process_cmd.hdel = process_cmd.set
  543. process_cmd.hexists = process_cmd.set
  544. process_cmd.hget = process_cmd.set
  545. process_cmd.hgetall = process_cmd.set
  546. process_cmd.hincrby = process_cmd.set
  547. process_cmd.hincrbyfloat = process_cmd.set
  548. process_cmd.hkeys = process_cmd.set
  549. process_cmd.hlen = process_cmd.set
  550. process_cmd.hmget = process_cmd.set
  551. process_cmd.hmset = process_cmd.set
  552. process_cmd.hscan = process_cmd.set
  553. process_cmd.hset = process_cmd.set
  554. process_cmd.hsetnx = process_cmd.set
  555. process_cmd.hstrlen = process_cmd.set
  556. process_cmd.hvals = process_cmd.set
  557. process_cmd.incr = process_cmd.set
  558. process_cmd.incrby = process_cmd.set
  559. process_cmd.incrbyfloat = process_cmd.set
  560. process_cmd.info = process_cmd.script
  561. process_cmd.keys = process_cmd.script
  562. process_cmd.lastsave = process_cmd.script
  563. process_cmd.lindex = process_cmd.set
  564. process_cmd.linsert = process_cmd.set
  565. process_cmd.llen = process_cmd.set
  566. process_cmd.lpop = process_cmd.set
  567. process_cmd.lpush = process_cmd.set
  568. process_cmd.lpushx = process_cmd.set
  569. process_cmd.lrange = process_cmd.set
  570. process_cmd.lrem = process_cmd.set
  571. process_cmd.lset = process_cmd.set
  572. process_cmd.ltrim = process_cmd.set
  573. process_cmd.migrate = process_cmd.script
  574. process_cmd.monitor = process_cmd.script
  575. process_cmd.move = process_cmd.set
  576. process_cmd.msetnx = process_cmd.mset
  577. process_cmd.multi = process_cmd.script
  578. process_cmd.object = process_cmd.script
  579. process_cmd.persist = process_cmd.set
  580. process_cmd.pexpire = process_cmd.set
  581. process_cmd.pexpireat = process_cmd.set
  582. process_cmd.pfadd = process_cmd.set
  583. process_cmd.pfcount = process_cmd.set
  584. process_cmd.pfmerge = process_cmd.mget
  585. process_cmd.ping = process_cmd.script
  586. process_cmd.psetex = process_cmd.set
  587. process_cmd.psubscribe = process_cmd.script
  588. process_cmd.pubsub = process_cmd.script
  589. process_cmd.pttl = process_cmd.set
  590. process_cmd.publish = process_cmd.script
  591. process_cmd.punsubscribe = process_cmd.script
  592. process_cmd.quit = process_cmd.script
  593. process_cmd.randomkey = process_cmd.script
  594. process_cmd.readonly = process_cmd.script
  595. process_cmd.readwrite = process_cmd.script
  596. process_cmd.rename = process_cmd.mget
  597. process_cmd.renamenx = process_cmd.mget
  598. process_cmd.restore = process_cmd.set
  599. process_cmd.role = process_cmd.script
  600. process_cmd.rpop = process_cmd.set
  601. process_cmd.rpoplpush = process_cmd.mget
  602. process_cmd.rpush = process_cmd.set
  603. process_cmd.rpushx = process_cmd.set
  604. process_cmd.sadd = process_cmd.set
  605. process_cmd.save = process_cmd.script
  606. process_cmd.scard = process_cmd.set
  607. process_cmd.sdiff = process_cmd.mget
  608. process_cmd.select = process_cmd.script
  609. process_cmd.setbit = process_cmd.set
  610. process_cmd.setex = process_cmd.set
  611. process_cmd.setnx = process_cmd.set
  612. process_cmd.sinterstore = process_cmd.sdiff
  613. process_cmd.sismember = process_cmd.set
  614. process_cmd.slaveof = process_cmd.script
  615. process_cmd.slowlog = process_cmd.script
  616. process_cmd.smembers = process_cmd.script
  617. process_cmd.sort = process_cmd.set
  618. process_cmd.spop = process_cmd.set
  619. process_cmd.srandmember = process_cmd.set
  620. process_cmd.srem = process_cmd.set
  621. process_cmd.strlen = process_cmd.set
  622. process_cmd.subscribe = process_cmd.script
  623. process_cmd.sunion = process_cmd.mget
  624. process_cmd.sunionstore = process_cmd.mget
  625. process_cmd.swapdb = process_cmd.script
  626. process_cmd.sync = process_cmd.script
  627. process_cmd.time = process_cmd.script
  628. process_cmd.touch = process_cmd.mget
  629. process_cmd.ttl = process_cmd.set
  630. process_cmd.type = process_cmd.set
  631. process_cmd.unsubscribe = process_cmd.script
  632. process_cmd.unlink = process_cmd.mget
  633. process_cmd.unwatch = process_cmd.script
  634. process_cmd.wait = process_cmd.script
  635. process_cmd.watch = process_cmd.mget
  636. process_cmd.zadd = process_cmd.set
  637. process_cmd.zcard = process_cmd.set
  638. process_cmd.zcount = process_cmd.set
  639. process_cmd.zincrby = process_cmd.set
  640. process_cmd.zinterstore = process_cmd.eval
  641. process_cmd.zlexcount = process_cmd.set
  642. process_cmd.zrange = process_cmd.set
  643. process_cmd.zrangebylex = process_cmd.set
  644. process_cmd.zrank = process_cmd.set
  645. process_cmd.zrem = process_cmd.set
  646. process_cmd.zrembylex = process_cmd.set
  647. process_cmd.zrembyrank = process_cmd.set
  648. process_cmd.zrembyscore = process_cmd.set
  649. process_cmd.zrevrange = process_cmd.set
  650. process_cmd.zrevrangebyscore = process_cmd.set
  651. process_cmd.zrevrank = process_cmd.set
  652. process_cmd.zscore = process_cmd.set
  653. process_cmd.zunionstore = process_cmd.eval
  654. process_cmd.scan = process_cmd.script
  655. process_cmd.sscan = process_cmd.set
  656. process_cmd.hscan = process_cmd.set
  657. process_cmd.zscan = process_cmd.set
  658. local function get_key_indexes(cmd, args)
  659. local idx_l = {}
  660. cmd = string.lower(cmd)
  661. if process_cmd[cmd] then
  662. idx_l = process_cmd[cmd](args)
  663. else
  664. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  665. end
  666. return idx_l
  667. end
  668. local gen_meta = {
  669. principal_recipient = function(task)
  670. return task:get_principal_recipient()
  671. end,
  672. principal_recipient_domain = function(task)
  673. local p = task:get_principal_recipient()
  674. if not p then return end
  675. return string.match(p, '.*@(.*)')
  676. end,
  677. ip = function(task)
  678. local i = task:get_ip()
  679. if i and i:is_valid() then return i:to_string() end
  680. end,
  681. from = function(task)
  682. return ((task:get_from('smtp') or E)[1] or E)['addr']
  683. end,
  684. from_domain = function(task)
  685. return ((task:get_from('smtp') or E)[1] or E)['domain']
  686. end,
  687. from_domain_or_helo_domain = function(task)
  688. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  689. if d and #d > 0 then return d end
  690. return task:get_helo()
  691. end,
  692. mime_from = function(task)
  693. return ((task:get_from('mime') or E)[1] or E)['addr']
  694. end,
  695. mime_from_domain = function(task)
  696. return ((task:get_from('mime') or E)[1] or E)['domain']
  697. end,
  698. }
  699. local function gen_get_esld(f)
  700. return function(task)
  701. local d = f(task)
  702. if not d then return end
  703. return rspamd_util.get_tld(d)
  704. end
  705. end
  706. gen_meta.smtp_from = gen_meta.from
  707. gen_meta.smtp_from_domain = gen_meta.from_domain
  708. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  709. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  710. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  711. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  712. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  713. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  714. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  715. local function get_key_expansion_metadata(task)
  716. local md_mt = {
  717. __index = function(self, k)
  718. k = string.lower(k)
  719. local v = rawget(self, k)
  720. if v then
  721. return v
  722. end
  723. if gen_meta[k] then
  724. v = gen_meta[k](task)
  725. rawset(self, k, v)
  726. end
  727. return v
  728. end,
  729. }
  730. local lazy_meta = {}
  731. setmetatable(lazy_meta, md_mt)
  732. return lazy_meta
  733. end
  734. -- Performs async call to redis hiding all complexity inside function
  735. -- task - rspamd_task
  736. -- redis_params - valid params returned by rspamd_parse_redis_server
  737. -- key - key to select upstream or nil to select round-robin/master-slave
  738. -- is_write - true if need to write to redis server
  739. -- callback - function to be called upon request is completed
  740. -- command - redis command
  741. -- args - table of arguments
  742. -- extra_opts - table of optional request arguments
  743. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  744. callback, command, args, extra_opts)
  745. local addr
  746. local function rspamd_redis_make_request_cb(err, data)
  747. if err then
  748. addr:fail()
  749. else
  750. addr:ok()
  751. end
  752. callback(err, data, addr)
  753. end
  754. if not task or not redis_params or not callback or not command then
  755. return false,nil,nil
  756. end
  757. local rspamd_redis = require "rspamd_redis"
  758. if key then
  759. if is_write then
  760. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  761. else
  762. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  763. end
  764. else
  765. if is_write then
  766. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  767. else
  768. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  769. end
  770. end
  771. if not addr then
  772. logger.errx(task, 'cannot select server to make redis request')
  773. end
  774. if redis_params['expand_keys'] then
  775. local m = get_key_expansion_metadata(task)
  776. local indexes = get_key_indexes(command, args)
  777. for _, i in ipairs(indexes) do
  778. args[i] = lutil.template(args[i], m)
  779. end
  780. end
  781. local ip_addr = addr:get_addr()
  782. local options = {
  783. task = task,
  784. callback = rspamd_redis_make_request_cb,
  785. host = ip_addr,
  786. timeout = redis_params['timeout'],
  787. cmd = command,
  788. args = args
  789. }
  790. if extra_opts then
  791. for k,v in pairs(extra_opts) do
  792. options[k] = v
  793. end
  794. end
  795. if redis_params['password'] then
  796. options['password'] = redis_params['password']
  797. end
  798. if redis_params['db'] then
  799. options['dbname'] = redis_params['db']
  800. end
  801. lutil.debugm(N, task, 'perform request to redis server' ..
  802. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', ip_addr,
  803. options.timeout, options.cmd, args)
  804. local ret,conn = rspamd_redis.make_request(options)
  805. if not ret then
  806. addr:fail()
  807. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  808. end
  809. return ret,conn,addr
  810. end
  811. --[[[
  812. -- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
  813. -- Sends a request to Redis
  814. -- @param {rspamd_task} task task object
  815. -- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
  816. -- @param {string} key key to use for sharding
  817. -- @param {boolean} is_write should be `true` if we are performing a write operating
  818. -- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
  819. -- @param {string} command Redis command to run
  820. -- @param {table} args Numerically indexed table containing arguments for command
  821. --]]
  822. exports.rspamd_redis_make_request = rspamd_redis_make_request
  823. exports.redis_make_request = rspamd_redis_make_request
  824. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  825. is_write, callback, command, args, extra_opts)
  826. if not ev_base or not redis_params or not callback or not command then
  827. return false,nil,nil
  828. end
  829. local addr
  830. local function rspamd_redis_make_request_cb(err, data)
  831. if err then
  832. addr:fail()
  833. else
  834. addr:ok()
  835. end
  836. callback(err, data, addr)
  837. end
  838. local rspamd_redis = require "rspamd_redis"
  839. if key then
  840. if is_write then
  841. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  842. else
  843. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  844. end
  845. else
  846. if is_write then
  847. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  848. else
  849. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  850. end
  851. end
  852. if not addr then
  853. logger.errx(cfg, 'cannot select server to make redis request')
  854. end
  855. local options = {
  856. ev_base = ev_base,
  857. config = cfg,
  858. callback = rspamd_redis_make_request_cb,
  859. host = addr:get_addr(),
  860. timeout = redis_params['timeout'],
  861. cmd = command,
  862. args = args
  863. }
  864. if extra_opts then
  865. for k,v in pairs(extra_opts) do
  866. options[k] = v
  867. end
  868. end
  869. if redis_params['password'] then
  870. options['password'] = redis_params['password']
  871. end
  872. if redis_params['db'] then
  873. options['dbname'] = redis_params['db']
  874. end
  875. lutil.debugm(N, cfg, 'perform taskless request to redis server' ..
  876. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', options.host,
  877. options.timeout, options.cmd, args)
  878. local ret,conn = rspamd_redis.make_request(options)
  879. if not ret then
  880. logger.errx('cannot execute redis request')
  881. addr:fail()
  882. end
  883. return ret,conn,addr
  884. end
  885. --[[[
  886. -- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
  887. -- Sends a request to Redis in context where `task` is not available for some specific use-cases
  888. -- Identical to redis_make_request() except in that first parameter is an `event base` object
  889. --]]
  890. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  891. exports.redis_make_request_taskless = redis_make_request_taskless
  892. local redis_scripts = {
  893. }
  894. local function script_set_loaded(script)
  895. if script.sha then
  896. script.loaded = true
  897. end
  898. local wait_table = {}
  899. for _,s in ipairs(script.waitq) do
  900. table.insert(wait_table, s)
  901. end
  902. script.waitq = {}
  903. for _,s in ipairs(wait_table) do
  904. s(script.loaded)
  905. end
  906. end
  907. local function prepare_redis_call(script)
  908. local function merge_tables(t1, t2)
  909. for k,v in pairs(t2) do t1[k] = v end
  910. end
  911. local servers = {}
  912. local options = {}
  913. if script.redis_params.read_servers then
  914. merge_tables(servers, script.redis_params.read_servers:all_upstreams())
  915. end
  916. if script.redis_params.write_servers then
  917. merge_tables(servers, script.redis_params.write_servers:all_upstreams())
  918. end
  919. -- Call load script on each server, set loaded flag
  920. script.in_flight = #servers
  921. for _,s in ipairs(servers) do
  922. local cur_opts = {
  923. host = s:get_addr(),
  924. timeout = script.redis_params['timeout'],
  925. cmd = 'SCRIPT',
  926. args = {'LOAD', script.script },
  927. upstream = s
  928. }
  929. if script.redis_params['password'] then
  930. cur_opts['password'] = script.redis_params['password']
  931. end
  932. if script.redis_params['db'] then
  933. cur_opts['dbname'] = script.redis_params['db']
  934. end
  935. table.insert(options, cur_opts)
  936. end
  937. return options
  938. end
  939. local function load_script_task(script, task)
  940. local rspamd_redis = require "rspamd_redis"
  941. local opts = prepare_redis_call(script)
  942. for _,opt in ipairs(opts) do
  943. opt.task = task
  944. opt.callback = function(err, data)
  945. if err then
  946. logger.errx(task, 'cannot upload script to %s: %s',
  947. opt.upstream:get_addr(), err)
  948. opt.upstream:fail()
  949. script.fatal_error = err
  950. else
  951. opt.upstream:ok()
  952. logger.infox(task,
  953. "uploaded redis script to %s with id %s, sha: %s",
  954. opt.upstream:get_addr(), script.id, data)
  955. script.sha = data -- We assume that sha is the same on all servers
  956. end
  957. script.in_flight = script.in_flight - 1
  958. if script.in_flight == 0 then
  959. script_set_loaded(script)
  960. end
  961. end
  962. local ret = rspamd_redis.make_request(opt)
  963. if not ret then
  964. logger.errx('cannot execute redis request to load script on %s',
  965. opt.upstream:get_addr())
  966. script.in_flight = script.in_flight - 1
  967. opt.upstream:fail()
  968. end
  969. if script.in_flight == 0 then
  970. script_set_loaded(script)
  971. end
  972. end
  973. end
  974. local function load_script_taskless(script, cfg, ev_base)
  975. local rspamd_redis = require "rspamd_redis"
  976. local opts = prepare_redis_call(script)
  977. for _,opt in ipairs(opts) do
  978. opt.config = cfg
  979. opt.ev_base = ev_base
  980. opt.callback = function(err, data)
  981. if err then
  982. logger.errx(cfg, 'cannot upload script to %s: %s',
  983. opt.upstream:get_addr(), err)
  984. opt.upstream:fail()
  985. script.fatal_error = err
  986. else
  987. opt.upstream:ok()
  988. logger.infox(cfg,
  989. "uploaded redis script to %s with id %s, sha: %s",
  990. opt.upstream:get_addr(), script.id, data)
  991. script.sha = data -- We assume that sha is the same on all servers
  992. script.fatal_error = nil
  993. end
  994. script.in_flight = script.in_flight - 1
  995. if script.in_flight == 0 then
  996. script_set_loaded(script)
  997. end
  998. end
  999. local ret = rspamd_redis.make_request(opt)
  1000. if not ret then
  1001. logger.errx('cannot execute redis request to load script on %s',
  1002. opt.upstream:get_addr())
  1003. script.in_flight = script.in_flight - 1
  1004. opt.upstream:fail()
  1005. end
  1006. if script.in_flight == 0 then
  1007. script_set_loaded(script)
  1008. end
  1009. end
  1010. end
  1011. local function load_redis_script(script, cfg, ev_base, _)
  1012. if script.redis_params then
  1013. load_script_taskless(script, cfg, ev_base)
  1014. end
  1015. end
  1016. local function add_redis_script(script, redis_params)
  1017. local new_script = {
  1018. loaded = false,
  1019. redis_params = redis_params,
  1020. script = script,
  1021. waitq = {}, -- callbacks pending for script being loaded
  1022. id = #redis_scripts + 1
  1023. }
  1024. -- Register on load function
  1025. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  1026. local mult = 0.0
  1027. rspamd_config:add_periodic(ev_base, 0.0, function()
  1028. if not new_script.sha then
  1029. load_redis_script(new_script, cfg, ev_base, worker)
  1030. mult = mult + 1
  1031. return 1.0 * mult -- Check one more time in one second
  1032. end
  1033. return false
  1034. end, false)
  1035. end)
  1036. table.insert(redis_scripts, new_script)
  1037. return #redis_scripts
  1038. end
  1039. exports.add_redis_script = add_redis_script
  1040. local function exec_redis_script(id, params, callback, keys, args)
  1041. local redis_args = {}
  1042. if not redis_scripts[id] then
  1043. logger.errx("cannot find registered script with id %s", id)
  1044. return false
  1045. end
  1046. local script = redis_scripts[id]
  1047. if script.fatal_error then
  1048. callback(script.fatal_error, nil)
  1049. return true
  1050. end
  1051. if not script.redis_params then
  1052. callback('no redis servers defined', nil)
  1053. return true
  1054. end
  1055. local function do_call(can_reload)
  1056. local function redis_cb(err, data)
  1057. if not err then
  1058. callback(err, data)
  1059. elseif string.match(err, 'NOSCRIPT') then
  1060. -- Schedule restart
  1061. script.sha = nil
  1062. if can_reload then
  1063. table.insert(script.waitq, do_call)
  1064. if script.in_flight == 0 then
  1065. -- Reload scripts if this has not been initiated yet
  1066. if params.task then
  1067. load_script_task(script, params.task)
  1068. else
  1069. load_script_taskless(script, rspamd_config, params.ev_base)
  1070. end
  1071. end
  1072. else
  1073. callback(err, data)
  1074. end
  1075. else
  1076. callback(err, data)
  1077. end
  1078. end
  1079. if #redis_args == 0 then
  1080. table.insert(redis_args, script.sha)
  1081. table.insert(redis_args, tostring(#keys))
  1082. for _,k in ipairs(keys) do
  1083. table.insert(redis_args, k)
  1084. end
  1085. if type(args) == 'table' then
  1086. for _, a in ipairs(args) do
  1087. table.insert(redis_args, a)
  1088. end
  1089. end
  1090. end
  1091. if params.task then
  1092. if not rspamd_redis_make_request(params.task, script.redis_params,
  1093. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1094. callback('Cannot make redis request', nil)
  1095. end
  1096. else
  1097. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  1098. script.redis_params,
  1099. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1100. callback('Cannot make redis request', nil)
  1101. end
  1102. end
  1103. end
  1104. if script.loaded then
  1105. do_call(true)
  1106. else
  1107. -- Delayed until scripts are loaded
  1108. if not params.task then
  1109. table.insert(script.waitq, do_call)
  1110. else
  1111. -- TODO: fix taskfull requests
  1112. callback('NOSCRIPT', nil)
  1113. end
  1114. end
  1115. return true
  1116. end
  1117. exports.exec_redis_script = exec_redis_script
  1118. local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
  1119. if not redis_params then
  1120. return false,nil
  1121. end
  1122. local rspamd_redis = require "rspamd_redis"
  1123. local addr
  1124. if key then
  1125. if is_write then
  1126. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  1127. else
  1128. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  1129. end
  1130. else
  1131. if is_write then
  1132. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  1133. else
  1134. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  1135. end
  1136. end
  1137. if not addr then
  1138. logger.errx(cfg, 'cannot select server to make redis request')
  1139. end
  1140. local options = {
  1141. host = addr:get_addr(),
  1142. timeout = redis_params['timeout'],
  1143. config = cfg or rspamd_config,
  1144. ev_base = ev_base or rspamadm_ev_base,
  1145. session = redis_params.session or rspamadm_session
  1146. }
  1147. for k,v in pairs(redis_params) do
  1148. options[k] = v
  1149. end
  1150. if not options.config then
  1151. logger.errx('config is not set')
  1152. return false,nil,addr
  1153. end
  1154. if not options.ev_base then
  1155. logger.errx('ev_base is not set')
  1156. return false,nil,addr
  1157. end
  1158. if not options.session then
  1159. logger.errx('session is not set')
  1160. return false,nil,addr
  1161. end
  1162. local ret,conn = rspamd_redis.connect_sync(options)
  1163. if not ret then
  1164. logger.errx('cannot execute redis request: %s', conn)
  1165. addr:fail()
  1166. return false,nil,addr
  1167. end
  1168. if conn then
  1169. if redis_params['password'] then
  1170. conn:add_cmd('AUTH', {redis_params['password']})
  1171. end
  1172. if redis_params['db'] then
  1173. conn:add_cmd('SELECT', {tostring(redis_params['db'])})
  1174. elseif redis_params['dbname'] then
  1175. conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
  1176. end
  1177. end
  1178. return ret,conn,addr
  1179. end
  1180. exports.redis_connect_sync = redis_connect_sync
  1181. --[[[
  1182. -- @function lua_redis.request(redis_params, attrs, req)
  1183. -- Sends a request to Redis synchronously with coroutines or asynchronously using
  1184. -- a callback (modern API)
  1185. -- @param redis_params a table of redis server parameters
  1186. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1187. -- @param req a table of request: a command + command options
  1188. -- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
  1189. --]]
  1190. exports.request = function(redis_params, attrs, req)
  1191. local lua_util = require "lua_util"
  1192. if not attrs or not redis_params or not req then
  1193. logger.errx('invalid arguments for redis request')
  1194. return false,nil,nil
  1195. end
  1196. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1197. logger.errx('invalid attributes for redis request')
  1198. return false,nil,nil
  1199. end
  1200. local opts = lua_util.shallowcopy(attrs)
  1201. local log_obj = opts.task or opts.config
  1202. local addr
  1203. if opts.callback then
  1204. -- Wrap callback
  1205. local callback = opts.callback
  1206. local function rspamd_redis_make_request_cb(err, data)
  1207. if err then
  1208. addr:fail()
  1209. else
  1210. addr:ok()
  1211. end
  1212. callback(err, data, addr)
  1213. end
  1214. opts.callback = rspamd_redis_make_request_cb
  1215. end
  1216. local rspamd_redis = require "rspamd_redis"
  1217. local is_write = opts.is_write
  1218. if opts.key then
  1219. if is_write then
  1220. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1221. else
  1222. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1223. end
  1224. else
  1225. if is_write then
  1226. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1227. else
  1228. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1229. end
  1230. end
  1231. if not addr then
  1232. logger.errx(log_obj, 'cannot select server to make redis request')
  1233. end
  1234. opts.host = addr:get_addr()
  1235. opts.timeout = redis_params.timeout
  1236. if type(req) == 'string' then
  1237. opts.cmd = req
  1238. else
  1239. -- XXX: modifies the input table
  1240. opts.cmd = table.remove(req, 1);
  1241. opts.args = req
  1242. end
  1243. if redis_params.password then
  1244. opts.password = redis_params.password
  1245. end
  1246. if redis_params.db then
  1247. opts.dbname = redis_params.db
  1248. end
  1249. lutil.debugm(N, 'perform generic request to redis server' ..
  1250. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr,
  1251. opts.timeout, opts.cmd, opts.args)
  1252. if opts.callback then
  1253. local ret,conn = rspamd_redis.make_request(opts)
  1254. if not ret then
  1255. logger.errx(log_obj, 'cannot execute redis request')
  1256. addr:fail()
  1257. end
  1258. return ret,conn,addr
  1259. else
  1260. -- Coroutines version
  1261. local ret,conn = rspamd_redis.connect_sync(opts)
  1262. if not ret then
  1263. logger.errx(log_obj, 'cannot execute redis request')
  1264. addr:fail()
  1265. else
  1266. conn:add_cmd(opts.cmd, opts.args)
  1267. return conn:exec()
  1268. end
  1269. return false,nil,addr
  1270. end
  1271. end
  1272. --[[[
  1273. -- @function lua_redis.connect(redis_params, attrs)
  1274. -- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
  1275. -- @param redis_params a table of redis server parameters
  1276. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1277. -- @return {result,connection,address} boolean result, connection object, redis server address
  1278. --]]
  1279. exports.connect = function(redis_params, attrs)
  1280. local lua_util = require "lua_util"
  1281. if not attrs or not redis_params then
  1282. logger.errx('invalid arguments for redis connect')
  1283. return false,nil,nil
  1284. end
  1285. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1286. logger.errx('invalid attributes for redis connect')
  1287. return false,nil,nil
  1288. end
  1289. local opts = lua_util.shallowcopy(attrs)
  1290. local log_obj = opts.task or opts.config
  1291. local addr
  1292. if opts.callback then
  1293. -- Wrap callback
  1294. local callback = opts.callback
  1295. local function rspamd_redis_make_request_cb(err, data)
  1296. if err then
  1297. addr:fail()
  1298. else
  1299. addr:ok()
  1300. end
  1301. callback(err, data, addr)
  1302. end
  1303. opts.callback = rspamd_redis_make_request_cb
  1304. end
  1305. local rspamd_redis = require "rspamd_redis"
  1306. local is_write = opts.is_write
  1307. if opts.key then
  1308. if is_write then
  1309. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1310. else
  1311. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1312. end
  1313. else
  1314. if is_write then
  1315. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1316. else
  1317. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1318. end
  1319. end
  1320. if not addr then
  1321. logger.errx(log_obj, 'cannot select server to make redis connect')
  1322. end
  1323. opts.host = addr:get_addr()
  1324. opts.timeout = redis_params.timeout
  1325. if redis_params.password then
  1326. opts.password = redis_params.password
  1327. end
  1328. if redis_params.db then
  1329. opts.dbname = redis_params.db
  1330. end
  1331. if opts.callback then
  1332. local ret,conn = rspamd_redis.connect(opts)
  1333. if not ret then
  1334. logger.errx(log_obj, 'cannot execute redis connect')
  1335. addr:fail()
  1336. end
  1337. return ret,conn,addr
  1338. else
  1339. -- Coroutines version
  1340. local ret,conn = rspamd_redis.connect_sync(opts)
  1341. if not ret then
  1342. logger.errx(log_obj, 'cannot execute redis connect')
  1343. addr:fail()
  1344. else
  1345. return true,conn,addr
  1346. end
  1347. return false,nil,addr
  1348. end
  1349. end
  1350. return exports