You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_redis.lua 24KB


  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local exports = {}
  17. local E = {}
  18. -- This function parses redis server definition using either
  19. -- specific server string for this module or global
  20. -- redis section
  21. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  22. local result = {}
  23. local default_port = 6379
  24. local default_timeout = 1.0
  25. local default_expand_keys = false
  26. local upstream_list = require "rspamd_upstream_list"
  27. local function try_load_redis_servers(options)
  28. -- Try to get read servers:
  29. local upstreams_read, upstreams_write
  30. if options['read_servers'] then
  31. upstreams_read = upstream_list.create(rspamd_config,
  32. options['read_servers'], default_port)
  33. elseif options['servers'] then
  34. upstreams_read = upstream_list.create(rspamd_config,
  35. options['servers'], default_port)
  36. elseif options['server'] then
  37. upstreams_read = upstream_list.create(rspamd_config,
  38. options['server'], default_port)
  39. end
  40. if upstreams_read then
  41. if options['write_servers'] then
  42. upstreams_write = upstream_list.create(rspamd_config,
  43. options['write_servers'], default_port)
  44. else
  45. upstreams_write = upstreams_read
  46. end
  47. end
  48. -- Store options
  49. if not result['timeout'] or result['timeout'] == default_timeout then
  50. if options['timeout'] then
  51. result['timeout'] = tonumber(options['timeout'])
  52. else
  53. result['timeout'] = default_timeout
  54. end
  55. end
  56. if options['prefix'] and not result['prefix'] then
  57. result['prefix'] = options['prefix']
  58. end
  59. if type(options['expand_keys']) == 'boolean' then
  60. result['expand_keys'] = options['expand_keys']
  61. else
  62. result['expand_keys'] = default_expand_keys
  63. end
  64. if not result['db'] then
  65. if options['db'] then
  66. result['db'] = tostring(options['db'])
  67. elseif options['dbname'] then
  68. result['db'] = tostring(options['dbname'])
  69. end
  70. end
  71. if options['password'] and not result['password'] then
  72. result['password'] = options['password']
  73. end
  74. if upstreams_write and upstreams_read then
  75. result.read_servers = upstreams_read
  76. result.write_servers = upstreams_write
  77. return true
  78. end
  79. return false
  80. end
  81. -- Try local options
  82. local opts
  83. if not module_opts then
  84. opts = rspamd_config:get_all_opt(module_name)
  85. else
  86. opts = module_opts
  87. end
  88. if opts then
  89. local ret
  90. if opts.redis then
  91. ret = try_load_redis_servers(opts.redis, result)
  92. if ret then
  93. return result
  94. end
  95. end
  96. ret = try_load_redis_servers(opts, result)
  97. if ret then
  98. return result
  99. end
  100. end
  101. if no_fallback then return nil end
  102. -- Try global options
  103. opts = rspamd_config:get_all_opt('redis')
  104. if opts then
  105. local ret
  106. if opts[module_name] then
  107. ret = try_load_redis_servers(opts[module_name], result)
  108. if ret then
  109. return result
  110. end
  111. else
  112. ret = try_load_redis_servers(opts, result)
  113. -- Exclude disabled
  114. if opts['disabled_modules'] then
  115. for _,v in ipairs(opts['disabled_modules']) do
  116. if v == module_name then
  117. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  118. module_name)
  119. return nil
  120. end
  121. end
  122. end
  123. if ret then
  124. logger.infox(rspamd_config, "using default redis server for module %s",
  125. module_name)
  126. end
  127. end
  128. end
  129. if result.read_servers then
  130. return result
  131. else
  132. return nil
  133. end
  134. end
  135. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  136. exports.parse_redis_server = rspamd_parse_redis_server
  137. local process_cmd = {
  138. bitop = function(args)
  139. local idx_l = {}
  140. for i = 2, #args do
  141. table.insert(idx_l, i)
  142. end
  143. return idx_l
  144. end,
  145. blpop = function(args)
  146. local idx_l = {}
  147. for i = 1, #args -1 do
  148. table.insert(idx_l, i)
  149. end
  150. return idx_l
  151. end,
  152. eval = function(args)
  153. local idx_l = {}
  154. local numkeys = args[2]
  155. if numkeys >= 1 then
  156. for i = 3, numkeys + 2 do
  157. table.insert(idx_l, i)
  158. end
  159. end
  160. return idx_l
  161. end,
  162. set = function(args)
  163. return {1}
  164. end,
  165. mget = function(args)
  166. local idx_l = {}
  167. for i = 1, #args do
  168. table.insert(idx_l, i)
  169. end
  170. return idx_l
  171. end,
  172. mset = function(args)
  173. local idx_l = {}
  174. for i = 1, #args, 2 do
  175. table.insert(idx_l, i)
  176. end
  177. return idx_l
  178. end,
  179. sdiffstore = function(args)
  180. local idx_l = {}
  181. for i = 2, #args do
  182. table.insert(idx_l, i)
  183. end
  184. return idx_l
  185. end,
  186. smove = function(args)
  187. return {1, 2}
  188. end,
  189. script = function() end
  190. }
  191. process_cmd.append = process_cmd.set
  192. process_cmd.auth = process_cmd.script
  193. process_cmd.bgrewriteaof = process_cmd.script
  194. process_cmd.bgsave = process_cmd.script
  195. process_cmd.bitcount = process_cmd.set
  196. process_cmd.bitfield = process_cmd.set
  197. process_cmd.bitpos = process_cmd.set
  198. process_cmd.brpop = process_cmd.blpop
  199. process_cmd.brpoplpush = process_cmd.blpop
  200. process_cmd.client = process_cmd.script
  201. process_cmd.cluster = process_cmd.script
  202. process_cmd.command = process_cmd.script
  203. process_cmd.config = process_cmd.script
  204. process_cmd.dbsize = process_cmd.script
  205. process_cmd.debug = process_cmd.script
  206. process_cmd.decr = process_cmd.set
  207. process_cmd.decrby = process_cmd.set
  208. process_cmd.del = process_cmd.mget
  209. process_cmd.discard = process_cmd.script
  210. process_cmd.dump = process_cmd.set
  211. process_cmd.echo = process_cmd.script
  212. process_cmd.evalsha = process_cmd.eval
  213. process_cmd.exec = process_cmd.script
  214. process_cmd.exists = process_cmd.mget
  215. process_cmd.expire = process_cmd.set
  216. process_cmd.expireat = process_cmd.set
  217. process_cmd.flushall = process_cmd.script
  218. process_cmd.flushdb = process_cmd.script
  219. process_cmd.geoadd = process_cmd.set
  220. process_cmd.geohash = process_cmd.set
  221. process_cmd.geopos = process_cmd.set
  222. process_cmd.geodist = process_cmd.set
  223. process_cmd.georadius = process_cmd.set
  224. process_cmd.georadiusbymember = process_cmd.set
  225. process_cmd.get = process_cmd.set
  226. process_cmd.getbit = process_cmd.set
  227. process_cmd.getrange = process_cmd.set
  228. process_cmd.getset = process_cmd.set
  229. process_cmd.hdel = process_cmd.set
  230. process_cmd.hexists = process_cmd.set
  231. process_cmd.hget = process_cmd.set
  232. process_cmd.hgetall = process_cmd.set
  233. process_cmd.hincrby = process_cmd.set
  234. process_cmd.hincrbyfloat = process_cmd.set
  235. process_cmd.hkeys = process_cmd.set
  236. process_cmd.hlen = process_cmd.set
  237. process_cmd.hmget = process_cmd.set
  238. process_cmd.hmset = process_cmd.set
  239. process_cmd.hscan = process_cmd.set
  240. process_cmd.hset = process_cmd.set
  241. process_cmd.hsetnx = process_cmd.set
  242. process_cmd.hstrlen = process_cmd.set
  243. process_cmd.hvals = process_cmd.set
  244. process_cmd.incr = process_cmd.set
  245. process_cmd.incrby = process_cmd.set
  246. process_cmd.incrbyfloat = process_cmd.set
  247. process_cmd.info = process_cmd.script
  248. process_cmd.keys = process_cmd.script
  249. process_cmd.lastsave = process_cmd.script
  250. process_cmd.lindex = process_cmd.set
  251. process_cmd.linsert = process_cmd.set
  252. process_cmd.llen = process_cmd.set
  253. process_cmd.lpop = process_cmd.set
  254. process_cmd.lpush = process_cmd.set
  255. process_cmd.lpushx = process_cmd.set
  256. process_cmd.lrange = process_cmd.set
  257. process_cmd.lrem = process_cmd.set
  258. process_cmd.lset = process_cmd.set
  259. process_cmd.ltrim = process_cmd.set
  260. process_cmd.migrate = process_cmd.script
  261. process_cmd.monitor = process_cmd.script
  262. process_cmd.move = process_cmd.set
  263. process_cmd.msetnx = process_cmd.mset
  264. process_cmd.multi = process_cmd.script
  265. process_cmd.object = process_cmd.script
  266. process_cmd.persist = process_cmd.set
  267. process_cmd.pexpire = process_cmd.set
  268. process_cmd.pexpireat = process_cmd.set
  269. process_cmd.pfadd = process_cmd.set
  270. process_cmd.pfcount = process_cmd.set
  271. process_cmd.pfmerge = process_cmd.mget
  272. process_cmd.ping = process_cmd.script
  273. process_cmd.psetex = process_cmd.set
  274. process_cmd.psubscribe = process_cmd.script
  275. process_cmd.pubsub = process_cmd.script
  276. process_cmd.pttl = process_cmd.set
  277. process_cmd.publish = process_cmd.script
  278. process_cmd.punsubscribe = process_cmd.script
  279. process_cmd.quit = process_cmd.script
  280. process_cmd.randomkey = process_cmd.script
  281. process_cmd.readonly = process_cmd.script
  282. process_cmd.readwrite = process_cmd.script
  283. process_cmd.rename = process_cmd.mget
  284. process_cmd.renamenx = process_cmd.mget
  285. process_cmd.restore = process_cmd.set
  286. process_cmd.role = process_cmd.script
  287. process_cmd.rpop = process_cmd.set
  288. process_cmd.rpoplpush = process_cmd.mget
  289. process_cmd.rpush = process_cmd.set
  290. process_cmd.rpushx = process_cmd.set
  291. process_cmd.sadd = process_cmd.set
  292. process_cmd.save = process_cmd.script
  293. process_cmd.scard = process_cmd.set
  294. process_cmd.sdiff = process_cmd.mget
  295. process_cmd.select = process_cmd.script
  296. process_cmd.setbit = process_cmd.set
  297. process_cmd.setex = process_cmd.set
  298. process_cmd.setnx = process_cmd.set
  299. process_cmd.sinterstore = process_cmd.sdiff
  300. process_cmd.sismember = process_cmd.set
  301. process_cmd.slaveof = process_cmd.script
  302. process_cmd.slowlog = process_cmd.script
  303. process_cmd.smembers = process_cmd.script
  304. process_cmd.sort = process_cmd.set
  305. process_cmd.spop = process_cmd.set
  306. process_cmd.srandmember = process_cmd.set
  307. process_cmd.srem = process_cmd.set
  308. process_cmd.strlen = process_cmd.set
  309. process_cmd.subscribe = process_cmd.script
  310. process_cmd.sunion = process_cmd.mget
  311. process_cmd.sunionstore = process_cmd.mget
  312. process_cmd.swapdb = process_cmd.script
  313. process_cmd.sync = process_cmd.script
  314. process_cmd.time = process_cmd.script
  315. process_cmd.touch = process_cmd.mget
  316. process_cmd.ttl = process_cmd.set
  317. process_cmd.type = process_cmd.set
  318. process_cmd.unsubscribe = process_cmd.script
  319. process_cmd.unlink = process_cmd.mget
  320. process_cmd.unwatch = process_cmd.script
  321. process_cmd.wait = process_cmd.script
  322. process_cmd.watch = process_cmd.mget
  323. process_cmd.zadd = process_cmd.set
  324. process_cmd.zcard = process_cmd.set
  325. process_cmd.zcount = process_cmd.set
  326. process_cmd.zincrby = process_cmd.set
  327. process_cmd.zinterstore = process_cmd.eval
  328. process_cmd.zlexcount = process_cmd.set
  329. process_cmd.zrange = process_cmd.set
  330. process_cmd.zrangebylex = process_cmd.set
  331. process_cmd.zrank = process_cmd.set
  332. process_cmd.zrem = process_cmd.set
  333. process_cmd.zrembylex = process_cmd.set
  334. process_cmd.zrembyrank = process_cmd.set
  335. process_cmd.zrembyscore = process_cmd.set
  336. process_cmd.zrevrange = process_cmd.set
  337. process_cmd.zrevrangebyscore = process_cmd.set
  338. process_cmd.zrevrank = process_cmd.set
  339. process_cmd.zscore = process_cmd.set
  340. process_cmd.zunionstore = process_cmd.eval
  341. process_cmd.scan = process_cmd.script
  342. process_cmd.sscan = process_cmd.set
  343. process_cmd.hscan = process_cmd.set
  344. process_cmd.zscan = process_cmd.set
  345. local function get_key_indexes(cmd, args)
  346. local idx_l = {}
  347. cmd = string.lower(cmd)
  348. if process_cmd[cmd] then
  349. idx_l = process_cmd[cmd](args)
  350. else
  351. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  352. end
  353. return idx_l
  354. end
  355. local gen_meta = {
  356. principal_recipient = function(task)
  357. return task:get_principal_recipient()
  358. end,
  359. principal_recipient_domain = function(task)
  360. local p = task:get_principal_recipient()
  361. if not p then return end
  362. return string.match(p, '.*@(.*)')
  363. end,
  364. ip = function(task)
  365. local i = task:get_ip()
  366. if i and i:is_valid() then return i:to_string() end
  367. end,
  368. from = function(task)
  369. return ((task:get_from('smtp') or E)[1] or E)['addr']
  370. end,
  371. from_domain = function(task)
  372. return ((task:get_from('smtp') or E)[1] or E)['domain']
  373. end,
  374. from_domain_or_helo_domain = function(task)
  375. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  376. if d and #d > 0 then return d end
  377. return task:get_helo()
  378. end,
  379. mime_from = function(task)
  380. return ((task:get_from('mime') or E)[1] or E)['addr']
  381. end,
  382. mime_from_domain = function(task)
  383. return ((task:get_from('mime') or E)[1] or E)['domain']
  384. end,
  385. }
  386. local function gen_get_esld(f)
  387. return function(task)
  388. local d = f(task)
  389. if not d then return end
  390. return rspamd_util.get_tld(d)
  391. end
  392. end
  393. gen_meta.smtp_from = gen_meta.from
  394. gen_meta.smtp_from_domain = gen_meta.from_domain
  395. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  396. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  397. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  398. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  399. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  400. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  401. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  402. local function get_key_expansion_metadata(task)
  403. local md_mt = {
  404. __index = function(self, k)
  405. k = string.lower(k)
  406. local v = rawget(self, k)
  407. if v then
  408. return v
  409. end
  410. if gen_meta[k] then
  411. v = gen_meta[k](task)
  412. rawset(self, k, v)
  413. end
  414. return v
  415. end,
  416. }
  417. local lazy_meta = {}
  418. setmetatable(lazy_meta, md_mt)
  419. return lazy_meta
  420. end
  421. -- Performs async call to redis hiding all complexity inside function
  422. -- task - rspamd_task
  423. -- redis_params - valid params returned by rspamd_parse_redis_server
  424. -- key - key to select upstream or nil to select round-robin/master-slave
  425. -- is_write - true if need to write to redis server
  426. -- callback - function to be called upon request is completed
  427. -- command - redis command
  428. -- args - table of arguments
  429. local function rspamd_redis_make_request(task, redis_params, key, is_write, callback, command, args)
  430. local addr
  431. local function rspamd_redis_make_request_cb(err, data)
  432. if err then
  433. addr:fail()
  434. else
  435. addr:ok()
  436. end
  437. callback(err, data, addr)
  438. end
  439. if not task or not redis_params or not callback or not command then
  440. return false,nil,nil
  441. end
  442. local rspamd_redis = require "rspamd_redis"
  443. if key then
  444. if is_write then
  445. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  446. else
  447. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  448. end
  449. else
  450. if is_write then
  451. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  452. else
  453. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  454. end
  455. end
  456. if not addr then
  457. logger.errx(task, 'cannot select server to make redis request')
  458. end
  459. if redis_params['expand_keys'] then
  460. local m = get_key_expansion_metadata(task)
  461. local indexes = get_key_indexes(command, args)
  462. for _, i in ipairs(indexes) do
  463. args[i] = lutil.template(args[i], m)
  464. end
  465. end
  466. local ip_addr = addr:get_addr()
  467. local options = {
  468. task = task,
  469. callback = rspamd_redis_make_request_cb,
  470. host = ip_addr,
  471. timeout = redis_params['timeout'],
  472. cmd = command,
  473. args = args
  474. }
  475. if redis_params['password'] then
  476. options['password'] = redis_params['password']
  477. end
  478. if redis_params['db'] then
  479. options['dbname'] = redis_params['db']
  480. end
  481. local ret,conn = rspamd_redis.make_request(options)
  482. if not ret then
  483. addr:fail()
  484. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  485. end
  486. return ret,conn,addr
  487. end
  488. exports.rspamd_redis_make_request = rspamd_redis_make_request
  489. exports.redis_make_request = rspamd_redis_make_request
  490. local function redis_make_request_taskless(ev_base, cfg, redis_params, key, is_write, callback, command, args)
  491. if not ev_base or not redis_params or not callback or not command then
  492. return false,nil,nil
  493. end
  494. local addr
  495. local function rspamd_redis_make_request_cb(err, data)
  496. if err then
  497. addr:fail()
  498. else
  499. addr:ok()
  500. end
  501. callback(err, data, addr)
  502. end
  503. local rspamd_redis = require "rspamd_redis"
  504. if key then
  505. if is_write then
  506. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  507. else
  508. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  509. end
  510. else
  511. if is_write then
  512. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  513. else
  514. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  515. end
  516. end
  517. if not addr then
  518. logger.errx(cfg, 'cannot select server to make redis request')
  519. end
  520. local options = {
  521. ev_base = ev_base,
  522. config = cfg,
  523. callback = rspamd_redis_make_request_cb,
  524. host = addr:get_addr(),
  525. timeout = redis_params['timeout'],
  526. cmd = command,
  527. args = args
  528. }
  529. if redis_params['password'] then
  530. options['password'] = redis_params['password']
  531. end
  532. if redis_params['db'] then
  533. options['dbname'] = redis_params['db']
  534. end
  535. local ret,conn = rspamd_redis.make_request(options)
  536. if not ret then
  537. logger.errx('cannot execute redis request')
  538. addr:fail()
  539. end
  540. return ret,conn,addr
  541. end
  542. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  543. exports.redis_make_request_taskless = redis_make_request_taskless
  544. local redis_scripts = {
  545. }
  546. local function script_set_loaded(script)
  547. if script.sha then
  548. script.loaded = true
  549. end
  550. local wait_table = {}
  551. for _,s in ipairs(script.waitq) do
  552. table.insert(wait_table, s)
  553. end
  554. script.waitq = {}
  555. for _,s in ipairs(wait_table) do
  556. s(script.loaded)
  557. end
  558. end
  559. local function prepare_redis_call(script)
  560. local function merge_tables(t1, t2)
  561. for k,v in pairs(t2) do t1[k] = v end
  562. end
  563. local servers = {}
  564. local options = {}
  565. if script.redis_params.read_servers then
  566. merge_tables(servers, script.redis_params.read_servers:all_upstreams())
  567. end
  568. if script.redis_params.write_servers then
  569. merge_tables(servers, script.redis_params.write_servers:all_upstreams())
  570. end
  571. -- Call load script on each server, set loaded flag
  572. script.in_flight = #servers
  573. for _,s in ipairs(servers) do
  574. local cur_opts = {
  575. host = s:get_addr(),
  576. timeout = script.redis_params['timeout'],
  577. cmd = 'SCRIPT',
  578. args = {'LOAD', script.script },
  579. upstream = s
  580. }
  581. if script.redis_params['password'] then
  582. cur_opts['password'] = script.redis_params['password']
  583. end
  584. if script.redis_params['db'] then
  585. cur_opts['dbname'] = script.redis_params['db']
  586. end
  587. table.insert(options, cur_opts)
  588. end
  589. return options
  590. end
  591. local function load_script_task(script, task)
  592. local rspamd_redis = require "rspamd_redis"
  593. local opts = prepare_redis_call(script)
  594. for _,opt in ipairs(opts) do
  595. opt.task = task
  596. opt.callback = function(err, data)
  597. if err then
  598. opt.upstream:fail()
  599. else
  600. opt.upstream:ok()
  601. logger.infox(task,
  602. "loaded redis script with id %s, sha: %s", script.id, data)
  603. script.sha = data -- We assume that sha is the same on all servers
  604. end
  605. script.in_flight = script.in_flight - 1
  606. if script.in_flight == 0 then
  607. script_set_loaded(script)
  608. end
  609. end
  610. local ret = rspamd_redis.make_request(opt)
  611. if not ret then
  612. logger.errx('cannot execute redis request to load script')
  613. script.in_flight = script.in_flight - 1
  614. opt.upstream:fail()
  615. end
  616. if script.in_flight == 0 then
  617. script_set_loaded(script)
  618. end
  619. end
  620. end
  621. local function load_script_taskless(script, cfg, ev_base)
  622. local rspamd_redis = require "rspamd_redis"
  623. local opts = prepare_redis_call(script)
  624. for _,opt in ipairs(opts) do
  625. opt.config = cfg
  626. opt.ev_base = ev_base
  627. opt.callback = function(err, data)
  628. if err then
  629. opt.upstream:fail()
  630. else
  631. opt.upstream:ok()
  632. logger.infox(cfg,
  633. "loaded redis script with id %s, sha: %s", script.id, data)
  634. script.sha = data -- We assume that sha is the same on all servers
  635. end
  636. script.in_flight = script.in_flight - 1
  637. if script.in_flight == 0 then
  638. script_set_loaded(script)
  639. end
  640. end
  641. local ret = rspamd_redis.make_request(opt)
  642. if not ret then
  643. logger.errx('cannot execute redis request to load script')
  644. script.in_flight = script.in_flight - 1
  645. opt.upstream:fail()
  646. end
  647. if script.in_flight == 0 then
  648. script_set_loaded(script)
  649. end
  650. end
  651. end
  652. local function load_redis_script(script, cfg, ev_base, _)
  653. load_script_taskless(script, cfg, ev_base)
  654. end
  655. local function add_redis_script(script, redis_params)
  656. local new_script = {
  657. loaded = false,
  658. redis_params = redis_params,
  659. script = script,
  660. waitq = {}, -- callbacks pending for script being loaded
  661. id = #redis_scripts + 1
  662. }
  663. -- Register on load function
  664. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  665. load_redis_script(new_script, cfg, ev_base, worker)
  666. end)
  667. table.insert(redis_scripts, new_script)
  668. return #redis_scripts
  669. end
  670. exports.add_redis_script = add_redis_script
  671. local function exec_redis_script(id, params, callback, args)
  672. local args_modified = false
  673. if not redis_scripts[id] then
  674. logger.errx("cannot find registered script with id %s", id)
  675. return false
  676. end
  677. local script = redis_scripts[id]
  678. local function do_call(can_reload)
  679. local function redis_cb(err, data)
  680. if not err then
  681. callback(err, data)
  682. elseif string.match(err, 'NOSCRIPT') then
  683. -- Schedule restart
  684. script.sha = nil
  685. if can_reload then
  686. table.insert(script.waitq, do_call)
  687. if script.in_flight == 0 then
  688. -- Reload scripts if this has not been initiated yet
  689. if params.task then
  690. load_script_task(script, params.task)
  691. else
  692. load_script_taskless(script, rspamd_config, params.ev_base)
  693. end
  694. end
  695. else
  696. callback(err, data)
  697. end
  698. else
  699. callback(err, data)
  700. end
  701. end
  702. if not args_modified then
  703. table.insert(args, 1, tostring(#args))
  704. table.insert(args, 1, script.sha)
  705. args_modified = true
  706. end
  707. if params.task then
  708. if not rspamd_redis_make_request(params.task, script.redis_params,
  709. params.key, params.is_write, redis_cb, 'EVALSHA', args) then
  710. callback('Cannot make redis request', nil)
  711. end
  712. else
  713. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  714. script.redis_params,
  715. params.key, params.is_write, redis_cb, 'EVALSHA', args) then
  716. callback('Cannot make redis request', nil)
  717. end
  718. end
  719. end
  720. if script.loaded then
  721. do_call(true)
  722. else
  723. -- Delayed until scripts are loaded
  724. if not params.task then
  725. table.insert(script.waitq, do_call)
  726. else
  727. -- TODO: fix taskfull requests
  728. callback('NOSCRIPT', nil)
  729. end
  730. end
  731. return true
  732. end
  733. exports.exec_redis_script = exec_redis_script
  734. local function redis_connect_sync(redis_params, is_write, key, cfg)
  735. if not redis_params then
  736. return false,nil
  737. end
  738. local rspamd_redis = require "rspamd_redis"
  739. local addr
  740. if key then
  741. if is_write then
  742. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  743. else
  744. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  745. end
  746. else
  747. if is_write then
  748. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  749. else
  750. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  751. end
  752. end
  753. if not addr then
  754. logger.errx(cfg, 'cannot select server to make redis request')
  755. end
  756. local options = {
  757. host = addr:get_addr(),
  758. timeout = redis_params['timeout'],
  759. }
  760. local ret,conn = rspamd_redis.connect_sync(options)
  761. if not ret then
  762. logger.errx('cannot execute redis request: %s', conn)
  763. addr:fail()
  764. end
  765. if conn then
  766. if redis_params['password'] then
  767. conn:add_cmd('AUTH', {redis_params['password']})
  768. end
  769. if redis_params['db'] then
  770. conn:add_cmd('SELECT', {tostring(redis_params['db'])})
  771. elseif redis_params['dbname'] then
  772. conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
  773. end
  774. end
  775. return ret,conn,addr
  776. end
  777. exports.redis_connect_sync = redis_connect_sync
  778. return exports