From: Pierre Ossman Date: Thu, 14 May 2020 16:49:39 +0000 (+0200) Subject: Change streams to be asynchronous X-Git-Tag: v1.11.90~74^2~1 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=ad0f0618fa2ca13d7b916f22eccc5ba3201482cb;p=tigervnc.git Change streams to be asynchronous Major restructuring of how streams work. Neither input nor output streams are now blocking. This avoids stalling the rest of the client or server when a peer is slow or unresponsive. Note that this puts an extra burden on users of streams to make sure they are allowed to do their work once the underlying transports are ready (e.g. monitoring fds). --- diff --git a/common/rdr/BufferedInStream.cxx b/common/rdr/BufferedInStream.cxx index 14b73563..5a2694b4 100644 --- a/common/rdr/BufferedInStream.cxx +++ b/common/rdr/BufferedInStream.cxx @@ -47,7 +47,7 @@ size_t BufferedInStream::pos() return offset + ptr - start; } -bool BufferedInStream::overrun(size_t needed, bool wait) +bool BufferedInStream::overrun(size_t needed) { struct timeval now; @@ -112,7 +112,7 @@ bool BufferedInStream::overrun(size_t needed, bool wait) } while (avail() < needed) { - if (!fillBuffer(start + bufSize - end, wait)) + if (!fillBuffer(start + bufSize - end)) return false; } diff --git a/common/rdr/BufferedInStream.h b/common/rdr/BufferedInStream.h index 24b5a23c..84405255 100644 --- a/common/rdr/BufferedInStream.h +++ b/common/rdr/BufferedInStream.h @@ -38,9 +38,9 @@ namespace rdr { virtual size_t pos(); private: - virtual bool fillBuffer(size_t maxSize, bool wait) = 0; + virtual bool fillBuffer(size_t maxSize) = 0; - virtual bool overrun(size_t needed, bool wait); + virtual bool overrun(size_t needed); private: size_t bufSize; diff --git a/common/rdr/BufferedOutStream.cxx b/common/rdr/BufferedOutStream.cxx index 930b80b9..c8f6ce9c 100644 --- a/common/rdr/BufferedOutStream.cxx +++ b/common/rdr/BufferedOutStream.cxx @@ -60,7 +60,7 @@ void BufferedOutStream::flush() len = (ptr - sentUpTo); - if (!flushBuffer(false)) + if (!flushBuffer()) break; offset += len - (ptr - sentUpTo); @@ -148,4 +148,6 @@ void BufferedOutStream::overrun(size_t needed) gettimeofday(&lastSizeCheck, NULL); peakUsage = totalNeeded; + + return; } diff --git a/common/rdr/BufferedOutStream.h b/common/rdr/BufferedOutStream.h index dd64a136..b01d1fee 100644 --- a/common/rdr/BufferedOutStream.h +++ b/common/rdr/BufferedOutStream.h @@ -45,10 +45,9 @@ namespace rdr { private: // flushBuffer() requests that the stream be flushed. Returns true if it is // able to progress the output (which might still not mean any bytes - // actually moved) and can be called again. If wait is true then it will - // block until all data has been written. + // actually moved) and can be called again. - virtual bool flushBuffer(bool wait) = 0; + virtual bool flushBuffer() = 0; virtual void overrun(size_t needed); diff --git a/common/rdr/Exception.h b/common/rdr/Exception.h index eb3c8a9d..e5bff80d 100644 --- a/common/rdr/Exception.h +++ b/common/rdr/Exception.h @@ -47,10 +47,6 @@ namespace rdr { GAIException(const char* s, int err_); }; - struct TimedOut : public Exception { - TimedOut() : Exception("Timed out") {} - }; - struct EndOfStream : public Exception { EndOfStream() : Exception("End of stream") {} }; diff --git a/common/rdr/FdInStream.cxx b/common/rdr/FdInStream.cxx index 27de92bb..ecc34ecd 100644 --- a/common/rdr/FdInStream.cxx +++ b/common/rdr/FdInStream.cxx @@ -46,17 +46,8 @@ using namespace rdr; -enum { DEFAULT_BUF_SIZE = 8192 }; - -FdInStream::FdInStream(int fd_, int timeoutms_, - bool closeWhenDone_) - : fd(fd_), closeWhenDone(closeWhenDone_), - timeoutms(timeoutms_), blockCallback(0) -{ -} - -FdInStream::FdInStream(int fd_, FdInStreamBlockCallback* blockCallback_) - : fd(fd_), timeoutms(0), blockCallback(blockCallback_) +FdInStream::FdInStream(int fd_, bool closeWhenDone_) + : fd(fd_), closeWhenDone(closeWhenDone_) { } @@ -66,20 +57,9 @@ FdInStream::~FdInStream() } -void FdInStream::setTimeout(int timeoutms_) { - timeoutms = timeoutms_; -} - -void FdInStream::setBlockCallback(FdInStreamBlockCallback* blockCallback_) -{ - blockCallback = blockCallback_; - timeoutms = 0; -} - - -bool FdInStream::fillBuffer(size_t maxSize, bool wait) +bool FdInStream::fillBuffer(size_t maxSize) { - size_t n = readWithTimeoutOrCallback((U8*)end, maxSize, wait); + size_t n = readFd((U8*)end, maxSize); if (n == 0) return false; end += n; @@ -88,55 +68,43 @@ bool FdInStream::fillBuffer(size_t maxSize, bool wait) } // -// readWithTimeoutOrCallback() reads up to the given length in bytes from the -// file descriptor into a buffer. If the wait argument is false, then zero is -// returned if no bytes can be read without blocking. Otherwise if a -// blockCallback is set, it will be called (repeatedly) instead of blocking. -// If alternatively there is a timeout set and that timeout expires, it throws -// a TimedOut exception. Otherwise it returns the number of bytes read. It +// readFd() reads up to the given length in bytes from the +// file descriptor into a buffer. Zero is +// returned if no bytes can be read. Otherwise it returns the number of bytes read. It // never attempts to recv() unless select() indicates that the fd is readable - // this means it can be used on an fd which has been set non-blocking. It also // has to cope with the annoying possibility of both select() and recv() // returning EINTR. // -size_t FdInStream::readWithTimeoutOrCallback(void* buf, size_t len, bool wait) +size_t FdInStream::readFd(void* buf, size_t len) { int n; - while (true) { - do { - fd_set fds; - struct timeval tv; - struct timeval* tvp = &tv; - - if (!wait) { - tv.tv_sec = tv.tv_usec = 0; - } else if (timeoutms != -1) { - tv.tv_sec = timeoutms / 1000; - tv.tv_usec = (timeoutms % 1000) * 1000; - } else { - tvp = 0; - } - - FD_ZERO(&fds); - FD_SET(fd, &fds); - n = select(fd+1, &fds, 0, 0, tvp); - } while (n < 0 && errno == EINTR); - - if (n > 0) break; - if (n < 0) throw SystemException("select",errno); - if (!wait) return 0; - if (!blockCallback) throw TimedOut(); - - blockCallback->blockCallback(); - } + do { + fd_set fds; + struct timeval tv; + + tv.tv_sec = tv.tv_usec = 0; + + FD_ZERO(&fds); + FD_SET(fd, &fds); + n = select(fd+1, &fds, 0, 0, &tv); + } while (n < 0 && errno == EINTR); + + if (n < 0) + throw SystemException("select",errno); + + if (n == 0) + return 0; do { n = ::recv(fd, (char*)buf, len, 0); } while (n < 0 && errno == EINTR); - if (n < 0) throw SystemException("read",errno); - if (n == 0) throw EndOfStream(); + if (n < 0) + throw SystemException("read",errno); + if (n == 0) + throw EndOfStream(); return n; } diff --git a/common/rdr/FdInStream.h b/common/rdr/FdInStream.h index 0203389b..f732ceaa 100644 --- a/common/rdr/FdInStream.h +++ b/common/rdr/FdInStream.h @@ -27,33 +27,22 @@ namespace rdr { - class FdInStreamBlockCallback { - public: - virtual void blockCallback() = 0; - virtual ~FdInStreamBlockCallback() {} - }; - class FdInStream : public BufferedInStream { public: - FdInStream(int fd, int timeoutms=-1, bool closeWhenDone_=false); - FdInStream(int fd, FdInStreamBlockCallback* blockCallback); + FdInStream(int fd, bool closeWhenDone_=false); virtual ~FdInStream(); - void setTimeout(int timeoutms); - void setBlockCallback(FdInStreamBlockCallback* blockCallback); int getFd() { return fd; } private: - virtual bool fillBuffer(size_t maxSize, bool wait); + virtual bool fillBuffer(size_t maxSize); - size_t readWithTimeoutOrCallback(void* buf, size_t len, bool wait=true); + size_t readFd(void* buf, size_t len); int fd; bool closeWhenDone; - int timeoutms; - FdInStreamBlockCallback* blockCallback; size_t offset; U8* start; diff --git a/common/rdr/FdOutStream.cxx b/common/rdr/FdOutStream.cxx index 3405838d..b52fc85d 100644 --- a/common/rdr/FdOutStream.cxx +++ b/common/rdr/FdOutStream.cxx @@ -49,27 +49,14 @@ using namespace rdr; -FdOutStream::FdOutStream(int fd_, bool blocking_, int timeoutms_) - : fd(fd_), blocking(blocking_), timeoutms(timeoutms_) +FdOutStream::FdOutStream(int fd_) + : fd(fd_) { gettimeofday(&lastWrite, NULL); } FdOutStream::~FdOutStream() { - try { - while (sentUpTo != ptr) - flushBuffer(true); - } catch (Exception&) { - } -} - -void FdOutStream::setTimeout(int timeoutms_) { - timeoutms = timeoutms_; -} - -void FdOutStream::setBlocking(bool blocking_) { - blocking = blocking_; } unsigned FdOutStream::getIdleTime() @@ -87,20 +74,11 @@ void FdOutStream::cork(bool enable) #endif } -bool FdOutStream::flushBuffer(bool wait) +bool FdOutStream::flushBuffer() { - size_t n = writeWithTimeout((const void*) sentUpTo, - ptr - sentUpTo, - (blocking || wait)? timeoutms : 0); - - // Timeout? - if (n == 0) { - // If non-blocking then we're done here - if (!blocking && !wait) - return false; - - throw TimedOut(); - } + size_t n = writeFd((const void*) sentUpTo, ptr - sentUpTo); + if (n == 0) + return false; sentUpTo += n; @@ -108,34 +86,27 @@ bool FdOutStream::flushBuffer(bool wait) } // -// writeWithTimeout() writes up to the given length in bytes from the given -// buffer to the file descriptor. If there is a timeout set and that timeout -// expires, it throws a TimedOut exception. Otherwise it returns the number of -// bytes written. It never attempts to send() unless select() indicates that -// the fd is writable - this means it can be used on an fd which has been set -// non-blocking. It also has to cope with the annoying possibility of both -// select() and send() returning EINTR. +// writeFd() writes up to the given length in bytes from the given +// buffer to the file descriptor. It returns the number of bytes written. It +// never attempts to send() unless select() indicates that the fd is writable +// - this means it can be used on an fd which has been set non-blocking. It +// also has to cope with the annoying possibility of both select() and send() +// returning EINTR. // -size_t FdOutStream::writeWithTimeout(const void* data, size_t length, int timeoutms) +size_t FdOutStream::writeFd(const void* data, size_t length) { int n; do { fd_set fds; struct timeval tv; - struct timeval* tvp = &tv; - if (timeoutms != -1) { - tv.tv_sec = timeoutms / 1000; - tv.tv_usec = (timeoutms % 1000) * 1000; - } else { - tvp = NULL; - } + tv.tv_sec = tv.tv_usec = 0; FD_ZERO(&fds); FD_SET(fd, &fds); - n = select(fd+1, 0, &fds, 0, tvp); + n = select(fd+1, 0, &fds, 0, &tv); } while (n < 0 && errno == EINTR); if (n < 0) diff --git a/common/rdr/FdOutStream.h b/common/rdr/FdOutStream.h index b1fb74c0..80804da4 100644 --- a/common/rdr/FdOutStream.h +++ b/common/rdr/FdOutStream.h @@ -34,11 +34,9 @@ namespace rdr { public: - FdOutStream(int fd, bool blocking=true, int timeoutms=-1); + FdOutStream(int fd); virtual ~FdOutStream(); - void setTimeout(int timeoutms); - void setBlocking(bool blocking); int getFd() { return fd; } unsigned getIdleTime(); @@ -46,11 +44,9 @@ namespace rdr { virtual void cork(bool enable); private: - virtual bool flushBuffer(bool wait); - size_t writeWithTimeout(const void* data, size_t length, int timeoutms); + virtual bool flushBuffer(); + size_t writeFd(const void* data, size_t length); int fd; - bool blocking; - int timeoutms; struct timeval lastWrite; }; diff --git a/common/rdr/FileInStream.cxx b/common/rdr/FileInStream.cxx index 66dfe766..9975fde6 100644 --- a/common/rdr/FileInStream.cxx +++ b/common/rdr/FileInStream.cxx @@ -39,7 +39,7 @@ FileInStream::~FileInStream(void) { } } -bool FileInStream::fillBuffer(size_t maxSize, bool wait) +bool FileInStream::fillBuffer(size_t maxSize) { size_t n = fread((U8 *)end, 1, maxSize, file); if (n == 0) { diff --git a/common/rdr/FileInStream.h b/common/rdr/FileInStream.h index 268f5375..619397f0 100644 --- a/common/rdr/FileInStream.h +++ b/common/rdr/FileInStream.h @@ -34,7 +34,7 @@ namespace rdr { ~FileInStream(void); private: - virtual bool fillBuffer(size_t maxSize, bool wait); + virtual bool fillBuffer(size_t maxSize); private: FILE *file; diff --git a/common/rdr/HexInStream.cxx b/common/rdr/HexInStream.cxx index 322432c0..66bbf174 100644 --- a/common/rdr/HexInStream.cxx +++ b/common/rdr/HexInStream.cxx @@ -73,7 +73,7 @@ decodeError: bool HexInStream::fillBuffer(size_t maxSize, bool wait) { - if (!in_stream.check(2, wait)) + if (!in_stream.hasData(2)) return false; size_t length = min(in_stream.avail()/2, maxSize); diff --git a/common/rdr/InStream.h b/common/rdr/InStream.h index 5d873011..60ea4997 100644 --- a/common/rdr/InStream.h +++ b/common/rdr/InStream.h @@ -1,4 +1,5 @@ /* Copyright (C) 2002-2005 RealVNC Ltd. All Rights Reserved. + * Copyright 2014-2020 Pierre Ossman for Cendio AB * * This is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -28,6 +29,10 @@ #include #include // for memcpy +// Check that callers are using InStream properly, +// useful when writing new protocol handling +#undef RFB_INSTREAM_CHECK + namespace rdr { class InStream { @@ -39,29 +44,79 @@ namespace rdr { // avail() returns the number of bytes that are currenctly directly // available from the stream. - inline size_t avail() - { + inline size_t avail() { +#ifdef RFB_INSTREAM_CHECK + checkedBytes = end - ptr; +#endif + return end - ptr; } - // check() ensures there is buffer data for at least needed bytes. Returns - // true once the data is available. If wait is false, then instead of - // blocking to wait for the bytes, false is returned if the bytes are not - // immediately available. + // hasData() ensures there is at least "length" bytes of buffer data, + // possibly trying to fetch more data if there isn't enough right away + + inline bool hasData(size_t length) { +#ifdef RFB_INSTREAM_CHECK + checkedBytes = 0; +#endif + + if (length > (size_t)(end - ptr)) { + if (restorePoint != NULL) { + bool ret; + size_t restoreDiff; + + restoreDiff = ptr - restorePoint; + ptr = restorePoint; + + ret = overrun(length + restoreDiff); - inline size_t check(size_t needed, bool wait=true) - { - if (needed > avail()) - return overrun(needed, wait); + restorePoint = ptr; + ptr += restoreDiff; + + if (!ret) + return false; + } else { + if (!overrun(length)) + return false; + } + } + +#ifdef RFB_INSTREAM_CHECK + checkedBytes = length; +#endif return true; } - // checkNoWait() tries to make sure that the given number of bytes can - // be read without blocking. It returns true if this is the case, false - // otherwise. The length must be "small" (less than the buffer size). + inline bool hasDataOrRestore(size_t length) { + if (hasData(length)) + return true; + gotoRestorePoint(); + return false; + } - inline bool checkNoWait(size_t length) { return check(length, false); } + inline void setRestorePoint() { +#ifdef RFB_INSTREAM_CHECK + if (restorePoint != NULL) + throw Exception("Nested use of input stream restore point"); +#endif + restorePoint = ptr; + } + inline void clearRestorePoint() { +#ifdef RFB_INSTREAM_CHECK + if (restorePoint == NULL) + throw Exception("Incorrect clearing of input stream restore point"); +#endif + restorePoint = NULL; + } + inline void gotoRestorePoint() { +#ifdef RFB_INSTREAM_CHECK + if (restorePoint == NULL) + throw Exception("Incorrect activation of input stream restore point"); +#endif + ptr = restorePoint; + clearRestorePoint(); + } // readU/SN() methods read unsigned and signed N-bit integers. @@ -76,24 +131,19 @@ namespace rdr { inline S16 readS16() { return (S16)readU16(); } inline S32 readS32() { return (S32)readU32(); } + // skip() ignores a number of bytes on the stream + inline void skip(size_t bytes) { - while (bytes > 0) { - size_t n = check(1, bytes); - ptr += n; - bytes -= n; - } + check(bytes); + ptr += bytes; } // readBytes() reads an exact number of bytes. void readBytes(void* data, size_t length) { - while (length > 0) { - size_t n = check(1, length); - memcpy(data, ptr, n); - ptr += n; - data = (U8*)data + n; - length -= n; - } + check(length); + memcpy(data, ptr, length); + ptr += length; } // readOpaqueN() reads a quantity without byte-swapping. @@ -113,24 +163,45 @@ namespace rdr { // to the buffer. This is useful for a stream which is a wrapper around an // some other stream API. - inline const U8* getptr(size_t length) { check(length); return ptr; } + inline const U8* getptr(size_t length) { check(length); +#ifdef RFB_INSTREAM_CHECK + checkedBytes += length; +#endif + return ptr; } inline void setptr(size_t length) { if (length > avail()) throw Exception("Input stream overflow"); skip(length); } private: + const U8* restorePoint; +#ifdef RFB_INSTREAM_CHECK + size_t checkedBytes; +#endif + + inline void check(size_t bytes) { +#ifdef RFB_INSTREAM_CHECK + if (bytes > checkedBytes) + throw Exception("Input stream used without underrun check"); + checkedBytes -= bytes; +#endif + if (bytes > (size_t)(end - ptr)) + throw Exception("InStream buffer underrun"); + } + // overrun() is implemented by a derived class to cope with buffer overrun. - // It ensures there are at least needed bytes of buffer data. Returns true - // once the data is available. If wait is false, then instead of blocking - // to wait for the bytes, false is returned if the bytes are not - // immediately available. + // It tries to ensure there are at least needed bytes of buffer data. + // Returns true if it managed to satisfy the request, or false otherwise. - virtual bool overrun(size_t needed, bool wait=true) = 0; + virtual bool overrun(size_t needed) = 0; protected: - InStream() {} + InStream() : restorePoint(NULL) +#ifdef RFB_INSTREAM_CHECK + ,checkedBytes(0) +#endif + {} const U8* ptr; const U8* end; }; diff --git a/common/rdr/MemInStream.h b/common/rdr/MemInStream.h index 83740dd9..a5196594 100644 --- a/common/rdr/MemInStream.h +++ b/common/rdr/MemInStream.h @@ -41,6 +41,12 @@ namespace rdr { { ptr = start; end = start + len; + +#ifdef RFB_INSTREAM_CHECK + // MemInStream cannot add more data, so callers are assumed to already + // new the total size + avail(); +#endif } virtual ~MemInStream() { @@ -53,7 +59,7 @@ namespace rdr { private: - bool overrun(size_t needed, bool wait) { throw EndOfStream(); } + bool overrun(size_t needed) { throw EndOfStream(); } const U8* start; bool deleteWhenDone; }; diff --git a/common/rdr/OutStream.h b/common/rdr/OutStream.h index f432520f..61d5100b 100644 --- a/common/rdr/OutStream.h +++ b/common/rdr/OutStream.h @@ -49,14 +49,6 @@ namespace rdr { return end - ptr; } - // check() ensures there is buffer space for at least needed bytes. - - inline void check(size_t needed) - { - if (needed > avail()) - overrun(needed); - } - // writeU/SN() methods write unsigned and signed N-bit integers. inline void writeU8( U8 u) { check(1); *ptr++ = u; } @@ -136,6 +128,12 @@ namespace rdr { private: + inline void check(size_t length) + { + if (length > avail()) + overrun(length); + } + // overrun() is implemented by a derived class to cope with buffer overrun. // It ensures there are at least needed bytes of buffer space. diff --git a/common/rdr/RandomStream.cxx b/common/rdr/RandomStream.cxx index 6333be3f..e2da0957 100644 --- a/common/rdr/RandomStream.cxx +++ b/common/rdr/RandomStream.cxx @@ -79,7 +79,7 @@ RandomStream::~RandomStream() { #endif } -bool RandomStream::fillBuffer(size_t maxSize, bool wait) { +bool RandomStream::fillBuffer(size_t maxSize) { #ifdef RFB_HAVE_WINCRYPT if (provider) { if (!CryptGenRandom(provider, maxSize, (U8*)end)) diff --git a/common/rdr/RandomStream.h b/common/rdr/RandomStream.h index 08ae0ff6..58986433 100644 --- a/common/rdr/RandomStream.h +++ b/common/rdr/RandomStream.h @@ -40,7 +40,7 @@ namespace rdr { virtual ~RandomStream(); private: - virtual bool fillBuffer(size_t maxSize, bool wait); + virtual bool fillBuffer(size_t maxSize); private: static unsigned int seed; diff --git a/common/rdr/TLSInStream.cxx b/common/rdr/TLSInStream.cxx index ba20e752..2339956d 100644 --- a/common/rdr/TLSInStream.cxx +++ b/common/rdr/TLSInStream.cxx @@ -39,7 +39,7 @@ ssize_t TLSInStream::pull(gnutls_transport_ptr_t str, void* data, size_t size) InStream *in = self->in; try { - if (!in->check(1, false)) { + if (!in->hasData(1)) { gnutls_transport_set_errno(self->session, EAGAIN); return -1; } @@ -74,23 +74,22 @@ TLSInStream::~TLSInStream() gnutls_transport_set_pull_function(session, NULL); } -bool TLSInStream::fillBuffer(size_t maxSize, bool wait) +bool TLSInStream::fillBuffer(size_t maxSize) { - size_t n = readTLS((U8*) end, maxSize, wait); - if (!wait && n == 0) + size_t n = readTLS((U8*) end, maxSize); + if (n == 0) return false; end += n; return true; } -size_t TLSInStream::readTLS(U8* buf, size_t len, bool wait) +size_t TLSInStream::readTLS(U8* buf, size_t len) { int n; if (gnutls_record_check_pending(session) == 0) { - n = in->check(1, wait); - if (n == 0) + if (!in->hasData(1)) return 0; } diff --git a/common/rdr/TLSInStream.h b/common/rdr/TLSInStream.h index 9779c68e..df5ebb48 100644 --- a/common/rdr/TLSInStream.h +++ b/common/rdr/TLSInStream.h @@ -37,8 +37,8 @@ namespace rdr { virtual ~TLSInStream(); private: - virtual bool fillBuffer(size_t maxSize, bool wait); - size_t readTLS(U8* buf, size_t len, bool wait); + virtual bool fillBuffer(size_t maxSize); + size_t readTLS(U8* buf, size_t len); static ssize_t pull(gnutls_transport_ptr_t str, void* data, size_t size); gnutls_session_t session; diff --git a/common/rdr/ZlibInStream.cxx b/common/rdr/ZlibInStream.cxx index 26977228..0cacc21f 100644 --- a/common/rdr/ZlibInStream.cxx +++ b/common/rdr/ZlibInStream.cxx @@ -45,7 +45,7 @@ void ZlibInStream::setUnderlying(InStream* is, size_t bytesIn_) void ZlibInStream::flushUnderlying() { while (bytesIn > 0) { - if (!check(1)) + if (!hasData(1)) throw Exception("ZlibInStream: failed to flush remaining stream data"); skip(avail()); } @@ -85,7 +85,7 @@ void ZlibInStream::deinit() zs = NULL; } -bool ZlibInStream::fillBuffer(size_t maxSize, bool wait) +bool ZlibInStream::fillBuffer(size_t maxSize) { if (!underlying) throw Exception("ZlibInStream overrun: no underlying stream"); @@ -93,8 +93,8 @@ bool ZlibInStream::fillBuffer(size_t maxSize, bool wait) zs->next_out = (U8*)end; zs->avail_out = maxSize; - size_t n = underlying->check(1, wait); - if (n == 0) return false; + if (!underlying->hasData(1)) + return false; size_t length = underlying->avail(); if (length > bytesIn) length = bytesIn; diff --git a/common/rdr/ZlibInStream.h b/common/rdr/ZlibInStream.h index 1597b54a..302c99d8 100644 --- a/common/rdr/ZlibInStream.h +++ b/common/rdr/ZlibInStream.h @@ -44,7 +44,7 @@ namespace rdr { void init(); void deinit(); - virtual bool fillBuffer(size_t maxSize, bool wait); + virtual bool fillBuffer(size_t maxSize); private: InStream* underlying; diff --git a/common/rfb/CConnection.cxx b/common/rfb/CConnection.cxx index cb1b84bd..9c957a54 100644 --- a/common/rfb/CConnection.cxx +++ b/common/rfb/CConnection.cxx @@ -120,16 +120,19 @@ void CConnection::initialiseProtocol() state_ = RFBSTATE_PROTOCOL_VERSION; } -void CConnection::processMsg() +bool CConnection::processMsg() { switch (state_) { - case RFBSTATE_PROTOCOL_VERSION: processVersionMsg(); break; - case RFBSTATE_SECURITY_TYPES: processSecurityTypesMsg(); break; - case RFBSTATE_SECURITY: processSecurityMsg(); break; - case RFBSTATE_SECURITY_RESULT: processSecurityResultMsg(); break; - case RFBSTATE_INITIALISATION: processInitMsg(); break; - case RFBSTATE_NORMAL: reader_->readMsg(); break; + case RFBSTATE_PROTOCOL_VERSION: return processVersionMsg(); break; + case RFBSTATE_SECURITY_TYPES: return processSecurityTypesMsg(); break; + case RFBSTATE_SECURITY: return processSecurityMsg(); break; + case RFBSTATE_SECURITY_RESULT: return processSecurityResultMsg(); break; + case RFBSTATE_SECURITY_REASON: return processSecurityReasonMsg(); break; + case RFBSTATE_INITIALISATION: return processInitMsg(); break; + case RFBSTATE_NORMAL: return reader_->readMsg(); break; + case RFBSTATE_CLOSING: + throw Exception("CConnection::processMsg: called while closing"); case RFBSTATE_UNINITIALISED: throw Exception("CConnection::processMsg: not initialised yet?"); default: @@ -137,7 +140,7 @@ void CConnection::processMsg() } } -void CConnection::processVersionMsg() +bool CConnection::processVersionMsg() { char verStr[27]; // FIXME: gcc has some bug in format-overflow int majorVersion; @@ -145,8 +148,8 @@ void CConnection::processVersionMsg() vlog.debug("reading protocol version"); - if (!is->checkNoWait(12)) - return; + if (!is->hasData(12)) + return false; is->readBytes(verStr, 12); verStr[12] = '\0'; @@ -184,10 +187,12 @@ void CConnection::processVersionMsg() vlog.info("Using RFB protocol version %d.%d", server.majorVersion, server.minorVersion); + + return true; } -void CConnection::processSecurityTypesMsg() +bool CConnection::processSecurityTypesMsg() { vlog.debug("processing security types message"); @@ -200,10 +205,13 @@ void CConnection::processSecurityTypesMsg() // legacy 3.3 server may only offer "vnc authentication" or "none" + if (!is->hasData(4)) + return false; + secType = is->readU32(); if (secType == secTypeInvalid) { - throwConnFailedException(); - + state_ = RFBSTATE_SECURITY_REASON; + return true; } else if (secType == secTypeNone || secType == secTypeVncAuth) { std::list::iterator i; for (i = secTypes.begin(); i != secTypes.end(); i++) @@ -223,9 +231,21 @@ void CConnection::processSecurityTypesMsg() // >=3.7 server will offer us a list + if (!is->hasData(1)) + return false; + + is->setRestorePoint(); + int nServerSecTypes = is->readU8(); - if (nServerSecTypes == 0) - throwConnFailedException(); + + if (!is->hasDataOrRestore(nServerSecTypes)) + return false; + is->clearRestorePoint(); + + if (nServerSecTypes == 0) { + state_ = RFBSTATE_SECURITY_REASON; + return true; + } std::list::iterator j; @@ -263,32 +283,38 @@ void CConnection::processSecurityTypesMsg() state_ = RFBSTATE_SECURITY; csecurity = security.GetCSecurity(this, secType); - processSecurityMsg(); + + return true; } -void CConnection::processSecurityMsg() +bool CConnection::processSecurityMsg() { vlog.debug("processing security message"); - if (csecurity->processMsg()) { - state_ = RFBSTATE_SECURITY_RESULT; - processSecurityResultMsg(); - } + if (!csecurity->processMsg()) + return false; + + state_ = RFBSTATE_SECURITY_RESULT; + + return true; } -void CConnection::processSecurityResultMsg() +bool CConnection::processSecurityResultMsg() { vlog.debug("processing security result message"); int result; + if (server.beforeVersion(3,8) && csecurity->getType() == secTypeNone) { result = secResultOK; } else { - if (!is->checkNoWait(1)) return; + if (!is->hasData(4)) + return false; result = is->readU32(); } + switch (result) { case secResultOK: securityCompleted(); - return; + return true; case secResultFailed: vlog.debug("auth failed"); break; @@ -298,30 +324,42 @@ void CConnection::processSecurityResultMsg() default: throw Exception("Unknown security result from server"); } - state_ = RFBSTATE_INVALID; - if (server.beforeVersion(3,8)) + + if (server.beforeVersion(3,8)) { + state_ = RFBSTATE_INVALID; throw AuthFailureException(); + } + + state_ = RFBSTATE_SECURITY_REASON; + return true; +} + +bool CConnection::processSecurityReasonMsg() +{ + vlog.debug("processing security reason message"); + + if (!is->hasData(4)) + return false; + + is->setRestorePoint(); + rdr::U32 len = is->readU32(); + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + CharArray reason(len + 1); is->readBytes(reason.buf, len); reason.buf[len] = '\0'; + + state_ = RFBSTATE_INVALID; throw AuthFailureException(reason.buf); } -void CConnection::processInitMsg() +bool CConnection::processInitMsg() { vlog.debug("reading server initialisation"); - reader_->readServerInit(); -} - -void CConnection::throwConnFailedException() -{ - state_ = RFBSTATE_INVALID; - rdr::U32 len = is->readU32(); - CharArray reason(len + 1); - is->readBytes(reason.buf, len); - reason.buf[len] = '\0'; - throw ConnFailedException(reason.buf); + return reader_->readServerInit(); } void CConnection::securityCompleted() @@ -429,11 +467,13 @@ void CConnection::serverInit(int width, int height, } } -void CConnection::readAndDecodeRect(const Rect& r, int encoding, +bool CConnection::readAndDecodeRect(const Rect& r, int encoding, ModifiablePixelBuffer* pb) { - decoder.decodeRect(r, encoding, pb); + if (!decoder.decodeRect(r, encoding, pb)) + return false; decoder.flush(); + return true; } void CConnection::framebufferUpdateStart() @@ -474,9 +514,9 @@ void CConnection::framebufferUpdateEnd() } } -void CConnection::dataRect(const Rect& r, int encoding) +bool CConnection::dataRect(const Rect& r, int encoding) { - decoder.decodeRect(r, encoding, framebuffer); + return decoder.decodeRect(r, encoding, framebuffer); } void CConnection::serverCutText(const char* str) diff --git a/common/rfb/CConnection.h b/common/rfb/CConnection.h index 3857be4d..3c2b04e8 100644 --- a/common/rfb/CConnection.h +++ b/common/rfb/CConnection.h @@ -84,7 +84,7 @@ namespace rfb { // In this case, processMsg should always process the available RFB // message before returning. // NB: In either case, you must have called initialiseProtocol() first. - void processMsg(); + bool processMsg(); // close() gracefully shuts down the connection to the server and // should be called before terminating the underlying network @@ -107,12 +107,12 @@ namespace rfb { const PixelFormat& pf, const char* name); - virtual void readAndDecodeRect(const Rect& r, int encoding, + virtual bool readAndDecodeRect(const Rect& r, int encoding, ModifiablePixelBuffer* pb); virtual void framebufferUpdateStart(); virtual void framebufferUpdateEnd(); - virtual void dataRect(const Rect& r, int encoding); + virtual bool dataRect(const Rect& r, int encoding); virtual void serverCutText(const char* str); @@ -216,6 +216,7 @@ namespace rfb { RFBSTATE_SECURITY_TYPES, RFBSTATE_SECURITY, RFBSTATE_SECURITY_RESULT, + RFBSTATE_SECURITY_REASON, RFBSTATE_INITIALISATION, RFBSTATE_NORMAL, RFBSTATE_CLOSING, @@ -249,13 +250,13 @@ namespace rfb { virtual void fence(rdr::U32 flags, unsigned len, const char data[]); private: - void processVersionMsg(); - void processSecurityTypesMsg(); - void processSecurityMsg(); - void processSecurityResultMsg(); - void processInitMsg(); + bool processVersionMsg(); + bool processSecurityTypesMsg(); + bool processSecurityMsg(); + bool processSecurityResultMsg(); + bool processSecurityReasonMsg(); + bool processInitMsg(); void throwAuthFailureException(); - void throwConnFailedException(); void securityCompleted(); void requestNewUpdate(); diff --git a/common/rfb/CMsgHandler.h b/common/rfb/CMsgHandler.h index 84dd115c..5b14806a 100644 --- a/common/rfb/CMsgHandler.h +++ b/common/rfb/CMsgHandler.h @@ -61,12 +61,12 @@ namespace rfb { const PixelFormat& pf, const char* name) = 0; - virtual void readAndDecodeRect(const Rect& r, int encoding, + virtual bool readAndDecodeRect(const Rect& r, int encoding, ModifiablePixelBuffer* pb) = 0; virtual void framebufferUpdateStart(); virtual void framebufferUpdateEnd(); - virtual void dataRect(const Rect& r, int encoding) = 0; + virtual bool dataRect(const Rect& r, int encoding) = 0; virtual void setColourMapEntries(int firstColour, int nColours, rdr::U16* rgbs) = 0; diff --git a/common/rfb/CMsgReader.cxx b/common/rfb/CMsgReader.cxx index a015ec99..40fb5912 100644 --- a/common/rfb/CMsgReader.cxx +++ b/common/rfb/CMsgReader.cxx @@ -39,7 +39,7 @@ using namespace rfb; CMsgReader::CMsgReader(CMsgHandler* handler_, rdr::InStream* is_) : imageBufIdealSize(0), handler(handler_), is(is_), - nUpdateRectsLeft(0) + state(MSGSTATE_IDLE), cursorEncoding(-1) { } @@ -47,149 +47,246 @@ CMsgReader::~CMsgReader() { } -void CMsgReader::readServerInit() +bool CMsgReader::readServerInit() { - int width = is->readU16(); - int height = is->readU16(); + int width, height; + rdr::U32 len; + + if (!is->hasData(2 + 2 + 16 + 4)) + return false; + + is->setRestorePoint(); + + width = is->readU16(); + height = is->readU16(); + PixelFormat pf; pf.read(is); - rdr::U32 len = is->readU32(); + + len = is->readU32(); + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); CharArray name(len + 1); is->readBytes(name.buf, len); name.buf[len] = '\0'; handler->serverInit(width, height, pf, name.buf); + + return true; } -void CMsgReader::readMsg() +bool CMsgReader::readMsg() { - if (nUpdateRectsLeft == 0) { - int type = is->readU8(); + if (state == MSGSTATE_IDLE) { + if (!is->hasData(1)) + return false; + + currentMsgType = is->readU8(); + state = MSGSTATE_MESSAGE; + } + + if (currentMsgType != msgTypeFramebufferUpdate) { + bool ret; - switch (type) { + switch (currentMsgType) { case msgTypeSetColourMapEntries: - readSetColourMapEntries(); + ret = readSetColourMapEntries(); break; case msgTypeBell: - readBell(); + ret = readBell(); break; case msgTypeServerCutText: - readServerCutText(); + ret = readServerCutText(); break; case msgTypeFramebufferUpdate: - readFramebufferUpdate(); + ret = readFramebufferUpdate(); break; case msgTypeServerFence: - readFence(); + ret = readFence(); break; case msgTypeEndOfContinuousUpdates: - readEndOfContinuousUpdates(); + ret = readEndOfContinuousUpdates(); break; default: - vlog.error("unknown message type %d", type); - throw Exception("unknown message type"); + throw Exception("Unknown message type %d", currentMsgType); } + + if (ret) + state = MSGSTATE_IDLE; + + return ret; } else { - int x = is->readU16(); - int y = is->readU16(); - int w = is->readU16(); - int h = is->readU16(); - int encoding = is->readS32(); + if (state == MSGSTATE_MESSAGE) { + if (!readFramebufferUpdate()) + return false; + + // Empty update? + if (nUpdateRectsLeft == 0) { + state = MSGSTATE_IDLE; + handler->framebufferUpdateEnd(); + return true; + } + + state = MSGSTATE_RECT_HEADER; + } + + if (state == MSGSTATE_RECT_HEADER) { + if (!is->hasData(12)) + return false; + + int x = is->readU16(); + int y = is->readU16(); + int w = is->readU16(); + int h = is->readU16(); + + dataRect.setXYWH(x, y, w, h); + + rectEncoding = is->readS32(); + + state = MSGSTATE_RECT_DATA; + } + + bool ret; - switch (encoding) { + switch (rectEncoding) { case pseudoEncodingLastRect: nUpdateRectsLeft = 1; // this rectangle is the last one + ret = true; break; case pseudoEncodingXCursor: - readSetXCursor(w, h, Point(x,y)); + ret = readSetXCursor(dataRect.width(), dataRect.height(), dataRect.tl); break; case pseudoEncodingCursor: - readSetCursor(w, h, Point(x,y)); + ret = readSetCursor(dataRect.width(), dataRect.height(), dataRect.tl); break; case pseudoEncodingCursorWithAlpha: - readSetCursorWithAlpha(w, h, Point(x,y)); + ret = readSetCursorWithAlpha(dataRect.width(), dataRect.height(), dataRect.tl); break; case pseudoEncodingVMwareCursor: - readSetVMwareCursor(w, h, Point(x,y)); + ret = readSetVMwareCursor(dataRect.width(), dataRect.height(), dataRect.tl); break; case pseudoEncodingDesktopName: - readSetDesktopName(x, y, w, h); + ret = readSetDesktopName(dataRect.tl.x, dataRect.tl.y, + dataRect.width(), dataRect.height()); break; case pseudoEncodingDesktopSize: - handler->setDesktopSize(w, h); + handler->setDesktopSize(dataRect.width(), dataRect.height()); + ret = true; break; case pseudoEncodingExtendedDesktopSize: - readExtendedDesktopSize(x, y, w, h); + ret = readExtendedDesktopSize(dataRect.tl.x, dataRect.tl.y, + dataRect.width(), dataRect.height()); break; case pseudoEncodingLEDState: - readLEDState(); + ret = readLEDState(); break; case pseudoEncodingVMwareLEDState: - readVMwareLEDState(); + ret = readVMwareLEDState(); break; case pseudoEncodingQEMUKeyEvent: handler->supportsQEMUKeyEvent(); + ret = true; break; default: - readRect(Rect(x, y, x+w, y+h), encoding); + ret = readRect(dataRect, rectEncoding); break; }; - nUpdateRectsLeft--; - if (nUpdateRectsLeft == 0) - handler->framebufferUpdateEnd(); + if (ret) { + state = MSGSTATE_RECT_HEADER; + nUpdateRectsLeft--; + if (nUpdateRectsLeft == 0) { + state = MSGSTATE_IDLE; + handler->framebufferUpdateEnd(); + } + } + + return ret; } } -void CMsgReader::readSetColourMapEntries() +bool CMsgReader::readSetColourMapEntries() { + if (!is->hasData(5)) + return false; + + is->setRestorePoint(); + is->skip(1); int firstColour = is->readU16(); int nColours = is->readU16(); + + if (!is->hasDataOrRestore(nColours * 3 * 2)) + return false; + is->clearRestorePoint(); + rdr::U16Array rgbs(nColours * 3); for (int i = 0; i < nColours * 3; i++) rgbs.buf[i] = is->readU16(); handler->setColourMapEntries(firstColour, nColours, rgbs.buf); + + return true; } -void CMsgReader::readBell() +bool CMsgReader::readBell() { handler->bell(); + return true; } -void CMsgReader::readServerCutText() +bool CMsgReader::readServerCutText() { + if (!is->hasData(7)) + return false; + + is->setRestorePoint(); + is->skip(3); rdr::U32 len = is->readU32(); if (len & 0x80000000) { rdr::S32 slen = len; slen = -slen; - readExtendedClipboard(slen); - return; + if (readExtendedClipboard(slen)) { + is->clearRestorePoint(); + return true; + } else { + is->gotoRestorePoint(); + return false; + } } + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + if (len > (size_t)maxCutText) { is->skip(len); vlog.error("cut text too long (%d bytes) - ignoring",len); - return; + return true; } CharArray ca(len); is->readBytes(ca.buf, len); CharArray filtered(convertLF(ca.buf, len)); handler->serverCutText(filtered.buf); + + return true; } -void CMsgReader::readExtendedClipboard(rdr::S32 len) +bool CMsgReader::readExtendedClipboard(rdr::S32 len) { rdr::U32 flags; rdr::U32 action; + if (!is->hasData(len)) + return false; + if (len < 4) throw Exception("Invalid extended clipboard message"); if (len > maxCutText) { vlog.error("Extended clipboard message too long (%d bytes) - ignoring", len); is->skip(len); - return; + return true; } flags = is->readU32(); @@ -231,7 +328,14 @@ void CMsgReader::readExtendedClipboard(rdr::S32 len) if (!(flags & 1 << i)) continue; + if (!zis.hasData(4)) + throw Exception("Extended clipboard decode error"); + lengths[num] = zis.readU32(); + + if (!zis.hasData(lengths[num])) + throw Exception("Extended clipboard decode error"); + if (lengths[num] > (size_t)maxCutText) { vlog.error("Extended clipboard data too long (%d bytes) - ignoring", (unsigned)lengths[num]); @@ -271,43 +375,63 @@ void CMsgReader::readExtendedClipboard(rdr::S32 len) throw Exception("Invalid extended clipboard action"); } } + + return true; } -void CMsgReader::readFence() +bool CMsgReader::readFence() { rdr::U32 flags; rdr::U8 len; char data[64]; + if (!is->hasData(8)) + return false; + + is->setRestorePoint(); + is->skip(3); flags = is->readU32(); len = is->readU8(); + + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + if (len > sizeof(data)) { vlog.error("Ignoring fence with too large payload"); is->skip(len); - return; + return true; } is->readBytes(data, len); handler->fence(flags, len, data); + + return true; } -void CMsgReader::readEndOfContinuousUpdates() +bool CMsgReader::readEndOfContinuousUpdates() { handler->endOfContinuousUpdates(); + return true; } -void CMsgReader::readFramebufferUpdate() +bool CMsgReader::readFramebufferUpdate() { + if (!is->hasData(3)) + return false; + is->skip(1); nUpdateRectsLeft = is->readU16(); handler->framebufferUpdateStart(); + + return true; } -void CMsgReader::readRect(const Rect& r, int encoding) +bool CMsgReader::readRect(const Rect& r, int encoding) { if ((r.br.x > handler->server.width()) || (r.br.y > handler->server.height())) { @@ -320,10 +444,10 @@ void CMsgReader::readRect(const Rect& r, int encoding) if (r.is_empty()) vlog.error("zero size rect"); - handler->dataRect(r, encoding); + return handler->dataRect(r, encoding); } -void CMsgReader::readSetXCursor(int width, int height, const Point& hotspot) +bool CMsgReader::readSetXCursor(int width, int height, const Point& hotspot) { if (width > maxCursorSize || height > maxCursorSize) throw Exception("Too big cursor"); @@ -341,6 +465,9 @@ void CMsgReader::readSetXCursor(int width, int height, const Point& hotspot) int x, y; rdr::U8* out; + if (!is->hasData(3 + 3 + data_len + mask_len)) + return false; + pr = is->readU8(); pg = is->readU8(); pb = is->readU8(); @@ -380,9 +507,11 @@ void CMsgReader::readSetXCursor(int width, int height, const Point& hotspot) } handler->setCursor(width, height, hotspot, rgba.buf); + + return true; } -void CMsgReader::readSetCursor(int width, int height, const Point& hotspot) +bool CMsgReader::readSetCursor(int width, int height, const Point& hotspot) { if (width > maxCursorSize || height > maxCursorSize) throw Exception("Too big cursor"); @@ -397,6 +526,9 @@ void CMsgReader::readSetCursor(int width, int height, const Point& hotspot) rdr::U8* in; rdr::U8* out; + if (!is->hasData(data_len + mask_len)) + return false; + is->readBytes(data.buf, data_len); is->readBytes(mask.buf, mask_len); @@ -421,29 +553,44 @@ void CMsgReader::readSetCursor(int width, int height, const Point& hotspot) } handler->setCursor(width, height, hotspot, rgba.buf); + + return true; } -void CMsgReader::readSetCursorWithAlpha(int width, int height, const Point& hotspot) +bool CMsgReader::readSetCursorWithAlpha(int width, int height, const Point& hotspot) { if (width > maxCursorSize || height > maxCursorSize) throw Exception("Too big cursor"); - int encoding; - const PixelFormat rgbaPF(32, 32, false, true, 255, 255, 255, 16, 8, 0); ManagedPixelBuffer pb(rgbaPF, width, height); PixelFormat origPF; + bool ret; + rdr::U8* buf; int stride; - encoding = is->readS32(); + // We can't use restore points as the decoder likely wants to as well, so + // we need to keep track of the read encoding + + if (cursorEncoding == -1) { + if (!is->hasData(4)) + return false; + + cursorEncoding = is->readS32(); + } origPF = handler->server.pf(); handler->server.setPF(rgbaPF); - handler->readAndDecodeRect(pb.getRect(), encoding, &pb); + ret = handler->readAndDecodeRect(pb.getRect(), cursorEncoding, &pb); handler->server.setPF(origPF); + if (!ret) + return false; + + cursorEncoding = -1; + // On-wire data has pre-multiplied alpha, but we store it // non-pre-multiplied buf = pb.getBufferRW(pb.getRect(), &stride); @@ -467,18 +614,25 @@ void CMsgReader::readSetCursorWithAlpha(int width, int height, const Point& hots handler->setCursor(width, height, hotspot, pb.getBuffer(pb.getRect(), &stride)); + + return true; } -void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot) +bool CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot) { if (width > maxCursorSize || height > maxCursorSize) throw Exception("Too big cursor"); rdr::U8 type; + if (!is->hasData(2)) + return false; + type = is->readU8(); is->skip(1); + is->setRestorePoint(); + if (type == 0) { int len = width * height * (handler->server.pf().bpp/8); rdr::U8Array andMask(len); @@ -491,6 +645,10 @@ void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot rdr::U8* out; int Bpp; + if (!is->hasDataOrRestore(len + len)) + return false; + is->clearRestorePoint(); + is->readBytes(andMask.buf, len); is->readBytes(xorMask.buf, len); @@ -548,6 +706,10 @@ void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot } else if (type == 1) { rdr::U8Array data(width*height*4); + if (!is->hasDataOrRestore(width*height*4)) + return false; + is->clearRestorePoint(); + // FIXME: Is alpha premultiplied? is->readBytes(data.buf, width*height*4); @@ -555,11 +717,25 @@ void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot } else { throw Exception("Unknown cursor type"); } + + return true; } -void CMsgReader::readSetDesktopName(int x, int y, int w, int h) +bool CMsgReader::readSetDesktopName(int x, int y, int w, int h) { - rdr::U32 len = is->readU32(); + rdr::U32 len; + + if (!is->hasData(4)) + return false; + + is->setRestorePoint(); + + len = is->readU32(); + + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + CharArray name(len + 1); is->readBytes(name.buf, len); name.buf[len] = '\0'; @@ -569,18 +745,29 @@ void CMsgReader::readSetDesktopName(int x, int y, int w, int h) } else { handler->setName(name.buf); } + + return true; } -void CMsgReader::readExtendedDesktopSize(int x, int y, int w, int h) +bool CMsgReader::readExtendedDesktopSize(int x, int y, int w, int h) { unsigned int screens, i; rdr::U32 id, flags; int sx, sy, sw, sh; ScreenSet layout; + if (!is->hasData(4)) + return false; + + is->setRestorePoint(); + screens = is->readU8(); is->skip(3); + if (!is->hasDataOrRestore(16 * screens)) + return false; + is->clearRestorePoint(); + for (i = 0;i < screens;i++) { id = is->readU32(); sx = is->readU16(); @@ -593,25 +780,37 @@ void CMsgReader::readExtendedDesktopSize(int x, int y, int w, int h) } handler->setExtendedDesktopSize(x, y, w, h, layout); + + return true; } -void CMsgReader::readLEDState() +bool CMsgReader::readLEDState() { rdr::U8 state; + if (!is->hasData(1)) + return false; + state = is->readU8(); handler->setLEDState(state); + + return true; } -void CMsgReader::readVMwareLEDState() +bool CMsgReader::readVMwareLEDState() { rdr::U32 state; + if (!is->hasData(4)) + return false; + state = is->readU32(); // As luck has it, this extension uses the same bit definitions, // so no conversion required handler->setLEDState(state); + + return true; } diff --git a/common/rfb/CMsgReader.h b/common/rfb/CMsgReader.h index 050990a9..ab55aed8 100644 --- a/common/rfb/CMsgReader.h +++ b/common/rfb/CMsgReader.h @@ -40,39 +40,55 @@ namespace rfb { CMsgReader(CMsgHandler* handler, rdr::InStream* is); virtual ~CMsgReader(); - void readServerInit(); + bool readServerInit(); // readMsg() reads a message, calling the handler as appropriate. - void readMsg(); + bool readMsg(); rdr::InStream* getInStream() { return is; } int imageBufIdealSize; protected: - void readSetColourMapEntries(); - void readBell(); - void readServerCutText(); - void readExtendedClipboard(rdr::S32 len); - void readFence(); - void readEndOfContinuousUpdates(); - - void readFramebufferUpdate(); - - void readRect(const Rect& r, int encoding); - - void readSetXCursor(int width, int height, const Point& hotspot); - void readSetCursor(int width, int height, const Point& hotspot); - void readSetCursorWithAlpha(int width, int height, const Point& hotspot); - void readSetVMwareCursor(int width, int height, const Point& hotspot); - void readSetDesktopName(int x, int y, int w, int h); - void readExtendedDesktopSize(int x, int y, int w, int h); - void readLEDState(); - void readVMwareLEDState(); - + bool readSetColourMapEntries(); + bool readBell(); + bool readServerCutText(); + bool readExtendedClipboard(rdr::S32 len); + bool readFence(); + bool readEndOfContinuousUpdates(); + + bool readFramebufferUpdate(); + + bool readRect(const Rect& r, int encoding); + + bool readSetXCursor(int width, int height, const Point& hotspot); + bool readSetCursor(int width, int height, const Point& hotspot); + bool readSetCursorWithAlpha(int width, int height, const Point& hotspot); + bool readSetVMwareCursor(int width, int height, const Point& hotspot); + bool readSetDesktopName(int x, int y, int w, int h); + bool readExtendedDesktopSize(int x, int y, int w, int h); + bool readLEDState(); + bool readVMwareLEDState(); + + private: CMsgHandler* handler; rdr::InStream* is; + + enum stateEnum { + MSGSTATE_IDLE, + MSGSTATE_MESSAGE, + MSGSTATE_RECT_HEADER, + MSGSTATE_RECT_DATA, + }; + + stateEnum state; + + rdr::U8 currentMsgType; int nUpdateRectsLeft; + Rect dataRect; + int rectEncoding; + + int cursorEncoding; static const int maxCursorSize = 256; }; diff --git a/common/rfb/CSecurityTLS.cxx b/common/rfb/CSecurityTLS.cxx index 374ec7f3..4fcaa7a9 100644 --- a/common/rfb/CSecurityTLS.cxx +++ b/common/rfb/CSecurityTLS.cxx @@ -154,7 +154,7 @@ bool CSecurityTLS::processMsg() client = cc; if (!session) { - if (!is->checkNoWait(1)) + if (!is->hasData(1)) return false; if (is->readU8() == 0) @@ -180,8 +180,10 @@ bool CSecurityTLS::processMsg() int err; err = gnutls_handshake(session); if (err != GNUTLS_E_SUCCESS) { - if (!gnutls_error_is_fatal(err)) + if (!gnutls_error_is_fatal(err)) { + vlog.debug("Deferring completion of TLS handshake: %s", gnutls_strerror(err)); return false; + } vlog.error("TLS Handshake failed: %s\n", gnutls_strerror (err)); shutdown(false); diff --git a/common/rfb/CSecurityVeNCrypt.cxx b/common/rfb/CSecurityVeNCrypt.cxx index 22201dd2..98dad494 100644 --- a/common/rfb/CSecurityVeNCrypt.cxx +++ b/common/rfb/CSecurityVeNCrypt.cxx @@ -51,7 +51,6 @@ CSecurityVeNCrypt::CSecurityVeNCrypt(CConnection* cc, SecurityClient* sec) chosenType = secTypeVeNCrypt; nAvailableTypes = 0; availableTypes = NULL; - iAvailableType = 0; } CSecurityVeNCrypt::~CSecurityVeNCrypt() @@ -64,16 +63,20 @@ bool CSecurityVeNCrypt::processMsg() { InStream* is = cc->getInStream(); OutStream* os = cc->getOutStream(); - + /* get major, minor versions, send what we can support (or 0.0 for can't support it) */ if (!haveRecvdMajorVersion) { + if (!is->hasData(1)) + return false; + majorVersion = is->readU8(); haveRecvdMajorVersion = true; - - return false; } if (!haveRecvdMinorVersion) { + if (!is->hasData(1)) + return false; + minorVersion = is->readU8(); haveRecvdMinorVersion = true; } @@ -100,47 +103,48 @@ bool CSecurityVeNCrypt::processMsg() } haveSentVersion = true; - return false; } /* Check that the server is OK */ if (!haveAgreedVersion) { + if (!is->hasData(1)) + return false; + if (is->readU8()) throw AuthFailureException("The server reported it could not support the " "VeNCrypt version"); haveAgreedVersion = true; - return false; } /* get a number of types */ if (!haveNumberOfTypes) { + if (!is->hasData(1)) + return false; + nAvailableTypes = is->readU8(); - iAvailableType = 0; if (!nAvailableTypes) throw AuthFailureException("The server reported no VeNCrypt sub-types"); availableTypes = new rdr::U32[nAvailableTypes]; haveNumberOfTypes = true; - return false; } if (nAvailableTypes) { /* read in the types possible */ if (!haveListOfTypes) { - if (is->checkNoWait(4)) { - availableTypes[iAvailableType++] = is->readU32(); - haveListOfTypes = (iAvailableType >= nAvailableTypes); - vlog.debug("Server offers security type %s (%d)", - secTypeName(availableTypes[iAvailableType - 1]), - availableTypes[iAvailableType - 1]); - - if (!haveListOfTypes) - return false; - - } else - return false; + if (!is->hasData(4 * nAvailableTypes)) + return false; + + for (int i = 0;i < nAvailableTypes;i++) { + availableTypes[i] = is->readU32(); + vlog.debug("Server offers security type %s (%d)", + secTypeName(availableTypes[i]), + availableTypes[i]); + } + + haveListOfTypes = true; } /* make a choice and send it to the server, meanwhile set up the stack */ diff --git a/common/rfb/CSecurityVeNCrypt.h b/common/rfb/CSecurityVeNCrypt.h index d015e8f2..1e2a7e68 100644 --- a/common/rfb/CSecurityVeNCrypt.h +++ b/common/rfb/CSecurityVeNCrypt.h @@ -55,7 +55,6 @@ namespace rfb { rdr::U32 chosenType; rdr::U8 nAvailableTypes; rdr::U32 *availableTypes; - rdr::U8 iAvailableType; }; } #endif diff --git a/common/rfb/CSecurityVncAuth.cxx b/common/rfb/CSecurityVncAuth.cxx index 6a87498c..78a3a061 100644 --- a/common/rfb/CSecurityVncAuth.cxx +++ b/common/rfb/CSecurityVncAuth.cxx @@ -45,6 +45,9 @@ bool CSecurityVncAuth::processMsg() rdr::InStream* is = cc->getInStream(); rdr::OutStream* os = cc->getOutStream(); + if (!is->hasData(vncAuthChallengeSize)) + return false; + // Read the challenge & obtain the user's password rdr::U8 challenge[vncAuthChallengeSize]; is->readBytes(challenge, vncAuthChallengeSize); diff --git a/common/rfb/CopyRectDecoder.cxx b/common/rfb/CopyRectDecoder.cxx index ecf50323..bca8da48 100644 --- a/common/rfb/CopyRectDecoder.cxx +++ b/common/rfb/CopyRectDecoder.cxx @@ -31,10 +31,13 @@ CopyRectDecoder::~CopyRectDecoder() { } -void CopyRectDecoder::readRect(const Rect& r, rdr::InStream* is, +bool CopyRectDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { + if (!is->hasData(4)) + return false; os->copyBytes(is, 4); + return true; } diff --git a/common/rfb/CopyRectDecoder.h b/common/rfb/CopyRectDecoder.h index 546266e1..5100eb2f 100644 --- a/common/rfb/CopyRectDecoder.h +++ b/common/rfb/CopyRectDecoder.h @@ -26,7 +26,7 @@ namespace rfb { public: CopyRectDecoder(); virtual ~CopyRectDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual void getAffectedRegion(const Rect& rect, const void* buffer, size_t buflen, const ServerParams& server, diff --git a/common/rfb/DecodeManager.cxx b/common/rfb/DecodeManager.cxx index 80c10510..c003ab40 100644 --- a/common/rfb/DecodeManager.cxx +++ b/common/rfb/DecodeManager.cxx @@ -103,7 +103,7 @@ DecodeManager::~DecodeManager() delete decoders[i]; } -void DecodeManager::decodeRect(const Rect& r, int encoding, +bool DecodeManager::decodeRect(const Rect& r, int encoding, ModifiablePixelBuffer* pb) { Decoder *decoder; @@ -133,19 +133,21 @@ void DecodeManager::decodeRect(const Rect& r, int encoding, if (threads.empty()) { bufferStream = freeBuffers.front(); bufferStream->clear(); - decoder->readRect(r, conn->getInStream(), conn->server, bufferStream); + if (!decoder->readRect(r, conn->getInStream(), conn->server, bufferStream)) + return false; try { decoder->decodeRect(r, bufferStream->data(), bufferStream->length(), conn->server, pb); } catch (rdr::Exception& e) { throw Exception("Error decoding rect: %s", e.str()); } - return; + return true; } // Wait for an available memory buffer queueMutex->lock(); + // FIXME: Should we return and let other things run here? while (freeBuffers.empty()) producerCond->wait(); @@ -160,7 +162,8 @@ void DecodeManager::decodeRect(const Rect& r, int encoding, // Read the rect bufferStream->clear(); - decoder->readRect(r, conn->getInStream(), conn->server, bufferStream); + if (!decoder->readRect(r, conn->getInStream(), conn->server, bufferStream)) + return false; // Then try to put it on the queue entry = new QueueEntry; @@ -190,6 +193,8 @@ void DecodeManager::decodeRect(const Rect& r, int encoding, consumerCond->signal(); queueMutex->unlock(); + + return true; } void DecodeManager::flush() diff --git a/common/rfb/DecodeManager.h b/common/rfb/DecodeManager.h index 058d8240..289686b5 100644 --- a/common/rfb/DecodeManager.h +++ b/common/rfb/DecodeManager.h @@ -47,7 +47,7 @@ namespace rfb { DecodeManager(CConnection *conn); ~DecodeManager(); - void decodeRect(const Rect& r, int encoding, + bool decodeRect(const Rect& r, int encoding, ModifiablePixelBuffer* pb); void flush(); diff --git a/common/rfb/Decoder.h b/common/rfb/Decoder.h index e074f3ec..cb206a0d 100644 --- a/common/rfb/Decoder.h +++ b/common/rfb/Decoder.h @@ -52,7 +52,7 @@ namespace rfb { // InStream to the OutStream, possibly changing it along the way to // make it easier to decode. This function will always be called in // a serial manner on the main thread. - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os)=0; // These functions will be called from any of the worker threads. diff --git a/common/rfb/HextileDecoder.cxx b/common/rfb/HextileDecoder.cxx index 742dfb28..34392ea8 100644 --- a/common/rfb/HextileDecoder.cxx +++ b/common/rfb/HextileDecoder.cxx @@ -44,12 +44,14 @@ HextileDecoder::~HextileDecoder() { } -void HextileDecoder::readRect(const Rect& r, rdr::InStream* is, +bool HextileDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { Rect t; size_t bytesPerPixel; + is->setRestorePoint(); + bytesPerPixel = server.pf().bpp/8; for (t.tl.y = r.tl.y; t.tl.y < r.br.y; t.tl.y += 16) { @@ -61,33 +63,57 @@ void HextileDecoder::readRect(const Rect& r, rdr::InStream* is, t.br.x = __rfbmin(r.br.x, t.tl.x + 16); + if (!is->hasDataOrRestore(1)) + return false; + tileType = is->readU8(); os->writeU8(tileType); if (tileType & hextileRaw) { + if (!is->hasDataOrRestore(t.area() * bytesPerPixel)) + return false; os->copyBytes(is, t.area() * bytesPerPixel); continue; } - if (tileType & hextileBgSpecified) + + if (tileType & hextileBgSpecified) { + if (!is->hasDataOrRestore(bytesPerPixel)) + return false; os->copyBytes(is, bytesPerPixel); + } - if (tileType & hextileFgSpecified) + if (tileType & hextileFgSpecified) { + if (!is->hasDataOrRestore(bytesPerPixel)) + return false; os->copyBytes(is, bytesPerPixel); + } if (tileType & hextileAnySubrects) { rdr::U8 nSubrects; + if (!is->hasDataOrRestore(1)) + return false; + nSubrects = is->readU8(); os->writeU8(nSubrects); - if (tileType & hextileSubrectsColoured) + if (tileType & hextileSubrectsColoured) { + if (!is->hasDataOrRestore(nSubrects * (bytesPerPixel + 2))) + return false; os->copyBytes(is, nSubrects * (bytesPerPixel + 2)); - else + } else { + if (!is->hasDataOrRestore(nSubrects * 2)) + return false; os->copyBytes(is, nSubrects * 2); + } } } } + + is->clearRestorePoint(); + + return true; } void HextileDecoder::decodeRect(const Rect& r, const void* buffer, diff --git a/common/rfb/HextileDecoder.h b/common/rfb/HextileDecoder.h index b8515bfc..2c42be54 100644 --- a/common/rfb/HextileDecoder.h +++ b/common/rfb/HextileDecoder.h @@ -26,7 +26,7 @@ namespace rfb { public: HextileDecoder(); virtual ~HextileDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual void decodeRect(const Rect& r, const void* buffer, size_t buflen, const ServerParams& server, diff --git a/common/rfb/RREDecoder.cxx b/common/rfb/RREDecoder.cxx index 70a7ddb2..af821cb9 100644 --- a/common/rfb/RREDecoder.cxx +++ b/common/rfb/RREDecoder.cxx @@ -44,15 +44,30 @@ RREDecoder::~RREDecoder() { } -void RREDecoder::readRect(const Rect& r, rdr::InStream* is, +bool RREDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { rdr::U32 numRects; + size_t len; + + if (!is->hasData(4)) + return false; + + is->setRestorePoint(); numRects = is->readU32(); os->writeU32(numRects); - os->copyBytes(is, server.pf().bpp/8 + numRects * (server.pf().bpp/8 + 8)); + len = server.pf().bpp/8 + numRects * (server.pf().bpp/8 + 8); + + if (!is->hasDataOrRestore(len)) + return false; + + is->clearRestorePoint(); + + os->copyBytes(is, len); + + return true; } void RREDecoder::decodeRect(const Rect& r, const void* buffer, diff --git a/common/rfb/RREDecoder.h b/common/rfb/RREDecoder.h index f47eddad..b8ec18f6 100644 --- a/common/rfb/RREDecoder.h +++ b/common/rfb/RREDecoder.h @@ -26,7 +26,7 @@ namespace rfb { public: RREDecoder(); virtual ~RREDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual void decodeRect(const Rect& r, const void* buffer, size_t buflen, const ServerParams& server, diff --git a/common/rfb/RawDecoder.cxx b/common/rfb/RawDecoder.cxx index 61235047..a7648f97 100644 --- a/common/rfb/RawDecoder.cxx +++ b/common/rfb/RawDecoder.cxx @@ -33,10 +33,13 @@ RawDecoder::~RawDecoder() { } -void RawDecoder::readRect(const Rect& r, rdr::InStream* is, +bool RawDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { + if (!is->hasData(r.area() * (server.pf().bpp/8))) + return false; os->copyBytes(is, r.area() * (server.pf().bpp/8)); + return true; } void RawDecoder::decodeRect(const Rect& r, const void* buffer, diff --git a/common/rfb/RawDecoder.h b/common/rfb/RawDecoder.h index 4ab80717..2661ea57 100644 --- a/common/rfb/RawDecoder.h +++ b/common/rfb/RawDecoder.h @@ -25,7 +25,7 @@ namespace rfb { public: RawDecoder(); virtual ~RawDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual void decodeRect(const Rect& r, const void* buffer, size_t buflen, const ServerParams& server, diff --git a/common/rfb/SConnection.cxx b/common/rfb/SConnection.cxx index e06fc6bb..1c9ca3e7 100644 --- a/common/rfb/SConnection.cxx +++ b/common/rfb/SConnection.cxx @@ -86,18 +86,20 @@ void SConnection::initialiseProtocol() state_ = RFBSTATE_PROTOCOL_VERSION; } -void SConnection::processMsg() +bool SConnection::processMsg() { switch (state_) { - case RFBSTATE_PROTOCOL_VERSION: processVersionMsg(); break; - case RFBSTATE_SECURITY_TYPE: processSecurityTypeMsg(); break; - case RFBSTATE_SECURITY: processSecurityMsg(); break; - case RFBSTATE_SECURITY_FAILURE: processSecurityFailure(); break; - case RFBSTATE_INITIALISATION: processInitMsg(); break; - case RFBSTATE_NORMAL: reader_->readMsg(); break; + case RFBSTATE_PROTOCOL_VERSION: return processVersionMsg(); break; + case RFBSTATE_SECURITY_TYPE: return processSecurityTypeMsg(); break; + case RFBSTATE_SECURITY: return processSecurityMsg(); break; + case RFBSTATE_SECURITY_FAILURE: return processSecurityFailure(); break; + case RFBSTATE_INITIALISATION: return processInitMsg(); break; + case RFBSTATE_NORMAL: return reader_->readMsg(); break; case RFBSTATE_QUERYING: throw Exception("SConnection::processMsg: bogus data from client while " "querying"); + case RFBSTATE_CLOSING: + throw Exception("SConnection::processMsg: called while closing"); case RFBSTATE_UNINITIALISED: throw Exception("SConnection::processMsg: not initialised yet?"); default: @@ -105,7 +107,7 @@ void SConnection::processMsg() } } -void SConnection::processVersionMsg() +bool SConnection::processVersionMsg() { char verStr[13]; int majorVersion; @@ -113,8 +115,8 @@ void SConnection::processVersionMsg() vlog.debug("reading protocol version"); - if (!is->checkNoWait(12)) - return; + if (!is->hasData(12)) + return false; is->readBytes(verStr, 12); verStr[12] = '\0'; @@ -172,8 +174,7 @@ void SConnection::processVersionMsg() if (*i == secTypeNone) os->flush(); state_ = RFBSTATE_SECURITY; ssecurity = security.GetSSecurity(this, *i); - processSecurityMsg(); - return; + return true; } // list supported security types for >=3.7 clients @@ -186,15 +187,23 @@ void SConnection::processVersionMsg() os->writeU8(*i); os->flush(); state_ = RFBSTATE_SECURITY_TYPE; + + return true; } -void SConnection::processSecurityTypeMsg() +bool SConnection::processSecurityTypeMsg() { vlog.debug("processing security type message"); + + if (!is->hasData(1)) + return false; + int secType = is->readU8(); processSecurityType(secType); + + return true; } void SConnection::processSecurityType(int secType) @@ -218,16 +227,14 @@ void SConnection::processSecurityType(int secType) } catch (rdr::Exception& e) { throwConnFailedException("%s", e.str()); } - - processSecurityMsg(); } -void SConnection::processSecurityMsg() +bool SConnection::processSecurityMsg() { vlog.debug("processing security message"); try { if (!ssecurity->processMsg()) - return; + return false; } catch (AuthFailureException& e) { vlog.error("AuthFailureException: %s", e.str()); state_ = RFBSTATE_SECURITY_FAILURE; @@ -235,28 +242,41 @@ void SConnection::processSecurityMsg() // to make it difficult to brute force a password authFailureMsg.replaceBuf(strDup(e.str())); authFailureTimer.start(100); + return true; } state_ = RFBSTATE_QUERYING; setAccessRights(ssecurity->getAccessRights()); queryConnection(ssecurity->getUserName()); + + // If the connection got approved right away then we can continue + if (state_ == RFBSTATE_INITIALISATION) + return true; + + // Otherwise we need to wait for the result + // (or give up if if was rejected) + return false; } -void SConnection::processSecurityFailure() +bool SConnection::processSecurityFailure() { // Silently drop any data if we are currently delaying an // authentication failure response as otherwise we would close // the connection on unexpected data, and an attacker could use // that to detect our delayed state. - while (is->checkNoWait(1)) - is->skip(1); + if (!is->hasData(1)) + return false; + + is->skip(is->avail()); + + return true; } -void SConnection::processInitMsg() +bool SConnection::processInitMsg() { vlog.debug("reading client initialisation"); - reader_->readClientInit(); + return reader_->readClientInit(); } bool SConnection::handleAuthFailureTimeout(Timer* t) diff --git a/common/rfb/SConnection.h b/common/rfb/SConnection.h index e7bbf2c3..b333086a 100644 --- a/common/rfb/SConnection.h +++ b/common/rfb/SConnection.h @@ -60,7 +60,7 @@ namespace rfb { // processMsg() should be called whenever there is data to read on the // InStream. You must have called initialiseProtocol() first. - void processMsg(); + bool processMsg(); // approveConnection() is called to either accept or reject the connection. // If accept is false, the reason string gives the reason for the @@ -235,12 +235,12 @@ namespace rfb { bool readyForSetColourMapEntries; - void processVersionMsg(); - void processSecurityTypeMsg(); + bool processVersionMsg(); + bool processSecurityTypeMsg(); void processSecurityType(int secType); - void processSecurityMsg(); - void processSecurityFailure(); - void processInitMsg(); + bool processSecurityMsg(); + bool processSecurityFailure(); + bool processInitMsg(); bool handleAuthFailureTimeout(Timer* t); diff --git a/common/rfb/SMsgReader.cxx b/common/rfb/SMsgReader.cxx index dc7ddea6..944f9315 100644 --- a/common/rfb/SMsgReader.cxx +++ b/common/rfb/SMsgReader.cxx @@ -38,7 +38,7 @@ static LogWriter vlog("SMsgReader"); static IntParameter maxCutText("MaxCutText", "Maximum permitted length of an incoming clipboard update", 256*1024); SMsgReader::SMsgReader(SMsgHandler* handler_, rdr::InStream* is_) - : handler(handler_), is(is_) + : handler(handler_), is(is_), state(MSGSTATE_IDLE) { } @@ -46,71 +46,105 @@ SMsgReader::~SMsgReader() { } -void SMsgReader::readClientInit() +bool SMsgReader::readClientInit() { + if (!is->hasData(1)) + return false; bool shared = is->readU8(); handler->clientInit(shared); + return true; } -void SMsgReader::readMsg() +bool SMsgReader::readMsg() { - int msgType = is->readU8(); - switch (msgType) { + bool ret; + + if (state == MSGSTATE_IDLE) { + if (!is->hasData(1)) + return false; + + currentMsgType = is->readU8(); + state = MSGSTATE_MESSAGE; + } + + switch (currentMsgType) { case msgTypeSetPixelFormat: - readSetPixelFormat(); + ret = readSetPixelFormat(); break; case msgTypeSetEncodings: - readSetEncodings(); + ret = readSetEncodings(); break; case msgTypeSetDesktopSize: - readSetDesktopSize(); + ret = readSetDesktopSize(); break; case msgTypeFramebufferUpdateRequest: - readFramebufferUpdateRequest(); + ret = readFramebufferUpdateRequest(); break; case msgTypeEnableContinuousUpdates: - readEnableContinuousUpdates(); + ret = readEnableContinuousUpdates(); break; case msgTypeClientFence: - readFence(); + ret = readFence(); break; case msgTypeKeyEvent: - readKeyEvent(); + ret = readKeyEvent(); break; case msgTypePointerEvent: - readPointerEvent(); + ret = readPointerEvent(); break; case msgTypeClientCutText: - readClientCutText(); + ret = readClientCutText(); break; case msgTypeQEMUClientMessage: - readQEMUMessage(); + ret = readQEMUMessage(); break; default: - vlog.error("unknown message type %d", msgType); + vlog.error("unknown message type %d", currentMsgType); throw Exception("unknown message type"); } + + if (ret) + state = MSGSTATE_IDLE; + + return ret; } -void SMsgReader::readSetPixelFormat() +bool SMsgReader::readSetPixelFormat() { + if (!is->hasData(3 + 16)) + return false; is->skip(3); PixelFormat pf; pf.read(is); handler->setPixelFormat(pf); + return true; } -void SMsgReader::readSetEncodings() +bool SMsgReader::readSetEncodings() { + if (!is->hasData(3)) + return false; + + is->setRestorePoint(); + is->skip(1); + int nEncodings = is->readU16(); + + if (!is->hasDataOrRestore(nEncodings * 4)) + return false; + is->clearRestorePoint(); + rdr::S32Array encodings(nEncodings); for (int i = 0; i < nEncodings; i++) encodings.buf[i] = is->readU32(); + handler->setEncodings(nEncodings, encodings.buf); + + return true; } -void SMsgReader::readSetDesktopSize() +bool SMsgReader::readSetDesktopSize() { int width, height; int screens, i; @@ -118,6 +152,11 @@ void SMsgReader::readSetDesktopSize() int sx, sy, sw, sh; ScreenSet layout; + if (!is->hasData(7)) + return true; + + is->setRestorePoint(); + is->skip(1); width = is->readU16(); @@ -126,6 +165,10 @@ void SMsgReader::readSetDesktopSize() screens = is->readU8(); is->skip(1); + if (!is->hasDataOrRestore(screens * 24)) + return false; + is->clearRestorePoint(); + for (i = 0;i < screens;i++) { id = is->readU32(); sx = is->readU16(); @@ -138,23 +181,31 @@ void SMsgReader::readSetDesktopSize() } handler->setDesktopSize(width, height, layout); + + return true; } -void SMsgReader::readFramebufferUpdateRequest() +bool SMsgReader::readFramebufferUpdateRequest() { + if (!is->hasData(17)) + return false; bool inc = is->readU8(); int x = is->readU16(); int y = is->readU16(); int w = is->readU16(); int h = is->readU16(); handler->framebufferUpdateRequest(Rect(x, y, x+w, y+h), inc); + return true; } -void SMsgReader::readEnableContinuousUpdates() +bool SMsgReader::readEnableContinuousUpdates() { bool enable; int x, y, w, h; + if (!is->hasData(17)) + return false; + enable = is->readU8(); x = is->readU16(); @@ -163,81 +214,121 @@ void SMsgReader::readEnableContinuousUpdates() h = is->readU16(); handler->enableContinuousUpdates(enable, x, y, w, h); + + return true; } -void SMsgReader::readFence() +bool SMsgReader::readFence() { rdr::U32 flags; rdr::U8 len; char data[64]; + if (!is->hasData(8)) + return false; + + is->setRestorePoint(); + is->skip(3); flags = is->readU32(); len = is->readU8(); + + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + if (len > sizeof(data)) { vlog.error("Ignoring fence with too large payload"); is->skip(len); - return; + return true; } is->readBytes(data, len); handler->fence(flags, len, data); + + return true; } -void SMsgReader::readKeyEvent() +bool SMsgReader::readKeyEvent() { + if (!is->hasData(7)) + return false; bool down = is->readU8(); is->skip(2); rdr::U32 key = is->readU32(); handler->keyEvent(key, 0, down); + return true; } -void SMsgReader::readPointerEvent() +bool SMsgReader::readPointerEvent() { + if (!is->hasData(5)) + return false; int mask = is->readU8(); int x = is->readU16(); int y = is->readU16(); handler->pointerEvent(Point(x, y), mask); + return true; } -void SMsgReader::readClientCutText() +bool SMsgReader::readClientCutText() { + if (!is->hasData(7)) + return false; + + is->setRestorePoint(); + is->skip(3); rdr::U32 len = is->readU32(); if (len & 0x80000000) { rdr::S32 slen = len; slen = -slen; - readExtendedClipboard(slen); - return; + if (readExtendedClipboard(slen)) { + is->clearRestorePoint(); + return true; + } else { + is->gotoRestorePoint(); + return false; + } } + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + if (len > (size_t)maxCutText) { is->skip(len); vlog.error("Cut text too long (%d bytes) - ignoring", len); - return; + return true; } + CharArray ca(len); is->readBytes(ca.buf, len); CharArray filtered(convertLF(ca.buf, len)); handler->clientCutText(filtered.buf); + + return true; } -void SMsgReader::readExtendedClipboard(rdr::S32 len) +bool SMsgReader::readExtendedClipboard(rdr::S32 len) { rdr::U32 flags; rdr::U32 action; + if (!is->hasData(len)) + return false; + if (len < 4) throw Exception("Invalid extended clipboard message"); if (len > maxCutText) { vlog.error("Extended clipboard message too long (%d bytes) - ignoring", len); is->skip(len); - return; + return true; } flags = is->readU32(); @@ -279,7 +370,14 @@ void SMsgReader::readExtendedClipboard(rdr::S32 len) if (!(flags & 1 << i)) continue; + if (!zis.hasData(4)) + throw Exception("Extended clipboard decode error"); + lengths[num] = zis.readU32(); + + if (!zis.hasData(lengths[num])) + throw Exception("Extended clipboard decode error"); + if (lengths[num] > (size_t)maxCutText) { vlog.error("Extended clipboard data too long (%d bytes) - ignoring", (unsigned)lengths[num]); @@ -319,28 +417,50 @@ void SMsgReader::readExtendedClipboard(rdr::S32 len) throw Exception("Invalid extended clipboard action"); } } + + return true; } -void SMsgReader::readQEMUMessage() +bool SMsgReader::readQEMUMessage() { - int subType = is->readU8(); + int subType; + bool ret; + + if (!is->hasData(1)) + return false; + + is->setRestorePoint(); + + subType = is->readU8(); + switch (subType) { case qemuExtendedKeyEvent: - readQEMUKeyEvent(); + ret = readQEMUKeyEvent(); break; default: throw Exception("unknown QEMU submessage type %d", subType); } + + if (!ret) { + is->gotoRestorePoint(); + return false; + } else { + is->clearRestorePoint(); + return true; + } } -void SMsgReader::readQEMUKeyEvent() +bool SMsgReader::readQEMUKeyEvent() { + if (!is->hasData(10)) + return false; bool down = is->readU16(); rdr::U32 keysym = is->readU32(); rdr::U32 keycode = is->readU32(); if (!keycode) { vlog.error("Key event without keycode - ignoring"); - return; + return true; } handler->keyEvent(keysym, keycode, down); + return true; } diff --git a/common/rfb/SMsgReader.h b/common/rfb/SMsgReader.h index 4991fd38..acc872ed 100644 --- a/common/rfb/SMsgReader.h +++ b/common/rfb/SMsgReader.h @@ -34,33 +34,43 @@ namespace rfb { SMsgReader(SMsgHandler* handler, rdr::InStream* is); virtual ~SMsgReader(); - void readClientInit(); + bool readClientInit(); // readMsg() reads a message, calling the handler as appropriate. - void readMsg(); + bool readMsg(); rdr::InStream* getInStream() { return is; } protected: - void readSetPixelFormat(); - void readSetEncodings(); - void readSetDesktopSize(); + bool readSetPixelFormat(); + bool readSetEncodings(); + bool readSetDesktopSize(); - void readFramebufferUpdateRequest(); - void readEnableContinuousUpdates(); + bool readFramebufferUpdateRequest(); + bool readEnableContinuousUpdates(); - void readFence(); + bool readFence(); - void readKeyEvent(); - void readPointerEvent(); - void readClientCutText(); - void readExtendedClipboard(rdr::S32 len); + bool readKeyEvent(); + bool readPointerEvent(); + bool readClientCutText(); + bool readExtendedClipboard(rdr::S32 len); - void readQEMUMessage(); - void readQEMUKeyEvent(); + bool readQEMUMessage(); + bool readQEMUKeyEvent(); + private: SMsgHandler* handler; rdr::InStream* is; + + enum stateEnum { + MSGSTATE_IDLE, + MSGSTATE_MESSAGE, + }; + + stateEnum state; + + rdr::U8 currentMsgType; }; } #endif diff --git a/common/rfb/SSecurityPlain.cxx b/common/rfb/SSecurityPlain.cxx index f577c0d6..6ae19557 100644 --- a/common/rfb/SSecurityPlain.cxx +++ b/common/rfb/SSecurityPlain.cxx @@ -84,7 +84,7 @@ bool SSecurityPlain::processMsg() throw AuthFailureException("No password validator configured"); if (state == 0) { - if (!is->checkNoWait(8)) + if (!is->hasData(8)) return false; ulen = is->readU32(); @@ -99,7 +99,7 @@ bool SSecurityPlain::processMsg() } if (state == 1) { - if (!is->checkNoWait(ulen + plen)) + if (!is->hasData(ulen + plen)) return false; state = 2; pw = new char[plen + 1]; diff --git a/common/rfb/SSecurityVeNCrypt.cxx b/common/rfb/SSecurityVeNCrypt.cxx index d522ef6f..135742c0 100644 --- a/common/rfb/SSecurityVeNCrypt.cxx +++ b/common/rfb/SSecurityVeNCrypt.cxx @@ -78,19 +78,21 @@ bool SSecurityVeNCrypt::processMsg() os->writeU8(2); haveSentVersion = true; os->flush(); - - return false; } /* Receive back highest version that client can support (up to and including ours) */ if (!haveRecvdMajorVersion) { + if (!is->hasData(1)) + return false; + majorVersion = is->readU8(); haveRecvdMajorVersion = true; - - return false; } if (!haveRecvdMinorVersion) { + if (!is->hasData(1)) + return false; + minorVersion = is->readU8(); haveRecvdMinorVersion = true; @@ -140,14 +142,15 @@ bool SSecurityVeNCrypt::processMsg() os->flush(); haveSentTypes = true; - return false; } else throw AuthFailureException("There are no VeNCrypt sub-types to send to the client"); } /* get type back from client (must be one of the ones we sent) */ if (!haveChosenType) { - is->check(4); + if (!is->hasData(4)) + return false; + chosenType = is->readU32(); for (i = 0; i < numTypes; i++) { diff --git a/common/rfb/SSecurityVncAuth.cxx b/common/rfb/SSecurityVncAuth.cxx index 882f0b08..c2a348b9 100644 --- a/common/rfb/SSecurityVncAuth.cxx +++ b/common/rfb/SSecurityVncAuth.cxx @@ -49,7 +49,7 @@ VncAuthPasswdParameter SSecurityVncAuth::vncAuthPasswd "access the server", &SSecurityVncAuth::vncAuthPasswdFile); SSecurityVncAuth::SSecurityVncAuth(SConnection* sc) - : SSecurity(sc), sentChallenge(false), responsePos(0), + : SSecurity(sc), sentChallenge(false), pg(&vncAuthPasswd), accessRights(0) { } @@ -78,6 +78,8 @@ bool SSecurityVncAuth::processMsg() if (!sentChallenge) { rdr::RandomStream rs; + if (!rs.hasData(vncAuthChallengeSize)) + throw Exception("Could not generate random data for VNC auth challenge"); rs.readBytes(challenge, vncAuthChallengeSize); os->writeBytes(challenge, vncAuthChallengeSize); os->flush(); @@ -85,10 +87,10 @@ bool SSecurityVncAuth::processMsg() return false; } - while (responsePos < vncAuthChallengeSize && is->checkNoWait(1)) - response[responsePos++] = is->readU8(); + if (!is->hasData(vncAuthChallengeSize)) + return false; - if (responsePos < vncAuthChallengeSize) return false; + is->readBytes(response, vncAuthChallengeSize); PlainPasswd passwd, passwdReadOnly; pg->getVncAuthPasswd(&passwd, &passwdReadOnly); diff --git a/common/rfb/SSecurityVncAuth.h b/common/rfb/SSecurityVncAuth.h index fe00b031..94d5aaf2 100644 --- a/common/rfb/SSecurityVncAuth.h +++ b/common/rfb/SSecurityVncAuth.h @@ -64,7 +64,6 @@ namespace rfb { rdr::U8 challenge[vncAuthChallengeSize]; rdr::U8 response[vncAuthChallengeSize]; bool sentChallenge; - int responsePos; VncAuthPasswdGetter* pg; SConnection::AccessRights accessRights; }; diff --git a/common/rfb/ServerCore.cxx b/common/rfb/ServerCore.cxx index b1097a3e..8f49848c 100644 --- a/common/rfb/ServerCore.cxx +++ b/common/rfb/ServerCore.cxx @@ -42,11 +42,6 @@ rfb::IntParameter rfb::Server::maxIdleTime ("MaxIdleTime", "Terminate after s seconds of user inactivity", 0, 0); -rfb::IntParameter rfb::Server::clientWaitTimeMillis -("ClientWaitTimeMillis", - "The number of milliseconds to wait for a client which is no longer " - "responding", - 20000, 0); rfb::IntParameter rfb::Server::compareFB ("CompareFB", "Perform pixel comparison on framebuffer to reduce unnecessary updates " diff --git a/common/rfb/ServerCore.h b/common/rfb/ServerCore.h index f915c7a7..20a740a8 100644 --- a/common/rfb/ServerCore.h +++ b/common/rfb/ServerCore.h @@ -36,7 +36,6 @@ namespace rfb { static IntParameter maxDisconnectionTime; static IntParameter maxConnectionTime; static IntParameter maxIdleTime; - static IntParameter clientWaitTimeMillis; static IntParameter compareFB; static IntParameter frameRate; static BoolParameter protocol3_3; diff --git a/common/rfb/TightDecoder.cxx b/common/rfb/TightDecoder.cxx index ebc98b06..fe03e453 100644 --- a/common/rfb/TightDecoder.cxx +++ b/common/rfb/TightDecoder.cxx @@ -54,11 +54,16 @@ TightDecoder::~TightDecoder() { } -void TightDecoder::readRect(const Rect& r, rdr::InStream* is, +bool TightDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { rdr::U8 comp_ctl; + if (!is->hasData(1)) + return false; + + is->setRestorePoint(); + comp_ctl = is->readU8(); os->writeU8(comp_ctl); @@ -66,21 +71,38 @@ void TightDecoder::readRect(const Rect& r, rdr::InStream* is, // "Fill" compression type. if (comp_ctl == tightFill) { - if (server.pf().is888()) + if (server.pf().is888()) { + if (!is->hasDataOrRestore(3)) + return false; os->copyBytes(is, 3); - else + } else { + if (!is->hasDataOrRestore(server.pf().bpp/8)) + return false; os->copyBytes(is, server.pf().bpp/8); - return; + } + is->clearRestorePoint(); + return true; } // "JPEG" compression type. if (comp_ctl == tightJpeg) { rdr::U32 len; + // FIXME: Might be less than 3 bytes + if (!is->hasDataOrRestore(3)) + return false; + len = readCompact(is); os->writeOpaque32(len); + + if (!is->hasDataOrRestore(len)) + return false; + os->copyBytes(is, len); - return; + + is->clearRestorePoint(); + + return true; } // Quit on unsupported compression type. @@ -98,18 +120,29 @@ void TightDecoder::readRect(const Rect& r, rdr::InStream* is, if ((comp_ctl & tightExplicitFilter) != 0) { rdr::U8 filterId; + if (!is->hasDataOrRestore(1)) + return false; + filterId = is->readU8(); os->writeU8(filterId); switch (filterId) { case tightFilterPalette: + if (!is->hasDataOrRestore(1)) + return false; + palSize = is->readU8() + 1; os->writeU8(palSize - 1); - if (server.pf().is888()) + if (server.pf().is888()) { + if (!is->hasDataOrRestore(palSize * 3)) + return false; os->copyBytes(is, palSize * 3); - else + } else { + if (!is->hasDataOrRestore(palSize * server.pf().bpp/8)) + return false; os->copyBytes(is, palSize * server.pf().bpp/8); + } break; case tightFilterGradient: if (server.pf().bpp == 8) @@ -137,15 +170,29 @@ void TightDecoder::readRect(const Rect& r, rdr::InStream* is, dataSize = r.height() * rowSize; - if (dataSize < TIGHT_MIN_TO_COMPRESS) + if (dataSize < TIGHT_MIN_TO_COMPRESS) { + if (!is->hasDataOrRestore(dataSize)) + return false; os->copyBytes(is, dataSize); - else { + } else { rdr::U32 len; + // FIXME: Might be less than 3 bytes + if (!is->hasDataOrRestore(3)) + return false; + len = readCompact(is); os->writeOpaque32(len); + + if (!is->hasDataOrRestore(len)) + return false; + os->copyBytes(is, len); } + + is->clearRestorePoint(); + + return true; } bool TightDecoder::doRectsConflict(const Rect& rectA, @@ -339,6 +386,8 @@ void TightDecoder::decodeRect(const Rect& r, const void* buffer, // Allocate buffer and decompress the data netbuf = new rdr::U8[dataSize]; + if (!zis[streamId].hasData(dataSize)) + throw Exception("Tight decode error"); zis[streamId].readBytes(netbuf, dataSize); zis[streamId].flushUnderlying(); diff --git a/common/rfb/TightDecoder.h b/common/rfb/TightDecoder.h index 28b6c30f..763c82d6 100644 --- a/common/rfb/TightDecoder.h +++ b/common/rfb/TightDecoder.h @@ -31,7 +31,7 @@ namespace rfb { public: TightDecoder(); virtual ~TightDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual bool doRectsConflict(const Rect& rectA, const void* bufferA, diff --git a/common/rfb/VNCSConnectionST.cxx b/common/rfb/VNCSConnectionST.cxx index 00f640b3..c4ec733b 100644 --- a/common/rfb/VNCSConnectionST.cxx +++ b/common/rfb/VNCSConnectionST.cxx @@ -57,9 +57,6 @@ VNCSConnectionST::VNCSConnectionST(VNCServerST* server_, network::Socket *s, setStreams(&sock->inStream(), &sock->outStream()); peerEndpoint.buf = sock->getPeerEndpoint(); - // Configure the socket - setSocketTimeouts(); - // Kick off the idle timer if (rfb::Server::idleTimeout) { // minimum of 15 seconds while authenticating @@ -152,26 +149,23 @@ void VNCSConnectionST::processMessages() { if (state() == RFBSTATE_CLOSING) return; try { - // - Now set appropriate socket timeouts and process data - setSocketTimeouts(); - inProcessMessages = true; // Get the underlying transport to build large packets if we send // multiple small responses. getOutStream()->cork(true); - while (getInStream()->checkNoWait(1)) { - if (pendingSyncFence) { + while (true) { + if (pendingSyncFence) syncFence = true; - pendingSyncFence = false; - } - processMsg(); + if (!processMsg()) + break; if (syncFence) { writer()->writeFence(fenceFlags, fenceDataLen, fenceData); syncFence = false; + pendingSyncFence = false; } } @@ -195,7 +189,6 @@ void VNCSConnectionST::flushSocket() { if (state() == RFBSTATE_CLOSING) return; try { - setSocketTimeouts(); sock->outStream().flush(); // Flushing the socket might release an update that was previously // delayed because of congestion. @@ -1150,12 +1143,3 @@ void VNCSConnectionST::setLEDState(unsigned int ledstate) if (client.supportsLEDState()) writer()->writeLEDState(); } - -void VNCSConnectionST::setSocketTimeouts() -{ - int timeoutms = rfb::Server::clientWaitTimeMillis; - if (timeoutms == 0) - timeoutms = -1; - sock->inStream().setTimeout(timeoutms); - sock->outStream().setTimeout(timeoutms); -} diff --git a/common/rfb/VNCSConnectionST.h b/common/rfb/VNCSConnectionST.h index 46a2b28b..06fdf541 100644 --- a/common/rfb/VNCSConnectionST.h +++ b/common/rfb/VNCSConnectionST.h @@ -155,7 +155,6 @@ namespace rfb { void setCursor(); void setDesktopName(const char *name); void setLEDState(unsigned int state); - void setSocketTimeouts(); private: network::Socket* sock; diff --git a/common/rfb/ZRLEDecoder.cxx b/common/rfb/ZRLEDecoder.cxx index 9d1ff6b6..4fba0c22 100644 --- a/common/rfb/ZRLEDecoder.cxx +++ b/common/rfb/ZRLEDecoder.cxx @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -29,7 +30,6 @@ using namespace rfb; static inline rdr::U32 readOpaque24A(rdr::InStream* is) { - is->check(3); rdr::U32 r=0; ((rdr::U8*)&r)[0] = is->readU8(); ((rdr::U8*)&r)[1] = is->readU8(); @@ -39,7 +39,6 @@ static inline rdr::U32 readOpaque24A(rdr::InStream* is) } static inline rdr::U32 readOpaque24B(rdr::InStream* is) { - is->check(3); rdr::U32 r=0; ((rdr::U8*)&r)[1] = is->readU8(); ((rdr::U8*)&r)[2] = is->readU8(); @@ -47,6 +46,12 @@ static inline rdr::U32 readOpaque24B(rdr::InStream* is) return r; } +static inline void zlibHasData(rdr::ZlibInStream* zis, size_t length) +{ + if (!zis->hasData(length)) + throw Exception("ZRLE decode error"); +} + #define BPP 8 #include #undef BPP @@ -71,14 +76,27 @@ ZRLEDecoder::~ZRLEDecoder() { } -void ZRLEDecoder::readRect(const Rect& r, rdr::InStream* is, +bool ZRLEDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { rdr::U32 len; + if (!is->hasData(4)) + return false; + + is->setRestorePoint(); + len = is->readU32(); os->writeU32(len); + + if (!is->hasDataOrRestore(len)) + return false; + + is->clearRestorePoint(); + os->copyBytes(is, len); + + return true; } void ZRLEDecoder::decodeRect(const Rect& r, const void* buffer, diff --git a/common/rfb/ZRLEDecoder.h b/common/rfb/ZRLEDecoder.h index a530586e..115f8fb8 100644 --- a/common/rfb/ZRLEDecoder.h +++ b/common/rfb/ZRLEDecoder.h @@ -27,7 +27,7 @@ namespace rfb { public: ZRLEDecoder(); virtual ~ZRLEDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual void decodeRect(const Rect& r, const void* buffer, size_t buflen, const ServerParams& server, diff --git a/common/rfb/zrleDecode.h b/common/rfb/zrleDecode.h index f4325385..998e51ed 100644 --- a/common/rfb/zrleDecode.h +++ b/common/rfb/zrleDecode.h @@ -22,11 +22,6 @@ // This file is #included after having set the following macro: // BPP - 8, 16 or 32 -#include -#include -#include -#include - namespace rfb { // CONCAT2E concatenates its arguments, expanding them if they are macros @@ -63,11 +58,17 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is, t.br.x = __rfbmin(r.br.x, t.tl.x + 64); + zlibHasData(zis, 1); int mode = zis->readU8(); bool rle = mode & 128; int palSize = mode & 127; PIXEL_T palette[128]; +#ifdef CPIXEL + zlibHasData(zis, 3 * palSize); +#else + zlibHasData(zis, BPP/8 * palSize); +#endif for (int i = 0; i < palSize; i++) { palette[i] = READ_PIXEL(zis); } @@ -84,10 +85,12 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is, // raw #ifdef CPIXEL + zlibHasData(zis, 3 * t.area()); for (PIXEL_T* ptr = buf; ptr < buf+t.area(); ptr++) { *ptr = READ_PIXEL(zis); } #else + zlibHasData(zis, BPP/8 * t.area()); zis->readBytes(buf, t.area() * (BPP / 8)); #endif @@ -106,6 +109,7 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is, while (ptr < eol) { if (nbits == 0) { + zlibHasData(zis, 1); byte = zis->readU8(); nbits = 8; } @@ -125,10 +129,16 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is, PIXEL_T* ptr = buf; PIXEL_T* end = ptr + t.area(); while (ptr < end) { +#ifdef CPIXEL + zlibHasData(zis, 3); +#else + zlibHasData(zis, BPP/8); +#endif PIXEL_T pix = READ_PIXEL(zis); int len = 1; int b; do { + zlibHasData(zis, 1); b = zis->readU8(); len += b; } while (b == 255); @@ -147,11 +157,13 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is, PIXEL_T* ptr = buf; PIXEL_T* end = ptr + t.area(); while (ptr < end) { + zlibHasData(zis, 1); int index = zis->readU8(); int len = 1; if (index & 128) { int b; do { + zlibHasData(zis, 1); b = zis->readU8(); len += b; } while (b == 255); diff --git a/tests/perf/decperf.cxx b/tests/perf/decperf.cxx index e1307070..a6c65a22 100644 --- a/tests/perf/decperf.cxx +++ b/tests/perf/decperf.cxx @@ -102,6 +102,8 @@ void DummyOutStream::flush() void DummyOutStream::overrun(size_t needed) { flush(); + if (avail() < needed) + throw rdr::Exception("Insufficient dummy output buffer"); } CConn::CConn(const char *filename) diff --git a/tests/perf/encperf.cxx b/tests/perf/encperf.cxx index 6bcb6f74..41c309c1 100644 --- a/tests/perf/encperf.cxx +++ b/tests/perf/encperf.cxx @@ -95,7 +95,7 @@ public: virtual void setCursor(int, int, const rfb::Point&, const rdr::U8*); virtual void framebufferUpdateStart(); virtual void framebufferUpdateEnd(); - virtual void dataRect(const rfb::Rect&, int); + virtual bool dataRect(const rfb::Rect&, int); virtual void setColourMapEntries(int, int, rdr::U16*); virtual void bell(); virtual void serverCutText(const char*); @@ -159,6 +159,8 @@ void DummyOutStream::flush() void DummyOutStream::overrun(size_t needed) { flush(); + if (avail() < needed) + throw rdr::Exception("Insufficient dummy output buffer"); } CConn::CConn(const char *filename) @@ -241,12 +243,15 @@ void CConn::framebufferUpdateEnd() encodeTime += getCpuCounter(); } -void CConn::dataRect(const rfb::Rect &r, int encoding) +bool CConn::dataRect(const rfb::Rect &r, int encoding) { - CConnection::dataRect(r, encoding); + if (!CConnection::dataRect(r, encoding)) + return false; if (encoding != rfb::encodingCopyRect) // FIXME updates.add_changed(rfb::Region(r)); + + return true; } void CConn::setColourMapEntries(int, int, rdr::U16*) diff --git a/unix/vncserver/vncserver.in b/unix/vncserver/vncserver.in index 9f3a4750..2964df33 100755 --- a/unix/vncserver/vncserver.in +++ b/unix/vncserver/vncserver.in @@ -103,7 +103,6 @@ my %config; # override these where present. $default_opts{desktop} = $desktopName; $default_opts{auth} = $xauthorityFile; -$default_opts{rfbwait} = 30000; $default_opts{rfbauth} = "$vncUserDir/passwd"; $default_opts{rfbport} = $vncPort; $default_opts{fp} = $fontPath if ($fontPath); diff --git a/unix/x0vncserver/x0vncserver.cxx b/unix/x0vncserver/x0vncserver.cxx index a9782ada..1531de60 100644 --- a/unix/x0vncserver/x0vncserver.cxx +++ b/unix/x0vncserver/x0vncserver.cxx @@ -352,7 +352,6 @@ int main(int argc, char** argv) if (FD_ISSET((*i)->getFd(), &rfds)) { Socket* sock = (*i)->accept(); if (sock) { - sock->outStream().setBlocking(false); server.addSocket(sock); } else { vlog.status("Client connection rejected"); diff --git a/unix/x0vncserver/x0vncserver.man b/unix/x0vncserver/x0vncserver.man index b54fcb48..094abbe9 100644 --- a/unix/x0vncserver/x0vncserver.man +++ b/unix/x0vncserver/x0vncserver.man @@ -298,13 +298,6 @@ Terminate when a client has been connected for \fIN\fP seconds. Default is Terminate after \fIN\fP seconds of user inactivity. Default is 0. . .TP -.B \-ClientWaitTimeMillis \fItime\fP -Time in milliseconds to wait for a viewer which is blocking the server. This is -necessary because the server is single-threaded and sometimes blocks until the -viewer has finished sending or receiving a message - note that this does not -mean an update will be aborted after this time. Default is 20000 (20 seconds). -. -.TP .B \-AcceptCutText .TQ .B \-SendCutText diff --git a/unix/xserver/hw/vnc/XserverDesktop.cc b/unix/xserver/hw/vnc/XserverDesktop.cc index 8215c936..6f707299 100644 --- a/unix/xserver/hw/vnc/XserverDesktop.cc +++ b/unix/xserver/hw/vnc/XserverDesktop.cc @@ -311,7 +311,6 @@ bool XserverDesktop::handleListenerEvent(int fd, return false; Socket* sock = (*i)->accept(); - sock->outStream().setBlocking(false); vlog.debug("new client, sock %d", sock->getFd()); sockserv->addSocket(sock); vncSetNotifyFd(sock->getFd(), screenIndex, true, false); @@ -393,7 +392,6 @@ void XserverDesktop::blockHandler(int* timeout) void XserverDesktop::addClient(Socket* sock, bool reverse) { vlog.debug("new client, sock %d reverse %d",sock->getFd(),reverse); - sock->outStream().setBlocking(false); server->addSocket(sock, reverse); vncSetNotifyFd(sock->getFd(), screenIndex, true, false); } diff --git a/unix/xserver/hw/vnc/Xvnc.man b/unix/xserver/hw/vnc/Xvnc.man index 83621c08..2d0089d7 100644 --- a/unix/xserver/hw/vnc/Xvnc.man +++ b/unix/xserver/hw/vnc/Xvnc.man @@ -98,13 +98,6 @@ connections from viewers, instead of listening on a TCP port. Specifies the mode of the Unix domain socket. The default is 0600. . .TP -.B \-rfbwait \fItime\fP, \-ClientWaitTimeMillis \fItime\fP -Time in milliseconds to wait for a viewer which is blocking the server. This is -necessary because the server is single-threaded and sometimes blocks until the -viewer has finished sending or receiving a message - note that this does not -mean an update will be aborted after this time. Default is 20000 (20 seconds). -. -.TP .B \-rfbauth \fIpasswd-file\fP, \-PasswordFile \fIpasswd-file\fP Password file for VNC authentication. There is no default, you should specify the password file explicitly. Password file should be created with diff --git a/unix/xserver/hw/vnc/vncExtInit.cc b/unix/xserver/hw/vnc/vncExtInit.cc index 6ab306b1..43f83088 100644 --- a/unix/xserver/hw/vnc/vncExtInit.cc +++ b/unix/xserver/hw/vnc/vncExtInit.cc @@ -73,8 +73,6 @@ struct CaseInsensitiveCompare { typedef std::set ParamSet; static ParamSet allowOverrideSet; -rfb::AliasParameter rfbwait("rfbwait", "Alias for ClientWaitTimeMillis", - &rfb::Server::clientWaitTimeMillis); rfb::IntParameter rfbport("rfbport", "TCP port to listen for RFB protocol",0); rfb::StringParameter rfbunixpath("rfbunixpath", "Unix socket to listen for RFB protocol", ""); rfb::IntParameter rfbunixmode("rfbunixmode", "Unix socket access mode", 0600); diff --git a/vncviewer/CConn.cxx b/vncviewer/CConn.cxx index e7362c8e..68dd031b 100644 --- a/vncviewer/CConn.cxx +++ b/vncviewer/CConn.cxx @@ -116,9 +116,6 @@ CConn::CConn(const char* vncServerName, network::Socket* socket=NULL) Fl::add_fd(sock->getFd(), FL_READ | FL_EXCEPT, socketEvent, this); - // See callback below - sock->inStream().setBlockCallback(this); - setServerName(serverHost); setStreams(&sock->inStream(), &sock->outStream()); @@ -228,22 +225,11 @@ unsigned CConn::getPosition() return sock->inStream().pos(); } -// The RFB core is not properly asynchronous, so it calls this callback -// whenever it needs to block to wait for more data. Since FLTK is -// monitoring the socket, we just make sure FLTK gets to run. - -void CConn::blockCallback() -{ - run_mainloop(); - - if (should_exit()) - throw rdr::Exception("Termination requested"); -} - void CConn::socketEvent(FL_SOCKET fd, void *data) { CConn *cc; static bool recursing = false; + int when; assert(data); cc = (CConn*)data; @@ -255,10 +241,14 @@ void CConn::socketEvent(FL_SOCKET fd, void *data) recursing = true; try { + // We might have been called to flush unwritten socket data + cc->sock->outStream().flush(); + + cc->sock->outStream().cork(true); + // processMsg() only processes one message, so we need to loop // until the buffers are empty or things will stall. - do { - cc->processMsg(); + while (cc->processMsg()) { // Make sure that the FLTK handling and the timers gets some CPU // time in case of back to back messages @@ -268,7 +258,10 @@ void CConn::socketEvent(FL_SOCKET fd, void *data) // Also check if we need to stop reading and terminate if (should_exit()) break; - } while (cc->getInStream()->checkNoWait(1)); + } + + cc->sock->outStream().cork(false); + cc->sock->outStream().flush(); } catch (rdr::EndOfStream& e) { vlog.info("%s", e.str()); exit_vncviewer(); @@ -280,6 +273,12 @@ void CConn::socketEvent(FL_SOCKET fd, void *data) exit_vncviewer(e.str()); } + when = FL_READ | FL_EXCEPT; + if (cc->sock->outStream().hasBufferedData()) + when |= FL_WRITE; + + Fl::add_fd(fd, when, socketEvent, data); + recursing = false; } @@ -402,14 +401,19 @@ void CConn::bell() fl_beep(); } -void CConn::dataRect(const Rect& r, int encoding) +bool CConn::dataRect(const Rect& r, int encoding) { + bool ret; + if (encoding != encodingCopyRect) lastServerEncoding = encoding; - CConnection::dataRect(r, encoding); + ret = CConnection::dataRect(r, encoding); + + if (ret) + pixelCount += r.area(); - pixelCount += r.area(); + return ret; } void CConn::setCursor(int width, int height, const Point& hotspot, diff --git a/vncviewer/CConn.h b/vncviewer/CConn.h index 25dff875..ad3fb797 100644 --- a/vncviewer/CConn.h +++ b/vncviewer/CConn.h @@ -29,8 +29,7 @@ namespace network { class Socket; } class DesktopWindow; -class CConn : public rfb::CConnection, - public rdr::FdInStreamBlockCallback +class CConn : public rfb::CConnection { public: CConn(const char* vncServerName, network::Socket* sock); @@ -42,9 +41,6 @@ public: unsigned getPixelCount(); unsigned getPosition(); - // FdInStreamBlockCallback methods - void blockCallback(); - // Callback when socket is ready (or broken) static void socketEvent(FL_SOCKET fd, void *data); @@ -63,7 +59,7 @@ public: void framebufferUpdateStart(); void framebufferUpdateEnd(); - void dataRect(const rfb::Rect& r, int encoding); + bool dataRect(const rfb::Rect& r, int encoding); void setCursor(int width, int height, const rfb::Point& hotspot, const rdr::U8* data); diff --git a/win/rfb_win32/SocketManager.cxx b/win/rfb_win32/SocketManager.cxx index 0092d94d..393e2191 100644 --- a/win/rfb_win32/SocketManager.cxx +++ b/win/rfb_win32/SocketManager.cxx @@ -170,6 +170,13 @@ int SocketManager::checkTimeouts() { j_next = j; j_next++; if (j->second.sock->isShutdown()) shutdownSocks.push_back(j->second.sock); + else { + long eventMask = FD_READ | FD_CLOSE; + if (j->second.sock->outStream().hasBufferedData()) + eventMask |= FD_WRITE; + if (WSAEventSelect(j->second.sock->getFd(), j->first, eventMask) == SOCKET_ERROR) + throw rdr::SystemException("unable to adjust WSAEventSelect:%u", WSAGetLastError()); + } } std::list::iterator k; @@ -213,6 +220,13 @@ void SocketManager::processEvent(HANDLE event) { try { // Process data from an active connection + WSANETWORKEVENTS events; + long eventMask; + + // Fetch why this event notification triggered + if (WSAEnumNetworkEvents(ci.sock->getFd(), event, &events) == SOCKET_ERROR) + throw rdr::SystemException("unable to get WSAEnumNetworkEvents:%u", WSAGetLastError()); + // Cancel event notification for this socket if (WSAEventSelect(ci.sock->getFd(), event, 0) == SOCKET_ERROR) throw rdr::SystemException("unable to disable WSAEventSelect:%u", WSAGetLastError()); @@ -220,16 +234,29 @@ void SocketManager::processEvent(HANDLE event) { // Reset the event object WSAResetEvent(event); + // Call the socket server to process the event - ci.server->processSocketReadEvent(ci.sock); - if (ci.sock->isShutdown()) { - remSocket(ci.sock); - return; + if (events.lNetworkEvents & FD_WRITE) { + ci.server->processSocketWriteEvent(ci.sock); + if (ci.sock->isShutdown()) { + remSocket(ci.sock); + return; + } + } + if (events.lNetworkEvents & (FD_READ | FD_CLOSE)) { + ci.server->processSocketReadEvent(ci.sock); + if (ci.sock->isShutdown()) { + remSocket(ci.sock); + return; + } } // Re-instate the required socket event // If the read event is still valid, the event object gets set here - if (WSAEventSelect(ci.sock->getFd(), event, FD_READ | FD_CLOSE) == SOCKET_ERROR) + eventMask = FD_READ | FD_CLOSE; + if (ci.sock->outStream().hasBufferedData()) + eventMask |= FD_WRITE; + if (WSAEventSelect(ci.sock->getFd(), event, eventMask) == SOCKET_ERROR) throw rdr::SystemException("unable to re-enable WSAEventSelect:%u", WSAGetLastError()); } catch (rdr::Exception& e) { vlog.error("%s", e.str());