aboutsummaryrefslogtreecommitdiffstats
path: root/common/rfb/SSecurityRSAAES.cxx
diff options
context:
space:
mode:
Diffstat (limited to 'common/rfb/SSecurityRSAAES.cxx')
-rw-r--r--common/rfb/SSecurityRSAAES.cxx598
1 files changed, 598 insertions, 0 deletions
diff --git a/common/rfb/SSecurityRSAAES.cxx b/common/rfb/SSecurityRSAAES.cxx
new file mode 100644
index 00000000..15d2e97b
--- /dev/null
+++ b/common/rfb/SSecurityRSAAES.cxx
@@ -0,0 +1,598 @@
+/* Copyright (C) 2022 Dinglan Peng
+ *
+ * This is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This software is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this software; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
+ * USA.
+ */
+
+#ifdef HAVE_CONFIG_H
+#include <config.h>
+#endif
+
+#ifndef HAVE_NETTLE
+#error "This source should not be compiled without HAVE_NETTLE defined"
+#endif
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+
+#include <nettle/bignum.h>
+#include <nettle/sha1.h>
+#include <nettle/sha2.h>
+#include <nettle/base64.h>
+#include <nettle/asn1.h>
+#include <rfb/SSecurityRSAAES.h>
+#include <rfb/SConnection.h>
+#include <rfb/LogWriter.h>
+#include <rfb/Exception.h>
+#include <rdr/AESInStream.h>
+#include <rdr/AESOutStream.h>
+#if !defined(WIN32) && !defined(__APPLE__)
+#include <rfb/UnixPasswordValidator.h>
+#endif
+#ifdef WIN32
+#include <rfb/WinPasswdValidator.h>
+#endif
+#include <rfb/SSecurityVncAuth.h>
+
+enum {
+ SendPublicKey,
+ ReadPublicKey,
+ ReadRandom,
+ ReadHash,
+ ReadCredentials,
+};
+
+const int MinKeyLength = 1024;
+const int MaxKeyLength = 8192;
+const size_t MaxKeyFileSize = 32 * 1024;
+
+using namespace rfb;
+
+StringParameter SSecurityRSAAES::keyFile
+("RSAKey", "Path to the RSA key for the RSA-AES security types in "
+ "PEM format", "", ConfServer);
+BoolParameter SSecurityRSAAES::requireUsername
+("RequireUsername", "Require username for the RSA-AES security types",
+ false, ConfServer);
+
+SSecurityRSAAES::SSecurityRSAAES(SConnection* sc, rdr::U32 _secType,
+ int _keySize, bool _isAllEncrypted)
+ : SSecurity(sc), state(SendPublicKey),
+ keySize(_keySize), isAllEncrypted(_isAllEncrypted), secType(_secType),
+ serverKey(), clientKey(),
+ serverKeyN(NULL), serverKeyE(NULL), clientKeyN(NULL), clientKeyE(NULL),
+ accessRights(SConnection::AccessDefault),
+ rais(NULL), raos(NULL), rawis(NULL), rawos(NULL)
+{
+ assert(keySize == 128 || keySize == 256);
+}
+
+SSecurityRSAAES::~SSecurityRSAAES()
+{
+ cleanup();
+}
+
+void SSecurityRSAAES::cleanup()
+{
+ if (serverKeyN)
+ delete[] serverKeyN;
+ if (serverKeyE)
+ delete[] serverKeyE;
+ if (clientKeyN)
+ delete[] clientKeyN;
+ if (clientKeyE)
+ delete[] clientKeyE;
+ if (serverKey.size)
+ rsa_private_key_clear(&serverKey);
+ if (clientKey.size)
+ rsa_public_key_clear(&clientKey);
+ if (isAllEncrypted && rawis && rawos)
+ sc->setStreams(rawis, rawos);
+ if (rais)
+ delete rais;
+ if (raos)
+ delete raos;
+}
+
+static inline ssize_t findSubstr(rdr::U8* data, size_t size, const char *pattern)
+{
+ size_t patternLength = strlen(pattern);
+ for (size_t i = 0; i + patternLength < size; ++i) {
+ for (size_t j = 0; j < patternLength; ++j)
+ if (data[i + j] != pattern[j])
+ goto next;
+ return i;
+next:
+ continue;
+ }
+ return -1;
+}
+
+static bool loadPEM(rdr::U8* data, size_t size, const char *begin,
+ const char *end, rdr::U8** der, size_t *derSize)
+{
+ ssize_t pos1 = findSubstr(data, size, begin);
+ if (pos1 == -1)
+ return false;
+ pos1 += strlen(begin);
+ ssize_t base64Size = findSubstr(data + pos1, size - pos1, end);
+ if (base64Size == -1)
+ return false;
+ char *derBase64 = (char *)data + pos1;
+ if (!base64Size)
+ return false;
+ *der = new rdr::U8[BASE64_DECODE_LENGTH(base64Size)];
+ struct base64_decode_ctx ctx;
+ base64_decode_init(&ctx);
+ if (!base64_decode_update(&ctx, derSize, *der, base64Size, derBase64))
+ return false;
+ if (!base64_decode_final(&ctx))
+ return false;
+ return true;
+}
+
+void SSecurityRSAAES::loadPrivateKey()
+{
+ FILE* file = fopen(keyFile.getData(), "rb");
+ if (!file)
+ throw ConnFailedException("failed to open key file");
+ fseek(file, 0, SEEK_END);
+ size_t size = ftell(file);
+ if (size == 0 || size > MaxKeyFileSize) {
+ fclose(file);
+ throw ConnFailedException("size of key file is zero or too big");
+ }
+ fseek(file, 0, SEEK_SET);
+ rdr::U8Array data(size);
+ if (fread(data.buf, 1, size, file) != size) {
+ fclose(file);
+ throw ConnFailedException("failed to read key");
+ }
+ fclose(file);
+
+ rdr::U8Array der;
+ size_t derSize;
+ if (loadPEM(data.buf, size, "-----BEGIN RSA PRIVATE KEY-----\n",
+ "-----END RSA PRIVATE KEY-----", &der.buf, &derSize)) {
+ loadPKCS1Key(der.buf, derSize);
+ return;
+ }
+ if (der.buf)
+ delete[] der.takeBuf();
+ if (loadPEM(data.buf, size, "-----BEGIN PRIVATE KEY-----\n",
+ "-----END PRIVATE KEY-----", &der.buf, &derSize)) {
+ loadPKCS8Key(der.buf, derSize);
+ return;
+ }
+ throw ConnFailedException("failed to import key");
+}
+
+void SSecurityRSAAES::loadPKCS1Key(const rdr::U8* data, size_t size)
+{
+ struct rsa_public_key pub;
+ rsa_private_key_init(&serverKey);
+ rsa_public_key_init(&pub);
+ if (!rsa_keypair_from_der(&pub, &serverKey, 0, size, data)) {
+ rsa_private_key_clear(&serverKey);
+ rsa_public_key_clear(&pub);
+ throw ConnFailedException("failed to import key");
+ }
+ serverKeyLength = serverKey.size * 8;
+ serverKeyN = new rdr::U8[serverKey.size];
+ serverKeyE = new rdr::U8[serverKey.size];
+ nettle_mpz_get_str_256(serverKey.size, serverKeyN, pub.n);
+ nettle_mpz_get_str_256(serverKey.size, serverKeyE, pub.e);
+ rsa_public_key_clear(&pub);
+}
+
+void SSecurityRSAAES::loadPKCS8Key(const rdr::U8* data, size_t size)
+{
+ struct asn1_der_iterator i, j;
+ uint32_t version;
+ const char* rsaIdentifier = "\x2a\x86\x48\x86\xf7\x0d\x01\x01\x01";
+ const size_t rsaIdentifierLength = 9;
+ enum asn1_iterator_result res = asn1_der_iterator_first(&i, size, data);
+ if (res != ASN1_ITERATOR_CONSTRUCTED)
+ goto failed;
+ if (i.type != ASN1_SEQUENCE)
+ goto failed;
+ if (asn1_der_decode_constructed_last(&i) != ASN1_ITERATOR_PRIMITIVE)
+ goto failed;
+ if (!(i.type == ASN1_INTEGER &&
+ asn1_der_get_uint32(&i, &version) &&
+ version == 0))
+ goto failed;
+ if (!(asn1_der_iterator_next(&i) == ASN1_ITERATOR_CONSTRUCTED &&
+ i.type == ASN1_SEQUENCE &&
+ asn1_der_decode_constructed(&i, &j) == ASN1_ITERATOR_PRIMITIVE &&
+ j.type == ASN1_IDENTIFIER &&
+ j.length == rsaIdentifierLength &&
+ memcmp(j.data, rsaIdentifier, rsaIdentifierLength) == 0))
+ goto failed;
+ if (!(asn1_der_iterator_next(&i) == ASN1_ITERATOR_PRIMITIVE &&
+ i.type == ASN1_OCTETSTRING && i.length))
+ goto failed;
+ loadPKCS1Key(i.data, i.length);
+ return;
+failed:
+ throw ConnFailedException("failed to import key");
+}
+
+bool SSecurityRSAAES::processMsg()
+{
+ switch (state) {
+ case SendPublicKey:
+ loadPrivateKey();
+ writePublicKey();
+ state = ReadPublicKey;
+ // fall through
+ case ReadPublicKey:
+ if (readPublicKey()) {
+ writeRandom();
+ state = ReadRandom;
+ }
+ return false;
+ case ReadRandom:
+ if (readRandom()) {
+ setCipher();
+ writeHash();
+ state = ReadHash;
+ }
+ return false;
+ case ReadHash:
+ if (readHash()) {
+ clearSecrets();
+ writeSubtype();
+ state = ReadCredentials;
+ }
+ return false;
+ case ReadCredentials:
+ if (readCredentials()) {
+ if (requireUsername)
+ verifyUserPass();
+ else
+ verifyPass();
+ return true;
+ }
+ return false;
+ }
+ assert(!"unreachable");
+ return false;
+}
+
+void SSecurityRSAAES::writePublicKey()
+{
+ rdr::OutStream* os = sc->getOutStream();
+ os->writeU32(serverKeyLength);
+ os->writeBytes(serverKeyN, serverKey.size);
+ os->writeBytes(serverKeyE, serverKey.size);
+ os->flush();
+}
+
+bool SSecurityRSAAES::readPublicKey()
+{
+ rdr::InStream* is = sc->getInStream();
+ if (!is->hasData(4))
+ return false;
+ is->setRestorePoint();
+ clientKeyLength = is->readU32();
+ if (clientKeyLength < MinKeyLength)
+ throw ConnFailedException("client key is too short");
+ if (clientKeyLength > MaxKeyLength)
+ throw ConnFailedException("client key is too long");
+ size_t size = (clientKeyLength + 7) / 8;
+ if (!is->hasDataOrRestore(size * 2))
+ return false;
+ is->clearRestorePoint();
+ clientKeyE = new rdr::U8[size];
+ clientKeyN = new rdr::U8[size];
+ is->readBytes(clientKeyN, size);
+ is->readBytes(clientKeyE, size);
+ rsa_public_key_init(&clientKey);
+ nettle_mpz_set_str_256_u(clientKey.n, size, clientKeyN);
+ nettle_mpz_set_str_256_u(clientKey.e, size, clientKeyE);
+ if (!rsa_public_key_prepare(&clientKey))
+ throw ConnFailedException("client key is invalid");
+ return true;
+}
+
+static void random_func(void* ctx, size_t length, uint8_t* dst)
+{
+ rdr::RandomStream* rs = (rdr::RandomStream*)ctx;
+ if (!rs->hasData(length))
+ throw ConnFailedException("failed to encrypt random");
+ rs->readBytes(dst, length);
+}
+
+void SSecurityRSAAES::writeRandom()
+{
+ rdr::OutStream* os = sc->getOutStream();
+ if (!rs.hasData(keySize / 8))
+ throw ConnFailedException("failed to generate random");
+ rs.readBytes(serverRandom, keySize / 8);
+ mpz_t x;
+ mpz_init(x);
+ int res;
+ try {
+ res = rsa_encrypt(&clientKey, &rs, random_func, keySize / 8,
+ serverRandom, x);
+ } catch (...) {
+ mpz_clear(x);
+ throw;
+ }
+ if (!res) {
+ mpz_clear(x);
+ throw ConnFailedException("failed to encrypt random");
+ }
+ rdr::U8* buffer = new rdr::U8[clientKey.size];
+ nettle_mpz_get_str_256(clientKey.size, buffer, x);
+ mpz_clear(x);
+ os->writeU16(clientKey.size);
+ os->writeBytes(buffer, clientKey.size);
+ os->flush();
+ delete[] buffer;
+}
+
+bool SSecurityRSAAES::readRandom()
+{
+ rdr::InStream* is = sc->getInStream();
+ if (!is->hasData(2))
+ return false;
+ is->setRestorePoint();
+ size_t size = is->readU16();
+ if (size != serverKey.size)
+ throw ConnFailedException("server key length doesn't match");
+ if (!is->hasDataOrRestore(size))
+ return false;
+ is->clearRestorePoint();
+ rdr::U8* buffer = new rdr::U8[size];
+ is->readBytes(buffer, size);
+ size_t randomSize = keySize / 8;
+ mpz_t x;
+ nettle_mpz_init_set_str_256_u(x, size, buffer);
+ delete[] buffer;
+ if (!rsa_decrypt(&serverKey, &randomSize, clientRandom, x) ||
+ randomSize != (size_t)keySize / 8) {
+ mpz_clear(x);
+ throw ConnFailedException("failed to decrypt client random");
+ }
+ mpz_clear(x);
+ return true;
+}
+
+void SSecurityRSAAES::setCipher()
+{
+ rawis = sc->getInStream();
+ rawos = sc->getOutStream();
+ rdr::U8 key[32];
+ if (keySize == 128) {
+ struct sha1_ctx ctx;
+ sha1_init(&ctx);
+ sha1_update(&ctx, 16, serverRandom);
+ sha1_update(&ctx, 16, clientRandom);
+ sha1_digest(&ctx, 16, key);
+ rais = new rdr::AESInStream(rawis, key, 128);
+ sha1_init(&ctx);
+ sha1_update(&ctx, 16, clientRandom);
+ sha1_update(&ctx, 16, serverRandom);
+ sha1_digest(&ctx, 16, key);
+ raos = new rdr::AESOutStream(rawos, key, 128);
+ } else {
+ struct sha256_ctx ctx;
+ sha256_init(&ctx);
+ sha256_update(&ctx, 32, serverRandom);
+ sha256_update(&ctx, 32, clientRandom);
+ sha256_digest(&ctx, 32, key);
+ rais = new rdr::AESInStream(rawis, key, 256);
+ sha256_init(&ctx);
+ sha256_update(&ctx, 32, clientRandom);
+ sha256_update(&ctx, 32, serverRandom);
+ sha256_digest(&ctx, 32, key);
+ raos = new rdr::AESOutStream(rawos, key, 256);
+ }
+ if (isAllEncrypted)
+ sc->setStreams(rais, raos);
+}
+
+void SSecurityRSAAES::writeHash()
+{
+ rdr::U8 hash[32];
+ size_t len = serverKeyLength;
+ rdr::U8 lenServerKey[4] = {
+ (rdr::U8)((len & 0xff000000) >> 24),
+ (rdr::U8)((len & 0xff0000) >> 16),
+ (rdr::U8)((len & 0xff00) >> 8),
+ (rdr::U8)(len & 0xff)
+ };
+ len = clientKeyLength;
+ rdr::U8 lenClientKey[4] = {
+ (rdr::U8)((len & 0xff000000) >> 24),
+ (rdr::U8)((len & 0xff0000) >> 16),
+ (rdr::U8)((len & 0xff00) >> 8),
+ (rdr::U8)(len & 0xff)
+ };
+ int hashSize;
+ if (keySize == 128) {
+ hashSize = 20;
+ struct sha1_ctx ctx;
+ sha1_init(&ctx);
+ sha1_update(&ctx, 4, lenServerKey);
+ sha1_update(&ctx, serverKey.size, serverKeyN);
+ sha1_update(&ctx, serverKey.size, serverKeyE);
+ sha1_update(&ctx, 4, lenClientKey);
+ sha1_update(&ctx, clientKey.size, clientKeyN);
+ sha1_update(&ctx, clientKey.size, clientKeyE);
+ sha1_digest(&ctx, hashSize, hash);
+ } else {
+ hashSize = 32;
+ struct sha256_ctx ctx;
+ sha256_init(&ctx);
+ sha256_update(&ctx, 4, lenServerKey);
+ sha256_update(&ctx, serverKey.size, serverKeyN);
+ sha256_update(&ctx, serverKey.size, serverKeyE);
+ sha256_update(&ctx, 4, lenClientKey);
+ sha256_update(&ctx, clientKey.size, clientKeyN);
+ sha256_update(&ctx, clientKey.size, clientKeyE);
+ sha256_digest(&ctx, hashSize, hash);
+ }
+ raos->writeBytes(hash, hashSize);
+ raos->flush();
+}
+
+bool SSecurityRSAAES::readHash()
+{
+ rdr::U8 hash[32];
+ rdr::U8 realHash[32];
+ int hashSize = keySize == 128 ? 20 : 32;
+ if (!rais->hasData(hashSize))
+ return false;
+ rais->readBytes(hash, hashSize);
+ size_t len = serverKeyLength;
+ rdr::U8 lenServerKey[4] = {
+ (rdr::U8)((len & 0xff000000) >> 24),
+ (rdr::U8)((len & 0xff0000) >> 16),
+ (rdr::U8)((len & 0xff00) >> 8),
+ (rdr::U8)(len & 0xff)
+ };
+ len = clientKeyLength;
+ rdr::U8 lenClientKey[4] = {
+ (rdr::U8)((len & 0xff000000) >> 24),
+ (rdr::U8)((len & 0xff0000) >> 16),
+ (rdr::U8)((len & 0xff00) >> 8),
+ (rdr::U8)(len & 0xff)
+ };
+ if (keySize == 128) {
+ struct sha1_ctx ctx;
+ sha1_init(&ctx);
+ sha1_update(&ctx, 4, lenClientKey);
+ sha1_update(&ctx, clientKey.size, clientKeyN);
+ sha1_update(&ctx, clientKey.size, clientKeyE);
+ sha1_update(&ctx, 4, lenServerKey);
+ sha1_update(&ctx, serverKey.size, serverKeyN);
+ sha1_update(&ctx, serverKey.size, serverKeyE);
+ sha1_digest(&ctx, hashSize, realHash);
+ } else {
+ struct sha256_ctx ctx;
+ sha256_init(&ctx);
+ sha256_update(&ctx, 4, lenClientKey);
+ sha256_update(&ctx, clientKey.size, clientKeyN);
+ sha256_update(&ctx, clientKey.size, clientKeyE);
+ sha256_update(&ctx, 4, lenServerKey);
+ sha256_update(&ctx, serverKey.size, serverKeyN);
+ sha256_update(&ctx, serverKey.size, serverKeyE);
+ sha256_digest(&ctx, hashSize, realHash);
+ }
+ if (memcmp(hash, realHash, hashSize) != 0)
+ throw ConnFailedException("hash doesn't match");
+ return true;
+}
+
+void SSecurityRSAAES::clearSecrets()
+{
+ rsa_private_key_clear(&serverKey);
+ rsa_public_key_clear(&clientKey);
+ serverKey.size = 0;
+ clientKey.size = 0;
+ delete[] serverKeyN;
+ delete[] serverKeyE;
+ delete[] clientKeyN;
+ delete[] clientKeyE;
+ serverKeyN = NULL;
+ serverKeyE = NULL;
+ clientKeyN = NULL;
+ clientKeyE = NULL;
+ memset(serverRandom, 0, sizeof(serverRandom));
+ memset(clientRandom, 0, sizeof(clientRandom));
+}
+
+void SSecurityRSAAES::writeSubtype()
+{
+ if (requireUsername)
+ raos->writeU8(secTypeRA2UserPass);
+ else
+ raos->writeU8(secTypeRA2Pass);
+ raos->flush();
+}
+
+bool SSecurityRSAAES::readCredentials()
+{
+ rais->setRestorePoint();
+ if (!rais->hasData(1))
+ return false;
+ rdr::U8 lenUsername = rais->readU8();
+ if (!rais->hasDataOrRestore(lenUsername + 1))
+ return false;
+ if (!username.buf) {
+ username.replaceBuf(new char[lenUsername + 1]);
+ rais->readBytes(username.buf, lenUsername);
+ username.buf[lenUsername] = 0;
+ } else {
+ rais->skip(lenUsername);
+ }
+ rdr::U8 lenPassword = rais->readU8();
+ if (!rais->hasDataOrRestore(lenPassword))
+ return false;
+ password.replaceBuf(new char[lenPassword + 1]);
+ rais->readBytes(password.buf, lenPassword);
+ password.buf[lenPassword] = 0;
+ rais->clearRestorePoint();
+ return true;
+}
+
+void SSecurityRSAAES::verifyUserPass()
+{
+#ifndef __APPLE__
+#ifdef WIN32
+ WinPasswdValidator* valid = new WinPasswdValidator();
+#elif !defined(__APPLE__)
+ UnixPasswordValidator *valid = new UnixPasswordValidator();
+#endif
+ if (!valid->validate(sc, username.buf, password.buf)) {
+ delete valid;
+ throw AuthFailureException("invalid password or username");
+ }
+ delete valid;
+#else
+ throw AuthFailureException("No password validator configured");
+#endif
+}
+
+void SSecurityRSAAES::verifyPass()
+{
+ VncAuthPasswdGetter* pg = &SSecurityVncAuth::vncAuthPasswd;
+ PlainPasswd passwd, passwdReadOnly;
+ pg->getVncAuthPasswd(&passwd, &passwdReadOnly);
+
+ if (!passwd.buf)
+ throw AuthFailureException("No password configured for VNC Auth");
+
+ if (strcmp(password.buf, passwd.buf) == 0) {
+ accessRights = SConnection::AccessDefault;
+ return;
+ }
+
+ if (passwdReadOnly.buf && strcmp(password.buf, passwdReadOnly.buf) == 0) {
+ accessRights = SConnection::AccessView;
+ return;
+ }
+
+ throw AuthFailureException();
+}
+
+const char* SSecurityRSAAES::getUserName() const
+{
+ return username.buf;
+}