You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_redis.lua 49KB


  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local ts = require("tableshape").types
  17. local exports = {}
  18. local E = {}
  19. local N = "lua_redis"
  20. local common_schema = ts.shape {
  21. timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  22. db = ts.string:is_optional(),
  23. database = ts.string:is_optional(),
  24. dbname = ts.string:is_optional(),
  25. prefix = ts.string:is_optional(),
  26. password = ts.string:is_optional(),
  27. expand_keys = ts.boolean:is_optional(),
  28. sentinels = (ts.string + ts.array_of(ts.string)):is_optional(),
  29. sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  30. sentinel_masters_pattern = ts.string:is_optional(),
  31. sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional(),
  32. }
  33. local config_schema =
  34. -- Allow separate read/write servers to allow usage in the `extra_fields`
  35. ts.shape({
  36. read_servers = ts.string + ts.array_of(ts.string),
  37. }, {extra_fields = common_schema}) +
  38. ts.shape({
  39. write_servers = ts.string + ts.array_of(ts.string),
  40. }, {extra_fields = common_schema}) +
  41. ts.shape({
  42. read_servers = ts.string + ts.array_of(ts.string),
  43. write_servers = ts.string + ts.array_of(ts.string),
  44. }, {extra_fields = common_schema}) +
  45. ts.shape({
  46. servers = ts.string + ts.array_of(ts.string),
  47. }, {extra_fields = common_schema}) +
  48. ts.shape({
  49. server = ts.string + ts.array_of(ts.string),
  50. }, {extra_fields = common_schema})
  51. exports.config_schema = config_schema
  52. local function redis_query_sentinel(ev_base, params, initialised)
  53. local function flatten_redis_table(tbl)
  54. local res = {}
  55. for i=1,#tbl,2 do
  56. res[tbl[i]] = tbl[i + 1]
  57. end
  58. return res
  59. end
  60. -- Coroutines syntax
  61. local rspamd_redis = require "rspamd_redis"
  62. local sentinels = params.sentinels
  63. local addr = sentinels:get_upstream_round_robin()
  64. local host = addr:get_addr()
  65. local masters = {}
  66. local process_masters -- Function that is called to process masters data
  67. local function masters_cb(err, result)
  68. if not err and result and type(result) == 'table' then
  69. local pending_subrequests = 0
  70. for _,m in ipairs(result) do
  71. local master = flatten_redis_table(m)
  72. -- Wrap IPv6-addresses in brackets
  73. if (master.ip:match(":")) then
  74. master.ip = "["..master.ip.."]"
  75. end
  76. if params.sentinel_masters_pattern then
  77. if master.name:match(params.sentinel_masters_pattern) then
  78. lutil.debugm(N, 'found master %s with ip %s and port %s',
  79. master.name, master.ip, master.port)
  80. masters[master.name] = master
  81. else
  82. lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
  83. master.name, master.ip, master.port, params.sentinel_masters_pattern)
  84. end
  85. else
  86. lutil.debugm(N, 'found master %s with ip %s and port %s',
  87. master.name, master.ip, master.port)
  88. masters[master.name] = master
  89. end
  90. end
  91. -- For each master we need to get a list of slaves
  92. for k,v in pairs(masters) do
  93. v.slaves = {}
  94. local function slaves_cb(slave_err, slave_result)
  95. if not slave_err and type(slave_result) == 'table' then
  96. for _,s in ipairs(slave_result) do
  97. local slave = flatten_redis_table(s)
  98. lutil.debugm(N, rspamd_config,
  99. 'found slave for master %s with ip %s and port %s',
  100. v.name, slave.ip, slave.port)
  101. -- Wrap IPv6-addresses in brackets
  102. if (slave.ip:match(":")) then
  103. slave.ip = "["..slave.ip.."]"
  104. end
  105. v.slaves[#v.slaves + 1] = slave
  106. end
  107. else
  108. logger.errx('cannot get slaves data from Redis Sentinel %s: %s',
  109. host:to_string(true), slave_err)
  110. addr:fail()
  111. end
  112. pending_subrequests = pending_subrequests - 1
  113. if pending_subrequests == 0 then
  114. -- Finalize masters and slaves
  115. process_masters()
  116. end
  117. end
  118. local ret = rspamd_redis.make_request({
  119. host = addr:get_addr(),
  120. timeout = params.timeout,
  121. config = rspamd_config,
  122. ev_base = ev_base,
  123. cmd = 'SENTINEL',
  124. args = {'slaves', k},
  125. no_pool = true,
  126. callback = slaves_cb
  127. })
  128. if not ret then
  129. logger.errx(rspamd_config, 'cannot connect sentinel when query slaves at address: %s',
  130. host:to_string(true))
  131. addr:fail()
  132. else
  133. pending_subrequests = pending_subrequests + 1
  134. end
  135. end
  136. addr:ok()
  137. else
  138. logger.errx('cannot get masters data from Redis Sentinel %s: %s',
  139. host:to_string(true), err)
  140. addr:fail()
  141. end
  142. end
  143. local ret = rspamd_redis.make_request({
  144. host = addr:get_addr(),
  145. timeout = params.timeout,
  146. config = rspamd_config,
  147. ev_base = ev_base,
  148. cmd = 'SENTINEL',
  149. args = {'masters'},
  150. no_pool = true,
  151. callback = masters_cb,
  152. })
  153. if not ret then
  154. logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
  155. host:to_string(true))
  156. addr:fail()
  157. end
  158. process_masters = function()
  159. -- We now form new strings for masters and slaves
  160. local read_servers_tbl, write_servers_tbl = {}, {}
  161. for _,master in pairs(masters) do
  162. write_servers_tbl[#write_servers_tbl + 1] = string.format(
  163. '%s:%s', master.ip, master.port
  164. )
  165. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  166. '%s:%s', master.ip, master.port
  167. )
  168. for _,slave in ipairs(master.slaves) do
  169. if slave['master-link-status'] == 'ok' then
  170. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  171. '%s:%s', slave.ip, slave.port
  172. )
  173. end
  174. end
  175. end
  176. table.sort(read_servers_tbl)
  177. table.sort(write_servers_tbl)
  178. local read_servers_str = table.concat(read_servers_tbl, ',')
  179. local write_servers_str = table.concat(write_servers_tbl, ',')
  180. lutil.debugm(N, rspamd_config,
  181. 'new servers list: %s read; %s write',
  182. read_servers_str,
  183. write_servers_str)
  184. if read_servers_str ~= params.read_servers_str then
  185. local upstream_list = require "rspamd_upstream_list"
  186. local read_upstreams = upstream_list.create(rspamd_config,
  187. read_servers_str, 6379)
  188. if read_upstreams then
  189. logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
  190. host:to_string(true), read_servers_str)
  191. params.read_servers = read_upstreams
  192. params.read_servers_str = read_servers_str
  193. end
  194. end
  195. if write_servers_str ~= params.write_servers_str then
  196. local upstream_list = require "rspamd_upstream_list"
  197. local write_upstreams = upstream_list.create(rspamd_config,
  198. write_servers_str, 6379)
  199. if write_upstreams then
  200. logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
  201. host:to_string(true), write_servers_str)
  202. params.write_servers = write_upstreams
  203. params.write_servers_str = write_servers_str
  204. local queried = false
  205. local function monitor_failures(up, _, count)
  206. if count > params.sentinel_master_maxerrors and not queried then
  207. logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel',
  208. host:to_string(true), count)
  209. queried = true -- Avoid multiple checks caused by this monitor
  210. redis_query_sentinel(ev_base, params, true)
  211. end
  212. end
  213. write_upstreams:add_watcher('failure', monitor_failures)
  214. end
  215. end
  216. end
  217. end
  218. local function add_redis_sentinels(params)
  219. local upstream_list = require "rspamd_upstream_list"
  220. local upstreams_sentinels = upstream_list.create(rspamd_config,
  221. params.sentinels, 5000)
  222. if not upstreams_sentinels then
  223. logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
  224. params.sentinels)
  225. return
  226. end
  227. params.sentinels = upstreams_sentinels
  228. if not params.sentinel_watch_time then
  229. params.sentinel_watch_time = 60 -- Each minute
  230. end
  231. if not params.sentinel_master_maxerrors then
  232. params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking
  233. end
  234. rspamd_config:add_on_load(function(_, ev_base, worker)
  235. local initialised = false
  236. if worker:is_scanner() or worker:get_type() == 'fuzzy' then
  237. rspamd_config:add_periodic(ev_base, 0.0, function()
  238. redis_query_sentinel(ev_base, params, initialised)
  239. initialised = true
  240. return params.sentinel_watch_time
  241. end, false)
  242. end
  243. end)
  244. end
  245. local cached_results = {}
  246. local function calculate_redis_hash(params)
  247. local cr = require "rspamd_cryptobox_hash"
  248. local h = cr.create()
  249. local function rec_hash(k, v)
  250. if type(v) == 'string' then
  251. h:update(k)
  252. h:update(v)
  253. elseif type(v) == 'number' then
  254. h:update(k)
  255. h:update(tostring(v))
  256. elseif type(v) == 'table' then
  257. for kk,vv in pairs(v) do
  258. rec_hash(kk, vv)
  259. end
  260. end
  261. end
  262. rec_hash('top', params)
  263. return h:base32()
  264. end
  265. local function process_redis_opts(options, redis_params)
  266. local default_timeout = 1.0
  267. local default_expand_keys = false
  268. if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then
  269. if options['timeout'] then
  270. redis_params['timeout'] = tonumber(options['timeout'])
  271. else
  272. redis_params['timeout'] = default_timeout
  273. end
  274. end
  275. if options['prefix'] and not redis_params['prefix'] then
  276. redis_params['prefix'] = options['prefix']
  277. end
  278. if type(options['expand_keys']) == 'boolean' then
  279. redis_params['expand_keys'] = options['expand_keys']
  280. else
  281. redis_params['expand_keys'] = default_expand_keys
  282. end
  283. if not redis_params['db'] then
  284. if options['db'] then
  285. redis_params['db'] = tostring(options['db'])
  286. elseif options['dbname'] then
  287. redis_params['db'] = tostring(options['dbname'])
  288. elseif options['database'] then
  289. redis_params['db'] = tostring(options['database'])
  290. end
  291. end
  292. if options['password'] and not redis_params['password'] then
  293. redis_params['password'] = options['password']
  294. end
  295. if not redis_params.sentinels and options.sentinels then
  296. redis_params.sentinels = options.sentinels
  297. end
  298. if options['sentinel_masters_pattern'] and not redis_params['sentinel_masters_pattern'] then
  299. redis_params['sentinel_masters_pattern'] = options['sentinel_masters_pattern']
  300. end
  301. end
  302. local function enrich_defaults(rspamd_config, module, redis_params)
  303. if rspamd_config then
  304. local opts = rspamd_config:get_all_opt('redis')
  305. if opts then
  306. if module then
  307. if opts[module] then
  308. process_redis_opts(opts[module], redis_params)
  309. end
  310. end
  311. process_redis_opts(opts, redis_params)
  312. end
  313. end
  314. end
  315. local function maybe_return_cached(redis_params)
  316. local h = calculate_redis_hash(redis_params)
  317. if cached_results[h] then
  318. lutil.debugm(N, 'reused redis server: %s', redis_params)
  319. return cached_results[h]
  320. end
  321. redis_params.hash = h
  322. cached_results[h] = redis_params
  323. if not redis_params.read_only and redis_params.sentinels then
  324. add_redis_sentinels(redis_params)
  325. end
  326. lutil.debugm(N, 'loaded new redis server: %s', redis_params)
  327. return redis_params
  328. end
  329. --[[[
  330. -- @module lua_redis
  331. -- This module contains helper functions for working with Redis
  332. --]]
  333. local function process_redis_options(options, rspamd_config, result)
  334. local default_port = 6379
  335. local upstream_list = require "rspamd_upstream_list"
  336. local read_only = true
  337. -- Try to get read servers:
  338. local upstreams_read, upstreams_write
  339. if options['read_servers'] then
  340. if rspamd_config then
  341. upstreams_read = upstream_list.create(rspamd_config,
  342. options['read_servers'], default_port)
  343. else
  344. upstreams_read = upstream_list.create(options['read_servers'],
  345. default_port)
  346. end
  347. result.read_servers_str = options['read_servers']
  348. elseif options['servers'] then
  349. if rspamd_config then
  350. upstreams_read = upstream_list.create(rspamd_config,
  351. options['servers'], default_port)
  352. else
  353. upstreams_read = upstream_list.create(options['servers'], default_port)
  354. end
  355. result.read_servers_str = options['servers']
  356. read_only = false
  357. elseif options['server'] then
  358. if rspamd_config then
  359. upstreams_read = upstream_list.create(rspamd_config,
  360. options['server'], default_port)
  361. else
  362. upstreams_read = upstream_list.create(options['server'], default_port)
  363. end
  364. result.read_servers_str = options['server']
  365. read_only = false
  366. end
  367. if upstreams_read then
  368. if options['write_servers'] then
  369. if rspamd_config then
  370. upstreams_write = upstream_list.create(rspamd_config,
  371. options['write_servers'], default_port)
  372. else
  373. upstreams_write = upstream_list.create(options['write_servers'],
  374. default_port)
  375. end
  376. result.write_servers_str = options['write_servers']
  377. read_only = false
  378. elseif not read_only then
  379. upstreams_write = upstreams_read
  380. result.write_servers_str = result.read_servers_str
  381. end
  382. end
  383. -- Store options
  384. process_redis_opts(options, result)
  385. if read_only and not upstreams_write then
  386. result.read_only = true
  387. elseif upstreams_write then
  388. result.read_only = false
  389. end
  390. if upstreams_read then
  391. result.read_servers = upstreams_read
  392. if upstreams_write then
  393. result.write_servers = upstreams_write
  394. end
  395. return true
  396. end
  397. lutil.debugm(N, rspamd_config,
  398. 'cannot load redis server from obj: %s, processed to %s',
  399. options, result)
  400. return false
  401. end
  402. --[[[
  403. @function try_load_redis_servers(options, rspamd_config, no_fallback)
  404. Tries to load redis servers from the specified `options` object.
  405. Returns `redis_params` table or nil in case of failure
  406. --]]
  407. exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name)
  408. local result = {}
  409. if process_redis_options(options, rspamd_config, result) then
  410. if not no_fallback then
  411. enrich_defaults(rspamd_config, module_name, result)
  412. end
  413. return maybe_return_cached(result)
  414. end
  415. end
  416. -- This function parses redis server definition using either
  417. -- specific server string for this module or global
  418. -- redis section
  419. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  420. local result = {}
  421. -- Try local options
  422. local opts
  423. lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name)
  424. if not module_opts then
  425. opts = rspamd_config:get_all_opt(module_name)
  426. else
  427. opts = module_opts
  428. end
  429. if opts then
  430. local ret
  431. if opts.redis then
  432. ret = process_redis_options(opts.redis, rspamd_config, result)
  433. if ret then
  434. if not no_fallback then
  435. enrich_defaults(rspamd_config, module_name, result)
  436. end
  437. return maybe_return_cached(result)
  438. end
  439. end
  440. ret = process_redis_options(opts, rspamd_config, result)
  441. if ret then
  442. if not no_fallback then
  443. enrich_defaults(rspamd_config, module_name, result)
  444. end
  445. return maybe_return_cached(result)
  446. end
  447. end
  448. if no_fallback then
  449. logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled",
  450. module_name)
  451. return nil
  452. end
  453. -- Try global options
  454. opts = rspamd_config:get_all_opt('redis')
  455. if opts then
  456. local ret
  457. if opts[module_name] then
  458. ret = process_redis_options(opts[module_name], rspamd_config, result)
  459. if ret then
  460. return maybe_return_cached(result)
  461. end
  462. else
  463. ret = process_redis_options(opts, rspamd_config, result)
  464. -- Exclude disabled
  465. if opts['disabled_modules'] then
  466. for _,v in ipairs(opts['disabled_modules']) do
  467. if v == module_name then
  468. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  469. module_name)
  470. return nil
  471. end
  472. end
  473. end
  474. if ret then
  475. logger.infox(rspamd_config, "use default Redis settings for %s",
  476. module_name)
  477. return maybe_return_cached(result)
  478. end
  479. end
  480. end
  481. if result.read_servers then
  482. return maybe_return_cached(result)
  483. end
  484. return nil
  485. end
  486. --[[[
  487. -- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
  488. -- Extracts Redis server settings from configuration
  489. -- @param {string} module_name name of module to get settings for
  490. -- @param {table} module_opts settings for module or `nil` to fetch them from configuration
  491. -- @param {boolean} no_fallback should be `true` if global settings must not be used
  492. -- @return {table} redis server settings
  493. -- @example
  494. -- local rconfig = lua_redis.parse_redis_server('my_module')
  495. -- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
  496. -- -- ['timeout'] contains timeout in seconds
  497. -- -- ['expand_keys'] if true tells that redis key expansion is enabled
  498. --]]
  499. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  500. exports.parse_redis_server = rspamd_parse_redis_server
  501. local process_cmd = {
  502. bitop = function(args)
  503. local idx_l = {}
  504. for i = 2, #args do
  505. table.insert(idx_l, i)
  506. end
  507. return idx_l
  508. end,
  509. blpop = function(args)
  510. local idx_l = {}
  511. for i = 1, #args -1 do
  512. table.insert(idx_l, i)
  513. end
  514. return idx_l
  515. end,
  516. eval = function(args)
  517. local idx_l = {}
  518. local numkeys = args[2]
  519. if numkeys and tonumber(numkeys) >= 1 then
  520. for i = 3, numkeys + 2 do
  521. table.insert(idx_l, i)
  522. end
  523. end
  524. return idx_l
  525. end,
  526. set = function(args)
  527. return {1}
  528. end,
  529. mget = function(args)
  530. local idx_l = {}
  531. for i = 1, #args do
  532. table.insert(idx_l, i)
  533. end
  534. return idx_l
  535. end,
  536. mset = function(args)
  537. local idx_l = {}
  538. for i = 1, #args, 2 do
  539. table.insert(idx_l, i)
  540. end
  541. return idx_l
  542. end,
  543. sdiffstore = function(args)
  544. local idx_l = {}
  545. for i = 2, #args do
  546. table.insert(idx_l, i)
  547. end
  548. return idx_l
  549. end,
  550. smove = function(args)
  551. return {1, 2}
  552. end,
  553. script = function() end
  554. }
  555. process_cmd.append = process_cmd.set
  556. process_cmd.auth = process_cmd.script
  557. process_cmd.bgrewriteaof = process_cmd.script
  558. process_cmd.bgsave = process_cmd.script
  559. process_cmd.bitcount = process_cmd.set
  560. process_cmd.bitfield = process_cmd.set
  561. process_cmd.bitpos = process_cmd.set
  562. process_cmd.brpop = process_cmd.blpop
  563. process_cmd.brpoplpush = process_cmd.blpop
  564. process_cmd.client = process_cmd.script
  565. process_cmd.cluster = process_cmd.script
  566. process_cmd.command = process_cmd.script
  567. process_cmd.config = process_cmd.script
  568. process_cmd.dbsize = process_cmd.script
  569. process_cmd.debug = process_cmd.script
  570. process_cmd.decr = process_cmd.set
  571. process_cmd.decrby = process_cmd.set
  572. process_cmd.del = process_cmd.mget
  573. process_cmd.discard = process_cmd.script
  574. process_cmd.dump = process_cmd.set
  575. process_cmd.echo = process_cmd.script
  576. process_cmd.evalsha = process_cmd.eval
  577. process_cmd.exec = process_cmd.script
  578. process_cmd.exists = process_cmd.mget
  579. process_cmd.expire = process_cmd.set
  580. process_cmd.expireat = process_cmd.set
  581. process_cmd.flushall = process_cmd.script
  582. process_cmd.flushdb = process_cmd.script
  583. process_cmd.geoadd = process_cmd.set
  584. process_cmd.geohash = process_cmd.set
  585. process_cmd.geopos = process_cmd.set
  586. process_cmd.geodist = process_cmd.set
  587. process_cmd.georadius = process_cmd.set
  588. process_cmd.georadiusbymember = process_cmd.set
  589. process_cmd.get = process_cmd.set
  590. process_cmd.getbit = process_cmd.set
  591. process_cmd.getrange = process_cmd.set
  592. process_cmd.getset = process_cmd.set
  593. process_cmd.hdel = process_cmd.set
  594. process_cmd.hexists = process_cmd.set
  595. process_cmd.hget = process_cmd.set
  596. process_cmd.hgetall = process_cmd.set
  597. process_cmd.hincrby = process_cmd.set
  598. process_cmd.hincrbyfloat = process_cmd.set
  599. process_cmd.hkeys = process_cmd.set
  600. process_cmd.hlen = process_cmd.set
  601. process_cmd.hmget = process_cmd.set
  602. process_cmd.hmset = process_cmd.set
  603. process_cmd.hscan = process_cmd.set
  604. process_cmd.hset = process_cmd.set
  605. process_cmd.hsetnx = process_cmd.set
  606. process_cmd.hstrlen = process_cmd.set
  607. process_cmd.hvals = process_cmd.set
  608. process_cmd.incr = process_cmd.set
  609. process_cmd.incrby = process_cmd.set
  610. process_cmd.incrbyfloat = process_cmd.set
  611. process_cmd.info = process_cmd.script
  612. process_cmd.keys = process_cmd.script
  613. process_cmd.lastsave = process_cmd.script
  614. process_cmd.lindex = process_cmd.set
  615. process_cmd.linsert = process_cmd.set
  616. process_cmd.llen = process_cmd.set
  617. process_cmd.lpop = process_cmd.set
  618. process_cmd.lpush = process_cmd.set
  619. process_cmd.lpushx = process_cmd.set
  620. process_cmd.lrange = process_cmd.set
  621. process_cmd.lrem = process_cmd.set
  622. process_cmd.lset = process_cmd.set
  623. process_cmd.ltrim = process_cmd.set
  624. process_cmd.migrate = process_cmd.script
  625. process_cmd.monitor = process_cmd.script
  626. process_cmd.move = process_cmd.set
  627. process_cmd.msetnx = process_cmd.mset
  628. process_cmd.multi = process_cmd.script
  629. process_cmd.object = process_cmd.script
  630. process_cmd.persist = process_cmd.set
  631. process_cmd.pexpire = process_cmd.set
  632. process_cmd.pexpireat = process_cmd.set
  633. process_cmd.pfadd = process_cmd.set
  634. process_cmd.pfcount = process_cmd.set
  635. process_cmd.pfmerge = process_cmd.mget
  636. process_cmd.ping = process_cmd.script
  637. process_cmd.psetex = process_cmd.set
  638. process_cmd.psubscribe = process_cmd.script
  639. process_cmd.pubsub = process_cmd.script
  640. process_cmd.pttl = process_cmd.set
  641. process_cmd.publish = process_cmd.script
  642. process_cmd.punsubscribe = process_cmd.script
  643. process_cmd.quit = process_cmd.script
  644. process_cmd.randomkey = process_cmd.script
  645. process_cmd.readonly = process_cmd.script
  646. process_cmd.readwrite = process_cmd.script
  647. process_cmd.rename = process_cmd.mget
  648. process_cmd.renamenx = process_cmd.mget
  649. process_cmd.restore = process_cmd.set
  650. process_cmd.role = process_cmd.script
  651. process_cmd.rpop = process_cmd.set
  652. process_cmd.rpoplpush = process_cmd.mget
  653. process_cmd.rpush = process_cmd.set
  654. process_cmd.rpushx = process_cmd.set
  655. process_cmd.sadd = process_cmd.set
  656. process_cmd.save = process_cmd.script
  657. process_cmd.scard = process_cmd.set
  658. process_cmd.sdiff = process_cmd.mget
  659. process_cmd.select = process_cmd.script
  660. process_cmd.setbit = process_cmd.set
  661. process_cmd.setex = process_cmd.set
  662. process_cmd.setnx = process_cmd.set
  663. process_cmd.sinterstore = process_cmd.sdiff
  664. process_cmd.sismember = process_cmd.set
  665. process_cmd.slaveof = process_cmd.script
  666. process_cmd.slowlog = process_cmd.script
  667. process_cmd.smembers = process_cmd.script
  668. process_cmd.sort = process_cmd.set
  669. process_cmd.spop = process_cmd.set
  670. process_cmd.srandmember = process_cmd.set
  671. process_cmd.srem = process_cmd.set
  672. process_cmd.strlen = process_cmd.set
  673. process_cmd.subscribe = process_cmd.script
  674. process_cmd.sunion = process_cmd.mget
  675. process_cmd.sunionstore = process_cmd.mget
  676. process_cmd.swapdb = process_cmd.script
  677. process_cmd.sync = process_cmd.script
  678. process_cmd.time = process_cmd.script
  679. process_cmd.touch = process_cmd.mget
  680. process_cmd.ttl = process_cmd.set
  681. process_cmd.type = process_cmd.set
  682. process_cmd.unsubscribe = process_cmd.script
  683. process_cmd.unlink = process_cmd.mget
  684. process_cmd.unwatch = process_cmd.script
  685. process_cmd.wait = process_cmd.script
  686. process_cmd.watch = process_cmd.mget
  687. process_cmd.zadd = process_cmd.set
  688. process_cmd.zcard = process_cmd.set
  689. process_cmd.zcount = process_cmd.set
  690. process_cmd.zincrby = process_cmd.set
  691. process_cmd.zinterstore = process_cmd.eval
  692. process_cmd.zlexcount = process_cmd.set
  693. process_cmd.zrange = process_cmd.set
  694. process_cmd.zrangebylex = process_cmd.set
  695. process_cmd.zrank = process_cmd.set
  696. process_cmd.zrem = process_cmd.set
  697. process_cmd.zrembylex = process_cmd.set
  698. process_cmd.zrembyrank = process_cmd.set
  699. process_cmd.zrembyscore = process_cmd.set
  700. process_cmd.zrevrange = process_cmd.set
  701. process_cmd.zrevrangebyscore = process_cmd.set
  702. process_cmd.zrevrank = process_cmd.set
  703. process_cmd.zscore = process_cmd.set
  704. process_cmd.zunionstore = process_cmd.eval
  705. process_cmd.scan = process_cmd.script
  706. process_cmd.sscan = process_cmd.set
  707. process_cmd.hscan = process_cmd.set
  708. process_cmd.zscan = process_cmd.set
  709. local function get_key_indexes(cmd, args)
  710. local idx_l = {}
  711. cmd = string.lower(cmd)
  712. if process_cmd[cmd] then
  713. idx_l = process_cmd[cmd](args)
  714. else
  715. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  716. end
  717. return idx_l
  718. end
  719. local gen_meta = {
  720. principal_recipient = function(task)
  721. return task:get_principal_recipient()
  722. end,
  723. principal_recipient_domain = function(task)
  724. local p = task:get_principal_recipient()
  725. if not p then return end
  726. return string.match(p, '.*@(.*)')
  727. end,
  728. ip = function(task)
  729. local i = task:get_ip()
  730. if i and i:is_valid() then return i:to_string() end
  731. end,
  732. from = function(task)
  733. return ((task:get_from('smtp') or E)[1] or E)['addr']
  734. end,
  735. from_domain = function(task)
  736. return ((task:get_from('smtp') or E)[1] or E)['domain']
  737. end,
  738. from_domain_or_helo_domain = function(task)
  739. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  740. if d and #d > 0 then return d end
  741. return task:get_helo()
  742. end,
  743. mime_from = function(task)
  744. return ((task:get_from('mime') or E)[1] or E)['addr']
  745. end,
  746. mime_from_domain = function(task)
  747. return ((task:get_from('mime') or E)[1] or E)['domain']
  748. end,
  749. }
  750. local function gen_get_esld(f)
  751. return function(task)
  752. local d = f(task)
  753. if not d then return end
  754. return rspamd_util.get_tld(d)
  755. end
  756. end
  757. gen_meta.smtp_from = gen_meta.from
  758. gen_meta.smtp_from_domain = gen_meta.from_domain
  759. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  760. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  761. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  762. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  763. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  764. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  765. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  766. local function get_key_expansion_metadata(task)
  767. local md_mt = {
  768. __index = function(self, k)
  769. k = string.lower(k)
  770. local v = rawget(self, k)
  771. if v then
  772. return v
  773. end
  774. if gen_meta[k] then
  775. v = gen_meta[k](task)
  776. rawset(self, k, v)
  777. end
  778. return v
  779. end,
  780. }
  781. local lazy_meta = {}
  782. setmetatable(lazy_meta, md_mt)
  783. return lazy_meta
  784. end
  785. -- Performs async call to redis hiding all complexity inside function
  786. -- task - rspamd_task
  787. -- redis_params - valid params returned by rspamd_parse_redis_server
  788. -- key - key to select upstream or nil to select round-robin/master-slave
  789. -- is_write - true if need to write to redis server
  790. -- callback - function to be called upon request is completed
  791. -- command - redis command
  792. -- args - table of arguments
  793. -- extra_opts - table of optional request arguments
  794. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  795. callback, command, args, extra_opts)
  796. local addr
  797. local function rspamd_redis_make_request_cb(err, data)
  798. if err then
  799. addr:fail()
  800. else
  801. addr:ok()
  802. end
  803. if callback then
  804. callback(err, data, addr)
  805. end
  806. end
  807. if not task or not redis_params or not callback or not command then
  808. return false,nil,nil
  809. end
  810. local rspamd_redis = require "rspamd_redis"
  811. if key then
  812. if is_write then
  813. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  814. else
  815. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  816. end
  817. else
  818. if is_write then
  819. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  820. else
  821. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  822. end
  823. end
  824. if not addr then
  825. logger.errx(task, 'cannot select server to make redis request')
  826. end
  827. if redis_params['expand_keys'] then
  828. local m = get_key_expansion_metadata(task)
  829. local indexes = get_key_indexes(command, args)
  830. for _, i in ipairs(indexes) do
  831. args[i] = lutil.template(args[i], m)
  832. end
  833. end
  834. local ip_addr = addr:get_addr()
  835. local options = {
  836. task = task,
  837. callback = rspamd_redis_make_request_cb,
  838. host = ip_addr,
  839. timeout = redis_params['timeout'],
  840. cmd = command,
  841. args = args
  842. }
  843. if extra_opts then
  844. for k,v in pairs(extra_opts) do
  845. options[k] = v
  846. end
  847. end
  848. if redis_params['password'] then
  849. options['password'] = redis_params['password']
  850. end
  851. if redis_params['db'] then
  852. options['dbname'] = redis_params['db']
  853. end
  854. lutil.debugm(N, task, 'perform request to redis server' ..
  855. ' (host=%s, timeout=%s): cmd: %s', ip_addr,
  856. options.timeout, options.cmd)
  857. local ret,conn = rspamd_redis.make_request(options)
  858. if not ret then
  859. addr:fail()
  860. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  861. end
  862. return ret,conn,addr
  863. end
  864. --[[[
  865. -- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
  866. -- Sends a request to Redis
  867. -- @param {rspamd_task} task task object
  868. -- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
  869. -- @param {string} key key to use for sharding
  870. -- @param {boolean} is_write should be `true` if we are performing a write operating
  871. -- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
  872. -- @param {string} command Redis command to run
  873. -- @param {table} args Numerically indexed table containing arguments for command
  874. --]]
  875. exports.rspamd_redis_make_request = rspamd_redis_make_request
  876. exports.redis_make_request = rspamd_redis_make_request
  877. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  878. is_write, callback, command, args, extra_opts)
  879. if not ev_base or not redis_params or not callback or not command then
  880. return false,nil,nil
  881. end
  882. local addr
  883. local function rspamd_redis_make_request_cb(err, data)
  884. if err then
  885. addr:fail()
  886. else
  887. addr:ok()
  888. end
  889. if callback then
  890. callback(err, data, addr)
  891. end
  892. end
  893. local rspamd_redis = require "rspamd_redis"
  894. if key then
  895. if is_write then
  896. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  897. else
  898. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  899. end
  900. else
  901. if is_write then
  902. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  903. else
  904. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  905. end
  906. end
  907. if not addr then
  908. logger.errx(cfg, 'cannot select server to make redis request')
  909. end
  910. local options = {
  911. ev_base = ev_base,
  912. config = cfg,
  913. callback = rspamd_redis_make_request_cb,
  914. host = addr:get_addr(),
  915. timeout = redis_params['timeout'],
  916. cmd = command,
  917. args = args
  918. }
  919. if extra_opts then
  920. for k,v in pairs(extra_opts) do
  921. options[k] = v
  922. end
  923. end
  924. if redis_params['password'] then
  925. options['password'] = redis_params['password']
  926. end
  927. if redis_params['db'] then
  928. options['dbname'] = redis_params['db']
  929. end
  930. lutil.debugm(N, cfg, 'perform taskless request to redis server' ..
  931. ' (host=%s, timeout=%s): cmd: %s', options.host:tostring(true),
  932. options.timeout, options.cmd)
  933. local ret,conn = rspamd_redis.make_request(options)
  934. if not ret then
  935. logger.errx('cannot execute redis request')
  936. addr:fail()
  937. end
  938. return ret,conn,addr
  939. end
  940. --[[[
  941. -- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
  942. -- Sends a request to Redis in context where `task` is not available for some specific use-cases
  943. -- Identical to redis_make_request() except in that first parameter is an `event base` object
  944. --]]
  945. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  946. exports.redis_make_request_taskless = redis_make_request_taskless
  947. local redis_scripts = {
  948. }
  949. local function script_set_loaded(script)
  950. if script.sha then
  951. script.loaded = true
  952. end
  953. local wait_table = {}
  954. for _,s in ipairs(script.waitq) do
  955. table.insert(wait_table, s)
  956. end
  957. script.waitq = {}
  958. for _,s in ipairs(wait_table) do
  959. s(script.loaded)
  960. end
  961. end
  962. local function prepare_redis_call(script)
  963. local function merge_tables(t1, t2)
  964. for k,v in pairs(t2) do t1[k] = v end
  965. end
  966. local servers = {}
  967. local options = {}
  968. if script.redis_params.read_servers then
  969. merge_tables(servers, script.redis_params.read_servers:all_upstreams())
  970. end
  971. if script.redis_params.write_servers then
  972. merge_tables(servers, script.redis_params.write_servers:all_upstreams())
  973. end
  974. -- Call load script on each server, set loaded flag
  975. script.in_flight = #servers
  976. for _,s in ipairs(servers) do
  977. local cur_opts = {
  978. host = s:get_addr(),
  979. timeout = script.redis_params['timeout'],
  980. cmd = 'SCRIPT',
  981. args = {'LOAD', script.script },
  982. upstream = s
  983. }
  984. if script.redis_params['password'] then
  985. cur_opts['password'] = script.redis_params['password']
  986. end
  987. if script.redis_params['db'] then
  988. cur_opts['dbname'] = script.redis_params['db']
  989. end
  990. table.insert(options, cur_opts)
  991. end
  992. return options
  993. end
  994. local function load_script_task(script, task, is_write)
  995. local rspamd_redis = require "rspamd_redis"
  996. local opts = prepare_redis_call(script)
  997. for _,opt in ipairs(opts) do
  998. opt.task = task
  999. opt.is_write = is_write
  1000. opt.callback = function(err, data)
  1001. if err then
  1002. logger.errx(task, 'cannot upload script to %s: %s; registered from: %s:%s',
  1003. opt.upstream:get_addr():to_string(true),
  1004. err, script.caller.short_src, script.caller.currentline)
  1005. opt.upstream:fail()
  1006. script.fatal_error = err
  1007. else
  1008. opt.upstream:ok()
  1009. logger.infox(task,
  1010. "uploaded redis script to %s with id %s, sha: %s",
  1011. opt.upstream:get_addr():to_string(true),
  1012. script.id, data)
  1013. script.sha = data -- We assume that sha is the same on all servers
  1014. end
  1015. script.in_flight = script.in_flight - 1
  1016. if script.in_flight == 0 then
  1017. script_set_loaded(script)
  1018. end
  1019. end
  1020. local ret = rspamd_redis.make_request(opt)
  1021. if not ret then
  1022. logger.errx('cannot execute redis request to load script on %s',
  1023. opt.upstream:get_addr())
  1024. script.in_flight = script.in_flight - 1
  1025. opt.upstream:fail()
  1026. end
  1027. if script.in_flight == 0 then
  1028. script_set_loaded(script)
  1029. end
  1030. end
  1031. end
  1032. local function load_script_taskless(script, cfg, ev_base, is_write)
  1033. local rspamd_redis = require "rspamd_redis"
  1034. local opts = prepare_redis_call(script)
  1035. for _,opt in ipairs(opts) do
  1036. opt.config = cfg
  1037. opt.ev_base = ev_base
  1038. opt.is_write = is_write
  1039. opt.callback = function(err, data)
  1040. if err then
  1041. logger.errx(cfg, 'cannot upload script to %s: %s; registered from: %s:%s',
  1042. opt.upstream:get_addr():to_string(true),
  1043. err, script.caller.short_src, script.caller.currentline)
  1044. opt.upstream:fail()
  1045. script.fatal_error = err
  1046. else
  1047. opt.upstream:ok()
  1048. logger.infox(cfg,
  1049. "uploaded redis script to %s with id %s, sha: %s",
  1050. opt.upstream:get_addr():to_string(true), script.id, data)
  1051. script.sha = data -- We assume that sha is the same on all servers
  1052. script.fatal_error = nil
  1053. end
  1054. script.in_flight = script.in_flight - 1
  1055. if script.in_flight == 0 then
  1056. script_set_loaded(script)
  1057. end
  1058. end
  1059. local ret = rspamd_redis.make_request(opt)
  1060. if not ret then
  1061. logger.errx('cannot execute redis request to load script on %s',
  1062. opt.upstream:get_addr())
  1063. script.in_flight = script.in_flight - 1
  1064. opt.upstream:fail()
  1065. end
  1066. if script.in_flight == 0 then
  1067. script_set_loaded(script)
  1068. end
  1069. end
  1070. end
  1071. local function load_redis_script(script, cfg, ev_base, _)
  1072. if script.redis_params then
  1073. load_script_taskless(script, cfg, ev_base)
  1074. end
  1075. end
  1076. local function add_redis_script(script, redis_params, caller_level)
  1077. if not caller_level then caller_level = 2 end
  1078. local caller = debug.getinfo(caller_level)
  1079. local new_script = {
  1080. caller = caller,
  1081. loaded = false,
  1082. redis_params = redis_params,
  1083. script = script,
  1084. waitq = {}, -- callbacks pending for script being loaded
  1085. id = #redis_scripts + 1
  1086. }
  1087. -- Register on load function
  1088. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  1089. local mult = 0.0
  1090. rspamd_config:add_periodic(ev_base, 0.0, function()
  1091. if not new_script.sha then
  1092. load_redis_script(new_script, cfg, ev_base, worker)
  1093. mult = mult + 1
  1094. return 1.0 * mult -- Check one more time in one second
  1095. end
  1096. return false
  1097. end, false)
  1098. end)
  1099. table.insert(redis_scripts, new_script)
  1100. return #redis_scripts
  1101. end
  1102. exports.add_redis_script = add_redis_script
  1103. -- Loads a Redis script from a file, strips comments, and passes the content to
  1104. -- `add_redis_script` function.
  1105. --
  1106. -- @param filename The name of the file containing the Redis script.
  1107. -- @param redis_params The Redis parameters to use for this script.
  1108. -- @return The ID of the newly added Redis script.
  1109. --
  1110. local function load_redis_script_from_file(filename, redis_params, dir)
  1111. local lua_util = require "lua_util"
  1112. local rspamd_logger = require "rspamd_logger"
  1113. if not dir then dir = rspamd_paths.LUALIBDIR end
  1114. if filename:sub(1, 1) ~= package.config:sub(1,1) then
  1115. -- Relative path
  1116. filename = lua_util.join_path(dir, "redis_scripts", filename)
  1117. end
  1118. -- Read file contents
  1119. local file = io.open(filename, "r")
  1120. if not file then
  1121. rspamd_logger.errx("failed to open Redis script file: %s", filename)
  1122. return nil
  1123. end
  1124. local script = file:read("*all")
  1125. if not script then
  1126. rspamd_logger.errx("failed to load Redis script file: %s", filename)
  1127. return nil
  1128. end
  1129. file:close()
  1130. script = lua_util.strip_lua_comments(script)
  1131. return add_redis_script(script, redis_params, 3)
  1132. end
  1133. exports.load_redis_script_from_file = load_redis_script_from_file
  1134. local function exec_redis_script(id, params, callback, keys, args)
  1135. local redis_args = {}
  1136. if not redis_scripts[id] then
  1137. logger.errx("cannot find registered script with id %s", id)
  1138. return false
  1139. end
  1140. local script = redis_scripts[id]
  1141. if script.fatal_error then
  1142. callback(script.fatal_error, nil)
  1143. return true
  1144. end
  1145. if not script.redis_params then
  1146. callback('no redis servers defined', nil)
  1147. return true
  1148. end
  1149. local function do_call(can_reload)
  1150. local function redis_cb(err, data)
  1151. if not err then
  1152. callback(err, data)
  1153. elseif string.match(err, 'NOSCRIPT') then
  1154. -- Schedule restart
  1155. script.sha = nil
  1156. if can_reload then
  1157. table.insert(script.waitq, do_call)
  1158. if script.in_flight == 0 then
  1159. -- Reload scripts if this has not been initiated yet
  1160. if params.task then
  1161. load_script_task(script, params.task)
  1162. else
  1163. load_script_taskless(script, rspamd_config, params.ev_base)
  1164. end
  1165. end
  1166. else
  1167. callback(err, data)
  1168. end
  1169. else
  1170. callback(err, data)
  1171. end
  1172. end
  1173. if #redis_args == 0 then
  1174. table.insert(redis_args, script.sha)
  1175. table.insert(redis_args, tostring(#keys))
  1176. for _,k in ipairs(keys) do
  1177. table.insert(redis_args, k)
  1178. end
  1179. if type(args) == 'table' then
  1180. for _, a in ipairs(args) do
  1181. table.insert(redis_args, a)
  1182. end
  1183. end
  1184. end
  1185. if params.task then
  1186. if not rspamd_redis_make_request(params.task, script.redis_params,
  1187. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1188. callback('Cannot make redis request', nil)
  1189. end
  1190. else
  1191. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  1192. script.redis_params,
  1193. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1194. callback('Cannot make redis request', nil)
  1195. end
  1196. end
  1197. end
  1198. if script.loaded then
  1199. do_call(true)
  1200. else
  1201. -- Delayed until scripts are loaded
  1202. if not params.task then
  1203. table.insert(script.waitq, do_call)
  1204. else
  1205. -- TODO: fix taskfull requests
  1206. table.insert(script.waitq, function()
  1207. if script.loaded then
  1208. do_call(false)
  1209. else
  1210. callback('NOSCRIPT', nil)
  1211. end
  1212. end)
  1213. load_script_task(script, params.task, params.is_write)
  1214. end
  1215. end
  1216. return true
  1217. end
  1218. exports.exec_redis_script = exec_redis_script
  1219. local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
  1220. if not redis_params then
  1221. return false,nil
  1222. end
  1223. local rspamd_redis = require "rspamd_redis"
  1224. local addr
  1225. if key then
  1226. if is_write then
  1227. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  1228. else
  1229. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  1230. end
  1231. else
  1232. if is_write then
  1233. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  1234. else
  1235. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  1236. end
  1237. end
  1238. if not addr then
  1239. logger.errx(cfg, 'cannot select server to make redis request')
  1240. end
  1241. local options = {
  1242. host = addr:get_addr(),
  1243. timeout = redis_params['timeout'],
  1244. config = cfg or rspamd_config,
  1245. ev_base = ev_base or rspamadm_ev_base,
  1246. session = redis_params.session or rspamadm_session
  1247. }
  1248. for k,v in pairs(redis_params) do
  1249. options[k] = v
  1250. end
  1251. if not options.config then
  1252. logger.errx('config is not set')
  1253. return false,nil,addr
  1254. end
  1255. if not options.ev_base then
  1256. logger.errx('ev_base is not set')
  1257. return false,nil,addr
  1258. end
  1259. if not options.session then
  1260. logger.errx('session is not set')
  1261. return false,nil,addr
  1262. end
  1263. local ret,conn = rspamd_redis.connect_sync(options)
  1264. if not ret then
  1265. logger.errx('cannot create redis connection: %s', conn)
  1266. addr:fail()
  1267. return false,nil,addr
  1268. end
  1269. if conn then
  1270. local need_exec = false
  1271. if redis_params['password'] then
  1272. conn:add_cmd('AUTH', {redis_params['password']})
  1273. need_exec = true
  1274. end
  1275. if redis_params['db'] then
  1276. conn:add_cmd('SELECT', {tostring(redis_params['db'])})
  1277. need_exec = true
  1278. elseif redis_params['dbname'] then
  1279. conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
  1280. need_exec = true
  1281. end
  1282. if need_exec then
  1283. local exec_ret, res = conn:exec()
  1284. if not exec_ret then
  1285. logger.errx('cannot prepare redis connection (authentication or db selection failure): %s',
  1286. res)
  1287. addr:fail()
  1288. return false,nil,addr
  1289. end
  1290. end
  1291. end
  1292. return ret,conn,addr
  1293. end
  1294. exports.redis_connect_sync = redis_connect_sync
  1295. --[[[
  1296. -- @function lua_redis.request(redis_params, attrs, req)
  1297. -- Sends a request to Redis synchronously with coroutines or asynchronously using
  1298. -- a callback (modern API)
  1299. -- @param redis_params a table of redis server parameters
  1300. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1301. -- @param req a table of request: a command + command options
  1302. -- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
  1303. --]]
  1304. exports.request = function(redis_params, attrs, req)
  1305. local lua_util = require "lua_util"
  1306. if not attrs or not redis_params or not req then
  1307. logger.errx('invalid arguments for redis request')
  1308. return false,nil,nil
  1309. end
  1310. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1311. logger.errx('invalid attributes for redis request')
  1312. return false,nil,nil
  1313. end
  1314. local opts = lua_util.shallowcopy(attrs)
  1315. local log_obj = opts.task or opts.config
  1316. local addr
  1317. if opts.callback then
  1318. -- Wrap callback
  1319. local callback = opts.callback
  1320. local function rspamd_redis_make_request_cb(err, data)
  1321. if err then
  1322. addr:fail()
  1323. else
  1324. addr:ok()
  1325. end
  1326. callback(err, data, addr)
  1327. end
  1328. opts.callback = rspamd_redis_make_request_cb
  1329. end
  1330. local rspamd_redis = require "rspamd_redis"
  1331. local is_write = opts.is_write
  1332. if opts.key then
  1333. if is_write then
  1334. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1335. else
  1336. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1337. end
  1338. else
  1339. if is_write then
  1340. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1341. else
  1342. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1343. end
  1344. end
  1345. if not addr then
  1346. logger.errx(log_obj, 'cannot select server to make redis request')
  1347. end
  1348. opts.host = addr:get_addr()
  1349. opts.timeout = redis_params.timeout
  1350. if type(req) == 'string' then
  1351. opts.cmd = req
  1352. else
  1353. -- XXX: modifies the input table
  1354. opts.cmd = table.remove(req, 1);
  1355. opts.args = req
  1356. end
  1357. if redis_params.password then
  1358. opts.password = redis_params.password
  1359. end
  1360. if redis_params.db then
  1361. opts.dbname = redis_params.db
  1362. end
  1363. lutil.debugm(N, 'perform generic request to redis server' ..
  1364. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr,
  1365. opts.timeout, opts.cmd, opts.args)
  1366. if opts.callback then
  1367. local ret,conn = rspamd_redis.make_request(opts)
  1368. if not ret then
  1369. logger.errx(log_obj, 'cannot execute redis request')
  1370. addr:fail()
  1371. end
  1372. return ret,conn,addr
  1373. else
  1374. -- Coroutines version
  1375. local ret,conn = rspamd_redis.connect_sync(opts)
  1376. if not ret then
  1377. logger.errx(log_obj, 'cannot execute redis request')
  1378. addr:fail()
  1379. else
  1380. conn:add_cmd(opts.cmd, opts.args)
  1381. return conn:exec()
  1382. end
  1383. return false,nil,addr
  1384. end
  1385. end
  1386. --[[[
  1387. -- @function lua_redis.connect(redis_params, attrs)
  1388. -- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
  1389. -- @param redis_params a table of redis server parameters
  1390. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1391. -- @return {result,connection,address} boolean result, connection object, redis server address
  1392. --]]
  1393. exports.connect = function(redis_params, attrs)
  1394. local lua_util = require "lua_util"
  1395. if not attrs or not redis_params then
  1396. logger.errx('invalid arguments for redis connect')
  1397. return false,nil,nil
  1398. end
  1399. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1400. logger.errx('invalid attributes for redis connect')
  1401. return false,nil,nil
  1402. end
  1403. local opts = lua_util.shallowcopy(attrs)
  1404. local log_obj = opts.task or opts.config
  1405. local addr
  1406. if opts.callback then
  1407. -- Wrap callback
  1408. local callback = opts.callback
  1409. local function rspamd_redis_make_request_cb(err, data)
  1410. if err then
  1411. addr:fail()
  1412. else
  1413. addr:ok()
  1414. end
  1415. callback(err, data, addr)
  1416. end
  1417. opts.callback = rspamd_redis_make_request_cb
  1418. end
  1419. local rspamd_redis = require "rspamd_redis"
  1420. local is_write = opts.is_write
  1421. if opts.key then
  1422. if is_write then
  1423. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1424. else
  1425. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1426. end
  1427. else
  1428. if is_write then
  1429. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1430. else
  1431. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1432. end
  1433. end
  1434. if not addr then
  1435. logger.errx(log_obj, 'cannot select server to make redis connect')
  1436. end
  1437. opts.host = addr:get_addr()
  1438. opts.timeout = redis_params.timeout
  1439. if redis_params.password then
  1440. opts.password = redis_params.password
  1441. end
  1442. if redis_params.db then
  1443. opts.dbname = redis_params.db
  1444. end
  1445. if opts.callback then
  1446. local ret,conn = rspamd_redis.connect(opts)
  1447. if not ret then
  1448. logger.errx(log_obj, 'cannot execute redis connect')
  1449. addr:fail()
  1450. end
  1451. return ret,conn,addr
  1452. else
  1453. -- Coroutines version
  1454. local ret,conn = rspamd_redis.connect_sync(opts)
  1455. if not ret then
  1456. logger.errx(log_obj, 'cannot execute redis connect')
  1457. addr:fail()
  1458. else
  1459. return true,conn,addr
  1460. end
  1461. return false,nil,addr
  1462. end
  1463. end
  1464. local redis_prefixes = {}
  1465. --[[[
  1466. -- @function lua_redis.register_prefix(prefix, module, description[, optional])
  1467. -- Register new redis prefix for documentation purposes
  1468. -- @param {string} prefix string prefix
  1469. -- @param {string} module module name
  1470. -- @param {string} description prefix description
  1471. -- @param {table} optional optional kv pairs (e.g. pattern)
  1472. --]]
  1473. local function register_prefix(prefix, module, description, optional)
  1474. local pr = {
  1475. module = module,
  1476. description = description
  1477. }
  1478. if optional and type(optional) == 'table' then
  1479. for k,v in pairs(optional) do
  1480. pr[k] = v
  1481. end
  1482. end
  1483. redis_prefixes[prefix] = pr
  1484. end
  1485. exports.register_prefix = register_prefix
  1486. --[[[
  1487. -- @function lua_redis.prefixes([mname])
  1488. -- Returns prefixes for specific module (or all prefixes). Returns a table prefix -> table
  1489. --]]
  1490. exports.prefixes = function(mname)
  1491. if not mname then
  1492. return redis_prefixes
  1493. else
  1494. local fun = require "fun"
  1495. return fun.totable(fun.filter(function(_, data) return data.module == mname end,
  1496. redis_prefixes))
  1497. end
  1498. end
  1499. return exports