You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_redis.lua 34KB


  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local exports = {}
  17. local E = {}
  18. --[[[
  19. -- @module lua_redis
  20. -- This module contains helper functions for working with Redis
  21. --]]
  22. local function try_load_redis_servers(options, rspamd_config, result)
  23. local default_port = 6379
  24. local default_timeout = 1.0
  25. local default_expand_keys = false
  26. local upstream_list = require "rspamd_upstream_list"
  27. local read_only = true
  28. -- Try to get read servers:
  29. local upstreams_read, upstreams_write
  30. if options['read_servers'] then
  31. if rspamd_config then
  32. upstreams_read = upstream_list.create(rspamd_config,
  33. options['read_servers'], default_port)
  34. else
  35. upstreams_read = upstream_list.create(options['read_servers'],
  36. default_port)
  37. end
  38. elseif options['servers'] then
  39. if rspamd_config then
  40. upstreams_read = upstream_list.create(rspamd_config,
  41. options['servers'], default_port)
  42. else
  43. upstreams_read = upstream_list.create(options['servers'], default_port)
  44. end
  45. read_only = false
  46. elseif options['server'] then
  47. if rspamd_config then
  48. upstreams_read = upstream_list.create(rspamd_config,
  49. options['server'], default_port)
  50. else
  51. upstreams_read = upstream_list.create(options['server'], default_port)
  52. end
  53. read_only = false
  54. end
  55. if upstreams_read then
  56. if options['write_servers'] then
  57. if rspamd_config then
  58. upstreams_write = upstream_list.create(rspamd_config,
  59. options['write_servers'], default_port)
  60. else
  61. upstreams_write = upstream_list.create(options['write_servers'],
  62. default_port)
  63. end
  64. read_only = false
  65. elseif not read_only then
  66. upstreams_write = upstreams_read
  67. end
  68. end
  69. -- Store options
  70. if not result['timeout'] or result['timeout'] == default_timeout then
  71. if options['timeout'] then
  72. result['timeout'] = tonumber(options['timeout'])
  73. else
  74. result['timeout'] = default_timeout
  75. end
  76. end
  77. if options['prefix'] and not result['prefix'] then
  78. result['prefix'] = options['prefix']
  79. end
  80. if type(options['expand_keys']) == 'boolean' then
  81. result['expand_keys'] = options['expand_keys']
  82. else
  83. result['expand_keys'] = default_expand_keys
  84. end
  85. if not result['db'] then
  86. if options['db'] then
  87. result['db'] = tostring(options['db'])
  88. elseif options['dbname'] then
  89. result['db'] = tostring(options['dbname'])
  90. elseif options['database'] then
  91. result['db'] = tostring(options['database'])
  92. end
  93. end
  94. if options['password'] and not result['password'] then
  95. result['password'] = options['password']
  96. end
  97. if read_only and not upstreams_write then
  98. result.read_only = true
  99. elseif upstreams_write then
  100. result.read_only = false
  101. end
  102. if upstreams_read then
  103. result.read_servers = upstreams_read
  104. if upstreams_write then
  105. result.write_servers = upstreams_write
  106. end
  107. return true
  108. end
  109. return false
  110. end
  111. exports.try_load_redis_servers = try_load_redis_servers
  112. -- This function parses redis server definition using either
  113. -- specific server string for this module or global
  114. -- redis section
  115. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  116. local result = {}
  117. -- Try local options
  118. local opts
  119. if not module_opts then
  120. opts = rspamd_config:get_all_opt(module_name)
  121. else
  122. opts = module_opts
  123. end
  124. if opts then
  125. local ret
  126. if opts.redis then
  127. ret = try_load_redis_servers(opts.redis, rspamd_config, result)
  128. if ret then
  129. return result
  130. end
  131. end
  132. ret = try_load_redis_servers(opts, rspamd_config, result)
  133. if ret then
  134. return result
  135. end
  136. end
  137. if no_fallback then return nil end
  138. -- Try global options
  139. opts = rspamd_config:get_all_opt('redis')
  140. if opts then
  141. local ret
  142. if opts[module_name] then
  143. ret = try_load_redis_servers(opts[module_name], rspamd_config, result)
  144. if ret then
  145. return result
  146. end
  147. else
  148. ret = try_load_redis_servers(opts, rspamd_config, result)
  149. -- Exclude disabled
  150. if opts['disabled_modules'] then
  151. for _,v in ipairs(opts['disabled_modules']) do
  152. if v == module_name then
  153. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  154. module_name)
  155. return nil
  156. end
  157. end
  158. end
  159. if ret then
  160. logger.infox(rspamd_config, "using default redis server for module %s",
  161. module_name)
  162. end
  163. end
  164. end
  165. if result.read_servers then
  166. return result
  167. else
  168. return nil
  169. end
  170. end
  171. --[[[
  172. -- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
  173. -- Extracts Redis server settings from configuration
  174. -- @param {string} module_name name of module to get settings for
  175. -- @param {table} module_opts settings for module or `nil` to fetch them from configuration
  176. -- @param {boolean} no_fallback should be `true` if global settings must not be used
  177. -- @return {table} redis server settings
  178. -- @example
  179. -- local rconfig = lua_redis.parse_redis_server('my_module')
  180. -- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
  181. -- -- ['timeout'] contains timeout in seconds
  182. -- -- ['expand_keys'] if true tells that redis key expansion is enabled
  183. --]]
  184. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  185. exports.parse_redis_server = rspamd_parse_redis_server
  186. local process_cmd = {
  187. bitop = function(args)
  188. local idx_l = {}
  189. for i = 2, #args do
  190. table.insert(idx_l, i)
  191. end
  192. return idx_l
  193. end,
  194. blpop = function(args)
  195. local idx_l = {}
  196. for i = 1, #args -1 do
  197. table.insert(idx_l, i)
  198. end
  199. return idx_l
  200. end,
  201. eval = function(args)
  202. local idx_l = {}
  203. local numkeys = args[2]
  204. if numkeys and tonumber(numkeys) >= 1 then
  205. for i = 3, numkeys + 2 do
  206. table.insert(idx_l, i)
  207. end
  208. end
  209. return idx_l
  210. end,
  211. set = function(args)
  212. return {1}
  213. end,
  214. mget = function(args)
  215. local idx_l = {}
  216. for i = 1, #args do
  217. table.insert(idx_l, i)
  218. end
  219. return idx_l
  220. end,
  221. mset = function(args)
  222. local idx_l = {}
  223. for i = 1, #args, 2 do
  224. table.insert(idx_l, i)
  225. end
  226. return idx_l
  227. end,
  228. sdiffstore = function(args)
  229. local idx_l = {}
  230. for i = 2, #args do
  231. table.insert(idx_l, i)
  232. end
  233. return idx_l
  234. end,
  235. smove = function(args)
  236. return {1, 2}
  237. end,
  238. script = function() end
  239. }
  240. process_cmd.append = process_cmd.set
  241. process_cmd.auth = process_cmd.script
  242. process_cmd.bgrewriteaof = process_cmd.script
  243. process_cmd.bgsave = process_cmd.script
  244. process_cmd.bitcount = process_cmd.set
  245. process_cmd.bitfield = process_cmd.set
  246. process_cmd.bitpos = process_cmd.set
  247. process_cmd.brpop = process_cmd.blpop
  248. process_cmd.brpoplpush = process_cmd.blpop
  249. process_cmd.client = process_cmd.script
  250. process_cmd.cluster = process_cmd.script
  251. process_cmd.command = process_cmd.script
  252. process_cmd.config = process_cmd.script
  253. process_cmd.dbsize = process_cmd.script
  254. process_cmd.debug = process_cmd.script
  255. process_cmd.decr = process_cmd.set
  256. process_cmd.decrby = process_cmd.set
  257. process_cmd.del = process_cmd.mget
  258. process_cmd.discard = process_cmd.script
  259. process_cmd.dump = process_cmd.set
  260. process_cmd.echo = process_cmd.script
  261. process_cmd.evalsha = process_cmd.eval
  262. process_cmd.exec = process_cmd.script
  263. process_cmd.exists = process_cmd.mget
  264. process_cmd.expire = process_cmd.set
  265. process_cmd.expireat = process_cmd.set
  266. process_cmd.flushall = process_cmd.script
  267. process_cmd.flushdb = process_cmd.script
  268. process_cmd.geoadd = process_cmd.set
  269. process_cmd.geohash = process_cmd.set
  270. process_cmd.geopos = process_cmd.set
  271. process_cmd.geodist = process_cmd.set
  272. process_cmd.georadius = process_cmd.set
  273. process_cmd.georadiusbymember = process_cmd.set
  274. process_cmd.get = process_cmd.set
  275. process_cmd.getbit = process_cmd.set
  276. process_cmd.getrange = process_cmd.set
  277. process_cmd.getset = process_cmd.set
  278. process_cmd.hdel = process_cmd.set
  279. process_cmd.hexists = process_cmd.set
  280. process_cmd.hget = process_cmd.set
  281. process_cmd.hgetall = process_cmd.set
  282. process_cmd.hincrby = process_cmd.set
  283. process_cmd.hincrbyfloat = process_cmd.set
  284. process_cmd.hkeys = process_cmd.set
  285. process_cmd.hlen = process_cmd.set
  286. process_cmd.hmget = process_cmd.set
  287. process_cmd.hmset = process_cmd.set
  288. process_cmd.hscan = process_cmd.set
  289. process_cmd.hset = process_cmd.set
  290. process_cmd.hsetnx = process_cmd.set
  291. process_cmd.hstrlen = process_cmd.set
  292. process_cmd.hvals = process_cmd.set
  293. process_cmd.incr = process_cmd.set
  294. process_cmd.incrby = process_cmd.set
  295. process_cmd.incrbyfloat = process_cmd.set
  296. process_cmd.info = process_cmd.script
  297. process_cmd.keys = process_cmd.script
  298. process_cmd.lastsave = process_cmd.script
  299. process_cmd.lindex = process_cmd.set
  300. process_cmd.linsert = process_cmd.set
  301. process_cmd.llen = process_cmd.set
  302. process_cmd.lpop = process_cmd.set
  303. process_cmd.lpush = process_cmd.set
  304. process_cmd.lpushx = process_cmd.set
  305. process_cmd.lrange = process_cmd.set
  306. process_cmd.lrem = process_cmd.set
  307. process_cmd.lset = process_cmd.set
  308. process_cmd.ltrim = process_cmd.set
  309. process_cmd.migrate = process_cmd.script
  310. process_cmd.monitor = process_cmd.script
  311. process_cmd.move = process_cmd.set
  312. process_cmd.msetnx = process_cmd.mset
  313. process_cmd.multi = process_cmd.script
  314. process_cmd.object = process_cmd.script
  315. process_cmd.persist = process_cmd.set
  316. process_cmd.pexpire = process_cmd.set
  317. process_cmd.pexpireat = process_cmd.set
  318. process_cmd.pfadd = process_cmd.set
  319. process_cmd.pfcount = process_cmd.set
  320. process_cmd.pfmerge = process_cmd.mget
  321. process_cmd.ping = process_cmd.script
  322. process_cmd.psetex = process_cmd.set
  323. process_cmd.psubscribe = process_cmd.script
  324. process_cmd.pubsub = process_cmd.script
  325. process_cmd.pttl = process_cmd.set
  326. process_cmd.publish = process_cmd.script
  327. process_cmd.punsubscribe = process_cmd.script
  328. process_cmd.quit = process_cmd.script
  329. process_cmd.randomkey = process_cmd.script
  330. process_cmd.readonly = process_cmd.script
  331. process_cmd.readwrite = process_cmd.script
  332. process_cmd.rename = process_cmd.mget
  333. process_cmd.renamenx = process_cmd.mget
  334. process_cmd.restore = process_cmd.set
  335. process_cmd.role = process_cmd.script
  336. process_cmd.rpop = process_cmd.set
  337. process_cmd.rpoplpush = process_cmd.mget
  338. process_cmd.rpush = process_cmd.set
  339. process_cmd.rpushx = process_cmd.set
  340. process_cmd.sadd = process_cmd.set
  341. process_cmd.save = process_cmd.script
  342. process_cmd.scard = process_cmd.set
  343. process_cmd.sdiff = process_cmd.mget
  344. process_cmd.select = process_cmd.script
  345. process_cmd.setbit = process_cmd.set
  346. process_cmd.setex = process_cmd.set
  347. process_cmd.setnx = process_cmd.set
  348. process_cmd.sinterstore = process_cmd.sdiff
  349. process_cmd.sismember = process_cmd.set
  350. process_cmd.slaveof = process_cmd.script
  351. process_cmd.slowlog = process_cmd.script
  352. process_cmd.smembers = process_cmd.script
  353. process_cmd.sort = process_cmd.set
  354. process_cmd.spop = process_cmd.set
  355. process_cmd.srandmember = process_cmd.set
  356. process_cmd.srem = process_cmd.set
  357. process_cmd.strlen = process_cmd.set
  358. process_cmd.subscribe = process_cmd.script
  359. process_cmd.sunion = process_cmd.mget
  360. process_cmd.sunionstore = process_cmd.mget
  361. process_cmd.swapdb = process_cmd.script
  362. process_cmd.sync = process_cmd.script
  363. process_cmd.time = process_cmd.script
  364. process_cmd.touch = process_cmd.mget
  365. process_cmd.ttl = process_cmd.set
  366. process_cmd.type = process_cmd.set
  367. process_cmd.unsubscribe = process_cmd.script
  368. process_cmd.unlink = process_cmd.mget
  369. process_cmd.unwatch = process_cmd.script
  370. process_cmd.wait = process_cmd.script
  371. process_cmd.watch = process_cmd.mget
  372. process_cmd.zadd = process_cmd.set
  373. process_cmd.zcard = process_cmd.set
  374. process_cmd.zcount = process_cmd.set
  375. process_cmd.zincrby = process_cmd.set
  376. process_cmd.zinterstore = process_cmd.eval
  377. process_cmd.zlexcount = process_cmd.set
  378. process_cmd.zrange = process_cmd.set
  379. process_cmd.zrangebylex = process_cmd.set
  380. process_cmd.zrank = process_cmd.set
  381. process_cmd.zrem = process_cmd.set
  382. process_cmd.zrembylex = process_cmd.set
  383. process_cmd.zrembyrank = process_cmd.set
  384. process_cmd.zrembyscore = process_cmd.set
  385. process_cmd.zrevrange = process_cmd.set
  386. process_cmd.zrevrangebyscore = process_cmd.set
  387. process_cmd.zrevrank = process_cmd.set
  388. process_cmd.zscore = process_cmd.set
  389. process_cmd.zunionstore = process_cmd.eval
  390. process_cmd.scan = process_cmd.script
  391. process_cmd.sscan = process_cmd.set
  392. process_cmd.hscan = process_cmd.set
  393. process_cmd.zscan = process_cmd.set
  394. local function get_key_indexes(cmd, args)
  395. local idx_l = {}
  396. cmd = string.lower(cmd)
  397. if process_cmd[cmd] then
  398. idx_l = process_cmd[cmd](args)
  399. else
  400. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  401. end
  402. return idx_l
  403. end
  404. local gen_meta = {
  405. principal_recipient = function(task)
  406. return task:get_principal_recipient()
  407. end,
  408. principal_recipient_domain = function(task)
  409. local p = task:get_principal_recipient()
  410. if not p then return end
  411. return string.match(p, '.*@(.*)')
  412. end,
  413. ip = function(task)
  414. local i = task:get_ip()
  415. if i and i:is_valid() then return i:to_string() end
  416. end,
  417. from = function(task)
  418. return ((task:get_from('smtp') or E)[1] or E)['addr']
  419. end,
  420. from_domain = function(task)
  421. return ((task:get_from('smtp') or E)[1] or E)['domain']
  422. end,
  423. from_domain_or_helo_domain = function(task)
  424. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  425. if d and #d > 0 then return d end
  426. return task:get_helo()
  427. end,
  428. mime_from = function(task)
  429. return ((task:get_from('mime') or E)[1] or E)['addr']
  430. end,
  431. mime_from_domain = function(task)
  432. return ((task:get_from('mime') or E)[1] or E)['domain']
  433. end,
  434. }
  435. local function gen_get_esld(f)
  436. return function(task)
  437. local d = f(task)
  438. if not d then return end
  439. return rspamd_util.get_tld(d)
  440. end
  441. end
  442. gen_meta.smtp_from = gen_meta.from
  443. gen_meta.smtp_from_domain = gen_meta.from_domain
  444. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  445. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  446. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  447. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  448. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  449. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  450. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  451. local function get_key_expansion_metadata(task)
  452. local md_mt = {
  453. __index = function(self, k)
  454. k = string.lower(k)
  455. local v = rawget(self, k)
  456. if v then
  457. return v
  458. end
  459. if gen_meta[k] then
  460. v = gen_meta[k](task)
  461. rawset(self, k, v)
  462. end
  463. return v
  464. end,
  465. }
  466. local lazy_meta = {}
  467. setmetatable(lazy_meta, md_mt)
  468. return lazy_meta
  469. end
  470. -- Performs async call to redis hiding all complexity inside function
  471. -- task - rspamd_task
  472. -- redis_params - valid params returned by rspamd_parse_redis_server
  473. -- key - key to select upstream or nil to select round-robin/master-slave
  474. -- is_write - true if need to write to redis server
  475. -- callback - function to be called upon request is completed
  476. -- command - redis command
  477. -- args - table of arguments
  478. -- extra_opts - table of optional request arguments
  479. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  480. callback, command, args, extra_opts)
  481. local addr
  482. local function rspamd_redis_make_request_cb(err, data)
  483. if err then
  484. addr:fail()
  485. else
  486. addr:ok()
  487. end
  488. callback(err, data, addr)
  489. end
  490. if not task or not redis_params or not callback or not command then
  491. return false,nil,nil
  492. end
  493. local rspamd_redis = require "rspamd_redis"
  494. if key then
  495. if is_write then
  496. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  497. else
  498. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  499. end
  500. else
  501. if is_write then
  502. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  503. else
  504. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  505. end
  506. end
  507. if not addr then
  508. logger.errx(task, 'cannot select server to make redis request')
  509. end
  510. if redis_params['expand_keys'] then
  511. local m = get_key_expansion_metadata(task)
  512. local indexes = get_key_indexes(command, args)
  513. for _, i in ipairs(indexes) do
  514. args[i] = lutil.template(args[i], m)
  515. end
  516. end
  517. local ip_addr = addr:get_addr()
  518. local options = {
  519. task = task,
  520. callback = rspamd_redis_make_request_cb,
  521. host = ip_addr,
  522. timeout = redis_params['timeout'],
  523. cmd = command,
  524. args = args
  525. }
  526. if extra_opts then
  527. for k,v in pairs(extra_opts) do
  528. options[k] = v
  529. end
  530. end
  531. if redis_params['password'] then
  532. options['password'] = redis_params['password']
  533. end
  534. if redis_params['db'] then
  535. options['dbname'] = redis_params['db']
  536. end
  537. local ret,conn = rspamd_redis.make_request(options)
  538. if not ret then
  539. addr:fail()
  540. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  541. end
  542. return ret,conn,addr
  543. end
  544. --[[[
  545. -- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
  546. -- Sends a request to Redis
  547. -- @param {rspamd_task} task task object
  548. -- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
  549. -- @param {string} key key to use for sharding
  550. -- @param {boolean} is_write should be `true` if we are performing a write operating
  551. -- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
  552. -- @param {string} command Redis command to run
  553. -- @param {table} args Numerically indexed table containing arguments for command
  554. --]]
  555. exports.rspamd_redis_make_request = rspamd_redis_make_request
  556. exports.redis_make_request = rspamd_redis_make_request
  557. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  558. is_write, callback, command, args, extra_opts)
  559. if not ev_base or not redis_params or not callback or not command then
  560. return false,nil,nil
  561. end
  562. local addr
  563. local function rspamd_redis_make_request_cb(err, data)
  564. if err then
  565. addr:fail()
  566. else
  567. addr:ok()
  568. end
  569. callback(err, data, addr)
  570. end
  571. local rspamd_redis = require "rspamd_redis"
  572. if key then
  573. if is_write then
  574. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  575. else
  576. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  577. end
  578. else
  579. if is_write then
  580. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  581. else
  582. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  583. end
  584. end
  585. if not addr then
  586. logger.errx(cfg, 'cannot select server to make redis request')
  587. end
  588. local options = {
  589. ev_base = ev_base,
  590. config = cfg,
  591. callback = rspamd_redis_make_request_cb,
  592. host = addr:get_addr(),
  593. timeout = redis_params['timeout'],
  594. cmd = command,
  595. args = args
  596. }
  597. if extra_opts then
  598. for k,v in pairs(extra_opts) do
  599. options[k] = v
  600. end
  601. end
  602. if redis_params['password'] then
  603. options['password'] = redis_params['password']
  604. end
  605. if redis_params['db'] then
  606. options['dbname'] = redis_params['db']
  607. end
  608. local ret,conn = rspamd_redis.make_request(options)
  609. if not ret then
  610. logger.errx('cannot execute redis request')
  611. addr:fail()
  612. end
  613. return ret,conn,addr
  614. end
  615. --[[[
  616. -- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
  617. -- Sends a request to Redis in context where `task` is not available for some specific use-cases
  618. -- Identical to redis_make_request() except in that first parameter is an `event base` object
  619. --]]
  620. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  621. exports.redis_make_request_taskless = redis_make_request_taskless
  622. local redis_scripts = {
  623. }
  624. local function script_set_loaded(script)
  625. if script.sha then
  626. script.loaded = true
  627. end
  628. local wait_table = {}
  629. for _,s in ipairs(script.waitq) do
  630. table.insert(wait_table, s)
  631. end
  632. script.waitq = {}
  633. for _,s in ipairs(wait_table) do
  634. s(script.loaded)
  635. end
  636. end
  637. local function prepare_redis_call(script)
  638. local function merge_tables(t1, t2)
  639. for k,v in pairs(t2) do t1[k] = v end
  640. end
  641. local servers = {}
  642. local options = {}
  643. if script.redis_params.read_servers then
  644. merge_tables(servers, script.redis_params.read_servers:all_upstreams())
  645. end
  646. if script.redis_params.write_servers then
  647. merge_tables(servers, script.redis_params.write_servers:all_upstreams())
  648. end
  649. -- Call load script on each server, set loaded flag
  650. script.in_flight = #servers
  651. for _,s in ipairs(servers) do
  652. local cur_opts = {
  653. host = s:get_addr(),
  654. timeout = script.redis_params['timeout'],
  655. cmd = 'SCRIPT',
  656. args = {'LOAD', script.script },
  657. upstream = s
  658. }
  659. if script.redis_params['password'] then
  660. cur_opts['password'] = script.redis_params['password']
  661. end
  662. if script.redis_params['db'] then
  663. cur_opts['dbname'] = script.redis_params['db']
  664. end
  665. table.insert(options, cur_opts)
  666. end
  667. return options
  668. end
  669. local function load_script_task(script, task)
  670. local rspamd_redis = require "rspamd_redis"
  671. local opts = prepare_redis_call(script)
  672. for _,opt in ipairs(opts) do
  673. opt.task = task
  674. opt.callback = function(err, data)
  675. if err then
  676. logger.errx(task, 'cannot upload script to %s: %s',
  677. opt.upstream:get_addr(), err)
  678. opt.upstream:fail()
  679. script.fatal_error = err
  680. else
  681. opt.upstream:ok()
  682. logger.infox(task,
  683. "uploaded redis script to %s with id %s, sha: %s",
  684. opt.upstream:get_addr(), script.id, data)
  685. script.sha = data -- We assume that sha is the same on all servers
  686. end
  687. script.in_flight = script.in_flight - 1
  688. if script.in_flight == 0 then
  689. script_set_loaded(script)
  690. end
  691. end
  692. local ret = rspamd_redis.make_request(opt)
  693. if not ret then
  694. logger.errx('cannot execute redis request to load script on %s',
  695. opt.upstream:get_addr())
  696. script.in_flight = script.in_flight - 1
  697. opt.upstream:fail()
  698. end
  699. if script.in_flight == 0 then
  700. script_set_loaded(script)
  701. end
  702. end
  703. end
  704. local function load_script_taskless(script, cfg, ev_base)
  705. local rspamd_redis = require "rspamd_redis"
  706. local opts = prepare_redis_call(script)
  707. for _,opt in ipairs(opts) do
  708. opt.config = cfg
  709. opt.ev_base = ev_base
  710. opt.callback = function(err, data)
  711. if err then
  712. logger.errx(cfg, 'cannot upload script to %s: %s',
  713. opt.upstream:get_addr(), err)
  714. opt.upstream:fail()
  715. script.fatal_error = err
  716. else
  717. opt.upstream:ok()
  718. logger.infox(cfg,
  719. "uploaded redis script to %s with id %s, sha: %s",
  720. opt.upstream:get_addr(), script.id, data)
  721. script.sha = data -- We assume that sha is the same on all servers
  722. script.fatal_error = nil
  723. end
  724. script.in_flight = script.in_flight - 1
  725. if script.in_flight == 0 then
  726. script_set_loaded(script)
  727. end
  728. end
  729. local ret = rspamd_redis.make_request(opt)
  730. if not ret then
  731. logger.errx('cannot execute redis request to load script on %s',
  732. opt.upstream:get_addr())
  733. script.in_flight = script.in_flight - 1
  734. opt.upstream:fail()
  735. end
  736. if script.in_flight == 0 then
  737. script_set_loaded(script)
  738. end
  739. end
  740. end
  741. local function load_redis_script(script, cfg, ev_base, _)
  742. if script.redis_params then
  743. load_script_taskless(script, cfg, ev_base)
  744. end
  745. end
  746. local function add_redis_script(script, redis_params)
  747. local new_script = {
  748. loaded = false,
  749. redis_params = redis_params,
  750. script = script,
  751. waitq = {}, -- callbacks pending for script being loaded
  752. id = #redis_scripts + 1
  753. }
  754. -- Register on load function
  755. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  756. local mult = 0.0
  757. rspamd_config:add_periodic(ev_base, 0.0, function()
  758. if not new_script.sha then
  759. load_redis_script(new_script, cfg, ev_base, worker)
  760. mult = mult + 1
  761. return 1.0 * mult -- Check one more time in one second
  762. end
  763. return false
  764. end, false)
  765. end)
  766. table.insert(redis_scripts, new_script)
  767. return #redis_scripts
  768. end
  769. exports.add_redis_script = add_redis_script
  770. local function exec_redis_script(id, params, callback, keys, args)
  771. local redis_args = {}
  772. if not redis_scripts[id] then
  773. logger.errx("cannot find registered script with id %s", id)
  774. return false
  775. end
  776. local script = redis_scripts[id]
  777. if script.fatal_error then
  778. callback(script.fatal_error, nil)
  779. return true
  780. end
  781. if not script.redis_params then
  782. callback('no redis servers defined', nil)
  783. return true
  784. end
  785. local function do_call(can_reload)
  786. local function redis_cb(err, data)
  787. if not err then
  788. callback(err, data)
  789. elseif string.match(err, 'NOSCRIPT') then
  790. -- Schedule restart
  791. script.sha = nil
  792. if can_reload then
  793. table.insert(script.waitq, do_call)
  794. if script.in_flight == 0 then
  795. -- Reload scripts if this has not been initiated yet
  796. if params.task then
  797. load_script_task(script, params.task)
  798. else
  799. load_script_taskless(script, rspamd_config, params.ev_base)
  800. end
  801. end
  802. else
  803. callback(err, data)
  804. end
  805. else
  806. callback(err, data)
  807. end
  808. end
  809. if #redis_args == 0 then
  810. table.insert(redis_args, script.sha)
  811. table.insert(redis_args, tostring(#keys))
  812. for _,k in ipairs(keys) do
  813. table.insert(redis_args, k)
  814. end
  815. if type(args) == 'table' then
  816. for _, a in ipairs(args) do
  817. table.insert(redis_args, a)
  818. end
  819. end
  820. end
  821. if params.task then
  822. if not rspamd_redis_make_request(params.task, script.redis_params,
  823. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  824. callback('Cannot make redis request', nil)
  825. end
  826. else
  827. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  828. script.redis_params,
  829. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  830. callback('Cannot make redis request', nil)
  831. end
  832. end
  833. end
  834. if script.loaded then
  835. do_call(true)
  836. else
  837. -- Delayed until scripts are loaded
  838. if not params.task then
  839. table.insert(script.waitq, do_call)
  840. else
  841. -- TODO: fix taskfull requests
  842. callback('NOSCRIPT', nil)
  843. end
  844. end
  845. return true
  846. end
  847. exports.exec_redis_script = exec_redis_script
  848. local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
  849. if not redis_params then
  850. return false,nil
  851. end
  852. if not cfg then
  853. cfg = rspamd_config
  854. end
  855. if not ev_base then
  856. ev_base = rspamadm_ev_base
  857. end
  858. local rspamd_redis = require "rspamd_redis"
  859. local addr
  860. if key then
  861. if is_write then
  862. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  863. else
  864. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  865. end
  866. else
  867. if is_write then
  868. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  869. else
  870. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  871. end
  872. end
  873. if not addr then
  874. logger.errx(cfg, 'cannot select server to make redis request')
  875. end
  876. local options = {
  877. host = addr:get_addr(),
  878. timeout = redis_params['timeout'],
  879. config = cfg,
  880. ev_base = ev_base
  881. }
  882. for k,v in pairs(redis_params) do
  883. options[k] = v
  884. end
  885. local ret,conn = rspamd_redis.connect_sync(options)
  886. if not ret then
  887. logger.errx('cannot execute redis request: %s', conn)
  888. addr:fail()
  889. return false,nil,addr
  890. end
  891. if conn then
  892. if redis_params['password'] then
  893. conn:add_cmd('AUTH', {redis_params['password']})
  894. end
  895. if redis_params['db'] then
  896. conn:add_cmd('SELECT', {tostring(redis_params['db'])})
  897. elseif redis_params['dbname'] then
  898. conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
  899. end
  900. end
  901. return ret,conn,addr
  902. end
  903. exports.redis_connect_sync = redis_connect_sync
  904. --[[[
  905. -- @function lua_redis.request(redis_params, attrs, req)
  906. -- Sends a request to Redis synchronously with coroutines or asynchronously using
  907. -- a callback (modern API)
  908. -- @param redis_params a table of redis server parameters
  909. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  910. -- @param req a table of request: a command + command options
  911. -- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
  912. --]]
  913. exports.request = function(redis_params, attrs, req)
  914. local lua_util = require "lua_util"
  915. if not attrs or not redis_params or not req then
  916. logger.errx('invalid arguments for redis request')
  917. return false,nil,nil
  918. end
  919. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  920. logger.errx('invalid attributes for redis request')
  921. return false,nil,nil
  922. end
  923. local opts = lua_util.shallowcopy(attrs)
  924. local log_obj = opts.task or opts.config
  925. local addr
  926. if opts.callback then
  927. -- Wrap callback
  928. local callback = opts.callback
  929. local function rspamd_redis_make_request_cb(err, data)
  930. if err then
  931. addr:fail()
  932. else
  933. addr:ok()
  934. end
  935. callback(err, data, addr)
  936. end
  937. opts.callback = rspamd_redis_make_request_cb
  938. end
  939. local rspamd_redis = require "rspamd_redis"
  940. local is_write = opts.is_write
  941. if opts.key then
  942. if is_write then
  943. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  944. else
  945. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  946. end
  947. else
  948. if is_write then
  949. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  950. else
  951. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  952. end
  953. end
  954. if not addr then
  955. logger.errx(log_obj, 'cannot select server to make redis request')
  956. end
  957. opts.host = addr:get_addr()
  958. opts.timeout = redis_params.timeout
  959. if type(req) == 'string' then
  960. opts.cmd = req
  961. else
  962. -- XXX: modifies the input table
  963. opts.cmd = table.remove(req, 1);
  964. opts.args = req
  965. end
  966. if redis_params.password then
  967. opts.password = redis_params.password
  968. end
  969. if redis_params.db then
  970. opts.dbname = redis_params.db
  971. end
  972. if opts.callback then
  973. local ret,conn = rspamd_redis.make_request(opts)
  974. if not ret then
  975. logger.errx(log_obj, 'cannot execute redis request')
  976. addr:fail()
  977. end
  978. return ret,conn,addr
  979. else
  980. -- Coroutines version
  981. local ret,conn = rspamd_redis.connect_sync(opts)
  982. if not ret then
  983. logger.errx(log_obj, 'cannot execute redis request')
  984. addr:fail()
  985. else
  986. conn:add_cmd(opts.cmd, opts.args)
  987. return conn:exec()
  988. end
  989. return false,nil,addr
  990. end
  991. end
  992. --[[[
  993. -- @function lua_redis.connect(redis_params, attrs)
  994. -- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
  995. -- @param redis_params a table of redis server parameters
  996. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  997. -- @return {result,connection,address} boolean result, connection object, redis server address
  998. --]]
  999. exports.connect = function(redis_params, attrs)
  1000. local lua_util = require "lua_util"
  1001. if not attrs or not redis_params then
  1002. logger.errx('invalid arguments for redis connect')
  1003. return false,nil,nil
  1004. end
  1005. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1006. logger.errx('invalid attributes for redis connect')
  1007. return false,nil,nil
  1008. end
  1009. local opts = lua_util.shallowcopy(attrs)
  1010. local log_obj = opts.task or opts.config
  1011. local addr
  1012. if opts.callback then
  1013. -- Wrap callback
  1014. local callback = opts.callback
  1015. local function rspamd_redis_make_request_cb(err, data)
  1016. if err then
  1017. addr:fail()
  1018. else
  1019. addr:ok()
  1020. end
  1021. callback(err, data, addr)
  1022. end
  1023. opts.callback = rspamd_redis_make_request_cb
  1024. end
  1025. local rspamd_redis = require "rspamd_redis"
  1026. local is_write = opts.is_write
  1027. if opts.key then
  1028. if is_write then
  1029. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1030. else
  1031. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1032. end
  1033. else
  1034. if is_write then
  1035. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1036. else
  1037. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1038. end
  1039. end
  1040. if not addr then
  1041. logger.errx(log_obj, 'cannot select server to make redis connect')
  1042. end
  1043. opts.host = addr:get_addr()
  1044. opts.timeout = redis_params.timeout
  1045. if redis_params.password then
  1046. opts.password = redis_params.password
  1047. end
  1048. if redis_params.db then
  1049. opts.dbname = redis_params.db
  1050. end
  1051. if opts.callback then
  1052. local ret,conn = rspamd_redis.connect(opts)
  1053. if not ret then
  1054. logger.errx(log_obj, 'cannot execute redis connect')
  1055. addr:fail()
  1056. end
  1057. return ret,conn,addr
  1058. else
  1059. -- Coroutines version
  1060. local ret,conn = rspamd_redis.connect_sync(opts)
  1061. if not ret then
  1062. logger.errx(log_obj, 'cannot execute redis connect')
  1063. addr:fail()
  1064. else
  1065. return true,conn,addr
  1066. end
  1067. return false,nil,addr
  1068. end
  1069. end
  1070. return exports