diff options
Diffstat (limited to 'common/rdr')
-rw-r--r-- | common/rdr/CMakeLists.txt | 2 | ||||
-rw-r--r-- | common/rdr/InStream.h | 8 | ||||
-rw-r--r-- | common/rdr/TLSException.cxx | 23 | ||||
-rw-r--r-- | common/rdr/TLSException.h | 6 | ||||
-rw-r--r-- | common/rdr/TLSInStream.cxx | 91 | ||||
-rw-r--r-- | common/rdr/TLSInStream.h | 18 | ||||
-rw-r--r-- | common/rdr/TLSOutStream.cxx | 77 | ||||
-rw-r--r-- | common/rdr/TLSOutStream.h | 17 | ||||
-rw-r--r-- | common/rdr/TLSSocket.cxx | 228 | ||||
-rw-r--r-- | common/rdr/TLSSocket.h | 81 |
10 files changed, 365 insertions, 186 deletions
diff --git a/common/rdr/CMakeLists.txt b/common/rdr/CMakeLists.txt index 55c69132..526b2971 100644 --- a/common/rdr/CMakeLists.txt +++ b/common/rdr/CMakeLists.txt @@ -12,6 +12,7 @@ add_library(rdr STATIC TLSException.cxx TLSInStream.cxx TLSOutStream.cxx + TLSSocket.cxx ZlibInStream.cxx ZlibOutStream.cxx) @@ -27,7 +28,6 @@ endif() if (NETTLE_FOUND) target_include_directories(rdr SYSTEM PUBLIC ${NETTLE_INCLUDE_DIRS}) target_link_libraries(rdr ${NETTLE_LIBRARIES}) - target_link_directories(rdr PUBLIC ${NETTLE_LIBRARY_DIRS}) endif() if(WIN32) target_link_libraries(rdr ws2_32) diff --git a/common/rdr/InStream.h b/common/rdr/InStream.h index 5bec276d..7ad4996f 100644 --- a/common/rdr/InStream.h +++ b/common/rdr/InStream.h @@ -187,9 +187,7 @@ namespace rdr { private: const uint8_t* restorePoint; -#ifdef RFB_INSTREAM_CHECK size_t checkedBytes; -#endif inline void check(size_t bytes) { #ifdef RFB_INSTREAM_CHECK @@ -209,11 +207,7 @@ namespace rdr { protected: - InStream() : restorePoint(nullptr) -#ifdef RFB_INSTREAM_CHECK - ,checkedBytes(0) -#endif - {} + InStream() : restorePoint(nullptr), checkedBytes(0) {} const uint8_t* ptr; const uint8_t* end; }; diff --git a/common/rdr/TLSException.cxx b/common/rdr/TLSException.cxx index a1896af4..8c93a3d3 100644 --- a/common/rdr/TLSException.cxx +++ b/common/rdr/TLSException.cxx @@ -35,11 +35,28 @@ using namespace rdr; #ifdef HAVE_GNUTLS -tls_error::tls_error(const char* s, int err_) noexcept +tls_error::tls_error(const char* s, int err_, int alert_) noexcept : std::runtime_error(core::format("%s: %s (%d)", s, - gnutls_strerror(err_), err_)), - err(err_) + strerror(err_, alert_), err_)), + err(err_), alert(alert_) { } + +const char* tls_error::strerror(int err_, int alert_) const noexcept +{ + const char* msg; + + msg = nullptr; + + if ((alert_ != -1) && + ((err_ == GNUTLS_E_WARNING_ALERT_RECEIVED) || + (err_ == GNUTLS_E_FATAL_ALERT_RECEIVED))) + msg = gnutls_alert_get_name((gnutls_alert_description_t)alert_); + + if (msg == nullptr) + msg = gnutls_strerror(err_); + + return msg; +} #endif /* HAVE_GNUTLS */ diff --git a/common/rdr/TLSException.h b/common/rdr/TLSException.h index b35a675f..75ee94f5 100644 --- a/common/rdr/TLSException.h +++ b/common/rdr/TLSException.h @@ -27,8 +27,10 @@ namespace rdr { class tls_error : public std::runtime_error { public: - int err; - tls_error(const char* s, int err_) noexcept; + int err, alert; + tls_error(const char* s, int err_, int alert_=-1) noexcept; + private: + const char* strerror(int err_, int alert_) const noexcept; }; } diff --git a/common/rdr/TLSInStream.cxx b/common/rdr/TLSInStream.cxx index 223e3ee6..3e5ea2be 100644 --- a/common/rdr/TLSInStream.cxx +++ b/common/rdr/TLSInStream.cxx @@ -1,7 +1,7 @@ /* Copyright (C) 2002-2005 RealVNC Ltd. All Rights Reserved. * Copyright (C) 2005 Martin Koegler * Copyright (C) 2010 TigerVNC Team - * Copyright (C) 2012-2021 Pierre Ossman for Cendio AB + * Copyright 2012-2025 Pierre Ossman for Cendio AB * * This is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -23,73 +23,25 @@ #include <config.h> #endif -#include <core/Exception.h> -#include <core/LogWriter.h> - -#include <rdr/TLSException.h> #include <rdr/TLSInStream.h> +#include <rdr/TLSSocket.h> -#include <errno.h> +#ifdef HAVE_GNUTLS -#ifdef HAVE_GNUTLS using namespace rdr; -static core::LogWriter vlog("TLSInStream"); - -ssize_t TLSInStream::pull(gnutls_transport_ptr_t str, void* data, size_t size) -{ - TLSInStream* self= (TLSInStream*) str; - InStream *in = self->in; - - self->streamEmpty = false; - self->saved_exception = nullptr; - - try { - if (!in->hasData(1)) { - self->streamEmpty = true; - gnutls_transport_set_errno(self->session, EAGAIN); - return -1; - } - - if (in->avail() < size) - size = in->avail(); - - in->readBytes((uint8_t*)data, size); - } catch (end_of_stream&) { - return 0; - } catch (std::exception& e) { - core::socket_error* se; - vlog.error("Failure reading TLS data: %s", e.what()); - se = dynamic_cast<core::socket_error*>(&e); - if (se) - gnutls_transport_set_errno(self->session, se->err); - else - gnutls_transport_set_errno(self->session, EINVAL); - self->saved_exception = std::current_exception(); - return -1; - } - - return size; -} - -TLSInStream::TLSInStream(InStream* _in, gnutls_session_t _session) - : session(_session), in(_in) +TLSInStream::TLSInStream(TLSSocket* sock_) + : sock(sock_) { - gnutls_transport_ptr_t recv, send; - - gnutls_transport_set_pull_function(session, pull); - gnutls_transport_get_ptr2(session, &recv, &send); - gnutls_transport_set_ptr2(session, this, send); } TLSInStream::~TLSInStream() { - gnutls_transport_set_pull_function(session, nullptr); } bool TLSInStream::fillBuffer() { - size_t n = readTLS((uint8_t*) end, availSpace()); + size_t n = sock->readTLS((uint8_t*) end, availSpace()); if (n == 0) return false; end += n; @@ -97,35 +49,4 @@ bool TLSInStream::fillBuffer() return true; } -size_t TLSInStream::readTLS(uint8_t* buf, size_t len) -{ - int n; - - while (true) { - streamEmpty = false; - n = gnutls_record_recv(session, (void *) buf, len); - if (n == GNUTLS_E_INTERRUPTED || n == GNUTLS_E_AGAIN) { - // GnuTLS returns GNUTLS_E_AGAIN for a bunch of other scenarios - // other than the pull function returning EAGAIN, so we have to - // double check that the underlying stream really is empty - if (!streamEmpty) - continue; - else - return 0; - } - break; - }; - - if (n == GNUTLS_E_PULL_ERROR) - std::rethrow_exception(saved_exception); - - if (n < 0) - throw tls_error("readTLS", n); - - if (n == 0) - throw end_of_stream(); - - return n; -} - #endif diff --git a/common/rdr/TLSInStream.h b/common/rdr/TLSInStream.h index bc9c74ba..94266e50 100644 --- a/common/rdr/TLSInStream.h +++ b/common/rdr/TLSInStream.h @@ -1,5 +1,6 @@ /* Copyright (C) 2005 Martin Koegler * Copyright (C) 2010 TigerVNC Team + * Copyright 2012-2025 Pierre Ossman for Cendio AB * * This is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -22,28 +23,25 @@ #ifdef HAVE_GNUTLS -#include <gnutls/gnutls.h> #include <rdr/BufferedInStream.h> namespace rdr { + class TLSSocket; + class TLSInStream : public BufferedInStream { public: - TLSInStream(InStream* in, gnutls_session_t session); + TLSInStream(TLSSocket* sock); virtual ~TLSInStream(); private: bool fillBuffer() override; - size_t readTLS(uint8_t* buf, size_t len); - static ssize_t pull(gnutls_transport_ptr_t str, void* data, size_t size); - - gnutls_session_t session; - InStream* in; - bool streamEmpty; - std::exception_ptr saved_exception; + TLSSocket* sock; }; -}; + +} #endif + #endif diff --git a/common/rdr/TLSOutStream.cxx b/common/rdr/TLSOutStream.cxx index c3ae2d0a..ba9d182f 100644 --- a/common/rdr/TLSOutStream.cxx +++ b/common/rdr/TLSOutStream.cxx @@ -1,7 +1,7 @@ /* Copyright (C) 2002-2005 RealVNC Ltd. All Rights Reserved. * Copyright (C) 2005 Martin Koegler * Copyright (C) 2010 TigerVNC Team - * Copyright (C) 2012-2021 Pierre Ossman for Cendio AB + * Copyright 2012-2025 Pierre Ossman for Cendio AB * * This is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -23,103 +23,42 @@ #include <config.h> #endif -#include <core/Exception.h> -#include <core/LogWriter.h> - -#include <rdr/TLSException.h> #include <rdr/TLSOutStream.h> - -#include <errno.h> +#include <rdr/TLSSocket.h> #ifdef HAVE_GNUTLS -using namespace rdr; -static core::LogWriter vlog("TLSOutStream"); +using namespace rdr; -ssize_t TLSOutStream::push(gnutls_transport_ptr_t str, const void* data, - size_t size) +TLSOutStream::TLSOutStream(TLSSocket* sock_) + : sock(sock_) { - TLSOutStream* self= (TLSOutStream*) str; - OutStream *out = self->out; - - self->saved_exception = nullptr; - - try { - out->writeBytes((const uint8_t*)data, size); - out->flush(); - } catch (std::exception& e) { - core::socket_error* se; - vlog.error("Failure sending TLS data: %s", e.what()); - se = dynamic_cast<core::socket_error*>(&e); - if (se) - gnutls_transport_set_errno(self->session, se->err); - else - gnutls_transport_set_errno(self->session, EINVAL); - self->saved_exception = std::current_exception(); - return -1; - } - - return size; -} - -TLSOutStream::TLSOutStream(OutStream* _out, gnutls_session_t _session) - : session(_session), out(_out) -{ - gnutls_transport_ptr_t recv, send; - - gnutls_transport_set_push_function(session, push); - gnutls_transport_get_ptr2(session, &recv, &send); - gnutls_transport_set_ptr2(session, recv, this); } TLSOutStream::~TLSOutStream() { -#if 0 - try { -// flush(); - } catch (Exception&) { - } -#endif - gnutls_transport_set_push_function(session, nullptr); } void TLSOutStream::flush() { BufferedOutStream::flush(); - out->flush(); + sock->out->flush(); } void TLSOutStream::cork(bool enable) { BufferedOutStream::cork(enable); - out->cork(enable); + sock->out->cork(enable); } bool TLSOutStream::flushBuffer() { while (sentUpTo < ptr) { - size_t n = writeTLS(sentUpTo, ptr - sentUpTo); + size_t n = sock->writeTLS(sentUpTo, ptr - sentUpTo); sentUpTo += n; } return true; } -size_t TLSOutStream::writeTLS(const uint8_t* data, size_t length) -{ - int n; - - n = gnutls_record_send(session, data, length); - if (n == GNUTLS_E_INTERRUPTED || n == GNUTLS_E_AGAIN) - return 0; - - if (n == GNUTLS_E_PUSH_ERROR) - std::rethrow_exception(saved_exception); - - if (n < 0) - throw tls_error("writeTLS", n); - - return n; -} - #endif diff --git a/common/rdr/TLSOutStream.h b/common/rdr/TLSOutStream.h index 0ae9c460..aa9572ba 100644 --- a/common/rdr/TLSOutStream.h +++ b/common/rdr/TLSOutStream.h @@ -21,14 +21,16 @@ #define __RDR_TLSOUTSTREAM_H__ #ifdef HAVE_GNUTLS -#include <gnutls/gnutls.h> + #include <rdr/BufferedOutStream.h> namespace rdr { + class TLSSocket; + class TLSOutStream : public BufferedOutStream { public: - TLSOutStream(OutStream* out, gnutls_session_t session); + TLSOutStream(TLSSocket* out); virtual ~TLSOutStream(); void flush() override; @@ -36,15 +38,12 @@ namespace rdr { private: bool flushBuffer() override; - size_t writeTLS(const uint8_t* data, size_t length); - static ssize_t push(gnutls_transport_ptr_t str, const void* data, size_t size); - - gnutls_session_t session; - OutStream* out; - std::exception_ptr saved_exception; + TLSSocket* sock; }; -}; + +} #endif + #endif diff --git a/common/rdr/TLSSocket.cxx b/common/rdr/TLSSocket.cxx new file mode 100644 index 00000000..a29e41c1 --- /dev/null +++ b/common/rdr/TLSSocket.cxx @@ -0,0 +1,228 @@ +/* Copyright (C) 2002-2005 RealVNC Ltd. All Rights Reserved. + * Copyright (C) 2005 Martin Koegler + * Copyright (C) 2010 TigerVNC Team + * Copyright (C) 2012-2025 Pierre Ossman for Cendio AB + * + * This is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This software is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this software; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, + * USA. + */ + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#include <core/Exception.h> +#include <core/LogWriter.h> + +#include <rdr/InStream.h> +#include <rdr/OutStream.h> +#include <rdr/TLSException.h> +#include <rdr/TLSSocket.h> + +#include <errno.h> + +#ifdef HAVE_GNUTLS + +using namespace rdr; + +static core::LogWriter vlog("TLSSocket"); + +TLSSocket::TLSSocket(InStream* in_, OutStream* out_, + gnutls_session_t session_) + : session(session_), in(in_), out(out_), tlsin(this), tlsout(this) +{ + gnutls_transport_set_pull_function( + session, [](gnutls_transport_ptr_t sock, void* data, size_t size) { + return ((TLSSocket*)sock)->pull(data, size); + }); + gnutls_transport_set_push_function( + session, [](gnutls_transport_ptr_t sock, const void* data, size_t size) { + return ((TLSSocket*)sock)->push(data, size); + }); + gnutls_transport_set_ptr(session, this); +} + +TLSSocket::~TLSSocket() +{ + gnutls_transport_set_pull_function(session, nullptr); + gnutls_transport_set_push_function(session, nullptr); + gnutls_transport_set_ptr(session, nullptr); +} + +bool TLSSocket::handshake() +{ + int err; + + err = gnutls_handshake(session); + if (err != GNUTLS_E_SUCCESS) { + gnutls_alert_description_t alert; + const char* msg; + + if ((err == GNUTLS_E_PULL_ERROR) || (err == GNUTLS_E_PUSH_ERROR)) + std::rethrow_exception(saved_exception); + + alert = gnutls_alert_get(session); + msg = nullptr; + + if ((err == GNUTLS_E_WARNING_ALERT_RECEIVED) || + (err == GNUTLS_E_FATAL_ALERT_RECEIVED)) + msg = gnutls_alert_get_name(alert); + + if (msg == nullptr) + msg = gnutls_strerror(err); + + if (!gnutls_error_is_fatal(err)) { + vlog.debug("Deferring completion of TLS handshake: %s", msg); + return false; + } + + vlog.error("TLS Handshake failed: %s\n", msg); + gnutls_alert_send_appropriate(session, err); + throw rdr::tls_error("TLS Handshake failed", err, alert); + } + + return true; +} + +void TLSSocket::shutdown() +{ + int ret; + + try { + if (tlsout.hasBufferedData()) { + tlsout.cork(false); + tlsout.flush(); + if (tlsout.hasBufferedData()) + vlog.error("Failed to flush remaining socket data on close"); + } + } catch (std::exception& e) { + vlog.error("Failed to flush remaining socket data on close: %s", e.what()); + } + + // FIXME: We can't currently wait for the response, so we only send + // our close and hope for the best + ret = gnutls_bye(session, GNUTLS_SHUT_WR); + if ((ret != GNUTLS_E_SUCCESS) && (ret != GNUTLS_E_INVALID_SESSION)) + vlog.error("TLS shutdown failed: %s", gnutls_strerror(ret)); +} + +size_t TLSSocket::readTLS(uint8_t* buf, size_t len) +{ + int n; + + while (true) { + streamEmpty = false; + n = gnutls_record_recv(session, (void *) buf, len); + if (n == GNUTLS_E_INTERRUPTED || n == GNUTLS_E_AGAIN) { + // GnuTLS returns GNUTLS_E_AGAIN for a bunch of other scenarios + // other than the pull function returning EAGAIN, so we have to + // double check that the underlying stream really is empty + if (!streamEmpty) + continue; + else + return 0; + } + break; + }; + + if (n == GNUTLS_E_PULL_ERROR) + std::rethrow_exception(saved_exception); + + if (n < 0) { + gnutls_alert_send_appropriate(session, n); + throw tls_error("readTLS", n, gnutls_alert_get(session)); + } + + if (n == 0) + throw end_of_stream(); + + return n; +} + +size_t TLSSocket::writeTLS(const uint8_t* data, size_t length) +{ + int n; + + n = gnutls_record_send(session, data, length); + if (n == GNUTLS_E_INTERRUPTED || n == GNUTLS_E_AGAIN) + return 0; + + if (n == GNUTLS_E_PUSH_ERROR) + std::rethrow_exception(saved_exception); + + if (n < 0) { + gnutls_alert_send_appropriate(session, n); + throw tls_error("writeTLS", n, gnutls_alert_get(session)); + } + + return n; +} + +ssize_t TLSSocket::pull(void* data, size_t size) +{ + streamEmpty = false; + saved_exception = nullptr; + + try { + if (!in->hasData(1)) { + streamEmpty = true; + gnutls_transport_set_errno(session, EAGAIN); + return -1; + } + + if (in->avail() < size) + size = in->avail(); + + in->readBytes((uint8_t*)data, size); + } catch (end_of_stream&) { + return 0; + } catch (std::exception& e) { + core::socket_error* se; + vlog.error("Failure reading TLS data: %s", e.what()); + se = dynamic_cast<core::socket_error*>(&e); + if (se) + gnutls_transport_set_errno(session, se->err); + else + gnutls_transport_set_errno(session, EINVAL); + saved_exception = std::current_exception(); + return -1; + } + + return size; +} + +ssize_t TLSSocket::push(const void* data, size_t size) +{ + saved_exception = nullptr; + + try { + out->writeBytes((const uint8_t*)data, size); + out->flush(); + } catch (std::exception& e) { + core::socket_error* se; + vlog.error("Failure sending TLS data: %s", e.what()); + se = dynamic_cast<core::socket_error*>(&e); + if (se) + gnutls_transport_set_errno(session, se->err); + else + gnutls_transport_set_errno(session, EINVAL); + saved_exception = std::current_exception(); + return -1; + } + + return size; +} + +#endif diff --git a/common/rdr/TLSSocket.h b/common/rdr/TLSSocket.h new file mode 100644 index 00000000..ca29f8bc --- /dev/null +++ b/common/rdr/TLSSocket.h @@ -0,0 +1,81 @@ +/* Copyright (C) 2005 Martin Koegler + * Copyright (C) 2010 TigerVNC Team + * Copyright 2012-2025 Pierre Ossman for Cendio AB + * + * This is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This software is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this software; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, + * USA. + */ + +#ifndef __RDR_TLSSOCKET_H__ +#define __RDR_TLSSOCKET_H__ + +#ifdef HAVE_GNUTLS + +#include <exception> + +#include <gnutls/gnutls.h> + +#include <rdr/TLSInStream.h> +#include <rdr/TLSOutStream.h> + +namespace rdr { + + class InStream; + class OutStream; + + class TLSInStream; + class TLSOutStream; + + class TLSSocket { + public: + TLSSocket(InStream* in, OutStream* out, gnutls_session_t session); + virtual ~TLSSocket(); + + TLSInStream& inStream() { return tlsin; } + TLSOutStream& outStream() { return tlsout; } + + bool handshake(); + void shutdown(); + + protected: + /* Used by the stream classes */ + size_t readTLS(uint8_t* buf, size_t len); + size_t writeTLS(const uint8_t* data, size_t length); + + friend TLSInStream; + friend TLSOutStream; + + private: + ssize_t pull(void* data, size_t size); + ssize_t push(const void* data, size_t size); + + gnutls_session_t session; + + InStream* in; + OutStream* out; + + TLSInStream tlsin; + TLSOutStream tlsout; + + bool streamEmpty; + + std::exception_ptr saved_exception; + }; + +} + +#endif + +#endif |