]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Fix taskfull scripts reload
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Dec 2017 18:03:38 +0000 (18:03 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Dec 2017 18:03:38 +0000 (18:03 +0000)
lualib/lua_redis.lua

index d88d174894e2da72130076ccedce815c5ee37e79..eace60fa82fb5499e19d4245b2d08c6d46b2d2d5 100644 (file)
@@ -609,28 +609,30 @@ exports.redis_make_request_taskless = redis_make_request_taskless
 local redis_scripts = {
 }
 
-local function load_redis_script(script, cfg, ev_base, _)
-  local function merge_tables(t1, t2)
-    for k,v in pairs(t2) do t1[k] = v end
+local function script_set_loaded(script)
+  if script.sha then
+    script.loaded = true
   end
 
-  local function set_loaded()
-    if script.sha then
-      script.loaded = true
-    end
+  local wait_table = {}
+  for _,s in ipairs(script.waitq) do
+    table.insert(wait_table, s)
+  end
 
-    local wait_table = {}
-    for _,s in ipairs(script.waitq) do
-      table.insert(wait_table, s)
-    end
+  script.waitq = {}
 
-    script.waitq = {}
+  for _,s in ipairs(wait_table) do
+    s(script.loaded)
+  end
+end
 
-    for _,s in ipairs(wait_table) do
-      s(script)
-    end
+local function prepare_redis_call(script)
+  local function merge_tables(t1, t2)
+    for k,v in pairs(t2) do t1[k] = v end
   end
+
   local servers = {}
+  local options = {}
 
   if script.redis_params.read_servers then
     merge_tables(servers, script.redis_params.read_servers:all_upstreams())
@@ -642,52 +644,104 @@ local function load_redis_script(script, cfg, ev_base, _)
   -- Call load script on each server, set loaded flag
   script.in_flight = #servers
   for _,s in ipairs(servers) do
-    local function script_cb(err, data)
+    local cur_opts = {
+      host = s:get_addr(),
+      timeout = script.redis_params['timeout'],
+      cmd = 'SCRIPT',
+      args = {'LOAD', script.script },
+      upstream = s
+    }
+
+    if script.redis_params['password'] then
+      cur_opts['password'] = script.redis_params['password']
+    end
+
+    if script.redis_params['db'] then
+      cur_opts['dbname'] = script.redis_params['db']
+    end
+
+    table.insert(options, cur_opts)
+  end
+
+  return options
+end
+
+local function load_script_task(script, task)
+  local rspamd_redis = require "rspamd_redis"
+  local opts = prepare_redis_call(script)
+
+  for _,opt in ipairs(opts) do
+    opt.task = task
+    opt.callback = function(err, data)
       if err then
-        s:fail()
+        opt.upstream:fail()
       else
-        s:ok()
+        opt.upstream:ok()
+        logger.infox(task,
+          "loaded redis script with id %s, sha: %s", script.id, data)
         script.sha = data -- We assume that sha is the same on all servers
       end
       script.in_flight = script.in_flight - 1
 
       if script.in_flight == 0 then
-        set_loaded(script)
+        script_set_loaded(script)
       end
     end
 
-    local rspamd_redis = require "rspamd_redis"
+    local ret = rspamd_redis.make_request(opt)
 
-    local options = {
-      ev_base = ev_base,
-      config = cfg,
-      callback = script_cb,
-      host = s:get_addr(),
-      timeout = script.redis_params['timeout'],
-      cmd = 'SCRIPT',
-      args = {'LOAD', script.script}
-    }
+    if not ret then
+      logger.errx('cannot execute redis request to load script')
+      script.in_flight = script.in_flight - 1
+      opt.upstream:fail()
+    end
 
-    if script.redis_params['password'] then
-      options['password'] = script.redis_params['password']
+    if script.in_flight == 0 then
+      script_set_loaded(script)
     end
+  end
+end
 
-    if script.redis_params['db'] then
-      options['dbname'] = script.redis_params['db']
+local function load_script_taskless(script, cfg, ev_base)
+  local rspamd_redis = require "rspamd_redis"
+  local opts = prepare_redis_call(script)
+
+  for _,opt in ipairs(opts) do
+    opt.config = cfg
+    opt.ev_base = ev_base
+    opt.callback = function(err, data)
+      if err then
+        opt.upstream:fail()
+      else
+        opt.upstream:ok()
+        logger.infox(cfg,
+          "loaded redis script with id %s, sha: %s", script.id, data)
+        script.sha = data -- We assume that sha is the same on all servers
+      end
+      script.in_flight = script.in_flight - 1
+
+      if script.in_flight == 0 then
+        script_set_loaded(script)
+      end
     end
+    local ret = rspamd_redis.make_request(opt)
 
-    local ret = rspamd_redis.make_request(options)
     if not ret then
       logger.errx('cannot execute redis request to load script')
       script.in_flight = script.in_flight - 1
+      opt.upstream:fail()
     end
-  end
 
-  if script.in_flight == 0 then
-    set_loaded(script)
+    if script.in_flight == 0 then
+      script_set_loaded(script)
+    end
   end
 end
 
+local function load_redis_script(script, cfg, ev_base, _)
+  load_script_taskless(script, cfg, ev_base)
+end
+
 local function add_redis_script(script, redis_params)
   local new_script = {
     loaded = false,
@@ -709,34 +763,47 @@ end
 exports.add_redis_script = add_redis_script
 
 local function exec_redis_script(id, params, callback, args)
+  local logger = require "rspamd_logger"
+  local args_modified = false
+
   if not redis_scripts[id] then
+      logger.errx("cannot find registered script with id %s", id)
     return false
   end
 
   local script = redis_scripts[id]
 
-  local function do_call()
+  local function do_call(can_reload)
     local function redis_cb(err, data)
       if not err then
         callback(err, data)
-      elseif err == 'NOSCRIPT' then
+      elseif string.match(err, 'NOSCRIPT') then
         -- Schedule restart
-        table.insert(script.waitq, do_call)
-        if script.in_flight ~= 0 then
-          -- Reload scripts if this has not been initiated yet
-          if params.task then
-            load_redis_script(script, rspamd_config,
-              params.task:get_ev_base(), nil)
-          else
-            load_redis_script(script, rspamd_config,
-              params.ev_base, nil)
+        script.sha = nil
+        if can_reload then
+          table.insert(script.waitq, do_call)
+          if script.in_flight == 0 then
+            -- Reload scripts if this has not been initiated yet
+            if params.task then
+              load_script_task(script, params.task)
+            else
+              load_script_taskless(script, rspamd_config, params.ev_base)
+            end
           end
+        else
+          callback(err, data)
         end
       else
         callback(err, data)
       end
     end
 
+    if not args_modified then
+      table.insert(args, 1, tostring(#args))
+      table.insert(args, 1, script.sha)
+      args_modified = true
+    end
+
     if params.task then
       if not rspamd_redis_make_request(params.task, script.redis_params,
         params.key, params.is_write, redis_cb, 'EVALSHA', args) then
@@ -751,11 +818,16 @@ local function exec_redis_script(id, params, callback, args)
     end
   end
 
-  if not script.loaded then
-    do_call()
+  if script.loaded then
+    do_call(true)
   else
     -- Delayed until scripts are loaded
-    table.insert(script.waitq, do_call)
+    if not params.task then
+      table.insert(script.waitq, do_call)
+    else
+      -- TODO: fix taskfull requests
+      callback('NOSCRIPT', nil)
+    end
   end
 
   return true