You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_redis.lua 46KB


  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local ts = require("tableshape").types
  17. local exports = {}
  18. local E = {}
  19. local N = "lua_redis"
  20. local common_schema = ts.shape {
  21. timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  22. db = ts.string:is_optional(),
  23. database = ts.string:is_optional(),
  24. dbname = ts.string:is_optional(),
  25. prefix = ts.string:is_optional(),
  26. password = ts.string:is_optional(),
  27. expand_keys = ts.boolean:is_optional(),
  28. sentinels = (ts.string + ts.array_of(ts.string)):is_optional(),
  29. sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  30. sentinel_masters_pattern = ts.string:is_optional(),
  31. sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional(),
  32. }
  33. local config_schema =
  34. ts.shape({
  35. read_servers = ts.string + ts.array_of(ts.string),
  36. write_servers = ts.string + ts.array_of(ts.string),
  37. }, {extra_opts = common_schema}) +
  38. ts.shape({
  39. servers = ts.string + ts.array_of(ts.string),
  40. }, {extra_opts = common_schema}) +
  41. ts.shape({
  42. server = ts.string + ts.array_of(ts.string),
  43. }, {extra_opts = common_schema})
  44. exports.config_schema = config_schema
  45. local function redis_query_sentinel(ev_base, params, initialised)
  46. local function flatten_redis_table(tbl)
  47. local res = {}
  48. for i=1,#tbl,2 do
  49. res[tbl[i]] = tbl[i + 1]
  50. end
  51. return res
  52. end
  53. -- Coroutines syntax
  54. local rspamd_redis = require "rspamd_redis"
  55. local addr = params.sentinels:get_upstream_round_robin()
  56. local is_ok, connection = rspamd_redis.connect_sync({
  57. host = addr:get_addr(),
  58. timeout = params.timeout,
  59. config = rspamd_config,
  60. ev_base = ev_base,
  61. no_pool = true,
  62. })
  63. if not is_ok then
  64. logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
  65. tostring(addr:get_addr()))
  66. addr:fail()
  67. return
  68. end
  69. -- Get masters list
  70. connection:add_cmd('SENTINEL', {'masters'})
  71. local ok,result = connection:exec()
  72. if ok and result and type(result) == 'table' then
  73. local masters = {}
  74. for _,m in ipairs(result) do
  75. local master = flatten_redis_table(m)
  76. -- Wrap IPv6-adresses in brackets
  77. if (master.ip:match(":")) then
  78. master.ip = "["..master.ip.."]"
  79. end
  80. if params.sentinel_masters_pattern then
  81. if master.name:match(params.sentinel_masters_pattern) then
  82. lutil.debugm(N, 'found master %s with ip %s and port %s',
  83. master.name, master.ip, master.port)
  84. masters[master.name] = master
  85. else
  86. lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
  87. master.name, master.ip, master.port, params.sentinel_masters_pattern)
  88. end
  89. else
  90. lutil.debugm(N, 'found master %s with ip %s and port %s',
  91. master.name, master.ip, master.port)
  92. masters[master.name] = master
  93. end
  94. end
  95. -- For each master we need to get a list of slaves
  96. for k,v in pairs(masters) do
  97. v.slaves = {}
  98. local slave_result
  99. connection:add_cmd('SENTINEL', {'slaves', k})
  100. ok,slave_result = connection:exec()
  101. if ok then
  102. for _,s in ipairs(slave_result) do
  103. local slave = flatten_redis_table(s)
  104. lutil.debugm(N, rspamd_config,
  105. 'found slave form master %s with ip %s and port %s',
  106. v.name, slave.ip, slave.port)
  107. -- Wrap IPv6-adresses in brackets
  108. if (slave.ip:match(":")) then
  109. slave.ip = "["..slave.ip.."]"
  110. end
  111. v.slaves[#v.slaves + 1] = slave
  112. end
  113. end
  114. end
  115. -- We now form new strings for masters and slaves
  116. local read_servers_tbl, write_servers_tbl = {}, {}
  117. for _,master in pairs(masters) do
  118. write_servers_tbl[#write_servers_tbl + 1] = string.format(
  119. '%s:%s', master.ip, master.port
  120. )
  121. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  122. '%s:%s', master.ip, master.port
  123. )
  124. for _,slave in ipairs(master.slaves) do
  125. if slave['master-link-status'] == 'ok' then
  126. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  127. '%s:%s', slave.ip, slave.port
  128. )
  129. end
  130. end
  131. end
  132. table.sort(read_servers_tbl)
  133. table.sort(write_servers_tbl)
  134. local read_servers_str = table.concat(read_servers_tbl, ',')
  135. local write_servers_str = table.concat(write_servers_tbl, ',')
  136. lutil.debugm(N, rspamd_config,
  137. 'new servers list: %s read; %s write',
  138. read_servers_str,
  139. write_servers_str)
  140. if read_servers_str ~= params.read_servers_str then
  141. local upstream_list = require "rspamd_upstream_list"
  142. local read_upstreams = upstream_list.create(rspamd_config,
  143. read_servers_str, 6379)
  144. if read_upstreams then
  145. logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
  146. addr:get_addr():to_string(true), read_servers_str)
  147. params.read_servers = read_upstreams
  148. params.read_servers_str = read_servers_str
  149. end
  150. end
  151. if write_servers_str ~= params.write_servers_str then
  152. local upstream_list = require "rspamd_upstream_list"
  153. local write_upstreams = upstream_list.create(rspamd_config,
  154. write_servers_str, 6379)
  155. if write_upstreams then
  156. logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
  157. addr:get_addr():to_string(true), write_servers_str)
  158. params.write_servers = write_upstreams
  159. params.write_servers_str = write_servers_str
  160. local queried = false
  161. local function monitor_failures(up, _, count)
  162. if count > params.sentinel_master_maxerrors and not queried then
  163. logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel',
  164. up:get_addr():to_string(true), count)
  165. queried = true -- Avoid multiple checks caused by this monitor
  166. redis_query_sentinel(ev_base, params, true)
  167. end
  168. end
  169. write_upstreams:add_watcher('failure', monitor_failures)
  170. end
  171. end
  172. addr:ok()
  173. else
  174. logger.errx('cannot get data from Redis Sentinel %s: %s',
  175. addr:get_addr():to_string(true), result)
  176. addr:fail()
  177. end
  178. end
  179. local function add_redis_sentinels(params)
  180. local upstream_list = require "rspamd_upstream_list"
  181. local upstreams_sentinels = upstream_list.create(rspamd_config,
  182. params.sentinels, 5000)
  183. if not upstreams_sentinels then
  184. logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
  185. params.sentinels)
  186. return
  187. end
  188. params.sentinels = upstreams_sentinels
  189. if not params.sentinel_watch_time then
  190. params.sentinel_watch_time = 60 -- Each minute
  191. end
  192. if not params.sentinel_master_maxerrors then
  193. params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking
  194. end
  195. rspamd_config:add_on_load(function(_, ev_base, worker)
  196. local initialised = false
  197. if worker:is_scanner() then
  198. rspamd_config:add_periodic(ev_base, 0.0, function()
  199. redis_query_sentinel(ev_base, params, initialised)
  200. initialised = true
  201. return params.sentinel_watch_time
  202. end, false)
  203. end
  204. end)
  205. end
  206. local cached_results = {}
  207. local function calculate_redis_hash(params)
  208. local cr = require "rspamd_cryptobox_hash"
  209. local h = cr.create()
  210. local function rec_hash(k, v)
  211. if type(v) == 'string' then
  212. h:update(k)
  213. h:update(v)
  214. elseif type(v) == 'number' then
  215. h:update(k)
  216. h:update(tostring(v))
  217. elseif type(v) == 'table' then
  218. for kk,vv in pairs(v) do
  219. rec_hash(kk, vv)
  220. end
  221. end
  222. end
  223. rec_hash('top', params)
  224. return h:base32()
  225. end
  226. local function process_redis_opts(options, redis_params)
  227. local default_timeout = 1.0
  228. local default_expand_keys = false
  229. if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then
  230. if options['timeout'] then
  231. redis_params['timeout'] = tonumber(options['timeout'])
  232. else
  233. redis_params['timeout'] = default_timeout
  234. end
  235. end
  236. if options['prefix'] and not redis_params['prefix'] then
  237. redis_params['prefix'] = options['prefix']
  238. end
  239. if type(options['expand_keys']) == 'boolean' then
  240. redis_params['expand_keys'] = options['expand_keys']
  241. else
  242. redis_params['expand_keys'] = default_expand_keys
  243. end
  244. if not redis_params['db'] then
  245. if options['db'] then
  246. redis_params['db'] = tostring(options['db'])
  247. elseif options['dbname'] then
  248. redis_params['db'] = tostring(options['dbname'])
  249. elseif options['database'] then
  250. redis_params['db'] = tostring(options['database'])
  251. end
  252. end
  253. if options['password'] and not redis_params['password'] then
  254. redis_params['password'] = options['password']
  255. end
  256. if not redis_params.sentinels and options.sentinels then
  257. redis_params.sentinels = options.sentinels
  258. end
  259. end
  260. local function enrich_defaults(rspamd_config, module, redis_params)
  261. if rspamd_config then
  262. local opts = rspamd_config:get_all_opt('redis')
  263. if opts then
  264. if module then
  265. if opts[module] then
  266. process_redis_opts(opts[module], redis_params)
  267. end
  268. end
  269. process_redis_opts(opts, redis_params)
  270. end
  271. end
  272. end
  273. local function maybe_return_cached(redis_params)
  274. local h = calculate_redis_hash(redis_params)
  275. if cached_results[h] then
  276. lutil.debugm(N, 'reused redis server: %s', redis_params)
  277. return cached_results[h]
  278. end
  279. redis_params.hash = h
  280. cached_results[h] = redis_params
  281. if not redis_params.read_only and redis_params.sentinels then
  282. add_redis_sentinels(redis_params)
  283. end
  284. lutil.debugm(N, 'loaded new redis server: %s', redis_params)
  285. return redis_params
  286. end
  287. --[[[
  288. -- @module lua_redis
  289. -- This module contains helper functions for working with Redis
  290. --]]
  291. local function process_redis_options(options, rspamd_config, result)
  292. local default_port = 6379
  293. local upstream_list = require "rspamd_upstream_list"
  294. local read_only = true
  295. -- Try to get read servers:
  296. local upstreams_read, upstreams_write
  297. if options['read_servers'] then
  298. if rspamd_config then
  299. upstreams_read = upstream_list.create(rspamd_config,
  300. options['read_servers'], default_port)
  301. else
  302. upstreams_read = upstream_list.create(options['read_servers'],
  303. default_port)
  304. end
  305. result.read_servers_str = options['read_servers']
  306. elseif options['servers'] then
  307. if rspamd_config then
  308. upstreams_read = upstream_list.create(rspamd_config,
  309. options['servers'], default_port)
  310. else
  311. upstreams_read = upstream_list.create(options['servers'], default_port)
  312. end
  313. result.read_servers_str = options['servers']
  314. read_only = false
  315. elseif options['server'] then
  316. if rspamd_config then
  317. upstreams_read = upstream_list.create(rspamd_config,
  318. options['server'], default_port)
  319. else
  320. upstreams_read = upstream_list.create(options['server'], default_port)
  321. end
  322. result.read_servers_str = options['server']
  323. read_only = false
  324. end
  325. if upstreams_read then
  326. if options['write_servers'] then
  327. if rspamd_config then
  328. upstreams_write = upstream_list.create(rspamd_config,
  329. options['write_servers'], default_port)
  330. else
  331. upstreams_write = upstream_list.create(options['write_servers'],
  332. default_port)
  333. end
  334. result.write_servers_str = options['write_servers']
  335. read_only = false
  336. elseif not read_only then
  337. upstreams_write = upstreams_read
  338. result.write_servers_str = result.read_servers_str
  339. end
  340. end
  341. -- Store options
  342. process_redis_opts(options, result)
  343. if read_only and not upstreams_write then
  344. result.read_only = true
  345. elseif upstreams_write then
  346. result.read_only = false
  347. end
  348. if upstreams_read then
  349. result.read_servers = upstreams_read
  350. if upstreams_write then
  351. result.write_servers = upstreams_write
  352. end
  353. return true
  354. end
  355. lutil.debugm(N, rspamd_config,
  356. 'cannot load redis server from obj: %s, processed to %s',
  357. options, result)
  358. return false
  359. end
  360. --[[[
  361. @function try_load_redis_servers(options, rspamd_config, no_fallback)
  362. Tries to load redis servers from the specified `options` object.
  363. Returns `redis_params` table or nil in case of failure
  364. --]]
  365. exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name)
  366. local result = {}
  367. if process_redis_options(options, rspamd_config, result) then
  368. if not no_fallback then
  369. enrich_defaults(rspamd_config, module_name, result)
  370. end
  371. return maybe_return_cached(result)
  372. end
  373. end
  374. -- This function parses redis server definition using either
  375. -- specific server string for this module or global
  376. -- redis section
  377. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  378. local result = {}
  379. -- Try local options
  380. local opts
  381. lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name)
  382. if not module_opts then
  383. opts = rspamd_config:get_all_opt(module_name)
  384. else
  385. opts = module_opts
  386. end
  387. if opts then
  388. local ret
  389. if opts.redis then
  390. ret = process_redis_options(opts.redis, rspamd_config, result)
  391. if ret then
  392. if not no_fallback then
  393. enrich_defaults(rspamd_config, module_name, result)
  394. end
  395. return maybe_return_cached(result)
  396. end
  397. end
  398. ret = process_redis_options(opts, rspamd_config, result)
  399. if ret then
  400. if not no_fallback then
  401. enrich_defaults(rspamd_config, module_name, result)
  402. end
  403. return maybe_return_cached(result)
  404. end
  405. end
  406. if no_fallback then
  407. logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled",
  408. module_name)
  409. return nil
  410. end
  411. -- Try global options
  412. opts = rspamd_config:get_all_opt('redis')
  413. if opts then
  414. local ret
  415. if opts[module_name] then
  416. ret = process_redis_options(opts[module_name], rspamd_config, result)
  417. if ret then
  418. return maybe_return_cached(result)
  419. end
  420. else
  421. ret = process_redis_options(opts, rspamd_config, result)
  422. -- Exclude disabled
  423. if opts['disabled_modules'] then
  424. for _,v in ipairs(opts['disabled_modules']) do
  425. if v == module_name then
  426. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  427. module_name)
  428. return nil
  429. end
  430. end
  431. end
  432. if ret then
  433. logger.infox(rspamd_config, "use default Redis settings for %s",
  434. module_name)
  435. return maybe_return_cached(result)
  436. end
  437. end
  438. end
  439. if result.read_servers then
  440. return maybe_return_cached(result)
  441. end
  442. return nil
  443. end
  444. --[[[
  445. -- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
  446. -- Extracts Redis server settings from configuration
  447. -- @param {string} module_name name of module to get settings for
  448. -- @param {table} module_opts settings for module or `nil` to fetch them from configuration
  449. -- @param {boolean} no_fallback should be `true` if global settings must not be used
  450. -- @return {table} redis server settings
  451. -- @example
  452. -- local rconfig = lua_redis.parse_redis_server('my_module')
  453. -- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
  454. -- -- ['timeout'] contains timeout in seconds
  455. -- -- ['expand_keys'] if true tells that redis key expansion is enabled
  456. --]]
  457. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  458. exports.parse_redis_server = rspamd_parse_redis_server
  459. local process_cmd = {
  460. bitop = function(args)
  461. local idx_l = {}
  462. for i = 2, #args do
  463. table.insert(idx_l, i)
  464. end
  465. return idx_l
  466. end,
  467. blpop = function(args)
  468. local idx_l = {}
  469. for i = 1, #args -1 do
  470. table.insert(idx_l, i)
  471. end
  472. return idx_l
  473. end,
  474. eval = function(args)
  475. local idx_l = {}
  476. local numkeys = args[2]
  477. if numkeys and tonumber(numkeys) >= 1 then
  478. for i = 3, numkeys + 2 do
  479. table.insert(idx_l, i)
  480. end
  481. end
  482. return idx_l
  483. end,
  484. set = function(args)
  485. return {1}
  486. end,
  487. mget = function(args)
  488. local idx_l = {}
  489. for i = 1, #args do
  490. table.insert(idx_l, i)
  491. end
  492. return idx_l
  493. end,
  494. mset = function(args)
  495. local idx_l = {}
  496. for i = 1, #args, 2 do
  497. table.insert(idx_l, i)
  498. end
  499. return idx_l
  500. end,
  501. sdiffstore = function(args)
  502. local idx_l = {}
  503. for i = 2, #args do
  504. table.insert(idx_l, i)
  505. end
  506. return idx_l
  507. end,
  508. smove = function(args)
  509. return {1, 2}
  510. end,
  511. script = function() end
  512. }
  513. process_cmd.append = process_cmd.set
  514. process_cmd.auth = process_cmd.script
  515. process_cmd.bgrewriteaof = process_cmd.script
  516. process_cmd.bgsave = process_cmd.script
  517. process_cmd.bitcount = process_cmd.set
  518. process_cmd.bitfield = process_cmd.set
  519. process_cmd.bitpos = process_cmd.set
  520. process_cmd.brpop = process_cmd.blpop
  521. process_cmd.brpoplpush = process_cmd.blpop
  522. process_cmd.client = process_cmd.script
  523. process_cmd.cluster = process_cmd.script
  524. process_cmd.command = process_cmd.script
  525. process_cmd.config = process_cmd.script
  526. process_cmd.dbsize = process_cmd.script
  527. process_cmd.debug = process_cmd.script
  528. process_cmd.decr = process_cmd.set
  529. process_cmd.decrby = process_cmd.set
  530. process_cmd.del = process_cmd.mget
  531. process_cmd.discard = process_cmd.script
  532. process_cmd.dump = process_cmd.set
  533. process_cmd.echo = process_cmd.script
  534. process_cmd.evalsha = process_cmd.eval
  535. process_cmd.exec = process_cmd.script
  536. process_cmd.exists = process_cmd.mget
  537. process_cmd.expire = process_cmd.set
  538. process_cmd.expireat = process_cmd.set
  539. process_cmd.flushall = process_cmd.script
  540. process_cmd.flushdb = process_cmd.script
  541. process_cmd.geoadd = process_cmd.set
  542. process_cmd.geohash = process_cmd.set
  543. process_cmd.geopos = process_cmd.set
  544. process_cmd.geodist = process_cmd.set
  545. process_cmd.georadius = process_cmd.set
  546. process_cmd.georadiusbymember = process_cmd.set
  547. process_cmd.get = process_cmd.set
  548. process_cmd.getbit = process_cmd.set
  549. process_cmd.getrange = process_cmd.set
  550. process_cmd.getset = process_cmd.set
  551. process_cmd.hdel = process_cmd.set
  552. process_cmd.hexists = process_cmd.set
  553. process_cmd.hget = process_cmd.set
  554. process_cmd.hgetall = process_cmd.set
  555. process_cmd.hincrby = process_cmd.set
  556. process_cmd.hincrbyfloat = process_cmd.set
  557. process_cmd.hkeys = process_cmd.set
  558. process_cmd.hlen = process_cmd.set
  559. process_cmd.hmget = process_cmd.set
  560. process_cmd.hmset = process_cmd.set
  561. process_cmd.hscan = process_cmd.set
  562. process_cmd.hset = process_cmd.set
  563. process_cmd.hsetnx = process_cmd.set
  564. process_cmd.hstrlen = process_cmd.set
  565. process_cmd.hvals = process_cmd.set
  566. process_cmd.incr = process_cmd.set
  567. process_cmd.incrby = process_cmd.set
  568. process_cmd.incrbyfloat = process_cmd.set
  569. process_cmd.info = process_cmd.script
  570. process_cmd.keys = process_cmd.script
  571. process_cmd.lastsave = process_cmd.script
  572. process_cmd.lindex = process_cmd.set
  573. process_cmd.linsert = process_cmd.set
  574. process_cmd.llen = process_cmd.set
  575. process_cmd.lpop = process_cmd.set
  576. process_cmd.lpush = process_cmd.set
  577. process_cmd.lpushx = process_cmd.set
  578. process_cmd.lrange = process_cmd.set
  579. process_cmd.lrem = process_cmd.set
  580. process_cmd.lset = process_cmd.set
  581. process_cmd.ltrim = process_cmd.set
  582. process_cmd.migrate = process_cmd.script
  583. process_cmd.monitor = process_cmd.script
  584. process_cmd.move = process_cmd.set
  585. process_cmd.msetnx = process_cmd.mset
  586. process_cmd.multi = process_cmd.script
  587. process_cmd.object = process_cmd.script
  588. process_cmd.persist = process_cmd.set
  589. process_cmd.pexpire = process_cmd.set
  590. process_cmd.pexpireat = process_cmd.set
  591. process_cmd.pfadd = process_cmd.set
  592. process_cmd.pfcount = process_cmd.set
  593. process_cmd.pfmerge = process_cmd.mget
  594. process_cmd.ping = process_cmd.script
  595. process_cmd.psetex = process_cmd.set
  596. process_cmd.psubscribe = process_cmd.script
  597. process_cmd.pubsub = process_cmd.script
  598. process_cmd.pttl = process_cmd.set
  599. process_cmd.publish = process_cmd.script
  600. process_cmd.punsubscribe = process_cmd.script
  601. process_cmd.quit = process_cmd.script
  602. process_cmd.randomkey = process_cmd.script
  603. process_cmd.readonly = process_cmd.script
  604. process_cmd.readwrite = process_cmd.script
  605. process_cmd.rename = process_cmd.mget
  606. process_cmd.renamenx = process_cmd.mget
  607. process_cmd.restore = process_cmd.set
  608. process_cmd.role = process_cmd.script
  609. process_cmd.rpop = process_cmd.set
  610. process_cmd.rpoplpush = process_cmd.mget
  611. process_cmd.rpush = process_cmd.set
  612. process_cmd.rpushx = process_cmd.set
  613. process_cmd.sadd = process_cmd.set
  614. process_cmd.save = process_cmd.script
  615. process_cmd.scard = process_cmd.set
  616. process_cmd.sdiff = process_cmd.mget
  617. process_cmd.select = process_cmd.script
  618. process_cmd.setbit = process_cmd.set
  619. process_cmd.setex = process_cmd.set
  620. process_cmd.setnx = process_cmd.set
  621. process_cmd.sinterstore = process_cmd.sdiff
  622. process_cmd.sismember = process_cmd.set
  623. process_cmd.slaveof = process_cmd.script
  624. process_cmd.slowlog = process_cmd.script
  625. process_cmd.smembers = process_cmd.script
  626. process_cmd.sort = process_cmd.set
  627. process_cmd.spop = process_cmd.set
  628. process_cmd.srandmember = process_cmd.set
  629. process_cmd.srem = process_cmd.set
  630. process_cmd.strlen = process_cmd.set
  631. process_cmd.subscribe = process_cmd.script
  632. process_cmd.sunion = process_cmd.mget
  633. process_cmd.sunionstore = process_cmd.mget
  634. process_cmd.swapdb = process_cmd.script
  635. process_cmd.sync = process_cmd.script
  636. process_cmd.time = process_cmd.script
  637. process_cmd.touch = process_cmd.mget
  638. process_cmd.ttl = process_cmd.set
  639. process_cmd.type = process_cmd.set
  640. process_cmd.unsubscribe = process_cmd.script
  641. process_cmd.unlink = process_cmd.mget
  642. process_cmd.unwatch = process_cmd.script
  643. process_cmd.wait = process_cmd.script
  644. process_cmd.watch = process_cmd.mget
  645. process_cmd.zadd = process_cmd.set
  646. process_cmd.zcard = process_cmd.set
  647. process_cmd.zcount = process_cmd.set
  648. process_cmd.zincrby = process_cmd.set
  649. process_cmd.zinterstore = process_cmd.eval
  650. process_cmd.zlexcount = process_cmd.set
  651. process_cmd.zrange = process_cmd.set
  652. process_cmd.zrangebylex = process_cmd.set
  653. process_cmd.zrank = process_cmd.set
  654. process_cmd.zrem = process_cmd.set
  655. process_cmd.zrembylex = process_cmd.set
  656. process_cmd.zrembyrank = process_cmd.set
  657. process_cmd.zrembyscore = process_cmd.set
  658. process_cmd.zrevrange = process_cmd.set
  659. process_cmd.zrevrangebyscore = process_cmd.set
  660. process_cmd.zrevrank = process_cmd.set
  661. process_cmd.zscore = process_cmd.set
  662. process_cmd.zunionstore = process_cmd.eval
  663. process_cmd.scan = process_cmd.script
  664. process_cmd.sscan = process_cmd.set
  665. process_cmd.hscan = process_cmd.set
  666. process_cmd.zscan = process_cmd.set
  667. local function get_key_indexes(cmd, args)
  668. local idx_l = {}
  669. cmd = string.lower(cmd)
  670. if process_cmd[cmd] then
  671. idx_l = process_cmd[cmd](args)
  672. else
  673. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  674. end
  675. return idx_l
  676. end
  677. local gen_meta = {
  678. principal_recipient = function(task)
  679. return task:get_principal_recipient()
  680. end,
  681. principal_recipient_domain = function(task)
  682. local p = task:get_principal_recipient()
  683. if not p then return end
  684. return string.match(p, '.*@(.*)')
  685. end,
  686. ip = function(task)
  687. local i = task:get_ip()
  688. if i and i:is_valid() then return i:to_string() end
  689. end,
  690. from = function(task)
  691. return ((task:get_from('smtp') or E)[1] or E)['addr']
  692. end,
  693. from_domain = function(task)
  694. return ((task:get_from('smtp') or E)[1] or E)['domain']
  695. end,
  696. from_domain_or_helo_domain = function(task)
  697. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  698. if d and #d > 0 then return d end
  699. return task:get_helo()
  700. end,
  701. mime_from = function(task)
  702. return ((task:get_from('mime') or E)[1] or E)['addr']
  703. end,
  704. mime_from_domain = function(task)
  705. return ((task:get_from('mime') or E)[1] or E)['domain']
  706. end,
  707. }
  708. local function gen_get_esld(f)
  709. return function(task)
  710. local d = f(task)
  711. if not d then return end
  712. return rspamd_util.get_tld(d)
  713. end
  714. end
  715. gen_meta.smtp_from = gen_meta.from
  716. gen_meta.smtp_from_domain = gen_meta.from_domain
  717. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  718. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  719. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  720. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  721. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  722. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  723. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  724. local function get_key_expansion_metadata(task)
  725. local md_mt = {
  726. __index = function(self, k)
  727. k = string.lower(k)
  728. local v = rawget(self, k)
  729. if v then
  730. return v
  731. end
  732. if gen_meta[k] then
  733. v = gen_meta[k](task)
  734. rawset(self, k, v)
  735. end
  736. return v
  737. end,
  738. }
  739. local lazy_meta = {}
  740. setmetatable(lazy_meta, md_mt)
  741. return lazy_meta
  742. end
  743. -- Performs async call to redis hiding all complexity inside function
  744. -- task - rspamd_task
  745. -- redis_params - valid params returned by rspamd_parse_redis_server
  746. -- key - key to select upstream or nil to select round-robin/master-slave
  747. -- is_write - true if need to write to redis server
  748. -- callback - function to be called upon request is completed
  749. -- command - redis command
  750. -- args - table of arguments
  751. -- extra_opts - table of optional request arguments
  752. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  753. callback, command, args, extra_opts)
  754. local addr
  755. local function rspamd_redis_make_request_cb(err, data)
  756. if err then
  757. addr:fail()
  758. else
  759. addr:ok()
  760. end
  761. if callback then
  762. callback(err, data, addr)
  763. end
  764. end
  765. if not task or not redis_params or not callback or not command then
  766. return false,nil,nil
  767. end
  768. local rspamd_redis = require "rspamd_redis"
  769. if key then
  770. if is_write then
  771. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  772. else
  773. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  774. end
  775. else
  776. if is_write then
  777. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  778. else
  779. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  780. end
  781. end
  782. if not addr then
  783. logger.errx(task, 'cannot select server to make redis request')
  784. end
  785. if redis_params['expand_keys'] then
  786. local m = get_key_expansion_metadata(task)
  787. local indexes = get_key_indexes(command, args)
  788. for _, i in ipairs(indexes) do
  789. args[i] = lutil.template(args[i], m)
  790. end
  791. end
  792. local ip_addr = addr:get_addr()
  793. local options = {
  794. task = task,
  795. callback = rspamd_redis_make_request_cb,
  796. host = ip_addr,
  797. timeout = redis_params['timeout'],
  798. cmd = command,
  799. args = args
  800. }
  801. if extra_opts then
  802. for k,v in pairs(extra_opts) do
  803. options[k] = v
  804. end
  805. end
  806. if redis_params['password'] then
  807. options['password'] = redis_params['password']
  808. end
  809. if redis_params['db'] then
  810. options['dbname'] = redis_params['db']
  811. end
  812. lutil.debugm(N, task, 'perform request to redis server' ..
  813. ' (host=%s, timeout=%s): cmd: %s', ip_addr,
  814. options.timeout, options.cmd)
  815. local ret,conn = rspamd_redis.make_request(options)
  816. if not ret then
  817. addr:fail()
  818. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  819. end
  820. return ret,conn,addr
  821. end
  822. --[[[
  823. -- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
  824. -- Sends a request to Redis
  825. -- @param {rspamd_task} task task object
  826. -- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
  827. -- @param {string} key key to use for sharding
  828. -- @param {boolean} is_write should be `true` if we are performing a write operating
  829. -- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
  830. -- @param {string} command Redis command to run
  831. -- @param {table} args Numerically indexed table containing arguments for command
  832. --]]
  833. exports.rspamd_redis_make_request = rspamd_redis_make_request
  834. exports.redis_make_request = rspamd_redis_make_request
  835. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  836. is_write, callback, command, args, extra_opts)
  837. if not ev_base or not redis_params or not callback or not command then
  838. return false,nil,nil
  839. end
  840. local addr
  841. local function rspamd_redis_make_request_cb(err, data)
  842. if err then
  843. addr:fail()
  844. else
  845. addr:ok()
  846. end
  847. if callback then
  848. callback(err, data, addr)
  849. end
  850. end
  851. local rspamd_redis = require "rspamd_redis"
  852. if key then
  853. if is_write then
  854. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  855. else
  856. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  857. end
  858. else
  859. if is_write then
  860. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  861. else
  862. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  863. end
  864. end
  865. if not addr then
  866. logger.errx(cfg, 'cannot select server to make redis request')
  867. end
  868. local options = {
  869. ev_base = ev_base,
  870. config = cfg,
  871. callback = rspamd_redis_make_request_cb,
  872. host = addr:get_addr(),
  873. timeout = redis_params['timeout'],
  874. cmd = command,
  875. args = args
  876. }
  877. if extra_opts then
  878. for k,v in pairs(extra_opts) do
  879. options[k] = v
  880. end
  881. end
  882. if redis_params['password'] then
  883. options['password'] = redis_params['password']
  884. end
  885. if redis_params['db'] then
  886. options['dbname'] = redis_params['db']
  887. end
  888. lutil.debugm(N, cfg, 'perform taskless request to redis server' ..
  889. ' (host=%s, timeout=%s): cmd: %s', options.host,
  890. options.timeout, options.cmd)
  891. local ret,conn = rspamd_redis.make_request(options)
  892. if not ret then
  893. logger.errx('cannot execute redis request')
  894. addr:fail()
  895. end
  896. return ret,conn,addr
  897. end
  898. --[[[
  899. -- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
  900. -- Sends a request to Redis in context where `task` is not available for some specific use-cases
  901. -- Identical to redis_make_request() except in that first parameter is an `event base` object
  902. --]]
  903. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  904. exports.redis_make_request_taskless = redis_make_request_taskless
  905. local redis_scripts = {
  906. }
  907. local function script_set_loaded(script)
  908. if script.sha then
  909. script.loaded = true
  910. end
  911. local wait_table = {}
  912. for _,s in ipairs(script.waitq) do
  913. table.insert(wait_table, s)
  914. end
  915. script.waitq = {}
  916. for _,s in ipairs(wait_table) do
  917. s(script.loaded)
  918. end
  919. end
  920. local function prepare_redis_call(script)
  921. local function merge_tables(t1, t2)
  922. for k,v in pairs(t2) do t1[k] = v end
  923. end
  924. local servers = {}
  925. local options = {}
  926. if script.redis_params.read_servers then
  927. merge_tables(servers, script.redis_params.read_servers:all_upstreams())
  928. end
  929. if script.redis_params.write_servers then
  930. merge_tables(servers, script.redis_params.write_servers:all_upstreams())
  931. end
  932. -- Call load script on each server, set loaded flag
  933. script.in_flight = #servers
  934. for _,s in ipairs(servers) do
  935. local cur_opts = {
  936. host = s:get_addr(),
  937. timeout = script.redis_params['timeout'],
  938. cmd = 'SCRIPT',
  939. args = {'LOAD', script.script },
  940. upstream = s
  941. }
  942. if script.redis_params['password'] then
  943. cur_opts['password'] = script.redis_params['password']
  944. end
  945. if script.redis_params['db'] then
  946. cur_opts['dbname'] = script.redis_params['db']
  947. end
  948. table.insert(options, cur_opts)
  949. end
  950. return options
  951. end
  952. local function load_script_task(script, task)
  953. local rspamd_redis = require "rspamd_redis"
  954. local opts = prepare_redis_call(script)
  955. for _,opt in ipairs(opts) do
  956. opt.task = task
  957. opt.callback = function(err, data)
  958. if err then
  959. logger.errx(task, 'cannot upload script to %s: %s; registered from: %s:%s',
  960. opt.upstream:get_addr():to_string(true),
  961. err, script.caller.short_src, script.caller.currentline)
  962. opt.upstream:fail()
  963. script.fatal_error = err
  964. else
  965. opt.upstream:ok()
  966. logger.infox(task,
  967. "uploaded redis script to %s with id %s, sha: %s",
  968. opt.upstream:get_addr():to_string(true),
  969. script.id, data)
  970. script.sha = data -- We assume that sha is the same on all servers
  971. end
  972. script.in_flight = script.in_flight - 1
  973. if script.in_flight == 0 then
  974. script_set_loaded(script)
  975. end
  976. end
  977. local ret = rspamd_redis.make_request(opt)
  978. if not ret then
  979. logger.errx('cannot execute redis request to load script on %s',
  980. opt.upstream:get_addr())
  981. script.in_flight = script.in_flight - 1
  982. opt.upstream:fail()
  983. end
  984. if script.in_flight == 0 then
  985. script_set_loaded(script)
  986. end
  987. end
  988. end
  989. local function load_script_taskless(script, cfg, ev_base)
  990. local rspamd_redis = require "rspamd_redis"
  991. local opts = prepare_redis_call(script)
  992. for _,opt in ipairs(opts) do
  993. opt.config = cfg
  994. opt.ev_base = ev_base
  995. opt.callback = function(err, data)
  996. if err then
  997. logger.errx(cfg, 'cannot upload script to %s: %s; registered from: %s:%s',
  998. opt.upstream:get_addr():to_string(true),
  999. err, script.caller.short_src, script.caller.currentline)
  1000. opt.upstream:fail()
  1001. script.fatal_error = err
  1002. else
  1003. opt.upstream:ok()
  1004. logger.infox(cfg,
  1005. "uploaded redis script to %s with id %s, sha: %s",
  1006. opt.upstream:get_addr():to_string(true), script.id, data)
  1007. script.sha = data -- We assume that sha is the same on all servers
  1008. script.fatal_error = nil
  1009. end
  1010. script.in_flight = script.in_flight - 1
  1011. if script.in_flight == 0 then
  1012. script_set_loaded(script)
  1013. end
  1014. end
  1015. local ret = rspamd_redis.make_request(opt)
  1016. if not ret then
  1017. logger.errx('cannot execute redis request to load script on %s',
  1018. opt.upstream:get_addr())
  1019. script.in_flight = script.in_flight - 1
  1020. opt.upstream:fail()
  1021. end
  1022. if script.in_flight == 0 then
  1023. script_set_loaded(script)
  1024. end
  1025. end
  1026. end
  1027. local function load_redis_script(script, cfg, ev_base, _)
  1028. if script.redis_params then
  1029. load_script_taskless(script, cfg, ev_base)
  1030. end
  1031. end
  1032. local function add_redis_script(script, redis_params)
  1033. local caller = debug.getinfo(2)
  1034. local new_script = {
  1035. caller = caller,
  1036. loaded = false,
  1037. redis_params = redis_params,
  1038. script = script,
  1039. waitq = {}, -- callbacks pending for script being loaded
  1040. id = #redis_scripts + 1
  1041. }
  1042. -- Register on load function
  1043. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  1044. local mult = 0.0
  1045. rspamd_config:add_periodic(ev_base, 0.0, function()
  1046. if not new_script.sha then
  1047. load_redis_script(new_script, cfg, ev_base, worker)
  1048. mult = mult + 1
  1049. return 1.0 * mult -- Check one more time in one second
  1050. end
  1051. return false
  1052. end, false)
  1053. end)
  1054. table.insert(redis_scripts, new_script)
  1055. return #redis_scripts
  1056. end
  1057. exports.add_redis_script = add_redis_script
  1058. local function exec_redis_script(id, params, callback, keys, args)
  1059. local redis_args = {}
  1060. if not redis_scripts[id] then
  1061. logger.errx("cannot find registered script with id %s", id)
  1062. return false
  1063. end
  1064. local script = redis_scripts[id]
  1065. if script.fatal_error then
  1066. callback(script.fatal_error, nil)
  1067. return true
  1068. end
  1069. if not script.redis_params then
  1070. callback('no redis servers defined', nil)
  1071. return true
  1072. end
  1073. local function do_call(can_reload)
  1074. local function redis_cb(err, data)
  1075. if not err then
  1076. callback(err, data)
  1077. elseif string.match(err, 'NOSCRIPT') then
  1078. -- Schedule restart
  1079. script.sha = nil
  1080. if can_reload then
  1081. table.insert(script.waitq, do_call)
  1082. if script.in_flight == 0 then
  1083. -- Reload scripts if this has not been initiated yet
  1084. if params.task then
  1085. load_script_task(script, params.task)
  1086. else
  1087. load_script_taskless(script, rspamd_config, params.ev_base)
  1088. end
  1089. end
  1090. else
  1091. callback(err, data)
  1092. end
  1093. else
  1094. callback(err, data)
  1095. end
  1096. end
  1097. if #redis_args == 0 then
  1098. table.insert(redis_args, script.sha)
  1099. table.insert(redis_args, tostring(#keys))
  1100. for _,k in ipairs(keys) do
  1101. table.insert(redis_args, k)
  1102. end
  1103. if type(args) == 'table' then
  1104. for _, a in ipairs(args) do
  1105. table.insert(redis_args, a)
  1106. end
  1107. end
  1108. end
  1109. if params.task then
  1110. if not rspamd_redis_make_request(params.task, script.redis_params,
  1111. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1112. callback('Cannot make redis request', nil)
  1113. end
  1114. else
  1115. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  1116. script.redis_params,
  1117. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1118. callback('Cannot make redis request', nil)
  1119. end
  1120. end
  1121. end
  1122. if script.loaded then
  1123. do_call(true)
  1124. else
  1125. -- Delayed until scripts are loaded
  1126. if not params.task then
  1127. table.insert(script.waitq, do_call)
  1128. else
  1129. -- TODO: fix taskfull requests
  1130. table.insert(script.waitq, function()
  1131. if script.loaded then
  1132. do_call(false)
  1133. else
  1134. callback('NOSCRIPT', nil)
  1135. end
  1136. end)
  1137. load_script_task(script, params.task)
  1138. end
  1139. end
  1140. return true
  1141. end
  1142. exports.exec_redis_script = exec_redis_script
  1143. local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
  1144. if not redis_params then
  1145. return false,nil
  1146. end
  1147. local rspamd_redis = require "rspamd_redis"
  1148. local addr
  1149. if key then
  1150. if is_write then
  1151. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  1152. else
  1153. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  1154. end
  1155. else
  1156. if is_write then
  1157. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  1158. else
  1159. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  1160. end
  1161. end
  1162. if not addr then
  1163. logger.errx(cfg, 'cannot select server to make redis request')
  1164. end
  1165. local options = {
  1166. host = addr:get_addr(),
  1167. timeout = redis_params['timeout'],
  1168. config = cfg or rspamd_config,
  1169. ev_base = ev_base or rspamadm_ev_base,
  1170. session = redis_params.session or rspamadm_session
  1171. }
  1172. for k,v in pairs(redis_params) do
  1173. options[k] = v
  1174. end
  1175. if not options.config then
  1176. logger.errx('config is not set')
  1177. return false,nil,addr
  1178. end
  1179. if not options.ev_base then
  1180. logger.errx('ev_base is not set')
  1181. return false,nil,addr
  1182. end
  1183. if not options.session then
  1184. logger.errx('session is not set')
  1185. return false,nil,addr
  1186. end
  1187. local ret,conn = rspamd_redis.connect_sync(options)
  1188. if not ret then
  1189. logger.errx('cannot execute redis request: %s', conn)
  1190. addr:fail()
  1191. return false,nil,addr
  1192. end
  1193. if conn then
  1194. if redis_params['password'] then
  1195. conn:add_cmd('AUTH', {redis_params['password']})
  1196. end
  1197. if redis_params['db'] then
  1198. conn:add_cmd('SELECT', {tostring(redis_params['db'])})
  1199. elseif redis_params['dbname'] then
  1200. conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
  1201. end
  1202. end
  1203. return ret,conn,addr
  1204. end
  1205. exports.redis_connect_sync = redis_connect_sync
  1206. --[[[
  1207. -- @function lua_redis.request(redis_params, attrs, req)
  1208. -- Sends a request to Redis synchronously with coroutines or asynchronously using
  1209. -- a callback (modern API)
  1210. -- @param redis_params a table of redis server parameters
  1211. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1212. -- @param req a table of request: a command + command options
  1213. -- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
  1214. --]]
  1215. exports.request = function(redis_params, attrs, req)
  1216. local lua_util = require "lua_util"
  1217. if not attrs or not redis_params or not req then
  1218. logger.errx('invalid arguments for redis request')
  1219. return false,nil,nil
  1220. end
  1221. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1222. logger.errx('invalid attributes for redis request')
  1223. return false,nil,nil
  1224. end
  1225. local opts = lua_util.shallowcopy(attrs)
  1226. local log_obj = opts.task or opts.config
  1227. local addr
  1228. if opts.callback then
  1229. -- Wrap callback
  1230. local callback = opts.callback
  1231. local function rspamd_redis_make_request_cb(err, data)
  1232. if err then
  1233. addr:fail()
  1234. else
  1235. addr:ok()
  1236. end
  1237. callback(err, data, addr)
  1238. end
  1239. opts.callback = rspamd_redis_make_request_cb
  1240. end
  1241. local rspamd_redis = require "rspamd_redis"
  1242. local is_write = opts.is_write
  1243. if opts.key then
  1244. if is_write then
  1245. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1246. else
  1247. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1248. end
  1249. else
  1250. if is_write then
  1251. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1252. else
  1253. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1254. end
  1255. end
  1256. if not addr then
  1257. logger.errx(log_obj, 'cannot select server to make redis request')
  1258. end
  1259. opts.host = addr:get_addr()
  1260. opts.timeout = redis_params.timeout
  1261. if type(req) == 'string' then
  1262. opts.cmd = req
  1263. else
  1264. -- XXX: modifies the input table
  1265. opts.cmd = table.remove(req, 1);
  1266. opts.args = req
  1267. end
  1268. if redis_params.password then
  1269. opts.password = redis_params.password
  1270. end
  1271. if redis_params.db then
  1272. opts.dbname = redis_params.db
  1273. end
  1274. lutil.debugm(N, 'perform generic request to redis server' ..
  1275. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr,
  1276. opts.timeout, opts.cmd, opts.args)
  1277. if opts.callback then
  1278. local ret,conn = rspamd_redis.make_request(opts)
  1279. if not ret then
  1280. logger.errx(log_obj, 'cannot execute redis request')
  1281. addr:fail()
  1282. end
  1283. return ret,conn,addr
  1284. else
  1285. -- Coroutines version
  1286. local ret,conn = rspamd_redis.connect_sync(opts)
  1287. if not ret then
  1288. logger.errx(log_obj, 'cannot execute redis request')
  1289. addr:fail()
  1290. else
  1291. conn:add_cmd(opts.cmd, opts.args)
  1292. return conn:exec()
  1293. end
  1294. return false,nil,addr
  1295. end
  1296. end
  1297. --[[[
  1298. -- @function lua_redis.connect(redis_params, attrs)
  1299. -- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
  1300. -- @param redis_params a table of redis server parameters
  1301. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1302. -- @return {result,connection,address} boolean result, connection object, redis server address
  1303. --]]
  1304. exports.connect = function(redis_params, attrs)
  1305. local lua_util = require "lua_util"
  1306. if not attrs or not redis_params then
  1307. logger.errx('invalid arguments for redis connect')
  1308. return false,nil,nil
  1309. end
  1310. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1311. logger.errx('invalid attributes for redis connect')
  1312. return false,nil,nil
  1313. end
  1314. local opts = lua_util.shallowcopy(attrs)
  1315. local log_obj = opts.task or opts.config
  1316. local addr
  1317. if opts.callback then
  1318. -- Wrap callback
  1319. local callback = opts.callback
  1320. local function rspamd_redis_make_request_cb(err, data)
  1321. if err then
  1322. addr:fail()
  1323. else
  1324. addr:ok()
  1325. end
  1326. callback(err, data, addr)
  1327. end
  1328. opts.callback = rspamd_redis_make_request_cb
  1329. end
  1330. local rspamd_redis = require "rspamd_redis"
  1331. local is_write = opts.is_write
  1332. if opts.key then
  1333. if is_write then
  1334. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1335. else
  1336. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1337. end
  1338. else
  1339. if is_write then
  1340. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1341. else
  1342. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1343. end
  1344. end
  1345. if not addr then
  1346. logger.errx(log_obj, 'cannot select server to make redis connect')
  1347. end
  1348. opts.host = addr:get_addr()
  1349. opts.timeout = redis_params.timeout
  1350. if redis_params.password then
  1351. opts.password = redis_params.password
  1352. end
  1353. if redis_params.db then
  1354. opts.dbname = redis_params.db
  1355. end
  1356. if opts.callback then
  1357. local ret,conn = rspamd_redis.connect(opts)
  1358. if not ret then
  1359. logger.errx(log_obj, 'cannot execute redis connect')
  1360. addr:fail()
  1361. end
  1362. return ret,conn,addr
  1363. else
  1364. -- Coroutines version
  1365. local ret,conn = rspamd_redis.connect_sync(opts)
  1366. if not ret then
  1367. logger.errx(log_obj, 'cannot execute redis connect')
  1368. addr:fail()
  1369. else
  1370. return true,conn,addr
  1371. end
  1372. return false,nil,addr
  1373. end
  1374. end
  1375. local redis_prefixes = {}
  1376. --[[[
  1377. -- @function lua_redis.register_prefix(prefix, module, description[, optional])
  1378. -- Register new redis prefix for documentation purposes
  1379. -- @param {string} prefix string prefix
  1380. -- @param {string} module module name
  1381. -- @param {string} description prefix description
  1382. -- @param {table} optional optional kv pairs (e.g. pattern)
  1383. --]]
  1384. local function register_prefix(prefix, module, description, optional)
  1385. local pr = {
  1386. module = module,
  1387. description = description
  1388. }
  1389. if optional and type(optional) == 'table' then
  1390. for k,v in pairs(optional) do
  1391. pr[k] = v
  1392. end
  1393. end
  1394. redis_prefixes[prefix] = pr
  1395. end
  1396. exports.register_prefix = register_prefix
  1397. --[[[
  1398. -- @function lua_redis.prefixes([mname])
  1399. -- Returns prefixes for specific module (or all prefixes). Returns a table prefix -> table
  1400. --]]
  1401. exports.prefixes = function(mname)
  1402. if not mname then
  1403. return redis_prefixes
  1404. else
  1405. local fun = require "fun"
  1406. return fun.totable(fun.filter(function(_, data) return data.module == mname end,
  1407. redis_prefixes))
  1408. end
  1409. end
  1410. return exports