You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_redis.lua 45KB


  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local ts = require("tableshape").types
  17. local exports = {}
  18. local E = {}
  19. local N = "lua_redis"
  20. local common_schema = ts.shape {
  21. timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  22. db = ts.string:is_optional(),
  23. database = ts.string:is_optional(),
  24. dbname = ts.string:is_optional(),
  25. prefix = ts.string:is_optional(),
  26. password = ts.string:is_optional(),
  27. expand_keys = ts.boolean:is_optional(),
  28. sentinels = (ts.string + ts.array_of(ts.string)):is_optional(),
  29. sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  30. sentinel_masters_pattern = ts.string:is_optional(),
  31. sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional(),
  32. }
  33. local config_schema =
  34. ts.shape({
  35. read_servers = ts.string + ts.array_of(ts.string),
  36. write_servers = ts.string + ts.array_of(ts.string),
  37. }, {extra_opts = common_schema}) +
  38. ts.shape({
  39. servers = ts.string + ts.array_of(ts.string),
  40. }, {extra_opts = common_schema}) +
  41. ts.shape({
  42. server = ts.string + ts.array_of(ts.string),
  43. }, {extra_opts = common_schema})
  44. exports.config_schema = config_schema
  45. local function redis_query_sentinel(ev_base, params, initialised)
  46. local function flatten_redis_table(tbl)
  47. local res = {}
  48. for i=1,#tbl,2 do
  49. res[tbl[i]] = tbl[i + 1]
  50. end
  51. return res
  52. end
  53. -- Coroutines syntax
  54. local rspamd_redis = require "rspamd_redis"
  55. local addr = params.sentinels:get_upstream_round_robin()
  56. local is_ok, connection = rspamd_redis.connect_sync({
  57. host = addr:get_addr(),
  58. timeout = params.timeout,
  59. config = rspamd_config,
  60. ev_base = ev_base,
  61. no_pool = true,
  62. })
  63. if not is_ok then
  64. logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
  65. tostring(addr:get_addr()))
  66. addr:fail()
  67. return
  68. end
  69. -- Get masters list
  70. connection:add_cmd('SENTINEL', {'masters'})
  71. local ok,result = connection:exec()
  72. if ok and result and type(result) == 'table' then
  73. local masters = {}
  74. for _,m in ipairs(result) do
  75. local master = flatten_redis_table(m)
  76. if params.sentinel_masters_pattern then
  77. if master.name:match(params.sentinel_masters_pattern) then
  78. lutil.debugm(N, 'found master %s with ip %s and port %s',
  79. master.name, master.ip, master.port)
  80. masters[master.name] = master
  81. else
  82. lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
  83. master.name, master.ip, master.port, params.sentinel_masters_pattern)
  84. end
  85. else
  86. lutil.debugm(N, 'found master %s with ip %s and port %s',
  87. master.name, master.ip, master.port)
  88. masters[master.name] = master
  89. end
  90. end
  91. -- For each master we need to get a list of slaves
  92. for k,v in pairs(masters) do
  93. v.slaves = {}
  94. local slave_result
  95. connection:add_cmd('SENTINEL', {'slaves', k})
  96. ok,slave_result = connection:exec()
  97. if ok then
  98. for _,s in ipairs(slave_result) do
  99. local slave = flatten_redis_table(s)
  100. lutil.debugm(N, rspamd_config,
  101. 'found slave form master %s with ip %s and port %s',
  102. v.name, slave.ip, slave.port)
  103. v.slaves[#v.slaves + 1] = slave
  104. end
  105. end
  106. end
  107. -- We now form new strings for masters and slaves
  108. local read_servers_tbl, write_servers_tbl = {}, {}
  109. for _,master in pairs(masters) do
  110. write_servers_tbl[#write_servers_tbl + 1] = string.format(
  111. '%s:%s', master.ip, master.port
  112. )
  113. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  114. '%s:%s', master.ip, master.port
  115. )
  116. for _,slave in ipairs(master.slaves) do
  117. if slave['master-link-status'] == 'ok' then
  118. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  119. '%s:%s', slave.ip, slave.port
  120. )
  121. end
  122. end
  123. end
  124. table.sort(read_servers_tbl)
  125. table.sort(write_servers_tbl)
  126. local read_servers_str = table.concat(read_servers_tbl, ',')
  127. local write_servers_str = table.concat(write_servers_tbl, ',')
  128. lutil.debugm(N, rspamd_config,
  129. 'new servers list: %s read; %s write',
  130. read_servers_str,
  131. write_servers_str)
  132. if read_servers_str ~= params.read_servers_str then
  133. local upstream_list = require "rspamd_upstream_list"
  134. local read_upstreams = upstream_list.create(rspamd_config,
  135. read_servers_str, 6379)
  136. if read_upstreams then
  137. logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
  138. addr:get_addr():to_string(true), read_servers_str)
  139. params.read_servers = read_upstreams
  140. params.read_servers_str = read_servers_str
  141. end
  142. end
  143. if write_servers_str ~= params.write_servers_str then
  144. local upstream_list = require "rspamd_upstream_list"
  145. local write_upstreams = upstream_list.create(rspamd_config,
  146. write_servers_str, 6379)
  147. if write_upstreams then
  148. logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
  149. addr:get_addr():to_string(true), write_servers_str)
  150. params.write_servers = write_upstreams
  151. params.write_servers_str = write_servers_str
  152. local queried = false
  153. local function monitor_failures(up, _, count)
  154. if count > params.sentinel_master_maxerrors and not queried then
  155. logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel',
  156. up:get_addr():to_string(true), count)
  157. queried = true -- Avoid multiple checks caused by this monitor
  158. redis_query_sentinel(ev_base, params, true)
  159. end
  160. end
  161. write_upstreams:add_watcher('failure', monitor_failures)
  162. end
  163. end
  164. addr:ok()
  165. else
  166. logger.errx('cannot get data from Redis Sentinel %s: %s',
  167. addr:get_addr():to_string(true), result)
  168. addr:fail()
  169. end
  170. end
  171. local function add_redis_sentinels(params)
  172. local upstream_list = require "rspamd_upstream_list"
  173. local upstreams_sentinels = upstream_list.create(rspamd_config,
  174. params.sentinels, 5000)
  175. if not upstreams_sentinels then
  176. logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
  177. params.sentinels)
  178. return
  179. end
  180. params.sentinels = upstreams_sentinels
  181. if not params.sentinel_watch_time then
  182. params.sentinel_watch_time = 60 -- Each minute
  183. end
  184. if not params.sentinel_master_maxerrors then
  185. params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking
  186. end
  187. rspamd_config:add_on_load(function(_, ev_base, worker)
  188. local initialised = false
  189. if worker:is_scanner() then
  190. rspamd_config:add_periodic(ev_base, 0.0, function()
  191. redis_query_sentinel(ev_base, params, initialised)
  192. initialised = true
  193. return params.sentinel_watch_time
  194. end, false)
  195. end
  196. end)
  197. end
  198. local cached_results = {}
  199. local function calculate_redis_hash(params)
  200. local cr = require "rspamd_cryptobox_hash"
  201. local h = cr.create()
  202. local function rec_hash(k, v)
  203. if type(v) == 'string' then
  204. h:update(k)
  205. h:update(v)
  206. elseif type(v) == 'number' then
  207. h:update(k)
  208. h:update(tostring(v))
  209. elseif type(v) == 'table' then
  210. for kk,vv in pairs(v) do
  211. rec_hash(kk, vv)
  212. end
  213. end
  214. end
  215. rec_hash('top', params)
  216. return h:base32()
  217. end
  218. local function process_redis_opts(options, redis_params)
  219. local default_timeout = 1.0
  220. local default_expand_keys = false
  221. if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then
  222. if options['timeout'] then
  223. redis_params['timeout'] = tonumber(options['timeout'])
  224. else
  225. redis_params['timeout'] = default_timeout
  226. end
  227. end
  228. if options['prefix'] and not redis_params['prefix'] then
  229. redis_params['prefix'] = options['prefix']
  230. end
  231. if type(options['expand_keys']) == 'boolean' then
  232. redis_params['expand_keys'] = options['expand_keys']
  233. else
  234. redis_params['expand_keys'] = default_expand_keys
  235. end
  236. if not redis_params['db'] then
  237. if options['db'] then
  238. redis_params['db'] = tostring(options['db'])
  239. elseif options['dbname'] then
  240. redis_params['db'] = tostring(options['dbname'])
  241. elseif options['database'] then
  242. redis_params['db'] = tostring(options['database'])
  243. end
  244. end
  245. if options['password'] and not redis_params['password'] then
  246. redis_params['password'] = options['password']
  247. end
  248. if not redis_params.sentinels and options.sentinels then
  249. redis_params.sentinels = options.sentinels
  250. end
  251. end
  252. local function enrich_defaults(rspamd_config, module, redis_params)
  253. if rspamd_config then
  254. local opts = rspamd_config:get_all_opt('redis')
  255. if opts then
  256. if module then
  257. if opts[module] then
  258. process_redis_opts(opts[module], redis_params)
  259. end
  260. end
  261. process_redis_opts(opts, redis_params)
  262. end
  263. end
  264. end
  265. local function maybe_return_cached(redis_params)
  266. local h = calculate_redis_hash(redis_params)
  267. if cached_results[h] then
  268. lutil.debugm(N, 'reused redis server: %s', redis_params)
  269. return cached_results[h]
  270. end
  271. redis_params.hash = h
  272. cached_results[h] = redis_params
  273. if not redis_params.read_only and redis_params.sentinels then
  274. add_redis_sentinels(redis_params)
  275. end
  276. lutil.debugm(N, 'loaded new redis server: %s', redis_params)
  277. return redis_params
  278. end
  279. --[[[
  280. -- @module lua_redis
  281. -- This module contains helper functions for working with Redis
  282. --]]
  283. local function process_redis_options(options, rspamd_config, result)
  284. local default_port = 6379
  285. local upstream_list = require "rspamd_upstream_list"
  286. local read_only = true
  287. -- Try to get read servers:
  288. local upstreams_read, upstreams_write
  289. if options['read_servers'] then
  290. if rspamd_config then
  291. upstreams_read = upstream_list.create(rspamd_config,
  292. options['read_servers'], default_port)
  293. else
  294. upstreams_read = upstream_list.create(options['read_servers'],
  295. default_port)
  296. end
  297. result.read_servers_str = options['read_servers']
  298. elseif options['servers'] then
  299. if rspamd_config then
  300. upstreams_read = upstream_list.create(rspamd_config,
  301. options['servers'], default_port)
  302. else
  303. upstreams_read = upstream_list.create(options['servers'], default_port)
  304. end
  305. result.read_servers_str = options['servers']
  306. read_only = false
  307. elseif options['server'] then
  308. if rspamd_config then
  309. upstreams_read = upstream_list.create(rspamd_config,
  310. options['server'], default_port)
  311. else
  312. upstreams_read = upstream_list.create(options['server'], default_port)
  313. end
  314. result.read_servers_str = options['server']
  315. read_only = false
  316. end
  317. if upstreams_read then
  318. if options['write_servers'] then
  319. if rspamd_config then
  320. upstreams_write = upstream_list.create(rspamd_config,
  321. options['write_servers'], default_port)
  322. else
  323. upstreams_write = upstream_list.create(options['write_servers'],
  324. default_port)
  325. end
  326. result.write_servers_str = options['write_servers']
  327. read_only = false
  328. elseif not read_only then
  329. upstreams_write = upstreams_read
  330. result.write_servers_str = result.read_servers_str
  331. end
  332. end
  333. -- Store options
  334. process_redis_opts(options, result)
  335. if read_only and not upstreams_write then
  336. result.read_only = true
  337. elseif upstreams_write then
  338. result.read_only = false
  339. end
  340. if upstreams_read then
  341. result.read_servers = upstreams_read
  342. if upstreams_write then
  343. result.write_servers = upstreams_write
  344. end
  345. return true
  346. end
  347. lutil.debugm(N, rspamd_config,
  348. 'cannot load redis server from obj: %s, processed to %s',
  349. options, result)
  350. return false
  351. end
  352. --[[[
  353. @function try_load_redis_servers(options, rspamd_config, no_fallback)
  354. Tries to load redis servers from the specified `options` object.
  355. Returns `redis_params` table or nil in case of failure
  356. --]]
  357. exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name)
  358. local result = {}
  359. if process_redis_options(options, rspamd_config, result) then
  360. if not no_fallback then
  361. enrich_defaults(rspamd_config, module_name, result)
  362. end
  363. return maybe_return_cached(result)
  364. end
  365. end
  366. -- This function parses redis server definition using either
  367. -- specific server string for this module or global
  368. -- redis section
  369. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  370. local result = {}
  371. -- Try local options
  372. local opts
  373. lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name)
  374. if not module_opts then
  375. opts = rspamd_config:get_all_opt(module_name)
  376. else
  377. opts = module_opts
  378. end
  379. if opts then
  380. local ret
  381. if opts.redis then
  382. ret = process_redis_options(opts.redis, rspamd_config, result)
  383. if ret then
  384. if not no_fallback then
  385. enrich_defaults(rspamd_config, module_name, result)
  386. end
  387. return maybe_return_cached(result)
  388. end
  389. end
  390. ret = process_redis_options(opts, rspamd_config, result)
  391. if ret then
  392. if not no_fallback then
  393. enrich_defaults(rspamd_config, module_name, result)
  394. end
  395. return maybe_return_cached(result)
  396. end
  397. end
  398. if no_fallback then
  399. logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled",
  400. module_name)
  401. return nil
  402. end
  403. -- Try global options
  404. opts = rspamd_config:get_all_opt('redis')
  405. if opts then
  406. local ret
  407. if opts[module_name] then
  408. ret = process_redis_options(opts[module_name], rspamd_config, result)
  409. if ret then
  410. return maybe_return_cached(result)
  411. end
  412. else
  413. ret = process_redis_options(opts, rspamd_config, result)
  414. -- Exclude disabled
  415. if opts['disabled_modules'] then
  416. for _,v in ipairs(opts['disabled_modules']) do
  417. if v == module_name then
  418. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  419. module_name)
  420. return nil
  421. end
  422. end
  423. end
  424. if ret then
  425. logger.infox(rspamd_config, "use default Redis settings for %s",
  426. module_name)
  427. return maybe_return_cached(result)
  428. end
  429. end
  430. end
  431. if result.read_servers then
  432. return maybe_return_cached(result)
  433. end
  434. return nil
  435. end
  436. --[[[
  437. -- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
  438. -- Extracts Redis server settings from configuration
  439. -- @param {string} module_name name of module to get settings for
  440. -- @param {table} module_opts settings for module or `nil` to fetch them from configuration
  441. -- @param {boolean} no_fallback should be `true` if global settings must not be used
  442. -- @return {table} redis server settings
  443. -- @example
  444. -- local rconfig = lua_redis.parse_redis_server('my_module')
  445. -- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
  446. -- -- ['timeout'] contains timeout in seconds
  447. -- -- ['expand_keys'] if true tells that redis key expansion is enabled
  448. --]]
  449. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  450. exports.parse_redis_server = rspamd_parse_redis_server
  451. local process_cmd = {
  452. bitop = function(args)
  453. local idx_l = {}
  454. for i = 2, #args do
  455. table.insert(idx_l, i)
  456. end
  457. return idx_l
  458. end,
  459. blpop = function(args)
  460. local idx_l = {}
  461. for i = 1, #args -1 do
  462. table.insert(idx_l, i)
  463. end
  464. return idx_l
  465. end,
  466. eval = function(args)
  467. local idx_l = {}
  468. local numkeys = args[2]
  469. if numkeys and tonumber(numkeys) >= 1 then
  470. for i = 3, numkeys + 2 do
  471. table.insert(idx_l, i)
  472. end
  473. end
  474. return idx_l
  475. end,
  476. set = function(args)
  477. return {1}
  478. end,
  479. mget = function(args)
  480. local idx_l = {}
  481. for i = 1, #args do
  482. table.insert(idx_l, i)
  483. end
  484. return idx_l
  485. end,
  486. mset = function(args)
  487. local idx_l = {}
  488. for i = 1, #args, 2 do
  489. table.insert(idx_l, i)
  490. end
  491. return idx_l
  492. end,
  493. sdiffstore = function(args)
  494. local idx_l = {}
  495. for i = 2, #args do
  496. table.insert(idx_l, i)
  497. end
  498. return idx_l
  499. end,
  500. smove = function(args)
  501. return {1, 2}
  502. end,
  503. script = function() end
  504. }
  505. process_cmd.append = process_cmd.set
  506. process_cmd.auth = process_cmd.script
  507. process_cmd.bgrewriteaof = process_cmd.script
  508. process_cmd.bgsave = process_cmd.script
  509. process_cmd.bitcount = process_cmd.set
  510. process_cmd.bitfield = process_cmd.set
  511. process_cmd.bitpos = process_cmd.set
  512. process_cmd.brpop = process_cmd.blpop
  513. process_cmd.brpoplpush = process_cmd.blpop
  514. process_cmd.client = process_cmd.script
  515. process_cmd.cluster = process_cmd.script
  516. process_cmd.command = process_cmd.script
  517. process_cmd.config = process_cmd.script
  518. process_cmd.dbsize = process_cmd.script
  519. process_cmd.debug = process_cmd.script
  520. process_cmd.decr = process_cmd.set
  521. process_cmd.decrby = process_cmd.set
  522. process_cmd.del = process_cmd.mget
  523. process_cmd.discard = process_cmd.script
  524. process_cmd.dump = process_cmd.set
  525. process_cmd.echo = process_cmd.script
  526. process_cmd.evalsha = process_cmd.eval
  527. process_cmd.exec = process_cmd.script
  528. process_cmd.exists = process_cmd.mget
  529. process_cmd.expire = process_cmd.set
  530. process_cmd.expireat = process_cmd.set
  531. process_cmd.flushall = process_cmd.script
  532. process_cmd.flushdb = process_cmd.script
  533. process_cmd.geoadd = process_cmd.set
  534. process_cmd.geohash = process_cmd.set
  535. process_cmd.geopos = process_cmd.set
  536. process_cmd.geodist = process_cmd.set
  537. process_cmd.georadius = process_cmd.set
  538. process_cmd.georadiusbymember = process_cmd.set
  539. process_cmd.get = process_cmd.set
  540. process_cmd.getbit = process_cmd.set
  541. process_cmd.getrange = process_cmd.set
  542. process_cmd.getset = process_cmd.set
  543. process_cmd.hdel = process_cmd.set
  544. process_cmd.hexists = process_cmd.set
  545. process_cmd.hget = process_cmd.set
  546. process_cmd.hgetall = process_cmd.set
  547. process_cmd.hincrby = process_cmd.set
  548. process_cmd.hincrbyfloat = process_cmd.set
  549. process_cmd.hkeys = process_cmd.set
  550. process_cmd.hlen = process_cmd.set
  551. process_cmd.hmget = process_cmd.set
  552. process_cmd.hmset = process_cmd.set
  553. process_cmd.hscan = process_cmd.set
  554. process_cmd.hset = process_cmd.set
  555. process_cmd.hsetnx = process_cmd.set
  556. process_cmd.hstrlen = process_cmd.set
  557. process_cmd.hvals = process_cmd.set
  558. process_cmd.incr = process_cmd.set
  559. process_cmd.incrby = process_cmd.set
  560. process_cmd.incrbyfloat = process_cmd.set
  561. process_cmd.info = process_cmd.script
  562. process_cmd.keys = process_cmd.script
  563. process_cmd.lastsave = process_cmd.script
  564. process_cmd.lindex = process_cmd.set
  565. process_cmd.linsert = process_cmd.set
  566. process_cmd.llen = process_cmd.set
  567. process_cmd.lpop = process_cmd.set
  568. process_cmd.lpush = process_cmd.set
  569. process_cmd.lpushx = process_cmd.set
  570. process_cmd.lrange = process_cmd.set
  571. process_cmd.lrem = process_cmd.set
  572. process_cmd.lset = process_cmd.set
  573. process_cmd.ltrim = process_cmd.set
  574. process_cmd.migrate = process_cmd.script
  575. process_cmd.monitor = process_cmd.script
  576. process_cmd.move = process_cmd.set
  577. process_cmd.msetnx = process_cmd.mset
  578. process_cmd.multi = process_cmd.script
  579. process_cmd.object = process_cmd.script
  580. process_cmd.persist = process_cmd.set
  581. process_cmd.pexpire = process_cmd.set
  582. process_cmd.pexpireat = process_cmd.set
  583. process_cmd.pfadd = process_cmd.set
  584. process_cmd.pfcount = process_cmd.set
  585. process_cmd.pfmerge = process_cmd.mget
  586. process_cmd.ping = process_cmd.script
  587. process_cmd.psetex = process_cmd.set
  588. process_cmd.psubscribe = process_cmd.script
  589. process_cmd.pubsub = process_cmd.script
  590. process_cmd.pttl = process_cmd.set
  591. process_cmd.publish = process_cmd.script
  592. process_cmd.punsubscribe = process_cmd.script
  593. process_cmd.quit = process_cmd.script
  594. process_cmd.randomkey = process_cmd.script
  595. process_cmd.readonly = process_cmd.script
  596. process_cmd.readwrite = process_cmd.script
  597. process_cmd.rename = process_cmd.mget
  598. process_cmd.renamenx = process_cmd.mget
  599. process_cmd.restore = process_cmd.set
  600. process_cmd.role = process_cmd.script
  601. process_cmd.rpop = process_cmd.set
  602. process_cmd.rpoplpush = process_cmd.mget
  603. process_cmd.rpush = process_cmd.set
  604. process_cmd.rpushx = process_cmd.set
  605. process_cmd.sadd = process_cmd.set
  606. process_cmd.save = process_cmd.script
  607. process_cmd.scard = process_cmd.set
  608. process_cmd.sdiff = process_cmd.mget
  609. process_cmd.select = process_cmd.script
  610. process_cmd.setbit = process_cmd.set
  611. process_cmd.setex = process_cmd.set
  612. process_cmd.setnx = process_cmd.set
  613. process_cmd.sinterstore = process_cmd.sdiff
  614. process_cmd.sismember = process_cmd.set
  615. process_cmd.slaveof = process_cmd.script
  616. process_cmd.slowlog = process_cmd.script
  617. process_cmd.smembers = process_cmd.script
  618. process_cmd.sort = process_cmd.set
  619. process_cmd.spop = process_cmd.set
  620. process_cmd.srandmember = process_cmd.set
  621. process_cmd.srem = process_cmd.set
  622. process_cmd.strlen = process_cmd.set
  623. process_cmd.subscribe = process_cmd.script
  624. process_cmd.sunion = process_cmd.mget
  625. process_cmd.sunionstore = process_cmd.mget
  626. process_cmd.swapdb = process_cmd.script
  627. process_cmd.sync = process_cmd.script
  628. process_cmd.time = process_cmd.script
  629. process_cmd.touch = process_cmd.mget
  630. process_cmd.ttl = process_cmd.set
  631. process_cmd.type = process_cmd.set
  632. process_cmd.unsubscribe = process_cmd.script
  633. process_cmd.unlink = process_cmd.mget
  634. process_cmd.unwatch = process_cmd.script
  635. process_cmd.wait = process_cmd.script
  636. process_cmd.watch = process_cmd.mget
  637. process_cmd.zadd = process_cmd.set
  638. process_cmd.zcard = process_cmd.set
  639. process_cmd.zcount = process_cmd.set
  640. process_cmd.zincrby = process_cmd.set
  641. process_cmd.zinterstore = process_cmd.eval
  642. process_cmd.zlexcount = process_cmd.set
  643. process_cmd.zrange = process_cmd.set
  644. process_cmd.zrangebylex = process_cmd.set
  645. process_cmd.zrank = process_cmd.set
  646. process_cmd.zrem = process_cmd.set
  647. process_cmd.zrembylex = process_cmd.set
  648. process_cmd.zrembyrank = process_cmd.set
  649. process_cmd.zrembyscore = process_cmd.set
  650. process_cmd.zrevrange = process_cmd.set
  651. process_cmd.zrevrangebyscore = process_cmd.set
  652. process_cmd.zrevrank = process_cmd.set
  653. process_cmd.zscore = process_cmd.set
  654. process_cmd.zunionstore = process_cmd.eval
  655. process_cmd.scan = process_cmd.script
  656. process_cmd.sscan = process_cmd.set
  657. process_cmd.hscan = process_cmd.set
  658. process_cmd.zscan = process_cmd.set
  659. local function get_key_indexes(cmd, args)
  660. local idx_l = {}
  661. cmd = string.lower(cmd)
  662. if process_cmd[cmd] then
  663. idx_l = process_cmd[cmd](args)
  664. else
  665. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  666. end
  667. return idx_l
  668. end
  669. local gen_meta = {
  670. principal_recipient = function(task)
  671. return task:get_principal_recipient()
  672. end,
  673. principal_recipient_domain = function(task)
  674. local p = task:get_principal_recipient()
  675. if not p then return end
  676. return string.match(p, '.*@(.*)')
  677. end,
  678. ip = function(task)
  679. local i = task:get_ip()
  680. if i and i:is_valid() then return i:to_string() end
  681. end,
  682. from = function(task)
  683. return ((task:get_from('smtp') or E)[1] or E)['addr']
  684. end,
  685. from_domain = function(task)
  686. return ((task:get_from('smtp') or E)[1] or E)['domain']
  687. end,
  688. from_domain_or_helo_domain = function(task)
  689. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  690. if d and #d > 0 then return d end
  691. return task:get_helo()
  692. end,
  693. mime_from = function(task)
  694. return ((task:get_from('mime') or E)[1] or E)['addr']
  695. end,
  696. mime_from_domain = function(task)
  697. return ((task:get_from('mime') or E)[1] or E)['domain']
  698. end,
  699. }
  700. local function gen_get_esld(f)
  701. return function(task)
  702. local d = f(task)
  703. if not d then return end
  704. return rspamd_util.get_tld(d)
  705. end
  706. end
  707. gen_meta.smtp_from = gen_meta.from
  708. gen_meta.smtp_from_domain = gen_meta.from_domain
  709. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  710. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  711. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  712. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  713. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  714. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  715. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  716. local function get_key_expansion_metadata(task)
  717. local md_mt = {
  718. __index = function(self, k)
  719. k = string.lower(k)
  720. local v = rawget(self, k)
  721. if v then
  722. return v
  723. end
  724. if gen_meta[k] then
  725. v = gen_meta[k](task)
  726. rawset(self, k, v)
  727. end
  728. return v
  729. end,
  730. }
  731. local lazy_meta = {}
  732. setmetatable(lazy_meta, md_mt)
  733. return lazy_meta
  734. end
  735. -- Performs async call to redis hiding all complexity inside function
  736. -- task - rspamd_task
  737. -- redis_params - valid params returned by rspamd_parse_redis_server
  738. -- key - key to select upstream or nil to select round-robin/master-slave
  739. -- is_write - true if need to write to redis server
  740. -- callback - function to be called upon request is completed
  741. -- command - redis command
  742. -- args - table of arguments
  743. -- extra_opts - table of optional request arguments
  744. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  745. callback, command, args, extra_opts)
  746. local addr
  747. local function rspamd_redis_make_request_cb(err, data)
  748. if err then
  749. addr:fail()
  750. else
  751. addr:ok()
  752. end
  753. callback(err, data, addr)
  754. end
  755. if not task or not redis_params or not callback or not command then
  756. return false,nil,nil
  757. end
  758. local rspamd_redis = require "rspamd_redis"
  759. if key then
  760. if is_write then
  761. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  762. else
  763. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  764. end
  765. else
  766. if is_write then
  767. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  768. else
  769. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  770. end
  771. end
  772. if not addr then
  773. logger.errx(task, 'cannot select server to make redis request')
  774. end
  775. if redis_params['expand_keys'] then
  776. local m = get_key_expansion_metadata(task)
  777. local indexes = get_key_indexes(command, args)
  778. for _, i in ipairs(indexes) do
  779. args[i] = lutil.template(args[i], m)
  780. end
  781. end
  782. local ip_addr = addr:get_addr()
  783. local options = {
  784. task = task,
  785. callback = rspamd_redis_make_request_cb,
  786. host = ip_addr,
  787. timeout = redis_params['timeout'],
  788. cmd = command,
  789. args = args
  790. }
  791. if extra_opts then
  792. for k,v in pairs(extra_opts) do
  793. options[k] = v
  794. end
  795. end
  796. if redis_params['password'] then
  797. options['password'] = redis_params['password']
  798. end
  799. if redis_params['db'] then
  800. options['dbname'] = redis_params['db']
  801. end
  802. lutil.debugm(N, task, 'perform request to redis server' ..
  803. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', ip_addr,
  804. options.timeout, options.cmd, args)
  805. local ret,conn = rspamd_redis.make_request(options)
  806. if not ret then
  807. addr:fail()
  808. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  809. end
  810. return ret,conn,addr
  811. end
  812. --[[[
  813. -- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
  814. -- Sends a request to Redis
  815. -- @param {rspamd_task} task task object
  816. -- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
  817. -- @param {string} key key to use for sharding
  818. -- @param {boolean} is_write should be `true` if we are performing a write operating
  819. -- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
  820. -- @param {string} command Redis command to run
  821. -- @param {table} args Numerically indexed table containing arguments for command
  822. --]]
  823. exports.rspamd_redis_make_request = rspamd_redis_make_request
  824. exports.redis_make_request = rspamd_redis_make_request
  825. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  826. is_write, callback, command, args, extra_opts)
  827. if not ev_base or not redis_params or not callback or not command then
  828. return false,nil,nil
  829. end
  830. local addr
  831. local function rspamd_redis_make_request_cb(err, data)
  832. if err then
  833. addr:fail()
  834. else
  835. addr:ok()
  836. end
  837. callback(err, data, addr)
  838. end
  839. local rspamd_redis = require "rspamd_redis"
  840. if key then
  841. if is_write then
  842. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  843. else
  844. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  845. end
  846. else
  847. if is_write then
  848. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  849. else
  850. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  851. end
  852. end
  853. if not addr then
  854. logger.errx(cfg, 'cannot select server to make redis request')
  855. end
  856. local options = {
  857. ev_base = ev_base,
  858. config = cfg,
  859. callback = rspamd_redis_make_request_cb,
  860. host = addr:get_addr(),
  861. timeout = redis_params['timeout'],
  862. cmd = command,
  863. args = args
  864. }
  865. if extra_opts then
  866. for k,v in pairs(extra_opts) do
  867. options[k] = v
  868. end
  869. end
  870. if redis_params['password'] then
  871. options['password'] = redis_params['password']
  872. end
  873. if redis_params['db'] then
  874. options['dbname'] = redis_params['db']
  875. end
  876. lutil.debugm(N, cfg, 'perform taskless request to redis server' ..
  877. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', options.host,
  878. options.timeout, options.cmd, args)
  879. local ret,conn = rspamd_redis.make_request(options)
  880. if not ret then
  881. logger.errx('cannot execute redis request')
  882. addr:fail()
  883. end
  884. return ret,conn,addr
  885. end
  886. --[[[
  887. -- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
  888. -- Sends a request to Redis in context where `task` is not available for some specific use-cases
  889. -- Identical to redis_make_request() except in that first parameter is an `event base` object
  890. --]]
  891. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  892. exports.redis_make_request_taskless = redis_make_request_taskless
  893. local redis_scripts = {
  894. }
  895. local function script_set_loaded(script)
  896. if script.sha then
  897. script.loaded = true
  898. end
  899. local wait_table = {}
  900. for _,s in ipairs(script.waitq) do
  901. table.insert(wait_table, s)
  902. end
  903. script.waitq = {}
  904. for _,s in ipairs(wait_table) do
  905. s(script.loaded)
  906. end
  907. end
  908. local function prepare_redis_call(script)
  909. local function merge_tables(t1, t2)
  910. for k,v in pairs(t2) do t1[k] = v end
  911. end
  912. local servers = {}
  913. local options = {}
  914. if script.redis_params.read_servers then
  915. merge_tables(servers, script.redis_params.read_servers:all_upstreams())
  916. end
  917. if script.redis_params.write_servers then
  918. merge_tables(servers, script.redis_params.write_servers:all_upstreams())
  919. end
  920. -- Call load script on each server, set loaded flag
  921. script.in_flight = #servers
  922. for _,s in ipairs(servers) do
  923. local cur_opts = {
  924. host = s:get_addr(),
  925. timeout = script.redis_params['timeout'],
  926. cmd = 'SCRIPT',
  927. args = {'LOAD', script.script },
  928. upstream = s
  929. }
  930. if script.redis_params['password'] then
  931. cur_opts['password'] = script.redis_params['password']
  932. end
  933. if script.redis_params['db'] then
  934. cur_opts['dbname'] = script.redis_params['db']
  935. end
  936. table.insert(options, cur_opts)
  937. end
  938. return options
  939. end
  940. local function load_script_task(script, task)
  941. local rspamd_redis = require "rspamd_redis"
  942. local opts = prepare_redis_call(script)
  943. for _,opt in ipairs(opts) do
  944. opt.task = task
  945. opt.callback = function(err, data)
  946. if err then
  947. logger.errx(task, 'cannot upload script to %s: %s; registered from: %s:%s',
  948. opt.upstream:get_addr(), err, script.caller.short_src, script.caller.currentline)
  949. opt.upstream:fail()
  950. script.fatal_error = err
  951. else
  952. opt.upstream:ok()
  953. logger.infox(task,
  954. "uploaded redis script to %s with id %s, sha: %s",
  955. opt.upstream:get_addr(), script.id, data)
  956. script.sha = data -- We assume that sha is the same on all servers
  957. end
  958. script.in_flight = script.in_flight - 1
  959. if script.in_flight == 0 then
  960. script_set_loaded(script)
  961. end
  962. end
  963. local ret = rspamd_redis.make_request(opt)
  964. if not ret then
  965. logger.errx('cannot execute redis request to load script on %s',
  966. opt.upstream:get_addr())
  967. script.in_flight = script.in_flight - 1
  968. opt.upstream:fail()
  969. end
  970. if script.in_flight == 0 then
  971. script_set_loaded(script)
  972. end
  973. end
  974. end
  975. local function load_script_taskless(script, cfg, ev_base)
  976. local rspamd_redis = require "rspamd_redis"
  977. local opts = prepare_redis_call(script)
  978. for _,opt in ipairs(opts) do
  979. opt.config = cfg
  980. opt.ev_base = ev_base
  981. opt.callback = function(err, data)
  982. if err then
  983. logger.errx(cfg, 'cannot upload script to %s: %s; registered from: %s:%s',
  984. opt.upstream:get_addr(), err, script.caller.short_src, script.caller.currentline)
  985. opt.upstream:fail()
  986. script.fatal_error = err
  987. else
  988. opt.upstream:ok()
  989. logger.infox(cfg,
  990. "uploaded redis script to %s with id %s, sha: %s",
  991. opt.upstream:get_addr(), script.id, data)
  992. script.sha = data -- We assume that sha is the same on all servers
  993. script.fatal_error = nil
  994. end
  995. script.in_flight = script.in_flight - 1
  996. if script.in_flight == 0 then
  997. script_set_loaded(script)
  998. end
  999. end
  1000. local ret = rspamd_redis.make_request(opt)
  1001. if not ret then
  1002. logger.errx('cannot execute redis request to load script on %s',
  1003. opt.upstream:get_addr())
  1004. script.in_flight = script.in_flight - 1
  1005. opt.upstream:fail()
  1006. end
  1007. if script.in_flight == 0 then
  1008. script_set_loaded(script)
  1009. end
  1010. end
  1011. end
  1012. local function load_redis_script(script, cfg, ev_base, _)
  1013. if script.redis_params then
  1014. load_script_taskless(script, cfg, ev_base)
  1015. end
  1016. end
  1017. local function add_redis_script(script, redis_params)
  1018. local caller = debug.getinfo(2)
  1019. local new_script = {
  1020. caller = caller,
  1021. loaded = false,
  1022. redis_params = redis_params,
  1023. script = script,
  1024. waitq = {}, -- callbacks pending for script being loaded
  1025. id = #redis_scripts + 1
  1026. }
  1027. -- Register on load function
  1028. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  1029. local mult = 0.0
  1030. rspamd_config:add_periodic(ev_base, 0.0, function()
  1031. if not new_script.sha then
  1032. load_redis_script(new_script, cfg, ev_base, worker)
  1033. mult = mult + 1
  1034. return 1.0 * mult -- Check one more time in one second
  1035. end
  1036. return false
  1037. end, false)
  1038. end)
  1039. table.insert(redis_scripts, new_script)
  1040. return #redis_scripts
  1041. end
  1042. exports.add_redis_script = add_redis_script
  1043. local function exec_redis_script(id, params, callback, keys, args)
  1044. local redis_args = {}
  1045. if not redis_scripts[id] then
  1046. logger.errx("cannot find registered script with id %s", id)
  1047. return false
  1048. end
  1049. local script = redis_scripts[id]
  1050. if script.fatal_error then
  1051. callback(script.fatal_error, nil)
  1052. return true
  1053. end
  1054. if not script.redis_params then
  1055. callback('no redis servers defined', nil)
  1056. return true
  1057. end
  1058. local function do_call(can_reload)
  1059. local function redis_cb(err, data)
  1060. if not err then
  1061. callback(err, data)
  1062. elseif string.match(err, 'NOSCRIPT') then
  1063. -- Schedule restart
  1064. script.sha = nil
  1065. if can_reload then
  1066. table.insert(script.waitq, do_call)
  1067. if script.in_flight == 0 then
  1068. -- Reload scripts if this has not been initiated yet
  1069. if params.task then
  1070. load_script_task(script, params.task)
  1071. else
  1072. load_script_taskless(script, rspamd_config, params.ev_base)
  1073. end
  1074. end
  1075. else
  1076. callback(err, data)
  1077. end
  1078. else
  1079. callback(err, data)
  1080. end
  1081. end
  1082. if #redis_args == 0 then
  1083. table.insert(redis_args, script.sha)
  1084. table.insert(redis_args, tostring(#keys))
  1085. for _,k in ipairs(keys) do
  1086. table.insert(redis_args, k)
  1087. end
  1088. if type(args) == 'table' then
  1089. for _, a in ipairs(args) do
  1090. table.insert(redis_args, a)
  1091. end
  1092. end
  1093. end
  1094. if params.task then
  1095. if not rspamd_redis_make_request(params.task, script.redis_params,
  1096. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1097. callback('Cannot make redis request', nil)
  1098. end
  1099. else
  1100. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  1101. script.redis_params,
  1102. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1103. callback('Cannot make redis request', nil)
  1104. end
  1105. end
  1106. end
  1107. if script.loaded then
  1108. do_call(true)
  1109. else
  1110. -- Delayed until scripts are loaded
  1111. if not params.task then
  1112. table.insert(script.waitq, do_call)
  1113. else
  1114. -- TODO: fix taskfull requests
  1115. callback('NOSCRIPT', nil)
  1116. end
  1117. end
  1118. return true
  1119. end
  1120. exports.exec_redis_script = exec_redis_script
  1121. local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
  1122. if not redis_params then
  1123. return false,nil
  1124. end
  1125. local rspamd_redis = require "rspamd_redis"
  1126. local addr
  1127. if key then
  1128. if is_write then
  1129. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  1130. else
  1131. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  1132. end
  1133. else
  1134. if is_write then
  1135. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  1136. else
  1137. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  1138. end
  1139. end
  1140. if not addr then
  1141. logger.errx(cfg, 'cannot select server to make redis request')
  1142. end
  1143. local options = {
  1144. host = addr:get_addr(),
  1145. timeout = redis_params['timeout'],
  1146. config = cfg or rspamd_config,
  1147. ev_base = ev_base or rspamadm_ev_base,
  1148. session = redis_params.session or rspamadm_session
  1149. }
  1150. for k,v in pairs(redis_params) do
  1151. options[k] = v
  1152. end
  1153. if not options.config then
  1154. logger.errx('config is not set')
  1155. return false,nil,addr
  1156. end
  1157. if not options.ev_base then
  1158. logger.errx('ev_base is not set')
  1159. return false,nil,addr
  1160. end
  1161. if not options.session then
  1162. logger.errx('session is not set')
  1163. return false,nil,addr
  1164. end
  1165. local ret,conn = rspamd_redis.connect_sync(options)
  1166. if not ret then
  1167. logger.errx('cannot execute redis request: %s', conn)
  1168. addr:fail()
  1169. return false,nil,addr
  1170. end
  1171. if conn then
  1172. if redis_params['password'] then
  1173. conn:add_cmd('AUTH', {redis_params['password']})
  1174. end
  1175. if redis_params['db'] then
  1176. conn:add_cmd('SELECT', {tostring(redis_params['db'])})
  1177. elseif redis_params['dbname'] then
  1178. conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
  1179. end
  1180. end
  1181. return ret,conn,addr
  1182. end
  1183. exports.redis_connect_sync = redis_connect_sync
  1184. --[[[
  1185. -- @function lua_redis.request(redis_params, attrs, req)
  1186. -- Sends a request to Redis synchronously with coroutines or asynchronously using
  1187. -- a callback (modern API)
  1188. -- @param redis_params a table of redis server parameters
  1189. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1190. -- @param req a table of request: a command + command options
  1191. -- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
  1192. --]]
  1193. exports.request = function(redis_params, attrs, req)
  1194. local lua_util = require "lua_util"
  1195. if not attrs or not redis_params or not req then
  1196. logger.errx('invalid arguments for redis request')
  1197. return false,nil,nil
  1198. end
  1199. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1200. logger.errx('invalid attributes for redis request')
  1201. return false,nil,nil
  1202. end
  1203. local opts = lua_util.shallowcopy(attrs)
  1204. local log_obj = opts.task or opts.config
  1205. local addr
  1206. if opts.callback then
  1207. -- Wrap callback
  1208. local callback = opts.callback
  1209. local function rspamd_redis_make_request_cb(err, data)
  1210. if err then
  1211. addr:fail()
  1212. else
  1213. addr:ok()
  1214. end
  1215. callback(err, data, addr)
  1216. end
  1217. opts.callback = rspamd_redis_make_request_cb
  1218. end
  1219. local rspamd_redis = require "rspamd_redis"
  1220. local is_write = opts.is_write
  1221. if opts.key then
  1222. if is_write then
  1223. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1224. else
  1225. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1226. end
  1227. else
  1228. if is_write then
  1229. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1230. else
  1231. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1232. end
  1233. end
  1234. if not addr then
  1235. logger.errx(log_obj, 'cannot select server to make redis request')
  1236. end
  1237. opts.host = addr:get_addr()
  1238. opts.timeout = redis_params.timeout
  1239. if type(req) == 'string' then
  1240. opts.cmd = req
  1241. else
  1242. -- XXX: modifies the input table
  1243. opts.cmd = table.remove(req, 1);
  1244. opts.args = req
  1245. end
  1246. if redis_params.password then
  1247. opts.password = redis_params.password
  1248. end
  1249. if redis_params.db then
  1250. opts.dbname = redis_params.db
  1251. end
  1252. lutil.debugm(N, 'perform generic request to redis server' ..
  1253. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr,
  1254. opts.timeout, opts.cmd, opts.args)
  1255. if opts.callback then
  1256. local ret,conn = rspamd_redis.make_request(opts)
  1257. if not ret then
  1258. logger.errx(log_obj, 'cannot execute redis request')
  1259. addr:fail()
  1260. end
  1261. return ret,conn,addr
  1262. else
  1263. -- Coroutines version
  1264. local ret,conn = rspamd_redis.connect_sync(opts)
  1265. if not ret then
  1266. logger.errx(log_obj, 'cannot execute redis request')
  1267. addr:fail()
  1268. else
  1269. conn:add_cmd(opts.cmd, opts.args)
  1270. return conn:exec()
  1271. end
  1272. return false,nil,addr
  1273. end
  1274. end
  1275. --[[[
  1276. -- @function lua_redis.connect(redis_params, attrs)
  1277. -- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
  1278. -- @param redis_params a table of redis server parameters
  1279. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1280. -- @return {result,connection,address} boolean result, connection object, redis server address
  1281. --]]
  1282. exports.connect = function(redis_params, attrs)
  1283. local lua_util = require "lua_util"
  1284. if not attrs or not redis_params then
  1285. logger.errx('invalid arguments for redis connect')
  1286. return false,nil,nil
  1287. end
  1288. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1289. logger.errx('invalid attributes for redis connect')
  1290. return false,nil,nil
  1291. end
  1292. local opts = lua_util.shallowcopy(attrs)
  1293. local log_obj = opts.task or opts.config
  1294. local addr
  1295. if opts.callback then
  1296. -- Wrap callback
  1297. local callback = opts.callback
  1298. local function rspamd_redis_make_request_cb(err, data)
  1299. if err then
  1300. addr:fail()
  1301. else
  1302. addr:ok()
  1303. end
  1304. callback(err, data, addr)
  1305. end
  1306. opts.callback = rspamd_redis_make_request_cb
  1307. end
  1308. local rspamd_redis = require "rspamd_redis"
  1309. local is_write = opts.is_write
  1310. if opts.key then
  1311. if is_write then
  1312. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1313. else
  1314. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1315. end
  1316. else
  1317. if is_write then
  1318. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1319. else
  1320. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1321. end
  1322. end
  1323. if not addr then
  1324. logger.errx(log_obj, 'cannot select server to make redis connect')
  1325. end
  1326. opts.host = addr:get_addr()
  1327. opts.timeout = redis_params.timeout
  1328. if redis_params.password then
  1329. opts.password = redis_params.password
  1330. end
  1331. if redis_params.db then
  1332. opts.dbname = redis_params.db
  1333. end
  1334. if opts.callback then
  1335. local ret,conn = rspamd_redis.connect(opts)
  1336. if not ret then
  1337. logger.errx(log_obj, 'cannot execute redis connect')
  1338. addr:fail()
  1339. end
  1340. return ret,conn,addr
  1341. else
  1342. -- Coroutines version
  1343. local ret,conn = rspamd_redis.connect_sync(opts)
  1344. if not ret then
  1345. logger.errx(log_obj, 'cannot execute redis connect')
  1346. addr:fail()
  1347. else
  1348. return true,conn,addr
  1349. end
  1350. return false,nil,addr
  1351. end
  1352. end
  1353. local redis_prefixes = {}
  1354. --[[[
  1355. -- @function lua_redis.register_prefix(prefix, module, description[, optional])
  1356. -- Register new redis prefix for documentation purposes
  1357. -- @param {string} prefix string prefix
  1358. -- @param {string} module module name
  1359. -- @param {string} description prefix description
  1360. -- @param {table} optional optional kv pairs (e.g. pattern)
  1361. --]]
  1362. local function register_prefix(prefix, module, description, optional)
  1363. local pr = {
  1364. module = module,
  1365. description = description
  1366. }
  1367. if optional and type(optional) == 'table' then
  1368. for k,v in pairs(optional) do
  1369. pr[k] = v
  1370. end
  1371. end
  1372. redis_prefixes[prefix] = pr
  1373. end
  1374. exports.register_prefix = register_prefix
  1375. --[[[
  1376. -- @function lua_redis.prefixes([mname])
  1377. -- Returns prefixes for specific module (or all prefixes). Returns a table prefix -> table
  1378. --]]
  1379. exports.prefixes = function(mname)
  1380. if not mname then
  1381. return redis_prefixes
  1382. else
  1383. local fun = require "fun"
  1384. return fun.totable(fun.filter(function(_, data) return data.module == mname end,
  1385. redis_prefixes))
  1386. end
  1387. end
  1388. return exports