You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_redis.lua 52KB


  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local ts = require("tableshape").types
  17. local exports = {}
  18. local E = {}
  19. local N = "lua_redis"
  20. local common_schema = {
  21. timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional():describe("Connection timeout"),
  22. db = ts.string:is_optional():describe("Database number"),
  23. database = ts.string:is_optional():describe("Database number"),
  24. dbname = ts.string:is_optional():describe("Database number"),
  25. prefix = ts.string:is_optional():describe("Key prefix"),
  26. username = ts.string:is_optional():describe("Username"),
  27. password = ts.string:is_optional():describe("Password"),
  28. expand_keys = ts.boolean:is_optional():describe("Expand keys"),
  29. sentinels = (ts.string + ts.array_of(ts.string)):is_optional():describe("Sentinel servers"),
  30. sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional():describe("Sentinel watch time"),
  31. sentinel_masters_pattern = ts.string:is_optional():describe("Sentinel masters pattern"),
  32. sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional():describe("Sentinel master max errors"),
  33. sentinel_username = ts.string:is_optional():describe("Sentinel username"),
  34. sentinel_password = ts.string:is_optional():describe("Sentinel password"),
  35. }
  36. local read_schema = lutil.table_merge({
  37. read_servers = ts.string + ts.array_of(ts.string),
  38. }, common_schema)
  39. local write_schema = lutil.table_merge({
  40. write_servers = ts.string + ts.array_of(ts.string),
  41. }, common_schema)
  42. local rw_schema = lutil.table_merge({
  43. read_servers = ts.string + ts.array_of(ts.string),
  44. write_servers = ts.string + ts.array_of(ts.string),
  45. }, common_schema)
  46. local servers_schema = lutil.table_merge({
  47. servers = ts.string + ts.array_of(ts.string),
  48. }, common_schema)
  49. local server_schema = lutil.table_merge({
  50. server = ts.string + ts.array_of(ts.string),
  51. }, common_schema)
  52. local enrich_schema = function(external)
  53. return ts.one_of {
  54. ts.shape(external), -- no specific redis parameters
  55. ts.shape(lutil.table_merge(read_schema, external)), -- read_servers specified
  56. ts.shape(lutil.table_merge(write_schema, external)), -- write_servers specified
  57. ts.shape(lutil.table_merge(rw_schema, external)), -- both read and write servers defined
  58. ts.shape(lutil.table_merge(servers_schema, external)), -- just servers for both ops
  59. ts.shape(lutil.table_merge(server_schema, external)), -- legacy `server` attribute
  60. }
  61. end
  62. exports.enrich_schema = enrich_schema
  63. local function redis_query_sentinel(ev_base, params, initialised)
  64. local function flatten_redis_table(tbl)
  65. local res = {}
  66. for i = 1, #tbl, 2 do
  67. res[tbl[i]] = tbl[i + 1]
  68. end
  69. return res
  70. end
  71. -- Coroutines syntax
  72. local rspamd_redis = require "rspamd_redis"
  73. local sentinels = params.sentinels
  74. local addr = sentinels:get_upstream_round_robin()
  75. local host = addr:get_addr()
  76. local masters = {}
  77. local process_masters -- Function that is called to process masters data
  78. local function masters_cb(err, result)
  79. if not err and result and type(result) == 'table' then
  80. local pending_subrequests = 0
  81. for _, m in ipairs(result) do
  82. local master = flatten_redis_table(m)
  83. -- Wrap IPv6-addresses in brackets
  84. if (master.ip:match(":")) then
  85. master.ip = "[" .. master.ip .. "]"
  86. end
  87. if params.sentinel_masters_pattern then
  88. if master.name:match(params.sentinel_masters_pattern) then
  89. lutil.debugm(N, 'found master %s with ip %s and port %s',
  90. master.name, master.ip, master.port)
  91. masters[master.name] = master
  92. else
  93. lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
  94. master.name, master.ip, master.port, params.sentinel_masters_pattern)
  95. end
  96. else
  97. lutil.debugm(N, 'found master %s with ip %s and port %s',
  98. master.name, master.ip, master.port)
  99. masters[master.name] = master
  100. end
  101. end
  102. -- For each master we need to get a list of slaves
  103. for k, v in pairs(masters) do
  104. v.slaves = {}
  105. local function slaves_cb(slave_err, slave_result)
  106. if not slave_err and type(slave_result) == 'table' then
  107. for _, s in ipairs(slave_result) do
  108. local slave = flatten_redis_table(s)
  109. lutil.debugm(N, rspamd_config,
  110. 'found slave for master %s with ip %s and port %s',
  111. v.name, slave.ip, slave.port)
  112. -- Wrap IPv6-addresses in brackets
  113. if (slave.ip:match(":")) then
  114. slave.ip = "[" .. slave.ip .. "]"
  115. end
  116. v.slaves[#v.slaves + 1] = slave
  117. end
  118. else
  119. logger.errx('cannot get slaves data from Redis Sentinel %s: %s',
  120. host:to_string(true), slave_err)
  121. addr:fail()
  122. end
  123. pending_subrequests = pending_subrequests - 1
  124. if pending_subrequests == 0 then
  125. -- Finalize masters and slaves
  126. process_masters()
  127. end
  128. end
  129. local ret = rspamd_redis.make_request {
  130. host = addr:get_addr(),
  131. timeout = params.timeout,
  132. username = params.sentinel_username,
  133. password = params.sentinel_password,
  134. config = rspamd_config,
  135. ev_base = ev_base,
  136. cmd = 'SENTINEL',
  137. args = { 'slaves', k },
  138. no_pool = true,
  139. callback = slaves_cb
  140. }
  141. if not ret then
  142. logger.errx(rspamd_config, 'cannot connect sentinel when query slaves at address: %s',
  143. host:to_string(true))
  144. addr:fail()
  145. else
  146. pending_subrequests = pending_subrequests + 1
  147. end
  148. end
  149. addr:ok()
  150. else
  151. logger.errx('cannot get masters data from Redis Sentinel %s: %s',
  152. host:to_string(true), err)
  153. addr:fail()
  154. end
  155. end
  156. local ret = rspamd_redis.make_request {
  157. host = addr:get_addr(),
  158. timeout = params.timeout,
  159. config = rspamd_config,
  160. ev_base = ev_base,
  161. username = params.sentinel_username,
  162. password = params.sentinel_password,
  163. cmd = 'SENTINEL',
  164. args = { 'masters' },
  165. no_pool = true,
  166. callback = masters_cb,
  167. }
  168. if not ret then
  169. logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
  170. host:to_string(true))
  171. addr:fail()
  172. end
  173. process_masters = function()
  174. -- We now form new strings for masters and slaves
  175. local read_servers_tbl, write_servers_tbl = {}, {}
  176. for _, master in pairs(masters) do
  177. write_servers_tbl[#write_servers_tbl + 1] = string.format(
  178. '%s:%s', master.ip, master.port
  179. )
  180. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  181. '%s:%s', master.ip, master.port
  182. )
  183. for _, slave in ipairs(master.slaves) do
  184. if slave['master-link-status'] == 'ok' then
  185. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  186. '%s:%s', slave.ip, slave.port
  187. )
  188. end
  189. end
  190. end
  191. table.sort(read_servers_tbl)
  192. table.sort(write_servers_tbl)
  193. local read_servers_str = table.concat(read_servers_tbl, ',')
  194. local write_servers_str = table.concat(write_servers_tbl, ',')
  195. lutil.debugm(N, rspamd_config,
  196. 'new servers list: %s read; %s write',
  197. read_servers_str,
  198. write_servers_str)
  199. if read_servers_str ~= params.read_servers_str then
  200. local upstream_list = require "rspamd_upstream_list"
  201. local read_upstreams = upstream_list.create(rspamd_config,
  202. read_servers_str, 6379)
  203. if read_upstreams then
  204. logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
  205. host:to_string(true), read_servers_str)
  206. params.read_servers = read_upstreams
  207. params.read_servers_str = read_servers_str
  208. end
  209. end
  210. if write_servers_str ~= params.write_servers_str then
  211. local upstream_list = require "rspamd_upstream_list"
  212. local write_upstreams = upstream_list.create(rspamd_config,
  213. write_servers_str, 6379)
  214. if write_upstreams then
  215. logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
  216. host:to_string(true), write_servers_str)
  217. params.write_servers = write_upstreams
  218. params.write_servers_str = write_servers_str
  219. local queried = false
  220. local function monitor_failures(up, _, count)
  221. if count > params.sentinel_master_maxerrors and not queried then
  222. logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel',
  223. host:to_string(true), count)
  224. queried = true -- Avoid multiple checks caused by this monitor
  225. redis_query_sentinel(ev_base, params, true)
  226. end
  227. end
  228. write_upstreams:add_watcher('failure', monitor_failures)
  229. end
  230. end
  231. end
  232. end
  233. local function add_redis_sentinels(params)
  234. local upstream_list = require "rspamd_upstream_list"
  235. local upstreams_sentinels = upstream_list.create(rspamd_config,
  236. params.sentinels, 5000)
  237. if not upstreams_sentinels then
  238. logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
  239. params.sentinels)
  240. return
  241. end
  242. params.sentinels = upstreams_sentinels
  243. if not params.sentinel_watch_time then
  244. params.sentinel_watch_time = 60 -- Each minute
  245. end
  246. if not params.sentinel_master_maxerrors then
  247. params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking
  248. end
  249. rspamd_config:add_on_load(function(_, ev_base, worker)
  250. local initialised = false
  251. if worker:is_scanner() or worker:get_type() == 'fuzzy' then
  252. rspamd_config:add_periodic(ev_base, 0.0, function()
  253. redis_query_sentinel(ev_base, params, initialised)
  254. initialised = true
  255. return params.sentinel_watch_time
  256. end, false)
  257. end
  258. end)
  259. end
  260. local cached_results = {}
  261. local function calculate_redis_hash(params)
  262. local cr = require "rspamd_cryptobox_hash"
  263. local h = cr.create()
  264. local function rec_hash(k, v)
  265. if type(v) == 'string' then
  266. h:update(k)
  267. h:update(v)
  268. elseif type(v) == 'number' then
  269. h:update(k)
  270. h:update(tostring(v))
  271. elseif type(v) == 'table' then
  272. for kk, vv in pairs(v) do
  273. rec_hash(kk, vv)
  274. end
  275. end
  276. end
  277. rec_hash('top', params)
  278. return h:base32()
  279. end
  280. local function process_redis_opts(options, redis_params)
  281. local default_timeout = 1.0
  282. local default_expand_keys = false
  283. if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then
  284. if options['timeout'] then
  285. redis_params['timeout'] = tonumber(options['timeout'])
  286. else
  287. redis_params['timeout'] = default_timeout
  288. end
  289. end
  290. if options['prefix'] and not redis_params['prefix'] then
  291. redis_params['prefix'] = options['prefix']
  292. end
  293. if type(options['expand_keys']) == 'boolean' then
  294. redis_params['expand_keys'] = options['expand_keys']
  295. else
  296. redis_params['expand_keys'] = default_expand_keys
  297. end
  298. if not redis_params['db'] then
  299. if options['db'] then
  300. redis_params['db'] = tostring(options['db'])
  301. elseif options['dbname'] then
  302. redis_params['db'] = tostring(options['dbname'])
  303. elseif options['database'] then
  304. redis_params['db'] = tostring(options['database'])
  305. end
  306. end
  307. if options['username'] and not redis_params['username'] then
  308. redis_params['username'] = options['username']
  309. end
  310. if options['password'] and not redis_params['password'] then
  311. redis_params['password'] = options['password']
  312. end
  313. if not redis_params.sentinels and options.sentinels then
  314. redis_params.sentinels = options.sentinels
  315. end
  316. if options['sentinel_masters_pattern'] and not redis_params['sentinel_masters_pattern'] then
  317. redis_params['sentinel_masters_pattern'] = options['sentinel_masters_pattern']
  318. end
  319. end
  320. local function enrich_defaults(rspamd_config, module, redis_params)
  321. if rspamd_config then
  322. local opts = rspamd_config:get_all_opt('redis')
  323. if opts then
  324. if module then
  325. if opts[module] then
  326. process_redis_opts(opts[module], redis_params)
  327. end
  328. end
  329. process_redis_opts(opts, redis_params)
  330. end
  331. end
  332. end
  333. local function maybe_return_cached(redis_params)
  334. local h = calculate_redis_hash(redis_params)
  335. if cached_results[h] then
  336. lutil.debugm(N, 'reused redis server: %s', redis_params)
  337. return cached_results[h]
  338. end
  339. redis_params.hash = h
  340. cached_results[h] = redis_params
  341. if not redis_params.read_only and redis_params.sentinels then
  342. add_redis_sentinels(redis_params)
  343. end
  344. lutil.debugm(N, 'loaded new redis server: %s', redis_params)
  345. return redis_params
  346. end
  347. --[[[
  348. -- @module lua_redis
  349. -- This module contains helper functions for working with Redis
  350. --]]
  351. local function process_redis_options(options, rspamd_config, result)
  352. local default_port = 6379
  353. local upstream_list = require "rspamd_upstream_list"
  354. local read_only = true
  355. -- Try to get read servers:
  356. local upstreams_read, upstreams_write
  357. if options['read_servers'] then
  358. if rspamd_config then
  359. upstreams_read = upstream_list.create(rspamd_config,
  360. options['read_servers'], default_port)
  361. else
  362. upstreams_read = upstream_list.create(options['read_servers'],
  363. default_port)
  364. end
  365. result.read_servers_str = options['read_servers']
  366. elseif options['servers'] then
  367. if rspamd_config then
  368. upstreams_read = upstream_list.create(rspamd_config,
  369. options['servers'], default_port)
  370. else
  371. upstreams_read = upstream_list.create(options['servers'], default_port)
  372. end
  373. result.read_servers_str = options['servers']
  374. read_only = false
  375. elseif options['server'] then
  376. if rspamd_config then
  377. upstreams_read = upstream_list.create(rspamd_config,
  378. options['server'], default_port)
  379. else
  380. upstreams_read = upstream_list.create(options['server'], default_port)
  381. end
  382. result.read_servers_str = options['server']
  383. read_only = false
  384. end
  385. if upstreams_read then
  386. if options['write_servers'] then
  387. if rspamd_config then
  388. upstreams_write = upstream_list.create(rspamd_config,
  389. options['write_servers'], default_port)
  390. else
  391. upstreams_write = upstream_list.create(options['write_servers'],
  392. default_port)
  393. end
  394. result.write_servers_str = options['write_servers']
  395. read_only = false
  396. elseif not read_only then
  397. upstreams_write = upstreams_read
  398. result.write_servers_str = result.read_servers_str
  399. end
  400. end
  401. -- Store options
  402. process_redis_opts(options, result)
  403. if read_only and not upstreams_write then
  404. result.read_only = true
  405. elseif upstreams_write then
  406. result.read_only = false
  407. end
  408. if upstreams_read then
  409. result.read_servers = upstreams_read
  410. if upstreams_write then
  411. result.write_servers = upstreams_write
  412. end
  413. return true
  414. end
  415. lutil.debugm(N, rspamd_config,
  416. 'cannot load redis server from obj: %s, processed to %s',
  417. options, result)
  418. return false
  419. end
  420. --[[[
  421. @function try_load_redis_servers(options, rspamd_config, no_fallback)
  422. Tries to load redis servers from the specified `options` object.
  423. Returns `redis_params` table or nil in case of failure
  424. --]]
  425. exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name)
  426. local result = {}
  427. if process_redis_options(options, rspamd_config, result) then
  428. if not no_fallback then
  429. enrich_defaults(rspamd_config, module_name, result)
  430. end
  431. return maybe_return_cached(result)
  432. end
  433. end
  434. -- This function parses redis server definition using either
  435. -- specific server string for this module or global
  436. -- redis section
  437. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  438. local result = {}
  439. -- Try local options
  440. local opts
  441. lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name)
  442. if not module_opts then
  443. opts = rspamd_config:get_all_opt(module_name)
  444. else
  445. opts = module_opts
  446. end
  447. if opts then
  448. local ret
  449. if opts.redis then
  450. ret = process_redis_options(opts.redis, rspamd_config, result)
  451. if ret then
  452. if not no_fallback then
  453. enrich_defaults(rspamd_config, module_name, result)
  454. end
  455. return maybe_return_cached(result)
  456. end
  457. end
  458. ret = process_redis_options(opts, rspamd_config, result)
  459. if ret then
  460. if not no_fallback then
  461. enrich_defaults(rspamd_config, module_name, result)
  462. end
  463. return maybe_return_cached(result)
  464. end
  465. end
  466. if no_fallback then
  467. logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled",
  468. module_name)
  469. return nil
  470. end
  471. -- Try global options
  472. opts = rspamd_config:get_all_opt('redis')
  473. if opts then
  474. local ret
  475. if opts[module_name] then
  476. ret = process_redis_options(opts[module_name], rspamd_config, result)
  477. if ret then
  478. return maybe_return_cached(result)
  479. end
  480. else
  481. ret = process_redis_options(opts, rspamd_config, result)
  482. -- Exclude disabled
  483. if opts['disabled_modules'] then
  484. for _, v in ipairs(opts['disabled_modules']) do
  485. if v == module_name then
  486. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  487. module_name)
  488. return nil
  489. end
  490. end
  491. end
  492. if ret then
  493. logger.infox(rspamd_config, "use default Redis settings for %s",
  494. module_name)
  495. return maybe_return_cached(result)
  496. end
  497. end
  498. end
  499. if result.read_servers then
  500. return maybe_return_cached(result)
  501. end
  502. return nil
  503. end
  504. --[[[
  505. -- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
  506. -- Extracts Redis server settings from configuration
  507. -- @param {string} module_name name of module to get settings for
  508. -- @param {table} module_opts settings for module or `nil` to fetch them from configuration
  509. -- @param {boolean} no_fallback should be `true` if global settings must not be used
  510. -- @return {table} redis server settings
  511. -- @example
  512. -- local rconfig = lua_redis.parse_redis_server('my_module')
  513. -- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
  514. -- -- ['timeout'] contains timeout in seconds
  515. -- -- ['expand_keys'] if true tells that redis key expansion is enabled
  516. --]]
  517. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  518. exports.parse_redis_server = rspamd_parse_redis_server
  519. local process_cmd = {
  520. bitop = function(args)
  521. local idx_l = {}
  522. for i = 2, #args do
  523. table.insert(idx_l, i)
  524. end
  525. return idx_l
  526. end,
  527. blpop = function(args)
  528. local idx_l = {}
  529. for i = 1, #args - 1 do
  530. table.insert(idx_l, i)
  531. end
  532. return idx_l
  533. end,
  534. eval = function(args)
  535. local idx_l = {}
  536. local numkeys = args[2]
  537. if numkeys and tonumber(numkeys) >= 1 then
  538. for i = 3, numkeys + 2 do
  539. table.insert(idx_l, i)
  540. end
  541. end
  542. return idx_l
  543. end,
  544. set = function(args)
  545. return { 1 }
  546. end,
  547. mget = function(args)
  548. local idx_l = {}
  549. for i = 1, #args do
  550. table.insert(idx_l, i)
  551. end
  552. return idx_l
  553. end,
  554. mset = function(args)
  555. local idx_l = {}
  556. for i = 1, #args, 2 do
  557. table.insert(idx_l, i)
  558. end
  559. return idx_l
  560. end,
  561. sdiffstore = function(args)
  562. local idx_l = {}
  563. for i = 2, #args do
  564. table.insert(idx_l, i)
  565. end
  566. return idx_l
  567. end,
  568. smove = function(args)
  569. return { 1, 2 }
  570. end,
  571. script = function()
  572. end
  573. }
  574. process_cmd.append = process_cmd.set
  575. process_cmd.auth = process_cmd.script
  576. process_cmd.bgrewriteaof = process_cmd.script
  577. process_cmd.bgsave = process_cmd.script
  578. process_cmd.bitcount = process_cmd.set
  579. process_cmd.bitfield = process_cmd.set
  580. process_cmd.bitpos = process_cmd.set
  581. process_cmd.brpop = process_cmd.blpop
  582. process_cmd.brpoplpush = process_cmd.blpop
  583. process_cmd.client = process_cmd.script
  584. process_cmd.cluster = process_cmd.script
  585. process_cmd.command = process_cmd.script
  586. process_cmd.config = process_cmd.script
  587. process_cmd.dbsize = process_cmd.script
  588. process_cmd.debug = process_cmd.script
  589. process_cmd.decr = process_cmd.set
  590. process_cmd.decrby = process_cmd.set
  591. process_cmd.del = process_cmd.mget
  592. process_cmd.discard = process_cmd.script
  593. process_cmd.dump = process_cmd.set
  594. process_cmd.echo = process_cmd.script
  595. process_cmd.evalsha = process_cmd.eval
  596. process_cmd.exec = process_cmd.script
  597. process_cmd.exists = process_cmd.mget
  598. process_cmd.expire = process_cmd.set
  599. process_cmd.expireat = process_cmd.set
  600. process_cmd.flushall = process_cmd.script
  601. process_cmd.flushdb = process_cmd.script
  602. process_cmd.geoadd = process_cmd.set
  603. process_cmd.geohash = process_cmd.set
  604. process_cmd.geopos = process_cmd.set
  605. process_cmd.geodist = process_cmd.set
  606. process_cmd.georadius = process_cmd.set
  607. process_cmd.georadiusbymember = process_cmd.set
  608. process_cmd.get = process_cmd.set
  609. process_cmd.getbit = process_cmd.set
  610. process_cmd.getrange = process_cmd.set
  611. process_cmd.getset = process_cmd.set
  612. process_cmd.hdel = process_cmd.set
  613. process_cmd.hexists = process_cmd.set
  614. process_cmd.hget = process_cmd.set
  615. process_cmd.hgetall = process_cmd.set
  616. process_cmd.hincrby = process_cmd.set
  617. process_cmd.hincrbyfloat = process_cmd.set
  618. process_cmd.hkeys = process_cmd.set
  619. process_cmd.hlen = process_cmd.set
  620. process_cmd.hmget = process_cmd.set
  621. process_cmd.hmset = process_cmd.set
  622. process_cmd.hscan = process_cmd.set
  623. process_cmd.hset = process_cmd.set
  624. process_cmd.hsetnx = process_cmd.set
  625. process_cmd.hstrlen = process_cmd.set
  626. process_cmd.hvals = process_cmd.set
  627. process_cmd.incr = process_cmd.set
  628. process_cmd.incrby = process_cmd.set
  629. process_cmd.incrbyfloat = process_cmd.set
  630. process_cmd.info = process_cmd.script
  631. process_cmd.keys = process_cmd.script
  632. process_cmd.lastsave = process_cmd.script
  633. process_cmd.lindex = process_cmd.set
  634. process_cmd.linsert = process_cmd.set
  635. process_cmd.llen = process_cmd.set
  636. process_cmd.lpop = process_cmd.set
  637. process_cmd.lpush = process_cmd.set
  638. process_cmd.lpushx = process_cmd.set
  639. process_cmd.lrange = process_cmd.set
  640. process_cmd.lrem = process_cmd.set
  641. process_cmd.lset = process_cmd.set
  642. process_cmd.ltrim = process_cmd.set
  643. process_cmd.migrate = process_cmd.script
  644. process_cmd.monitor = process_cmd.script
  645. process_cmd.move = process_cmd.set
  646. process_cmd.msetnx = process_cmd.mset
  647. process_cmd.multi = process_cmd.script
  648. process_cmd.object = process_cmd.script
  649. process_cmd.persist = process_cmd.set
  650. process_cmd.pexpire = process_cmd.set
  651. process_cmd.pexpireat = process_cmd.set
  652. process_cmd.pfadd = process_cmd.set
  653. process_cmd.pfcount = process_cmd.set
  654. process_cmd.pfmerge = process_cmd.mget
  655. process_cmd.ping = process_cmd.script
  656. process_cmd.psetex = process_cmd.set
  657. process_cmd.psubscribe = process_cmd.script
  658. process_cmd.pubsub = process_cmd.script
  659. process_cmd.pttl = process_cmd.set
  660. process_cmd.publish = process_cmd.script
  661. process_cmd.punsubscribe = process_cmd.script
  662. process_cmd.quit = process_cmd.script
  663. process_cmd.randomkey = process_cmd.script
  664. process_cmd.readonly = process_cmd.script
  665. process_cmd.readwrite = process_cmd.script
  666. process_cmd.rename = process_cmd.mget
  667. process_cmd.renamenx = process_cmd.mget
  668. process_cmd.restore = process_cmd.set
  669. process_cmd.role = process_cmd.script
  670. process_cmd.rpop = process_cmd.set
  671. process_cmd.rpoplpush = process_cmd.mget
  672. process_cmd.rpush = process_cmd.set
  673. process_cmd.rpushx = process_cmd.set
  674. process_cmd.sadd = process_cmd.set
  675. process_cmd.save = process_cmd.script
  676. process_cmd.scard = process_cmd.set
  677. process_cmd.sdiff = process_cmd.mget
  678. process_cmd.select = process_cmd.script
  679. process_cmd.setbit = process_cmd.set
  680. process_cmd.setex = process_cmd.set
  681. process_cmd.setnx = process_cmd.set
  682. process_cmd.sinterstore = process_cmd.sdiff
  683. process_cmd.sismember = process_cmd.set
  684. process_cmd.slaveof = process_cmd.script
  685. process_cmd.slowlog = process_cmd.script
  686. process_cmd.smembers = process_cmd.script
  687. process_cmd.sort = process_cmd.set
  688. process_cmd.spop = process_cmd.set
  689. process_cmd.srandmember = process_cmd.set
  690. process_cmd.srem = process_cmd.set
  691. process_cmd.strlen = process_cmd.set
  692. process_cmd.subscribe = process_cmd.script
  693. process_cmd.sunion = process_cmd.mget
  694. process_cmd.sunionstore = process_cmd.mget
  695. process_cmd.swapdb = process_cmd.script
  696. process_cmd.sync = process_cmd.script
  697. process_cmd.time = process_cmd.script
  698. process_cmd.touch = process_cmd.mget
  699. process_cmd.ttl = process_cmd.set
  700. process_cmd.type = process_cmd.set
  701. process_cmd.unsubscribe = process_cmd.script
  702. process_cmd.unlink = process_cmd.mget
  703. process_cmd.unwatch = process_cmd.script
  704. process_cmd.wait = process_cmd.script
  705. process_cmd.watch = process_cmd.mget
  706. process_cmd.zadd = process_cmd.set
  707. process_cmd.zcard = process_cmd.set
  708. process_cmd.zcount = process_cmd.set
  709. process_cmd.zincrby = process_cmd.set
  710. process_cmd.zinterstore = process_cmd.eval
  711. process_cmd.zlexcount = process_cmd.set
  712. process_cmd.zrange = process_cmd.set
  713. process_cmd.zrangebylex = process_cmd.set
  714. process_cmd.zrank = process_cmd.set
  715. process_cmd.zrem = process_cmd.set
  716. process_cmd.zrembylex = process_cmd.set
  717. process_cmd.zrembyrank = process_cmd.set
  718. process_cmd.zrembyscore = process_cmd.set
  719. process_cmd.zrevrange = process_cmd.set
  720. process_cmd.zrevrangebyscore = process_cmd.set
  721. process_cmd.zrevrank = process_cmd.set
  722. process_cmd.zscore = process_cmd.set
  723. process_cmd.zunionstore = process_cmd.eval
  724. process_cmd.scan = process_cmd.script
  725. process_cmd.sscan = process_cmd.set
  726. process_cmd.hscan = process_cmd.set
  727. process_cmd.zscan = process_cmd.set
  728. local function get_key_indexes(cmd, args)
  729. local idx_l = {}
  730. cmd = string.lower(cmd)
  731. if process_cmd[cmd] then
  732. idx_l = process_cmd[cmd](args)
  733. else
  734. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  735. end
  736. return idx_l
  737. end
  738. local gen_meta = {
  739. principal_recipient = function(task)
  740. return task:get_principal_recipient()
  741. end,
  742. principal_recipient_domain = function(task)
  743. local p = task:get_principal_recipient()
  744. if not p then
  745. return
  746. end
  747. return string.match(p, '.*@(.*)')
  748. end,
  749. ip = function(task)
  750. local i = task:get_ip()
  751. if i and i:is_valid() then
  752. return i:to_string()
  753. end
  754. end,
  755. from = function(task)
  756. return ((task:get_from('smtp') or E)[1] or E)['addr']
  757. end,
  758. from_domain = function(task)
  759. return ((task:get_from('smtp') or E)[1] or E)['domain']
  760. end,
  761. from_domain_or_helo_domain = function(task)
  762. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  763. if d and #d > 0 then
  764. return d
  765. end
  766. return task:get_helo()
  767. end,
  768. mime_from = function(task)
  769. return ((task:get_from('mime') or E)[1] or E)['addr']
  770. end,
  771. mime_from_domain = function(task)
  772. return ((task:get_from('mime') or E)[1] or E)['domain']
  773. end,
  774. }
  775. local function gen_get_esld(f)
  776. return function(task)
  777. local d = f(task)
  778. if not d then
  779. return
  780. end
  781. return rspamd_util.get_tld(d)
  782. end
  783. end
  784. gen_meta.smtp_from = gen_meta.from
  785. gen_meta.smtp_from_domain = gen_meta.from_domain
  786. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  787. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  788. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  789. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  790. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  791. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  792. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  793. local function get_key_expansion_metadata(task)
  794. local md_mt = {
  795. __index = function(self, k)
  796. k = string.lower(k)
  797. local v = rawget(self, k)
  798. if v then
  799. return v
  800. end
  801. if gen_meta[k] then
  802. v = gen_meta[k](task)
  803. rawset(self, k, v)
  804. end
  805. return v
  806. end,
  807. }
  808. local lazy_meta = {}
  809. setmetatable(lazy_meta, md_mt)
  810. return lazy_meta
  811. end
  812. -- Performs async call to redis hiding all complexity inside function
  813. -- task - rspamd_task
  814. -- redis_params - valid params returned by rspamd_parse_redis_server
  815. -- key - key to select upstream or nil to select round-robin/master-slave
  816. -- is_write - true if need to write to redis server
  817. -- callback - function to be called upon request is completed
  818. -- command - redis command
  819. -- args - table of arguments
  820. -- extra_opts - table of optional request arguments
  821. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  822. callback, command, args, extra_opts)
  823. local addr
  824. local function rspamd_redis_make_request_cb(err, data)
  825. if err then
  826. addr:fail()
  827. else
  828. addr:ok()
  829. end
  830. if callback then
  831. callback(err, data, addr)
  832. end
  833. end
  834. if not task or not redis_params or not command then
  835. return false, nil, nil
  836. end
  837. local rspamd_redis = require "rspamd_redis"
  838. if key then
  839. if is_write then
  840. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  841. else
  842. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  843. end
  844. else
  845. if is_write then
  846. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  847. else
  848. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  849. end
  850. end
  851. if not addr then
  852. logger.errx(task, 'cannot select server to make redis request')
  853. end
  854. if redis_params['expand_keys'] then
  855. local m = get_key_expansion_metadata(task)
  856. local indexes = get_key_indexes(command, args)
  857. for _, i in ipairs(indexes) do
  858. args[i] = lutil.template(args[i], m)
  859. end
  860. end
  861. local ip_addr = addr:get_addr()
  862. local options = {
  863. task = task,
  864. callback = rspamd_redis_make_request_cb,
  865. host = ip_addr,
  866. timeout = redis_params['timeout'],
  867. cmd = command,
  868. args = args
  869. }
  870. if extra_opts then
  871. for k, v in pairs(extra_opts) do
  872. options[k] = v
  873. end
  874. end
  875. if redis_params['username'] then
  876. options['username'] = redis_params['username']
  877. end
  878. if redis_params['password'] then
  879. options['password'] = redis_params['password']
  880. end
  881. if redis_params['db'] then
  882. options['dbname'] = redis_params['db']
  883. end
  884. lutil.debugm(N, task, 'perform request to redis server' ..
  885. ' (host=%s, timeout=%s): cmd: %s', ip_addr,
  886. options.timeout, options.cmd)
  887. local ret, conn = rspamd_redis.make_request(options)
  888. if not ret then
  889. addr:fail()
  890. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  891. end
  892. return ret, conn, addr
  893. end
  894. --[[[
  895. -- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
  896. -- Sends a request to Redis
  897. -- @param {rspamd_task} task task object
  898. -- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
  899. -- @param {string} key key to use for sharding
  900. -- @param {boolean} is_write should be `true` if we are performing a write operating
  901. -- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
  902. -- @param {string} command Redis command to run
  903. -- @param {table} args Numerically indexed table containing arguments for command
  904. --]]
  905. exports.rspamd_redis_make_request = rspamd_redis_make_request
  906. exports.redis_make_request = rspamd_redis_make_request
  907. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  908. is_write, callback, command, args, extra_opts)
  909. if not ev_base or not redis_params or not command then
  910. return false, nil, nil
  911. end
  912. local addr
  913. local function rspamd_redis_make_request_cb(err, data)
  914. if err then
  915. addr:fail()
  916. else
  917. addr:ok()
  918. end
  919. if callback then
  920. callback(err, data, addr)
  921. end
  922. end
  923. local rspamd_redis = require "rspamd_redis"
  924. if key then
  925. if is_write then
  926. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  927. else
  928. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  929. end
  930. else
  931. if is_write then
  932. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  933. else
  934. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  935. end
  936. end
  937. if not addr then
  938. logger.errx(cfg, 'cannot select server to make redis request')
  939. end
  940. local options = {
  941. ev_base = ev_base,
  942. config = cfg,
  943. callback = rspamd_redis_make_request_cb,
  944. host = addr:get_addr(),
  945. timeout = redis_params['timeout'],
  946. cmd = command,
  947. args = args
  948. }
  949. if extra_opts then
  950. for k, v in pairs(extra_opts) do
  951. options[k] = v
  952. end
  953. end
  954. if redis_params['username'] then
  955. options['username'] = redis_params['username']
  956. end
  957. if redis_params['password'] then
  958. options['password'] = redis_params['password']
  959. end
  960. if redis_params['db'] then
  961. options['dbname'] = redis_params['db']
  962. end
  963. lutil.debugm(N, cfg, 'perform taskless request to redis server' ..
  964. ' (host=%s, timeout=%s): cmd: %s', options.host:tostring(true),
  965. options.timeout, options.cmd)
  966. local ret, conn = rspamd_redis.make_request(options)
  967. if not ret then
  968. logger.errx('cannot execute redis request')
  969. addr:fail()
  970. end
  971. return ret, conn, addr
  972. end
  973. --[[[
  974. -- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
  975. -- Sends a request to Redis in context where `task` is not available for some specific use-cases
  976. -- Identical to redis_make_request() except in that first parameter is an `event base` object
  977. --]]
  978. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  979. exports.redis_make_request_taskless = redis_make_request_taskless
  980. local redis_scripts = {
  981. }
  982. local function script_set_loaded(script)
  983. if script.sha then
  984. script.loaded = true
  985. end
  986. local wait_table = {}
  987. for _, s in ipairs(script.waitq) do
  988. table.insert(wait_table, s)
  989. end
  990. script.waitq = {}
  991. for _, s in ipairs(wait_table) do
  992. s(script.loaded)
  993. end
  994. end
  995. local function prepare_redis_call(script)
  996. local servers = {}
  997. local options = {}
  998. if script.redis_params.read_servers then
  999. servers = lutil.table_merge(servers, script.redis_params.read_servers:all_upstreams())
  1000. end
  1001. if script.redis_params.write_servers then
  1002. servers = lutil.table_merge(servers, script.redis_params.write_servers:all_upstreams())
  1003. end
  1004. -- Call load script on each server, set loaded flag
  1005. script.in_flight = #servers
  1006. for _, s in ipairs(servers) do
  1007. local cur_opts = {
  1008. host = s:get_addr(),
  1009. timeout = script.redis_params['timeout'],
  1010. cmd = 'SCRIPT',
  1011. args = { 'LOAD', script.script },
  1012. upstream = s
  1013. }
  1014. if script.redis_params['username'] then
  1015. cur_opts['username'] = script.redis_params['username']
  1016. end
  1017. if script.redis_params['password'] then
  1018. cur_opts['password'] = script.redis_params['password']
  1019. end
  1020. if script.redis_params['db'] then
  1021. cur_opts['dbname'] = script.redis_params['db']
  1022. end
  1023. table.insert(options, cur_opts)
  1024. end
  1025. return options
  1026. end
  1027. local function load_script_task(script, task, is_write)
  1028. local rspamd_redis = require "rspamd_redis"
  1029. local opts = prepare_redis_call(script)
  1030. for _, opt in ipairs(opts) do
  1031. opt.task = task
  1032. opt.is_write = is_write
  1033. opt.callback = function(err, data)
  1034. if err then
  1035. logger.errx(task, 'cannot upload script to %s: %s; registered from: %s:%s',
  1036. opt.upstream:get_addr():to_string(true),
  1037. err, script.caller.short_src, script.caller.currentline)
  1038. opt.upstream:fail()
  1039. script.fatal_error = err
  1040. else
  1041. opt.upstream:ok()
  1042. logger.infox(task,
  1043. "uploaded redis script to %s %s %s, sha: %s",
  1044. opt.upstream:get_addr():to_string(true),
  1045. script.filename and "from file" or "with id", script.filename or script.id, data)
  1046. script.sha = data -- We assume that sha is the same on all servers
  1047. end
  1048. script.in_flight = script.in_flight - 1
  1049. if script.in_flight == 0 then
  1050. script_set_loaded(script)
  1051. end
  1052. end
  1053. local ret = rspamd_redis.make_request(opt)
  1054. if not ret then
  1055. logger.errx('cannot execute redis request to load script on %s',
  1056. opt.upstream:get_addr())
  1057. script.in_flight = script.in_flight - 1
  1058. opt.upstream:fail()
  1059. end
  1060. if script.in_flight == 0 then
  1061. script_set_loaded(script)
  1062. end
  1063. end
  1064. end
  1065. local function load_script_taskless(script, cfg, ev_base, is_write)
  1066. local rspamd_redis = require "rspamd_redis"
  1067. local opts = prepare_redis_call(script)
  1068. for _, opt in ipairs(opts) do
  1069. opt.config = cfg
  1070. opt.ev_base = ev_base
  1071. opt.is_write = is_write
  1072. opt.callback = function(err, data)
  1073. if err then
  1074. logger.errx(cfg, 'cannot upload script to %s: %s; registered from: %s:%s, filename: %s',
  1075. opt.upstream:get_addr():to_string(true),
  1076. err, script.caller.short_src, script.caller.currentline, script.filename)
  1077. opt.upstream:fail()
  1078. script.fatal_error = err
  1079. else
  1080. opt.upstream:ok()
  1081. logger.infox(cfg,
  1082. "uploaded redis script to %s %s %s, sha: %s",
  1083. opt.upstream:get_addr():to_string(true),
  1084. script.filename and "from file" or "with id", script.filename or script.id,
  1085. data)
  1086. script.sha = data -- We assume that sha is the same on all servers
  1087. script.fatal_error = nil
  1088. end
  1089. script.in_flight = script.in_flight - 1
  1090. if script.in_flight == 0 then
  1091. script_set_loaded(script)
  1092. end
  1093. end
  1094. local ret = rspamd_redis.make_request(opt)
  1095. if not ret then
  1096. logger.errx('cannot execute redis request to load script on %s',
  1097. opt.upstream:get_addr())
  1098. script.in_flight = script.in_flight - 1
  1099. opt.upstream:fail()
  1100. end
  1101. if script.in_flight == 0 then
  1102. script_set_loaded(script)
  1103. end
  1104. end
  1105. end
  1106. local function load_redis_script(script, cfg, ev_base, _)
  1107. if script.redis_params then
  1108. load_script_taskless(script, cfg, ev_base)
  1109. end
  1110. end
  1111. local function add_redis_script(script, redis_params, caller_level, maybe_filename)
  1112. if not caller_level then
  1113. caller_level = 2
  1114. end
  1115. local caller = debug.getinfo(caller_level) or debug.getinfo(caller_level - 1) or E
  1116. local new_script = {
  1117. caller = caller,
  1118. loaded = false,
  1119. redis_params = redis_params,
  1120. script = script,
  1121. waitq = {}, -- callbacks pending for script being loaded
  1122. id = #redis_scripts + 1,
  1123. filename = maybe_filename,
  1124. }
  1125. -- Register on load function
  1126. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  1127. local mult = 0.0
  1128. rspamd_config:add_periodic(ev_base, 0.0, function()
  1129. if not new_script.sha then
  1130. load_redis_script(new_script, cfg, ev_base, worker)
  1131. mult = mult + 1
  1132. return 1.0 * mult -- Check one more time in one second
  1133. end
  1134. return false
  1135. end, false)
  1136. end)
  1137. table.insert(redis_scripts, new_script)
  1138. return #redis_scripts
  1139. end
  1140. exports.add_redis_script = add_redis_script
  1141. -- Loads a Redis script from a file, strips comments, and passes the content to
  1142. -- `add_redis_script` function.
  1143. --
  1144. -- @param filename The name of the file containing the Redis script.
  1145. -- @param redis_params The Redis parameters to use for this script.
  1146. -- @return The ID of the newly added Redis script.
  1147. --
  1148. local function load_redis_script_from_file(filename, redis_params, dir)
  1149. local lua_util = require "lua_util"
  1150. local rspamd_logger = require "rspamd_logger"
  1151. if not dir then
  1152. dir = rspamd_paths.LUALIBDIR
  1153. end
  1154. local path = filename
  1155. if filename:sub(1, 1) ~= package.config:sub(1, 1) then
  1156. -- Relative path
  1157. path = lua_util.join_path(dir, "redis_scripts", filename)
  1158. end
  1159. -- Read file contents
  1160. local file = io.open(path, "r")
  1161. if not file then
  1162. rspamd_logger.errx("failed to open Redis script file: %s", path)
  1163. return nil
  1164. end
  1165. local script = file:read("*all")
  1166. if not script then
  1167. rspamd_logger.errx("failed to load Redis script file: %s", path)
  1168. return nil
  1169. end
  1170. file:close()
  1171. script = lua_util.strip_lua_comments(script)
  1172. return add_redis_script(script, redis_params, 3, filename)
  1173. end
  1174. exports.load_redis_script_from_file = load_redis_script_from_file
  1175. local function exec_redis_script(id, params, callback, keys, args)
  1176. local redis_args = {}
  1177. if not redis_scripts[id] then
  1178. logger.errx("cannot find registered script with id %s", id)
  1179. return false
  1180. end
  1181. local script = redis_scripts[id]
  1182. if script.fatal_error then
  1183. callback(script.fatal_error, nil)
  1184. return true
  1185. end
  1186. if not script.redis_params then
  1187. callback('no redis servers defined', nil)
  1188. return true
  1189. end
  1190. local function do_call(can_reload)
  1191. local function redis_cb(err, data)
  1192. if not err then
  1193. callback(err, data)
  1194. elseif string.match(err, 'NOSCRIPT') then
  1195. -- Schedule restart
  1196. script.sha = nil
  1197. if can_reload then
  1198. table.insert(script.waitq, do_call)
  1199. if script.in_flight == 0 then
  1200. -- Reload scripts if this has not been initiated yet
  1201. if params.task then
  1202. load_script_task(script, params.task)
  1203. else
  1204. load_script_taskless(script, rspamd_config, params.ev_base)
  1205. end
  1206. end
  1207. else
  1208. callback(err, data)
  1209. end
  1210. else
  1211. callback(err, data)
  1212. end
  1213. end
  1214. if #redis_args == 0 then
  1215. table.insert(redis_args, script.sha)
  1216. table.insert(redis_args, tostring(#keys))
  1217. for _, k in ipairs(keys) do
  1218. table.insert(redis_args, k)
  1219. end
  1220. if type(args) == 'table' then
  1221. for _, a in ipairs(args) do
  1222. table.insert(redis_args, a)
  1223. end
  1224. end
  1225. end
  1226. if params.task then
  1227. if not rspamd_redis_make_request(params.task, script.redis_params,
  1228. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1229. callback('Cannot make redis request', nil)
  1230. end
  1231. else
  1232. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  1233. script.redis_params,
  1234. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1235. callback('Cannot make redis request', nil)
  1236. end
  1237. end
  1238. end
  1239. if script.loaded then
  1240. do_call(true)
  1241. else
  1242. -- Delayed until scripts are loaded
  1243. if not params.task then
  1244. table.insert(script.waitq, do_call)
  1245. else
  1246. -- TODO: fix taskfull requests
  1247. table.insert(script.waitq, function()
  1248. if script.loaded then
  1249. do_call(false)
  1250. else
  1251. callback('NOSCRIPT', nil)
  1252. end
  1253. end)
  1254. load_script_task(script, params.task, params.is_write)
  1255. end
  1256. end
  1257. return true
  1258. end
  1259. exports.exec_redis_script = exec_redis_script
  1260. local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
  1261. if not redis_params then
  1262. return false, nil
  1263. end
  1264. local rspamd_redis = require "rspamd_redis"
  1265. local addr
  1266. if key then
  1267. if is_write then
  1268. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  1269. else
  1270. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  1271. end
  1272. else
  1273. if is_write then
  1274. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  1275. else
  1276. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  1277. end
  1278. end
  1279. if not addr then
  1280. logger.errx(cfg, 'cannot select server to make redis request')
  1281. end
  1282. local options = {
  1283. host = addr:get_addr(),
  1284. timeout = redis_params['timeout'],
  1285. config = cfg or rspamd_config,
  1286. ev_base = ev_base or rspamadm_ev_base,
  1287. session = redis_params.session or rspamadm_session
  1288. }
  1289. for k, v in pairs(redis_params) do
  1290. options[k] = v
  1291. end
  1292. if not options.config then
  1293. logger.errx('config is not set')
  1294. return false, nil, addr
  1295. end
  1296. if not options.ev_base then
  1297. logger.errx('ev_base is not set')
  1298. return false, nil, addr
  1299. end
  1300. if not options.session then
  1301. logger.errx('session is not set')
  1302. return false, nil, addr
  1303. end
  1304. local ret, conn = rspamd_redis.connect_sync(options)
  1305. if not ret then
  1306. logger.errx('cannot create redis connection: %s', conn)
  1307. addr:fail()
  1308. return false, nil, addr
  1309. end
  1310. if conn then
  1311. local need_exec = false
  1312. if redis_params['username'] then
  1313. if redis_params['password'] then
  1314. conn:add_cmd('AUTH', { redis_params['username'], redis_params['password'] })
  1315. need_exec = true
  1316. else
  1317. logger.warnx('Redis requires a password when username is supplied')
  1318. return false, nil, addr
  1319. end
  1320. elseif redis_params['password'] then
  1321. conn:add_cmd('AUTH', { redis_params['password'] })
  1322. need_exec = true
  1323. end
  1324. if redis_params['db'] then
  1325. conn:add_cmd('SELECT', { tostring(redis_params['db']) })
  1326. need_exec = true
  1327. elseif redis_params['dbname'] then
  1328. conn:add_cmd('SELECT', { tostring(redis_params['dbname']) })
  1329. need_exec = true
  1330. end
  1331. if need_exec then
  1332. local exec_ret, res = conn:exec()
  1333. if not exec_ret then
  1334. logger.errx('cannot prepare redis connection (authentication or db selection failure): %s',
  1335. res)
  1336. addr:fail()
  1337. return false, nil, addr
  1338. end
  1339. end
  1340. end
  1341. return ret, conn, addr
  1342. end
  1343. exports.redis_connect_sync = redis_connect_sync
  1344. --[[[
  1345. -- @function lua_redis.request(redis_params, attrs, req)
  1346. -- Sends a request to Redis synchronously with coroutines or asynchronously using
  1347. -- a callback (modern API)
  1348. -- @param redis_params a table of redis server parameters
  1349. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1350. -- @param req a table of request: a command + command options
  1351. -- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
  1352. --]]
  1353. exports.request = function(redis_params, attrs, req)
  1354. local lua_util = require "lua_util"
  1355. if not attrs or not redis_params or not req then
  1356. logger.errx('invalid arguments for redis request')
  1357. return false, nil, nil
  1358. end
  1359. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1360. logger.errx('invalid attributes for redis request')
  1361. return false, nil, nil
  1362. end
  1363. local opts = lua_util.shallowcopy(attrs)
  1364. local log_obj = opts.task or opts.config
  1365. local addr
  1366. if opts.callback then
  1367. -- Wrap callback
  1368. local callback = opts.callback
  1369. local function rspamd_redis_make_request_cb(err, data)
  1370. if err then
  1371. addr:fail()
  1372. else
  1373. addr:ok()
  1374. end
  1375. callback(err, data, addr)
  1376. end
  1377. opts.callback = rspamd_redis_make_request_cb
  1378. end
  1379. local rspamd_redis = require "rspamd_redis"
  1380. local is_write = opts.is_write
  1381. if opts.key then
  1382. if is_write then
  1383. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1384. else
  1385. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1386. end
  1387. else
  1388. if is_write then
  1389. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1390. else
  1391. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1392. end
  1393. end
  1394. if not addr then
  1395. logger.errx(log_obj, 'cannot select server to make redis request')
  1396. end
  1397. opts.host = addr:get_addr()
  1398. opts.timeout = redis_params.timeout
  1399. if type(req) == 'string' then
  1400. opts.cmd = req
  1401. else
  1402. -- XXX: modifies the input table
  1403. opts.cmd = table.remove(req, 1);
  1404. opts.args = req
  1405. end
  1406. if redis_params.username then
  1407. opts.username = redis_params.username
  1408. end
  1409. if redis_params.password then
  1410. opts.password = redis_params.password
  1411. end
  1412. if redis_params.db then
  1413. opts.dbname = redis_params.db
  1414. end
  1415. lutil.debugm(N, 'perform generic request to redis server' ..
  1416. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr,
  1417. opts.timeout, opts.cmd, opts.args)
  1418. if opts.callback then
  1419. local ret, conn = rspamd_redis.make_request(opts)
  1420. if not ret then
  1421. logger.errx(log_obj, 'cannot execute redis request')
  1422. addr:fail()
  1423. end
  1424. return ret, conn, addr
  1425. else
  1426. -- Coroutines version
  1427. local ret, conn = rspamd_redis.connect_sync(opts)
  1428. if not ret then
  1429. logger.errx(log_obj, 'cannot execute redis request')
  1430. addr:fail()
  1431. else
  1432. conn:add_cmd(opts.cmd, opts.args)
  1433. return conn:exec()
  1434. end
  1435. return false, nil, addr
  1436. end
  1437. end
  1438. --[[[
  1439. -- @function lua_redis.connect(redis_params, attrs)
  1440. -- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
  1441. -- @param redis_params a table of redis server parameters
  1442. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1443. -- @return {result,connection,address} boolean result, connection object, redis server address
  1444. --]]
  1445. exports.connect = function(redis_params, attrs)
  1446. local lua_util = require "lua_util"
  1447. if not attrs or not redis_params then
  1448. logger.errx('invalid arguments for redis connect')
  1449. return false, nil, nil
  1450. end
  1451. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1452. logger.errx('invalid attributes for redis connect')
  1453. return false, nil, nil
  1454. end
  1455. local opts = lua_util.shallowcopy(attrs)
  1456. local log_obj = opts.task or opts.config
  1457. local addr
  1458. if opts.callback then
  1459. -- Wrap callback
  1460. local callback = opts.callback
  1461. local function rspamd_redis_make_request_cb(err, data)
  1462. if err then
  1463. addr:fail()
  1464. else
  1465. addr:ok()
  1466. end
  1467. callback(err, data, addr)
  1468. end
  1469. opts.callback = rspamd_redis_make_request_cb
  1470. end
  1471. local rspamd_redis = require "rspamd_redis"
  1472. local is_write = opts.is_write
  1473. if opts.key then
  1474. if is_write then
  1475. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1476. else
  1477. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1478. end
  1479. else
  1480. if is_write then
  1481. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1482. else
  1483. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1484. end
  1485. end
  1486. if not addr then
  1487. logger.errx(log_obj, 'cannot select server to make redis connect')
  1488. end
  1489. opts.host = addr:get_addr()
  1490. opts.timeout = redis_params.timeout
  1491. if redis_params.username then
  1492. opts.username = redis_params.username
  1493. end
  1494. if redis_params.password then
  1495. opts.password = redis_params.password
  1496. end
  1497. if redis_params.db then
  1498. opts.dbname = redis_params.db
  1499. end
  1500. if opts.callback then
  1501. local ret, conn = rspamd_redis.connect(opts)
  1502. if not ret then
  1503. logger.errx(log_obj, 'cannot execute redis connect')
  1504. addr:fail()
  1505. end
  1506. return ret, conn, addr
  1507. else
  1508. -- Coroutines version
  1509. local ret, conn = rspamd_redis.connect_sync(opts)
  1510. if not ret then
  1511. logger.errx(log_obj, 'cannot execute redis connect')
  1512. addr:fail()
  1513. else
  1514. return true, conn, addr
  1515. end
  1516. return false, nil, addr
  1517. end
  1518. end
  1519. local redis_prefixes = {}
  1520. --[[[
  1521. -- @function lua_redis.register_prefix(prefix, module, description[, optional])
  1522. -- Register new redis prefix for documentation purposes
  1523. -- @param {string} prefix string prefix
  1524. -- @param {string} module module name
  1525. -- @param {string} description prefix description
  1526. -- @param {table} optional optional kv pairs (e.g. pattern)
  1527. --]]
  1528. local function register_prefix(prefix, module, description, optional)
  1529. local pr = {
  1530. module = module,
  1531. description = description
  1532. }
  1533. if optional and type(optional) == 'table' then
  1534. for k, v in pairs(optional) do
  1535. pr[k] = v
  1536. end
  1537. end
  1538. redis_prefixes[prefix] = pr
  1539. end
  1540. exports.register_prefix = register_prefix
  1541. --[[[
  1542. -- @function lua_redis.prefixes([mname])
  1543. -- Returns prefixes for specific module (or all prefixes). Returns a table prefix -> table
  1544. --]]
  1545. exports.prefixes = function(mname)
  1546. if not mname then
  1547. return redis_prefixes
  1548. else
  1549. local fun = require "fun"
  1550. return fun.totable(fun.filter(function(_, data)
  1551. return data.module == mname
  1552. end,
  1553. redis_prefixes))
  1554. end
  1555. end
  1556. return exports