You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_redis.lua 51KB


  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local ts = require("tableshape").types
  17. local exports = {}
  18. local E = {}
  19. local N = "lua_redis"
  20. local common_schema = {
  21. timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional():describe("Connection timeout"),
  22. db = ts.string:is_optional():describe("Database number"),
  23. database = ts.string:is_optional():describe("Database number"),
  24. dbname = ts.string:is_optional():describe("Database number"),
  25. prefix = ts.string:is_optional():describe("Key prefix"),
  26. username = ts.string:is_optional():describe("Username"),
  27. password = ts.string:is_optional():describe("Password"),
  28. expand_keys = ts.boolean:is_optional():describe("Expand keys"),
  29. sentinels = (ts.string + ts.array_of(ts.string)):is_optional():describe("Sentinel servers"),
  30. sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional():describe("Sentinel watch time"),
  31. sentinel_masters_pattern = ts.string:is_optional():describe("Sentinel masters pattern"),
  32. sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional():describe("Sentinel master max errors"),
  33. sentinel_username = ts.string:is_optional():describe("Sentinel username"),
  34. sentinel_password = ts.string:is_optional():describe("Sentinel password"),
  35. }
  36. local read_schema = lutil.table_merge({
  37. read_servers = ts.string + ts.array_of(ts.string),
  38. }, common_schema)
  39. local write_schema = lutil.table_merge({
  40. write_servers = ts.string + ts.array_of(ts.string),
  41. }, common_schema)
  42. local rw_schema = lutil.table_merge({
  43. read_servers = ts.string + ts.array_of(ts.string),
  44. write_servers = ts.string + ts.array_of(ts.string),
  45. }, common_schema)
  46. local servers_schema = lutil.table_merge({
  47. servers = ts.string + ts.array_of(ts.string),
  48. }, common_schema)
  49. local server_schema = lutil.table_merge({
  50. server = ts.string + ts.array_of(ts.string),
  51. }, common_schema)
  52. local enrich_schema = function(external)
  53. return ts.one_of {
  54. ts.shape(external), -- no specific redis parameters
  55. ts.shape(lutil.table_merge(read_schema, external)), -- read_servers specified
  56. ts.shape(lutil.table_merge(write_schema, external)), -- write_servers specified
  57. ts.shape(lutil.table_merge(rw_schema, external)), -- both read and write servers defined
  58. ts.shape(lutil.table_merge(servers_schema, external)), -- just servers for both ops
  59. ts.shape(lutil.table_merge(server_schema, external)), -- legacy `server` attribute
  60. }
  61. end
  62. exports.enrich_schema = enrich_schema
  63. local function redis_query_sentinel(ev_base, params, initialised)
  64. local function flatten_redis_table(tbl)
  65. local res = {}
  66. for i = 1, #tbl, 2 do
  67. res[tbl[i]] = tbl[i + 1]
  68. end
  69. return res
  70. end
  71. -- Coroutines syntax
  72. local rspamd_redis = require "rspamd_redis"
  73. local sentinels = params.sentinels
  74. local addr = sentinels:get_upstream_round_robin()
  75. local host = addr:get_addr()
  76. local masters = {}
  77. local process_masters -- Function that is called to process masters data
  78. local function masters_cb(err, result)
  79. if not err and result and type(result) == 'table' then
  80. local pending_subrequests = 0
  81. for _, m in ipairs(result) do
  82. local master = flatten_redis_table(m)
  83. -- Wrap IPv6-addresses in brackets
  84. if (master.ip:match(":")) then
  85. master.ip = "[" .. master.ip .. "]"
  86. end
  87. if params.sentinel_masters_pattern then
  88. if master.name:match(params.sentinel_masters_pattern) then
  89. lutil.debugm(N, 'found master %s with ip %s and port %s',
  90. master.name, master.ip, master.port)
  91. masters[master.name] = master
  92. else
  93. lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
  94. master.name, master.ip, master.port, params.sentinel_masters_pattern)
  95. end
  96. else
  97. lutil.debugm(N, 'found master %s with ip %s and port %s',
  98. master.name, master.ip, master.port)
  99. masters[master.name] = master
  100. end
  101. end
  102. -- For each master we need to get a list of slaves
  103. for k, v in pairs(masters) do
  104. v.slaves = {}
  105. local function slaves_cb(slave_err, slave_result)
  106. if not slave_err and type(slave_result) == 'table' then
  107. for _, s in ipairs(slave_result) do
  108. local slave = flatten_redis_table(s)
  109. lutil.debugm(N, rspamd_config,
  110. 'found slave for master %s with ip %s and port %s',
  111. v.name, slave.ip, slave.port)
  112. -- Wrap IPv6-addresses in brackets
  113. if (slave.ip:match(":")) then
  114. slave.ip = "[" .. slave.ip .. "]"
  115. end
  116. v.slaves[#v.slaves + 1] = slave
  117. end
  118. else
  119. logger.errx('cannot get slaves data from Redis Sentinel %s: %s',
  120. host:to_string(true), slave_err)
  121. addr:fail()
  122. end
  123. pending_subrequests = pending_subrequests - 1
  124. if pending_subrequests == 0 then
  125. -- Finalize masters and slaves
  126. process_masters()
  127. end
  128. end
  129. local ret = rspamd_redis.make_request {
  130. host = addr:get_addr(),
  131. timeout = params.timeout,
  132. username = params.sentinel_username,
  133. password = params.sentinel_password,
  134. config = rspamd_config,
  135. ev_base = ev_base,
  136. cmd = 'SENTINEL',
  137. args = { 'slaves', k },
  138. no_pool = true,
  139. callback = slaves_cb
  140. }
  141. if not ret then
  142. logger.errx(rspamd_config, 'cannot connect sentinel when query slaves at address: %s',
  143. host:to_string(true))
  144. addr:fail()
  145. else
  146. pending_subrequests = pending_subrequests + 1
  147. end
  148. end
  149. addr:ok()
  150. else
  151. logger.errx('cannot get masters data from Redis Sentinel %s: %s',
  152. host:to_string(true), err)
  153. addr:fail()
  154. end
  155. end
  156. local ret = rspamd_redis.make_request {
  157. host = addr:get_addr(),
  158. timeout = params.timeout,
  159. config = rspamd_config,
  160. ev_base = ev_base,
  161. username = params.sentinel_username,
  162. password = params.sentinel_password,
  163. cmd = 'SENTINEL',
  164. args = { 'masters' },
  165. no_pool = true,
  166. callback = masters_cb,
  167. }
  168. if not ret then
  169. logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
  170. host:to_string(true))
  171. addr:fail()
  172. end
  173. process_masters = function()
  174. -- We now form new strings for masters and slaves
  175. local read_servers_tbl, write_servers_tbl = {}, {}
  176. for _, master in pairs(masters) do
  177. write_servers_tbl[#write_servers_tbl + 1] = string.format(
  178. '%s:%s', master.ip, master.port
  179. )
  180. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  181. '%s:%s', master.ip, master.port
  182. )
  183. for _, slave in ipairs(master.slaves) do
  184. if slave['master-link-status'] == 'ok' then
  185. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  186. '%s:%s', slave.ip, slave.port
  187. )
  188. end
  189. end
  190. end
  191. table.sort(read_servers_tbl)
  192. table.sort(write_servers_tbl)
  193. local read_servers_str = table.concat(read_servers_tbl, ',')
  194. local write_servers_str = table.concat(write_servers_tbl, ',')
  195. lutil.debugm(N, rspamd_config,
  196. 'new servers list: %s read; %s write',
  197. read_servers_str,
  198. write_servers_str)
  199. if read_servers_str ~= params.read_servers_str then
  200. local upstream_list = require "rspamd_upstream_list"
  201. local read_upstreams = upstream_list.create(rspamd_config,
  202. read_servers_str, 6379)
  203. if read_upstreams then
  204. logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
  205. host:to_string(true), read_servers_str)
  206. params.read_servers = read_upstreams
  207. params.read_servers_str = read_servers_str
  208. end
  209. end
  210. if write_servers_str ~= params.write_servers_str then
  211. local upstream_list = require "rspamd_upstream_list"
  212. local write_upstreams = upstream_list.create(rspamd_config,
  213. write_servers_str, 6379)
  214. if write_upstreams then
  215. logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
  216. host:to_string(true), write_servers_str)
  217. params.write_servers = write_upstreams
  218. params.write_servers_str = write_servers_str
  219. local queried = false
  220. local function monitor_failures(up, _, count)
  221. if count > params.sentinel_master_maxerrors and not queried then
  222. logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel',
  223. host:to_string(true), count)
  224. queried = true -- Avoid multiple checks caused by this monitor
  225. redis_query_sentinel(ev_base, params, true)
  226. end
  227. end
  228. write_upstreams:add_watcher('failure', monitor_failures)
  229. end
  230. end
  231. end
  232. end
  233. local function add_redis_sentinels(params)
  234. local upstream_list = require "rspamd_upstream_list"
  235. local upstreams_sentinels = upstream_list.create(rspamd_config,
  236. params.sentinels, 5000)
  237. if not upstreams_sentinels then
  238. logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
  239. params.sentinels)
  240. return
  241. end
  242. params.sentinels = upstreams_sentinels
  243. if not params.sentinel_watch_time then
  244. params.sentinel_watch_time = 60 -- Each minute
  245. end
  246. if not params.sentinel_master_maxerrors then
  247. params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking
  248. end
  249. rspamd_config:add_on_load(function(_, ev_base, worker)
  250. local initialised = false
  251. if worker:is_scanner() or worker:get_type() == 'fuzzy' then
  252. rspamd_config:add_periodic(ev_base, 0.0, function()
  253. redis_query_sentinel(ev_base, params, initialised)
  254. initialised = true
  255. return params.sentinel_watch_time
  256. end, false)
  257. end
  258. end)
  259. end
  260. local cached_results = {}
  261. local function calculate_redis_hash(params)
  262. local cr = require "rspamd_cryptobox_hash"
  263. local h = cr.create()
  264. local function rec_hash(k, v)
  265. if type(v) == 'string' then
  266. h:update(k)
  267. h:update(v)
  268. elseif type(v) == 'number' then
  269. h:update(k)
  270. h:update(tostring(v))
  271. elseif type(v) == 'table' then
  272. for kk, vv in pairs(v) do
  273. rec_hash(kk, vv)
  274. end
  275. end
  276. end
  277. rec_hash('top', params)
  278. return h:base32()
  279. end
  280. local function process_redis_opts(options, redis_params)
  281. local default_timeout = 1.0
  282. local default_expand_keys = false
  283. if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then
  284. if options['timeout'] then
  285. redis_params['timeout'] = tonumber(options['timeout'])
  286. else
  287. redis_params['timeout'] = default_timeout
  288. end
  289. end
  290. if options['prefix'] and not redis_params['prefix'] then
  291. redis_params['prefix'] = options['prefix']
  292. end
  293. if type(options['expand_keys']) == 'boolean' then
  294. redis_params['expand_keys'] = options['expand_keys']
  295. else
  296. redis_params['expand_keys'] = default_expand_keys
  297. end
  298. if not redis_params['db'] then
  299. if options['db'] then
  300. redis_params['db'] = tostring(options['db'])
  301. elseif options['dbname'] then
  302. redis_params['db'] = tostring(options['dbname'])
  303. elseif options['database'] then
  304. redis_params['db'] = tostring(options['database'])
  305. end
  306. end
  307. if options['username'] and not redis_params['username'] then
  308. redis_params['username'] = options['username']
  309. end
  310. if options['password'] and not redis_params['password'] then
  311. redis_params['password'] = options['password']
  312. end
  313. if not redis_params.sentinels and options.sentinels then
  314. redis_params.sentinels = options.sentinels
  315. end
  316. if options['sentinel_masters_pattern'] and not redis_params['sentinel_masters_pattern'] then
  317. redis_params['sentinel_masters_pattern'] = options['sentinel_masters_pattern']
  318. end
  319. end
  320. local function enrich_defaults(rspamd_config, module, redis_params)
  321. if rspamd_config then
  322. local opts = rspamd_config:get_all_opt('redis')
  323. if opts then
  324. if module then
  325. if opts[module] then
  326. process_redis_opts(opts[module], redis_params)
  327. end
  328. end
  329. process_redis_opts(opts, redis_params)
  330. end
  331. end
  332. end
  333. local function maybe_return_cached(redis_params)
  334. local h = calculate_redis_hash(redis_params)
  335. if cached_results[h] then
  336. lutil.debugm(N, 'reused redis server: %s', redis_params)
  337. return cached_results[h]
  338. end
  339. redis_params.hash = h
  340. cached_results[h] = redis_params
  341. if not redis_params.read_only and redis_params.sentinels then
  342. add_redis_sentinels(redis_params)
  343. end
  344. lutil.debugm(N, 'loaded new redis server: %s', redis_params)
  345. return redis_params
  346. end
  347. --[[[
  348. -- @module lua_redis
  349. -- This module contains helper functions for working with Redis
  350. --]]
  351. local function process_redis_options(options, rspamd_config, result)
  352. local default_port = 6379
  353. local upstream_list = require "rspamd_upstream_list"
  354. local read_only = true
  355. -- Try to get read servers:
  356. local upstreams_read, upstreams_write
  357. if options['read_servers'] then
  358. if rspamd_config then
  359. upstreams_read = upstream_list.create(rspamd_config,
  360. options['read_servers'], default_port)
  361. else
  362. upstreams_read = upstream_list.create(options['read_servers'],
  363. default_port)
  364. end
  365. result.read_servers_str = options['read_servers']
  366. elseif options['servers'] then
  367. if rspamd_config then
  368. upstreams_read = upstream_list.create(rspamd_config,
  369. options['servers'], default_port)
  370. else
  371. upstreams_read = upstream_list.create(options['servers'], default_port)
  372. end
  373. result.read_servers_str = options['servers']
  374. read_only = false
  375. elseif options['server'] then
  376. if rspamd_config then
  377. upstreams_read = upstream_list.create(rspamd_config,
  378. options['server'], default_port)
  379. else
  380. upstreams_read = upstream_list.create(options['server'], default_port)
  381. end
  382. result.read_servers_str = options['server']
  383. read_only = false
  384. end
  385. if upstreams_read then
  386. if options['write_servers'] then
  387. if rspamd_config then
  388. upstreams_write = upstream_list.create(rspamd_config,
  389. options['write_servers'], default_port)
  390. else
  391. upstreams_write = upstream_list.create(options['write_servers'],
  392. default_port)
  393. end
  394. result.write_servers_str = options['write_servers']
  395. read_only = false
  396. elseif not read_only then
  397. upstreams_write = upstreams_read
  398. result.write_servers_str = result.read_servers_str
  399. end
  400. end
  401. -- Store options
  402. process_redis_opts(options, result)
  403. if read_only and not upstreams_write then
  404. result.read_only = true
  405. elseif upstreams_write then
  406. result.read_only = false
  407. end
  408. if upstreams_read then
  409. result.read_servers = upstreams_read
  410. if upstreams_write then
  411. result.write_servers = upstreams_write
  412. end
  413. return true
  414. end
  415. lutil.debugm(N, rspamd_config,
  416. 'cannot load redis server from obj: %s, processed to %s',
  417. options, result)
  418. return false
  419. end
  420. --[[[
  421. @function try_load_redis_servers(options, rspamd_config, no_fallback)
  422. Tries to load redis servers from the specified `options` object.
  423. Returns `redis_params` table or nil in case of failure
  424. --]]
  425. exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name)
  426. local result = {}
  427. if process_redis_options(options, rspamd_config, result) then
  428. if not no_fallback then
  429. enrich_defaults(rspamd_config, module_name, result)
  430. end
  431. return maybe_return_cached(result)
  432. end
  433. end
  434. -- This function parses redis server definition using either
  435. -- specific server string for this module or global
  436. -- redis section
  437. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  438. local result = {}
  439. -- Try local options
  440. local opts
  441. lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name)
  442. if not module_opts then
  443. opts = rspamd_config:get_all_opt(module_name)
  444. else
  445. opts = module_opts
  446. end
  447. if opts then
  448. local ret
  449. if opts.redis then
  450. ret = process_redis_options(opts.redis, rspamd_config, result)
  451. if ret then
  452. if not no_fallback then
  453. enrich_defaults(rspamd_config, module_name, result)
  454. end
  455. return maybe_return_cached(result)
  456. end
  457. end
  458. ret = process_redis_options(opts, rspamd_config, result)
  459. if ret then
  460. if not no_fallback then
  461. enrich_defaults(rspamd_config, module_name, result)
  462. end
  463. return maybe_return_cached(result)
  464. end
  465. end
  466. if no_fallback then
  467. logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled",
  468. module_name)
  469. return nil
  470. end
  471. -- Try global options
  472. opts = rspamd_config:get_all_opt('redis')
  473. if opts then
  474. local ret
  475. if opts[module_name] then
  476. ret = process_redis_options(opts[module_name], rspamd_config, result)
  477. if ret then
  478. return maybe_return_cached(result)
  479. end
  480. else
  481. ret = process_redis_options(opts, rspamd_config, result)
  482. -- Exclude disabled
  483. if opts['disabled_modules'] then
  484. for _, v in ipairs(opts['disabled_modules']) do
  485. if v == module_name then
  486. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  487. module_name)
  488. return nil
  489. end
  490. end
  491. end
  492. if ret then
  493. logger.infox(rspamd_config, "use default Redis settings for %s",
  494. module_name)
  495. return maybe_return_cached(result)
  496. end
  497. end
  498. end
  499. if result.read_servers then
  500. return maybe_return_cached(result)
  501. end
  502. return nil
  503. end
  504. --[[[
  505. -- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
  506. -- Extracts Redis server settings from configuration
  507. -- @param {string} module_name name of module to get settings for
  508. -- @param {table} module_opts settings for module or `nil` to fetch them from configuration
  509. -- @param {boolean} no_fallback should be `true` if global settings must not be used
  510. -- @return {table} redis server settings
  511. -- @example
  512. -- local rconfig = lua_redis.parse_redis_server('my_module')
  513. -- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
  514. -- -- ['timeout'] contains timeout in seconds
  515. -- -- ['expand_keys'] if true tells that redis key expansion is enabled
  516. --]]
  517. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  518. exports.parse_redis_server = rspamd_parse_redis_server
  519. local process_cmd = {
  520. bitop = function(args)
  521. local idx_l = {}
  522. for i = 2, #args do
  523. table.insert(idx_l, i)
  524. end
  525. return idx_l
  526. end,
  527. blpop = function(args)
  528. local idx_l = {}
  529. for i = 1, #args - 1 do
  530. table.insert(idx_l, i)
  531. end
  532. return idx_l
  533. end,
  534. eval = function(args)
  535. local idx_l = {}
  536. local numkeys = args[2]
  537. if numkeys and tonumber(numkeys) >= 1 then
  538. for i = 3, numkeys + 2 do
  539. table.insert(idx_l, i)
  540. end
  541. end
  542. return idx_l
  543. end,
  544. set = function(args)
  545. return { 1 }
  546. end,
  547. mget = function(args)
  548. local idx_l = {}
  549. for i = 1, #args do
  550. table.insert(idx_l, i)
  551. end
  552. return idx_l
  553. end,
  554. mset = function(args)
  555. local idx_l = {}
  556. for i = 1, #args, 2 do
  557. table.insert(idx_l, i)
  558. end
  559. return idx_l
  560. end,
  561. sdiffstore = function(args)
  562. local idx_l = {}
  563. for i = 2, #args do
  564. table.insert(idx_l, i)
  565. end
  566. return idx_l
  567. end,
  568. smove = function(args)
  569. return { 1, 2 }
  570. end,
  571. script = function()
  572. end
  573. }
  574. process_cmd.append = process_cmd.set
  575. process_cmd.auth = process_cmd.script
  576. process_cmd.bgrewriteaof = process_cmd.script
  577. process_cmd.bgsave = process_cmd.script
  578. process_cmd.bitcount = process_cmd.set
  579. process_cmd.bitfield = process_cmd.set
  580. process_cmd.bitpos = process_cmd.set
  581. process_cmd.brpop = process_cmd.blpop
  582. process_cmd.brpoplpush = process_cmd.blpop
  583. process_cmd.client = process_cmd.script
  584. process_cmd.cluster = process_cmd.script
  585. process_cmd.command = process_cmd.script
  586. process_cmd.config = process_cmd.script
  587. process_cmd.dbsize = process_cmd.script
  588. process_cmd.debug = process_cmd.script
  589. process_cmd.decr = process_cmd.set
  590. process_cmd.decrby = process_cmd.set
  591. process_cmd.del = process_cmd.mget
  592. process_cmd.discard = process_cmd.script
  593. process_cmd.dump = process_cmd.set
  594. process_cmd.echo = process_cmd.script
  595. process_cmd.evalsha = process_cmd.eval
  596. process_cmd.exec = process_cmd.script
  597. process_cmd.exists = process_cmd.mget
  598. process_cmd.expire = process_cmd.set
  599. process_cmd.expireat = process_cmd.set
  600. process_cmd.flushall = process_cmd.script
  601. process_cmd.flushdb = process_cmd.script
  602. process_cmd.geoadd = process_cmd.set
  603. process_cmd.geohash = process_cmd.set
  604. process_cmd.geopos = process_cmd.set
  605. process_cmd.geodist = process_cmd.set
  606. process_cmd.georadius = process_cmd.set
  607. process_cmd.georadiusbymember = process_cmd.set
  608. process_cmd.get = process_cmd.set
  609. process_cmd.getbit = process_cmd.set
  610. process_cmd.getrange = process_cmd.set
  611. process_cmd.getset = process_cmd.set
  612. process_cmd.hdel = process_cmd.set
  613. process_cmd.hexists = process_cmd.set
  614. process_cmd.hget = process_cmd.set
  615. process_cmd.hgetall = process_cmd.set
  616. process_cmd.hincrby = process_cmd.set
  617. process_cmd.hincrbyfloat = process_cmd.set
  618. process_cmd.hkeys = process_cmd.set
  619. process_cmd.hlen = process_cmd.set
  620. process_cmd.hmget = process_cmd.set
  621. process_cmd.hmset = process_cmd.set
  622. process_cmd.hscan = process_cmd.set
  623. process_cmd.hset = process_cmd.set
  624. process_cmd.hsetnx = process_cmd.set
  625. process_cmd.hstrlen = process_cmd.set
  626. process_cmd.hvals = process_cmd.set
  627. process_cmd.incr = process_cmd.set
  628. process_cmd.incrby = process_cmd.set
  629. process_cmd.incrbyfloat = process_cmd.set
  630. process_cmd.info = process_cmd.script
  631. process_cmd.keys = process_cmd.script
  632. process_cmd.lastsave = process_cmd.script
  633. process_cmd.lindex = process_cmd.set
  634. process_cmd.linsert = process_cmd.set
  635. process_cmd.llen = process_cmd.set
  636. process_cmd.lpop = process_cmd.set
  637. process_cmd.lpush = process_cmd.set
  638. process_cmd.lpushx = process_cmd.set
  639. process_cmd.lrange = process_cmd.set
  640. process_cmd.lrem = process_cmd.set
  641. process_cmd.lset = process_cmd.set
  642. process_cmd.ltrim = process_cmd.set
  643. process_cmd.migrate = process_cmd.script
  644. process_cmd.monitor = process_cmd.script
  645. process_cmd.move = process_cmd.set
  646. process_cmd.msetnx = process_cmd.mset
  647. process_cmd.multi = process_cmd.script
  648. process_cmd.object = process_cmd.script
  649. process_cmd.persist = process_cmd.set
  650. process_cmd.pexpire = process_cmd.set
  651. process_cmd.pexpireat = process_cmd.set
  652. process_cmd.pfadd = process_cmd.set
  653. process_cmd.pfcount = process_cmd.set
  654. process_cmd.pfmerge = process_cmd.mget
  655. process_cmd.ping = process_cmd.script
  656. process_cmd.psetex = process_cmd.set
  657. process_cmd.psubscribe = process_cmd.script
  658. process_cmd.pubsub = process_cmd.script
  659. process_cmd.pttl = process_cmd.set
  660. process_cmd.publish = process_cmd.script
  661. process_cmd.punsubscribe = process_cmd.script
  662. process_cmd.quit = process_cmd.script
  663. process_cmd.randomkey = process_cmd.script
  664. process_cmd.readonly = process_cmd.script
  665. process_cmd.readwrite = process_cmd.script
  666. process_cmd.rename = process_cmd.mget
  667. process_cmd.renamenx = process_cmd.mget
  668. process_cmd.restore = process_cmd.set
  669. process_cmd.role = process_cmd.script
  670. process_cmd.rpop = process_cmd.set
  671. process_cmd.rpoplpush = process_cmd.mget
  672. process_cmd.rpush = process_cmd.set
  673. process_cmd.rpushx = process_cmd.set
  674. process_cmd.sadd = process_cmd.set
  675. process_cmd.save = process_cmd.script
  676. process_cmd.scard = process_cmd.set
  677. process_cmd.sdiff = process_cmd.mget
  678. process_cmd.select = process_cmd.script
  679. process_cmd.setbit = process_cmd.set
  680. process_cmd.setex = process_cmd.set
  681. process_cmd.setnx = process_cmd.set
  682. process_cmd.sinterstore = process_cmd.sdiff
  683. process_cmd.sismember = process_cmd.set
  684. process_cmd.slaveof = process_cmd.script
  685. process_cmd.slowlog = process_cmd.script
  686. process_cmd.smembers = process_cmd.script
  687. process_cmd.sort = process_cmd.set
  688. process_cmd.spop = process_cmd.set
  689. process_cmd.srandmember = process_cmd.set
  690. process_cmd.srem = process_cmd.set
  691. process_cmd.strlen = process_cmd.set
  692. process_cmd.subscribe = process_cmd.script
  693. process_cmd.sunion = process_cmd.mget
  694. process_cmd.sunionstore = process_cmd.mget
  695. process_cmd.swapdb = process_cmd.script
  696. process_cmd.sync = process_cmd.script
  697. process_cmd.time = process_cmd.script
  698. process_cmd.touch = process_cmd.mget
  699. process_cmd.ttl = process_cmd.set
  700. process_cmd.type = process_cmd.set
  701. process_cmd.unsubscribe = process_cmd.script
  702. process_cmd.unlink = process_cmd.mget
  703. process_cmd.unwatch = process_cmd.script
  704. process_cmd.wait = process_cmd.script
  705. process_cmd.watch = process_cmd.mget
  706. process_cmd.zadd = process_cmd.set
  707. process_cmd.zcard = process_cmd.set
  708. process_cmd.zcount = process_cmd.set
  709. process_cmd.zincrby = process_cmd.set
  710. process_cmd.zinterstore = process_cmd.eval
  711. process_cmd.zlexcount = process_cmd.set
  712. process_cmd.zrange = process_cmd.set
  713. process_cmd.zrangebylex = process_cmd.set
  714. process_cmd.zrank = process_cmd.set
  715. process_cmd.zrem = process_cmd.set
  716. process_cmd.zrembylex = process_cmd.set
  717. process_cmd.zrembyrank = process_cmd.set
  718. process_cmd.zrembyscore = process_cmd.set
  719. process_cmd.zrevrange = process_cmd.set
  720. process_cmd.zrevrangebyscore = process_cmd.set
  721. process_cmd.zrevrank = process_cmd.set
  722. process_cmd.zscore = process_cmd.set
  723. process_cmd.zunionstore = process_cmd.eval
  724. process_cmd.scan = process_cmd.script
  725. process_cmd.sscan = process_cmd.set
  726. process_cmd.hscan = process_cmd.set
  727. process_cmd.zscan = process_cmd.set
  728. local function get_key_indexes(cmd, args)
  729. local idx_l = {}
  730. cmd = string.lower(cmd)
  731. if process_cmd[cmd] then
  732. idx_l = process_cmd[cmd](args)
  733. else
  734. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  735. end
  736. return idx_l
  737. end
  738. local gen_meta = {
  739. principal_recipient = function(task)
  740. return task:get_principal_recipient()
  741. end,
  742. principal_recipient_domain = function(task)
  743. local p = task:get_principal_recipient()
  744. if not p then
  745. return
  746. end
  747. return string.match(p, '.*@(.*)')
  748. end,
  749. ip = function(task)
  750. local i = task:get_ip()
  751. if i and i:is_valid() then
  752. return i:to_string()
  753. end
  754. end,
  755. from = function(task)
  756. return ((task:get_from('smtp') or E)[1] or E)['addr']
  757. end,
  758. from_domain = function(task)
  759. return ((task:get_from('smtp') or E)[1] or E)['domain']
  760. end,
  761. from_domain_or_helo_domain = function(task)
  762. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  763. if d and #d > 0 then
  764. return d
  765. end
  766. return task:get_helo()
  767. end,
  768. mime_from = function(task)
  769. return ((task:get_from('mime') or E)[1] or E)['addr']
  770. end,
  771. mime_from_domain = function(task)
  772. return ((task:get_from('mime') or E)[1] or E)['domain']
  773. end,
  774. }
  775. local function gen_get_esld(f)
  776. return function(task)
  777. local d = f(task)
  778. if not d then
  779. return
  780. end
  781. return rspamd_util.get_tld(d)
  782. end
  783. end
  784. gen_meta.smtp_from = gen_meta.from
  785. gen_meta.smtp_from_domain = gen_meta.from_domain
  786. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  787. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  788. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  789. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  790. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  791. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  792. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  793. local function get_key_expansion_metadata(task)
  794. local md_mt = {
  795. __index = function(self, k)
  796. k = string.lower(k)
  797. local v = rawget(self, k)
  798. if v then
  799. return v
  800. end
  801. if gen_meta[k] then
  802. v = gen_meta[k](task)
  803. rawset(self, k, v)
  804. end
  805. return v
  806. end,
  807. }
  808. local lazy_meta = {}
  809. setmetatable(lazy_meta, md_mt)
  810. return lazy_meta
  811. end
  812. -- Performs async call to redis hiding all complexity inside function
  813. -- task - rspamd_task
  814. -- redis_params - valid params returned by rspamd_parse_redis_server
  815. -- key - key to select upstream or nil to select round-robin/master-slave
  816. -- is_write - true if need to write to redis server
  817. -- callback - function to be called upon request is completed
  818. -- command - redis command
  819. -- args - table of arguments
  820. -- extra_opts - table of optional request arguments
  821. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  822. callback, command, args, extra_opts)
  823. local addr
  824. local function rspamd_redis_make_request_cb(err, data)
  825. if err then
  826. addr:fail()
  827. else
  828. addr:ok()
  829. end
  830. if callback then
  831. callback(err, data, addr)
  832. end
  833. end
  834. if not task or not redis_params or not command then
  835. return false, nil, nil
  836. end
  837. local rspamd_redis = require "rspamd_redis"
  838. if key then
  839. if is_write then
  840. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  841. else
  842. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  843. end
  844. else
  845. if is_write then
  846. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  847. else
  848. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  849. end
  850. end
  851. if not addr then
  852. logger.errx(task, 'cannot select server to make redis request')
  853. end
  854. if redis_params['expand_keys'] then
  855. local m = get_key_expansion_metadata(task)
  856. local indexes = get_key_indexes(command, args)
  857. for _, i in ipairs(indexes) do
  858. args[i] = lutil.template(args[i], m)
  859. end
  860. end
  861. local ip_addr = addr:get_addr()
  862. local options = {
  863. task = task,
  864. callback = rspamd_redis_make_request_cb,
  865. host = ip_addr,
  866. timeout = redis_params['timeout'],
  867. cmd = command,
  868. args = args
  869. }
  870. if extra_opts then
  871. for k, v in pairs(extra_opts) do
  872. options[k] = v
  873. end
  874. end
  875. if redis_params['username'] then
  876. options['username'] = redis_params['username']
  877. end
  878. if redis_params['password'] then
  879. options['password'] = redis_params['password']
  880. end
  881. if redis_params['db'] then
  882. options['dbname'] = redis_params['db']
  883. end
  884. lutil.debugm(N, task, 'perform request to redis server' ..
  885. ' (host=%s, timeout=%s): cmd: %s', ip_addr,
  886. options.timeout, options.cmd)
  887. local ret, conn = rspamd_redis.make_request(options)
  888. if not ret then
  889. addr:fail()
  890. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  891. end
  892. return ret, conn, addr
  893. end
  894. --[[[
  895. -- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
  896. -- Sends a request to Redis
  897. -- @param {rspamd_task} task task object
  898. -- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
  899. -- @param {string} key key to use for sharding
  900. -- @param {boolean} is_write should be `true` if we are performing a write operating
  901. -- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
  902. -- @param {string} command Redis command to run
  903. -- @param {table} args Numerically indexed table containing arguments for command
  904. --]]
  905. exports.rspamd_redis_make_request = rspamd_redis_make_request
  906. exports.redis_make_request = rspamd_redis_make_request
  907. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  908. is_write, callback, command, args, extra_opts)
  909. if not ev_base or not redis_params or not command then
  910. return false, nil, nil
  911. end
  912. local addr
  913. local function rspamd_redis_make_request_cb(err, data)
  914. if err then
  915. addr:fail()
  916. else
  917. addr:ok()
  918. end
  919. if callback then
  920. callback(err, data, addr)
  921. end
  922. end
  923. local rspamd_redis = require "rspamd_redis"
  924. if key then
  925. if is_write then
  926. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  927. else
  928. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  929. end
  930. else
  931. if is_write then
  932. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  933. else
  934. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  935. end
  936. end
  937. if not addr then
  938. logger.errx(cfg, 'cannot select server to make redis request')
  939. end
  940. local options = {
  941. ev_base = ev_base,
  942. config = cfg,
  943. callback = rspamd_redis_make_request_cb,
  944. host = addr:get_addr(),
  945. timeout = redis_params['timeout'],
  946. cmd = command,
  947. args = args
  948. }
  949. if extra_opts then
  950. for k, v in pairs(extra_opts) do
  951. options[k] = v
  952. end
  953. end
  954. if redis_params['username'] then
  955. options['username'] = redis_params['username']
  956. end
  957. if redis_params['password'] then
  958. options['password'] = redis_params['password']
  959. end
  960. if redis_params['db'] then
  961. options['dbname'] = redis_params['db']
  962. end
  963. lutil.debugm(N, cfg, 'perform taskless request to redis server' ..
  964. ' (host=%s, timeout=%s): cmd: %s', options.host:tostring(true),
  965. options.timeout, options.cmd)
  966. local ret, conn = rspamd_redis.make_request(options)
  967. if not ret then
  968. logger.errx('cannot execute redis request')
  969. addr:fail()
  970. end
  971. return ret, conn, addr
  972. end
  973. --[[[
  974. -- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
  975. -- Sends a request to Redis in context where `task` is not available for some specific use-cases
  976. -- Identical to redis_make_request() except in that first parameter is an `event base` object
  977. --]]
  978. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  979. exports.redis_make_request_taskless = redis_make_request_taskless
  980. local redis_scripts = {
  981. }
  982. local function script_set_loaded(script)
  983. if script.sha then
  984. script.loaded = true
  985. end
  986. local wait_table = {}
  987. for _, s in ipairs(script.waitq) do
  988. table.insert(wait_table, s)
  989. end
  990. script.waitq = {}
  991. for _, s in ipairs(wait_table) do
  992. s(script.loaded)
  993. end
  994. end
  995. local function prepare_redis_call(script)
  996. local servers = {}
  997. local options = {}
  998. if script.redis_params.read_servers then
  999. servers = lutil.table_merge(servers, script.redis_params.read_servers:all_upstreams())
  1000. end
  1001. if script.redis_params.write_servers then
  1002. servers = lutil.table_merge(servers, script.redis_params.write_servers:all_upstreams())
  1003. end
  1004. -- Call load script on each server, set loaded flag
  1005. script.in_flight = #servers
  1006. for _, s in ipairs(servers) do
  1007. local cur_opts = {
  1008. host = s:get_addr(),
  1009. timeout = script.redis_params['timeout'],
  1010. cmd = 'SCRIPT',
  1011. args = { 'LOAD', script.script },
  1012. upstream = s
  1013. }
  1014. if script.redis_params['username'] then
  1015. cur_opts['username'] = script.redis_params['username']
  1016. end
  1017. if script.redis_params['password'] then
  1018. cur_opts['password'] = script.redis_params['password']
  1019. end
  1020. if script.redis_params['db'] then
  1021. cur_opts['dbname'] = script.redis_params['db']
  1022. end
  1023. table.insert(options, cur_opts)
  1024. end
  1025. return options
  1026. end
  1027. local function load_script_task(script, task, is_write)
  1028. local rspamd_redis = require "rspamd_redis"
  1029. local opts = prepare_redis_call(script)
  1030. for _, opt in ipairs(opts) do
  1031. opt.task = task
  1032. opt.is_write = is_write
  1033. opt.callback = function(err, data)
  1034. if err then
  1035. logger.errx(task, 'cannot upload script to %s: %s; registered from: %s:%s',
  1036. opt.upstream:get_addr():to_string(true),
  1037. err, script.caller.short_src, script.caller.currentline)
  1038. opt.upstream:fail()
  1039. script.fatal_error = err
  1040. else
  1041. opt.upstream:ok()
  1042. logger.infox(task,
  1043. "uploaded redis script to %s with id %s, sha: %s",
  1044. opt.upstream:get_addr():to_string(true),
  1045. script.id, data)
  1046. script.sha = data -- We assume that sha is the same on all servers
  1047. end
  1048. script.in_flight = script.in_flight - 1
  1049. if script.in_flight == 0 then
  1050. script_set_loaded(script)
  1051. end
  1052. end
  1053. local ret = rspamd_redis.make_request(opt)
  1054. if not ret then
  1055. logger.errx('cannot execute redis request to load script on %s',
  1056. opt.upstream:get_addr())
  1057. script.in_flight = script.in_flight - 1
  1058. opt.upstream:fail()
  1059. end
  1060. if script.in_flight == 0 then
  1061. script_set_loaded(script)
  1062. end
  1063. end
  1064. end
  1065. local function load_script_taskless(script, cfg, ev_base, is_write)
  1066. local rspamd_redis = require "rspamd_redis"
  1067. local opts = prepare_redis_call(script)
  1068. for _, opt in ipairs(opts) do
  1069. opt.config = cfg
  1070. opt.ev_base = ev_base
  1071. opt.is_write = is_write
  1072. opt.callback = function(err, data)
  1073. if err then
  1074. logger.errx(cfg, 'cannot upload script to %s: %s; registered from: %s:%s',
  1075. opt.upstream:get_addr():to_string(true),
  1076. err, script.caller.short_src, script.caller.currentline)
  1077. opt.upstream:fail()
  1078. script.fatal_error = err
  1079. else
  1080. opt.upstream:ok()
  1081. logger.infox(cfg,
  1082. "uploaded redis script to %s with id %s, sha: %s",
  1083. opt.upstream:get_addr():to_string(true), script.id, data)
  1084. script.sha = data -- We assume that sha is the same on all servers
  1085. script.fatal_error = nil
  1086. end
  1087. script.in_flight = script.in_flight - 1
  1088. if script.in_flight == 0 then
  1089. script_set_loaded(script)
  1090. end
  1091. end
  1092. local ret = rspamd_redis.make_request(opt)
  1093. if not ret then
  1094. logger.errx('cannot execute redis request to load script on %s',
  1095. opt.upstream:get_addr())
  1096. script.in_flight = script.in_flight - 1
  1097. opt.upstream:fail()
  1098. end
  1099. if script.in_flight == 0 then
  1100. script_set_loaded(script)
  1101. end
  1102. end
  1103. end
  1104. local function load_redis_script(script, cfg, ev_base, _)
  1105. if script.redis_params then
  1106. load_script_taskless(script, cfg, ev_base)
  1107. end
  1108. end
  1109. local function add_redis_script(script, redis_params, caller_level)
  1110. if not caller_level then
  1111. caller_level = 2
  1112. end
  1113. local caller = debug.getinfo(caller_level)
  1114. local new_script = {
  1115. caller = caller,
  1116. loaded = false,
  1117. redis_params = redis_params,
  1118. script = script,
  1119. waitq = {}, -- callbacks pending for script being loaded
  1120. id = #redis_scripts + 1
  1121. }
  1122. -- Register on load function
  1123. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  1124. local mult = 0.0
  1125. rspamd_config:add_periodic(ev_base, 0.0, function()
  1126. if not new_script.sha then
  1127. load_redis_script(new_script, cfg, ev_base, worker)
  1128. mult = mult + 1
  1129. return 1.0 * mult -- Check one more time in one second
  1130. end
  1131. return false
  1132. end, false)
  1133. end)
  1134. table.insert(redis_scripts, new_script)
  1135. return #redis_scripts
  1136. end
  1137. exports.add_redis_script = add_redis_script
  1138. -- Loads a Redis script from a file, strips comments, and passes the content to
  1139. -- `add_redis_script` function.
  1140. --
  1141. -- @param filename The name of the file containing the Redis script.
  1142. -- @param redis_params The Redis parameters to use for this script.
  1143. -- @return The ID of the newly added Redis script.
  1144. --
  1145. local function load_redis_script_from_file(filename, redis_params, dir)
  1146. local lua_util = require "lua_util"
  1147. local rspamd_logger = require "rspamd_logger"
  1148. if not dir then
  1149. dir = rspamd_paths.LUALIBDIR
  1150. end
  1151. if filename:sub(1, 1) ~= package.config:sub(1, 1) then
  1152. -- Relative path
  1153. filename = lua_util.join_path(dir, "redis_scripts", filename)
  1154. end
  1155. -- Read file contents
  1156. local file = io.open(filename, "r")
  1157. if not file then
  1158. rspamd_logger.errx("failed to open Redis script file: %s", filename)
  1159. return nil
  1160. end
  1161. local script = file:read("*all")
  1162. if not script then
  1163. rspamd_logger.errx("failed to load Redis script file: %s", filename)
  1164. return nil
  1165. end
  1166. file:close()
  1167. script = lua_util.strip_lua_comments(script)
  1168. return add_redis_script(script, redis_params, 3)
  1169. end
  1170. exports.load_redis_script_from_file = load_redis_script_from_file
  1171. local function exec_redis_script(id, params, callback, keys, args)
  1172. local redis_args = {}
  1173. if not redis_scripts[id] then
  1174. logger.errx("cannot find registered script with id %s", id)
  1175. return false
  1176. end
  1177. local script = redis_scripts[id]
  1178. if script.fatal_error then
  1179. callback(script.fatal_error, nil)
  1180. return true
  1181. end
  1182. if not script.redis_params then
  1183. callback('no redis servers defined', nil)
  1184. return true
  1185. end
  1186. local function do_call(can_reload)
  1187. local function redis_cb(err, data)
  1188. if not err then
  1189. callback(err, data)
  1190. elseif string.match(err, 'NOSCRIPT') then
  1191. -- Schedule restart
  1192. script.sha = nil
  1193. if can_reload then
  1194. table.insert(script.waitq, do_call)
  1195. if script.in_flight == 0 then
  1196. -- Reload scripts if this has not been initiated yet
  1197. if params.task then
  1198. load_script_task(script, params.task)
  1199. else
  1200. load_script_taskless(script, rspamd_config, params.ev_base)
  1201. end
  1202. end
  1203. else
  1204. callback(err, data)
  1205. end
  1206. else
  1207. callback(err, data)
  1208. end
  1209. end
  1210. if #redis_args == 0 then
  1211. table.insert(redis_args, script.sha)
  1212. table.insert(redis_args, tostring(#keys))
  1213. for _, k in ipairs(keys) do
  1214. table.insert(redis_args, k)
  1215. end
  1216. if type(args) == 'table' then
  1217. for _, a in ipairs(args) do
  1218. table.insert(redis_args, a)
  1219. end
  1220. end
  1221. end
  1222. if params.task then
  1223. if not rspamd_redis_make_request(params.task, script.redis_params,
  1224. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1225. callback('Cannot make redis request', nil)
  1226. end
  1227. else
  1228. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  1229. script.redis_params,
  1230. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1231. callback('Cannot make redis request', nil)
  1232. end
  1233. end
  1234. end
  1235. if script.loaded then
  1236. do_call(true)
  1237. else
  1238. -- Delayed until scripts are loaded
  1239. if not params.task then
  1240. table.insert(script.waitq, do_call)
  1241. else
  1242. -- TODO: fix taskfull requests
  1243. table.insert(script.waitq, function()
  1244. if script.loaded then
  1245. do_call(false)
  1246. else
  1247. callback('NOSCRIPT', nil)
  1248. end
  1249. end)
  1250. load_script_task(script, params.task, params.is_write)
  1251. end
  1252. end
  1253. return true
  1254. end
  1255. exports.exec_redis_script = exec_redis_script
  1256. local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
  1257. if not redis_params then
  1258. return false, nil
  1259. end
  1260. local rspamd_redis = require "rspamd_redis"
  1261. local addr
  1262. if key then
  1263. if is_write then
  1264. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  1265. else
  1266. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  1267. end
  1268. else
  1269. if is_write then
  1270. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  1271. else
  1272. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  1273. end
  1274. end
  1275. if not addr then
  1276. logger.errx(cfg, 'cannot select server to make redis request')
  1277. end
  1278. local options = {
  1279. host = addr:get_addr(),
  1280. timeout = redis_params['timeout'],
  1281. config = cfg or rspamd_config,
  1282. ev_base = ev_base or rspamadm_ev_base,
  1283. session = redis_params.session or rspamadm_session
  1284. }
  1285. for k, v in pairs(redis_params) do
  1286. options[k] = v
  1287. end
  1288. if not options.config then
  1289. logger.errx('config is not set')
  1290. return false, nil, addr
  1291. end
  1292. if not options.ev_base then
  1293. logger.errx('ev_base is not set')
  1294. return false, nil, addr
  1295. end
  1296. if not options.session then
  1297. logger.errx('session is not set')
  1298. return false, nil, addr
  1299. end
  1300. local ret, conn = rspamd_redis.connect_sync(options)
  1301. if not ret then
  1302. logger.errx('cannot create redis connection: %s', conn)
  1303. addr:fail()
  1304. return false, nil, addr
  1305. end
  1306. if conn then
  1307. local need_exec = false
  1308. if redis_params['username'] then
  1309. if redis_params['password'] then
  1310. conn:add_cmd('AUTH', { redis_params['username'], redis_params['password'] })
  1311. need_exec = true
  1312. else
  1313. logger.warnx('Redis requires a password when username is supplied')
  1314. return false, nil, addr
  1315. end
  1316. elseif redis_params['password'] then
  1317. conn:add_cmd('AUTH', { redis_params['password'] })
  1318. need_exec = true
  1319. end
  1320. if redis_params['db'] then
  1321. conn:add_cmd('SELECT', { tostring(redis_params['db']) })
  1322. need_exec = true
  1323. elseif redis_params['dbname'] then
  1324. conn:add_cmd('SELECT', { tostring(redis_params['dbname']) })
  1325. need_exec = true
  1326. end
  1327. if need_exec then
  1328. local exec_ret, res = conn:exec()
  1329. if not exec_ret then
  1330. logger.errx('cannot prepare redis connection (authentication or db selection failure): %s',
  1331. res)
  1332. addr:fail()
  1333. return false, nil, addr
  1334. end
  1335. end
  1336. end
  1337. return ret, conn, addr
  1338. end
  1339. exports.redis_connect_sync = redis_connect_sync
  1340. --[[[
  1341. -- @function lua_redis.request(redis_params, attrs, req)
  1342. -- Sends a request to Redis synchronously with coroutines or asynchronously using
  1343. -- a callback (modern API)
  1344. -- @param redis_params a table of redis server parameters
  1345. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1346. -- @param req a table of request: a command + command options
  1347. -- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
  1348. --]]
  1349. exports.request = function(redis_params, attrs, req)
  1350. local lua_util = require "lua_util"
  1351. if not attrs or not redis_params or not req then
  1352. logger.errx('invalid arguments for redis request')
  1353. return false, nil, nil
  1354. end
  1355. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1356. logger.errx('invalid attributes for redis request')
  1357. return false, nil, nil
  1358. end
  1359. local opts = lua_util.shallowcopy(attrs)
  1360. local log_obj = opts.task or opts.config
  1361. local addr
  1362. if opts.callback then
  1363. -- Wrap callback
  1364. local callback = opts.callback
  1365. local function rspamd_redis_make_request_cb(err, data)
  1366. if err then
  1367. addr:fail()
  1368. else
  1369. addr:ok()
  1370. end
  1371. callback(err, data, addr)
  1372. end
  1373. opts.callback = rspamd_redis_make_request_cb
  1374. end
  1375. local rspamd_redis = require "rspamd_redis"
  1376. local is_write = opts.is_write
  1377. if opts.key then
  1378. if is_write then
  1379. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1380. else
  1381. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1382. end
  1383. else
  1384. if is_write then
  1385. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1386. else
  1387. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1388. end
  1389. end
  1390. if not addr then
  1391. logger.errx(log_obj, 'cannot select server to make redis request')
  1392. end
  1393. opts.host = addr:get_addr()
  1394. opts.timeout = redis_params.timeout
  1395. if type(req) == 'string' then
  1396. opts.cmd = req
  1397. else
  1398. -- XXX: modifies the input table
  1399. opts.cmd = table.remove(req, 1);
  1400. opts.args = req
  1401. end
  1402. if redis_params.username then
  1403. opts.username = redis_params.username
  1404. end
  1405. if redis_params.password then
  1406. opts.password = redis_params.password
  1407. end
  1408. if redis_params.db then
  1409. opts.dbname = redis_params.db
  1410. end
  1411. lutil.debugm(N, 'perform generic request to redis server' ..
  1412. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr,
  1413. opts.timeout, opts.cmd, opts.args)
  1414. if opts.callback then
  1415. local ret, conn = rspamd_redis.make_request(opts)
  1416. if not ret then
  1417. logger.errx(log_obj, 'cannot execute redis request')
  1418. addr:fail()
  1419. end
  1420. return ret, conn, addr
  1421. else
  1422. -- Coroutines version
  1423. local ret, conn = rspamd_redis.connect_sync(opts)
  1424. if not ret then
  1425. logger.errx(log_obj, 'cannot execute redis request')
  1426. addr:fail()
  1427. else
  1428. conn:add_cmd(opts.cmd, opts.args)
  1429. return conn:exec()
  1430. end
  1431. return false, nil, addr
  1432. end
  1433. end
  1434. --[[[
  1435. -- @function lua_redis.connect(redis_params, attrs)
  1436. -- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
  1437. -- @param redis_params a table of redis server parameters
  1438. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1439. -- @return {result,connection,address} boolean result, connection object, redis server address
  1440. --]]
  1441. exports.connect = function(redis_params, attrs)
  1442. local lua_util = require "lua_util"
  1443. if not attrs or not redis_params then
  1444. logger.errx('invalid arguments for redis connect')
  1445. return false, nil, nil
  1446. end
  1447. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1448. logger.errx('invalid attributes for redis connect')
  1449. return false, nil, nil
  1450. end
  1451. local opts = lua_util.shallowcopy(attrs)
  1452. local log_obj = opts.task or opts.config
  1453. local addr
  1454. if opts.callback then
  1455. -- Wrap callback
  1456. local callback = opts.callback
  1457. local function rspamd_redis_make_request_cb(err, data)
  1458. if err then
  1459. addr:fail()
  1460. else
  1461. addr:ok()
  1462. end
  1463. callback(err, data, addr)
  1464. end
  1465. opts.callback = rspamd_redis_make_request_cb
  1466. end
  1467. local rspamd_redis = require "rspamd_redis"
  1468. local is_write = opts.is_write
  1469. if opts.key then
  1470. if is_write then
  1471. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1472. else
  1473. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1474. end
  1475. else
  1476. if is_write then
  1477. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1478. else
  1479. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1480. end
  1481. end
  1482. if not addr then
  1483. logger.errx(log_obj, 'cannot select server to make redis connect')
  1484. end
  1485. opts.host = addr:get_addr()
  1486. opts.timeout = redis_params.timeout
  1487. if redis_params.username then
  1488. opts.username = redis_params.username
  1489. end
  1490. if redis_params.password then
  1491. opts.password = redis_params.password
  1492. end
  1493. if redis_params.db then
  1494. opts.dbname = redis_params.db
  1495. end
  1496. if opts.callback then
  1497. local ret, conn = rspamd_redis.connect(opts)
  1498. if not ret then
  1499. logger.errx(log_obj, 'cannot execute redis connect')
  1500. addr:fail()
  1501. end
  1502. return ret, conn, addr
  1503. else
  1504. -- Coroutines version
  1505. local ret, conn = rspamd_redis.connect_sync(opts)
  1506. if not ret then
  1507. logger.errx(log_obj, 'cannot execute redis connect')
  1508. addr:fail()
  1509. else
  1510. return true, conn, addr
  1511. end
  1512. return false, nil, addr
  1513. end
  1514. end
  1515. local redis_prefixes = {}
  1516. --[[[
  1517. -- @function lua_redis.register_prefix(prefix, module, description[, optional])
  1518. -- Register new redis prefix for documentation purposes
  1519. -- @param {string} prefix string prefix
  1520. -- @param {string} module module name
  1521. -- @param {string} description prefix description
  1522. -- @param {table} optional optional kv pairs (e.g. pattern)
  1523. --]]
  1524. local function register_prefix(prefix, module, description, optional)
  1525. local pr = {
  1526. module = module,
  1527. description = description
  1528. }
  1529. if optional and type(optional) == 'table' then
  1530. for k, v in pairs(optional) do
  1531. pr[k] = v
  1532. end
  1533. end
  1534. redis_prefixes[prefix] = pr
  1535. end
  1536. exports.register_prefix = register_prefix
  1537. --[[[
  1538. -- @function lua_redis.prefixes([mname])
  1539. -- Returns prefixes for specific module (or all prefixes). Returns a table prefix -> table
  1540. --]]
  1541. exports.prefixes = function(mname)
  1542. if not mname then
  1543. return redis_prefixes
  1544. else
  1545. local fun = require "fun"
  1546. return fun.totable(fun.filter(function(_, data)
  1547. return data.module == mname
  1548. end,
  1549. redis_prefixes))
  1550. end
  1551. end
  1552. return exports