]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Implement keys rotation in the vault
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 30 Apr 2019 15:20:32 +0000 (16:20 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 30 Apr 2019 15:20:53 +0000 (16:20 +0100)
lualib/rspamadm/vault.lua

index 9c01624fc498098c9b716370a1e4878d3401502b..b11ba525ee412e810411826b48ea892a2562c035 100644 (file)
@@ -91,6 +91,20 @@ newkey:option "-x --expire"
       :convert(tonumber)
 newkey:flag "-r --rewrite"
 
+local roll = parser:command "roll rollover"
+                   :description "Perform keys rollover"
+roll:argument "domain"
+    :description "Domain to roll key(s) for"
+    :args "+"
+roll:option "-T --ttl"
+    :description "Validity period for old keys (days)"
+    :convert(tonumber)
+    :default "1"
+roll:flag "-r --remove-expired"
+    :description "Remove expired keys"
+roll:option "-x --expire"
+    :argname("<days>")
+    :convert(tonumber)
 
 local function printf(fmt, ...)
   if fmt then
@@ -252,7 +266,10 @@ local function create_and_push_key(opts, domain, existing)
         selector = opts.selector,
         domain = domain,
         key = tostring(sk),
+        pubkey = tostring(pk),
         alg = opts.algorithm,
+        bits = opts.bits or 0,
+        valid_start = os.time(),
       }
     }
   }
@@ -346,6 +363,167 @@ local function newkey_handler(opts, domain)
   end
 end
 
+local function roll_handler(opts, domain)
+  local uri = vault_url(opts, domain)
+  local res = {
+    selectors = {}
+  }
+
+  local err,data = rspamd_http.request{
+    config = rspamd_config,
+    ev_base = rspamadm_ev_base,
+    session = rspamadm_session,
+    resolver = rspamadm_dns_resolver,
+    url = uri,
+    method = 'get',
+    headers = {
+      ['X-Vault-Token'] = opts.token
+    }
+  }
+
+  if is_http_error(err, data) or not data.content then
+    printf("No keys to roll for domain %s", domain)
+    os.exit(1)
+  else
+    local rep = parse_vault_reply(data.content)
+
+    if not rep or not rep.data then
+      printf('cannot parse reply for %s: %s', uri, data.content)
+      os.exit(1)
+    end
+
+    local elts = rep.data.selectors
+
+    if not elts then
+      printf("No keys to roll for domain %s", domain)
+      os.exit(1)
+    end
+
+    local nkeys = {} -- indexed by algorithm
+
+    local function insert_key(sel, add_expire)
+      if not nkeys[sel.alg] then
+        nkeys[sel.alg] = {}
+      end
+
+      if add_expire then
+        sel.valid_end = os.time() + opts.ttl * 3600 * 24
+      end
+
+      table.insert(nkeys[sel.alg], sel)
+    end
+
+    for _,sel in ipairs(elts) do
+      if sel.valid_end and sel.valid_end < os.time() then
+        if not opts.remove_expired then
+          insert_key(sel, false)
+        else
+          maybe_printf(opts, 'removed expired key for %s (selector %s, expire "%s"',
+              domain, sel.selector, os.date('%c', sel.valid_end))
+        end
+      else
+        insert_key(sel, true)
+      end
+    end
+
+    -- Now we need to ensure that all but one selectors have either expired or just a single key
+    for alg,keys in pairs(nkeys) do
+      table.sort(keys, function(k1, k2)
+        if k1.valid_end and k2.valid_end then
+          return k1.valid_end > k2.valid_end
+        elseif k1.valid_end then
+          return true
+        elseif k2.valid_end then
+          return false
+        end
+        return false
+      end)
+      -- Exclude the key with the highest expiration date and examine the rest
+      if not (#keys == 1 or fun.all(function(k)
+            return k.valid_end and k.valid_end < os.time()
+          end, fun.tail(keys))) then
+        printf('bad keys list for %s and %s algorithm', domain, alg)
+        fun.each(function(k)
+          if not k.valid_end then
+            printf('selector %s, algorithm %s has a key with no expire',
+                k.selector, k.alg)
+          elseif k.valid_end >= os.time() then
+            printf('selector %s, algorithm %s has a key that not yet expired: %s',
+                k.selector, k.alg, os.date('%c', k.valid_end))
+          end
+        end, fun.tail(keys))
+        os.exit(1)
+      end
+      -- OK to process
+      -- Insert keys for each algorithm in pairs <old_key(s)>, <new_key>
+      local sk,pk = genkey({algorithm = alg, bits = keys[1].bits})
+      local selector = string.format('%s-%s', alg,
+          os.date("%Y%m%d"))
+
+      if selector == keys[1].selector then
+        selector = selector .. '-1'
+      end
+      local nelt = {
+        selector = selector,
+        domain = domain,
+        key = tostring(sk),
+        pubkey = tostring(pk),
+        alg = alg,
+        bits = keys[1].bits,
+        valid_start = os.time(),
+      }
+
+      if opts.expire then
+        nelt.valid_end = os.time() + opts.expire * 3600 * 24
+      end
+
+      table.insert(res.selectors, nelt)
+      for _,k in ipairs(keys) do
+        table.insert(res.selectors, k)
+      end
+    end
+  end
+
+  -- We can now store res in the vault
+  err,data = rspamd_http.request{
+    config = rspamd_config,
+    ev_base = rspamadm_ev_base,
+    session = rspamadm_session,
+    resolver = rspamadm_dns_resolver,
+    url = uri,
+    method = 'put',
+    headers = {
+      ['X-Vault-Token'] = opts.token
+    },
+    body = {
+      ucl.to_format(res, 'json-compact')
+    },
+  }
+
+  if is_http_error(err, data) then
+    printf('cannot put request to the vault (%s), HTTP error code %s', uri, data.code)
+    maybe_print_vault_data(opts, data.content)
+    os.exit(1)
+  else
+    for _,key in ipairs(res.selectors) do
+      if not key.valid_end or key.valid_end > os.time() + opts.ttl * 3600 * 24  then
+        maybe_printf(opts,'rolled key for: %s, new selector: %s', domain, key.selector)
+        maybe_printf(opts, 'please place the corresponding public key as following:')
+
+        if opts.silent then
+          printf('%s', key.pubkey)
+        else
+          print_dkim_txt_record(key.pubkey, key.selector, key.alg)
+        end
+
+      end
+    end
+
+    maybe_printf(opts, 'your old keys will be valid until %s',
+        os.date('%c', os.time() + opts.ttl * 3600 * 24))
+  end
+end
+
 local function handler(args)
   local opts = parser:parse(args)
 
@@ -373,6 +551,8 @@ local function handler(args)
     fun.each(function(d) show_handler(opts, d) end, opts.domain)
   elseif command == 'newkey' then
     fun.each(function(d) newkey_handler(opts, d) end, opts.domain)
+  elseif command == 'roll' then
+    fun.each(function(d) roll_handler(opts, d) end, opts.domain)
   elseif command == 'delete' then
     fun.each(function(d) delete_handler(opts, d) end, opts.domain)
   else