summaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm/vault.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/rspamadm/vault.lua')
-rw-r--r--lualib/rspamadm/vault.lua180
1 files changed, 180 insertions, 0 deletions
diff --git a/lualib/rspamadm/vault.lua b/lualib/rspamadm/vault.lua
index 9c01624fc..b11ba525e 100644
--- a/lualib/rspamadm/vault.lua
+++ b/lualib/rspamadm/vault.lua
@@ -91,6 +91,20 @@ newkey:option "-x --expire"
:convert(tonumber)
newkey:flag "-r --rewrite"
+local roll = parser:command "roll rollover"
+ :description "Perform keys rollover"
+roll:argument "domain"
+ :description "Domain to roll key(s) for"
+ :args "+"
+roll:option "-T --ttl"
+ :description "Validity period for old keys (days)"
+ :convert(tonumber)
+ :default "1"
+roll:flag "-r --remove-expired"
+ :description "Remove expired keys"
+roll:option "-x --expire"
+ :argname("<days>")
+ :convert(tonumber)
local function printf(fmt, ...)
if fmt then
@@ -252,7 +266,10 @@ local function create_and_push_key(opts, domain, existing)
selector = opts.selector,
domain = domain,
key = tostring(sk),
+ pubkey = tostring(pk),
alg = opts.algorithm,
+ bits = opts.bits or 0,
+ valid_start = os.time(),
}
}
}
@@ -346,6 +363,167 @@ local function newkey_handler(opts, domain)
end
end
+local function roll_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+ local res = {
+ selectors = {}
+ }
+
+ local err,data = rspamd_http.request{
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'get',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) or not data.content then
+ printf("No keys to roll for domain %s", domain)
+ os.exit(1)
+ else
+ local rep = parse_vault_reply(data.content)
+
+ if not rep or not rep.data then
+ printf('cannot parse reply for %s: %s', uri, data.content)
+ os.exit(1)
+ end
+
+ local elts = rep.data.selectors
+
+ if not elts then
+ printf("No keys to roll for domain %s", domain)
+ os.exit(1)
+ end
+
+ local nkeys = {} -- indexed by algorithm
+
+ local function insert_key(sel, add_expire)
+ if not nkeys[sel.alg] then
+ nkeys[sel.alg] = {}
+ end
+
+ if add_expire then
+ sel.valid_end = os.time() + opts.ttl * 3600 * 24
+ end
+
+ table.insert(nkeys[sel.alg], sel)
+ end
+
+ for _,sel in ipairs(elts) do
+ if sel.valid_end and sel.valid_end < os.time() then
+ if not opts.remove_expired then
+ insert_key(sel, false)
+ else
+ maybe_printf(opts, 'removed expired key for %s (selector %s, expire "%s"',
+ domain, sel.selector, os.date('%c', sel.valid_end))
+ end
+ else
+ insert_key(sel, true)
+ end
+ end
+
+ -- Now we need to ensure that all but one selectors have either expired or just a single key
+ for alg,keys in pairs(nkeys) do
+ table.sort(keys, function(k1, k2)
+ if k1.valid_end and k2.valid_end then
+ return k1.valid_end > k2.valid_end
+ elseif k1.valid_end then
+ return true
+ elseif k2.valid_end then
+ return false
+ end
+ return false
+ end)
+ -- Exclude the key with the highest expiration date and examine the rest
+ if not (#keys == 1 or fun.all(function(k)
+ return k.valid_end and k.valid_end < os.time()
+ end, fun.tail(keys))) then
+ printf('bad keys list for %s and %s algorithm', domain, alg)
+ fun.each(function(k)
+ if not k.valid_end then
+ printf('selector %s, algorithm %s has a key with no expire',
+ k.selector, k.alg)
+ elseif k.valid_end >= os.time() then
+ printf('selector %s, algorithm %s has a key that not yet expired: %s',
+ k.selector, k.alg, os.date('%c', k.valid_end))
+ end
+ end, fun.tail(keys))
+ os.exit(1)
+ end
+ -- OK to process
+ -- Insert keys for each algorithm in pairs <old_key(s)>, <new_key>
+ local sk,pk = genkey({algorithm = alg, bits = keys[1].bits})
+ local selector = string.format('%s-%s', alg,
+ os.date("%Y%m%d"))
+
+ if selector == keys[1].selector then
+ selector = selector .. '-1'
+ end
+ local nelt = {
+ selector = selector,
+ domain = domain,
+ key = tostring(sk),
+ pubkey = tostring(pk),
+ alg = alg,
+ bits = keys[1].bits,
+ valid_start = os.time(),
+ }
+
+ if opts.expire then
+ nelt.valid_end = os.time() + opts.expire * 3600 * 24
+ end
+
+ table.insert(res.selectors, nelt)
+ for _,k in ipairs(keys) do
+ table.insert(res.selectors, k)
+ end
+ end
+ end
+
+ -- We can now store res in the vault
+ err,data = rspamd_http.request{
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'put',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ },
+ body = {
+ ucl.to_format(res, 'json-compact')
+ },
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot put request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, data.content)
+ os.exit(1)
+ else
+ for _,key in ipairs(res.selectors) do
+ if not key.valid_end or key.valid_end > os.time() + opts.ttl * 3600 * 24 then
+ maybe_printf(opts,'rolled key for: %s, new selector: %s', domain, key.selector)
+ maybe_printf(opts, 'please place the corresponding public key as following:')
+
+ if opts.silent then
+ printf('%s', key.pubkey)
+ else
+ print_dkim_txt_record(key.pubkey, key.selector, key.alg)
+ end
+
+ end
+ end
+
+ maybe_printf(opts, 'your old keys will be valid until %s',
+ os.date('%c', os.time() + opts.ttl * 3600 * 24))
+ end
+end
+
local function handler(args)
local opts = parser:parse(args)
@@ -373,6 +551,8 @@ local function handler(args)
fun.each(function(d) show_handler(opts, d) end, opts.domain)
elseif command == 'newkey' then
fun.each(function(d) newkey_handler(opts, d) end, opts.domain)
+ elseif command == 'roll' then
+ fun.each(function(d) roll_handler(opts, d) end, opts.domain)
elseif command == 'delete' then
fun.each(function(d) delete_handler(opts, d) end, opts.domain)
else