您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. package ecdh
  2. import (
  3. "bytes"
  4. "crypto"
  5. "crypto/aes"
  6. "crypto/elliptic"
  7. "encoding/binary"
  8. "errors"
  9. "github.com/keybase/go-crypto/curve25519"
  10. "io"
  11. "math/big"
  12. )
  13. type PublicKey struct {
  14. elliptic.Curve
  15. X, Y *big.Int
  16. }
  17. type PrivateKey struct {
  18. PublicKey
  19. X *big.Int
  20. }
  21. // KDF implements Key Derivation Function as described in
  22. // https://tools.ietf.org/html/rfc6637#section-7
  23. func (e *PublicKey) KDF(S []byte, kdfParams []byte, hash crypto.Hash) []byte {
  24. sLen := (e.Curve.Params().P.BitLen() + 7) / 8
  25. buf := new(bytes.Buffer)
  26. buf.Write([]byte{0, 0, 0, 1})
  27. if sLen > len(S) {
  28. // zero-pad the S. If we got invalid S (bigger than curve's
  29. // P), we are going to produce invalid key. Garbage in,
  30. // garbage out.
  31. buf.Write(make([]byte, sLen-len(S)))
  32. }
  33. buf.Write(S)
  34. buf.Write(kdfParams)
  35. hashw := hash.New()
  36. hashw.Write(buf.Bytes())
  37. key := hashw.Sum(nil)
  38. return key
  39. }
  40. // AESKeyUnwrap implements RFC 3394 Key Unwrapping. See
  41. // http://tools.ietf.org/html/rfc3394#section-2.2.1
  42. // Note: The second described algorithm ("index-based") is implemented
  43. // here.
  44. func AESKeyUnwrap(key, cipherText []byte) ([]byte, error) {
  45. if len(cipherText)%8 != 0 {
  46. return nil, errors.New("cipherText must by a multiple of 64 bits")
  47. }
  48. cipher, err := aes.NewCipher(key)
  49. if err != nil {
  50. return nil, err
  51. }
  52. nblocks := len(cipherText)/8 - 1
  53. // 1) Initialize variables.
  54. // - Set A = C[0]
  55. var A [aes.BlockSize]byte
  56. copy(A[:8], cipherText[:8])
  57. // For i = 1 to n
  58. // Set R[i] = C[i]
  59. R := make([]byte, len(cipherText)-8)
  60. copy(R, cipherText[8:])
  61. // 2) Compute intermediate values.
  62. for j := 5; j >= 0; j-- {
  63. for i := nblocks - 1; i >= 0; i-- {
  64. // B = AES-1(K, (A ^ t) | R[i]) where t = n*j+i
  65. // A = MSB(64, B)
  66. t := uint64(nblocks*j + i + 1)
  67. At := binary.BigEndian.Uint64(A[:8]) ^ t
  68. binary.BigEndian.PutUint64(A[:8], At)
  69. copy(A[8:], R[i*8:i*8+8])
  70. cipher.Decrypt(A[:], A[:])
  71. // R[i] = LSB(B, 64)
  72. copy(R[i*8:i*8+8], A[8:])
  73. }
  74. }
  75. // 3) Output results.
  76. // If A is an appropriate initial value (see 2.2.3),
  77. for i := 0; i < 8; i++ {
  78. if A[i] != 0xA6 {
  79. return nil, errors.New("Failed to unwrap key (A is not IV)")
  80. }
  81. }
  82. return R, nil
  83. }
  84. // AESKeyWrap implements RFC 3394 Key Wrapping. See
  85. // https://tools.ietf.org/html/rfc3394#section-2.2.2
  86. // Note: The second described algorithm ("index-based") is implemented
  87. // here.
  88. func AESKeyWrap(key, plainText []byte) ([]byte, error) {
  89. if len(plainText)%8 != 0 {
  90. return nil, errors.New("plainText must be a multiple of 64 bits")
  91. }
  92. cipher, err := aes.NewCipher(key) // NewCipher checks key size
  93. if err != nil {
  94. return nil, err
  95. }
  96. nblocks := len(plainText) / 8
  97. // 1) Initialize variables.
  98. var A [aes.BlockSize]byte
  99. // Section 2.2.3.1 -- Initial Value
  100. // http://tools.ietf.org/html/rfc3394#section-2.2.3.1
  101. for i := 0; i < 8; i++ {
  102. A[i] = 0xA6
  103. }
  104. // For i = 1 to n
  105. // Set R[i] = P[i]
  106. R := make([]byte, len(plainText))
  107. copy(R, plainText)
  108. // 2) Calculate intermediate values.
  109. for j := 0; j <= 5; j++ {
  110. for i := 0; i < nblocks; i++ {
  111. // B = AES(K, A | R[i])
  112. copy(A[8:], R[i*8:i*8+8])
  113. cipher.Encrypt(A[:], A[:])
  114. // (Assume B = A)
  115. // A = MSB(64, B) ^ t where t = (n*j)+1
  116. t := uint64(j*nblocks + i + 1)
  117. At := binary.BigEndian.Uint64(A[:8]) ^ t
  118. binary.BigEndian.PutUint64(A[:8], At)
  119. // R[i] = LSB(64, B)
  120. copy(R[i*8:i*8+8], A[8:])
  121. }
  122. }
  123. // 3) Output results.
  124. // Set C[0] = A
  125. // For i = 1 to n
  126. // C[i] = R[i]
  127. return append(A[:8], R...), nil
  128. }
  129. // PadBuffer pads byte buffer buf to a length being multiple of
  130. // blockLen. Additional bytes appended to the buffer have value of the
  131. // number padded bytes. E.g. if the buffer is 3 bytes short of being
  132. // 40 bytes total, the appended bytes will be [03, 03, 03].
  133. func PadBuffer(buf []byte, blockLen int) []byte {
  134. padding := blockLen - (len(buf) % blockLen)
  135. if padding == 0 {
  136. return buf
  137. }
  138. padBuf := make([]byte, padding)
  139. for i := 0; i < padding; i++ {
  140. padBuf[i] = byte(padding)
  141. }
  142. return append(buf, padBuf...)
  143. }
  144. // UnpadBuffer verifies that buffer contains proper padding and
  145. // returns buffer without the padding, or nil if the padding was
  146. // invalid.
  147. func UnpadBuffer(buf []byte, dataLen int) []byte {
  148. padding := len(buf) - dataLen
  149. outBuf := buf[:dataLen]
  150. for i := dataLen; i < len(buf); i++ {
  151. if buf[i] != byte(padding) {
  152. // Invalid padding - bail out
  153. return nil
  154. }
  155. }
  156. return outBuf
  157. }
  158. func (e *PublicKey) Encrypt(random io.Reader, kdfParams []byte, plain []byte, hash crypto.Hash, kdfKeySize int) (Vx *big.Int, Vy *big.Int, C []byte, err error) {
  159. // Vx, Vy - encryption key
  160. // Note for Curve 25519 - curve25519 library already does key
  161. // clamping in scalarMult, so we can use generic random scalar
  162. // generation from elliptic.
  163. priv, Vx, Vy, err := elliptic.GenerateKey(e.Curve, random)
  164. if err != nil {
  165. return nil, nil, nil, err
  166. }
  167. // Sx, Sy - shared secret
  168. Sx, _ := e.Curve.ScalarMult(e.X, e.Y, priv)
  169. // Encrypt the payload with KDF-ed S as the encryption key. Pass
  170. // the ciphertext along with V to the recipient. Recipient can
  171. // generate S using V and their priv key, and then KDF(S), on
  172. // their own, to get encryption key and decrypt the ciphertext,
  173. // revealing encryption key for symmetric encryption later.
  174. plain = PadBuffer(plain, 8)
  175. key := e.KDF(Sx.Bytes(), kdfParams, hash)
  176. // Take only as many bytes from key as the key length (the hash
  177. // result might be bigger)
  178. encrypted, err := AESKeyWrap(key[:kdfKeySize], plain)
  179. return Vx, Vy, encrypted, nil
  180. }
  181. func (e *PrivateKey) DecryptShared(X, Y *big.Int) []byte {
  182. Sx, _ := e.Curve.ScalarMult(X, Y, e.X.Bytes())
  183. return Sx.Bytes()
  184. }
  185. func countBits(buffer []byte) int {
  186. var headerLen int
  187. switch buffer[0] {
  188. case 0x4:
  189. headerLen = 3
  190. case 0x40:
  191. headerLen = 7
  192. default:
  193. // Unexpected header - but we can still count the bits.
  194. val := buffer[0]
  195. headerLen = 0
  196. for val > 0 {
  197. val = val / 2
  198. headerLen++
  199. }
  200. }
  201. return headerLen + (len(buffer)-1)*8
  202. }
  203. // elliptic.Marshal and elliptic.Unmarshal only marshals uncompressed
  204. // 0x4 MPI types. These functions will check if the curve is cv25519,
  205. // and if so, use 0x40 compressed type to (un)marshal. Otherwise,
  206. // elliptic.(Un)marshal will be called.
  207. // Marshal encodes point into either 0x4 uncompressed point form, or
  208. // 0x40 compressed point for Curve 25519.
  209. func Marshal(curve elliptic.Curve, x, y *big.Int) (buf []byte, bitSize int) {
  210. // NOTE: Read more about MPI encoding in the RFC:
  211. // https://tools.ietf.org/html/rfc4880#section-3.2
  212. // We are required to encode size in bits, counting from the most-
  213. // significant non-zero bit. So assuming that the buffer never
  214. // starts with 0x00, we only need to count bits in the first byte
  215. // - and in current implentation it will always be 0x4 or 0x40.
  216. cv, ok := curve25519.ToCurve25519(curve)
  217. if ok {
  218. buf = cv.MarshalType40(x, y)
  219. } else {
  220. buf = elliptic.Marshal(curve, x, y)
  221. }
  222. return buf, countBits(buf)
  223. }
  224. // Unmarshal converts point, serialized by Marshal, into x, y pair.
  225. // For 0x40 compressed points (for Curve 25519), y will always be 0.
  226. // It is an error if point is not on the curve, On error, x = nil.
  227. func Unmarshal(curve elliptic.Curve, data []byte) (x, y *big.Int) {
  228. cv, ok := curve25519.ToCurve25519(curve)
  229. if ok {
  230. return cv.UnmarshalType40(data)
  231. }
  232. return elliptic.Unmarshal(curve, data)
  233. }
  234. func GenerateKey(curve elliptic.Curve, random io.Reader) (priv *PrivateKey, err error) {
  235. var privBytes []byte
  236. var Vx, Vy *big.Int
  237. if _, ok := curve25519.ToCurve25519(curve); ok {
  238. privBytes = make([]byte, 32)
  239. _, err = io.ReadFull(random, privBytes)
  240. if err != nil {
  241. return nil, err
  242. }
  243. // NOTE: PGP expect scalars in reverse order than Curve 25519
  244. // go library. That's why this trimming is backwards compared
  245. // to curve25519.go
  246. privBytes[31] &= 248
  247. privBytes[0] &= 127
  248. privBytes[0] |= 64
  249. Vx,Vy = curve.ScalarBaseMult(privBytes)
  250. } else {
  251. privBytes, Vx, Vy, err = elliptic.GenerateKey(curve, random)
  252. if err != nil {
  253. return nil, err
  254. }
  255. }
  256. priv = &PrivateKey{}
  257. priv.X = new(big.Int).SetBytes(privBytes)
  258. priv.PublicKey.Curve = curve
  259. priv.PublicKey.X = Vx
  260. priv.PublicKey.Y = Vy
  261. return priv, nil
  262. }