You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660
  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local logger = require "rspamd_logger"
  14. local lutil = require "lua_util"
  15. local rspamd_util = require "rspamd_util"
  16. local ts = require("tableshape").types
  17. local exports = {}
  18. local E = {}
  19. local N = "lua_redis"
  20. local common_schema = ts.shape {
  21. timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  22. db = ts.string:is_optional(),
  23. database = ts.string:is_optional(),
  24. dbname = ts.string:is_optional(),
  25. prefix = ts.string:is_optional(),
  26. password = ts.string:is_optional(),
  27. expand_keys = ts.boolean:is_optional(),
  28. sentinels = (ts.string + ts.array_of(ts.string)):is_optional(),
  29. sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional(),
  30. sentinel_masters_pattern = ts.string:is_optional(),
  31. sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional(),
  32. }
  33. local config_schema =
  34. ts.shape({
  35. read_servers = ts.string + ts.array_of(ts.string),
  36. write_servers = ts.string + ts.array_of(ts.string),
  37. }, {extra_opts = common_schema}) +
  38. ts.shape({
  39. servers = ts.string + ts.array_of(ts.string),
  40. }, {extra_opts = common_schema}) +
  41. ts.shape({
  42. server = ts.string + ts.array_of(ts.string),
  43. }, {extra_opts = common_schema})
  44. exports.config_schema = config_schema
  45. local function redis_query_sentinel(ev_base, params, initialised)
  46. local function flatten_redis_table(tbl)
  47. local res = {}
  48. for i=1,#tbl,2 do
  49. res[tbl[i]] = tbl[i + 1]
  50. end
  51. return res
  52. end
  53. -- Coroutines syntax
  54. local rspamd_redis = require "rspamd_redis"
  55. local addr = params.sentinels:get_upstream_round_robin()
  56. local is_ok, connection = rspamd_redis.connect_sync({
  57. host = addr:get_addr(),
  58. timeout = params.timeout,
  59. config = rspamd_config,
  60. ev_base = ev_base,
  61. no_pool = true,
  62. })
  63. if not is_ok then
  64. logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
  65. tostring(addr:get_addr()))
  66. addr:fail()
  67. return
  68. end
  69. -- Get masters list
  70. connection:add_cmd('SENTINEL', {'masters'})
  71. local ok,result = connection:exec()
  72. if ok and result and type(result) == 'table' then
  73. local masters = {}
  74. for _,m in ipairs(result) do
  75. local master = flatten_redis_table(m)
  76. -- Wrap IPv6-adresses in brackets
  77. if (master.ip:match(":")) then
  78. master.ip = "["..master.ip.."]"
  79. end
  80. if params.sentinel_masters_pattern then
  81. if master.name:match(params.sentinel_masters_pattern) then
  82. lutil.debugm(N, 'found master %s with ip %s and port %s',
  83. master.name, master.ip, master.port)
  84. masters[master.name] = master
  85. else
  86. lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
  87. master.name, master.ip, master.port, params.sentinel_masters_pattern)
  88. end
  89. else
  90. lutil.debugm(N, 'found master %s with ip %s and port %s',
  91. master.name, master.ip, master.port)
  92. masters[master.name] = master
  93. end
  94. end
  95. -- For each master we need to get a list of slaves
  96. for k,v in pairs(masters) do
  97. v.slaves = {}
  98. local slave_result
  99. connection:add_cmd('SENTINEL', {'slaves', k})
  100. ok,slave_result = connection:exec()
  101. if ok then
  102. for _,s in ipairs(slave_result) do
  103. local slave = flatten_redis_table(s)
  104. lutil.debugm(N, rspamd_config,
  105. 'found slave for master %s with ip %s and port %s',
  106. v.name, slave.ip, slave.port)
  107. -- Wrap IPv6-adresses in brackets
  108. if (slave.ip:match(":")) then
  109. slave.ip = "["..slave.ip.."]"
  110. end
  111. v.slaves[#v.slaves + 1] = slave
  112. end
  113. end
  114. end
  115. -- We now form new strings for masters and slaves
  116. local read_servers_tbl, write_servers_tbl = {}, {}
  117. for _,master in pairs(masters) do
  118. write_servers_tbl[#write_servers_tbl + 1] = string.format(
  119. '%s:%s', master.ip, master.port
  120. )
  121. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  122. '%s:%s', master.ip, master.port
  123. )
  124. for _,slave in ipairs(master.slaves) do
  125. if slave['master-link-status'] == 'ok' then
  126. read_servers_tbl[#read_servers_tbl + 1] = string.format(
  127. '%s:%s', slave.ip, slave.port
  128. )
  129. end
  130. end
  131. end
  132. table.sort(read_servers_tbl)
  133. table.sort(write_servers_tbl)
  134. local read_servers_str = table.concat(read_servers_tbl, ',')
  135. local write_servers_str = table.concat(write_servers_tbl, ',')
  136. lutil.debugm(N, rspamd_config,
  137. 'new servers list: %s read; %s write',
  138. read_servers_str,
  139. write_servers_str)
  140. if read_servers_str ~= params.read_servers_str then
  141. local upstream_list = require "rspamd_upstream_list"
  142. local read_upstreams = upstream_list.create(rspamd_config,
  143. read_servers_str, 6379)
  144. if read_upstreams then
  145. logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
  146. addr:get_addr():to_string(true), read_servers_str)
  147. params.read_servers = read_upstreams
  148. params.read_servers_str = read_servers_str
  149. end
  150. end
  151. if write_servers_str ~= params.write_servers_str then
  152. local upstream_list = require "rspamd_upstream_list"
  153. local write_upstreams = upstream_list.create(rspamd_config,
  154. write_servers_str, 6379)
  155. if write_upstreams then
  156. logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
  157. addr:get_addr():to_string(true), write_servers_str)
  158. params.write_servers = write_upstreams
  159. params.write_servers_str = write_servers_str
  160. local queried = false
  161. local function monitor_failures(up, _, count)
  162. if count > params.sentinel_master_maxerrors and not queried then
  163. logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel',
  164. up:get_addr():to_string(true), count)
  165. queried = true -- Avoid multiple checks caused by this monitor
  166. redis_query_sentinel(ev_base, params, true)
  167. end
  168. end
  169. write_upstreams:add_watcher('failure', monitor_failures)
  170. end
  171. end
  172. addr:ok()
  173. else
  174. logger.errx('cannot get data from Redis Sentinel %s: %s',
  175. addr:get_addr():to_string(true), result)
  176. addr:fail()
  177. end
  178. end
  179. local function add_redis_sentinels(params)
  180. local upstream_list = require "rspamd_upstream_list"
  181. local upstreams_sentinels = upstream_list.create(rspamd_config,
  182. params.sentinels, 5000)
  183. if not upstreams_sentinels then
  184. logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
  185. params.sentinels)
  186. return
  187. end
  188. params.sentinels = upstreams_sentinels
  189. if not params.sentinel_watch_time then
  190. params.sentinel_watch_time = 60 -- Each minute
  191. end
  192. if not params.sentinel_master_maxerrors then
  193. params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking
  194. end
  195. rspamd_config:add_on_load(function(_, ev_base, worker)
  196. local initialised = false
  197. if worker:is_scanner() then
  198. rspamd_config:add_periodic(ev_base, 0.0, function()
  199. redis_query_sentinel(ev_base, params, initialised)
  200. initialised = true
  201. return params.sentinel_watch_time
  202. end, false)
  203. end
  204. end)
  205. end
  206. local cached_results = {}
  207. local function calculate_redis_hash(params)
  208. local cr = require "rspamd_cryptobox_hash"
  209. local h = cr.create()
  210. local function rec_hash(k, v)
  211. if type(v) == 'string' then
  212. h:update(k)
  213. h:update(v)
  214. elseif type(v) == 'number' then
  215. h:update(k)
  216. h:update(tostring(v))
  217. elseif type(v) == 'table' then
  218. for kk,vv in pairs(v) do
  219. rec_hash(kk, vv)
  220. end
  221. end
  222. end
  223. rec_hash('top', params)
  224. return h:base32()
  225. end
  226. local function process_redis_opts(options, redis_params)
  227. local default_timeout = 1.0
  228. local default_expand_keys = false
  229. if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then
  230. if options['timeout'] then
  231. redis_params['timeout'] = tonumber(options['timeout'])
  232. else
  233. redis_params['timeout'] = default_timeout
  234. end
  235. end
  236. if options['prefix'] and not redis_params['prefix'] then
  237. redis_params['prefix'] = options['prefix']
  238. end
  239. if type(options['expand_keys']) == 'boolean' then
  240. redis_params['expand_keys'] = options['expand_keys']
  241. else
  242. redis_params['expand_keys'] = default_expand_keys
  243. end
  244. if not redis_params['db'] then
  245. if options['db'] then
  246. redis_params['db'] = tostring(options['db'])
  247. elseif options['dbname'] then
  248. redis_params['db'] = tostring(options['dbname'])
  249. elseif options['database'] then
  250. redis_params['db'] = tostring(options['database'])
  251. end
  252. end
  253. if options['password'] and not redis_params['password'] then
  254. redis_params['password'] = options['password']
  255. end
  256. if not redis_params.sentinels and options.sentinels then
  257. redis_params.sentinels = options.sentinels
  258. end
  259. if options['sentinel_masters_pattern'] and not redis_params['sentinel_masters_pattern'] then
  260. redis_params['sentinel_masters_pattern'] = options['sentinel_masters_pattern']
  261. end
  262. end
  263. local function enrich_defaults(rspamd_config, module, redis_params)
  264. if rspamd_config then
  265. local opts = rspamd_config:get_all_opt('redis')
  266. if opts then
  267. if module then
  268. if opts[module] then
  269. process_redis_opts(opts[module], redis_params)
  270. end
  271. end
  272. process_redis_opts(opts, redis_params)
  273. end
  274. end
  275. end
  276. local function maybe_return_cached(redis_params)
  277. local h = calculate_redis_hash(redis_params)
  278. if cached_results[h] then
  279. lutil.debugm(N, 'reused redis server: %s', redis_params)
  280. return cached_results[h]
  281. end
  282. redis_params.hash = h
  283. cached_results[h] = redis_params
  284. if not redis_params.read_only and redis_params.sentinels then
  285. add_redis_sentinels(redis_params)
  286. end
  287. lutil.debugm(N, 'loaded new redis server: %s', redis_params)
  288. return redis_params
  289. end
  290. --[[[
  291. -- @module lua_redis
  292. -- This module contains helper functions for working with Redis
  293. --]]
  294. local function process_redis_options(options, rspamd_config, result)
  295. local default_port = 6379
  296. local upstream_list = require "rspamd_upstream_list"
  297. local read_only = true
  298. -- Try to get read servers:
  299. local upstreams_read, upstreams_write
  300. if options['read_servers'] then
  301. if rspamd_config then
  302. upstreams_read = upstream_list.create(rspamd_config,
  303. options['read_servers'], default_port)
  304. else
  305. upstreams_read = upstream_list.create(options['read_servers'],
  306. default_port)
  307. end
  308. result.read_servers_str = options['read_servers']
  309. elseif options['servers'] then
  310. if rspamd_config then
  311. upstreams_read = upstream_list.create(rspamd_config,
  312. options['servers'], default_port)
  313. else
  314. upstreams_read = upstream_list.create(options['servers'], default_port)
  315. end
  316. result.read_servers_str = options['servers']
  317. read_only = false
  318. elseif options['server'] then
  319. if rspamd_config then
  320. upstreams_read = upstream_list.create(rspamd_config,
  321. options['server'], default_port)
  322. else
  323. upstreams_read = upstream_list.create(options['server'], default_port)
  324. end
  325. result.read_servers_str = options['server']
  326. read_only = false
  327. end
  328. if upstreams_read then
  329. if options['write_servers'] then
  330. if rspamd_config then
  331. upstreams_write = upstream_list.create(rspamd_config,
  332. options['write_servers'], default_port)
  333. else
  334. upstreams_write = upstream_list.create(options['write_servers'],
  335. default_port)
  336. end
  337. result.write_servers_str = options['write_servers']
  338. read_only = false
  339. elseif not read_only then
  340. upstreams_write = upstreams_read
  341. result.write_servers_str = result.read_servers_str
  342. end
  343. end
  344. -- Store options
  345. process_redis_opts(options, result)
  346. if read_only and not upstreams_write then
  347. result.read_only = true
  348. elseif upstreams_write then
  349. result.read_only = false
  350. end
  351. if upstreams_read then
  352. result.read_servers = upstreams_read
  353. if upstreams_write then
  354. result.write_servers = upstreams_write
  355. end
  356. return true
  357. end
  358. lutil.debugm(N, rspamd_config,
  359. 'cannot load redis server from obj: %s, processed to %s',
  360. options, result)
  361. return false
  362. end
  363. --[[[
  364. @function try_load_redis_servers(options, rspamd_config, no_fallback)
  365. Tries to load redis servers from the specified `options` object.
  366. Returns `redis_params` table or nil in case of failure
  367. --]]
  368. exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name)
  369. local result = {}
  370. if process_redis_options(options, rspamd_config, result) then
  371. if not no_fallback then
  372. enrich_defaults(rspamd_config, module_name, result)
  373. end
  374. return maybe_return_cached(result)
  375. end
  376. end
  377. -- This function parses redis server definition using either
  378. -- specific server string for this module or global
  379. -- redis section
  380. local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
  381. local result = {}
  382. -- Try local options
  383. local opts
  384. lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name)
  385. if not module_opts then
  386. opts = rspamd_config:get_all_opt(module_name)
  387. else
  388. opts = module_opts
  389. end
  390. if opts then
  391. local ret
  392. if opts.redis then
  393. ret = process_redis_options(opts.redis, rspamd_config, result)
  394. if ret then
  395. if not no_fallback then
  396. enrich_defaults(rspamd_config, module_name, result)
  397. end
  398. return maybe_return_cached(result)
  399. end
  400. end
  401. ret = process_redis_options(opts, rspamd_config, result)
  402. if ret then
  403. if not no_fallback then
  404. enrich_defaults(rspamd_config, module_name, result)
  405. end
  406. return maybe_return_cached(result)
  407. end
  408. end
  409. if no_fallback then
  410. logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled",
  411. module_name)
  412. return nil
  413. end
  414. -- Try global options
  415. opts = rspamd_config:get_all_opt('redis')
  416. if opts then
  417. local ret
  418. if opts[module_name] then
  419. ret = process_redis_options(opts[module_name], rspamd_config, result)
  420. if ret then
  421. return maybe_return_cached(result)
  422. end
  423. else
  424. ret = process_redis_options(opts, rspamd_config, result)
  425. -- Exclude disabled
  426. if opts['disabled_modules'] then
  427. for _,v in ipairs(opts['disabled_modules']) do
  428. if v == module_name then
  429. logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
  430. module_name)
  431. return nil
  432. end
  433. end
  434. end
  435. if ret then
  436. logger.infox(rspamd_config, "use default Redis settings for %s",
  437. module_name)
  438. return maybe_return_cached(result)
  439. end
  440. end
  441. end
  442. if result.read_servers then
  443. return maybe_return_cached(result)
  444. end
  445. return nil
  446. end
  447. --[[[
  448. -- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
  449. -- Extracts Redis server settings from configuration
  450. -- @param {string} module_name name of module to get settings for
  451. -- @param {table} module_opts settings for module or `nil` to fetch them from configuration
  452. -- @param {boolean} no_fallback should be `true` if global settings must not be used
  453. -- @return {table} redis server settings
  454. -- @example
  455. -- local rconfig = lua_redis.parse_redis_server('my_module')
  456. -- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
  457. -- -- ['timeout'] contains timeout in seconds
  458. -- -- ['expand_keys'] if true tells that redis key expansion is enabled
  459. --]]
  460. exports.rspamd_parse_redis_server = rspamd_parse_redis_server
  461. exports.parse_redis_server = rspamd_parse_redis_server
  462. local process_cmd = {
  463. bitop = function(args)
  464. local idx_l = {}
  465. for i = 2, #args do
  466. table.insert(idx_l, i)
  467. end
  468. return idx_l
  469. end,
  470. blpop = function(args)
  471. local idx_l = {}
  472. for i = 1, #args -1 do
  473. table.insert(idx_l, i)
  474. end
  475. return idx_l
  476. end,
  477. eval = function(args)
  478. local idx_l = {}
  479. local numkeys = args[2]
  480. if numkeys and tonumber(numkeys) >= 1 then
  481. for i = 3, numkeys + 2 do
  482. table.insert(idx_l, i)
  483. end
  484. end
  485. return idx_l
  486. end,
  487. set = function(args)
  488. return {1}
  489. end,
  490. mget = function(args)
  491. local idx_l = {}
  492. for i = 1, #args do
  493. table.insert(idx_l, i)
  494. end
  495. return idx_l
  496. end,
  497. mset = function(args)
  498. local idx_l = {}
  499. for i = 1, #args, 2 do
  500. table.insert(idx_l, i)
  501. end
  502. return idx_l
  503. end,
  504. sdiffstore = function(args)
  505. local idx_l = {}
  506. for i = 2, #args do
  507. table.insert(idx_l, i)
  508. end
  509. return idx_l
  510. end,
  511. smove = function(args)
  512. return {1, 2}
  513. end,
  514. script = function() end
  515. }
  516. process_cmd.append = process_cmd.set
  517. process_cmd.auth = process_cmd.script
  518. process_cmd.bgrewriteaof = process_cmd.script
  519. process_cmd.bgsave = process_cmd.script
  520. process_cmd.bitcount = process_cmd.set
  521. process_cmd.bitfield = process_cmd.set
  522. process_cmd.bitpos = process_cmd.set
  523. process_cmd.brpop = process_cmd.blpop
  524. process_cmd.brpoplpush = process_cmd.blpop
  525. process_cmd.client = process_cmd.script
  526. process_cmd.cluster = process_cmd.script
  527. process_cmd.command = process_cmd.script
  528. process_cmd.config = process_cmd.script
  529. process_cmd.dbsize = process_cmd.script
  530. process_cmd.debug = process_cmd.script
  531. process_cmd.decr = process_cmd.set
  532. process_cmd.decrby = process_cmd.set
  533. process_cmd.del = process_cmd.mget
  534. process_cmd.discard = process_cmd.script
  535. process_cmd.dump = process_cmd.set
  536. process_cmd.echo = process_cmd.script
  537. process_cmd.evalsha = process_cmd.eval
  538. process_cmd.exec = process_cmd.script
  539. process_cmd.exists = process_cmd.mget
  540. process_cmd.expire = process_cmd.set
  541. process_cmd.expireat = process_cmd.set
  542. process_cmd.flushall = process_cmd.script
  543. process_cmd.flushdb = process_cmd.script
  544. process_cmd.geoadd = process_cmd.set
  545. process_cmd.geohash = process_cmd.set
  546. process_cmd.geopos = process_cmd.set
  547. process_cmd.geodist = process_cmd.set
  548. process_cmd.georadius = process_cmd.set
  549. process_cmd.georadiusbymember = process_cmd.set
  550. process_cmd.get = process_cmd.set
  551. process_cmd.getbit = process_cmd.set
  552. process_cmd.getrange = process_cmd.set
  553. process_cmd.getset = process_cmd.set
  554. process_cmd.hdel = process_cmd.set
  555. process_cmd.hexists = process_cmd.set
  556. process_cmd.hget = process_cmd.set
  557. process_cmd.hgetall = process_cmd.set
  558. process_cmd.hincrby = process_cmd.set
  559. process_cmd.hincrbyfloat = process_cmd.set
  560. process_cmd.hkeys = process_cmd.set
  561. process_cmd.hlen = process_cmd.set
  562. process_cmd.hmget = process_cmd.set
  563. process_cmd.hmset = process_cmd.set
  564. process_cmd.hscan = process_cmd.set
  565. process_cmd.hset = process_cmd.set
  566. process_cmd.hsetnx = process_cmd.set
  567. process_cmd.hstrlen = process_cmd.set
  568. process_cmd.hvals = process_cmd.set
  569. process_cmd.incr = process_cmd.set
  570. process_cmd.incrby = process_cmd.set
  571. process_cmd.incrbyfloat = process_cmd.set
  572. process_cmd.info = process_cmd.script
  573. process_cmd.keys = process_cmd.script
  574. process_cmd.lastsave = process_cmd.script
  575. process_cmd.lindex = process_cmd.set
  576. process_cmd.linsert = process_cmd.set
  577. process_cmd.llen = process_cmd.set
  578. process_cmd.lpop = process_cmd.set
  579. process_cmd.lpush = process_cmd.set
  580. process_cmd.lpushx = process_cmd.set
  581. process_cmd.lrange = process_cmd.set
  582. process_cmd.lrem = process_cmd.set
  583. process_cmd.lset = process_cmd.set
  584. process_cmd.ltrim = process_cmd.set
  585. process_cmd.migrate = process_cmd.script
  586. process_cmd.monitor = process_cmd.script
  587. process_cmd.move = process_cmd.set
  588. process_cmd.msetnx = process_cmd.mset
  589. process_cmd.multi = process_cmd.script
  590. process_cmd.object = process_cmd.script
  591. process_cmd.persist = process_cmd.set
  592. process_cmd.pexpire = process_cmd.set
  593. process_cmd.pexpireat = process_cmd.set
  594. process_cmd.pfadd = process_cmd.set
  595. process_cmd.pfcount = process_cmd.set
  596. process_cmd.pfmerge = process_cmd.mget
  597. process_cmd.ping = process_cmd.script
  598. process_cmd.psetex = process_cmd.set
  599. process_cmd.psubscribe = process_cmd.script
  600. process_cmd.pubsub = process_cmd.script
  601. process_cmd.pttl = process_cmd.set
  602. process_cmd.publish = process_cmd.script
  603. process_cmd.punsubscribe = process_cmd.script
  604. process_cmd.quit = process_cmd.script
  605. process_cmd.randomkey = process_cmd.script
  606. process_cmd.readonly = process_cmd.script
  607. process_cmd.readwrite = process_cmd.script
  608. process_cmd.rename = process_cmd.mget
  609. process_cmd.renamenx = process_cmd.mget
  610. process_cmd.restore = process_cmd.set
  611. process_cmd.role = process_cmd.script
  612. process_cmd.rpop = process_cmd.set
  613. process_cmd.rpoplpush = process_cmd.mget
  614. process_cmd.rpush = process_cmd.set
  615. process_cmd.rpushx = process_cmd.set
  616. process_cmd.sadd = process_cmd.set
  617. process_cmd.save = process_cmd.script
  618. process_cmd.scard = process_cmd.set
  619. process_cmd.sdiff = process_cmd.mget
  620. process_cmd.select = process_cmd.script
  621. process_cmd.setbit = process_cmd.set
  622. process_cmd.setex = process_cmd.set
  623. process_cmd.setnx = process_cmd.set
  624. process_cmd.sinterstore = process_cmd.sdiff
  625. process_cmd.sismember = process_cmd.set
  626. process_cmd.slaveof = process_cmd.script
  627. process_cmd.slowlog = process_cmd.script
  628. process_cmd.smembers = process_cmd.script
  629. process_cmd.sort = process_cmd.set
  630. process_cmd.spop = process_cmd.set
  631. process_cmd.srandmember = process_cmd.set
  632. process_cmd.srem = process_cmd.set
  633. process_cmd.strlen = process_cmd.set
  634. process_cmd.subscribe = process_cmd.script
  635. process_cmd.sunion = process_cmd.mget
  636. process_cmd.sunionstore = process_cmd.mget
  637. process_cmd.swapdb = process_cmd.script
  638. process_cmd.sync = process_cmd.script
  639. process_cmd.time = process_cmd.script
  640. process_cmd.touch = process_cmd.mget
  641. process_cmd.ttl = process_cmd.set
  642. process_cmd.type = process_cmd.set
  643. process_cmd.unsubscribe = process_cmd.script
  644. process_cmd.unlink = process_cmd.mget
  645. process_cmd.unwatch = process_cmd.script
  646. process_cmd.wait = process_cmd.script
  647. process_cmd.watch = process_cmd.mget
  648. process_cmd.zadd = process_cmd.set
  649. process_cmd.zcard = process_cmd.set
  650. process_cmd.zcount = process_cmd.set
  651. process_cmd.zincrby = process_cmd.set
  652. process_cmd.zinterstore = process_cmd.eval
  653. process_cmd.zlexcount = process_cmd.set
  654. process_cmd.zrange = process_cmd.set
  655. process_cmd.zrangebylex = process_cmd.set
  656. process_cmd.zrank = process_cmd.set
  657. process_cmd.zrem = process_cmd.set
  658. process_cmd.zrembylex = process_cmd.set
  659. process_cmd.zrembyrank = process_cmd.set
  660. process_cmd.zrembyscore = process_cmd.set
  661. process_cmd.zrevrange = process_cmd.set
  662. process_cmd.zrevrangebyscore = process_cmd.set
  663. process_cmd.zrevrank = process_cmd.set
  664. process_cmd.zscore = process_cmd.set
  665. process_cmd.zunionstore = process_cmd.eval
  666. process_cmd.scan = process_cmd.script
  667. process_cmd.sscan = process_cmd.set
  668. process_cmd.hscan = process_cmd.set
  669. process_cmd.zscan = process_cmd.set
  670. local function get_key_indexes(cmd, args)
  671. local idx_l = {}
  672. cmd = string.lower(cmd)
  673. if process_cmd[cmd] then
  674. idx_l = process_cmd[cmd](args)
  675. else
  676. logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
  677. end
  678. return idx_l
  679. end
  680. local gen_meta = {
  681. principal_recipient = function(task)
  682. return task:get_principal_recipient()
  683. end,
  684. principal_recipient_domain = function(task)
  685. local p = task:get_principal_recipient()
  686. if not p then return end
  687. return string.match(p, '.*@(.*)')
  688. end,
  689. ip = function(task)
  690. local i = task:get_ip()
  691. if i and i:is_valid() then return i:to_string() end
  692. end,
  693. from = function(task)
  694. return ((task:get_from('smtp') or E)[1] or E)['addr']
  695. end,
  696. from_domain = function(task)
  697. return ((task:get_from('smtp') or E)[1] or E)['domain']
  698. end,
  699. from_domain_or_helo_domain = function(task)
  700. local d = ((task:get_from('smtp') or E)[1] or E)['domain']
  701. if d and #d > 0 then return d end
  702. return task:get_helo()
  703. end,
  704. mime_from = function(task)
  705. return ((task:get_from('mime') or E)[1] or E)['addr']
  706. end,
  707. mime_from_domain = function(task)
  708. return ((task:get_from('mime') or E)[1] or E)['domain']
  709. end,
  710. }
  711. local function gen_get_esld(f)
  712. return function(task)
  713. local d = f(task)
  714. if not d then return end
  715. return rspamd_util.get_tld(d)
  716. end
  717. end
  718. gen_meta.smtp_from = gen_meta.from
  719. gen_meta.smtp_from_domain = gen_meta.from_domain
  720. gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
  721. gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
  722. gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
  723. gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
  724. gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
  725. gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
  726. gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
  727. local function get_key_expansion_metadata(task)
  728. local md_mt = {
  729. __index = function(self, k)
  730. k = string.lower(k)
  731. local v = rawget(self, k)
  732. if v then
  733. return v
  734. end
  735. if gen_meta[k] then
  736. v = gen_meta[k](task)
  737. rawset(self, k, v)
  738. end
  739. return v
  740. end,
  741. }
  742. local lazy_meta = {}
  743. setmetatable(lazy_meta, md_mt)
  744. return lazy_meta
  745. end
  746. -- Performs async call to redis hiding all complexity inside function
  747. -- task - rspamd_task
  748. -- redis_params - valid params returned by rspamd_parse_redis_server
  749. -- key - key to select upstream or nil to select round-robin/master-slave
  750. -- is_write - true if need to write to redis server
  751. -- callback - function to be called upon request is completed
  752. -- command - redis command
  753. -- args - table of arguments
  754. -- extra_opts - table of optional request arguments
  755. local function rspamd_redis_make_request(task, redis_params, key, is_write,
  756. callback, command, args, extra_opts)
  757. local addr
  758. local function rspamd_redis_make_request_cb(err, data)
  759. if err then
  760. addr:fail()
  761. else
  762. addr:ok()
  763. end
  764. if callback then
  765. callback(err, data, addr)
  766. end
  767. end
  768. if not task or not redis_params or not callback or not command then
  769. return false,nil,nil
  770. end
  771. local rspamd_redis = require "rspamd_redis"
  772. if key then
  773. if is_write then
  774. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  775. else
  776. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  777. end
  778. else
  779. if is_write then
  780. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  781. else
  782. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  783. end
  784. end
  785. if not addr then
  786. logger.errx(task, 'cannot select server to make redis request')
  787. end
  788. if redis_params['expand_keys'] then
  789. local m = get_key_expansion_metadata(task)
  790. local indexes = get_key_indexes(command, args)
  791. for _, i in ipairs(indexes) do
  792. args[i] = lutil.template(args[i], m)
  793. end
  794. end
  795. local ip_addr = addr:get_addr()
  796. local options = {
  797. task = task,
  798. callback = rspamd_redis_make_request_cb,
  799. host = ip_addr,
  800. timeout = redis_params['timeout'],
  801. cmd = command,
  802. args = args
  803. }
  804. if extra_opts then
  805. for k,v in pairs(extra_opts) do
  806. options[k] = v
  807. end
  808. end
  809. if redis_params['password'] then
  810. options['password'] = redis_params['password']
  811. end
  812. if redis_params['db'] then
  813. options['dbname'] = redis_params['db']
  814. end
  815. lutil.debugm(N, task, 'perform request to redis server' ..
  816. ' (host=%s, timeout=%s): cmd: %s', ip_addr,
  817. options.timeout, options.cmd)
  818. local ret,conn = rspamd_redis.make_request(options)
  819. if not ret then
  820. addr:fail()
  821. logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
  822. end
  823. return ret,conn,addr
  824. end
  825. --[[[
  826. -- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
  827. -- Sends a request to Redis
  828. -- @param {rspamd_task} task task object
  829. -- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
  830. -- @param {string} key key to use for sharding
  831. -- @param {boolean} is_write should be `true` if we are performing a write operating
  832. -- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
  833. -- @param {string} command Redis command to run
  834. -- @param {table} args Numerically indexed table containing arguments for command
  835. --]]
  836. exports.rspamd_redis_make_request = rspamd_redis_make_request
  837. exports.redis_make_request = rspamd_redis_make_request
  838. local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
  839. is_write, callback, command, args, extra_opts)
  840. if not ev_base or not redis_params or not callback or not command then
  841. return false,nil,nil
  842. end
  843. local addr
  844. local function rspamd_redis_make_request_cb(err, data)
  845. if err then
  846. addr:fail()
  847. else
  848. addr:ok()
  849. end
  850. if callback then
  851. callback(err, data, addr)
  852. end
  853. end
  854. local rspamd_redis = require "rspamd_redis"
  855. if key then
  856. if is_write then
  857. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  858. else
  859. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  860. end
  861. else
  862. if is_write then
  863. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  864. else
  865. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  866. end
  867. end
  868. if not addr then
  869. logger.errx(cfg, 'cannot select server to make redis request')
  870. end
  871. local options = {
  872. ev_base = ev_base,
  873. config = cfg,
  874. callback = rspamd_redis_make_request_cb,
  875. host = addr:get_addr(),
  876. timeout = redis_params['timeout'],
  877. cmd = command,
  878. args = args
  879. }
  880. if extra_opts then
  881. for k,v in pairs(extra_opts) do
  882. options[k] = v
  883. end
  884. end
  885. if redis_params['password'] then
  886. options['password'] = redis_params['password']
  887. end
  888. if redis_params['db'] then
  889. options['dbname'] = redis_params['db']
  890. end
  891. lutil.debugm(N, cfg, 'perform taskless request to redis server' ..
  892. ' (host=%s, timeout=%s): cmd: %s', options.host,
  893. options.timeout, options.cmd)
  894. local ret,conn = rspamd_redis.make_request(options)
  895. if not ret then
  896. logger.errx('cannot execute redis request')
  897. addr:fail()
  898. end
  899. return ret,conn,addr
  900. end
  901. --[[[
  902. -- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
  903. -- Sends a request to Redis in context where `task` is not available for some specific use-cases
  904. -- Identical to redis_make_request() except in that first parameter is an `event base` object
  905. --]]
  906. exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
  907. exports.redis_make_request_taskless = redis_make_request_taskless
  908. local redis_scripts = {
  909. }
  910. local function script_set_loaded(script)
  911. if script.sha then
  912. script.loaded = true
  913. end
  914. local wait_table = {}
  915. for _,s in ipairs(script.waitq) do
  916. table.insert(wait_table, s)
  917. end
  918. script.waitq = {}
  919. for _,s in ipairs(wait_table) do
  920. s(script.loaded)
  921. end
  922. end
  923. local function prepare_redis_call(script)
  924. local function merge_tables(t1, t2)
  925. for k,v in pairs(t2) do t1[k] = v end
  926. end
  927. local servers = {}
  928. local options = {}
  929. if script.redis_params.read_servers then
  930. merge_tables(servers, script.redis_params.read_servers:all_upstreams())
  931. end
  932. if script.redis_params.write_servers then
  933. merge_tables(servers, script.redis_params.write_servers:all_upstreams())
  934. end
  935. -- Call load script on each server, set loaded flag
  936. script.in_flight = #servers
  937. for _,s in ipairs(servers) do
  938. local cur_opts = {
  939. host = s:get_addr(),
  940. timeout = script.redis_params['timeout'],
  941. cmd = 'SCRIPT',
  942. args = {'LOAD', script.script },
  943. upstream = s
  944. }
  945. if script.redis_params['password'] then
  946. cur_opts['password'] = script.redis_params['password']
  947. end
  948. if script.redis_params['db'] then
  949. cur_opts['dbname'] = script.redis_params['db']
  950. end
  951. table.insert(options, cur_opts)
  952. end
  953. return options
  954. end
  955. local function load_script_task(script, task)
  956. local rspamd_redis = require "rspamd_redis"
  957. local opts = prepare_redis_call(script)
  958. for _,opt in ipairs(opts) do
  959. opt.task = task
  960. opt.callback = function(err, data)
  961. if err then
  962. logger.errx(task, 'cannot upload script to %s: %s; registered from: %s:%s',
  963. opt.upstream:get_addr():to_string(true),
  964. err, script.caller.short_src, script.caller.currentline)
  965. opt.upstream:fail()
  966. script.fatal_error = err
  967. else
  968. opt.upstream:ok()
  969. logger.infox(task,
  970. "uploaded redis script to %s with id %s, sha: %s",
  971. opt.upstream:get_addr():to_string(true),
  972. script.id, data)
  973. script.sha = data -- We assume that sha is the same on all servers
  974. end
  975. script.in_flight = script.in_flight - 1
  976. if script.in_flight == 0 then
  977. script_set_loaded(script)
  978. end
  979. end
  980. local ret = rspamd_redis.make_request(opt)
  981. if not ret then
  982. logger.errx('cannot execute redis request to load script on %s',
  983. opt.upstream:get_addr())
  984. script.in_flight = script.in_flight - 1
  985. opt.upstream:fail()
  986. end
  987. if script.in_flight == 0 then
  988. script_set_loaded(script)
  989. end
  990. end
  991. end
  992. local function load_script_taskless(script, cfg, ev_base)
  993. local rspamd_redis = require "rspamd_redis"
  994. local opts = prepare_redis_call(script)
  995. for _,opt in ipairs(opts) do
  996. opt.config = cfg
  997. opt.ev_base = ev_base
  998. opt.callback = function(err, data)
  999. if err then
  1000. logger.errx(cfg, 'cannot upload script to %s: %s; registered from: %s:%s',
  1001. opt.upstream:get_addr():to_string(true),
  1002. err, script.caller.short_src, script.caller.currentline)
  1003. opt.upstream:fail()
  1004. script.fatal_error = err
  1005. else
  1006. opt.upstream:ok()
  1007. logger.infox(cfg,
  1008. "uploaded redis script to %s with id %s, sha: %s",
  1009. opt.upstream:get_addr():to_string(true), script.id, data)
  1010. script.sha = data -- We assume that sha is the same on all servers
  1011. script.fatal_error = nil
  1012. end
  1013. script.in_flight = script.in_flight - 1
  1014. if script.in_flight == 0 then
  1015. script_set_loaded(script)
  1016. end
  1017. end
  1018. local ret = rspamd_redis.make_request(opt)
  1019. if not ret then
  1020. logger.errx('cannot execute redis request to load script on %s',
  1021. opt.upstream:get_addr())
  1022. script.in_flight = script.in_flight - 1
  1023. opt.upstream:fail()
  1024. end
  1025. if script.in_flight == 0 then
  1026. script_set_loaded(script)
  1027. end
  1028. end
  1029. end
  1030. local function load_redis_script(script, cfg, ev_base, _)
  1031. if script.redis_params then
  1032. load_script_taskless(script, cfg, ev_base)
  1033. end
  1034. end
  1035. local function add_redis_script(script, redis_params)
  1036. local caller = debug.getinfo(2)
  1037. local new_script = {
  1038. caller = caller,
  1039. loaded = false,
  1040. redis_params = redis_params,
  1041. script = script,
  1042. waitq = {}, -- callbacks pending for script being loaded
  1043. id = #redis_scripts + 1
  1044. }
  1045. -- Register on load function
  1046. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  1047. local mult = 0.0
  1048. rspamd_config:add_periodic(ev_base, 0.0, function()
  1049. if not new_script.sha then
  1050. load_redis_script(new_script, cfg, ev_base, worker)
  1051. mult = mult + 1
  1052. return 1.0 * mult -- Check one more time in one second
  1053. end
  1054. return false
  1055. end, false)
  1056. end)
  1057. table.insert(redis_scripts, new_script)
  1058. return #redis_scripts
  1059. end
  1060. exports.add_redis_script = add_redis_script
  1061. local function exec_redis_script(id, params, callback, keys, args)
  1062. local redis_args = {}
  1063. if not redis_scripts[id] then
  1064. logger.errx("cannot find registered script with id %s", id)
  1065. return false
  1066. end
  1067. local script = redis_scripts[id]
  1068. if script.fatal_error then
  1069. callback(script.fatal_error, nil)
  1070. return true
  1071. end
  1072. if not script.redis_params then
  1073. callback('no redis servers defined', nil)
  1074. return true
  1075. end
  1076. local function do_call(can_reload)
  1077. local function redis_cb(err, data)
  1078. if not err then
  1079. callback(err, data)
  1080. elseif string.match(err, 'NOSCRIPT') then
  1081. -- Schedule restart
  1082. script.sha = nil
  1083. if can_reload then
  1084. table.insert(script.waitq, do_call)
  1085. if script.in_flight == 0 then
  1086. -- Reload scripts if this has not been initiated yet
  1087. if params.task then
  1088. load_script_task(script, params.task)
  1089. else
  1090. load_script_taskless(script, rspamd_config, params.ev_base)
  1091. end
  1092. end
  1093. else
  1094. callback(err, data)
  1095. end
  1096. else
  1097. callback(err, data)
  1098. end
  1099. end
  1100. if #redis_args == 0 then
  1101. table.insert(redis_args, script.sha)
  1102. table.insert(redis_args, tostring(#keys))
  1103. for _,k in ipairs(keys) do
  1104. table.insert(redis_args, k)
  1105. end
  1106. if type(args) == 'table' then
  1107. for _, a in ipairs(args) do
  1108. table.insert(redis_args, a)
  1109. end
  1110. end
  1111. end
  1112. if params.task then
  1113. if not rspamd_redis_make_request(params.task, script.redis_params,
  1114. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1115. callback('Cannot make redis request', nil)
  1116. end
  1117. else
  1118. if not redis_make_request_taskless(params.ev_base, rspamd_config,
  1119. script.redis_params,
  1120. params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
  1121. callback('Cannot make redis request', nil)
  1122. end
  1123. end
  1124. end
  1125. if script.loaded then
  1126. do_call(true)
  1127. else
  1128. -- Delayed until scripts are loaded
  1129. if not params.task then
  1130. table.insert(script.waitq, do_call)
  1131. else
  1132. -- TODO: fix taskfull requests
  1133. table.insert(script.waitq, function()
  1134. if script.loaded then
  1135. do_call(false)
  1136. else
  1137. callback('NOSCRIPT', nil)
  1138. end
  1139. end)
  1140. load_script_task(script, params.task)
  1141. end
  1142. end
  1143. return true
  1144. end
  1145. exports.exec_redis_script = exec_redis_script
  1146. local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
  1147. if not redis_params then
  1148. return false,nil
  1149. end
  1150. local rspamd_redis = require "rspamd_redis"
  1151. local addr
  1152. if key then
  1153. if is_write then
  1154. addr = redis_params['write_servers']:get_upstream_by_hash(key)
  1155. else
  1156. addr = redis_params['read_servers']:get_upstream_by_hash(key)
  1157. end
  1158. else
  1159. if is_write then
  1160. addr = redis_params['write_servers']:get_upstream_master_slave(key)
  1161. else
  1162. addr = redis_params['read_servers']:get_upstream_round_robin(key)
  1163. end
  1164. end
  1165. if not addr then
  1166. logger.errx(cfg, 'cannot select server to make redis request')
  1167. end
  1168. local options = {
  1169. host = addr:get_addr(),
  1170. timeout = redis_params['timeout'],
  1171. config = cfg or rspamd_config,
  1172. ev_base = ev_base or rspamadm_ev_base,
  1173. session = redis_params.session or rspamadm_session
  1174. }
  1175. for k,v in pairs(redis_params) do
  1176. options[k] = v
  1177. end
  1178. if not options.config then
  1179. logger.errx('config is not set')
  1180. return false,nil,addr
  1181. end
  1182. if not options.ev_base then
  1183. logger.errx('ev_base is not set')
  1184. return false,nil,addr
  1185. end
  1186. if not options.session then
  1187. logger.errx('session is not set')
  1188. return false,nil,addr
  1189. end
  1190. local ret,conn = rspamd_redis.connect_sync(options)
  1191. if not ret then
  1192. logger.errx('cannot execute redis request: %s', conn)
  1193. addr:fail()
  1194. return false,nil,addr
  1195. end
  1196. if conn then
  1197. if redis_params['password'] then
  1198. conn:add_cmd('AUTH', {redis_params['password']})
  1199. end
  1200. if redis_params['db'] then
  1201. conn:add_cmd('SELECT', {tostring(redis_params['db'])})
  1202. elseif redis_params['dbname'] then
  1203. conn:add_cmd('SELECT', {tostring(redis_params['dbname'])})
  1204. end
  1205. end
  1206. return ret,conn,addr
  1207. end
  1208. exports.redis_connect_sync = redis_connect_sync
  1209. --[[[
  1210. -- @function lua_redis.request(redis_params, attrs, req)
  1211. -- Sends a request to Redis synchronously with coroutines or asynchronously using
  1212. -- a callback (modern API)
  1213. -- @param redis_params a table of redis server parameters
  1214. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1215. -- @param req a table of request: a command + command options
  1216. -- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
  1217. --]]
  1218. exports.request = function(redis_params, attrs, req)
  1219. local lua_util = require "lua_util"
  1220. if not attrs or not redis_params or not req then
  1221. logger.errx('invalid arguments for redis request')
  1222. return false,nil,nil
  1223. end
  1224. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1225. logger.errx('invalid attributes for redis request')
  1226. return false,nil,nil
  1227. end
  1228. local opts = lua_util.shallowcopy(attrs)
  1229. local log_obj = opts.task or opts.config
  1230. local addr
  1231. if opts.callback then
  1232. -- Wrap callback
  1233. local callback = opts.callback
  1234. local function rspamd_redis_make_request_cb(err, data)
  1235. if err then
  1236. addr:fail()
  1237. else
  1238. addr:ok()
  1239. end
  1240. callback(err, data, addr)
  1241. end
  1242. opts.callback = rspamd_redis_make_request_cb
  1243. end
  1244. local rspamd_redis = require "rspamd_redis"
  1245. local is_write = opts.is_write
  1246. if opts.key then
  1247. if is_write then
  1248. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1249. else
  1250. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1251. end
  1252. else
  1253. if is_write then
  1254. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1255. else
  1256. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1257. end
  1258. end
  1259. if not addr then
  1260. logger.errx(log_obj, 'cannot select server to make redis request')
  1261. end
  1262. opts.host = addr:get_addr()
  1263. opts.timeout = redis_params.timeout
  1264. if type(req) == 'string' then
  1265. opts.cmd = req
  1266. else
  1267. -- XXX: modifies the input table
  1268. opts.cmd = table.remove(req, 1);
  1269. opts.args = req
  1270. end
  1271. if redis_params.password then
  1272. opts.password = redis_params.password
  1273. end
  1274. if redis_params.db then
  1275. opts.dbname = redis_params.db
  1276. end
  1277. lutil.debugm(N, 'perform generic request to redis server' ..
  1278. ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr,
  1279. opts.timeout, opts.cmd, opts.args)
  1280. if opts.callback then
  1281. local ret,conn = rspamd_redis.make_request(opts)
  1282. if not ret then
  1283. logger.errx(log_obj, 'cannot execute redis request')
  1284. addr:fail()
  1285. end
  1286. return ret,conn,addr
  1287. else
  1288. -- Coroutines version
  1289. local ret,conn = rspamd_redis.connect_sync(opts)
  1290. if not ret then
  1291. logger.errx(log_obj, 'cannot execute redis request')
  1292. addr:fail()
  1293. else
  1294. conn:add_cmd(opts.cmd, opts.args)
  1295. return conn:exec()
  1296. end
  1297. return false,nil,addr
  1298. end
  1299. end
  1300. --[[[
  1301. -- @function lua_redis.connect(redis_params, attrs)
  1302. -- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
  1303. -- @param redis_params a table of redis server parameters
  1304. -- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
  1305. -- @return {result,connection,address} boolean result, connection object, redis server address
  1306. --]]
  1307. exports.connect = function(redis_params, attrs)
  1308. local lua_util = require "lua_util"
  1309. if not attrs or not redis_params then
  1310. logger.errx('invalid arguments for redis connect')
  1311. return false,nil,nil
  1312. end
  1313. if not (attrs.task or (attrs.config and attrs.ev_base)) then
  1314. logger.errx('invalid attributes for redis connect')
  1315. return false,nil,nil
  1316. end
  1317. local opts = lua_util.shallowcopy(attrs)
  1318. local log_obj = opts.task or opts.config
  1319. local addr
  1320. if opts.callback then
  1321. -- Wrap callback
  1322. local callback = opts.callback
  1323. local function rspamd_redis_make_request_cb(err, data)
  1324. if err then
  1325. addr:fail()
  1326. else
  1327. addr:ok()
  1328. end
  1329. callback(err, data, addr)
  1330. end
  1331. opts.callback = rspamd_redis_make_request_cb
  1332. end
  1333. local rspamd_redis = require "rspamd_redis"
  1334. local is_write = opts.is_write
  1335. if opts.key then
  1336. if is_write then
  1337. addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
  1338. else
  1339. addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
  1340. end
  1341. else
  1342. if is_write then
  1343. addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
  1344. else
  1345. addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
  1346. end
  1347. end
  1348. if not addr then
  1349. logger.errx(log_obj, 'cannot select server to make redis connect')
  1350. end
  1351. opts.host = addr:get_addr()
  1352. opts.timeout = redis_params.timeout
  1353. if redis_params.password then
  1354. opts.password = redis_params.password
  1355. end
  1356. if redis_params.db then
  1357. opts.dbname = redis_params.db
  1358. end
  1359. if opts.callback then
  1360. local ret,conn = rspamd_redis.connect(opts)
  1361. if not ret then
  1362. logger.errx(log_obj, 'cannot execute redis connect')
  1363. addr:fail()
  1364. end
  1365. return ret,conn,addr
  1366. else
  1367. -- Coroutines version
  1368. local ret,conn = rspamd_redis.connect_sync(opts)
  1369. if not ret then
  1370. logger.errx(log_obj, 'cannot execute redis connect')
  1371. addr:fail()
  1372. else
  1373. return true,conn,addr
  1374. end
  1375. return false,nil,addr
  1376. end
  1377. end
  1378. local redis_prefixes = {}
  1379. --[[[
  1380. -- @function lua_redis.register_prefix(prefix, module, description[, optional])
  1381. -- Register new redis prefix for documentation purposes
  1382. -- @param {string} prefix string prefix
  1383. -- @param {string} module module name
  1384. -- @param {string} description prefix description
  1385. -- @param {table} optional optional kv pairs (e.g. pattern)
  1386. --]]
  1387. local function register_prefix(prefix, module, description, optional)
  1388. local pr = {
  1389. module = module,
  1390. description = description
  1391. }
  1392. if optional and type(optional) == 'table' then
  1393. for k,v in pairs(optional) do
  1394. pr[k] = v
  1395. end
  1396. end
  1397. redis_prefixes[prefix] = pr
  1398. end
  1399. exports.register_prefix = register_prefix
  1400. --[[[
  1401. -- @function lua_redis.prefixes([mname])
  1402. -- Returns prefixes for specific module (or all prefixes). Returns a table prefix -> table
  1403. --]]
  1404. exports.prefixes = function(mname)
  1405. if not mname then
  1406. return redis_prefixes
  1407. else
  1408. local fun = require "fun"
  1409. return fun.totable(fun.filter(function(_, data) return data.module == mname end,
  1410. redis_prefixes))
  1411. end
  1412. end
  1413. return exports