You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_redis.lua 47KB


  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local ts = require("tableshape").types
  17. local exports = {}
  18. local E = {}
  19. local N = "lua_redis"
  20. local common_schema = ts.shape {
  21. timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  22. db = ts.string:is_optional(),
  23. database = ts.string:is_optional(),
  24. dbname = ts.string:is_optional(),
  25. prefix = ts.string:is_optional(),
  26. password = ts.string:is_optional(),
  27. expand_keys = ts.boolean:is_optional(),
  28. sentinels = (ts.string + ts.array_of(ts.string)):is_optional(),
  29. sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  30. sentinel_masters_pattern = ts.string:is_optional(),
  31. sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional(),
  32. }
  33. local config_schema =
  34. ts.shape({
  35. read_servers = ts.string + ts.array_of(ts.string),
  36. write_servers = ts.string + ts.array_of(ts.string),
  37. }, {extra_opts = common_schema}) +
  38. ts.shape({
  39. servers = ts.string + ts.array_of(ts.string),
  40. }, {extra_opts = common_schema}) +
  41. ts.shape({
  42. server = ts.string + ts.array_of(ts.string),
  43. }, {extra_opts = common_schema})
  44. exports.config_schema = config_schema
  45. local function redis_query_sentinel(ev_base, params, initialised)
  46. local function flatten_redis_table(tbl)
  47. local res = {}
  48. for i=1,#tbl,2 do
  49. res[tbl[i]] = tbl[i + 1]
  50. end
  51. return res
  52. end
  53. -- Coroutines syntax
  54. local rspamd_redis = require "rspamd_redis"
  55. local sentinels = params.sentinels
  56. local addr = sentinels:get_upstream_round_robin()
  57. local host = addr:get_addr()
  58. local masters = {}
  59. local process_masters -- Function that is called to process masters data
  60. local function masters_cb(err, result)
  61. if not err and result and type(result) == 'table' then
  62. local pending_subrequests = 0
  63. for _,m in ipairs(result) do
  64. local master = flatten_redis_table(m)
  65. -- Wrap IPv6-addresses in brackets
  66. if (master.ip:match(":")) then
  67. master.ip = "["..master.ip.."]"
  68. end
  69. if params.sentinel_masters_pattern then
  70. if master.name:match(params.sentinel_masters_pattern) then
  71. lutil.debugm(N, 'found master %s with ip %s and port %s',
  72. master.name, master.ip, master.port)
  73. masters[master.name] = master
  74. else
  75. lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
  76. master.name, master.ip, master.port, params.sentinel_masters_pattern)
  77. end
  78. else
  79. lutil.debugm(N, 'found master %s with ip %s and port %s',
  80. master.name, master.ip, master.port)
  81. masters[master.name] = master
  82. end
  83. end
  84. -- For each master we need to get a list of slaves
  85. for k,v in pairs(masters) do
  86. v.slaves = {}
  87. local function slaves_cb(slave_err, slave_result)
  88. if not slave_err and type(slave_result) == 'table' then
  89. for _,s in ipairs(slave_result) do
  90. local slave = flatten_redis_table(s)
  91. lutil.debugm(N, rspamd_config,
  92. 'found slave for master %s with ip %s and port %s',
  93. v.name, slave.ip, slave.port)
  94. -- Wrap IPv6-addresses in brackets
  95. if (slave.ip:match(":")) then
  96. slave.ip = "["..slave.ip.."]"
  97. end
  98. v.slaves[#v.slaves + 1] = slave
  99. end
  100. else
  101. logger.errx('cannot get slaves data from Redis Sentinel %s: %s',
  102. host:to_string(true), slave_err)
  103. addr:fail()
  104. end
  105. pending_subrequests = pending_subrequests - 1
  106. if pending_subrequests == 0 then
  107. -- Finalize masters and slaves
  108. process_masters()
  109. end
  110. end
  111. local ret = rspamd_redis.make_request({
  112. host = addr:get_addr(),
  113. timeout = params.timeout,
  114. config = rspamd_config,
  115. ev_base = ev_base,
  116. cmd = 'SENTINEL',
  117. args = {'slaves', k},
  118. no_pool = true,
  119. callback = slaves_cb
  120. })
  121. if not ret then
  122. logger.errx(rspamd_config, 'cannot connect sentinel when query slaves at address: %s',
  123. host:to_string(true))
  124. addr:fail()
  125. else
  126. pending_subrequests = pending_subrequests + 1
  127. end
  128. end
  129. addr:ok()
  130. else
  131. logger.errx('cannot get masters data from Redis Sentinel %s: %s',
  132. host:to_string(true), err)
  133. addr:fail()
  134. end
  135. end
  136. local ret = rspamd_redis.make_request({
  137. host = addr:get_addr(),
  138. timeout = params.timeout,
  139. config = rspamd_config,
  140. ev_base = ev_base,
  141. cmd = 'SENTINEL',
  142. args = {'masters'},
  143. no_pool = true,
  144. callback = masters_cb,
  145. })
  146. if not ret then
  147. logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
  148. host:to_string(true))
  149. addr:fail()
  150. end
  151. process_masters = function()
  152. -- We now form new strings for masters and slaves
  153. local read_servers_tbl, write_servers_tbl = {}, {}
  154. for _,master in pairs(masters) do
  155. write_servers_tbl[#write_servers_tbl + 1] = string.format(
  156. '%s:%s', master.ip, master.port
  157. )
  158. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  159. '%s:%s', master.ip, master.port
  160. )
  161. for _,slave in ipairs(master.slaves) do
  162. if slave['master-link-status'] == 'ok' then
  163. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  164. '%s:%s', slave.ip, slave.port
  165. )
  166. end
  167. end
  168. end
  169. table.sort(read_servers_tbl)
  170. table.sort(write_servers_tbl)
  171. local read_servers_str = table.concat(read_servers_tbl, ',')
  172. local write_servers_str = table.concat(write_servers_tbl, ',')
  173. lutil.debugm(N, rspamd_config,
  174. 'new servers list: %s read; %s write',
  175. read_servers_str,
  176. write_servers_str)
  177. if read_servers_str ~= params.read_servers_str then
  178. local upstream_list = require "rspamd_upstream_list"
  179. local read_upstreams = upstream_list.create(rspamd_config,
  180. read_servers_str, 6379)
  181. if read_upstreams then
  182. logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
  183. host:to_string(true), read_servers_str)
  184. params.read_servers = read_upstreams
  185. params.read_servers_str = read_servers_str
  186. end
  187. end
  188. if write_servers_str ~= params.write_servers_str then
  189. local upstream_list = require "rspamd_upstream_list"
  190. local write_upstreams = upstream_list.create(rspamd_config,
  191. write_servers_str, 6379)
  192. if write_upstreams then
  193. logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
  194. host:to_string(true), write_servers_str)
  195. params.write_servers = write_upstreams
  196. params.write_servers_str = write_servers_str
  197. local queried = false
  198. local function monitor_failures(up, _, count)
  199. if count > params.sentinel_master_maxerrors and not queried then
  200. logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel',
  201. host:to_string(true), count)
  202. queried = true -- Avoid multiple checks caused by this monitor
  203. redis_query_sentinel(ev_base, params, true)
  204. end
  205. end
  206. write_upstreams:add_watcher('failure', monitor_failures)
  207. end
  208. end
  209. end
  210. end
  211. local function add_redis_sentinels(params)
  212. local upstream_list = require "rspamd_upstream_list"
  213. local upstreams_sentinels = upstream_list.create(rspamd_config,
  214. params.sentinels, 5000)
  215. if not upstreams_sentinels then
  216. logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
  217. params.sentinels)
  218. return
  219. end
  220. params.sentinels = upstreams_sentinels
  221. if not params.sentinel_watch_time then
  222. params.sentinel_watch_time = 60 -- Each minute
  223. end
  224. if not params.sentinel_master_maxerrors then
  225. params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking
  226. end
  227. rspamd_config:add_on_load(function(_, ev_base, worker)
  228. local initialised = false
  229. if worker:is_scanner() or worker:get_type() == 'fuzzy' then
  230. rspamd_config:add_periodic(ev_base, 0.0, function()
  231. redis_query_sentinel(ev_base, params, initialised)
  232. initialised = true
  233. return params.sentinel_watch_time
  234. end, false)
  235. end
  236. end)
  237. end
  238. local cached_results = {}
  239. local function calculate_redis_hash(params)
  240. local cr = require "rspamd_cryptobox_hash"
  241. local h = cr.create()
  242. local function rec_hash(k, v)
  243. if type(v) == 'string' then
  244. h:update(k)
  245. h:update(v)
  246. elseif type(v) == 'number' then
  247. h:update(k)
  248. h:update(tostring(v))
  249. elseif type(v) == 'table' then
  250. for kk,vv in pairs(v) do
  251. rec_hash(kk, vv)
  252. end
  253. end
  254. end
  255. rec_hash('top', params)
  256. return h:base32()
  257. end
  258. local function process_redis_opts(options, redis_params)
  259. local default_timeout = 1.0
  260. local default_expand_keys = false
  261. if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then
  262. if options['timeout'] then
  263. redis_params['timeout'] = tonumber(options['timeout'])
  264. else
  265. redis_params['timeout'] = default_timeout
  266. end
  267. end
  268. if options['prefix'] and not redis_params['prefix'] then
  269. redis_params['prefix'] = options['prefix']
  270. end
  271. if type(options['expand_keys']) == 'boolean' then
  272. redis_params['expand_keys'] = options['expand_keys']
  273. else
  274. redis_params['expand_keys'] = default_expand_keys
  275. end
  276. if not redis_params['db'] then
  277. if options['db'] then
  278. redis_params['db'] = tostring(options['db'])
  279. elseif options['dbname'] then
  280. redis_params['db'] = tostring(options['dbname'])
  281. elseif options['database'] then
  282. redis_params['db'] = tostring(options['database'])
  283. end
  284. end
  285. if options['password'] and not redis_params['password'] then
  286. redis_params['password'] = options['password']
  287. end
  288. if not redis_params.sentinels and options.sentinels then
  289. redis_params.sentinels = options.sentinels
  290. end
  291. if options['sentinel_masters_pattern'] and not redis_params['sentinel_masters_pattern'] then
  292. redis_params['sentinel_masters_pattern'] = options['sentinel_masters_pattern']
  293. end
  294. end
  295. local function enrich_defaults(rspamd_config, module, redis_params)
  296. if rspamd_config then
  297. local opts = rspamd_config:get_all_opt('redis')
  298. if opts then
  299. if module then
  300. if opts[module] then
  301. process_redis_opts(opts[module], redis_params)
  302. end
  303. end
  304. process_redis_opts(opts, redis_params)
  305. end
  306. end
  307. end
  308. local function maybe_return_cached(redis_params)
  309. local h = calculate_redis_hash(redis_params)
  310. if cached_results[h] then
  311. lutil.debugm(N, 'reused redis server: %s', redis_params)
  312. return cached_results[h]
  313. end
  314. redis_params.hash = h
  315. cached_results[h] = redis_params
  316. if not redis_params.read_only and redis_params.sentinels then
  317. add_redis_sentinels(redis_params)
  318. end
  319. lutil.debugm(N, 'loaded new redis server: %s', redis_params)
  320. return redis_params
  321. end
  322. --[[[
  323. -- @module lua_redis
  324. -- This module contains helper functions for working with Redis
  325. --]]
  326. local function process_redis_options(options, rspamd_config, result)
  327. local default_port = 6379
  328. local upstream_list = require "rspamd_upstream_list"
  329. local read_only = true
  330. -- Try to get read servers:
  331. local upstreams_read, upstreams_write
  332. if options['read_servers'] then
  333. if rspamd_config then
  334. upstreams_read = upstream_list.create(rspamd_config,
  335. options['read_servers'], default_port)
  336. else
  337. upstreams_read = upstream_list.create(options['read_servers'],
  338. default_port)
  339. end
  340. result.read_servers_str = options['read_servers']
  341. elseif options['servers'] then
  342. if rspamd_config then
  343. upstreams_read = upstream_list.create(rspamd_config,
  344. options['servers'], default_port)
  345. else
  346. upstreams_read = upstream_list.create(options['servers'], default_port)
  347. end
  348. result.read_servers_str = options['servers']
  349. read_only = false
  350. elseif options['server'] then
  351. if rspamd_config then
  352. upstreams_read = upstream_list.create(rspamd_config,
  353. options['server'], default_port)
  354. else
  355. upstreams_read = upstream_list.create(options['server'], default_port)
  356. end
  357. result.read_servers_str = options['server']
  358. read_only = false
  359. end
  360. if upstreams_read then
  361. if options['write_servers'] then
  362. if rspamd_config then
  363. upstreams_write = upstream_list.create(rspamd_config,
  364. options['write_servers'], default_port)
  365. else
  366. upstreams_write = upstream_list.create(options['write_servers'],
  367. default_port)
  368. end
  369. result.write_servers_str = options['write_servers']
  370. read_only = false
  371. elseif not read_only then
  372. upstreams_write = upstreams_read
  373. result.write_servers_str = result.read_servers_str
  374. end
  375. end
  376. -- Store options
  377. process_redis_opts(options, result)
  378. if read_only and not upstreams_write then
  379. result.read_only = true
  380. elseif upstreams_write then
  381. result.read_only = false
  382. end
  383. if upstreams_read then
  384. result.read_servers = upstreams_read
  385. if upstreams_write then
  386. result.write_servers = upstreams_write
  387. end
  388. return true
  389. end
  390. lutil.debugm(N, rspamd_config,
  391. 'cannot load redis server from obj: %s, processed to %s',
  392. options, result)
  393. return false
  394. end
  395. --[[[
  396. @function try_load_redis_servers(options, rspamd_config, no_fallback)
  397. Tries to load redis servers from the specified `options` object.
  398. Returns `redis_params` table or nil in case of failure
  399. --]]
  400. exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name)
  401. local result = {}
  402. if process_redis_options(options, rspamd_config, result) then
  403. if not no_fallback then
  404. enrich_defaults(rspamd_config, module_name, result)
  405. end
  406. return maybe_return_cached(result)
  407. end
  408. end
  409. -- This function parses redis server definition using either
  410. -- specific server string for this module or global
  411. -- redis section
  412. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  413. local result = {}
  414. -- Try local options
  415. local opts
  416. lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name)
  417. if not module_opts then
  418. opts = rspamd_config:get_all_opt(module_name)
  419. else
  420. opts = module_opts
  421. end
  422. if opts then
  423. local ret
  424. if opts.redis then
  425. ret = process_redis_options(opts.redis, rspamd_config, result)
  426. if ret then
  427. if not no_fallback then
  428. enrich_defaults(rspamd_config, module_name, result)
  429. end
  430. return maybe_return_cached(result)
  431. end
  432. end
  433. ret = process_redis_options(opts, rspamd_config, result)
  434. if ret then
  435. if not no_fallback then
  436. enrich_defaults(rspamd_config, module_name, result)
  437. end
  438. return maybe_return_cached(result)
  439. end
  440. end
  441. if no_fallback then
  442. logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled",
  443. module_name)
  444. return nil
  445. end
  446. -- Try global options
  447. opts = rspamd_config:get_all_opt('redis')
  448. if opts then
  449. local ret
  450. if opts[module_name] then
  451. ret = process_redis_options(opts[module_name], rspamd_config, result)
  452. if ret then
  453. return maybe_return_cached(result)
  454. end
  455. else
  456. ret = process_redis_options(opts, rspamd_config, result)
  457. -- Exclude disabled
  458. if opts['disabled_modules'] then
  459. for _,v in ipairs(opts['disabled_modules']) do
  460. if v == module_name then
  461. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  462. module_name)
  463. return nil
  464. end
  465. end
  466. end
  467. if ret then
  468. logger.infox(rspamd_config, "use default Redis settings for %s",
  469. module_name)
  470. return maybe_return_cached(result)
  471. end
  472. end
  473. end
  474. if result.read_servers then
  475. return maybe_return_cached(result)
  476. end
  477. return nil
  478. end
  479. --[[[
  480. -- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
  481. -- Extracts Redis server settings from configuration
  482. -- @param {string} module_name name of module to get settings for
  483. -- @param {table} module_opts settings for module or `nil` to fetch them from configuration
  484. -- @param {boolean} no_fallback should be `true` if global settings must not be used
  485. -- @return {table} redis server settings
  486. -- @example
  487. -- local rconfig = lua_redis.parse_redis_server('my_module')
  488. -- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
  489. -- -- ['timeout'] contains timeout in seconds
  490. -- -- ['expand_keys'] if true tells that redis key expansion is enabled
  491. --]]
  492. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  493. exports.parse_redis_server = rspamd_parse_redis_server
  494. local process_cmd = {
  495. bitop = function(args)
  496. local idx_l = {}
  497. for i = 2, #args do
  498. table.insert(idx_l, i)
  499. end
  500. return idx_l
  501. end,
  502. blpop = function(args)
  503. local idx_l = {}
  504. for i = 1, #args -1 do
  505. table.insert(idx_l, i)
  506. end
  507. return idx_l
  508. end,
  509. eval = function(args)
  510. local idx_l = {}
  511. local numkeys = args[2]
  512. if numkeys and tonumber(numkeys) >= 1 then
  513. for i = 3, numkeys + 2 do
  514. table.insert(idx_l, i)
  515. end
  516. end
  517. return idx_l
  518. end,
  519. set = function(args)
  520. return {1}
  521. end,
  522. mget = function(args)
  523. local idx_l = {}
  524. for i = 1, #args do
  525. table.insert(idx_l, i)
  526. end
  527. return idx_l
  528. end,
  529. mset = function(args)
  530. local idx_l = {}
  531. for i = 1, #args, 2 do
  532. table.insert(idx_l, i)
  533. end
  534. return idx_l
  535. end,
  536. sdiffstore = function(args)
  537. local idx_l = {}
  538. for i = 2, #args do
  539. table.insert(idx_l, i)
  540. end
  541. return idx_l
  542. end,
  543. smove = function(args)
  544. return {1, 2}
  545. end,
  546. script = function() end
  547. }
  548. process_cmd.append = process_cmd.set
  549. process_cmd.auth = process_cmd.script
  550. process_cmd.bgrewriteaof = process_cmd.script
  551. process_cmd.bgsave = process_cmd.script
  552. process_cmd.bitcount = process_cmd.set
  553. process_cmd.bitfield = process_cmd.set
  554. process_cmd.bitpos = process_cmd.set
  555. process_cmd.brpop = process_cmd.blpop
  556. process_cmd.brpoplpush = process_cmd.blpop
  557. process_cmd.client = process_cmd.script
  558. process_cmd.cluster = process_cmd.script
  559. process_cmd.command = process_cmd.script
  560. process_cmd.config = process_cmd.script
  561. process_cmd.dbsize = process_cmd.script
  562. process_cmd.debug = process_cmd.script
  563. process_cmd.decr = process_cmd.set
  564. process_cmd.decrby = process_cmd.set
  565. process_cmd.del = process_cmd.mget
  566. process_cmd.discard = process_cmd.script
  567. process_cmd.dump = process_cmd.set
  568. process_cmd.echo = process_cmd.script
  569. process_cmd.evalsha = process_cmd.eval
  570. process_cmd.exec = process_cmd.script
  571. process_cmd.exists = process_cmd.mget
  572. process_cmd.expire = process_cmd.set
  573. process_cmd.expireat = process_cmd.set
  574. process_cmd.flushall = process_cmd.script
  575. process_cmd.flushdb = process_cmd.script
  576. process_cmd.geoadd = process_cmd.set
  577. process_cmd.geohash = process_cmd.set
  578. process_cmd.geopos = process_cmd.set
  579. process_cmd.geodist = process_cmd.set
  580. process_cmd.georadius = process_cmd.set
  581. process_cmd.georadiusbymember = process_cmd.set
  582. process_cmd.get = process_cmd.set
  583. process_cmd.getbit = process_cmd.set
  584. process_cmd.getrange = process_cmd.set
  585. process_cmd.getset = process_cmd.set
  586. process_cmd.hdel = process_cmd.set
  587. process_cmd.hexists = process_cmd.set
  588. process_cmd.hget = process_cmd.set
  589. process_cmd.hgetall = process_cmd.set
  590. process_cmd.hincrby = process_cmd.set
  591. process_cmd.hincrbyfloat = process_cmd.set
  592. process_cmd.hkeys = process_cmd.set
  593. process_cmd.hlen = process_cmd.set
  594. process_cmd.hmget = process_cmd.set
  595. process_cmd.hmset = process_cmd.set
  596. process_cmd.hscan = process_cmd.set
  597. process_cmd.hset = process_cmd.set
  598. process_cmd.hsetnx = process_cmd.set
  599. process_cmd.hstrlen = process_cmd.set
  600. process_cmd.hvals = process_cmd.set
  601. process_cmd.incr = process_cmd.set
  602. process_cmd.incrby = process_cmd.set
  603. process_cmd.incrbyfloat = process_cmd.set
  604. process_cmd.info = process_cmd.script
  605. process_cmd.keys = process_cmd.script
  606. process_cmd.lastsave = process_cmd.script
  607. process_cmd.lindex = process_cmd.set
  608. process_cmd.linsert = process_cmd.set
  609. process_cmd.llen = process_cmd.set
  610. process_cmd.lpop = process_cmd.set
  611. process_cmd.lpush = process_cmd.set
  612. process_cmd.lpushx = process_cmd.set
  613. process_cmd.lrange = process_cmd.set
  614. process_cmd.lrem = process_cmd.set
  615. process_cmd.lset = process_cmd.set
  616. process_cmd.ltrim = process_cmd.set
  617. process_cmd.migrate = process_cmd.script
  618. process_cmd.monitor = process_cmd.script
  619. process_cmd.move = process_cmd.set
  620. process_cmd.msetnx = process_cmd.mset
  621. process_cmd.multi = process_cmd.script
  622. process_cmd.object = process_cmd.script
  623. process_cmd.persist = process_cmd.set
  624. process_cmd.pexpire = process_cmd.set
  625. process_cmd.pexpireat = process_cmd.set
  626. process_cmd.pfadd = process_cmd.set
  627. process_cmd.pfcount = process_cmd.set
  628. process_cmd.pfmerge = process_cmd.mget
  629. process_cmd.ping = process_cmd.script
  630. process_cmd.psetex = process_cmd.set
  631. process_cmd.psubscribe = process_cmd.script
  632. process_cmd.pubsub = process_cmd.script
  633. process_cmd.pttl = process_cmd.set
  634. process_cmd.publish = process_cmd.script
  635. process_cmd.punsubscribe = process_cmd.script
  636. process_cmd.quit = process_cmd.script
  637. process_cmd.randomkey = process_cmd.script
  638. process_cmd.readonly = process_cmd.script
  639. process_cmd.readwrite = process_cmd.script
  640. process_cmd.rename = process_cmd.mget
  641. process_cmd.renamenx = process_cmd.mget
  642. process_cmd.restore = process_cmd.set
  643. process_cmd.role = process_cmd.script
  644. process_cmd.rpop = process_cmd.set
  645. process_cmd.rpoplpush = process_cmd.mget
  646. process_cmd.rpush = process_cmd.set
  647. process_cmd.rpushx = process_cmd.set
  648. process_cmd.sadd = process_cmd.set
  649. process_cmd.save = process_cmd.script
  650. process_cmd.scard = process_cmd.set
  651. process_cmd.sdiff = process_cmd.mget
  652. process_cmd.select = process_cmd.script
  653. process_cmd.setbit = process_cmd.set
  654. process_cmd.setex = process_cmd.set
  655. process_cmd.setnx = process_cmd.set
  656. process_cmd.sinterstore = process_cmd.sdiff
  657. process_cmd.sismember = process_cmd.set
  658. process_cmd.slaveof = process_cmd.script
  659. process_cmd.slowlog = process_cmd.script
  660. process_cmd.smembers = process_cmd.script
  661. process_cmd.sort = process_cmd.set
  662. process_cmd.spop = process_cmd.set
  663. process_cmd.srandmember = process_cmd.set
  664. process_cmd.srem = process_cmd.set
  665. process_cmd.strlen = process_cmd.set
  666. process_cmd.subscribe = process_cmd.script
  667. process_cmd.sunion = process_cmd.mget
  668. process_cmd.sunionstore = process_cmd.mget
  669. process_cmd.swapdb = process_cmd.script
  670. process_cmd.sync = process_cmd.script
  671. process_cmd.time = process_cmd.script
  672. process_cmd.touch = process_cmd.mget
  673. process_cmd.ttl = process_cmd.set
  674. process_cmd.type = process_cmd.set
  675. process_cmd.unsubscribe = process_cmd.script
  676. process_cmd.unlink = process_cmd.mget
  677. process_cmd.unwatch = process_cmd.script
  678. process_cmd.wait = process_cmd.script
  679. process_cmd.watch = process_cmd.mget
  680. process_cmd.zadd = process_cmd.set
  681. process_cmd.zcard = process_cmd.set
  682. process_cmd.zcount = process_cmd.set
  683. process_cmd.zincrby = process_cmd.set
  684. process_cmd.zinterstore = process_cmd.eval
  685. process_cmd.zlexcount = process_cmd.set
  686. process_cmd.zrange = process_cmd.set
  687. process_cmd.zrangebylex = process_cmd.set
  688. process_cmd.zrank = process_cmd.set
  689. process_cmd.zrem = process_cmd.set
  690. process_cmd.zrembylex = process_cmd.set
  691. process_cmd.zrembyrank = process_cmd.set
  692. process_cmd.zrembyscore = process_cmd.set
  693. process_cmd.zrevrange = process_cmd.set
  694. process_cmd.zrevrangebyscore = process_cmd.set
  695. process_cmd.zrevrank = process_cmd.set
  696. process_cmd.zscore = process_cmd.set
  697. process_cmd.zunionstore = process_cmd.eval
  698. process_cmd.scan = process_cmd.script
  699. process_cmd.sscan = process_cmd.set
  700. process_cmd.hscan = process_cmd.set
  701. process_cmd.zscan = process_cmd.set
  702. local function get_key_indexes(cmd, args)
  703. local idx_l = {}
  704. cmd = string.lower(cmd)
  705. if process_cmd[cmd] then
  706. idx_l = process_cmd[cmd](args)
  707. else
  708. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  709. end
  710. return idx_l
  711. end
  712. local gen_meta = {
  713. principal_recipient = function(task)
  714. return task:get_principal_recipient()
  715. end,
  716. principal_recipient_domain = function(task)
  717. local p = task:get_principal_recipient()
  718. if not p then return end
  719. return string.match(p, '.*@(.*)')
  720. end,
  721. ip = function(task)
  722. local i = task:get_ip()
  723. if i and i:is_valid() then return i:to_string() end
  724. end,
  725. from = function(task)
  726. return ((task:get_from('smtp') or E)[1] or E)['addr']
  727. end,
  728. from_domain = function(task)
  729. return ((task:get_from('smtp') or E)[1] or E)['domain']
  730. end,
  731. from_domain_or_helo_domain = function(task)
  732. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  733. if d and #d > 0 then return d end
  734. return task:get_helo()
  735. end,
  736. mime_from = function(task)
  737. return ((task:get_from('mime') or E)[1] or E)['addr']
  738. end,
  739. mime_from_domain = function(task)
  740. return ((task:get_from('mime') or E)[1] or E)['domain']
  741. end,
  742. }
  743. local function gen_get_esld(f)
  744. return function(task)
  745. local d = f(task)
  746. if not d then return end
  747. return rspamd_util.get_tld(d)
  748. end
  749. end
  750. gen_meta.smtp_from = gen_meta.from
  751. gen_meta.smtp_from_domain = gen_meta.from_domain
  752. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  753. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  754. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  755. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  756. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  757. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  758. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  759. local function get_key_expansion_metadata(task)
  760. local md_mt = {
  761. __index = function(self, k)
  762. k = string.lower(k)
  763. local v = rawget(self, k)
  764. if v then
  765. return v
  766. end
  767. if gen_meta[k] then
  768. v = gen_meta[k](task)
  769. rawset(self, k, v)
  770. end
  771. return v
  772. end,
  773. }
  774. local lazy_meta = {}
  775. setmetatable(lazy_meta, md_mt)
  776. return lazy_meta
  777. end
  778. -- Performs async call to redis hiding all complexity inside function
  779. -- task - rspamd_task
  780. -- redis_params - valid params returned by rspamd_parse_redis_server
  781. -- key - key to select upstream or nil to select round-robin/master-slave
  782. -- is_write - true if need to write to redis server
  783. -- callback - function to be called upon request is completed
  784. -- command - redis command
  785. -- args - table of arguments
  786. -- extra_opts - table of optional request arguments
  787. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  788. callback, command, args, extra_opts)
  789. local addr
  790. local function rspamd_redis_make_request_cb(err, data)
  791. if err then
  792. addr:fail()
  793. else
  794. addr:ok()
  795. end
  796. if callback then
  797. callback(err, data, addr)
  798. end
  799. end
  800. if not task or not redis_params or not callback or not command then
  801. return false,nil,nil
  802. end
  803. local rspamd_redis = require "rspamd_redis"
  804. if key then
  805. if is_write then
  806. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  807. else
  808. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  809. end
  810. else
  811. if is_write then
  812. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  813. else
  814. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  815. end
  816. end
  817. if not addr then
  818. logger.errx(task, 'cannot select server to make redis request')
  819. end
  820. if redis_params['expand_keys'] then
  821. local m = get_key_expansion_metadata(task)
  822. local indexes = get_key_indexes(command, args)
  823. for _, i in ipairs(indexes) do
  824. args[i] = lutil.template(args[i], m)
  825. end
  826. end
  827. local ip_addr = addr:get_addr()
  828. local options = {
  829. task = task,
  830. callback = rspamd_redis_make_request_cb,
  831. host = ip_addr,
  832. timeout = redis_params['timeout'],
  833. cmd = command,
  834. args = args
  835. }
  836. if extra_opts then
  837. for k,v in pairs(extra_opts) do
  838. options[k] = v
  839. end
  840. end
  841. if redis_params['password'] then
  842. options['password'] = redis_params['password']
  843. end
  844. if redis_params['db'] then
  845. options['dbname'] = redis_params['db']
  846. end
  847. lutil.debugm(N, task, 'perform request to redis server' ..
  848. ' (host=%s, timeout=%s): cmd: %s', ip_addr,
  849. options.timeout, options.cmd)
  850. local ret,conn = rspamd_redis.make_request(options)
  851. if not ret then
  852. addr:fail()
  853. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  854. end
  855. return ret,conn,addr
  856. end
  857. --[[[
  858. -- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
  859. -- Sends a request to Redis
  860. -- @param {rspamd_task} task task object
  861. -- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
  862. -- @param {string} key key to use for sharding
  863. -- @param {boolean} is_write should be `true` if we are performing a write operating
  864. -- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
  865. -- @param {string} command Redis command to run
  866. -- @param {table} args Numerically indexed table containing arguments for command
  867. --]]
  868. exports.rspamd_redis_make_request = rspamd_redis_make_request
  869. exports.redis_make_request = rspamd_redis_make_request
  870. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  871. is_write, callback, command, args, extra_opts)
  872. if not ev_base or not redis_params or not callback or not command then
  873. return false,nil,nil
  874. end
  875. local addr
  876. local function rspamd_redis_make_request_cb(err, data)
  877. if err then
  878. addr:fail()
  879. else
  880. addr:ok()
  881. end
  882. if callback then
  883. callback(err, data, addr)
  884. end
  885. end
  886. local rspamd_redis = require "rspamd_redis"
  887. if key then
  888. if is_write then
  889. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  890. else
  891. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  892. end
  893. else
  894. if is_write then
  895. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  896. else
  897. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  898. end
  899. end
  900. if not addr then
  901. logger.errx(cfg, 'cannot select server to make redis request')
  902. end
  903. local options = {
  904. ev_base = ev_base,
  905. config = cfg,
  906. callback = rspamd_redis_make_request_cb,
  907. host = addr:get_addr(),
  908. timeout = redis_params['timeout'],
  909. cmd = command,
  910. args = args
  911. }
  912. if extra_opts then
  913. for k,v in pairs(extra_opts) do
  914. options[k] = v
  915. end
  916. end
  917. if redis_params['password'] then
  918. options['password'] = redis_params['password']
  919. end
  920. if redis_params['db'] then
  921. options['dbname'] = redis_params['db']
  922. end
  923. lutil.debugm(N, cfg, 'perform taskless request to redis server' ..
  924. ' (host=%s, timeout=%s): cmd: %s', options.host:tostring(true),
  925. options.timeout, options.cmd)
  926. local ret,conn = rspamd_redis.make_request(options)
  927. if not ret then
  928. logger.errx('cannot execute redis request')
  929. addr:fail()
  930. end
  931. return ret,conn,addr
  932. end
  933. --[[[
  934. -- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
  935. -- Sends a request to Redis in context where `task` is not available for some specific use-cases
  936. -- Identical to redis_make_request() except in that first parameter is an `event base` object
  937. --]]
  938. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  939. exports.redis_make_request_taskless = redis_make_request_taskless
  940. local redis_scripts = {
  941. }
  942. local function script_set_loaded(script)
  943. if script.sha then
  944. script.loaded = true
  945. end
  946. local wait_table = {}
  947. for _,s in ipairs(script.waitq) do
  948. table.insert(wait_table, s)
  949. end
  950. script.waitq = {}
  951. for _,s in ipairs(wait_table) do
  952. s(script.loaded)
  953. end
  954. end
  955. local function prepare_redis_call(script)
  956. local function merge_tables(t1, t2)
  957. for k,v in pairs(t2) do t1[k] = v end
  958. end
  959. local servers = {}
  960. local options = {}
  961. if script.redis_params.read_servers then
  962. merge_tables(servers, script.redis_params.read_servers:all_upstreams())
  963. end
  964. if script.redis_params.write_servers then
  965. merge_tables(servers, script.redis_params.write_servers:all_upstreams())
  966. end
  967. -- Call load script on each server, set loaded flag
  968. script.in_flight = #servers
  969. for _,s in ipairs(servers) do
  970. local cur_opts = {
  971. host = s:get_addr(),
  972. timeout = script.redis_params['timeout'],
  973. cmd = 'SCRIPT',
  974. args = {'LOAD', script.script },
  975. upstream = s
  976. }
  977. if script.redis_params['password'] then
  978. cur_opts['password'] = script.redis_params['password']
  979. end
  980. if script.redis_params['db'] then
  981. cur_opts['dbname'] = script.redis_params['db']
  982. end
  983. table.insert(options, cur_opts)
  984. end
  985. return options
  986. end
  987. local function load_script_task(script, task)
  988. local rspamd_redis = require "rspamd_redis"
  989. local opts = prepare_redis_call(script)
  990. for _,opt in ipairs(opts) do
  991. opt.task = task
  992. opt.callback = function(err, data)
  993. if err then
  994. logger.errx(task, 'cannot upload script to %s: %s; registered from: %s:%s',
  995. opt.upstream:get_addr():to_string(true),
  996. err, script.caller.short_src, script.caller.currentline)
  997. opt.upstream:fail()
  998. script.fatal_error = err
  999. else
  1000. opt.upstream:ok()
  1001. logger.infox(task,
  1002. "uploaded redis script to %s with id %s, sha: %s",
  1003. opt.upstream:get_addr():to_string(true),
  1004. script.id, data)
  1005. script.sha = data -- We assume that sha is the same on all servers
  1006. end
  1007. script.in_flight = script.in_flight - 1
  1008. if script.in_flight == 0 then
  1009. script_set_loaded(script)
  1010. end
  1011. end
  1012. local ret = rspamd_redis.make_request(opt)
  1013. if not ret then
  1014. logger.errx('cannot execute redis request to load script on %s',
  1015. opt.upstream:get_addr())
  1016. script.in_flight = script.in_flight - 1
  1017. opt.upstream:fail()
  1018. end
  1019. if script.in_flight == 0 then
  1020. script_set_loaded(script)
  1021. end
  1022. end
  1023. end
  1024. local function load_script_taskless(script, cfg, ev_base)
  1025. local rspamd_redis = require "rspamd_redis"
  1026. local opts = prepare_redis_call(script)
  1027. for _,opt in ipairs(opts) do
  1028. opt.config = cfg
  1029. opt.ev_base = ev_base
  1030. opt.callback = function(err, data)
  1031. if err then
  1032. logger.errx(cfg, 'cannot upload script to %s: %s; registered from: %s:%s',
  1033. opt.upstream:get_addr():to_string(true),
  1034. err, script.caller.short_src, script.caller.currentline)
  1035. opt.upstream:fail()
  1036. script.fatal_error = err
  1037. else
  1038. opt.upstream:ok()
  1039. logger.infox(cfg,
  1040. "uploaded redis script to %s with id %s, sha: %s",
  1041. opt.upstream:get_addr():to_string(true), script.id, data)
  1042. script.sha = data -- We assume that sha is the same on all servers
  1043. script.fatal_error = nil
  1044. end
  1045. script.in_flight = script.in_flight - 1
  1046. if script.in_flight == 0 then
  1047. script_set_loaded(script)
  1048. end
  1049. end
  1050. local ret = rspamd_redis.make_request(opt)
  1051. if not ret then
  1052. logger.errx('cannot execute redis request to load script on %s',
  1053. opt.upstream:get_addr())
  1054. script.in_flight = script.in_flight - 1
  1055. opt.upstream:fail()
  1056. end
  1057. if script.in_flight == 0 then
  1058. script_set_loaded(script)
  1059. end
  1060. end
  1061. end
  1062. local function load_redis_script(script, cfg, ev_base, _)
  1063. if script.redis_params then
  1064. load_script_taskless(script, cfg, ev_base)
  1065. end
  1066. end
  1067. local function add_redis_script(script, redis_params)
  1068. local caller = debug.getinfo(2)
  1069. local new_script = {
  1070. caller = caller,
  1071. loaded = false,
  1072. redis_params = redis_params,
  1073. script = script,
  1074. waitq = {}, -- callbacks pending for script being loaded
  1075. id = #redis_scripts + 1
  1076. }
  1077. -- Register on load function
  1078. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  1079. local mult = 0.0
  1080. rspamd_config:add_periodic(ev_base, 0.0, function()
  1081. if not new_script.sha then
  1082. load_redis_script(new_script, cfg, ev_base, worker)
  1083. mult = mult + 1
  1084. return 1.0 * mult -- Check one more time in one second
  1085. end
  1086. return false
  1087. end, false)
  1088. end)
  1089. table.insert(redis_scripts, new_script)
  1090. return #redis_scripts
  1091. end
  1092. exports.add_redis_script = add_redis_script
  1093. local function exec_redis_script(id, params, callback, keys, args)
  1094. local redis_args = {}
  1095. if not redis_scripts[id] then
  1096. logger.errx("cannot find registered script with id %s", id)
  1097. return false
  1098. end
  1099. local script = redis_scripts[id]
  1100. if script.fatal_error then
  1101. callback(script.fatal_error, nil)
  1102. return true
  1103. end
  1104. if not script.redis_params then
  1105. callback('no redis servers defined', nil)
  1106. return true
  1107. end
  1108. local function do_call(can_reload)
  1109. local function redis_cb(err, data)
  1110. if not err then
  1111. callback(err, data)
  1112. elseif string.match(err, 'NOSCRIPT') then
  1113. -- Schedule restart
  1114. script.sha = nil
  1115. if can_reload then
  1116. table.insert(script.waitq, do_call)
  1117. if script.in_flight == 0 then
  1118. -- Reload scripts if this has not been initiated yet
  1119. if params.task then
  1120. load_script_task(script, params.task)
  1121. else
  1122. load_script_taskless(script, rspamd_config, params.ev_base)
  1123. end
  1124. end
  1125. else
  1126. callback(err, data)
  1127. end
  1128. else
  1129. callback(err, data)
  1130. end
  1131. end
  1132. if #redis_args == 0 then
  1133. table.insert(redis_args, script.sha)
  1134. table.insert(redis_args, tostring(#keys))
  1135. for _,k in ipairs(keys) do
  1136. table.insert(redis_args, k)
  1137. end
  1138. if type(args) == 'table' then
  1139. for _, a in ipairs(args) do
  1140. table.insert(redis_args, a)
  1141. end
  1142. end
  1143. end
  1144. if params.task then
  1145. if not rspamd_redis_make_request(params.task, script.redis_params,
  1146. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1147. callback('Cannot make redis request', nil)
  1148. end
  1149. else
  1150. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  1151. script.redis_params,
  1152. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1153. callback('Cannot make redis request', nil)
  1154. end
  1155. end
  1156. end
  1157. if script.loaded then
  1158. do_call(true)
  1159. else
  1160. -- Delayed until scripts are loaded
  1161. if not params.task then
  1162. table.insert(script.waitq, do_call)
  1163. else
  1164. -- TODO: fix taskfull requests
  1165. table.insert(script.waitq, function()
  1166. if script.loaded then
  1167. do_call(false)
  1168. else
  1169. callback('NOSCRIPT', nil)
  1170. end
  1171. end)
  1172. load_script_task(script, params.task)
  1173. end
  1174. end
  1175. return true
  1176. end
  1177. exports.exec_redis_script = exec_redis_script
  1178. local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
  1179. if not redis_params then
  1180. return false,nil
  1181. end
  1182. local rspamd_redis = require "rspamd_redis"
  1183. local addr
  1184. if key then
  1185. if is_write then
  1186. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  1187. else
  1188. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  1189. end
  1190. else
  1191. if is_write then
  1192. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  1193. else
  1194. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  1195. end
  1196. end
  1197. if not addr then
  1198. logger.errx(cfg, 'cannot select server to make redis request')
  1199. end
  1200. local options = {
  1201. host = addr:get_addr(),
  1202. timeout = redis_params['timeout'],
  1203. config = cfg or rspamd_config,
  1204. ev_base = ev_base or rspamadm_ev_base,
  1205. session = redis_params.session or rspamadm_session
  1206. }
  1207. for k,v in pairs(redis_params) do
  1208. options[k] = v
  1209. end
  1210. if not options.config then
  1211. logger.errx('config is not set')
  1212. return false,nil,addr
  1213. end
  1214. if not options.ev_base then
  1215. logger.errx('ev_base is not set')
  1216. return false,nil,addr
  1217. end
  1218. if not options.session then
  1219. logger.errx('session is not set')
  1220. return false,nil,addr
  1221. end
  1222. local ret,conn = rspamd_redis.connect_sync(options)
  1223. if not ret then
  1224. logger.errx('cannot execute redis request: %s', conn)
  1225. addr:fail()
  1226. return false,nil,addr
  1227. end
  1228. if conn then
  1229. if redis_params['password'] then
  1230. conn:add_cmd('AUTH', {redis_params['password']})
  1231. end
  1232. if redis_params['db'] then
  1233. conn:add_cmd('SELECT', {tostring(redis_params['db'])})
  1234. elseif redis_params['dbname'] then
  1235. conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
  1236. end
  1237. end
  1238. return ret,conn,addr
  1239. end
  1240. exports.redis_connect_sync = redis_connect_sync
  1241. --[[[
  1242. -- @function lua_redis.request(redis_params, attrs, req)
  1243. -- Sends a request to Redis synchronously with coroutines or asynchronously using
  1244. -- a callback (modern API)
  1245. -- @param redis_params a table of redis server parameters
  1246. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1247. -- @param req a table of request: a command + command options
  1248. -- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
  1249. --]]
  1250. exports.request = function(redis_params, attrs, req)
  1251. local lua_util = require "lua_util"
  1252. if not attrs or not redis_params or not req then
  1253. logger.errx('invalid arguments for redis request')
  1254. return false,nil,nil
  1255. end
  1256. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1257. logger.errx('invalid attributes for redis request')
  1258. return false,nil,nil
  1259. end
  1260. local opts = lua_util.shallowcopy(attrs)
  1261. local log_obj = opts.task or opts.config
  1262. local addr
  1263. if opts.callback then
  1264. -- Wrap callback
  1265. local callback = opts.callback
  1266. local function rspamd_redis_make_request_cb(err, data)
  1267. if err then
  1268. addr:fail()
  1269. else
  1270. addr:ok()
  1271. end
  1272. callback(err, data, addr)
  1273. end
  1274. opts.callback = rspamd_redis_make_request_cb
  1275. end
  1276. local rspamd_redis = require "rspamd_redis"
  1277. local is_write = opts.is_write
  1278. if opts.key then
  1279. if is_write then
  1280. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1281. else
  1282. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1283. end
  1284. else
  1285. if is_write then
  1286. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1287. else
  1288. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1289. end
  1290. end
  1291. if not addr then
  1292. logger.errx(log_obj, 'cannot select server to make redis request')
  1293. end
  1294. opts.host = addr:get_addr()
  1295. opts.timeout = redis_params.timeout
  1296. if type(req) == 'string' then
  1297. opts.cmd = req
  1298. else
  1299. -- XXX: modifies the input table
  1300. opts.cmd = table.remove(req, 1);
  1301. opts.args = req
  1302. end
  1303. if redis_params.password then
  1304. opts.password = redis_params.password
  1305. end
  1306. if redis_params.db then
  1307. opts.dbname = redis_params.db
  1308. end
  1309. lutil.debugm(N, 'perform generic request to redis server' ..
  1310. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr,
  1311. opts.timeout, opts.cmd, opts.args)
  1312. if opts.callback then
  1313. local ret,conn = rspamd_redis.make_request(opts)
  1314. if not ret then
  1315. logger.errx(log_obj, 'cannot execute redis request')
  1316. addr:fail()
  1317. end
  1318. return ret,conn,addr
  1319. else
  1320. -- Coroutines version
  1321. local ret,conn = rspamd_redis.connect_sync(opts)
  1322. if not ret then
  1323. logger.errx(log_obj, 'cannot execute redis request')
  1324. addr:fail()
  1325. else
  1326. conn:add_cmd(opts.cmd, opts.args)
  1327. return conn:exec()
  1328. end
  1329. return false,nil,addr
  1330. end
  1331. end
  1332. --[[[
  1333. -- @function lua_redis.connect(redis_params, attrs)
  1334. -- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
  1335. -- @param redis_params a table of redis server parameters
  1336. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1337. -- @return {result,connection,address} boolean result, connection object, redis server address
  1338. --]]
  1339. exports.connect = function(redis_params, attrs)
  1340. local lua_util = require "lua_util"
  1341. if not attrs or not redis_params then
  1342. logger.errx('invalid arguments for redis connect')
  1343. return false,nil,nil
  1344. end
  1345. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1346. logger.errx('invalid attributes for redis connect')
  1347. return false,nil,nil
  1348. end
  1349. local opts = lua_util.shallowcopy(attrs)
  1350. local log_obj = opts.task or opts.config
  1351. local addr
  1352. if opts.callback then
  1353. -- Wrap callback
  1354. local callback = opts.callback
  1355. local function rspamd_redis_make_request_cb(err, data)
  1356. if err then
  1357. addr:fail()
  1358. else
  1359. addr:ok()
  1360. end
  1361. callback(err, data, addr)
  1362. end
  1363. opts.callback = rspamd_redis_make_request_cb
  1364. end
  1365. local rspamd_redis = require "rspamd_redis"
  1366. local is_write = opts.is_write
  1367. if opts.key then
  1368. if is_write then
  1369. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1370. else
  1371. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1372. end
  1373. else
  1374. if is_write then
  1375. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1376. else
  1377. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1378. end
  1379. end
  1380. if not addr then
  1381. logger.errx(log_obj, 'cannot select server to make redis connect')
  1382. end
  1383. opts.host = addr:get_addr()
  1384. opts.timeout = redis_params.timeout
  1385. if redis_params.password then
  1386. opts.password = redis_params.password
  1387. end
  1388. if redis_params.db then
  1389. opts.dbname = redis_params.db
  1390. end
  1391. if opts.callback then
  1392. local ret,conn = rspamd_redis.connect(opts)
  1393. if not ret then
  1394. logger.errx(log_obj, 'cannot execute redis connect')
  1395. addr:fail()
  1396. end
  1397. return ret,conn,addr
  1398. else
  1399. -- Coroutines version
  1400. local ret,conn = rspamd_redis.connect_sync(opts)
  1401. if not ret then
  1402. logger.errx(log_obj, 'cannot execute redis connect')
  1403. addr:fail()
  1404. else
  1405. return true,conn,addr
  1406. end
  1407. return false,nil,addr
  1408. end
  1409. end
  1410. local redis_prefixes = {}
  1411. --[[[
  1412. -- @function lua_redis.register_prefix(prefix, module, description[, optional])
  1413. -- Register new redis prefix for documentation purposes
  1414. -- @param {string} prefix string prefix
  1415. -- @param {string} module module name
  1416. -- @param {string} description prefix description
  1417. -- @param {table} optional optional kv pairs (e.g. pattern)
  1418. --]]
  1419. local function register_prefix(prefix, module, description, optional)
  1420. local pr = {
  1421. module = module,
  1422. description = description
  1423. }
  1424. if optional and type(optional) == 'table' then
  1425. for k,v in pairs(optional) do
  1426. pr[k] = v
  1427. end
  1428. end
  1429. redis_prefixes[prefix] = pr
  1430. end
  1431. exports.register_prefix = register_prefix
  1432. --[[[
  1433. -- @function lua_redis.prefixes([mname])
  1434. -- Returns prefixes for specific module (or all prefixes). Returns a table prefix -> table
  1435. --]]
  1436. exports.prefixes = function(mname)
  1437. if not mname then
  1438. return redis_prefixes
  1439. else
  1440. local fun = require "fun"
  1441. return fun.totable(fun.filter(function(_, data) return data.module == mname end,
  1442. redis_prefixes))
  1443. end
  1444. end
  1445. return exports