Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

lua_redis.lua 24KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908
  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local exports = {}
  17. local E = {}
  18. -- This function parses redis server definition using either
  19. -- specific server string for this module or global
  20. -- redis section
  21. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  22. local result = {}
  23. local default_port = 6379
  24. local default_timeout = 1.0
  25. local default_expand_keys = false
  26. local upstream_list = require "rspamd_upstream_list"
  27. local function try_load_redis_servers(options)
  28. -- Try to get read servers:
  29. local upstreams_read, upstreams_write
  30. if options['read_servers'] then
  31. upstreams_read = upstream_list.create(rspamd_config,
  32. options['read_servers'], default_port)
  33. elseif options['servers'] then
  34. upstreams_read = upstream_list.create(rspamd_config,
  35. options['servers'], default_port)
  36. elseif options['server'] then
  37. upstreams_read = upstream_list.create(rspamd_config,
  38. options['server'], default_port)
  39. end
  40. if upstreams_read then
  41. if options['write_servers'] then
  42. upstreams_write = upstream_list.create(rspamd_config,
  43. options['write_servers'], default_port)
  44. else
  45. upstreams_write = upstreams_read
  46. end
  47. end
  48. -- Store options
  49. if not result['timeout'] or result['timeout'] == default_timeout then
  50. if options['timeout'] then
  51. result['timeout'] = tonumber(options['timeout'])
  52. else
  53. result['timeout'] = default_timeout
  54. end
  55. end
  56. if options['prefix'] and not result['prefix'] then
  57. result['prefix'] = options['prefix']
  58. end
  59. if type(options['expand_keys']) == 'boolean' then
  60. result['expand_keys'] = options['expand_keys']
  61. else
  62. result['expand_keys'] = default_expand_keys
  63. end
  64. if not result['db'] then
  65. if options['db'] then
  66. result['db'] = tostring(options['db'])
  67. elseif options['dbname'] then
  68. result['db'] = tostring(options['dbname'])
  69. end
  70. end
  71. if options['password'] and not result['password'] then
  72. result['password'] = options['password']
  73. end
  74. if upstreams_write and upstreams_read then
  75. result.read_servers = upstreams_read
  76. result.write_servers = upstreams_write
  77. return true
  78. end
  79. return false
  80. end
  81. -- Try local options
  82. local opts
  83. if not module_opts then
  84. opts = rspamd_config:get_all_opt(module_name)
  85. else
  86. opts = module_opts
  87. end
  88. if opts then
  89. local ret
  90. if opts.redis then
  91. ret = try_load_redis_servers(opts.redis, result)
  92. if ret then
  93. return result
  94. end
  95. end
  96. ret = try_load_redis_servers(opts, result)
  97. if ret then
  98. return result
  99. end
  100. end
  101. if no_fallback then return nil end
  102. -- Try global options
  103. opts = rspamd_config:get_all_opt('redis')
  104. if opts then
  105. local ret
  106. if opts[module_name] then
  107. ret = try_load_redis_servers(opts[module_name], result)
  108. if ret then
  109. return result
  110. end
  111. else
  112. ret = try_load_redis_servers(opts, result)
  113. -- Exclude disabled
  114. if opts['disabled_modules'] then
  115. for _,v in ipairs(opts['disabled_modules']) do
  116. if v == module_name then
  117. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  118. module_name)
  119. return nil
  120. end
  121. end
  122. end
  123. if ret then
  124. logger.infox(rspamd_config, "using default redis server for module %s",
  125. module_name)
  126. end
  127. end
  128. end
  129. if result.read_servers then
  130. return result
  131. else
  132. return nil
  133. end
  134. end
  135. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  136. exports.parse_redis_server = rspamd_parse_redis_server
  137. local process_cmd = {
  138. bitop = function(args)
  139. local idx_l = {}
  140. for i = 2, #args do
  141. table.insert(idx_l, i)
  142. end
  143. return idx_l
  144. end,
  145. blpop = function(args)
  146. local idx_l = {}
  147. for i = 1, #args -1 do
  148. table.insert(idx_l, i)
  149. end
  150. return idx_l
  151. end,
  152. eval = function(args)
  153. local idx_l = {}
  154. local numkeys = args[2]
  155. if numkeys >= 1 then
  156. for i = 3, numkeys + 2 do
  157. table.insert(idx_l, i)
  158. end
  159. end
  160. return idx_l
  161. end,
  162. set = function(args)
  163. return {1}
  164. end,
  165. mget = function(args)
  166. local idx_l = {}
  167. for i = 1, #args do
  168. table.insert(idx_l, i)
  169. end
  170. return idx_l
  171. end,
  172. mset = function(args)
  173. local idx_l = {}
  174. for i = 1, #args, 2 do
  175. table.insert(idx_l, i)
  176. end
  177. return idx_l
  178. end,
  179. sdiffstore = function(args)
  180. local idx_l = {}
  181. for i = 2, #args do
  182. table.insert(idx_l, i)
  183. end
  184. return idx_l
  185. end,
  186. smove = function(args)
  187. return {1, 2}
  188. end,
  189. script = function() end
  190. }
  191. process_cmd.append = process_cmd.set
  192. process_cmd.auth = process_cmd.script
  193. process_cmd.bgrewriteaof = process_cmd.script
  194. process_cmd.bgsave = process_cmd.script
  195. process_cmd.bitcount = process_cmd.set
  196. process_cmd.bitfield = process_cmd.set
  197. process_cmd.bitpos = process_cmd.set
  198. process_cmd.brpop = process_cmd.blpop
  199. process_cmd.brpoplpush = process_cmd.blpop
  200. process_cmd.client = process_cmd.script
  201. process_cmd.cluster = process_cmd.script
  202. process_cmd.command = process_cmd.script
  203. process_cmd.config = process_cmd.script
  204. process_cmd.dbsize = process_cmd.script
  205. process_cmd.debug = process_cmd.script
  206. process_cmd.decr = process_cmd.set
  207. process_cmd.decrby = process_cmd.set
  208. process_cmd.del = process_cmd.mget
  209. process_cmd.discard = process_cmd.script
  210. process_cmd.dump = process_cmd.set
  211. process_cmd.echo = process_cmd.script
  212. process_cmd.evalsha = process_cmd.eval
  213. process_cmd.exec = process_cmd.script
  214. process_cmd.exists = process_cmd.mget
  215. process_cmd.expire = process_cmd.set
  216. process_cmd.expireat = process_cmd.set
  217. process_cmd.flushall = process_cmd.script
  218. process_cmd.flushdb = process_cmd.script
  219. process_cmd.geoadd = process_cmd.set
  220. process_cmd.geohash = process_cmd.set
  221. process_cmd.geopos = process_cmd.set
  222. process_cmd.geodist = process_cmd.set
  223. process_cmd.georadius = process_cmd.set
  224. process_cmd.georadiusbymember = process_cmd.set
  225. process_cmd.get = process_cmd.set
  226. process_cmd.getbit = process_cmd.set
  227. process_cmd.getrange = process_cmd.set
  228. process_cmd.getset = process_cmd.set
  229. process_cmd.hdel = process_cmd.set
  230. process_cmd.hexists = process_cmd.set
  231. process_cmd.hget = process_cmd.set
  232. process_cmd.hgetall = process_cmd.set
  233. process_cmd.hincrby = process_cmd.set
  234. process_cmd.hincrbyfloat = process_cmd.set
  235. process_cmd.hkeys = process_cmd.set
  236. process_cmd.hlen = process_cmd.set
  237. process_cmd.hmget = process_cmd.set
  238. process_cmd.hmset = process_cmd.set
  239. process_cmd.hscan = process_cmd.set
  240. process_cmd.hset = process_cmd.set
  241. process_cmd.hsetnx = process_cmd.set
  242. process_cmd.hstrlen = process_cmd.set
  243. process_cmd.hvals = process_cmd.set
  244. process_cmd.incr = process_cmd.set
  245. process_cmd.incrby = process_cmd.set
  246. process_cmd.incrbyfloat = process_cmd.set
  247. process_cmd.info = process_cmd.script
  248. process_cmd.keys = process_cmd.script
  249. process_cmd.lastsave = process_cmd.script
  250. process_cmd.lindex = process_cmd.set
  251. process_cmd.linsert = process_cmd.set
  252. process_cmd.llen = process_cmd.set
  253. process_cmd.lpop = process_cmd.set
  254. process_cmd.lpush = process_cmd.set
  255. process_cmd.lpushx = process_cmd.set
  256. process_cmd.lrange = process_cmd.set
  257. process_cmd.lrem = process_cmd.set
  258. process_cmd.lset = process_cmd.set
  259. process_cmd.ltrim = process_cmd.set
  260. process_cmd.migrate = process_cmd.script
  261. process_cmd.monitor = process_cmd.script
  262. process_cmd.move = process_cmd.set
  263. process_cmd.msetnx = process_cmd.mset
  264. process_cmd.multi = process_cmd.script
  265. process_cmd.object = process_cmd.script
  266. process_cmd.persist = process_cmd.set
  267. process_cmd.pexpire = process_cmd.set
  268. process_cmd.pexpireat = process_cmd.set
  269. process_cmd.pfadd = process_cmd.set
  270. process_cmd.pfcount = process_cmd.set
  271. process_cmd.pfmerge = process_cmd.mget
  272. process_cmd.ping = process_cmd.script
  273. process_cmd.psetex = process_cmd.set
  274. process_cmd.psubscribe = process_cmd.script
  275. process_cmd.pubsub = process_cmd.script
  276. process_cmd.pttl = process_cmd.set
  277. process_cmd.publish = process_cmd.script
  278. process_cmd.punsubscribe = process_cmd.script
  279. process_cmd.quit = process_cmd.script
  280. process_cmd.randomkey = process_cmd.script
  281. process_cmd.readonly = process_cmd.script
  282. process_cmd.readwrite = process_cmd.script
  283. process_cmd.rename = process_cmd.mget
  284. process_cmd.renamenx = process_cmd.mget
  285. process_cmd.restore = process_cmd.set
  286. process_cmd.role = process_cmd.script
  287. process_cmd.rpop = process_cmd.set
  288. process_cmd.rpoplpush = process_cmd.mget
  289. process_cmd.rpush = process_cmd.set
  290. process_cmd.rpushx = process_cmd.set
  291. process_cmd.sadd = process_cmd.set
  292. process_cmd.save = process_cmd.script
  293. process_cmd.scard = process_cmd.set
  294. process_cmd.sdiff = process_cmd.mget
  295. process_cmd.select = process_cmd.script
  296. process_cmd.setbit = process_cmd.set
  297. process_cmd.setex = process_cmd.set
  298. process_cmd.setnx = process_cmd.set
  299. process_cmd.sinterstore = process_cmd.sdiff
  300. process_cmd.sismember = process_cmd.set
  301. process_cmd.slaveof = process_cmd.script
  302. process_cmd.slowlog = process_cmd.script
  303. process_cmd.smembers = process_cmd.script
  304. process_cmd.sort = process_cmd.set
  305. process_cmd.spop = process_cmd.set
  306. process_cmd.srandmember = process_cmd.set
  307. process_cmd.srem = process_cmd.set
  308. process_cmd.strlen = process_cmd.set
  309. process_cmd.subscribe = process_cmd.script
  310. process_cmd.sunion = process_cmd.mget
  311. process_cmd.sunionstore = process_cmd.mget
  312. process_cmd.swapdb = process_cmd.script
  313. process_cmd.sync = process_cmd.script
  314. process_cmd.time = process_cmd.script
  315. process_cmd.touch = process_cmd.mget
  316. process_cmd.ttl = process_cmd.set
  317. process_cmd.type = process_cmd.set
  318. process_cmd.unsubscribe = process_cmd.script
  319. process_cmd.unlink = process_cmd.mget
  320. process_cmd.unwatch = process_cmd.script
  321. process_cmd.wait = process_cmd.script
  322. process_cmd.watch = process_cmd.mget
  323. process_cmd.zadd = process_cmd.set
  324. process_cmd.zcard = process_cmd.set
  325. process_cmd.zcount = process_cmd.set
  326. process_cmd.zincrby = process_cmd.set
  327. process_cmd.zinterstore = process_cmd.eval
  328. process_cmd.zlexcount = process_cmd.set
  329. process_cmd.zrange = process_cmd.set
  330. process_cmd.zrangebylex = process_cmd.set
  331. process_cmd.zrank = process_cmd.set
  332. process_cmd.zrem = process_cmd.set
  333. process_cmd.zrembylex = process_cmd.set
  334. process_cmd.zrembyrank = process_cmd.set
  335. process_cmd.zrembyscore = process_cmd.set
  336. process_cmd.zrevrange = process_cmd.set
  337. process_cmd.zrevrangebyscore = process_cmd.set
  338. process_cmd.zrevrank = process_cmd.set
  339. process_cmd.zscore = process_cmd.set
  340. process_cmd.zunionstore = process_cmd.eval
  341. process_cmd.scan = process_cmd.script
  342. process_cmd.sscan = process_cmd.set
  343. process_cmd.hscan = process_cmd.set
  344. process_cmd.zscan = process_cmd.set
  345. local function get_key_indexes(cmd, args)
  346. local idx_l = {}
  347. cmd = string.lower(cmd)
  348. if process_cmd[cmd] then
  349. idx_l = process_cmd[cmd](args)
  350. else
  351. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  352. end
  353. return idx_l
  354. end
  355. local gen_meta = {
  356. principal_recipient = function(task)
  357. return task:get_principal_recipient()
  358. end,
  359. principal_recipient_domain = function(task)
  360. local p = task:get_principal_recipient()
  361. if not p then return end
  362. return string.match(p, '.*@(.*)')
  363. end,
  364. ip = function(task)
  365. local i = task:get_ip()
  366. if i and i:is_valid() then return i:to_string() end
  367. end,
  368. from = function(task)
  369. return ((task:get_from('smtp') or E)[1] or E)['addr']
  370. end,
  371. from_domain = function(task)
  372. return ((task:get_from('smtp') or E)[1] or E)['domain']
  373. end,
  374. from_domain_or_helo_domain = function(task)
  375. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  376. if d and #d > 0 then return d end
  377. return task:get_helo()
  378. end,
  379. mime_from = function(task)
  380. return ((task:get_from('mime') or E)[1] or E)['addr']
  381. end,
  382. mime_from_domain = function(task)
  383. return ((task:get_from('mime') or E)[1] or E)['domain']
  384. end,
  385. }
  386. local function gen_get_esld(f)
  387. return function(task)
  388. local d = f(task)
  389. if not d then return end
  390. return rspamd_util.get_tld(d)
  391. end
  392. end
  393. gen_meta.smtp_from = gen_meta.from
  394. gen_meta.smtp_from_domain = gen_meta.from_domain
  395. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  396. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  397. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  398. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  399. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  400. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  401. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  402. local function get_key_expansion_metadata(task)
  403. local md_mt = {
  404. __index = function(self, k)
  405. k = string.lower(k)
  406. local v = rawget(self, k)
  407. if v then
  408. return v
  409. end
  410. if gen_meta[k] then
  411. v = gen_meta[k](task)
  412. rawset(self, k, v)
  413. end
  414. return v
  415. end,
  416. }
  417. local lazy_meta = {}
  418. setmetatable(lazy_meta, md_mt)
  419. return lazy_meta
  420. end
  421. -- Performs async call to redis hiding all complexity inside function
  422. -- task - rspamd_task
  423. -- redis_params - valid params returned by rspamd_parse_redis_server
  424. -- key - key to select upstream or nil to select round-robin/master-slave
  425. -- is_write - true if need to write to redis server
  426. -- callback - function to be called upon request is completed
  427. -- command - redis command
  428. -- args - table of arguments
  429. -- extra_opts - table of optional request arguments
  430. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  431. callback, command, args, extra_opts)
  432. local addr
  433. local function rspamd_redis_make_request_cb(err, data)
  434. if err then
  435. addr:fail()
  436. else
  437. addr:ok()
  438. end
  439. callback(err, data, addr)
  440. end
  441. if not task or not redis_params or not callback or not command then
  442. return false,nil,nil
  443. end
  444. local rspamd_redis = require "rspamd_redis"
  445. if key then
  446. if is_write then
  447. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  448. else
  449. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  450. end
  451. else
  452. if is_write then
  453. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  454. else
  455. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  456. end
  457. end
  458. if not addr then
  459. logger.errx(task, 'cannot select server to make redis request')
  460. end
  461. if redis_params['expand_keys'] then
  462. local m = get_key_expansion_metadata(task)
  463. local indexes = get_key_indexes(command, args)
  464. for _, i in ipairs(indexes) do
  465. args[i] = lutil.template(args[i], m)
  466. end
  467. end
  468. local ip_addr = addr:get_addr()
  469. local options = {
  470. task = task,
  471. callback = rspamd_redis_make_request_cb,
  472. host = ip_addr,
  473. timeout = redis_params['timeout'],
  474. cmd = command,
  475. args = args
  476. }
  477. if extra_opts then
  478. for k,v in pairs(extra_opts) do
  479. options[k] = v
  480. end
  481. end
  482. if redis_params['password'] then
  483. options['password'] = redis_params['password']
  484. end
  485. if redis_params['db'] then
  486. options['dbname'] = redis_params['db']
  487. end
  488. local ret,conn = rspamd_redis.make_request(options)
  489. if not ret then
  490. addr:fail()
  491. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  492. end
  493. return ret,conn,addr
  494. end
  495. exports.rspamd_redis_make_request = rspamd_redis_make_request
  496. exports.redis_make_request = rspamd_redis_make_request
  497. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  498. is_write, callback, command, args, extra_opts)
  499. if not ev_base or not redis_params or not callback or not command then
  500. return false,nil,nil
  501. end
  502. local addr
  503. local function rspamd_redis_make_request_cb(err, data)
  504. if err then
  505. addr:fail()
  506. else
  507. addr:ok()
  508. end
  509. callback(err, data, addr)
  510. end
  511. local rspamd_redis = require "rspamd_redis"
  512. if key then
  513. if is_write then
  514. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  515. else
  516. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  517. end
  518. else
  519. if is_write then
  520. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  521. else
  522. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  523. end
  524. end
  525. if not addr then
  526. logger.errx(cfg, 'cannot select server to make redis request')
  527. end
  528. local options = {
  529. ev_base = ev_base,
  530. config = cfg,
  531. callback = rspamd_redis_make_request_cb,
  532. host = addr:get_addr(),
  533. timeout = redis_params['timeout'],
  534. cmd = command,
  535. args = args
  536. }
  537. if extra_opts then
  538. for k,v in pairs(extra_opts) do
  539. options[k] = v
  540. end
  541. end
  542. if redis_params['password'] then
  543. options['password'] = redis_params['password']
  544. end
  545. if redis_params['db'] then
  546. options['dbname'] = redis_params['db']
  547. end
  548. local ret,conn = rspamd_redis.make_request(options)
  549. if not ret then
  550. logger.errx('cannot execute redis request')
  551. addr:fail()
  552. end
  553. return ret,conn,addr
  554. end
  555. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  556. exports.redis_make_request_taskless = redis_make_request_taskless
  557. local redis_scripts = {
  558. }
  559. local function script_set_loaded(script)
  560. if script.sha then
  561. script.loaded = true
  562. end
  563. local wait_table = {}
  564. for _,s in ipairs(script.waitq) do
  565. table.insert(wait_table, s)
  566. end
  567. script.waitq = {}
  568. for _,s in ipairs(wait_table) do
  569. s(script.loaded)
  570. end
  571. end
  572. local function prepare_redis_call(script)
  573. local function merge_tables(t1, t2)
  574. for k,v in pairs(t2) do t1[k] = v end
  575. end
  576. local servers = {}
  577. local options = {}
  578. if script.redis_params.read_servers then
  579. merge_tables(servers, script.redis_params.read_servers:all_upstreams())
  580. end
  581. if script.redis_params.write_servers then
  582. merge_tables(servers, script.redis_params.write_servers:all_upstreams())
  583. end
  584. -- Call load script on each server, set loaded flag
  585. script.in_flight = #servers
  586. for _,s in ipairs(servers) do
  587. local cur_opts = {
  588. host = s:get_addr(),
  589. timeout = script.redis_params['timeout'],
  590. cmd = 'SCRIPT',
  591. args = {'LOAD', script.script },
  592. upstream = s
  593. }
  594. if script.redis_params['password'] then
  595. cur_opts['password'] = script.redis_params['password']
  596. end
  597. if script.redis_params['db'] then
  598. cur_opts['dbname'] = script.redis_params['db']
  599. end
  600. table.insert(options, cur_opts)
  601. end
  602. return options
  603. end
  604. local function load_script_task(script, task)
  605. local rspamd_redis = require "rspamd_redis"
  606. local opts = prepare_redis_call(script)
  607. for _,opt in ipairs(opts) do
  608. opt.task = task
  609. opt.callback = function(err, data)
  610. if err then
  611. opt.upstream:fail()
  612. else
  613. opt.upstream:ok()
  614. logger.infox(task,
  615. "loaded redis script with id %s, sha: %s", script.id, data)
  616. script.sha = data -- We assume that sha is the same on all servers
  617. end
  618. script.in_flight = script.in_flight - 1
  619. if script.in_flight == 0 then
  620. script_set_loaded(script)
  621. end
  622. end
  623. local ret = rspamd_redis.make_request(opt)
  624. if not ret then
  625. logger.errx('cannot execute redis request to load script')
  626. script.in_flight = script.in_flight - 1
  627. opt.upstream:fail()
  628. end
  629. if script.in_flight == 0 then
  630. script_set_loaded(script)
  631. end
  632. end
  633. end
  634. local function load_script_taskless(script, cfg, ev_base)
  635. local rspamd_redis = require "rspamd_redis"
  636. local opts = prepare_redis_call(script)
  637. for _,opt in ipairs(opts) do
  638. opt.config = cfg
  639. opt.ev_base = ev_base
  640. opt.callback = function(err, data)
  641. if err then
  642. opt.upstream:fail()
  643. else
  644. opt.upstream:ok()
  645. logger.infox(cfg,
  646. "loaded redis script with id %s, sha: %s", script.id, data)
  647. script.sha = data -- We assume that sha is the same on all servers
  648. end
  649. script.in_flight = script.in_flight - 1
  650. if script.in_flight == 0 then
  651. script_set_loaded(script)
  652. end
  653. end
  654. local ret = rspamd_redis.make_request(opt)
  655. if not ret then
  656. logger.errx('cannot execute redis request to load script')
  657. script.in_flight = script.in_flight - 1
  658. opt.upstream:fail()
  659. end
  660. if script.in_flight == 0 then
  661. script_set_loaded(script)
  662. end
  663. end
  664. end
  665. local function load_redis_script(script, cfg, ev_base, _)
  666. load_script_taskless(script, cfg, ev_base)
  667. end
  668. local function add_redis_script(script, redis_params)
  669. local new_script = {
  670. loaded = false,
  671. redis_params = redis_params,
  672. script = script,
  673. waitq = {}, -- callbacks pending for script being loaded
  674. id = #redis_scripts + 1
  675. }
  676. -- Register on load function
  677. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  678. load_redis_script(new_script, cfg, ev_base, worker)
  679. end)
  680. table.insert(redis_scripts, new_script)
  681. return #redis_scripts
  682. end
  683. exports.add_redis_script = add_redis_script
  684. local function exec_redis_script(id, params, callback, args)
  685. local args_modified = false
  686. if not redis_scripts[id] then
  687. logger.errx("cannot find registered script with id %s", id)
  688. return false
  689. end
  690. local script = redis_scripts[id]
  691. local function do_call(can_reload)
  692. local function redis_cb(err, data)
  693. if not err then
  694. callback(err, data)
  695. elseif string.match(err, 'NOSCRIPT') then
  696. -- Schedule restart
  697. script.sha = nil
  698. if can_reload then
  699. table.insert(script.waitq, do_call)
  700. if script.in_flight == 0 then
  701. -- Reload scripts if this has not been initiated yet
  702. if params.task then
  703. load_script_task(script, params.task)
  704. else
  705. load_script_taskless(script, rspamd_config, params.ev_base)
  706. end
  707. end
  708. else
  709. callback(err, data)
  710. end
  711. else
  712. callback(err, data)
  713. end
  714. end
  715. if not args_modified then
  716. table.insert(args, 1, tostring(#args))
  717. table.insert(args, 1, script.sha)
  718. args_modified = true
  719. end
  720. if params.task then
  721. if not rspamd_redis_make_request(params.task, script.redis_params,
  722. params.key, params.is_write, redis_cb, 'EVALSHA', args) then
  723. callback('Cannot make redis request', nil)
  724. end
  725. else
  726. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  727. script.redis_params,
  728. params.key, params.is_write, redis_cb, 'EVALSHA', args) then
  729. callback('Cannot make redis request', nil)
  730. end
  731. end
  732. end
  733. if script.loaded then
  734. do_call(true)
  735. else
  736. -- Delayed until scripts are loaded
  737. if not params.task then
  738. table.insert(script.waitq, do_call)
  739. else
  740. -- TODO: fix taskfull requests
  741. callback('NOSCRIPT', nil)
  742. end
  743. end
  744. return true
  745. end
  746. exports.exec_redis_script = exec_redis_script
  747. local function redis_connect_sync(redis_params, is_write, key, cfg)
  748. if not redis_params then
  749. return false,nil
  750. end
  751. local rspamd_redis = require "rspamd_redis"
  752. local addr
  753. if key then
  754. if is_write then
  755. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  756. else
  757. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  758. end
  759. else
  760. if is_write then
  761. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  762. else
  763. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  764. end
  765. end
  766. if not addr then
  767. logger.errx(cfg, 'cannot select server to make redis request')
  768. end
  769. local options = {
  770. host = addr:get_addr(),
  771. timeout = redis_params['timeout'],
  772. }
  773. local ret,conn = rspamd_redis.connect_sync(options)
  774. if not ret then
  775. logger.errx('cannot execute redis request: %s', conn)
  776. addr:fail()
  777. end
  778. if conn then
  779. if redis_params['password'] then
  780. conn:add_cmd('AUTH', {redis_params['password']})
  781. end
  782. if redis_params['db'] then
  783. conn:add_cmd('SELECT', {tostring(redis_params['db'])})
  784. elseif redis_params['dbname'] then
  785. conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
  786. end
  787. end
  788. return ret,conn,addr
  789. end
  790. exports.redis_connect_sync = redis_connect_sync
  791. return exports