You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_redis.lua 28KB


  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local exports = {}
  17. local E = {}
  18. --[[[
  19. -- @module lua_redis
  20. -- This module contains helper functions for working with Redis
  21. --]]
  22. local function try_load_redis_servers(options, rspamd_config, result)
  23. local default_port = 6379
  24. local default_timeout = 1.0
  25. local default_expand_keys = false
  26. local upstream_list = require "rspamd_upstream_list"
  27. -- Try to get read servers:
  28. local upstreams_read, upstreams_write
  29. if options['read_servers'] then
  30. if rspamd_config then
  31. upstreams_read = upstream_list.create(rspamd_config,
  32. options['read_servers'], default_port)
  33. else
  34. upstreams_read = upstream_list.create(options['read_servers'],
  35. default_port)
  36. end
  37. elseif options['servers'] then
  38. if rspamd_config then
  39. upstreams_read = upstream_list.create(rspamd_config,
  40. options['servers'], default_port)
  41. else
  42. upstreams_read = upstream_list.create(options['servers'], default_port)
  43. end
  44. elseif options['server'] then
  45. if rspamd_config then
  46. upstreams_read = upstream_list.create(rspamd_config,
  47. options['server'], default_port)
  48. else
  49. upstreams_read = upstream_list.create(options['server'], default_port)
  50. end
  51. end
  52. if upstreams_read then
  53. if options['write_servers'] then
  54. if rspamd_config then
  55. upstreams_write = upstream_list.create(rspamd_config,
  56. options['write_servers'], default_port)
  57. else
  58. upstreams_write = upstream_list.create(options['write_servers'],
  59. default_port)
  60. end
  61. else
  62. upstreams_write = upstreams_read
  63. end
  64. end
  65. -- Store options
  66. if not result['timeout'] or result['timeout'] == default_timeout then
  67. if options['timeout'] then
  68. result['timeout'] = tonumber(options['timeout'])
  69. else
  70. result['timeout'] = default_timeout
  71. end
  72. end
  73. if options['prefix'] and not result['prefix'] then
  74. result['prefix'] = options['prefix']
  75. end
  76. if type(options['expand_keys']) == 'boolean' then
  77. result['expand_keys'] = options['expand_keys']
  78. else
  79. result['expand_keys'] = default_expand_keys
  80. end
  81. if not result['db'] then
  82. if options['db'] then
  83. result['db'] = tostring(options['db'])
  84. elseif options['dbname'] then
  85. result['db'] = tostring(options['dbname'])
  86. elseif options['database'] then
  87. result['db'] = tostring(options['database'])
  88. end
  89. end
  90. if options['password'] and not result['password'] then
  91. result['password'] = options['password']
  92. end
  93. if upstreams_write and upstreams_read then
  94. result.read_servers = upstreams_read
  95. result.write_servers = upstreams_write
  96. return true
  97. end
  98. return false
  99. end
  100. exports.try_load_redis_servers = try_load_redis_servers
  101. -- This function parses redis server definition using either
  102. -- specific server string for this module or global
  103. -- redis section
  104. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  105. local result = {}
  106. -- Try local options
  107. local opts
  108. if not module_opts then
  109. opts = rspamd_config:get_all_opt(module_name)
  110. else
  111. opts = module_opts
  112. end
  113. if opts then
  114. local ret
  115. if opts.redis then
  116. ret = try_load_redis_servers(opts.redis, rspamd_config, result)
  117. if ret then
  118. return result
  119. end
  120. end
  121. ret = try_load_redis_servers(opts, rspamd_config, result)
  122. if ret then
  123. return result
  124. end
  125. end
  126. if no_fallback then return nil end
  127. -- Try global options
  128. opts = rspamd_config:get_all_opt('redis')
  129. if opts then
  130. local ret
  131. if opts[module_name] then
  132. ret = try_load_redis_servers(opts[module_name], rspamd_config, result)
  133. if ret then
  134. return result
  135. end
  136. else
  137. ret = try_load_redis_servers(opts, rspamd_config, result)
  138. -- Exclude disabled
  139. if opts['disabled_modules'] then
  140. for _,v in ipairs(opts['disabled_modules']) do
  141. if v == module_name then
  142. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  143. module_name)
  144. return nil
  145. end
  146. end
  147. end
  148. if ret then
  149. logger.infox(rspamd_config, "using default redis server for module %s",
  150. module_name)
  151. end
  152. end
  153. end
  154. if result.read_servers then
  155. return result
  156. else
  157. return nil
  158. end
  159. end
  160. --[[[
  161. -- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
  162. -- Extracts Redis server settings from configuration
  163. -- @param {string} module_name name of module to get settings for
  164. -- @param {table} module_opts settings for module or `nil` to fetch them from configuration
  165. -- @param {boolean} no_fallback should be `true` if global settings must not be used
  166. -- @return {table} redis server settings
  167. -- @example
  168. -- local rconfig = lua_redis.parse_redis_server('my_module')
  169. -- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
  170. -- -- ['timeout'] contains timeout in seconds
  171. -- -- ['expand_keys'] if true tells that redis key expansion is enabled
  172. --]]
  173. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  174. exports.parse_redis_server = rspamd_parse_redis_server
  175. local process_cmd = {
  176. bitop = function(args)
  177. local idx_l = {}
  178. for i = 2, #args do
  179. table.insert(idx_l, i)
  180. end
  181. return idx_l
  182. end,
  183. blpop = function(args)
  184. local idx_l = {}
  185. for i = 1, #args -1 do
  186. table.insert(idx_l, i)
  187. end
  188. return idx_l
  189. end,
  190. eval = function(args)
  191. local idx_l = {}
  192. local numkeys = args[2]
  193. if numkeys and tonumber(numkeys) >= 1 then
  194. for i = 3, numkeys + 2 do
  195. table.insert(idx_l, i)
  196. end
  197. end
  198. return idx_l
  199. end,
  200. set = function(args)
  201. return {1}
  202. end,
  203. mget = function(args)
  204. local idx_l = {}
  205. for i = 1, #args do
  206. table.insert(idx_l, i)
  207. end
  208. return idx_l
  209. end,
  210. mset = function(args)
  211. local idx_l = {}
  212. for i = 1, #args, 2 do
  213. table.insert(idx_l, i)
  214. end
  215. return idx_l
  216. end,
  217. sdiffstore = function(args)
  218. local idx_l = {}
  219. for i = 2, #args do
  220. table.insert(idx_l, i)
  221. end
  222. return idx_l
  223. end,
  224. smove = function(args)
  225. return {1, 2}
  226. end,
  227. script = function() end
  228. }
  229. process_cmd.append = process_cmd.set
  230. process_cmd.auth = process_cmd.script
  231. process_cmd.bgrewriteaof = process_cmd.script
  232. process_cmd.bgsave = process_cmd.script
  233. process_cmd.bitcount = process_cmd.set
  234. process_cmd.bitfield = process_cmd.set
  235. process_cmd.bitpos = process_cmd.set
  236. process_cmd.brpop = process_cmd.blpop
  237. process_cmd.brpoplpush = process_cmd.blpop
  238. process_cmd.client = process_cmd.script
  239. process_cmd.cluster = process_cmd.script
  240. process_cmd.command = process_cmd.script
  241. process_cmd.config = process_cmd.script
  242. process_cmd.dbsize = process_cmd.script
  243. process_cmd.debug = process_cmd.script
  244. process_cmd.decr = process_cmd.set
  245. process_cmd.decrby = process_cmd.set
  246. process_cmd.del = process_cmd.mget
  247. process_cmd.discard = process_cmd.script
  248. process_cmd.dump = process_cmd.set
  249. process_cmd.echo = process_cmd.script
  250. process_cmd.evalsha = process_cmd.eval
  251. process_cmd.exec = process_cmd.script
  252. process_cmd.exists = process_cmd.mget
  253. process_cmd.expire = process_cmd.set
  254. process_cmd.expireat = process_cmd.set
  255. process_cmd.flushall = process_cmd.script
  256. process_cmd.flushdb = process_cmd.script
  257. process_cmd.geoadd = process_cmd.set
  258. process_cmd.geohash = process_cmd.set
  259. process_cmd.geopos = process_cmd.set
  260. process_cmd.geodist = process_cmd.set
  261. process_cmd.georadius = process_cmd.set
  262. process_cmd.georadiusbymember = process_cmd.set
  263. process_cmd.get = process_cmd.set
  264. process_cmd.getbit = process_cmd.set
  265. process_cmd.getrange = process_cmd.set
  266. process_cmd.getset = process_cmd.set
  267. process_cmd.hdel = process_cmd.set
  268. process_cmd.hexists = process_cmd.set
  269. process_cmd.hget = process_cmd.set
  270. process_cmd.hgetall = process_cmd.set
  271. process_cmd.hincrby = process_cmd.set
  272. process_cmd.hincrbyfloat = process_cmd.set
  273. process_cmd.hkeys = process_cmd.set
  274. process_cmd.hlen = process_cmd.set
  275. process_cmd.hmget = process_cmd.set
  276. process_cmd.hmset = process_cmd.set
  277. process_cmd.hscan = process_cmd.set
  278. process_cmd.hset = process_cmd.set
  279. process_cmd.hsetnx = process_cmd.set
  280. process_cmd.hstrlen = process_cmd.set
  281. process_cmd.hvals = process_cmd.set
  282. process_cmd.incr = process_cmd.set
  283. process_cmd.incrby = process_cmd.set
  284. process_cmd.incrbyfloat = process_cmd.set
  285. process_cmd.info = process_cmd.script
  286. process_cmd.keys = process_cmd.script
  287. process_cmd.lastsave = process_cmd.script
  288. process_cmd.lindex = process_cmd.set
  289. process_cmd.linsert = process_cmd.set
  290. process_cmd.llen = process_cmd.set
  291. process_cmd.lpop = process_cmd.set
  292. process_cmd.lpush = process_cmd.set
  293. process_cmd.lpushx = process_cmd.set
  294. process_cmd.lrange = process_cmd.set
  295. process_cmd.lrem = process_cmd.set
  296. process_cmd.lset = process_cmd.set
  297. process_cmd.ltrim = process_cmd.set
  298. process_cmd.migrate = process_cmd.script
  299. process_cmd.monitor = process_cmd.script
  300. process_cmd.move = process_cmd.set
  301. process_cmd.msetnx = process_cmd.mset
  302. process_cmd.multi = process_cmd.script
  303. process_cmd.object = process_cmd.script
  304. process_cmd.persist = process_cmd.set
  305. process_cmd.pexpire = process_cmd.set
  306. process_cmd.pexpireat = process_cmd.set
  307. process_cmd.pfadd = process_cmd.set
  308. process_cmd.pfcount = process_cmd.set
  309. process_cmd.pfmerge = process_cmd.mget
  310. process_cmd.ping = process_cmd.script
  311. process_cmd.psetex = process_cmd.set
  312. process_cmd.psubscribe = process_cmd.script
  313. process_cmd.pubsub = process_cmd.script
  314. process_cmd.pttl = process_cmd.set
  315. process_cmd.publish = process_cmd.script
  316. process_cmd.punsubscribe = process_cmd.script
  317. process_cmd.quit = process_cmd.script
  318. process_cmd.randomkey = process_cmd.script
  319. process_cmd.readonly = process_cmd.script
  320. process_cmd.readwrite = process_cmd.script
  321. process_cmd.rename = process_cmd.mget
  322. process_cmd.renamenx = process_cmd.mget
  323. process_cmd.restore = process_cmd.set
  324. process_cmd.role = process_cmd.script
  325. process_cmd.rpop = process_cmd.set
  326. process_cmd.rpoplpush = process_cmd.mget
  327. process_cmd.rpush = process_cmd.set
  328. process_cmd.rpushx = process_cmd.set
  329. process_cmd.sadd = process_cmd.set
  330. process_cmd.save = process_cmd.script
  331. process_cmd.scard = process_cmd.set
  332. process_cmd.sdiff = process_cmd.mget
  333. process_cmd.select = process_cmd.script
  334. process_cmd.setbit = process_cmd.set
  335. process_cmd.setex = process_cmd.set
  336. process_cmd.setnx = process_cmd.set
  337. process_cmd.sinterstore = process_cmd.sdiff
  338. process_cmd.sismember = process_cmd.set
  339. process_cmd.slaveof = process_cmd.script
  340. process_cmd.slowlog = process_cmd.script
  341. process_cmd.smembers = process_cmd.script
  342. process_cmd.sort = process_cmd.set
  343. process_cmd.spop = process_cmd.set
  344. process_cmd.srandmember = process_cmd.set
  345. process_cmd.srem = process_cmd.set
  346. process_cmd.strlen = process_cmd.set
  347. process_cmd.subscribe = process_cmd.script
  348. process_cmd.sunion = process_cmd.mget
  349. process_cmd.sunionstore = process_cmd.mget
  350. process_cmd.swapdb = process_cmd.script
  351. process_cmd.sync = process_cmd.script
  352. process_cmd.time = process_cmd.script
  353. process_cmd.touch = process_cmd.mget
  354. process_cmd.ttl = process_cmd.set
  355. process_cmd.type = process_cmd.set
  356. process_cmd.unsubscribe = process_cmd.script
  357. process_cmd.unlink = process_cmd.mget
  358. process_cmd.unwatch = process_cmd.script
  359. process_cmd.wait = process_cmd.script
  360. process_cmd.watch = process_cmd.mget
  361. process_cmd.zadd = process_cmd.set
  362. process_cmd.zcard = process_cmd.set
  363. process_cmd.zcount = process_cmd.set
  364. process_cmd.zincrby = process_cmd.set
  365. process_cmd.zinterstore = process_cmd.eval
  366. process_cmd.zlexcount = process_cmd.set
  367. process_cmd.zrange = process_cmd.set
  368. process_cmd.zrangebylex = process_cmd.set
  369. process_cmd.zrank = process_cmd.set
  370. process_cmd.zrem = process_cmd.set
  371. process_cmd.zrembylex = process_cmd.set
  372. process_cmd.zrembyrank = process_cmd.set
  373. process_cmd.zrembyscore = process_cmd.set
  374. process_cmd.zrevrange = process_cmd.set
  375. process_cmd.zrevrangebyscore = process_cmd.set
  376. process_cmd.zrevrank = process_cmd.set
  377. process_cmd.zscore = process_cmd.set
  378. process_cmd.zunionstore = process_cmd.eval
  379. process_cmd.scan = process_cmd.script
  380. process_cmd.sscan = process_cmd.set
  381. process_cmd.hscan = process_cmd.set
  382. process_cmd.zscan = process_cmd.set
  383. local function get_key_indexes(cmd, args)
  384. local idx_l = {}
  385. cmd = string.lower(cmd)
  386. if process_cmd[cmd] then
  387. idx_l = process_cmd[cmd](args)
  388. else
  389. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  390. end
  391. return idx_l
  392. end
  393. local gen_meta = {
  394. principal_recipient = function(task)
  395. return task:get_principal_recipient()
  396. end,
  397. principal_recipient_domain = function(task)
  398. local p = task:get_principal_recipient()
  399. if not p then return end
  400. return string.match(p, '.*@(.*)')
  401. end,
  402. ip = function(task)
  403. local i = task:get_ip()
  404. if i and i:is_valid() then return i:to_string() end
  405. end,
  406. from = function(task)
  407. return ((task:get_from('smtp') or E)[1] or E)['addr']
  408. end,
  409. from_domain = function(task)
  410. return ((task:get_from('smtp') or E)[1] or E)['domain']
  411. end,
  412. from_domain_or_helo_domain = function(task)
  413. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  414. if d and #d > 0 then return d end
  415. return task:get_helo()
  416. end,
  417. mime_from = function(task)
  418. return ((task:get_from('mime') or E)[1] or E)['addr']
  419. end,
  420. mime_from_domain = function(task)
  421. return ((task:get_from('mime') or E)[1] or E)['domain']
  422. end,
  423. }
  424. local function gen_get_esld(f)
  425. return function(task)
  426. local d = f(task)
  427. if not d then return end
  428. return rspamd_util.get_tld(d)
  429. end
  430. end
  431. gen_meta.smtp_from = gen_meta.from
  432. gen_meta.smtp_from_domain = gen_meta.from_domain
  433. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  434. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  435. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  436. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  437. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  438. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  439. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  440. local function get_key_expansion_metadata(task)
  441. local md_mt = {
  442. __index = function(self, k)
  443. k = string.lower(k)
  444. local v = rawget(self, k)
  445. if v then
  446. return v
  447. end
  448. if gen_meta[k] then
  449. v = gen_meta[k](task)
  450. rawset(self, k, v)
  451. end
  452. return v
  453. end,
  454. }
  455. local lazy_meta = {}
  456. setmetatable(lazy_meta, md_mt)
  457. return lazy_meta
  458. end
  459. -- Performs async call to redis hiding all complexity inside function
  460. -- task - rspamd_task
  461. -- redis_params - valid params returned by rspamd_parse_redis_server
  462. -- key - key to select upstream or nil to select round-robin/master-slave
  463. -- is_write - true if need to write to redis server
  464. -- callback - function to be called upon request is completed
  465. -- command - redis command
  466. -- args - table of arguments
  467. -- extra_opts - table of optional request arguments
  468. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  469. callback, command, args, extra_opts)
  470. local addr
  471. local function rspamd_redis_make_request_cb(err, data)
  472. if err then
  473. addr:fail()
  474. else
  475. addr:ok()
  476. end
  477. callback(err, data, addr)
  478. end
  479. if not task or not redis_params or not callback or not command then
  480. return false,nil,nil
  481. end
  482. local rspamd_redis = require "rspamd_redis"
  483. if key then
  484. if is_write then
  485. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  486. else
  487. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  488. end
  489. else
  490. if is_write then
  491. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  492. else
  493. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  494. end
  495. end
  496. if not addr then
  497. logger.errx(task, 'cannot select server to make redis request')
  498. end
  499. if redis_params['expand_keys'] then
  500. local m = get_key_expansion_metadata(task)
  501. local indexes = get_key_indexes(command, args)
  502. for _, i in ipairs(indexes) do
  503. args[i] = lutil.template(args[i], m)
  504. end
  505. end
  506. local ip_addr = addr:get_addr()
  507. local options = {
  508. task = task,
  509. callback = rspamd_redis_make_request_cb,
  510. host = ip_addr,
  511. timeout = redis_params['timeout'],
  512. cmd = command,
  513. args = args
  514. }
  515. if extra_opts then
  516. for k,v in pairs(extra_opts) do
  517. options[k] = v
  518. end
  519. end
  520. if redis_params['password'] then
  521. options['password'] = redis_params['password']
  522. end
  523. if redis_params['db'] then
  524. options['dbname'] = redis_params['db']
  525. end
  526. local ret,conn = rspamd_redis.make_request(options)
  527. if not ret then
  528. addr:fail()
  529. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  530. end
  531. return ret,conn,addr
  532. end
  533. --[[[
  534. -- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
  535. -- Sends a request to Redis
  536. -- @param {rspamd_task} task task object
  537. -- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
  538. -- @param {string} key key to use for sharding
  539. -- @param {boolean} is_write should be `true` if we are performing a write operating
  540. -- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
  541. -- @param {string} command Redis command to run
  542. -- @param {table} args Numerically indexed table containing arguments for command
  543. --]]
  544. exports.rspamd_redis_make_request = rspamd_redis_make_request
  545. exports.redis_make_request = rspamd_redis_make_request
  546. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  547. is_write, callback, command, args, extra_opts)
  548. if not ev_base or not redis_params or not callback or not command then
  549. return false,nil,nil
  550. end
  551. local addr
  552. local function rspamd_redis_make_request_cb(err, data)
  553. if err then
  554. addr:fail()
  555. else
  556. addr:ok()
  557. end
  558. callback(err, data, addr)
  559. end
  560. local rspamd_redis = require "rspamd_redis"
  561. if key then
  562. if is_write then
  563. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  564. else
  565. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  566. end
  567. else
  568. if is_write then
  569. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  570. else
  571. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  572. end
  573. end
  574. if not addr then
  575. logger.errx(cfg, 'cannot select server to make redis request')
  576. end
  577. local options = {
  578. ev_base = ev_base,
  579. config = cfg,
  580. callback = rspamd_redis_make_request_cb,
  581. host = addr:get_addr(),
  582. timeout = redis_params['timeout'],
  583. cmd = command,
  584. args = args
  585. }
  586. if extra_opts then
  587. for k,v in pairs(extra_opts) do
  588. options[k] = v
  589. end
  590. end
  591. if redis_params['password'] then
  592. options['password'] = redis_params['password']
  593. end
  594. if redis_params['db'] then
  595. options['dbname'] = redis_params['db']
  596. end
  597. local ret,conn = rspamd_redis.make_request(options)
  598. if not ret then
  599. logger.errx('cannot execute redis request')
  600. addr:fail()
  601. end
  602. return ret,conn,addr
  603. end
  604. --[[[
  605. -- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
  606. -- Sends a request to Redis in context where `task` is not available for some specific use-cases
  607. -- Identical to redis_make_request() except in that first parameter is an `event base` object
  608. --]]
  609. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  610. exports.redis_make_request_taskless = redis_make_request_taskless
  611. local redis_scripts = {
  612. }
  613. local function script_set_loaded(script)
  614. if script.sha then
  615. script.loaded = true
  616. end
  617. local wait_table = {}
  618. for _,s in ipairs(script.waitq) do
  619. table.insert(wait_table, s)
  620. end
  621. script.waitq = {}
  622. for _,s in ipairs(wait_table) do
  623. s(script.loaded)
  624. end
  625. end
  626. local function prepare_redis_call(script)
  627. local function merge_tables(t1, t2)
  628. for k,v in pairs(t2) do t1[k] = v end
  629. end
  630. local servers = {}
  631. local options = {}
  632. if script.redis_params.read_servers then
  633. merge_tables(servers, script.redis_params.read_servers:all_upstreams())
  634. end
  635. if script.redis_params.write_servers then
  636. merge_tables(servers, script.redis_params.write_servers:all_upstreams())
  637. end
  638. -- Call load script on each server, set loaded flag
  639. script.in_flight = #servers
  640. for _,s in ipairs(servers) do
  641. local cur_opts = {
  642. host = s:get_addr(),
  643. timeout = script.redis_params['timeout'],
  644. cmd = 'SCRIPT',
  645. args = {'LOAD', script.script },
  646. upstream = s
  647. }
  648. if script.redis_params['password'] then
  649. cur_opts['password'] = script.redis_params['password']
  650. end
  651. if script.redis_params['db'] then
  652. cur_opts['dbname'] = script.redis_params['db']
  653. end
  654. table.insert(options, cur_opts)
  655. end
  656. return options
  657. end
  658. local function load_script_task(script, task)
  659. local rspamd_redis = require "rspamd_redis"
  660. local opts = prepare_redis_call(script)
  661. for _,opt in ipairs(opts) do
  662. opt.task = task
  663. opt.callback = function(err, data)
  664. if err then
  665. logger.errx(task, 'cannot upload script to %s: %s',
  666. opt.upstream:get_addr(), err)
  667. opt.upstream:fail()
  668. script.fatal_error = err
  669. else
  670. opt.upstream:ok()
  671. logger.infox(task,
  672. "uploaded redis script to %s with id %s, sha: %s",
  673. opt.upstream:get_addr(), script.id, data)
  674. script.sha = data -- We assume that sha is the same on all servers
  675. end
  676. script.in_flight = script.in_flight - 1
  677. if script.in_flight == 0 then
  678. script_set_loaded(script)
  679. end
  680. end
  681. local ret = rspamd_redis.make_request(opt)
  682. if not ret then
  683. logger.errx('cannot execute redis request to load script on %s',
  684. opt.upstream:get_addr())
  685. script.in_flight = script.in_flight - 1
  686. opt.upstream:fail()
  687. end
  688. if script.in_flight == 0 then
  689. script_set_loaded(script)
  690. end
  691. end
  692. end
  693. local function load_script_taskless(script, cfg, ev_base)
  694. local rspamd_redis = require "rspamd_redis"
  695. local opts = prepare_redis_call(script)
  696. for _,opt in ipairs(opts) do
  697. opt.config = cfg
  698. opt.ev_base = ev_base
  699. opt.callback = function(err, data)
  700. if err then
  701. logger.errx(cfg, 'cannot upload script to %s: %s',
  702. opt.upstream:get_addr(), err)
  703. opt.upstream:fail()
  704. script.fatal_error = err
  705. else
  706. opt.upstream:ok()
  707. logger.infox(cfg,
  708. "uploaded redis script to %s with id %s, sha: %s",
  709. opt.upstream:get_addr(), script.id, data)
  710. script.sha = data -- We assume that sha is the same on all servers
  711. end
  712. script.in_flight = script.in_flight - 1
  713. if script.in_flight == 0 then
  714. script_set_loaded(script)
  715. end
  716. end
  717. local ret = rspamd_redis.make_request(opt)
  718. if not ret then
  719. logger.errx('cannot execute redis request to load script on %s',
  720. opt.upstream:get_addr())
  721. script.in_flight = script.in_flight - 1
  722. opt.upstream:fail()
  723. end
  724. if script.in_flight == 0 then
  725. script_set_loaded(script)
  726. end
  727. end
  728. end
  729. local function load_redis_script(script, cfg, ev_base, _)
  730. if script.redis_params then
  731. load_script_taskless(script, cfg, ev_base)
  732. end
  733. end
  734. local function add_redis_script(script, redis_params)
  735. local new_script = {
  736. loaded = false,
  737. redis_params = redis_params,
  738. script = script,
  739. waitq = {}, -- callbacks pending for script being loaded
  740. id = #redis_scripts + 1
  741. }
  742. -- Register on load function
  743. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  744. load_redis_script(new_script, cfg, ev_base, worker)
  745. end)
  746. table.insert(redis_scripts, new_script)
  747. return #redis_scripts
  748. end
  749. exports.add_redis_script = add_redis_script
  750. local function exec_redis_script(id, params, callback, keys, args)
  751. local redis_args = {}
  752. if not redis_scripts[id] then
  753. logger.errx("cannot find registered script with id %s", id)
  754. return false
  755. end
  756. local script = redis_scripts[id]
  757. if script.fatal_error then
  758. callback(script.fatal_error, nil)
  759. return true
  760. end
  761. if not script.redis_params then
  762. callback('no redis servers defined', nil)
  763. return true
  764. end
  765. local function do_call(can_reload)
  766. local function redis_cb(err, data)
  767. if not err then
  768. callback(err, data)
  769. elseif string.match(err, 'NOSCRIPT') then
  770. -- Schedule restart
  771. script.sha = nil
  772. if can_reload then
  773. table.insert(script.waitq, do_call)
  774. if script.in_flight == 0 then
  775. -- Reload scripts if this has not been initiated yet
  776. if params.task then
  777. load_script_task(script, params.task)
  778. else
  779. load_script_taskless(script, rspamd_config, params.ev_base)
  780. end
  781. end
  782. else
  783. callback(err, data)
  784. end
  785. else
  786. callback(err, data)
  787. end
  788. end
  789. if #redis_args == 0 then
  790. table.insert(redis_args, script.sha)
  791. table.insert(redis_args, tostring(#keys))
  792. for _,k in ipairs(keys) do
  793. table.insert(redis_args, k)
  794. end
  795. if type(args) == 'table' then
  796. for _, a in ipairs(args) do
  797. table.insert(redis_args, a)
  798. end
  799. end
  800. end
  801. if params.task then
  802. if not rspamd_redis_make_request(params.task, script.redis_params,
  803. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  804. callback('Cannot make redis request', nil)
  805. end
  806. else
  807. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  808. script.redis_params,
  809. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  810. callback('Cannot make redis request', nil)
  811. end
  812. end
  813. end
  814. if script.loaded then
  815. do_call(true)
  816. else
  817. -- Delayed until scripts are loaded
  818. if not params.task then
  819. table.insert(script.waitq, do_call)
  820. else
  821. -- TODO: fix taskfull requests
  822. callback('NOSCRIPT', nil)
  823. end
  824. end
  825. return true
  826. end
  827. exports.exec_redis_script = exec_redis_script
  828. local function redis_connect_sync(redis_params, is_write, key, cfg)
  829. if not redis_params then
  830. return false,nil
  831. end
  832. local rspamd_redis = require "rspamd_redis"
  833. local addr
  834. if key then
  835. if is_write then
  836. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  837. else
  838. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  839. end
  840. else
  841. if is_write then
  842. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  843. else
  844. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  845. end
  846. end
  847. if not addr then
  848. logger.errx(cfg, 'cannot select server to make redis request')
  849. end
  850. local options = {
  851. host = addr:get_addr(),
  852. timeout = redis_params['timeout'],
  853. }
  854. local ret,conn = rspamd_redis.connect_sync(options)
  855. if not ret then
  856. logger.errx('cannot execute redis request: %s', conn)
  857. addr:fail()
  858. return false,nil,addr
  859. end
  860. if conn then
  861. if redis_params['password'] then
  862. conn:add_cmd('AUTH', {redis_params['password']})
  863. end
  864. if redis_params['db'] then
  865. conn:add_cmd('SELECT', {tostring(redis_params['db'])})
  866. elseif redis_params['dbname'] then
  867. conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
  868. end
  869. end
  870. return ret,conn,addr
  871. end
  872. exports.redis_connect_sync = redis_connect_sync
  873. return exports