diff options
Diffstat (limited to 'common')
58 files changed, 1037 insertions, 516 deletions
diff --git a/common/rdr/BufferedInStream.cxx b/common/rdr/BufferedInStream.cxx index 14b73563..5a2694b4 100644 --- a/common/rdr/BufferedInStream.cxx +++ b/common/rdr/BufferedInStream.cxx @@ -47,7 +47,7 @@ size_t BufferedInStream::pos() return offset + ptr - start; } -bool BufferedInStream::overrun(size_t needed, bool wait) +bool BufferedInStream::overrun(size_t needed) { struct timeval now; @@ -112,7 +112,7 @@ bool BufferedInStream::overrun(size_t needed, bool wait) } while (avail() < needed) { - if (!fillBuffer(start + bufSize - end, wait)) + if (!fillBuffer(start + bufSize - end)) return false; } diff --git a/common/rdr/BufferedInStream.h b/common/rdr/BufferedInStream.h index 24b5a23c..84405255 100644 --- a/common/rdr/BufferedInStream.h +++ b/common/rdr/BufferedInStream.h @@ -38,9 +38,9 @@ namespace rdr { virtual size_t pos(); private: - virtual bool fillBuffer(size_t maxSize, bool wait) = 0; + virtual bool fillBuffer(size_t maxSize) = 0; - virtual bool overrun(size_t needed, bool wait); + virtual bool overrun(size_t needed); private: size_t bufSize; diff --git a/common/rdr/BufferedOutStream.cxx b/common/rdr/BufferedOutStream.cxx index 930b80b9..c8f6ce9c 100644 --- a/common/rdr/BufferedOutStream.cxx +++ b/common/rdr/BufferedOutStream.cxx @@ -60,7 +60,7 @@ void BufferedOutStream::flush() len = (ptr - sentUpTo); - if (!flushBuffer(false)) + if (!flushBuffer()) break; offset += len - (ptr - sentUpTo); @@ -148,4 +148,6 @@ void BufferedOutStream::overrun(size_t needed) gettimeofday(&lastSizeCheck, NULL); peakUsage = totalNeeded; + + return; } diff --git a/common/rdr/BufferedOutStream.h b/common/rdr/BufferedOutStream.h index dd64a136..b01d1fee 100644 --- a/common/rdr/BufferedOutStream.h +++ b/common/rdr/BufferedOutStream.h @@ -45,10 +45,9 @@ namespace rdr { private: // flushBuffer() requests that the stream be flushed. Returns true if it is // able to progress the output (which might still not mean any bytes - // actually moved) and can be called again. If wait is true then it will - // block until all data has been written. + // actually moved) and can be called again. - virtual bool flushBuffer(bool wait) = 0; + virtual bool flushBuffer() = 0; virtual void overrun(size_t needed); diff --git a/common/rdr/Exception.h b/common/rdr/Exception.h index eb3c8a9d..e5bff80d 100644 --- a/common/rdr/Exception.h +++ b/common/rdr/Exception.h @@ -47,10 +47,6 @@ namespace rdr { GAIException(const char* s, int err_); }; - struct TimedOut : public Exception { - TimedOut() : Exception("Timed out") {} - }; - struct EndOfStream : public Exception { EndOfStream() : Exception("End of stream") {} }; diff --git a/common/rdr/FdInStream.cxx b/common/rdr/FdInStream.cxx index 27de92bb..ecc34ecd 100644 --- a/common/rdr/FdInStream.cxx +++ b/common/rdr/FdInStream.cxx @@ -46,17 +46,8 @@ using namespace rdr; -enum { DEFAULT_BUF_SIZE = 8192 }; - -FdInStream::FdInStream(int fd_, int timeoutms_, - bool closeWhenDone_) - : fd(fd_), closeWhenDone(closeWhenDone_), - timeoutms(timeoutms_), blockCallback(0) -{ -} - -FdInStream::FdInStream(int fd_, FdInStreamBlockCallback* blockCallback_) - : fd(fd_), timeoutms(0), blockCallback(blockCallback_) +FdInStream::FdInStream(int fd_, bool closeWhenDone_) + : fd(fd_), closeWhenDone(closeWhenDone_) { } @@ -66,20 +57,9 @@ FdInStream::~FdInStream() } -void FdInStream::setTimeout(int timeoutms_) { - timeoutms = timeoutms_; -} - -void FdInStream::setBlockCallback(FdInStreamBlockCallback* blockCallback_) -{ - blockCallback = blockCallback_; - timeoutms = 0; -} - - -bool FdInStream::fillBuffer(size_t maxSize, bool wait) +bool FdInStream::fillBuffer(size_t maxSize) { - size_t n = readWithTimeoutOrCallback((U8*)end, maxSize, wait); + size_t n = readFd((U8*)end, maxSize); if (n == 0) return false; end += n; @@ -88,55 +68,43 @@ bool FdInStream::fillBuffer(size_t maxSize, bool wait) } // -// readWithTimeoutOrCallback() reads up to the given length in bytes from the -// file descriptor into a buffer. If the wait argument is false, then zero is -// returned if no bytes can be read without blocking. Otherwise if a -// blockCallback is set, it will be called (repeatedly) instead of blocking. -// If alternatively there is a timeout set and that timeout expires, it throws -// a TimedOut exception. Otherwise it returns the number of bytes read. It +// readFd() reads up to the given length in bytes from the +// file descriptor into a buffer. Zero is +// returned if no bytes can be read. Otherwise it returns the number of bytes read. It // never attempts to recv() unless select() indicates that the fd is readable - // this means it can be used on an fd which has been set non-blocking. It also // has to cope with the annoying possibility of both select() and recv() // returning EINTR. // -size_t FdInStream::readWithTimeoutOrCallback(void* buf, size_t len, bool wait) +size_t FdInStream::readFd(void* buf, size_t len) { int n; - while (true) { - do { - fd_set fds; - struct timeval tv; - struct timeval* tvp = &tv; - - if (!wait) { - tv.tv_sec = tv.tv_usec = 0; - } else if (timeoutms != -1) { - tv.tv_sec = timeoutms / 1000; - tv.tv_usec = (timeoutms % 1000) * 1000; - } else { - tvp = 0; - } - - FD_ZERO(&fds); - FD_SET(fd, &fds); - n = select(fd+1, &fds, 0, 0, tvp); - } while (n < 0 && errno == EINTR); - - if (n > 0) break; - if (n < 0) throw SystemException("select",errno); - if (!wait) return 0; - if (!blockCallback) throw TimedOut(); - - blockCallback->blockCallback(); - } + do { + fd_set fds; + struct timeval tv; + + tv.tv_sec = tv.tv_usec = 0; + + FD_ZERO(&fds); + FD_SET(fd, &fds); + n = select(fd+1, &fds, 0, 0, &tv); + } while (n < 0 && errno == EINTR); + + if (n < 0) + throw SystemException("select",errno); + + if (n == 0) + return 0; do { n = ::recv(fd, (char*)buf, len, 0); } while (n < 0 && errno == EINTR); - if (n < 0) throw SystemException("read",errno); - if (n == 0) throw EndOfStream(); + if (n < 0) + throw SystemException("read",errno); + if (n == 0) + throw EndOfStream(); return n; } diff --git a/common/rdr/FdInStream.h b/common/rdr/FdInStream.h index 0203389b..f732ceaa 100644 --- a/common/rdr/FdInStream.h +++ b/common/rdr/FdInStream.h @@ -27,33 +27,22 @@ namespace rdr { - class FdInStreamBlockCallback { - public: - virtual void blockCallback() = 0; - virtual ~FdInStreamBlockCallback() {} - }; - class FdInStream : public BufferedInStream { public: - FdInStream(int fd, int timeoutms=-1, bool closeWhenDone_=false); - FdInStream(int fd, FdInStreamBlockCallback* blockCallback); + FdInStream(int fd, bool closeWhenDone_=false); virtual ~FdInStream(); - void setTimeout(int timeoutms); - void setBlockCallback(FdInStreamBlockCallback* blockCallback); int getFd() { return fd; } private: - virtual bool fillBuffer(size_t maxSize, bool wait); + virtual bool fillBuffer(size_t maxSize); - size_t readWithTimeoutOrCallback(void* buf, size_t len, bool wait=true); + size_t readFd(void* buf, size_t len); int fd; bool closeWhenDone; - int timeoutms; - FdInStreamBlockCallback* blockCallback; size_t offset; U8* start; diff --git a/common/rdr/FdOutStream.cxx b/common/rdr/FdOutStream.cxx index 3405838d..b52fc85d 100644 --- a/common/rdr/FdOutStream.cxx +++ b/common/rdr/FdOutStream.cxx @@ -49,27 +49,14 @@ using namespace rdr; -FdOutStream::FdOutStream(int fd_, bool blocking_, int timeoutms_) - : fd(fd_), blocking(blocking_), timeoutms(timeoutms_) +FdOutStream::FdOutStream(int fd_) + : fd(fd_) { gettimeofday(&lastWrite, NULL); } FdOutStream::~FdOutStream() { - try { - while (sentUpTo != ptr) - flushBuffer(true); - } catch (Exception&) { - } -} - -void FdOutStream::setTimeout(int timeoutms_) { - timeoutms = timeoutms_; -} - -void FdOutStream::setBlocking(bool blocking_) { - blocking = blocking_; } unsigned FdOutStream::getIdleTime() @@ -87,20 +74,11 @@ void FdOutStream::cork(bool enable) #endif } -bool FdOutStream::flushBuffer(bool wait) +bool FdOutStream::flushBuffer() { - size_t n = writeWithTimeout((const void*) sentUpTo, - ptr - sentUpTo, - (blocking || wait)? timeoutms : 0); - - // Timeout? - if (n == 0) { - // If non-blocking then we're done here - if (!blocking && !wait) - return false; - - throw TimedOut(); - } + size_t n = writeFd((const void*) sentUpTo, ptr - sentUpTo); + if (n == 0) + return false; sentUpTo += n; @@ -108,34 +86,27 @@ bool FdOutStream::flushBuffer(bool wait) } // -// writeWithTimeout() writes up to the given length in bytes from the given -// buffer to the file descriptor. If there is a timeout set and that timeout -// expires, it throws a TimedOut exception. Otherwise it returns the number of -// bytes written. It never attempts to send() unless select() indicates that -// the fd is writable - this means it can be used on an fd which has been set -// non-blocking. It also has to cope with the annoying possibility of both -// select() and send() returning EINTR. +// writeFd() writes up to the given length in bytes from the given +// buffer to the file descriptor. It returns the number of bytes written. It +// never attempts to send() unless select() indicates that the fd is writable +// - this means it can be used on an fd which has been set non-blocking. It +// also has to cope with the annoying possibility of both select() and send() +// returning EINTR. // -size_t FdOutStream::writeWithTimeout(const void* data, size_t length, int timeoutms) +size_t FdOutStream::writeFd(const void* data, size_t length) { int n; do { fd_set fds; struct timeval tv; - struct timeval* tvp = &tv; - if (timeoutms != -1) { - tv.tv_sec = timeoutms / 1000; - tv.tv_usec = (timeoutms % 1000) * 1000; - } else { - tvp = NULL; - } + tv.tv_sec = tv.tv_usec = 0; FD_ZERO(&fds); FD_SET(fd, &fds); - n = select(fd+1, 0, &fds, 0, tvp); + n = select(fd+1, 0, &fds, 0, &tv); } while (n < 0 && errno == EINTR); if (n < 0) diff --git a/common/rdr/FdOutStream.h b/common/rdr/FdOutStream.h index b1fb74c0..80804da4 100644 --- a/common/rdr/FdOutStream.h +++ b/common/rdr/FdOutStream.h @@ -34,11 +34,9 @@ namespace rdr { public: - FdOutStream(int fd, bool blocking=true, int timeoutms=-1); + FdOutStream(int fd); virtual ~FdOutStream(); - void setTimeout(int timeoutms); - void setBlocking(bool blocking); int getFd() { return fd; } unsigned getIdleTime(); @@ -46,11 +44,9 @@ namespace rdr { virtual void cork(bool enable); private: - virtual bool flushBuffer(bool wait); - size_t writeWithTimeout(const void* data, size_t length, int timeoutms); + virtual bool flushBuffer(); + size_t writeFd(const void* data, size_t length); int fd; - bool blocking; - int timeoutms; struct timeval lastWrite; }; diff --git a/common/rdr/FileInStream.cxx b/common/rdr/FileInStream.cxx index 66dfe766..9975fde6 100644 --- a/common/rdr/FileInStream.cxx +++ b/common/rdr/FileInStream.cxx @@ -39,7 +39,7 @@ FileInStream::~FileInStream(void) { } } -bool FileInStream::fillBuffer(size_t maxSize, bool wait) +bool FileInStream::fillBuffer(size_t maxSize) { size_t n = fread((U8 *)end, 1, maxSize, file); if (n == 0) { diff --git a/common/rdr/FileInStream.h b/common/rdr/FileInStream.h index 268f5375..619397f0 100644 --- a/common/rdr/FileInStream.h +++ b/common/rdr/FileInStream.h @@ -34,7 +34,7 @@ namespace rdr { ~FileInStream(void); private: - virtual bool fillBuffer(size_t maxSize, bool wait); + virtual bool fillBuffer(size_t maxSize); private: FILE *file; diff --git a/common/rdr/HexInStream.cxx b/common/rdr/HexInStream.cxx index 322432c0..66bbf174 100644 --- a/common/rdr/HexInStream.cxx +++ b/common/rdr/HexInStream.cxx @@ -73,7 +73,7 @@ decodeError: bool HexInStream::fillBuffer(size_t maxSize, bool wait) { - if (!in_stream.check(2, wait)) + if (!in_stream.hasData(2)) return false; size_t length = min(in_stream.avail()/2, maxSize); diff --git a/common/rdr/InStream.h b/common/rdr/InStream.h index 5d873011..60ea4997 100644 --- a/common/rdr/InStream.h +++ b/common/rdr/InStream.h @@ -1,4 +1,5 @@ /* Copyright (C) 2002-2005 RealVNC Ltd. All Rights Reserved. + * Copyright 2014-2020 Pierre Ossman for Cendio AB * * This is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -28,6 +29,10 @@ #include <rdr/Exception.h> #include <string.h> // for memcpy +// Check that callers are using InStream properly, +// useful when writing new protocol handling +#undef RFB_INSTREAM_CHECK + namespace rdr { class InStream { @@ -39,29 +44,79 @@ namespace rdr { // avail() returns the number of bytes that are currenctly directly // available from the stream. - inline size_t avail() - { + inline size_t avail() { +#ifdef RFB_INSTREAM_CHECK + checkedBytes = end - ptr; +#endif + return end - ptr; } - // check() ensures there is buffer data for at least needed bytes. Returns - // true once the data is available. If wait is false, then instead of - // blocking to wait for the bytes, false is returned if the bytes are not - // immediately available. + // hasData() ensures there is at least "length" bytes of buffer data, + // possibly trying to fetch more data if there isn't enough right away + + inline bool hasData(size_t length) { +#ifdef RFB_INSTREAM_CHECK + checkedBytes = 0; +#endif + + if (length > (size_t)(end - ptr)) { + if (restorePoint != NULL) { + bool ret; + size_t restoreDiff; + + restoreDiff = ptr - restorePoint; + ptr = restorePoint; + + ret = overrun(length + restoreDiff); - inline size_t check(size_t needed, bool wait=true) - { - if (needed > avail()) - return overrun(needed, wait); + restorePoint = ptr; + ptr += restoreDiff; + + if (!ret) + return false; + } else { + if (!overrun(length)) + return false; + } + } + +#ifdef RFB_INSTREAM_CHECK + checkedBytes = length; +#endif return true; } - // checkNoWait() tries to make sure that the given number of bytes can - // be read without blocking. It returns true if this is the case, false - // otherwise. The length must be "small" (less than the buffer size). + inline bool hasDataOrRestore(size_t length) { + if (hasData(length)) + return true; + gotoRestorePoint(); + return false; + } - inline bool checkNoWait(size_t length) { return check(length, false); } + inline void setRestorePoint() { +#ifdef RFB_INSTREAM_CHECK + if (restorePoint != NULL) + throw Exception("Nested use of input stream restore point"); +#endif + restorePoint = ptr; + } + inline void clearRestorePoint() { +#ifdef RFB_INSTREAM_CHECK + if (restorePoint == NULL) + throw Exception("Incorrect clearing of input stream restore point"); +#endif + restorePoint = NULL; + } + inline void gotoRestorePoint() { +#ifdef RFB_INSTREAM_CHECK + if (restorePoint == NULL) + throw Exception("Incorrect activation of input stream restore point"); +#endif + ptr = restorePoint; + clearRestorePoint(); + } // readU/SN() methods read unsigned and signed N-bit integers. @@ -76,24 +131,19 @@ namespace rdr { inline S16 readS16() { return (S16)readU16(); } inline S32 readS32() { return (S32)readU32(); } + // skip() ignores a number of bytes on the stream + inline void skip(size_t bytes) { - while (bytes > 0) { - size_t n = check(1, bytes); - ptr += n; - bytes -= n; - } + check(bytes); + ptr += bytes; } // readBytes() reads an exact number of bytes. void readBytes(void* data, size_t length) { - while (length > 0) { - size_t n = check(1, length); - memcpy(data, ptr, n); - ptr += n; - data = (U8*)data + n; - length -= n; - } + check(length); + memcpy(data, ptr, length); + ptr += length; } // readOpaqueN() reads a quantity without byte-swapping. @@ -113,24 +163,45 @@ namespace rdr { // to the buffer. This is useful for a stream which is a wrapper around an // some other stream API. - inline const U8* getptr(size_t length) { check(length); return ptr; } + inline const U8* getptr(size_t length) { check(length); +#ifdef RFB_INSTREAM_CHECK + checkedBytes += length; +#endif + return ptr; } inline void setptr(size_t length) { if (length > avail()) throw Exception("Input stream overflow"); skip(length); } private: + const U8* restorePoint; +#ifdef RFB_INSTREAM_CHECK + size_t checkedBytes; +#endif + + inline void check(size_t bytes) { +#ifdef RFB_INSTREAM_CHECK + if (bytes > checkedBytes) + throw Exception("Input stream used without underrun check"); + checkedBytes -= bytes; +#endif + if (bytes > (size_t)(end - ptr)) + throw Exception("InStream buffer underrun"); + } + // overrun() is implemented by a derived class to cope with buffer overrun. - // It ensures there are at least needed bytes of buffer data. Returns true - // once the data is available. If wait is false, then instead of blocking - // to wait for the bytes, false is returned if the bytes are not - // immediately available. + // It tries to ensure there are at least needed bytes of buffer data. + // Returns true if it managed to satisfy the request, or false otherwise. - virtual bool overrun(size_t needed, bool wait=true) = 0; + virtual bool overrun(size_t needed) = 0; protected: - InStream() {} + InStream() : restorePoint(NULL) +#ifdef RFB_INSTREAM_CHECK + ,checkedBytes(0) +#endif + {} const U8* ptr; const U8* end; }; diff --git a/common/rdr/MemInStream.h b/common/rdr/MemInStream.h index 83740dd9..a5196594 100644 --- a/common/rdr/MemInStream.h +++ b/common/rdr/MemInStream.h @@ -41,6 +41,12 @@ namespace rdr { { ptr = start; end = start + len; + +#ifdef RFB_INSTREAM_CHECK + // MemInStream cannot add more data, so callers are assumed to already + // new the total size + avail(); +#endif } virtual ~MemInStream() { @@ -53,7 +59,7 @@ namespace rdr { private: - bool overrun(size_t needed, bool wait) { throw EndOfStream(); } + bool overrun(size_t needed) { throw EndOfStream(); } const U8* start; bool deleteWhenDone; }; diff --git a/common/rdr/OutStream.h b/common/rdr/OutStream.h index f432520f..61d5100b 100644 --- a/common/rdr/OutStream.h +++ b/common/rdr/OutStream.h @@ -49,14 +49,6 @@ namespace rdr { return end - ptr; } - // check() ensures there is buffer space for at least needed bytes. - - inline void check(size_t needed) - { - if (needed > avail()) - overrun(needed); - } - // writeU/SN() methods write unsigned and signed N-bit integers. inline void writeU8( U8 u) { check(1); *ptr++ = u; } @@ -136,6 +128,12 @@ namespace rdr { private: + inline void check(size_t length) + { + if (length > avail()) + overrun(length); + } + // overrun() is implemented by a derived class to cope with buffer overrun. // It ensures there are at least needed bytes of buffer space. diff --git a/common/rdr/RandomStream.cxx b/common/rdr/RandomStream.cxx index 6333be3f..e2da0957 100644 --- a/common/rdr/RandomStream.cxx +++ b/common/rdr/RandomStream.cxx @@ -79,7 +79,7 @@ RandomStream::~RandomStream() { #endif } -bool RandomStream::fillBuffer(size_t maxSize, bool wait) { +bool RandomStream::fillBuffer(size_t maxSize) { #ifdef RFB_HAVE_WINCRYPT if (provider) { if (!CryptGenRandom(provider, maxSize, (U8*)end)) diff --git a/common/rdr/RandomStream.h b/common/rdr/RandomStream.h index 08ae0ff6..58986433 100644 --- a/common/rdr/RandomStream.h +++ b/common/rdr/RandomStream.h @@ -40,7 +40,7 @@ namespace rdr { virtual ~RandomStream(); private: - virtual bool fillBuffer(size_t maxSize, bool wait); + virtual bool fillBuffer(size_t maxSize); private: static unsigned int seed; diff --git a/common/rdr/TLSInStream.cxx b/common/rdr/TLSInStream.cxx index ba20e752..2339956d 100644 --- a/common/rdr/TLSInStream.cxx +++ b/common/rdr/TLSInStream.cxx @@ -39,7 +39,7 @@ ssize_t TLSInStream::pull(gnutls_transport_ptr_t str, void* data, size_t size) InStream *in = self->in; try { - if (!in->check(1, false)) { + if (!in->hasData(1)) { gnutls_transport_set_errno(self->session, EAGAIN); return -1; } @@ -74,23 +74,22 @@ TLSInStream::~TLSInStream() gnutls_transport_set_pull_function(session, NULL); } -bool TLSInStream::fillBuffer(size_t maxSize, bool wait) +bool TLSInStream::fillBuffer(size_t maxSize) { - size_t n = readTLS((U8*) end, maxSize, wait); - if (!wait && n == 0) + size_t n = readTLS((U8*) end, maxSize); + if (n == 0) return false; end += n; return true; } -size_t TLSInStream::readTLS(U8* buf, size_t len, bool wait) +size_t TLSInStream::readTLS(U8* buf, size_t len) { int n; if (gnutls_record_check_pending(session) == 0) { - n = in->check(1, wait); - if (n == 0) + if (!in->hasData(1)) return 0; } diff --git a/common/rdr/TLSInStream.h b/common/rdr/TLSInStream.h index 9779c68e..df5ebb48 100644 --- a/common/rdr/TLSInStream.h +++ b/common/rdr/TLSInStream.h @@ -37,8 +37,8 @@ namespace rdr { virtual ~TLSInStream(); private: - virtual bool fillBuffer(size_t maxSize, bool wait); - size_t readTLS(U8* buf, size_t len, bool wait); + virtual bool fillBuffer(size_t maxSize); + size_t readTLS(U8* buf, size_t len); static ssize_t pull(gnutls_transport_ptr_t str, void* data, size_t size); gnutls_session_t session; diff --git a/common/rdr/ZlibInStream.cxx b/common/rdr/ZlibInStream.cxx index 26977228..0cacc21f 100644 --- a/common/rdr/ZlibInStream.cxx +++ b/common/rdr/ZlibInStream.cxx @@ -45,7 +45,7 @@ void ZlibInStream::setUnderlying(InStream* is, size_t bytesIn_) void ZlibInStream::flushUnderlying() { while (bytesIn > 0) { - if (!check(1)) + if (!hasData(1)) throw Exception("ZlibInStream: failed to flush remaining stream data"); skip(avail()); } @@ -85,7 +85,7 @@ void ZlibInStream::deinit() zs = NULL; } -bool ZlibInStream::fillBuffer(size_t maxSize, bool wait) +bool ZlibInStream::fillBuffer(size_t maxSize) { if (!underlying) throw Exception("ZlibInStream overrun: no underlying stream"); @@ -93,8 +93,8 @@ bool ZlibInStream::fillBuffer(size_t maxSize, bool wait) zs->next_out = (U8*)end; zs->avail_out = maxSize; - size_t n = underlying->check(1, wait); - if (n == 0) return false; + if (!underlying->hasData(1)) + return false; size_t length = underlying->avail(); if (length > bytesIn) length = bytesIn; diff --git a/common/rdr/ZlibInStream.h b/common/rdr/ZlibInStream.h index 1597b54a..302c99d8 100644 --- a/common/rdr/ZlibInStream.h +++ b/common/rdr/ZlibInStream.h @@ -44,7 +44,7 @@ namespace rdr { void init(); void deinit(); - virtual bool fillBuffer(size_t maxSize, bool wait); + virtual bool fillBuffer(size_t maxSize); private: InStream* underlying; diff --git a/common/rfb/CConnection.cxx b/common/rfb/CConnection.cxx index cb1b84bd..9c957a54 100644 --- a/common/rfb/CConnection.cxx +++ b/common/rfb/CConnection.cxx @@ -120,16 +120,19 @@ void CConnection::initialiseProtocol() state_ = RFBSTATE_PROTOCOL_VERSION; } -void CConnection::processMsg() +bool CConnection::processMsg() { switch (state_) { - case RFBSTATE_PROTOCOL_VERSION: processVersionMsg(); break; - case RFBSTATE_SECURITY_TYPES: processSecurityTypesMsg(); break; - case RFBSTATE_SECURITY: processSecurityMsg(); break; - case RFBSTATE_SECURITY_RESULT: processSecurityResultMsg(); break; - case RFBSTATE_INITIALISATION: processInitMsg(); break; - case RFBSTATE_NORMAL: reader_->readMsg(); break; + case RFBSTATE_PROTOCOL_VERSION: return processVersionMsg(); break; + case RFBSTATE_SECURITY_TYPES: return processSecurityTypesMsg(); break; + case RFBSTATE_SECURITY: return processSecurityMsg(); break; + case RFBSTATE_SECURITY_RESULT: return processSecurityResultMsg(); break; + case RFBSTATE_SECURITY_REASON: return processSecurityReasonMsg(); break; + case RFBSTATE_INITIALISATION: return processInitMsg(); break; + case RFBSTATE_NORMAL: return reader_->readMsg(); break; + case RFBSTATE_CLOSING: + throw Exception("CConnection::processMsg: called while closing"); case RFBSTATE_UNINITIALISED: throw Exception("CConnection::processMsg: not initialised yet?"); default: @@ -137,7 +140,7 @@ void CConnection::processMsg() } } -void CConnection::processVersionMsg() +bool CConnection::processVersionMsg() { char verStr[27]; // FIXME: gcc has some bug in format-overflow int majorVersion; @@ -145,8 +148,8 @@ void CConnection::processVersionMsg() vlog.debug("reading protocol version"); - if (!is->checkNoWait(12)) - return; + if (!is->hasData(12)) + return false; is->readBytes(verStr, 12); verStr[12] = '\0'; @@ -184,10 +187,12 @@ void CConnection::processVersionMsg() vlog.info("Using RFB protocol version %d.%d", server.majorVersion, server.minorVersion); + + return true; } -void CConnection::processSecurityTypesMsg() +bool CConnection::processSecurityTypesMsg() { vlog.debug("processing security types message"); @@ -200,10 +205,13 @@ void CConnection::processSecurityTypesMsg() // legacy 3.3 server may only offer "vnc authentication" or "none" + if (!is->hasData(4)) + return false; + secType = is->readU32(); if (secType == secTypeInvalid) { - throwConnFailedException(); - + state_ = RFBSTATE_SECURITY_REASON; + return true; } else if (secType == secTypeNone || secType == secTypeVncAuth) { std::list<rdr::U8>::iterator i; for (i = secTypes.begin(); i != secTypes.end(); i++) @@ -223,9 +231,21 @@ void CConnection::processSecurityTypesMsg() // >=3.7 server will offer us a list + if (!is->hasData(1)) + return false; + + is->setRestorePoint(); + int nServerSecTypes = is->readU8(); - if (nServerSecTypes == 0) - throwConnFailedException(); + + if (!is->hasDataOrRestore(nServerSecTypes)) + return false; + is->clearRestorePoint(); + + if (nServerSecTypes == 0) { + state_ = RFBSTATE_SECURITY_REASON; + return true; + } std::list<rdr::U8>::iterator j; @@ -263,32 +283,38 @@ void CConnection::processSecurityTypesMsg() state_ = RFBSTATE_SECURITY; csecurity = security.GetCSecurity(this, secType); - processSecurityMsg(); + + return true; } -void CConnection::processSecurityMsg() +bool CConnection::processSecurityMsg() { vlog.debug("processing security message"); - if (csecurity->processMsg()) { - state_ = RFBSTATE_SECURITY_RESULT; - processSecurityResultMsg(); - } + if (!csecurity->processMsg()) + return false; + + state_ = RFBSTATE_SECURITY_RESULT; + + return true; } -void CConnection::processSecurityResultMsg() +bool CConnection::processSecurityResultMsg() { vlog.debug("processing security result message"); int result; + if (server.beforeVersion(3,8) && csecurity->getType() == secTypeNone) { result = secResultOK; } else { - if (!is->checkNoWait(1)) return; + if (!is->hasData(4)) + return false; result = is->readU32(); } + switch (result) { case secResultOK: securityCompleted(); - return; + return true; case secResultFailed: vlog.debug("auth failed"); break; @@ -298,30 +324,42 @@ void CConnection::processSecurityResultMsg() default: throw Exception("Unknown security result from server"); } - state_ = RFBSTATE_INVALID; - if (server.beforeVersion(3,8)) + + if (server.beforeVersion(3,8)) { + state_ = RFBSTATE_INVALID; throw AuthFailureException(); + } + + state_ = RFBSTATE_SECURITY_REASON; + return true; +} + +bool CConnection::processSecurityReasonMsg() +{ + vlog.debug("processing security reason message"); + + if (!is->hasData(4)) + return false; + + is->setRestorePoint(); + rdr::U32 len = is->readU32(); + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + CharArray reason(len + 1); is->readBytes(reason.buf, len); reason.buf[len] = '\0'; + + state_ = RFBSTATE_INVALID; throw AuthFailureException(reason.buf); } -void CConnection::processInitMsg() +bool CConnection::processInitMsg() { vlog.debug("reading server initialisation"); - reader_->readServerInit(); -} - -void CConnection::throwConnFailedException() -{ - state_ = RFBSTATE_INVALID; - rdr::U32 len = is->readU32(); - CharArray reason(len + 1); - is->readBytes(reason.buf, len); - reason.buf[len] = '\0'; - throw ConnFailedException(reason.buf); + return reader_->readServerInit(); } void CConnection::securityCompleted() @@ -429,11 +467,13 @@ void CConnection::serverInit(int width, int height, } } -void CConnection::readAndDecodeRect(const Rect& r, int encoding, +bool CConnection::readAndDecodeRect(const Rect& r, int encoding, ModifiablePixelBuffer* pb) { - decoder.decodeRect(r, encoding, pb); + if (!decoder.decodeRect(r, encoding, pb)) + return false; decoder.flush(); + return true; } void CConnection::framebufferUpdateStart() @@ -474,9 +514,9 @@ void CConnection::framebufferUpdateEnd() } } -void CConnection::dataRect(const Rect& r, int encoding) +bool CConnection::dataRect(const Rect& r, int encoding) { - decoder.decodeRect(r, encoding, framebuffer); + return decoder.decodeRect(r, encoding, framebuffer); } void CConnection::serverCutText(const char* str) diff --git a/common/rfb/CConnection.h b/common/rfb/CConnection.h index 3857be4d..3c2b04e8 100644 --- a/common/rfb/CConnection.h +++ b/common/rfb/CConnection.h @@ -84,7 +84,7 @@ namespace rfb { // In this case, processMsg should always process the available RFB // message before returning. // NB: In either case, you must have called initialiseProtocol() first. - void processMsg(); + bool processMsg(); // close() gracefully shuts down the connection to the server and // should be called before terminating the underlying network @@ -107,12 +107,12 @@ namespace rfb { const PixelFormat& pf, const char* name); - virtual void readAndDecodeRect(const Rect& r, int encoding, + virtual bool readAndDecodeRect(const Rect& r, int encoding, ModifiablePixelBuffer* pb); virtual void framebufferUpdateStart(); virtual void framebufferUpdateEnd(); - virtual void dataRect(const Rect& r, int encoding); + virtual bool dataRect(const Rect& r, int encoding); virtual void serverCutText(const char* str); @@ -216,6 +216,7 @@ namespace rfb { RFBSTATE_SECURITY_TYPES, RFBSTATE_SECURITY, RFBSTATE_SECURITY_RESULT, + RFBSTATE_SECURITY_REASON, RFBSTATE_INITIALISATION, RFBSTATE_NORMAL, RFBSTATE_CLOSING, @@ -249,13 +250,13 @@ namespace rfb { virtual void fence(rdr::U32 flags, unsigned len, const char data[]); private: - void processVersionMsg(); - void processSecurityTypesMsg(); - void processSecurityMsg(); - void processSecurityResultMsg(); - void processInitMsg(); + bool processVersionMsg(); + bool processSecurityTypesMsg(); + bool processSecurityMsg(); + bool processSecurityResultMsg(); + bool processSecurityReasonMsg(); + bool processInitMsg(); void throwAuthFailureException(); - void throwConnFailedException(); void securityCompleted(); void requestNewUpdate(); diff --git a/common/rfb/CMsgHandler.h b/common/rfb/CMsgHandler.h index 84dd115c..5b14806a 100644 --- a/common/rfb/CMsgHandler.h +++ b/common/rfb/CMsgHandler.h @@ -61,12 +61,12 @@ namespace rfb { const PixelFormat& pf, const char* name) = 0; - virtual void readAndDecodeRect(const Rect& r, int encoding, + virtual bool readAndDecodeRect(const Rect& r, int encoding, ModifiablePixelBuffer* pb) = 0; virtual void framebufferUpdateStart(); virtual void framebufferUpdateEnd(); - virtual void dataRect(const Rect& r, int encoding) = 0; + virtual bool dataRect(const Rect& r, int encoding) = 0; virtual void setColourMapEntries(int firstColour, int nColours, rdr::U16* rgbs) = 0; diff --git a/common/rfb/CMsgReader.cxx b/common/rfb/CMsgReader.cxx index a015ec99..40fb5912 100644 --- a/common/rfb/CMsgReader.cxx +++ b/common/rfb/CMsgReader.cxx @@ -39,7 +39,7 @@ using namespace rfb; CMsgReader::CMsgReader(CMsgHandler* handler_, rdr::InStream* is_) : imageBufIdealSize(0), handler(handler_), is(is_), - nUpdateRectsLeft(0) + state(MSGSTATE_IDLE), cursorEncoding(-1) { } @@ -47,149 +47,246 @@ CMsgReader::~CMsgReader() { } -void CMsgReader::readServerInit() +bool CMsgReader::readServerInit() { - int width = is->readU16(); - int height = is->readU16(); + int width, height; + rdr::U32 len; + + if (!is->hasData(2 + 2 + 16 + 4)) + return false; + + is->setRestorePoint(); + + width = is->readU16(); + height = is->readU16(); + PixelFormat pf; pf.read(is); - rdr::U32 len = is->readU32(); + + len = is->readU32(); + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); CharArray name(len + 1); is->readBytes(name.buf, len); name.buf[len] = '\0'; handler->serverInit(width, height, pf, name.buf); + + return true; } -void CMsgReader::readMsg() +bool CMsgReader::readMsg() { - if (nUpdateRectsLeft == 0) { - int type = is->readU8(); + if (state == MSGSTATE_IDLE) { + if (!is->hasData(1)) + return false; + + currentMsgType = is->readU8(); + state = MSGSTATE_MESSAGE; + } + + if (currentMsgType != msgTypeFramebufferUpdate) { + bool ret; - switch (type) { + switch (currentMsgType) { case msgTypeSetColourMapEntries: - readSetColourMapEntries(); + ret = readSetColourMapEntries(); break; case msgTypeBell: - readBell(); + ret = readBell(); break; case msgTypeServerCutText: - readServerCutText(); + ret = readServerCutText(); break; case msgTypeFramebufferUpdate: - readFramebufferUpdate(); + ret = readFramebufferUpdate(); break; case msgTypeServerFence: - readFence(); + ret = readFence(); break; case msgTypeEndOfContinuousUpdates: - readEndOfContinuousUpdates(); + ret = readEndOfContinuousUpdates(); break; default: - vlog.error("unknown message type %d", type); - throw Exception("unknown message type"); + throw Exception("Unknown message type %d", currentMsgType); } + + if (ret) + state = MSGSTATE_IDLE; + + return ret; } else { - int x = is->readU16(); - int y = is->readU16(); - int w = is->readU16(); - int h = is->readU16(); - int encoding = is->readS32(); + if (state == MSGSTATE_MESSAGE) { + if (!readFramebufferUpdate()) + return false; + + // Empty update? + if (nUpdateRectsLeft == 0) { + state = MSGSTATE_IDLE; + handler->framebufferUpdateEnd(); + return true; + } + + state = MSGSTATE_RECT_HEADER; + } + + if (state == MSGSTATE_RECT_HEADER) { + if (!is->hasData(12)) + return false; + + int x = is->readU16(); + int y = is->readU16(); + int w = is->readU16(); + int h = is->readU16(); + + dataRect.setXYWH(x, y, w, h); + + rectEncoding = is->readS32(); + + state = MSGSTATE_RECT_DATA; + } + + bool ret; - switch (encoding) { + switch (rectEncoding) { case pseudoEncodingLastRect: nUpdateRectsLeft = 1; // this rectangle is the last one + ret = true; break; case pseudoEncodingXCursor: - readSetXCursor(w, h, Point(x,y)); + ret = readSetXCursor(dataRect.width(), dataRect.height(), dataRect.tl); break; case pseudoEncodingCursor: - readSetCursor(w, h, Point(x,y)); + ret = readSetCursor(dataRect.width(), dataRect.height(), dataRect.tl); break; case pseudoEncodingCursorWithAlpha: - readSetCursorWithAlpha(w, h, Point(x,y)); + ret = readSetCursorWithAlpha(dataRect.width(), dataRect.height(), dataRect.tl); break; case pseudoEncodingVMwareCursor: - readSetVMwareCursor(w, h, Point(x,y)); + ret = readSetVMwareCursor(dataRect.width(), dataRect.height(), dataRect.tl); break; case pseudoEncodingDesktopName: - readSetDesktopName(x, y, w, h); + ret = readSetDesktopName(dataRect.tl.x, dataRect.tl.y, + dataRect.width(), dataRect.height()); break; case pseudoEncodingDesktopSize: - handler->setDesktopSize(w, h); + handler->setDesktopSize(dataRect.width(), dataRect.height()); + ret = true; break; case pseudoEncodingExtendedDesktopSize: - readExtendedDesktopSize(x, y, w, h); + ret = readExtendedDesktopSize(dataRect.tl.x, dataRect.tl.y, + dataRect.width(), dataRect.height()); break; case pseudoEncodingLEDState: - readLEDState(); + ret = readLEDState(); break; case pseudoEncodingVMwareLEDState: - readVMwareLEDState(); + ret = readVMwareLEDState(); break; case pseudoEncodingQEMUKeyEvent: handler->supportsQEMUKeyEvent(); + ret = true; break; default: - readRect(Rect(x, y, x+w, y+h), encoding); + ret = readRect(dataRect, rectEncoding); break; }; - nUpdateRectsLeft--; - if (nUpdateRectsLeft == 0) - handler->framebufferUpdateEnd(); + if (ret) { + state = MSGSTATE_RECT_HEADER; + nUpdateRectsLeft--; + if (nUpdateRectsLeft == 0) { + state = MSGSTATE_IDLE; + handler->framebufferUpdateEnd(); + } + } + + return ret; } } -void CMsgReader::readSetColourMapEntries() +bool CMsgReader::readSetColourMapEntries() { + if (!is->hasData(5)) + return false; + + is->setRestorePoint(); + is->skip(1); int firstColour = is->readU16(); int nColours = is->readU16(); + + if (!is->hasDataOrRestore(nColours * 3 * 2)) + return false; + is->clearRestorePoint(); + rdr::U16Array rgbs(nColours * 3); for (int i = 0; i < nColours * 3; i++) rgbs.buf[i] = is->readU16(); handler->setColourMapEntries(firstColour, nColours, rgbs.buf); + + return true; } -void CMsgReader::readBell() +bool CMsgReader::readBell() { handler->bell(); + return true; } -void CMsgReader::readServerCutText() +bool CMsgReader::readServerCutText() { + if (!is->hasData(7)) + return false; + + is->setRestorePoint(); + is->skip(3); rdr::U32 len = is->readU32(); if (len & 0x80000000) { rdr::S32 slen = len; slen = -slen; - readExtendedClipboard(slen); - return; + if (readExtendedClipboard(slen)) { + is->clearRestorePoint(); + return true; + } else { + is->gotoRestorePoint(); + return false; + } } + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + if (len > (size_t)maxCutText) { is->skip(len); vlog.error("cut text too long (%d bytes) - ignoring",len); - return; + return true; } CharArray ca(len); is->readBytes(ca.buf, len); CharArray filtered(convertLF(ca.buf, len)); handler->serverCutText(filtered.buf); + + return true; } -void CMsgReader::readExtendedClipboard(rdr::S32 len) +bool CMsgReader::readExtendedClipboard(rdr::S32 len) { rdr::U32 flags; rdr::U32 action; + if (!is->hasData(len)) + return false; + if (len < 4) throw Exception("Invalid extended clipboard message"); if (len > maxCutText) { vlog.error("Extended clipboard message too long (%d bytes) - ignoring", len); is->skip(len); - return; + return true; } flags = is->readU32(); @@ -231,7 +328,14 @@ void CMsgReader::readExtendedClipboard(rdr::S32 len) if (!(flags & 1 << i)) continue; + if (!zis.hasData(4)) + throw Exception("Extended clipboard decode error"); + lengths[num] = zis.readU32(); + + if (!zis.hasData(lengths[num])) + throw Exception("Extended clipboard decode error"); + if (lengths[num] > (size_t)maxCutText) { vlog.error("Extended clipboard data too long (%d bytes) - ignoring", (unsigned)lengths[num]); @@ -271,43 +375,63 @@ void CMsgReader::readExtendedClipboard(rdr::S32 len) throw Exception("Invalid extended clipboard action"); } } + + return true; } -void CMsgReader::readFence() +bool CMsgReader::readFence() { rdr::U32 flags; rdr::U8 len; char data[64]; + if (!is->hasData(8)) + return false; + + is->setRestorePoint(); + is->skip(3); flags = is->readU32(); len = is->readU8(); + + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + if (len > sizeof(data)) { vlog.error("Ignoring fence with too large payload"); is->skip(len); - return; + return true; } is->readBytes(data, len); handler->fence(flags, len, data); + + return true; } -void CMsgReader::readEndOfContinuousUpdates() +bool CMsgReader::readEndOfContinuousUpdates() { handler->endOfContinuousUpdates(); + return true; } -void CMsgReader::readFramebufferUpdate() +bool CMsgReader::readFramebufferUpdate() { + if (!is->hasData(3)) + return false; + is->skip(1); nUpdateRectsLeft = is->readU16(); handler->framebufferUpdateStart(); + + return true; } -void CMsgReader::readRect(const Rect& r, int encoding) +bool CMsgReader::readRect(const Rect& r, int encoding) { if ((r.br.x > handler->server.width()) || (r.br.y > handler->server.height())) { @@ -320,10 +444,10 @@ void CMsgReader::readRect(const Rect& r, int encoding) if (r.is_empty()) vlog.error("zero size rect"); - handler->dataRect(r, encoding); + return handler->dataRect(r, encoding); } -void CMsgReader::readSetXCursor(int width, int height, const Point& hotspot) +bool CMsgReader::readSetXCursor(int width, int height, const Point& hotspot) { if (width > maxCursorSize || height > maxCursorSize) throw Exception("Too big cursor"); @@ -341,6 +465,9 @@ void CMsgReader::readSetXCursor(int width, int height, const Point& hotspot) int x, y; rdr::U8* out; + if (!is->hasData(3 + 3 + data_len + mask_len)) + return false; + pr = is->readU8(); pg = is->readU8(); pb = is->readU8(); @@ -380,9 +507,11 @@ void CMsgReader::readSetXCursor(int width, int height, const Point& hotspot) } handler->setCursor(width, height, hotspot, rgba.buf); + + return true; } -void CMsgReader::readSetCursor(int width, int height, const Point& hotspot) +bool CMsgReader::readSetCursor(int width, int height, const Point& hotspot) { if (width > maxCursorSize || height > maxCursorSize) throw Exception("Too big cursor"); @@ -397,6 +526,9 @@ void CMsgReader::readSetCursor(int width, int height, const Point& hotspot) rdr::U8* in; rdr::U8* out; + if (!is->hasData(data_len + mask_len)) + return false; + is->readBytes(data.buf, data_len); is->readBytes(mask.buf, mask_len); @@ -421,29 +553,44 @@ void CMsgReader::readSetCursor(int width, int height, const Point& hotspot) } handler->setCursor(width, height, hotspot, rgba.buf); + + return true; } -void CMsgReader::readSetCursorWithAlpha(int width, int height, const Point& hotspot) +bool CMsgReader::readSetCursorWithAlpha(int width, int height, const Point& hotspot) { if (width > maxCursorSize || height > maxCursorSize) throw Exception("Too big cursor"); - int encoding; - const PixelFormat rgbaPF(32, 32, false, true, 255, 255, 255, 16, 8, 0); ManagedPixelBuffer pb(rgbaPF, width, height); PixelFormat origPF; + bool ret; + rdr::U8* buf; int stride; - encoding = is->readS32(); + // We can't use restore points as the decoder likely wants to as well, so + // we need to keep track of the read encoding + + if (cursorEncoding == -1) { + if (!is->hasData(4)) + return false; + + cursorEncoding = is->readS32(); + } origPF = handler->server.pf(); handler->server.setPF(rgbaPF); - handler->readAndDecodeRect(pb.getRect(), encoding, &pb); + ret = handler->readAndDecodeRect(pb.getRect(), cursorEncoding, &pb); handler->server.setPF(origPF); + if (!ret) + return false; + + cursorEncoding = -1; + // On-wire data has pre-multiplied alpha, but we store it // non-pre-multiplied buf = pb.getBufferRW(pb.getRect(), &stride); @@ -467,18 +614,25 @@ void CMsgReader::readSetCursorWithAlpha(int width, int height, const Point& hots handler->setCursor(width, height, hotspot, pb.getBuffer(pb.getRect(), &stride)); + + return true; } -void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot) +bool CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot) { if (width > maxCursorSize || height > maxCursorSize) throw Exception("Too big cursor"); rdr::U8 type; + if (!is->hasData(2)) + return false; + type = is->readU8(); is->skip(1); + is->setRestorePoint(); + if (type == 0) { int len = width * height * (handler->server.pf().bpp/8); rdr::U8Array andMask(len); @@ -491,6 +645,10 @@ void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot rdr::U8* out; int Bpp; + if (!is->hasDataOrRestore(len + len)) + return false; + is->clearRestorePoint(); + is->readBytes(andMask.buf, len); is->readBytes(xorMask.buf, len); @@ -548,6 +706,10 @@ void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot } else if (type == 1) { rdr::U8Array data(width*height*4); + if (!is->hasDataOrRestore(width*height*4)) + return false; + is->clearRestorePoint(); + // FIXME: Is alpha premultiplied? is->readBytes(data.buf, width*height*4); @@ -555,11 +717,25 @@ void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot } else { throw Exception("Unknown cursor type"); } + + return true; } -void CMsgReader::readSetDesktopName(int x, int y, int w, int h) +bool CMsgReader::readSetDesktopName(int x, int y, int w, int h) { - rdr::U32 len = is->readU32(); + rdr::U32 len; + + if (!is->hasData(4)) + return false; + + is->setRestorePoint(); + + len = is->readU32(); + + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + CharArray name(len + 1); is->readBytes(name.buf, len); name.buf[len] = '\0'; @@ -569,18 +745,29 @@ void CMsgReader::readSetDesktopName(int x, int y, int w, int h) } else { handler->setName(name.buf); } + + return true; } -void CMsgReader::readExtendedDesktopSize(int x, int y, int w, int h) +bool CMsgReader::readExtendedDesktopSize(int x, int y, int w, int h) { unsigned int screens, i; rdr::U32 id, flags; int sx, sy, sw, sh; ScreenSet layout; + if (!is->hasData(4)) + return false; + + is->setRestorePoint(); + screens = is->readU8(); is->skip(3); + if (!is->hasDataOrRestore(16 * screens)) + return false; + is->clearRestorePoint(); + for (i = 0;i < screens;i++) { id = is->readU32(); sx = is->readU16(); @@ -593,25 +780,37 @@ void CMsgReader::readExtendedDesktopSize(int x, int y, int w, int h) } handler->setExtendedDesktopSize(x, y, w, h, layout); + + return true; } -void CMsgReader::readLEDState() +bool CMsgReader::readLEDState() { rdr::U8 state; + if (!is->hasData(1)) + return false; + state = is->readU8(); handler->setLEDState(state); + + return true; } -void CMsgReader::readVMwareLEDState() +bool CMsgReader::readVMwareLEDState() { rdr::U32 state; + if (!is->hasData(4)) + return false; + state = is->readU32(); // As luck has it, this extension uses the same bit definitions, // so no conversion required handler->setLEDState(state); + + return true; } diff --git a/common/rfb/CMsgReader.h b/common/rfb/CMsgReader.h index 050990a9..ab55aed8 100644 --- a/common/rfb/CMsgReader.h +++ b/common/rfb/CMsgReader.h @@ -40,39 +40,55 @@ namespace rfb { CMsgReader(CMsgHandler* handler, rdr::InStream* is); virtual ~CMsgReader(); - void readServerInit(); + bool readServerInit(); // readMsg() reads a message, calling the handler as appropriate. - void readMsg(); + bool readMsg(); rdr::InStream* getInStream() { return is; } int imageBufIdealSize; protected: - void readSetColourMapEntries(); - void readBell(); - void readServerCutText(); - void readExtendedClipboard(rdr::S32 len); - void readFence(); - void readEndOfContinuousUpdates(); - - void readFramebufferUpdate(); - - void readRect(const Rect& r, int encoding); - - void readSetXCursor(int width, int height, const Point& hotspot); - void readSetCursor(int width, int height, const Point& hotspot); - void readSetCursorWithAlpha(int width, int height, const Point& hotspot); - void readSetVMwareCursor(int width, int height, const Point& hotspot); - void readSetDesktopName(int x, int y, int w, int h); - void readExtendedDesktopSize(int x, int y, int w, int h); - void readLEDState(); - void readVMwareLEDState(); - + bool readSetColourMapEntries(); + bool readBell(); + bool readServerCutText(); + bool readExtendedClipboard(rdr::S32 len); + bool readFence(); + bool readEndOfContinuousUpdates(); + + bool readFramebufferUpdate(); + + bool readRect(const Rect& r, int encoding); + + bool readSetXCursor(int width, int height, const Point& hotspot); + bool readSetCursor(int width, int height, const Point& hotspot); + bool readSetCursorWithAlpha(int width, int height, const Point& hotspot); + bool readSetVMwareCursor(int width, int height, const Point& hotspot); + bool readSetDesktopName(int x, int y, int w, int h); + bool readExtendedDesktopSize(int x, int y, int w, int h); + bool readLEDState(); + bool readVMwareLEDState(); + + private: CMsgHandler* handler; rdr::InStream* is; + + enum stateEnum { + MSGSTATE_IDLE, + MSGSTATE_MESSAGE, + MSGSTATE_RECT_HEADER, + MSGSTATE_RECT_DATA, + }; + + stateEnum state; + + rdr::U8 currentMsgType; int nUpdateRectsLeft; + Rect dataRect; + int rectEncoding; + + int cursorEncoding; static const int maxCursorSize = 256; }; diff --git a/common/rfb/CSecurityTLS.cxx b/common/rfb/CSecurityTLS.cxx index 374ec7f3..4fcaa7a9 100644 --- a/common/rfb/CSecurityTLS.cxx +++ b/common/rfb/CSecurityTLS.cxx @@ -154,7 +154,7 @@ bool CSecurityTLS::processMsg() client = cc; if (!session) { - if (!is->checkNoWait(1)) + if (!is->hasData(1)) return false; if (is->readU8() == 0) @@ -180,8 +180,10 @@ bool CSecurityTLS::processMsg() int err; err = gnutls_handshake(session); if (err != GNUTLS_E_SUCCESS) { - if (!gnutls_error_is_fatal(err)) + if (!gnutls_error_is_fatal(err)) { + vlog.debug("Deferring completion of TLS handshake: %s", gnutls_strerror(err)); return false; + } vlog.error("TLS Handshake failed: %s\n", gnutls_strerror (err)); shutdown(false); diff --git a/common/rfb/CSecurityVeNCrypt.cxx b/common/rfb/CSecurityVeNCrypt.cxx index 22201dd2..98dad494 100644 --- a/common/rfb/CSecurityVeNCrypt.cxx +++ b/common/rfb/CSecurityVeNCrypt.cxx @@ -51,7 +51,6 @@ CSecurityVeNCrypt::CSecurityVeNCrypt(CConnection* cc, SecurityClient* sec) chosenType = secTypeVeNCrypt; nAvailableTypes = 0; availableTypes = NULL; - iAvailableType = 0; } CSecurityVeNCrypt::~CSecurityVeNCrypt() @@ -64,16 +63,20 @@ bool CSecurityVeNCrypt::processMsg() { InStream* is = cc->getInStream(); OutStream* os = cc->getOutStream(); - + /* get major, minor versions, send what we can support (or 0.0 for can't support it) */ if (!haveRecvdMajorVersion) { + if (!is->hasData(1)) + return false; + majorVersion = is->readU8(); haveRecvdMajorVersion = true; - - return false; } if (!haveRecvdMinorVersion) { + if (!is->hasData(1)) + return false; + minorVersion = is->readU8(); haveRecvdMinorVersion = true; } @@ -100,47 +103,48 @@ bool CSecurityVeNCrypt::processMsg() } haveSentVersion = true; - return false; } /* Check that the server is OK */ if (!haveAgreedVersion) { + if (!is->hasData(1)) + return false; + if (is->readU8()) throw AuthFailureException("The server reported it could not support the " "VeNCrypt version"); haveAgreedVersion = true; - return false; } /* get a number of types */ if (!haveNumberOfTypes) { + if (!is->hasData(1)) + return false; + nAvailableTypes = is->readU8(); - iAvailableType = 0; if (!nAvailableTypes) throw AuthFailureException("The server reported no VeNCrypt sub-types"); availableTypes = new rdr::U32[nAvailableTypes]; haveNumberOfTypes = true; - return false; } if (nAvailableTypes) { /* read in the types possible */ if (!haveListOfTypes) { - if (is->checkNoWait(4)) { - availableTypes[iAvailableType++] = is->readU32(); - haveListOfTypes = (iAvailableType >= nAvailableTypes); - vlog.debug("Server offers security type %s (%d)", - secTypeName(availableTypes[iAvailableType - 1]), - availableTypes[iAvailableType - 1]); - - if (!haveListOfTypes) - return false; - - } else - return false; + if (!is->hasData(4 * nAvailableTypes)) + return false; + + for (int i = 0;i < nAvailableTypes;i++) { + availableTypes[i] = is->readU32(); + vlog.debug("Server offers security type %s (%d)", + secTypeName(availableTypes[i]), + availableTypes[i]); + } + + haveListOfTypes = true; } /* make a choice and send it to the server, meanwhile set up the stack */ diff --git a/common/rfb/CSecurityVeNCrypt.h b/common/rfb/CSecurityVeNCrypt.h index d015e8f2..1e2a7e68 100644 --- a/common/rfb/CSecurityVeNCrypt.h +++ b/common/rfb/CSecurityVeNCrypt.h @@ -55,7 +55,6 @@ namespace rfb { rdr::U32 chosenType; rdr::U8 nAvailableTypes; rdr::U32 *availableTypes; - rdr::U8 iAvailableType; }; } #endif diff --git a/common/rfb/CSecurityVncAuth.cxx b/common/rfb/CSecurityVncAuth.cxx index 6a87498c..78a3a061 100644 --- a/common/rfb/CSecurityVncAuth.cxx +++ b/common/rfb/CSecurityVncAuth.cxx @@ -45,6 +45,9 @@ bool CSecurityVncAuth::processMsg() rdr::InStream* is = cc->getInStream(); rdr::OutStream* os = cc->getOutStream(); + if (!is->hasData(vncAuthChallengeSize)) + return false; + // Read the challenge & obtain the user's password rdr::U8 challenge[vncAuthChallengeSize]; is->readBytes(challenge, vncAuthChallengeSize); diff --git a/common/rfb/CopyRectDecoder.cxx b/common/rfb/CopyRectDecoder.cxx index ecf50323..bca8da48 100644 --- a/common/rfb/CopyRectDecoder.cxx +++ b/common/rfb/CopyRectDecoder.cxx @@ -31,10 +31,13 @@ CopyRectDecoder::~CopyRectDecoder() { } -void CopyRectDecoder::readRect(const Rect& r, rdr::InStream* is, +bool CopyRectDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { + if (!is->hasData(4)) + return false; os->copyBytes(is, 4); + return true; } diff --git a/common/rfb/CopyRectDecoder.h b/common/rfb/CopyRectDecoder.h index 546266e1..5100eb2f 100644 --- a/common/rfb/CopyRectDecoder.h +++ b/common/rfb/CopyRectDecoder.h @@ -26,7 +26,7 @@ namespace rfb { public: CopyRectDecoder(); virtual ~CopyRectDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual void getAffectedRegion(const Rect& rect, const void* buffer, size_t buflen, const ServerParams& server, diff --git a/common/rfb/DecodeManager.cxx b/common/rfb/DecodeManager.cxx index 80c10510..c003ab40 100644 --- a/common/rfb/DecodeManager.cxx +++ b/common/rfb/DecodeManager.cxx @@ -103,7 +103,7 @@ DecodeManager::~DecodeManager() delete decoders[i]; } -void DecodeManager::decodeRect(const Rect& r, int encoding, +bool DecodeManager::decodeRect(const Rect& r, int encoding, ModifiablePixelBuffer* pb) { Decoder *decoder; @@ -133,19 +133,21 @@ void DecodeManager::decodeRect(const Rect& r, int encoding, if (threads.empty()) { bufferStream = freeBuffers.front(); bufferStream->clear(); - decoder->readRect(r, conn->getInStream(), conn->server, bufferStream); + if (!decoder->readRect(r, conn->getInStream(), conn->server, bufferStream)) + return false; try { decoder->decodeRect(r, bufferStream->data(), bufferStream->length(), conn->server, pb); } catch (rdr::Exception& e) { throw Exception("Error decoding rect: %s", e.str()); } - return; + return true; } // Wait for an available memory buffer queueMutex->lock(); + // FIXME: Should we return and let other things run here? while (freeBuffers.empty()) producerCond->wait(); @@ -160,7 +162,8 @@ void DecodeManager::decodeRect(const Rect& r, int encoding, // Read the rect bufferStream->clear(); - decoder->readRect(r, conn->getInStream(), conn->server, bufferStream); + if (!decoder->readRect(r, conn->getInStream(), conn->server, bufferStream)) + return false; // Then try to put it on the queue entry = new QueueEntry; @@ -190,6 +193,8 @@ void DecodeManager::decodeRect(const Rect& r, int encoding, consumerCond->signal(); queueMutex->unlock(); + + return true; } void DecodeManager::flush() diff --git a/common/rfb/DecodeManager.h b/common/rfb/DecodeManager.h index 058d8240..289686b5 100644 --- a/common/rfb/DecodeManager.h +++ b/common/rfb/DecodeManager.h @@ -47,7 +47,7 @@ namespace rfb { DecodeManager(CConnection *conn); ~DecodeManager(); - void decodeRect(const Rect& r, int encoding, + bool decodeRect(const Rect& r, int encoding, ModifiablePixelBuffer* pb); void flush(); diff --git a/common/rfb/Decoder.h b/common/rfb/Decoder.h index e074f3ec..cb206a0d 100644 --- a/common/rfb/Decoder.h +++ b/common/rfb/Decoder.h @@ -52,7 +52,7 @@ namespace rfb { // InStream to the OutStream, possibly changing it along the way to // make it easier to decode. This function will always be called in // a serial manner on the main thread. - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os)=0; // These functions will be called from any of the worker threads. diff --git a/common/rfb/HextileDecoder.cxx b/common/rfb/HextileDecoder.cxx index 742dfb28..34392ea8 100644 --- a/common/rfb/HextileDecoder.cxx +++ b/common/rfb/HextileDecoder.cxx @@ -44,12 +44,14 @@ HextileDecoder::~HextileDecoder() { } -void HextileDecoder::readRect(const Rect& r, rdr::InStream* is, +bool HextileDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { Rect t; size_t bytesPerPixel; + is->setRestorePoint(); + bytesPerPixel = server.pf().bpp/8; for (t.tl.y = r.tl.y; t.tl.y < r.br.y; t.tl.y += 16) { @@ -61,33 +63,57 @@ void HextileDecoder::readRect(const Rect& r, rdr::InStream* is, t.br.x = __rfbmin(r.br.x, t.tl.x + 16); + if (!is->hasDataOrRestore(1)) + return false; + tileType = is->readU8(); os->writeU8(tileType); if (tileType & hextileRaw) { + if (!is->hasDataOrRestore(t.area() * bytesPerPixel)) + return false; os->copyBytes(is, t.area() * bytesPerPixel); continue; } - if (tileType & hextileBgSpecified) + + if (tileType & hextileBgSpecified) { + if (!is->hasDataOrRestore(bytesPerPixel)) + return false; os->copyBytes(is, bytesPerPixel); + } - if (tileType & hextileFgSpecified) + if (tileType & hextileFgSpecified) { + if (!is->hasDataOrRestore(bytesPerPixel)) + return false; os->copyBytes(is, bytesPerPixel); + } if (tileType & hextileAnySubrects) { rdr::U8 nSubrects; + if (!is->hasDataOrRestore(1)) + return false; + nSubrects = is->readU8(); os->writeU8(nSubrects); - if (tileType & hextileSubrectsColoured) + if (tileType & hextileSubrectsColoured) { + if (!is->hasDataOrRestore(nSubrects * (bytesPerPixel + 2))) + return false; os->copyBytes(is, nSubrects * (bytesPerPixel + 2)); - else + } else { + if (!is->hasDataOrRestore(nSubrects * 2)) + return false; os->copyBytes(is, nSubrects * 2); + } } } } + + is->clearRestorePoint(); + + return true; } void HextileDecoder::decodeRect(const Rect& r, const void* buffer, diff --git a/common/rfb/HextileDecoder.h b/common/rfb/HextileDecoder.h index b8515bfc..2c42be54 100644 --- a/common/rfb/HextileDecoder.h +++ b/common/rfb/HextileDecoder.h @@ -26,7 +26,7 @@ namespace rfb { public: HextileDecoder(); virtual ~HextileDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual void decodeRect(const Rect& r, const void* buffer, size_t buflen, const ServerParams& server, diff --git a/common/rfb/RREDecoder.cxx b/common/rfb/RREDecoder.cxx index 70a7ddb2..af821cb9 100644 --- a/common/rfb/RREDecoder.cxx +++ b/common/rfb/RREDecoder.cxx @@ -44,15 +44,30 @@ RREDecoder::~RREDecoder() { } -void RREDecoder::readRect(const Rect& r, rdr::InStream* is, +bool RREDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { rdr::U32 numRects; + size_t len; + + if (!is->hasData(4)) + return false; + + is->setRestorePoint(); numRects = is->readU32(); os->writeU32(numRects); - os->copyBytes(is, server.pf().bpp/8 + numRects * (server.pf().bpp/8 + 8)); + len = server.pf().bpp/8 + numRects * (server.pf().bpp/8 + 8); + + if (!is->hasDataOrRestore(len)) + return false; + + is->clearRestorePoint(); + + os->copyBytes(is, len); + + return true; } void RREDecoder::decodeRect(const Rect& r, const void* buffer, diff --git a/common/rfb/RREDecoder.h b/common/rfb/RREDecoder.h index f47eddad..b8ec18f6 100644 --- a/common/rfb/RREDecoder.h +++ b/common/rfb/RREDecoder.h @@ -26,7 +26,7 @@ namespace rfb { public: RREDecoder(); virtual ~RREDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual void decodeRect(const Rect& r, const void* buffer, size_t buflen, const ServerParams& server, diff --git a/common/rfb/RawDecoder.cxx b/common/rfb/RawDecoder.cxx index 61235047..a7648f97 100644 --- a/common/rfb/RawDecoder.cxx +++ b/common/rfb/RawDecoder.cxx @@ -33,10 +33,13 @@ RawDecoder::~RawDecoder() { } -void RawDecoder::readRect(const Rect& r, rdr::InStream* is, +bool RawDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { + if (!is->hasData(r.area() * (server.pf().bpp/8))) + return false; os->copyBytes(is, r.area() * (server.pf().bpp/8)); + return true; } void RawDecoder::decodeRect(const Rect& r, const void* buffer, diff --git a/common/rfb/RawDecoder.h b/common/rfb/RawDecoder.h index 4ab80717..2661ea57 100644 --- a/common/rfb/RawDecoder.h +++ b/common/rfb/RawDecoder.h @@ -25,7 +25,7 @@ namespace rfb { public: RawDecoder(); virtual ~RawDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual void decodeRect(const Rect& r, const void* buffer, size_t buflen, const ServerParams& server, diff --git a/common/rfb/SConnection.cxx b/common/rfb/SConnection.cxx index e06fc6bb..1c9ca3e7 100644 --- a/common/rfb/SConnection.cxx +++ b/common/rfb/SConnection.cxx @@ -86,18 +86,20 @@ void SConnection::initialiseProtocol() state_ = RFBSTATE_PROTOCOL_VERSION; } -void SConnection::processMsg() +bool SConnection::processMsg() { switch (state_) { - case RFBSTATE_PROTOCOL_VERSION: processVersionMsg(); break; - case RFBSTATE_SECURITY_TYPE: processSecurityTypeMsg(); break; - case RFBSTATE_SECURITY: processSecurityMsg(); break; - case RFBSTATE_SECURITY_FAILURE: processSecurityFailure(); break; - case RFBSTATE_INITIALISATION: processInitMsg(); break; - case RFBSTATE_NORMAL: reader_->readMsg(); break; + case RFBSTATE_PROTOCOL_VERSION: return processVersionMsg(); break; + case RFBSTATE_SECURITY_TYPE: return processSecurityTypeMsg(); break; + case RFBSTATE_SECURITY: return processSecurityMsg(); break; + case RFBSTATE_SECURITY_FAILURE: return processSecurityFailure(); break; + case RFBSTATE_INITIALISATION: return processInitMsg(); break; + case RFBSTATE_NORMAL: return reader_->readMsg(); break; case RFBSTATE_QUERYING: throw Exception("SConnection::processMsg: bogus data from client while " "querying"); + case RFBSTATE_CLOSING: + throw Exception("SConnection::processMsg: called while closing"); case RFBSTATE_UNINITIALISED: throw Exception("SConnection::processMsg: not initialised yet?"); default: @@ -105,7 +107,7 @@ void SConnection::processMsg() } } -void SConnection::processVersionMsg() +bool SConnection::processVersionMsg() { char verStr[13]; int majorVersion; @@ -113,8 +115,8 @@ void SConnection::processVersionMsg() vlog.debug("reading protocol version"); - if (!is->checkNoWait(12)) - return; + if (!is->hasData(12)) + return false; is->readBytes(verStr, 12); verStr[12] = '\0'; @@ -172,8 +174,7 @@ void SConnection::processVersionMsg() if (*i == secTypeNone) os->flush(); state_ = RFBSTATE_SECURITY; ssecurity = security.GetSSecurity(this, *i); - processSecurityMsg(); - return; + return true; } // list supported security types for >=3.7 clients @@ -186,15 +187,23 @@ void SConnection::processVersionMsg() os->writeU8(*i); os->flush(); state_ = RFBSTATE_SECURITY_TYPE; + + return true; } -void SConnection::processSecurityTypeMsg() +bool SConnection::processSecurityTypeMsg() { vlog.debug("processing security type message"); + + if (!is->hasData(1)) + return false; + int secType = is->readU8(); processSecurityType(secType); + + return true; } void SConnection::processSecurityType(int secType) @@ -218,16 +227,14 @@ void SConnection::processSecurityType(int secType) } catch (rdr::Exception& e) { throwConnFailedException("%s", e.str()); } - - processSecurityMsg(); } -void SConnection::processSecurityMsg() +bool SConnection::processSecurityMsg() { vlog.debug("processing security message"); try { if (!ssecurity->processMsg()) - return; + return false; } catch (AuthFailureException& e) { vlog.error("AuthFailureException: %s", e.str()); state_ = RFBSTATE_SECURITY_FAILURE; @@ -235,28 +242,41 @@ void SConnection::processSecurityMsg() // to make it difficult to brute force a password authFailureMsg.replaceBuf(strDup(e.str())); authFailureTimer.start(100); + return true; } state_ = RFBSTATE_QUERYING; setAccessRights(ssecurity->getAccessRights()); queryConnection(ssecurity->getUserName()); + + // If the connection got approved right away then we can continue + if (state_ == RFBSTATE_INITIALISATION) + return true; + + // Otherwise we need to wait for the result + // (or give up if if was rejected) + return false; } -void SConnection::processSecurityFailure() +bool SConnection::processSecurityFailure() { // Silently drop any data if we are currently delaying an // authentication failure response as otherwise we would close // the connection on unexpected data, and an attacker could use // that to detect our delayed state. - while (is->checkNoWait(1)) - is->skip(1); + if (!is->hasData(1)) + return false; + + is->skip(is->avail()); + + return true; } -void SConnection::processInitMsg() +bool SConnection::processInitMsg() { vlog.debug("reading client initialisation"); - reader_->readClientInit(); + return reader_->readClientInit(); } bool SConnection::handleAuthFailureTimeout(Timer* t) diff --git a/common/rfb/SConnection.h b/common/rfb/SConnection.h index e7bbf2c3..b333086a 100644 --- a/common/rfb/SConnection.h +++ b/common/rfb/SConnection.h @@ -60,7 +60,7 @@ namespace rfb { // processMsg() should be called whenever there is data to read on the // InStream. You must have called initialiseProtocol() first. - void processMsg(); + bool processMsg(); // approveConnection() is called to either accept or reject the connection. // If accept is false, the reason string gives the reason for the @@ -235,12 +235,12 @@ namespace rfb { bool readyForSetColourMapEntries; - void processVersionMsg(); - void processSecurityTypeMsg(); + bool processVersionMsg(); + bool processSecurityTypeMsg(); void processSecurityType(int secType); - void processSecurityMsg(); - void processSecurityFailure(); - void processInitMsg(); + bool processSecurityMsg(); + bool processSecurityFailure(); + bool processInitMsg(); bool handleAuthFailureTimeout(Timer* t); diff --git a/common/rfb/SMsgReader.cxx b/common/rfb/SMsgReader.cxx index dc7ddea6..944f9315 100644 --- a/common/rfb/SMsgReader.cxx +++ b/common/rfb/SMsgReader.cxx @@ -38,7 +38,7 @@ static LogWriter vlog("SMsgReader"); static IntParameter maxCutText("MaxCutText", "Maximum permitted length of an incoming clipboard update", 256*1024); SMsgReader::SMsgReader(SMsgHandler* handler_, rdr::InStream* is_) - : handler(handler_), is(is_) + : handler(handler_), is(is_), state(MSGSTATE_IDLE) { } @@ -46,71 +46,105 @@ SMsgReader::~SMsgReader() { } -void SMsgReader::readClientInit() +bool SMsgReader::readClientInit() { + if (!is->hasData(1)) + return false; bool shared = is->readU8(); handler->clientInit(shared); + return true; } -void SMsgReader::readMsg() +bool SMsgReader::readMsg() { - int msgType = is->readU8(); - switch (msgType) { + bool ret; + + if (state == MSGSTATE_IDLE) { + if (!is->hasData(1)) + return false; + + currentMsgType = is->readU8(); + state = MSGSTATE_MESSAGE; + } + + switch (currentMsgType) { case msgTypeSetPixelFormat: - readSetPixelFormat(); + ret = readSetPixelFormat(); break; case msgTypeSetEncodings: - readSetEncodings(); + ret = readSetEncodings(); break; case msgTypeSetDesktopSize: - readSetDesktopSize(); + ret = readSetDesktopSize(); break; case msgTypeFramebufferUpdateRequest: - readFramebufferUpdateRequest(); + ret = readFramebufferUpdateRequest(); break; case msgTypeEnableContinuousUpdates: - readEnableContinuousUpdates(); + ret = readEnableContinuousUpdates(); break; case msgTypeClientFence: - readFence(); + ret = readFence(); break; case msgTypeKeyEvent: - readKeyEvent(); + ret = readKeyEvent(); break; case msgTypePointerEvent: - readPointerEvent(); + ret = readPointerEvent(); break; case msgTypeClientCutText: - readClientCutText(); + ret = readClientCutText(); break; case msgTypeQEMUClientMessage: - readQEMUMessage(); + ret = readQEMUMessage(); break; default: - vlog.error("unknown message type %d", msgType); + vlog.error("unknown message type %d", currentMsgType); throw Exception("unknown message type"); } + + if (ret) + state = MSGSTATE_IDLE; + + return ret; } -void SMsgReader::readSetPixelFormat() +bool SMsgReader::readSetPixelFormat() { + if (!is->hasData(3 + 16)) + return false; is->skip(3); PixelFormat pf; pf.read(is); handler->setPixelFormat(pf); + return true; } -void SMsgReader::readSetEncodings() +bool SMsgReader::readSetEncodings() { + if (!is->hasData(3)) + return false; + + is->setRestorePoint(); + is->skip(1); + int nEncodings = is->readU16(); + + if (!is->hasDataOrRestore(nEncodings * 4)) + return false; + is->clearRestorePoint(); + rdr::S32Array encodings(nEncodings); for (int i = 0; i < nEncodings; i++) encodings.buf[i] = is->readU32(); + handler->setEncodings(nEncodings, encodings.buf); + + return true; } -void SMsgReader::readSetDesktopSize() +bool SMsgReader::readSetDesktopSize() { int width, height; int screens, i; @@ -118,6 +152,11 @@ void SMsgReader::readSetDesktopSize() int sx, sy, sw, sh; ScreenSet layout; + if (!is->hasData(7)) + return true; + + is->setRestorePoint(); + is->skip(1); width = is->readU16(); @@ -126,6 +165,10 @@ void SMsgReader::readSetDesktopSize() screens = is->readU8(); is->skip(1); + if (!is->hasDataOrRestore(screens * 24)) + return false; + is->clearRestorePoint(); + for (i = 0;i < screens;i++) { id = is->readU32(); sx = is->readU16(); @@ -138,23 +181,31 @@ void SMsgReader::readSetDesktopSize() } handler->setDesktopSize(width, height, layout); + + return true; } -void SMsgReader::readFramebufferUpdateRequest() +bool SMsgReader::readFramebufferUpdateRequest() { + if (!is->hasData(17)) + return false; bool inc = is->readU8(); int x = is->readU16(); int y = is->readU16(); int w = is->readU16(); int h = is->readU16(); handler->framebufferUpdateRequest(Rect(x, y, x+w, y+h), inc); + return true; } -void SMsgReader::readEnableContinuousUpdates() +bool SMsgReader::readEnableContinuousUpdates() { bool enable; int x, y, w, h; + if (!is->hasData(17)) + return false; + enable = is->readU8(); x = is->readU16(); @@ -163,81 +214,121 @@ void SMsgReader::readEnableContinuousUpdates() h = is->readU16(); handler->enableContinuousUpdates(enable, x, y, w, h); + + return true; } -void SMsgReader::readFence() +bool SMsgReader::readFence() { rdr::U32 flags; rdr::U8 len; char data[64]; + if (!is->hasData(8)) + return false; + + is->setRestorePoint(); + is->skip(3); flags = is->readU32(); len = is->readU8(); + + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + if (len > sizeof(data)) { vlog.error("Ignoring fence with too large payload"); is->skip(len); - return; + return true; } is->readBytes(data, len); handler->fence(flags, len, data); + + return true; } -void SMsgReader::readKeyEvent() +bool SMsgReader::readKeyEvent() { + if (!is->hasData(7)) + return false; bool down = is->readU8(); is->skip(2); rdr::U32 key = is->readU32(); handler->keyEvent(key, 0, down); + return true; } -void SMsgReader::readPointerEvent() +bool SMsgReader::readPointerEvent() { + if (!is->hasData(5)) + return false; int mask = is->readU8(); int x = is->readU16(); int y = is->readU16(); handler->pointerEvent(Point(x, y), mask); + return true; } -void SMsgReader::readClientCutText() +bool SMsgReader::readClientCutText() { + if (!is->hasData(7)) + return false; + + is->setRestorePoint(); + is->skip(3); rdr::U32 len = is->readU32(); if (len & 0x80000000) { rdr::S32 slen = len; slen = -slen; - readExtendedClipboard(slen); - return; + if (readExtendedClipboard(slen)) { + is->clearRestorePoint(); + return true; + } else { + is->gotoRestorePoint(); + return false; + } } + if (!is->hasDataOrRestore(len)) + return false; + is->clearRestorePoint(); + if (len > (size_t)maxCutText) { is->skip(len); vlog.error("Cut text too long (%d bytes) - ignoring", len); - return; + return true; } + CharArray ca(len); is->readBytes(ca.buf, len); CharArray filtered(convertLF(ca.buf, len)); handler->clientCutText(filtered.buf); + + return true; } -void SMsgReader::readExtendedClipboard(rdr::S32 len) +bool SMsgReader::readExtendedClipboard(rdr::S32 len) { rdr::U32 flags; rdr::U32 action; + if (!is->hasData(len)) + return false; + if (len < 4) throw Exception("Invalid extended clipboard message"); if (len > maxCutText) { vlog.error("Extended clipboard message too long (%d bytes) - ignoring", len); is->skip(len); - return; + return true; } flags = is->readU32(); @@ -279,7 +370,14 @@ void SMsgReader::readExtendedClipboard(rdr::S32 len) if (!(flags & 1 << i)) continue; + if (!zis.hasData(4)) + throw Exception("Extended clipboard decode error"); + lengths[num] = zis.readU32(); + + if (!zis.hasData(lengths[num])) + throw Exception("Extended clipboard decode error"); + if (lengths[num] > (size_t)maxCutText) { vlog.error("Extended clipboard data too long (%d bytes) - ignoring", (unsigned)lengths[num]); @@ -319,28 +417,50 @@ void SMsgReader::readExtendedClipboard(rdr::S32 len) throw Exception("Invalid extended clipboard action"); } } + + return true; } -void SMsgReader::readQEMUMessage() +bool SMsgReader::readQEMUMessage() { - int subType = is->readU8(); + int subType; + bool ret; + + if (!is->hasData(1)) + return false; + + is->setRestorePoint(); + + subType = is->readU8(); + switch (subType) { case qemuExtendedKeyEvent: - readQEMUKeyEvent(); + ret = readQEMUKeyEvent(); break; default: throw Exception("unknown QEMU submessage type %d", subType); } + + if (!ret) { + is->gotoRestorePoint(); + return false; + } else { + is->clearRestorePoint(); + return true; + } } -void SMsgReader::readQEMUKeyEvent() +bool SMsgReader::readQEMUKeyEvent() { + if (!is->hasData(10)) + return false; bool down = is->readU16(); rdr::U32 keysym = is->readU32(); rdr::U32 keycode = is->readU32(); if (!keycode) { vlog.error("Key event without keycode - ignoring"); - return; + return true; } handler->keyEvent(keysym, keycode, down); + return true; } diff --git a/common/rfb/SMsgReader.h b/common/rfb/SMsgReader.h index 4991fd38..acc872ed 100644 --- a/common/rfb/SMsgReader.h +++ b/common/rfb/SMsgReader.h @@ -34,33 +34,43 @@ namespace rfb { SMsgReader(SMsgHandler* handler, rdr::InStream* is); virtual ~SMsgReader(); - void readClientInit(); + bool readClientInit(); // readMsg() reads a message, calling the handler as appropriate. - void readMsg(); + bool readMsg(); rdr::InStream* getInStream() { return is; } protected: - void readSetPixelFormat(); - void readSetEncodings(); - void readSetDesktopSize(); + bool readSetPixelFormat(); + bool readSetEncodings(); + bool readSetDesktopSize(); - void readFramebufferUpdateRequest(); - void readEnableContinuousUpdates(); + bool readFramebufferUpdateRequest(); + bool readEnableContinuousUpdates(); - void readFence(); + bool readFence(); - void readKeyEvent(); - void readPointerEvent(); - void readClientCutText(); - void readExtendedClipboard(rdr::S32 len); + bool readKeyEvent(); + bool readPointerEvent(); + bool readClientCutText(); + bool readExtendedClipboard(rdr::S32 len); - void readQEMUMessage(); - void readQEMUKeyEvent(); + bool readQEMUMessage(); + bool readQEMUKeyEvent(); + private: SMsgHandler* handler; rdr::InStream* is; + + enum stateEnum { + MSGSTATE_IDLE, + MSGSTATE_MESSAGE, + }; + + stateEnum state; + + rdr::U8 currentMsgType; }; } #endif diff --git a/common/rfb/SSecurityPlain.cxx b/common/rfb/SSecurityPlain.cxx index f577c0d6..6ae19557 100644 --- a/common/rfb/SSecurityPlain.cxx +++ b/common/rfb/SSecurityPlain.cxx @@ -84,7 +84,7 @@ bool SSecurityPlain::processMsg() throw AuthFailureException("No password validator configured"); if (state == 0) { - if (!is->checkNoWait(8)) + if (!is->hasData(8)) return false; ulen = is->readU32(); @@ -99,7 +99,7 @@ bool SSecurityPlain::processMsg() } if (state == 1) { - if (!is->checkNoWait(ulen + plen)) + if (!is->hasData(ulen + plen)) return false; state = 2; pw = new char[plen + 1]; diff --git a/common/rfb/SSecurityVeNCrypt.cxx b/common/rfb/SSecurityVeNCrypt.cxx index d522ef6f..135742c0 100644 --- a/common/rfb/SSecurityVeNCrypt.cxx +++ b/common/rfb/SSecurityVeNCrypt.cxx @@ -78,19 +78,21 @@ bool SSecurityVeNCrypt::processMsg() os->writeU8(2); haveSentVersion = true; os->flush(); - - return false; } /* Receive back highest version that client can support (up to and including ours) */ if (!haveRecvdMajorVersion) { + if (!is->hasData(1)) + return false; + majorVersion = is->readU8(); haveRecvdMajorVersion = true; - - return false; } if (!haveRecvdMinorVersion) { + if (!is->hasData(1)) + return false; + minorVersion = is->readU8(); haveRecvdMinorVersion = true; @@ -140,14 +142,15 @@ bool SSecurityVeNCrypt::processMsg() os->flush(); haveSentTypes = true; - return false; } else throw AuthFailureException("There are no VeNCrypt sub-types to send to the client"); } /* get type back from client (must be one of the ones we sent) */ if (!haveChosenType) { - is->check(4); + if (!is->hasData(4)) + return false; + chosenType = is->readU32(); for (i = 0; i < numTypes; i++) { diff --git a/common/rfb/SSecurityVncAuth.cxx b/common/rfb/SSecurityVncAuth.cxx index 882f0b08..c2a348b9 100644 --- a/common/rfb/SSecurityVncAuth.cxx +++ b/common/rfb/SSecurityVncAuth.cxx @@ -49,7 +49,7 @@ VncAuthPasswdParameter SSecurityVncAuth::vncAuthPasswd "access the server", &SSecurityVncAuth::vncAuthPasswdFile); SSecurityVncAuth::SSecurityVncAuth(SConnection* sc) - : SSecurity(sc), sentChallenge(false), responsePos(0), + : SSecurity(sc), sentChallenge(false), pg(&vncAuthPasswd), accessRights(0) { } @@ -78,6 +78,8 @@ bool SSecurityVncAuth::processMsg() if (!sentChallenge) { rdr::RandomStream rs; + if (!rs.hasData(vncAuthChallengeSize)) + throw Exception("Could not generate random data for VNC auth challenge"); rs.readBytes(challenge, vncAuthChallengeSize); os->writeBytes(challenge, vncAuthChallengeSize); os->flush(); @@ -85,10 +87,10 @@ bool SSecurityVncAuth::processMsg() return false; } - while (responsePos < vncAuthChallengeSize && is->checkNoWait(1)) - response[responsePos++] = is->readU8(); + if (!is->hasData(vncAuthChallengeSize)) + return false; - if (responsePos < vncAuthChallengeSize) return false; + is->readBytes(response, vncAuthChallengeSize); PlainPasswd passwd, passwdReadOnly; pg->getVncAuthPasswd(&passwd, &passwdReadOnly); diff --git a/common/rfb/SSecurityVncAuth.h b/common/rfb/SSecurityVncAuth.h index fe00b031..94d5aaf2 100644 --- a/common/rfb/SSecurityVncAuth.h +++ b/common/rfb/SSecurityVncAuth.h @@ -64,7 +64,6 @@ namespace rfb { rdr::U8 challenge[vncAuthChallengeSize]; rdr::U8 response[vncAuthChallengeSize]; bool sentChallenge; - int responsePos; VncAuthPasswdGetter* pg; SConnection::AccessRights accessRights; }; diff --git a/common/rfb/ServerCore.cxx b/common/rfb/ServerCore.cxx index b1097a3e..8f49848c 100644 --- a/common/rfb/ServerCore.cxx +++ b/common/rfb/ServerCore.cxx @@ -42,11 +42,6 @@ rfb::IntParameter rfb::Server::maxIdleTime ("MaxIdleTime", "Terminate after s seconds of user inactivity", 0, 0); -rfb::IntParameter rfb::Server::clientWaitTimeMillis -("ClientWaitTimeMillis", - "The number of milliseconds to wait for a client which is no longer " - "responding", - 20000, 0); rfb::IntParameter rfb::Server::compareFB ("CompareFB", "Perform pixel comparison on framebuffer to reduce unnecessary updates " diff --git a/common/rfb/ServerCore.h b/common/rfb/ServerCore.h index f915c7a7..20a740a8 100644 --- a/common/rfb/ServerCore.h +++ b/common/rfb/ServerCore.h @@ -36,7 +36,6 @@ namespace rfb { static IntParameter maxDisconnectionTime; static IntParameter maxConnectionTime; static IntParameter maxIdleTime; - static IntParameter clientWaitTimeMillis; static IntParameter compareFB; static IntParameter frameRate; static BoolParameter protocol3_3; diff --git a/common/rfb/TightDecoder.cxx b/common/rfb/TightDecoder.cxx index ebc98b06..fe03e453 100644 --- a/common/rfb/TightDecoder.cxx +++ b/common/rfb/TightDecoder.cxx @@ -54,11 +54,16 @@ TightDecoder::~TightDecoder() { } -void TightDecoder::readRect(const Rect& r, rdr::InStream* is, +bool TightDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { rdr::U8 comp_ctl; + if (!is->hasData(1)) + return false; + + is->setRestorePoint(); + comp_ctl = is->readU8(); os->writeU8(comp_ctl); @@ -66,21 +71,38 @@ void TightDecoder::readRect(const Rect& r, rdr::InStream* is, // "Fill" compression type. if (comp_ctl == tightFill) { - if (server.pf().is888()) + if (server.pf().is888()) { + if (!is->hasDataOrRestore(3)) + return false; os->copyBytes(is, 3); - else + } else { + if (!is->hasDataOrRestore(server.pf().bpp/8)) + return false; os->copyBytes(is, server.pf().bpp/8); - return; + } + is->clearRestorePoint(); + return true; } // "JPEG" compression type. if (comp_ctl == tightJpeg) { rdr::U32 len; + // FIXME: Might be less than 3 bytes + if (!is->hasDataOrRestore(3)) + return false; + len = readCompact(is); os->writeOpaque32(len); + + if (!is->hasDataOrRestore(len)) + return false; + os->copyBytes(is, len); - return; + + is->clearRestorePoint(); + + return true; } // Quit on unsupported compression type. @@ -98,18 +120,29 @@ void TightDecoder::readRect(const Rect& r, rdr::InStream* is, if ((comp_ctl & tightExplicitFilter) != 0) { rdr::U8 filterId; + if (!is->hasDataOrRestore(1)) + return false; + filterId = is->readU8(); os->writeU8(filterId); switch (filterId) { case tightFilterPalette: + if (!is->hasDataOrRestore(1)) + return false; + palSize = is->readU8() + 1; os->writeU8(palSize - 1); - if (server.pf().is888()) + if (server.pf().is888()) { + if (!is->hasDataOrRestore(palSize * 3)) + return false; os->copyBytes(is, palSize * 3); - else + } else { + if (!is->hasDataOrRestore(palSize * server.pf().bpp/8)) + return false; os->copyBytes(is, palSize * server.pf().bpp/8); + } break; case tightFilterGradient: if (server.pf().bpp == 8) @@ -137,15 +170,29 @@ void TightDecoder::readRect(const Rect& r, rdr::InStream* is, dataSize = r.height() * rowSize; - if (dataSize < TIGHT_MIN_TO_COMPRESS) + if (dataSize < TIGHT_MIN_TO_COMPRESS) { + if (!is->hasDataOrRestore(dataSize)) + return false; os->copyBytes(is, dataSize); - else { + } else { rdr::U32 len; + // FIXME: Might be less than 3 bytes + if (!is->hasDataOrRestore(3)) + return false; + len = readCompact(is); os->writeOpaque32(len); + + if (!is->hasDataOrRestore(len)) + return false; + os->copyBytes(is, len); } + + is->clearRestorePoint(); + + return true; } bool TightDecoder::doRectsConflict(const Rect& rectA, @@ -339,6 +386,8 @@ void TightDecoder::decodeRect(const Rect& r, const void* buffer, // Allocate buffer and decompress the data netbuf = new rdr::U8[dataSize]; + if (!zis[streamId].hasData(dataSize)) + throw Exception("Tight decode error"); zis[streamId].readBytes(netbuf, dataSize); zis[streamId].flushUnderlying(); diff --git a/common/rfb/TightDecoder.h b/common/rfb/TightDecoder.h index 28b6c30f..763c82d6 100644 --- a/common/rfb/TightDecoder.h +++ b/common/rfb/TightDecoder.h @@ -31,7 +31,7 @@ namespace rfb { public: TightDecoder(); virtual ~TightDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual bool doRectsConflict(const Rect& rectA, const void* bufferA, diff --git a/common/rfb/VNCSConnectionST.cxx b/common/rfb/VNCSConnectionST.cxx index 00f640b3..c4ec733b 100644 --- a/common/rfb/VNCSConnectionST.cxx +++ b/common/rfb/VNCSConnectionST.cxx @@ -57,9 +57,6 @@ VNCSConnectionST::VNCSConnectionST(VNCServerST* server_, network::Socket *s, setStreams(&sock->inStream(), &sock->outStream()); peerEndpoint.buf = sock->getPeerEndpoint(); - // Configure the socket - setSocketTimeouts(); - // Kick off the idle timer if (rfb::Server::idleTimeout) { // minimum of 15 seconds while authenticating @@ -152,26 +149,23 @@ void VNCSConnectionST::processMessages() { if (state() == RFBSTATE_CLOSING) return; try { - // - Now set appropriate socket timeouts and process data - setSocketTimeouts(); - inProcessMessages = true; // Get the underlying transport to build large packets if we send // multiple small responses. getOutStream()->cork(true); - while (getInStream()->checkNoWait(1)) { - if (pendingSyncFence) { + while (true) { + if (pendingSyncFence) syncFence = true; - pendingSyncFence = false; - } - processMsg(); + if (!processMsg()) + break; if (syncFence) { writer()->writeFence(fenceFlags, fenceDataLen, fenceData); syncFence = false; + pendingSyncFence = false; } } @@ -195,7 +189,6 @@ void VNCSConnectionST::flushSocket() { if (state() == RFBSTATE_CLOSING) return; try { - setSocketTimeouts(); sock->outStream().flush(); // Flushing the socket might release an update that was previously // delayed because of congestion. @@ -1150,12 +1143,3 @@ void VNCSConnectionST::setLEDState(unsigned int ledstate) if (client.supportsLEDState()) writer()->writeLEDState(); } - -void VNCSConnectionST::setSocketTimeouts() -{ - int timeoutms = rfb::Server::clientWaitTimeMillis; - if (timeoutms == 0) - timeoutms = -1; - sock->inStream().setTimeout(timeoutms); - sock->outStream().setTimeout(timeoutms); -} diff --git a/common/rfb/VNCSConnectionST.h b/common/rfb/VNCSConnectionST.h index 46a2b28b..06fdf541 100644 --- a/common/rfb/VNCSConnectionST.h +++ b/common/rfb/VNCSConnectionST.h @@ -155,7 +155,6 @@ namespace rfb { void setCursor(); void setDesktopName(const char *name); void setLEDState(unsigned int state); - void setSocketTimeouts(); private: network::Socket* sock; diff --git a/common/rfb/ZRLEDecoder.cxx b/common/rfb/ZRLEDecoder.cxx index 9d1ff6b6..4fba0c22 100644 --- a/common/rfb/ZRLEDecoder.cxx +++ b/common/rfb/ZRLEDecoder.cxx @@ -21,6 +21,7 @@ #include <rdr/MemInStream.h> #include <rdr/OutStream.h> +#include <rfb/Exception.h> #include <rfb/ServerParams.h> #include <rfb/PixelBuffer.h> #include <rfb/ZRLEDecoder.h> @@ -29,7 +30,6 @@ using namespace rfb; static inline rdr::U32 readOpaque24A(rdr::InStream* is) { - is->check(3); rdr::U32 r=0; ((rdr::U8*)&r)[0] = is->readU8(); ((rdr::U8*)&r)[1] = is->readU8(); @@ -39,7 +39,6 @@ static inline rdr::U32 readOpaque24A(rdr::InStream* is) } static inline rdr::U32 readOpaque24B(rdr::InStream* is) { - is->check(3); rdr::U32 r=0; ((rdr::U8*)&r)[1] = is->readU8(); ((rdr::U8*)&r)[2] = is->readU8(); @@ -47,6 +46,12 @@ static inline rdr::U32 readOpaque24B(rdr::InStream* is) return r; } +static inline void zlibHasData(rdr::ZlibInStream* zis, size_t length) +{ + if (!zis->hasData(length)) + throw Exception("ZRLE decode error"); +} + #define BPP 8 #include <rfb/zrleDecode.h> #undef BPP @@ -71,14 +76,27 @@ ZRLEDecoder::~ZRLEDecoder() { } -void ZRLEDecoder::readRect(const Rect& r, rdr::InStream* is, +bool ZRLEDecoder::readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os) { rdr::U32 len; + if (!is->hasData(4)) + return false; + + is->setRestorePoint(); + len = is->readU32(); os->writeU32(len); + + if (!is->hasDataOrRestore(len)) + return false; + + is->clearRestorePoint(); + os->copyBytes(is, len); + + return true; } void ZRLEDecoder::decodeRect(const Rect& r, const void* buffer, diff --git a/common/rfb/ZRLEDecoder.h b/common/rfb/ZRLEDecoder.h index a530586e..115f8fb8 100644 --- a/common/rfb/ZRLEDecoder.h +++ b/common/rfb/ZRLEDecoder.h @@ -27,7 +27,7 @@ namespace rfb { public: ZRLEDecoder(); virtual ~ZRLEDecoder(); - virtual void readRect(const Rect& r, rdr::InStream* is, + virtual bool readRect(const Rect& r, rdr::InStream* is, const ServerParams& server, rdr::OutStream* os); virtual void decodeRect(const Rect& r, const void* buffer, size_t buflen, const ServerParams& server, diff --git a/common/rfb/zrleDecode.h b/common/rfb/zrleDecode.h index f4325385..998e51ed 100644 --- a/common/rfb/zrleDecode.h +++ b/common/rfb/zrleDecode.h @@ -22,11 +22,6 @@ // This file is #included after having set the following macro: // BPP - 8, 16 or 32 -#include <stdio.h> -#include <rdr/InStream.h> -#include <rdr/ZlibInStream.h> -#include <rfb/Exception.h> - namespace rfb { // CONCAT2E concatenates its arguments, expanding them if they are macros @@ -63,11 +58,17 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is, t.br.x = __rfbmin(r.br.x, t.tl.x + 64); + zlibHasData(zis, 1); int mode = zis->readU8(); bool rle = mode & 128; int palSize = mode & 127; PIXEL_T palette[128]; +#ifdef CPIXEL + zlibHasData(zis, 3 * palSize); +#else + zlibHasData(zis, BPP/8 * palSize); +#endif for (int i = 0; i < palSize; i++) { palette[i] = READ_PIXEL(zis); } @@ -84,10 +85,12 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is, // raw #ifdef CPIXEL + zlibHasData(zis, 3 * t.area()); for (PIXEL_T* ptr = buf; ptr < buf+t.area(); ptr++) { *ptr = READ_PIXEL(zis); } #else + zlibHasData(zis, BPP/8 * t.area()); zis->readBytes(buf, t.area() * (BPP / 8)); #endif @@ -106,6 +109,7 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is, while (ptr < eol) { if (nbits == 0) { + zlibHasData(zis, 1); byte = zis->readU8(); nbits = 8; } @@ -125,10 +129,16 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is, PIXEL_T* ptr = buf; PIXEL_T* end = ptr + t.area(); while (ptr < end) { +#ifdef CPIXEL + zlibHasData(zis, 3); +#else + zlibHasData(zis, BPP/8); +#endif PIXEL_T pix = READ_PIXEL(zis); int len = 1; int b; do { + zlibHasData(zis, 1); b = zis->readU8(); len += b; } while (b == 255); @@ -147,11 +157,13 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is, PIXEL_T* ptr = buf; PIXEL_T* end = ptr + t.area(); while (ptr < end) { + zlibHasData(zis, 1); int index = zis->readU8(); int len = 1; if (index & 128) { int b; do { + zlibHasData(zis, 1); b = zis->readU8(); len += b; } while (b == 255); |