You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541
  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local ts = require("tableshape").types
  17. local exports = {}
  18. local E = {}
  19. local N = "lua_redis"
  20. local common_schema = ts.shape {
  21. timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  22. db = ts.string:is_optional(),
  23. database = ts.string:is_optional(),
  24. dbname = ts.string:is_optional(),
  25. prefix = ts.string:is_optional(),
  26. password = ts.string:is_optional(),
  27. expand_keys = ts.boolean:is_optional(),
  28. sentinels = (ts.string + ts.array_of(ts.string)):is_optional(),
  29. sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  30. sentinel_masters_pattern = ts.string:is_optional(),
  31. sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional(),
  32. }
  33. local config_schema =
  34. ts.shape({
  35. read_servers = ts.string + ts.array_of(ts.string),
  36. write_servers = ts.string + ts.array_of(ts.string),
  37. }, {extra_opts = common_schema}) +
  38. ts.shape({
  39. servers = ts.string + ts.array_of(ts.string),
  40. }, {extra_opts = common_schema}) +
  41. ts.shape({
  42. server = ts.string + ts.array_of(ts.string),
  43. }, {extra_opts = common_schema})
  44. exports.config_schema = config_schema
  45. local function redis_query_sentinel(ev_base, params, initialised)
  46. local function flatten_redis_table(tbl)
  47. local res = {}
  48. for i=1,#tbl,2 do
  49. res[tbl[i]] = tbl[i + 1]
  50. end
  51. return res
  52. end
  53. -- Coroutines syntax
  54. local rspamd_redis = require "rspamd_redis"
  55. local addr = params.sentinels:get_upstream_round_robin()
  56. local is_ok, connection = rspamd_redis.connect_sync({
  57. host = addr:get_addr(),
  58. timeout = params.timeout,
  59. config = rspamd_config,
  60. ev_base = ev_base,
  61. })
  62. if not is_ok then
  63. logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
  64. tostring(addr:get_addr()))
  65. addr:fail()
  66. return
  67. end
  68. -- Get masters list
  69. connection:add_cmd('SENTINEL', {'masters'})
  70. local ok,result = connection:exec()
  71. if ok and result and type(result) == 'table' then
  72. local masters = {}
  73. for _,m in ipairs(result) do
  74. local master = flatten_redis_table(m)
  75. if params.sentinel_masters_pattern then
  76. if master.name:match(params.sentinel_masters_pattern) then
  77. lutil.debugm(N, 'found master %s with ip %s and port %s',
  78. master.name, master.ip, master.port)
  79. masters[master.name] = master
  80. else
  81. lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
  82. master.name, master.ip, master.port, params.sentinel_masters_pattern)
  83. end
  84. else
  85. lutil.debugm(N, 'found master %s with ip %s and port %s',
  86. master.name, master.ip, master.port)
  87. masters[master.name] = master
  88. end
  89. end
  90. -- For each master we need to get a list of slaves
  91. for k,v in pairs(masters) do
  92. v.slaves = {}
  93. local slave_result
  94. connection:add_cmd('SENTINEL', {'slaves', k})
  95. ok,slave_result = connection:exec()
  96. if ok then
  97. for _,s in ipairs(slave_result) do
  98. local slave = flatten_redis_table(s)
  99. lutil.debugm(N, rspamd_config,
  100. 'found slave form master %s with ip %s and port %s',
  101. v.name, slave.ip, slave.port)
  102. v.slaves[#v.slaves + 1] = slave
  103. end
  104. end
  105. end
  106. -- We now form new strings for masters and slaves
  107. local read_servers_tbl, write_servers_tbl = {}, {}
  108. for _,master in pairs(masters) do
  109. write_servers_tbl[#write_servers_tbl + 1] = string.format(
  110. '%s:%s', master.ip, master.port
  111. )
  112. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  113. '%s:%s', master.ip, master.port
  114. )
  115. for _,slave in ipairs(master.slaves) do
  116. if slave['master-link-status'] == 'ok' then
  117. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  118. '%s:%s', slave.ip, slave.port
  119. )
  120. end
  121. end
  122. end
  123. table.sort(read_servers_tbl)
  124. table.sort(write_servers_tbl)
  125. local read_servers_str = table.concat(read_servers_tbl, ',')
  126. local write_servers_str = table.concat(write_servers_tbl, ',')
  127. lutil.debugm(N, rspamd_config,
  128. 'new servers list: %s read; %s write',
  129. read_servers_str,
  130. write_servers_str)
  131. if read_servers_str ~= params.read_servers_str then
  132. local upstream_list = require "rspamd_upstream_list"
  133. local read_upstreams = upstream_list.create(rspamd_config,
  134. read_servers_str, 6379)
  135. if read_upstreams then
  136. logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
  137. addr:get_addr():to_string(true), read_servers_str)
  138. params.read_servers = read_upstreams
  139. params.read_servers_str = read_servers_str
  140. end
  141. end
  142. if write_servers_str ~= params.write_servers_str then
  143. local upstream_list = require "rspamd_upstream_list"
  144. local write_upstreams = upstream_list.create(rspamd_config,
  145. write_servers_str, 6379)
  146. if write_upstreams then
  147. logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
  148. addr:get_addr():to_string(true), write_servers_str)
  149. params.write_servers = write_upstreams
  150. params.write_servers_str = write_servers_str
  151. local queried = false
  152. local function monitor_failures(up, _, count)
  153. if count > params.sentinel_master_maxerrors and not queried then
  154. logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel',
  155. up:get_addr():to_string(true), count)
  156. queried = true -- Avoid multiple checks caused by this monitor
  157. redis_query_sentinel(ev_base, params, true)
  158. end
  159. end
  160. write_upstreams:add_watcher('failure', monitor_failures)
  161. end
  162. end
  163. addr:ok()
  164. else
  165. logger.errx('cannot get data from Redis Sentinel %s: %s',
  166. addr:get_addr():to_string(true), result)
  167. addr:fail()
  168. end
  169. end
  170. local function add_redis_sentinels(params)
  171. local upstream_list = require "rspamd_upstream_list"
  172. local upstreams_sentinels = upstream_list.create(rspamd_config,
  173. params.sentinels, 5000)
  174. if not upstreams_sentinels then
  175. logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
  176. params.sentinels)
  177. return
  178. end
  179. params.sentinels = upstreams_sentinels
  180. if not params.sentinel_watch_time then
  181. params.sentinel_watch_time = 60 -- Each minute
  182. end
  183. if not params.sentinel_master_maxerrors then
  184. params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking
  185. end
  186. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  187. local initialised = false
  188. if worker:is_scanner() then
  189. rspamd_config:add_periodic(ev_base, 0.0, function()
  190. redis_query_sentinel(ev_base, params, initialised)
  191. initialised = true
  192. return params.sentinel_watch_time
  193. end, false)
  194. end
  195. end)
  196. end
  197. local cached_results = {}
  198. local function calculate_redis_hash(params)
  199. local cr = require "rspamd_cryptobox_hash"
  200. local h = cr.create()
  201. local function rec_hash(k, v)
  202. if type(v) == 'string' then
  203. h:update(k)
  204. h:update(v)
  205. elseif type(v) == 'number' then
  206. h:update(k)
  207. h:update(tostring(v))
  208. elseif type(v) == 'table' then
  209. for kk,vv in pairs(v) do
  210. rec_hash(kk, vv)
  211. end
  212. end
  213. end
  214. rec_hash('top', params)
  215. return h:base32()
  216. end
  217. --[[[
  218. -- @module lua_redis
  219. -- This module contains helper functions for working with Redis
  220. --]]
  221. local function try_load_redis_servers(options, rspamd_config, result)
  222. local default_port = 6379
  223. local default_timeout = 1.0
  224. local default_expand_keys = false
  225. local upstream_list = require "rspamd_upstream_list"
  226. local read_only = true
  227. -- Try to get read servers:
  228. local upstreams_read, upstreams_write
  229. if options['read_servers'] then
  230. if rspamd_config then
  231. upstreams_read = upstream_list.create(rspamd_config,
  232. options['read_servers'], default_port)
  233. else
  234. upstreams_read = upstream_list.create(options['read_servers'],
  235. default_port)
  236. end
  237. result.read_servers_str = options['read_servers']
  238. elseif options['servers'] then
  239. if rspamd_config then
  240. upstreams_read = upstream_list.create(rspamd_config,
  241. options['servers'], default_port)
  242. else
  243. upstreams_read = upstream_list.create(options['servers'], default_port)
  244. end
  245. result.read_servers_str = options['servers']
  246. read_only = false
  247. elseif options['server'] then
  248. if rspamd_config then
  249. upstreams_read = upstream_list.create(rspamd_config,
  250. options['server'], default_port)
  251. else
  252. upstreams_read = upstream_list.create(options['server'], default_port)
  253. end
  254. result.read_servers_str = options['server']
  255. read_only = false
  256. end
  257. if upstreams_read then
  258. if options['write_servers'] then
  259. if rspamd_config then
  260. upstreams_write = upstream_list.create(rspamd_config,
  261. options['write_servers'], default_port)
  262. else
  263. upstreams_write = upstream_list.create(options['write_servers'],
  264. default_port)
  265. end
  266. result.write_servers_str = options['write_servers']
  267. read_only = false
  268. elseif not read_only then
  269. upstreams_write = upstreams_read
  270. result.write_servers_str = result.read_servers_str
  271. end
  272. end
  273. -- Store options
  274. if not result['timeout'] or result['timeout'] == default_timeout then
  275. if options['timeout'] then
  276. result['timeout'] = tonumber(options['timeout'])
  277. else
  278. result['timeout'] = default_timeout
  279. end
  280. end
  281. if options['prefix'] and not result['prefix'] then
  282. result['prefix'] = options['prefix']
  283. end
  284. if type(options['expand_keys']) == 'boolean' then
  285. result['expand_keys'] = options['expand_keys']
  286. else
  287. result['expand_keys'] = default_expand_keys
  288. end
  289. if not result['db'] then
  290. if options['db'] then
  291. result['db'] = tostring(options['db'])
  292. elseif options['dbname'] then
  293. result['db'] = tostring(options['dbname'])
  294. elseif options['database'] then
  295. result['db'] = tostring(options['database'])
  296. end
  297. end
  298. if options['password'] and not result['password'] then
  299. result['password'] = options['password']
  300. end
  301. if read_only and not upstreams_write then
  302. result.read_only = true
  303. elseif upstreams_write then
  304. result.read_only = false
  305. end
  306. if upstreams_read then
  307. result.read_servers = upstreams_read
  308. if upstreams_write then
  309. result.write_servers = upstreams_write
  310. end
  311. local h = calculate_redis_hash(result)
  312. if cached_results[h] then
  313. for k,v in pairs(cached_results[h]) do
  314. result[k] = v
  315. end
  316. lutil.debugm(N, 'reused redis server: %s', result)
  317. return true
  318. end
  319. result.hash = h
  320. cached_results[h] = result
  321. if not result.read_only and options.sentinels then
  322. result.sentinels = options.sentinels
  323. add_redis_sentinels(result)
  324. end
  325. lutil.debugm(N, 'loaded redis server: %s', result)
  326. return true
  327. end
  328. lutil.debugm(N, rspamd_config,
  329. 'cannot load redis server from obj: %s, processed to %s',
  330. options, result)
  331. return false
  332. end
  333. exports.try_load_redis_servers = try_load_redis_servers
  334. -- This function parses redis server definition using either
  335. -- specific server string for this module or global
  336. -- redis section
  337. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  338. local result = {}
  339. -- Try local options
  340. local opts
  341. lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name)
  342. if not module_opts then
  343. opts = rspamd_config:get_all_opt(module_name)
  344. else
  345. opts = module_opts
  346. end
  347. if opts then
  348. local ret
  349. if opts.redis then
  350. ret = try_load_redis_servers(opts.redis, rspamd_config, result)
  351. if ret then
  352. return result
  353. end
  354. end
  355. ret = try_load_redis_servers(opts, rspamd_config, result)
  356. if ret then
  357. return result
  358. end
  359. end
  360. if no_fallback then
  361. logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled",
  362. module_name)
  363. return nil
  364. end
  365. -- Try global options
  366. opts = rspamd_config:get_all_opt('redis')
  367. if opts then
  368. local ret
  369. if opts[module_name] then
  370. ret = try_load_redis_servers(opts[module_name], rspamd_config, result)
  371. if ret then
  372. return result
  373. end
  374. else
  375. ret = try_load_redis_servers(opts, rspamd_config, result)
  376. -- Exclude disabled
  377. if opts['disabled_modules'] then
  378. for _,v in ipairs(opts['disabled_modules']) do
  379. if v == module_name then
  380. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  381. module_name)
  382. return nil
  383. end
  384. end
  385. end
  386. if ret then
  387. logger.infox(rspamd_config, "use default Redis settings for %s",
  388. module_name)
  389. return result
  390. end
  391. end
  392. end
  393. if result.read_servers then
  394. return result
  395. end
  396. return nil
  397. end
  398. --[[[
  399. -- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
  400. -- Extracts Redis server settings from configuration
  401. -- @param {string} module_name name of module to get settings for
  402. -- @param {table} module_opts settings for module or `nil` to fetch them from configuration
  403. -- @param {boolean} no_fallback should be `true` if global settings must not be used
  404. -- @return {table} redis server settings
  405. -- @example
  406. -- local rconfig = lua_redis.parse_redis_server('my_module')
  407. -- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
  408. -- -- ['timeout'] contains timeout in seconds
  409. -- -- ['expand_keys'] if true tells that redis key expansion is enabled
  410. --]]
  411. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  412. exports.parse_redis_server = rspamd_parse_redis_server
  413. local process_cmd = {
  414. bitop = function(args)
  415. local idx_l = {}
  416. for i = 2, #args do
  417. table.insert(idx_l, i)
  418. end
  419. return idx_l
  420. end,
  421. blpop = function(args)
  422. local idx_l = {}
  423. for i = 1, #args -1 do
  424. table.insert(idx_l, i)
  425. end
  426. return idx_l
  427. end,
  428. eval = function(args)
  429. local idx_l = {}
  430. local numkeys = args[2]
  431. if numkeys and tonumber(numkeys) >= 1 then
  432. for i = 3, numkeys + 2 do
  433. table.insert(idx_l, i)
  434. end
  435. end
  436. return idx_l
  437. end,
  438. set = function(args)
  439. return {1}
  440. end,
  441. mget = function(args)
  442. local idx_l = {}
  443. for i = 1, #args do
  444. table.insert(idx_l, i)
  445. end
  446. return idx_l
  447. end,
  448. mset = function(args)
  449. local idx_l = {}
  450. for i = 1, #args, 2 do
  451. table.insert(idx_l, i)
  452. end
  453. return idx_l
  454. end,
  455. sdiffstore = function(args)
  456. local idx_l = {}
  457. for i = 2, #args do
  458. table.insert(idx_l, i)
  459. end
  460. return idx_l
  461. end,
  462. smove = function(args)
  463. return {1, 2}
  464. end,
  465. script = function() end
  466. }
  467. process_cmd.append = process_cmd.set
  468. process_cmd.auth = process_cmd.script
  469. process_cmd.bgrewriteaof = process_cmd.script
  470. process_cmd.bgsave = process_cmd.script
  471. process_cmd.bitcount = process_cmd.set
  472. process_cmd.bitfield = process_cmd.set
  473. process_cmd.bitpos = process_cmd.set
  474. process_cmd.brpop = process_cmd.blpop
  475. process_cmd.brpoplpush = process_cmd.blpop
  476. process_cmd.client = process_cmd.script
  477. process_cmd.cluster = process_cmd.script
  478. process_cmd.command = process_cmd.script
  479. process_cmd.config = process_cmd.script
  480. process_cmd.dbsize = process_cmd.script
  481. process_cmd.debug = process_cmd.script
  482. process_cmd.decr = process_cmd.set
  483. process_cmd.decrby = process_cmd.set
  484. process_cmd.del = process_cmd.mget
  485. process_cmd.discard = process_cmd.script
  486. process_cmd.dump = process_cmd.set
  487. process_cmd.echo = process_cmd.script
  488. process_cmd.evalsha = process_cmd.eval
  489. process_cmd.exec = process_cmd.script
  490. process_cmd.exists = process_cmd.mget
  491. process_cmd.expire = process_cmd.set
  492. process_cmd.expireat = process_cmd.set
  493. process_cmd.flushall = process_cmd.script
  494. process_cmd.flushdb = process_cmd.script
  495. process_cmd.geoadd = process_cmd.set
  496. process_cmd.geohash = process_cmd.set
  497. process_cmd.geopos = process_cmd.set
  498. process_cmd.geodist = process_cmd.set
  499. process_cmd.georadius = process_cmd.set
  500. process_cmd.georadiusbymember = process_cmd.set
  501. process_cmd.get = process_cmd.set
  502. process_cmd.getbit = process_cmd.set
  503. process_cmd.getrange = process_cmd.set
  504. process_cmd.getset = process_cmd.set
  505. process_cmd.hdel = process_cmd.set
  506. process_cmd.hexists = process_cmd.set
  507. process_cmd.hget = process_cmd.set
  508. process_cmd.hgetall = process_cmd.set
  509. process_cmd.hincrby = process_cmd.set
  510. process_cmd.hincrbyfloat = process_cmd.set
  511. process_cmd.hkeys = process_cmd.set
  512. process_cmd.hlen = process_cmd.set
  513. process_cmd.hmget = process_cmd.set
  514. process_cmd.hmset = process_cmd.set
  515. process_cmd.hscan = process_cmd.set
  516. process_cmd.hset = process_cmd.set
  517. process_cmd.hsetnx = process_cmd.set
  518. process_cmd.hstrlen = process_cmd.set
  519. process_cmd.hvals = process_cmd.set
  520. process_cmd.incr = process_cmd.set
  521. process_cmd.incrby = process_cmd.set
  522. process_cmd.incrbyfloat = process_cmd.set
  523. process_cmd.info = process_cmd.script
  524. process_cmd.keys = process_cmd.script
  525. process_cmd.lastsave = process_cmd.script
  526. process_cmd.lindex = process_cmd.set
  527. process_cmd.linsert = process_cmd.set
  528. process_cmd.llen = process_cmd.set
  529. process_cmd.lpop = process_cmd.set
  530. process_cmd.lpush = process_cmd.set
  531. process_cmd.lpushx = process_cmd.set
  532. process_cmd.lrange = process_cmd.set
  533. process_cmd.lrem = process_cmd.set
  534. process_cmd.lset = process_cmd.set
  535. process_cmd.ltrim = process_cmd.set
  536. process_cmd.migrate = process_cmd.script
  537. process_cmd.monitor = process_cmd.script
  538. process_cmd.move = process_cmd.set
  539. process_cmd.msetnx = process_cmd.mset
  540. process_cmd.multi = process_cmd.script
  541. process_cmd.object = process_cmd.script
  542. process_cmd.persist = process_cmd.set
  543. process_cmd.pexpire = process_cmd.set
  544. process_cmd.pexpireat = process_cmd.set
  545. process_cmd.pfadd = process_cmd.set
  546. process_cmd.pfcount = process_cmd.set
  547. process_cmd.pfmerge = process_cmd.mget
  548. process_cmd.ping = process_cmd.script
  549. process_cmd.psetex = process_cmd.set
  550. process_cmd.psubscribe = process_cmd.script
  551. process_cmd.pubsub = process_cmd.script
  552. process_cmd.pttl = process_cmd.set
  553. process_cmd.publish = process_cmd.script
  554. process_cmd.punsubscribe = process_cmd.script
  555. process_cmd.quit = process_cmd.script
  556. process_cmd.randomkey = process_cmd.script
  557. process_cmd.readonly = process_cmd.script
  558. process_cmd.readwrite = process_cmd.script
  559. process_cmd.rename = process_cmd.mget
  560. process_cmd.renamenx = process_cmd.mget
  561. process_cmd.restore = process_cmd.set
  562. process_cmd.role = process_cmd.script
  563. process_cmd.rpop = process_cmd.set
  564. process_cmd.rpoplpush = process_cmd.mget
  565. process_cmd.rpush = process_cmd.set
  566. process_cmd.rpushx = process_cmd.set
  567. process_cmd.sadd = process_cmd.set
  568. process_cmd.save = process_cmd.script
  569. process_cmd.scard = process_cmd.set
  570. process_cmd.sdiff = process_cmd.mget
  571. process_cmd.select = process_cmd.script
  572. process_cmd.setbit = process_cmd.set
  573. process_cmd.setex = process_cmd.set
  574. process_cmd.setnx = process_cmd.set
  575. process_cmd.sinterstore = process_cmd.sdiff
  576. process_cmd.sismember = process_cmd.set
  577. process_cmd.slaveof = process_cmd.script
  578. process_cmd.slowlog = process_cmd.script
  579. process_cmd.smembers = process_cmd.script
  580. process_cmd.sort = process_cmd.set
  581. process_cmd.spop = process_cmd.set
  582. process_cmd.srandmember = process_cmd.set
  583. process_cmd.srem = process_cmd.set
  584. process_cmd.strlen = process_cmd.set
  585. process_cmd.subscribe = process_cmd.script
  586. process_cmd.sunion = process_cmd.mget
  587. process_cmd.sunionstore = process_cmd.mget
  588. process_cmd.swapdb = process_cmd.script
  589. process_cmd.sync = process_cmd.script
  590. process_cmd.time = process_cmd.script
  591. process_cmd.touch = process_cmd.mget
  592. process_cmd.ttl = process_cmd.set
  593. process_cmd.type = process_cmd.set
  594. process_cmd.unsubscribe = process_cmd.script
  595. process_cmd.unlink = process_cmd.mget
  596. process_cmd.unwatch = process_cmd.script
  597. process_cmd.wait = process_cmd.script
  598. process_cmd.watch = process_cmd.mget
  599. process_cmd.zadd = process_cmd.set
  600. process_cmd.zcard = process_cmd.set
  601. process_cmd.zcount = process_cmd.set
  602. process_cmd.zincrby = process_cmd.set
  603. process_cmd.zinterstore = process_cmd.eval
  604. process_cmd.zlexcount = process_cmd.set
  605. process_cmd.zrange = process_cmd.set
  606. process_cmd.zrangebylex = process_cmd.set
  607. process_cmd.zrank = process_cmd.set
  608. process_cmd.zrem = process_cmd.set
  609. process_cmd.zrembylex = process_cmd.set
  610. process_cmd.zrembyrank = process_cmd.set
  611. process_cmd.zrembyscore = process_cmd.set
  612. process_cmd.zrevrange = process_cmd.set
  613. process_cmd.zrevrangebyscore = process_cmd.set
  614. process_cmd.zrevrank = process_cmd.set
  615. process_cmd.zscore = process_cmd.set
  616. process_cmd.zunionstore = process_cmd.eval
  617. process_cmd.scan = process_cmd.script
  618. process_cmd.sscan = process_cmd.set
  619. process_cmd.hscan = process_cmd.set
  620. process_cmd.zscan = process_cmd.set
  621. local function get_key_indexes(cmd, args)
  622. local idx_l = {}
  623. cmd = string.lower(cmd)
  624. if process_cmd[cmd] then
  625. idx_l = process_cmd[cmd](args)
  626. else
  627. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  628. end
  629. return idx_l
  630. end
  631. local gen_meta = {
  632. principal_recipient = function(task)
  633. return task:get_principal_recipient()
  634. end,
  635. principal_recipient_domain = function(task)
  636. local p = task:get_principal_recipient()
  637. if not p then return end
  638. return string.match(p, '.*@(.*)')
  639. end,
  640. ip = function(task)
  641. local i = task:get_ip()
  642. if i and i:is_valid() then return i:to_string() end
  643. end,
  644. from = function(task)
  645. return ((task:get_from('smtp') or E)[1] or E)['addr']
  646. end,
  647. from_domain = function(task)
  648. return ((task:get_from('smtp') or E)[1] or E)['domain']
  649. end,
  650. from_domain_or_helo_domain = function(task)
  651. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  652. if d and #d > 0 then return d end
  653. return task:get_helo()
  654. end,
  655. mime_from = function(task)
  656. return ((task:get_from('mime') or E)[1] or E)['addr']
  657. end,
  658. mime_from_domain = function(task)
  659. return ((task:get_from('mime') or E)[1] or E)['domain']
  660. end,
  661. }
  662. local function gen_get_esld(f)
  663. return function(task)
  664. local d = f(task)
  665. if not d then return end
  666. return rspamd_util.get_tld(d)
  667. end
  668. end
  669. gen_meta.smtp_from = gen_meta.from
  670. gen_meta.smtp_from_domain = gen_meta.from_domain
  671. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  672. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  673. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  674. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  675. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  676. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  677. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  678. local function get_key_expansion_metadata(task)
  679. local md_mt = {
  680. __index = function(self, k)
  681. k = string.lower(k)
  682. local v = rawget(self, k)
  683. if v then
  684. return v
  685. end
  686. if gen_meta[k] then
  687. v = gen_meta[k](task)
  688. rawset(self, k, v)
  689. end
  690. return v
  691. end,
  692. }
  693. local lazy_meta = {}
  694. setmetatable(lazy_meta, md_mt)
  695. return lazy_meta
  696. end
  697. -- Performs async call to redis hiding all complexity inside function
  698. -- task - rspamd_task
  699. -- redis_params - valid params returned by rspamd_parse_redis_server
  700. -- key - key to select upstream or nil to select round-robin/master-slave
  701. -- is_write - true if need to write to redis server
  702. -- callback - function to be called upon request is completed
  703. -- command - redis command
  704. -- args - table of arguments
  705. -- extra_opts - table of optional request arguments
  706. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  707. callback, command, args, extra_opts)
  708. local addr
  709. local function rspamd_redis_make_request_cb(err, data)
  710. if err then
  711. addr:fail()
  712. else
  713. addr:ok()
  714. end
  715. callback(err, data, addr)
  716. end
  717. if not task or not redis_params or not callback or not command then
  718. return false,nil,nil
  719. end
  720. local rspamd_redis = require "rspamd_redis"
  721. if key then
  722. if is_write then
  723. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  724. else
  725. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  726. end
  727. else
  728. if is_write then
  729. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  730. else
  731. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  732. end
  733. end
  734. if not addr then
  735. logger.errx(task, 'cannot select server to make redis request')
  736. end
  737. if redis_params['expand_keys'] then
  738. local m = get_key_expansion_metadata(task)
  739. local indexes = get_key_indexes(command, args)
  740. for _, i in ipairs(indexes) do
  741. args[i] = lutil.template(args[i], m)
  742. end
  743. end
  744. local ip_addr = addr:get_addr()
  745. local options = {
  746. task = task,
  747. callback = rspamd_redis_make_request_cb,
  748. host = ip_addr,
  749. timeout = redis_params['timeout'],
  750. cmd = command,
  751. args = args
  752. }
  753. if extra_opts then
  754. for k,v in pairs(extra_opts) do
  755. options[k] = v
  756. end
  757. end
  758. if redis_params['password'] then
  759. options['password'] = redis_params['password']
  760. end
  761. if redis_params['db'] then
  762. options['dbname'] = redis_params['db']
  763. end
  764. lutil.debugm(N, task, 'perform request to redis server' ..
  765. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', ip_addr,
  766. options.timeout, options.cmd, args)
  767. local ret,conn = rspamd_redis.make_request(options)
  768. if not ret then
  769. addr:fail()
  770. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  771. end
  772. return ret,conn,addr
  773. end
  774. --[[[
  775. -- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
  776. -- Sends a request to Redis
  777. -- @param {rspamd_task} task task object
  778. -- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
  779. -- @param {string} key key to use for sharding
  780. -- @param {boolean} is_write should be `true` if we are performing a write operating
  781. -- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
  782. -- @param {string} command Redis command to run
  783. -- @param {table} args Numerically indexed table containing arguments for command
  784. --]]
  785. exports.rspamd_redis_make_request = rspamd_redis_make_request
  786. exports.redis_make_request = rspamd_redis_make_request
  787. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  788. is_write, callback, command, args, extra_opts)
  789. if not ev_base or not redis_params or not callback or not command then
  790. return false,nil,nil
  791. end
  792. local addr
  793. local function rspamd_redis_make_request_cb(err, data)
  794. if err then
  795. addr:fail()
  796. else
  797. addr:ok()
  798. end
  799. callback(err, data, addr)
  800. end
  801. local rspamd_redis = require "rspamd_redis"
  802. if key then
  803. if is_write then
  804. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  805. else
  806. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  807. end
  808. else
  809. if is_write then
  810. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  811. else
  812. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  813. end
  814. end
  815. if not addr then
  816. logger.errx(cfg, 'cannot select server to make redis request')
  817. end
  818. local options = {
  819. ev_base = ev_base,
  820. config = cfg,
  821. callback = rspamd_redis_make_request_cb,
  822. host = addr:get_addr(),
  823. timeout = redis_params['timeout'],
  824. cmd = command,
  825. args = args
  826. }
  827. if extra_opts then
  828. for k,v in pairs(extra_opts) do
  829. options[k] = v
  830. end
  831. end
  832. if redis_params['password'] then
  833. options['password'] = redis_params['password']
  834. end
  835. if redis_params['db'] then
  836. options['dbname'] = redis_params['db']
  837. end
  838. lutil.debugm(N, cfg, 'perform taskless request to redis server' ..
  839. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', options.host,
  840. options.timeout, options.cmd, args)
  841. local ret,conn = rspamd_redis.make_request(options)
  842. if not ret then
  843. logger.errx('cannot execute redis request')
  844. addr:fail()
  845. end
  846. return ret,conn,addr
  847. end
  848. --[[[
  849. -- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
  850. -- Sends a request to Redis in context where `task` is not available for some specific use-cases
  851. -- Identical to redis_make_request() except in that first parameter is an `event base` object
  852. --]]
  853. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  854. exports.redis_make_request_taskless = redis_make_request_taskless
  855. local redis_scripts = {
  856. }
  857. local function script_set_loaded(script)
  858. if script.sha then
  859. script.loaded = true
  860. end
  861. local wait_table = {}
  862. for _,s in ipairs(script.waitq) do
  863. table.insert(wait_table, s)
  864. end
  865. script.waitq = {}
  866. for _,s in ipairs(wait_table) do
  867. s(script.loaded)
  868. end
  869. end
  870. local function prepare_redis_call(script)
  871. local function merge_tables(t1, t2)
  872. for k,v in pairs(t2) do t1[k] = v end
  873. end
  874. local servers = {}
  875. local options = {}
  876. if script.redis_params.read_servers then
  877. merge_tables(servers, script.redis_params.read_servers:all_upstreams())
  878. end
  879. if script.redis_params.write_servers then
  880. merge_tables(servers, script.redis_params.write_servers:all_upstreams())
  881. end
  882. -- Call load script on each server, set loaded flag
  883. script.in_flight = #servers
  884. for _,s in ipairs(servers) do
  885. local cur_opts = {
  886. host = s:get_addr(),
  887. timeout = script.redis_params['timeout'],
  888. cmd = 'SCRIPT',
  889. args = {'LOAD', script.script },
  890. upstream = s
  891. }
  892. if script.redis_params['password'] then
  893. cur_opts['password'] = script.redis_params['password']
  894. end
  895. if script.redis_params['db'] then
  896. cur_opts['dbname'] = script.redis_params['db']
  897. end
  898. table.insert(options, cur_opts)
  899. end
  900. return options
  901. end
  902. local function load_script_task(script, task)
  903. local rspamd_redis = require "rspamd_redis"
  904. local opts = prepare_redis_call(script)
  905. for _,opt in ipairs(opts) do
  906. opt.task = task
  907. opt.callback = function(err, data)
  908. if err then
  909. logger.errx(task, 'cannot upload script to %s: %s',
  910. opt.upstream:get_addr(), err)
  911. opt.upstream:fail()
  912. script.fatal_error = err
  913. else
  914. opt.upstream:ok()
  915. logger.infox(task,
  916. "uploaded redis script to %s with id %s, sha: %s",
  917. opt.upstream:get_addr(), script.id, data)
  918. script.sha = data -- We assume that sha is the same on all servers
  919. end
  920. script.in_flight = script.in_flight - 1
  921. if script.in_flight == 0 then
  922. script_set_loaded(script)
  923. end
  924. end
  925. local ret = rspamd_redis.make_request(opt)
  926. if not ret then
  927. logger.errx('cannot execute redis request to load script on %s',
  928. opt.upstream:get_addr())
  929. script.in_flight = script.in_flight - 1
  930. opt.upstream:fail()
  931. end
  932. if script.in_flight == 0 then
  933. script_set_loaded(script)
  934. end
  935. end
  936. end
  937. local function load_script_taskless(script, cfg, ev_base)
  938. local rspamd_redis = require "rspamd_redis"
  939. local opts = prepare_redis_call(script)
  940. for _,opt in ipairs(opts) do
  941. opt.config = cfg
  942. opt.ev_base = ev_base
  943. opt.callback = function(err, data)
  944. if err then
  945. logger.errx(cfg, 'cannot upload script to %s: %s',
  946. opt.upstream:get_addr(), err)
  947. opt.upstream:fail()
  948. script.fatal_error = err
  949. else
  950. opt.upstream:ok()
  951. logger.infox(cfg,
  952. "uploaded redis script to %s with id %s, sha: %s",
  953. opt.upstream:get_addr(), script.id, data)
  954. script.sha = data -- We assume that sha is the same on all servers
  955. script.fatal_error = nil
  956. end
  957. script.in_flight = script.in_flight - 1
  958. if script.in_flight == 0 then
  959. script_set_loaded(script)
  960. end
  961. end
  962. local ret = rspamd_redis.make_request(opt)
  963. if not ret then
  964. logger.errx('cannot execute redis request to load script on %s',
  965. opt.upstream:get_addr())
  966. script.in_flight = script.in_flight - 1
  967. opt.upstream:fail()
  968. end
  969. if script.in_flight == 0 then
  970. script_set_loaded(script)
  971. end
  972. end
  973. end
  974. local function load_redis_script(script, cfg, ev_base, _)
  975. if script.redis_params then
  976. load_script_taskless(script, cfg, ev_base)
  977. end
  978. end
  979. local function add_redis_script(script, redis_params)
  980. local new_script = {
  981. loaded = false,
  982. redis_params = redis_params,
  983. script = script,
  984. waitq = {}, -- callbacks pending for script being loaded
  985. id = #redis_scripts + 1
  986. }
  987. -- Register on load function
  988. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  989. local mult = 0.0
  990. rspamd_config:add_periodic(ev_base, 0.0, function()
  991. if not new_script.sha then
  992. load_redis_script(new_script, cfg, ev_base, worker)
  993. mult = mult + 1
  994. return 1.0 * mult -- Check one more time in one second
  995. end
  996. return false
  997. end, false)
  998. end)
  999. table.insert(redis_scripts, new_script)
  1000. return #redis_scripts
  1001. end
  1002. exports.add_redis_script = add_redis_script
  1003. local function exec_redis_script(id, params, callback, keys, args)
  1004. local redis_args = {}
  1005. if not redis_scripts[id] then
  1006. logger.errx("cannot find registered script with id %s", id)
  1007. return false
  1008. end
  1009. local script = redis_scripts[id]
  1010. if script.fatal_error then
  1011. callback(script.fatal_error, nil)
  1012. return true
  1013. end
  1014. if not script.redis_params then
  1015. callback('no redis servers defined', nil)
  1016. return true
  1017. end
  1018. local function do_call(can_reload)
  1019. local function redis_cb(err, data)
  1020. if not err then
  1021. callback(err, data)
  1022. elseif string.match(err, 'NOSCRIPT') then
  1023. -- Schedule restart
  1024. script.sha = nil
  1025. if can_reload then
  1026. table.insert(script.waitq, do_call)
  1027. if script.in_flight == 0 then
  1028. -- Reload scripts if this has not been initiated yet
  1029. if params.task then
  1030. load_script_task(script, params.task)
  1031. else
  1032. load_script_taskless(script, rspamd_config, params.ev_base)
  1033. end
  1034. end
  1035. else
  1036. callback(err, data)
  1037. end
  1038. else
  1039. callback(err, data)
  1040. end
  1041. end
  1042. if #redis_args == 0 then
  1043. table.insert(redis_args, script.sha)
  1044. table.insert(redis_args, tostring(#keys))
  1045. for _,k in ipairs(keys) do
  1046. table.insert(redis_args, k)
  1047. end
  1048. if type(args) == 'table' then
  1049. for _, a in ipairs(args) do
  1050. table.insert(redis_args, a)
  1051. end
  1052. end
  1053. end
  1054. if params.task then
  1055. if not rspamd_redis_make_request(params.task, script.redis_params,
  1056. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1057. callback('Cannot make redis request', nil)
  1058. end
  1059. else
  1060. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  1061. script.redis_params,
  1062. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1063. callback('Cannot make redis request', nil)
  1064. end
  1065. end
  1066. end
  1067. if script.loaded then
  1068. do_call(true)
  1069. else
  1070. -- Delayed until scripts are loaded
  1071. if not params.task then
  1072. table.insert(script.waitq, do_call)
  1073. else
  1074. -- TODO: fix taskfull requests
  1075. callback('NOSCRIPT', nil)
  1076. end
  1077. end
  1078. return true
  1079. end
  1080. exports.exec_redis_script = exec_redis_script
  1081. local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
  1082. if not redis_params then
  1083. return false,nil
  1084. end
  1085. local rspamd_redis = require "rspamd_redis"
  1086. local addr
  1087. if key then
  1088. if is_write then
  1089. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  1090. else
  1091. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  1092. end
  1093. else
  1094. if is_write then
  1095. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  1096. else
  1097. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  1098. end
  1099. end
  1100. if not addr then
  1101. logger.errx(cfg, 'cannot select server to make redis request')
  1102. end
  1103. local options = {
  1104. host = addr:get_addr(),
  1105. timeout = redis_params['timeout'],
  1106. config = cfg or rspamd_config,
  1107. ev_base = ev_base or rspamadm_ev_base,
  1108. session = redis_params.session or rspamadm_session
  1109. }
  1110. for k,v in pairs(redis_params) do
  1111. options[k] = v
  1112. end
  1113. if not options.config then
  1114. logger.errx('config is not set')
  1115. return false,nil,addr
  1116. end
  1117. if not options.ev_base then
  1118. logger.errx('ev_base is not set')
  1119. return false,nil,addr
  1120. end
  1121. if not options.session then
  1122. logger.errx('session is not set')
  1123. return false,nil,addr
  1124. end
  1125. local ret,conn = rspamd_redis.connect_sync(options)
  1126. if not ret then
  1127. logger.errx('cannot execute redis request: %s', conn)
  1128. addr:fail()
  1129. return false,nil,addr
  1130. end
  1131. if conn then
  1132. if redis_params['password'] then
  1133. conn:add_cmd('AUTH', {redis_params['password']})
  1134. end
  1135. if redis_params['db'] then
  1136. conn:add_cmd('SELECT', {tostring(redis_params['db'])})
  1137. elseif redis_params['dbname'] then
  1138. conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
  1139. end
  1140. end
  1141. return ret,conn,addr
  1142. end
  1143. exports.redis_connect_sync = redis_connect_sync
  1144. --[[[
  1145. -- @function lua_redis.request(redis_params, attrs, req)
  1146. -- Sends a request to Redis synchronously with coroutines or asynchronously using
  1147. -- a callback (modern API)
  1148. -- @param redis_params a table of redis server parameters
  1149. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1150. -- @param req a table of request: a command + command options
  1151. -- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
  1152. --]]
  1153. exports.request = function(redis_params, attrs, req)
  1154. local lua_util = require "lua_util"
  1155. if not attrs or not redis_params or not req then
  1156. logger.errx('invalid arguments for redis request')
  1157. return false,nil,nil
  1158. end
  1159. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1160. logger.errx('invalid attributes for redis request')
  1161. return false,nil,nil
  1162. end
  1163. local opts = lua_util.shallowcopy(attrs)
  1164. local log_obj = opts.task or opts.config
  1165. local addr
  1166. if opts.callback then
  1167. -- Wrap callback
  1168. local callback = opts.callback
  1169. local function rspamd_redis_make_request_cb(err, data)
  1170. if err then
  1171. addr:fail()
  1172. else
  1173. addr:ok()
  1174. end
  1175. callback(err, data, addr)
  1176. end
  1177. opts.callback = rspamd_redis_make_request_cb
  1178. end
  1179. local rspamd_redis = require "rspamd_redis"
  1180. local is_write = opts.is_write
  1181. if opts.key then
  1182. if is_write then
  1183. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1184. else
  1185. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1186. end
  1187. else
  1188. if is_write then
  1189. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1190. else
  1191. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1192. end
  1193. end
  1194. if not addr then
  1195. logger.errx(log_obj, 'cannot select server to make redis request')
  1196. end
  1197. opts.host = addr:get_addr()
  1198. opts.timeout = redis_params.timeout
  1199. if type(req) == 'string' then
  1200. opts.cmd = req
  1201. else
  1202. -- XXX: modifies the input table
  1203. opts.cmd = table.remove(req, 1);
  1204. opts.args = req
  1205. end
  1206. if redis_params.password then
  1207. opts.password = redis_params.password
  1208. end
  1209. if redis_params.db then
  1210. opts.dbname = redis_params.db
  1211. end
  1212. lutil.debugm(N, 'perform generic request to redis server' ..
  1213. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr,
  1214. opts.timeout, opts.cmd, opts.args)
  1215. if opts.callback then
  1216. local ret,conn = rspamd_redis.make_request(opts)
  1217. if not ret then
  1218. logger.errx(log_obj, 'cannot execute redis request')
  1219. addr:fail()
  1220. end
  1221. return ret,conn,addr
  1222. else
  1223. -- Coroutines version
  1224. local ret,conn = rspamd_redis.connect_sync(opts)
  1225. if not ret then
  1226. logger.errx(log_obj, 'cannot execute redis request')
  1227. addr:fail()
  1228. else
  1229. conn:add_cmd(opts.cmd, opts.args)
  1230. return conn:exec()
  1231. end
  1232. return false,nil,addr
  1233. end
  1234. end
  1235. --[[[
  1236. -- @function lua_redis.connect(redis_params, attrs)
  1237. -- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
  1238. -- @param redis_params a table of redis server parameters
  1239. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1240. -- @return {result,connection,address} boolean result, connection object, redis server address
  1241. --]]
  1242. exports.connect = function(redis_params, attrs)
  1243. local lua_util = require "lua_util"
  1244. if not attrs or not redis_params then
  1245. logger.errx('invalid arguments for redis connect')
  1246. return false,nil,nil
  1247. end
  1248. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1249. logger.errx('invalid attributes for redis connect')
  1250. return false,nil,nil
  1251. end
  1252. local opts = lua_util.shallowcopy(attrs)
  1253. local log_obj = opts.task or opts.config
  1254. local addr
  1255. if opts.callback then
  1256. -- Wrap callback
  1257. local callback = opts.callback
  1258. local function rspamd_redis_make_request_cb(err, data)
  1259. if err then
  1260. addr:fail()
  1261. else
  1262. addr:ok()
  1263. end
  1264. callback(err, data, addr)
  1265. end
  1266. opts.callback = rspamd_redis_make_request_cb
  1267. end
  1268. local rspamd_redis = require "rspamd_redis"
  1269. local is_write = opts.is_write
  1270. if opts.key then
  1271. if is_write then
  1272. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1273. else
  1274. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1275. end
  1276. else
  1277. if is_write then
  1278. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1279. else
  1280. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1281. end
  1282. end
  1283. if not addr then
  1284. logger.errx(log_obj, 'cannot select server to make redis connect')
  1285. end
  1286. opts.host = addr:get_addr()
  1287. opts.timeout = redis_params.timeout
  1288. if redis_params.password then
  1289. opts.password = redis_params.password
  1290. end
  1291. if redis_params.db then
  1292. opts.dbname = redis_params.db
  1293. end
  1294. if opts.callback then
  1295. local ret,conn = rspamd_redis.connect(opts)
  1296. if not ret then
  1297. logger.errx(log_obj, 'cannot execute redis connect')
  1298. addr:fail()
  1299. end
  1300. return ret,conn,addr
  1301. else
  1302. -- Coroutines version
  1303. local ret,conn = rspamd_redis.connect_sync(opts)
  1304. if not ret then
  1305. logger.errx(log_obj, 'cannot execute redis connect')
  1306. addr:fail()
  1307. else
  1308. return true,conn,addr
  1309. end
  1310. return false,nil,addr
  1311. end
  1312. end
  1313. return exports