]> source.dussan.org Git - tigervnc.git/commitdiff
Change streams to be asynchronous
authorPierre Ossman <ossman@cendio.se>
Thu, 14 May 2020 16:49:39 +0000 (18:49 +0200)
committerPierre Ossman <ossman@cendio.se>
Thu, 21 May 2020 10:59:02 +0000 (12:59 +0200)
Major restructuring of how streams work. Neither input nor output
streams are now blocking. This avoids stalling the rest of the client or
server when a peer is slow or unresponsive.

Note that this puts an extra burden on users of streams to make sure
they are allowed to do their work once the underlying transports are
ready (e.g. monitoring fds).

69 files changed:
common/rdr/BufferedInStream.cxx
common/rdr/BufferedInStream.h
common/rdr/BufferedOutStream.cxx
common/rdr/BufferedOutStream.h
common/rdr/Exception.h
common/rdr/FdInStream.cxx
common/rdr/FdInStream.h
common/rdr/FdOutStream.cxx
common/rdr/FdOutStream.h
common/rdr/FileInStream.cxx
common/rdr/FileInStream.h
common/rdr/HexInStream.cxx
common/rdr/InStream.h
common/rdr/MemInStream.h
common/rdr/OutStream.h
common/rdr/RandomStream.cxx
common/rdr/RandomStream.h
common/rdr/TLSInStream.cxx
common/rdr/TLSInStream.h
common/rdr/ZlibInStream.cxx
common/rdr/ZlibInStream.h
common/rfb/CConnection.cxx
common/rfb/CConnection.h
common/rfb/CMsgHandler.h
common/rfb/CMsgReader.cxx
common/rfb/CMsgReader.h
common/rfb/CSecurityTLS.cxx
common/rfb/CSecurityVeNCrypt.cxx
common/rfb/CSecurityVeNCrypt.h
common/rfb/CSecurityVncAuth.cxx
common/rfb/CopyRectDecoder.cxx
common/rfb/CopyRectDecoder.h
common/rfb/DecodeManager.cxx
common/rfb/DecodeManager.h
common/rfb/Decoder.h
common/rfb/HextileDecoder.cxx
common/rfb/HextileDecoder.h
common/rfb/RREDecoder.cxx
common/rfb/RREDecoder.h
common/rfb/RawDecoder.cxx
common/rfb/RawDecoder.h
common/rfb/SConnection.cxx
common/rfb/SConnection.h
common/rfb/SMsgReader.cxx
common/rfb/SMsgReader.h
common/rfb/SSecurityPlain.cxx
common/rfb/SSecurityVeNCrypt.cxx
common/rfb/SSecurityVncAuth.cxx
common/rfb/SSecurityVncAuth.h
common/rfb/ServerCore.cxx
common/rfb/ServerCore.h
common/rfb/TightDecoder.cxx
common/rfb/TightDecoder.h
common/rfb/VNCSConnectionST.cxx
common/rfb/VNCSConnectionST.h
common/rfb/ZRLEDecoder.cxx
common/rfb/ZRLEDecoder.h
common/rfb/zrleDecode.h
tests/perf/decperf.cxx
tests/perf/encperf.cxx
unix/vncserver/vncserver.in
unix/x0vncserver/x0vncserver.cxx
unix/x0vncserver/x0vncserver.man
unix/xserver/hw/vnc/XserverDesktop.cc
unix/xserver/hw/vnc/Xvnc.man
unix/xserver/hw/vnc/vncExtInit.cc
vncviewer/CConn.cxx
vncviewer/CConn.h
win/rfb_win32/SocketManager.cxx

index 14b735639ce40f6c9c4d721f52d55a8ae3dc15ca..5a2694b440769bf597ca6003bee5156714a52bf7 100644 (file)
@@ -47,7 +47,7 @@ size_t BufferedInStream::pos()
   return offset + ptr - start;
 }
 
-bool BufferedInStream::overrun(size_t needed, bool wait)
+bool BufferedInStream::overrun(size_t needed)
 {
   struct timeval now;
 
@@ -112,7 +112,7 @@ bool BufferedInStream::overrun(size_t needed, bool wait)
   }
 
   while (avail() < needed) {
-    if (!fillBuffer(start + bufSize - end, wait))
+    if (!fillBuffer(start + bufSize - end))
       return false;
   }
 
index 24b5a23cbaa04a250a5aaafa2ec2086d8cce77ca..84405255e11ec8de795d5c30a91bc096635c1364 100644 (file)
@@ -38,9 +38,9 @@ namespace rdr {
     virtual size_t pos();
 
   private:
-    virtual bool fillBuffer(size_t maxSize, bool wait) = 0;
+    virtual bool fillBuffer(size_t maxSize) = 0;
 
-    virtual bool overrun(size_t needed, bool wait);
+    virtual bool overrun(size_t needed);
 
   private:
     size_t bufSize;
index 930b80b99349e2c73ca06d7db793171da41a49cb..c8f6ce9cb6d733f0bfcea49f86f77e18c71a15f0 100644 (file)
@@ -60,7 +60,7 @@ void BufferedOutStream::flush()
 
     len = (ptr - sentUpTo);
 
-    if (!flushBuffer(false))
+    if (!flushBuffer())
       break;
 
     offset += len - (ptr - sentUpTo);
@@ -148,4 +148,6 @@ void BufferedOutStream::overrun(size_t needed)
 
   gettimeofday(&lastSizeCheck, NULL);
   peakUsage = totalNeeded;
+
+  return;
 }
index dd64a1364d4139fb0a08e858b29eb8f7937f9acb..b01d1fee3137a74bfa1ae97214068ec6da9c1fc7 100644 (file)
@@ -45,10 +45,9 @@ namespace rdr {
   private:
     // flushBuffer() requests that the stream be flushed. Returns true if it is
     // able to progress the output (which might still not mean any bytes
-    // actually moved) and can be called again. If wait is true then it will
-    // block until all data has been written.
+    // actually moved) and can be called again.
 
-    virtual bool flushBuffer(bool wait) = 0;
+    virtual bool flushBuffer() = 0;
 
     virtual void overrun(size_t needed);
 
index eb3c8a9dd30dade21e1850ad854de126c13375e4..e5bff80d1c937f40107986386f25abaa2f7c280f 100644 (file)
@@ -47,10 +47,6 @@ namespace rdr {
     GAIException(const char* s, int err_);
   };
 
-  struct TimedOut : public Exception {
-    TimedOut() : Exception("Timed out") {}
-  };
-
   struct EndOfStream : public Exception {
     EndOfStream() : Exception("End of stream") {}
   };
index 27de92bbf0f5c5d8b2b7c3183de0d5e587b44144..ecc34ecdcb284bec542257424937165937c5da40 100644 (file)
 
 using namespace rdr;
 
-enum { DEFAULT_BUF_SIZE = 8192 };
-
-FdInStream::FdInStream(int fd_, int timeoutms_,
-                       bool closeWhenDone_)
-  : fd(fd_), closeWhenDone(closeWhenDone_),
-    timeoutms(timeoutms_), blockCallback(0)
-{
-}
-
-FdInStream::FdInStream(int fd_, FdInStreamBlockCallback* blockCallback_)
-  : fd(fd_), timeoutms(0), blockCallback(blockCallback_)
+FdInStream::FdInStream(int fd_, bool closeWhenDone_)
+  : fd(fd_), closeWhenDone(closeWhenDone_)
 {
 }
 
@@ -66,20 +57,9 @@ FdInStream::~FdInStream()
 }
 
 
-void FdInStream::setTimeout(int timeoutms_) {
-  timeoutms = timeoutms_;
-}
-
-void FdInStream::setBlockCallback(FdInStreamBlockCallback* blockCallback_)
-{
-  blockCallback = blockCallback_;
-  timeoutms = 0;
-}
-
-
-bool FdInStream::fillBuffer(size_t maxSize, bool wait)
+bool FdInStream::fillBuffer(size_t maxSize)
 {
-  size_t n = readWithTimeoutOrCallback((U8*)end, maxSize, wait);
+  size_t n = readFd((U8*)end, maxSize);
   if (n == 0)
     return false;
   end += n;
@@ -88,55 +68,43 @@ bool FdInStream::fillBuffer(size_t maxSize, bool wait)
 }
 
 //
-// readWithTimeoutOrCallback() reads up to the given length in bytes from the
-// file descriptor into a buffer.  If the wait argument is false, then zero is
-// returned if no bytes can be read without blocking.  Otherwise if a
-// blockCallback is set, it will be called (repeatedly) instead of blocking.
-// If alternatively there is a timeout set and that timeout expires, it throws
-// a TimedOut exception.  Otherwise it returns the number of bytes read.  It
+// readFd() reads up to the given length in bytes from the
+// file descriptor into a buffer. Zero is
+// returned if no bytes can be read. Otherwise it returns the number of bytes read.  It
 // never attempts to recv() unless select() indicates that the fd is readable -
 // this means it can be used on an fd which has been set non-blocking.  It also
 // has to cope with the annoying possibility of both select() and recv()
 // returning EINTR.
 //
 
-size_t FdInStream::readWithTimeoutOrCallback(void* buf, size_t len, bool wait)
+size_t FdInStream::readFd(void* buf, size_t len)
 {
   int n;
-  while (true) {
-    do {
-      fd_set fds;
-      struct timeval tv;
-      struct timeval* tvp = &tv;
-
-      if (!wait) {
-        tv.tv_sec = tv.tv_usec = 0;
-      } else if (timeoutms != -1) {
-        tv.tv_sec = timeoutms / 1000;
-        tv.tv_usec = (timeoutms % 1000) * 1000;
-      } else {
-        tvp = 0;
-      }
-
-      FD_ZERO(&fds);
-      FD_SET(fd, &fds);
-      n = select(fd+1, &fds, 0, 0, tvp);
-    } while (n < 0 && errno == EINTR);
-
-    if (n > 0) break;
-    if (n < 0) throw SystemException("select",errno);
-    if (!wait) return 0;
-    if (!blockCallback) throw TimedOut();
-
-    blockCallback->blockCallback();
-  }
+  do {
+    fd_set fds;
+    struct timeval tv;
+
+    tv.tv_sec = tv.tv_usec = 0;
+
+    FD_ZERO(&fds);
+    FD_SET(fd, &fds);
+    n = select(fd+1, &fds, 0, 0, &tv);
+  } while (n < 0 && errno == EINTR);
+
+  if (n < 0)
+    throw SystemException("select",errno);
+
+  if (n == 0)
+    return 0;
 
   do {
     n = ::recv(fd, (char*)buf, len, 0);
   } while (n < 0 && errno == EINTR);
 
-  if (n < 0) throw SystemException("read",errno);
-  if (n == 0) throw EndOfStream();
+  if (n < 0)
+    throw SystemException("read",errno);
+  if (n == 0)
+    throw EndOfStream();
 
   return n;
 }
index 0203389b6a57c4495cf8acb748449cc866641ed9..f732ceaaf17ca51adb8712488a87c9b1a067c118 100644 (file)
 
 namespace rdr {
 
-  class FdInStreamBlockCallback {
-  public:
-    virtual void blockCallback() = 0;
-    virtual ~FdInStreamBlockCallback() {}
-  };
-
   class FdInStream : public BufferedInStream {
 
   public:
 
-    FdInStream(int fd, int timeoutms=-1, bool closeWhenDone_=false);
-    FdInStream(int fd, FdInStreamBlockCallback* blockCallback);
+    FdInStream(int fd, bool closeWhenDone_=false);
     virtual ~FdInStream();
 
-    void setTimeout(int timeoutms);
-    void setBlockCallback(FdInStreamBlockCallback* blockCallback);
     int getFd() { return fd; }
 
   private:
-    virtual bool fillBuffer(size_t maxSize, bool wait);
+    virtual bool fillBuffer(size_t maxSize);
 
-    size_t readWithTimeoutOrCallback(void* buf, size_t len, bool wait=true);
+    size_t readFd(void* buf, size_t len);
 
     int fd;
     bool closeWhenDone;
-    int timeoutms;
-    FdInStreamBlockCallback* blockCallback;
 
     size_t offset;
     U8* start;
index 3405838dfa82aff22f2f235db257bd764a99a022..b52fc85dd133a29de7c445b9744a3a1e63b8e3a9 100644 (file)
 
 using namespace rdr;
 
-FdOutStream::FdOutStream(int fd_, bool blocking_, int timeoutms_)
-  : fd(fd_), blocking(blocking_), timeoutms(timeoutms_)
+FdOutStream::FdOutStream(int fd_)
+  : fd(fd_)
 {
   gettimeofday(&lastWrite, NULL);
 }
 
 FdOutStream::~FdOutStream()
 {
-  try {
-    while (sentUpTo != ptr)
-      flushBuffer(true);
-  } catch (Exception&) {
-  }
-}
-
-void FdOutStream::setTimeout(int timeoutms_) {
-  timeoutms = timeoutms_;
-}
-
-void FdOutStream::setBlocking(bool blocking_) {
-  blocking = blocking_;
 }
 
 unsigned FdOutStream::getIdleTime()
@@ -87,20 +74,11 @@ void FdOutStream::cork(bool enable)
 #endif
 }
 
-bool FdOutStream::flushBuffer(bool wait)
+bool FdOutStream::flushBuffer()
 {
-  size_t n = writeWithTimeout((const void*) sentUpTo,
-                              ptr - sentUpTo,
-                              (blocking || wait)? timeoutms : 0);
-
-  // Timeout?
-  if (n == 0) {
-    // If non-blocking then we're done here
-    if (!blocking && !wait)
-      return false;
-
-    throw TimedOut();
-  }
+  size_t n = writeFd((const void*) sentUpTo, ptr - sentUpTo);
+  if (n == 0)
+    return false;
 
   sentUpTo += n;
 
@@ -108,34 +86,27 @@ bool FdOutStream::flushBuffer(bool wait)
 }
 
 //
-// writeWithTimeout() writes up to the given length in bytes from the given
-// buffer to the file descriptor.  If there is a timeout set and that timeout
-// expires, it throws a TimedOut exception.  Otherwise it returns the number of
-// bytes written.  It never attempts to send() unless select() indicates that
-// the fd is writable - this means it can be used on an fd which has been set
-// non-blocking.  It also has to cope with the annoying possibility of both
-// select() and send() returning EINTR.
+// writeFd() writes up to the given length in bytes from the given
+// buffer to the file descriptor. It returns the number of bytes written.  It
+// never attempts to send() unless select() indicates that the fd is writable
+// - this means it can be used on an fd which has been set non-blocking.  It
+// also has to cope with the annoying possibility of both select() and send()
+// returning EINTR.
 //
 
-size_t FdOutStream::writeWithTimeout(const void* data, size_t length, int timeoutms)
+size_t FdOutStream::writeFd(const void* data, size_t length)
 {
   int n;
 
   do {
     fd_set fds;
     struct timeval tv;
-    struct timeval* tvp = &tv;
 
-    if (timeoutms != -1) {
-      tv.tv_sec = timeoutms / 1000;
-      tv.tv_usec = (timeoutms % 1000) * 1000;
-    } else {
-      tvp = NULL;
-    }
+    tv.tv_sec = tv.tv_usec = 0;
 
     FD_ZERO(&fds);
     FD_SET(fd, &fds);
-    n = select(fd+1, 0, &fds, 0, tvp);
+    n = select(fd+1, 0, &fds, 0, &tv);
   } while (n < 0 && errno == EINTR);
 
   if (n < 0)
index b1fb74c0f40daa583a6281b76fd60f49c3413754..80804da4518ced89d63c5c895744de500b64e0e8 100644 (file)
@@ -34,11 +34,9 @@ namespace rdr {
 
   public:
 
-    FdOutStream(int fd, bool blocking=true, int timeoutms=-1);
+    FdOutStream(int fd);
     virtual ~FdOutStream();
 
-    void setTimeout(int timeoutms);
-    void setBlocking(bool blocking);
     int getFd() { return fd; }
 
     unsigned getIdleTime();
@@ -46,11 +44,9 @@ namespace rdr {
     virtual void cork(bool enable);
 
   private:
-    virtual bool flushBuffer(bool wait);
-    size_t writeWithTimeout(const void* data, size_t length, int timeoutms);
+    virtual bool flushBuffer();
+    size_t writeFd(const void* data, size_t length);
     int fd;
-    bool blocking;
-    int timeoutms;
     struct timeval lastWrite;
   };
 
index 66dfe76691fd4e5b531b7f8f4ab7cae13130d1db..9975fde6054cf7ec331719d3c3be0c45a7651588 100644 (file)
@@ -39,7 +39,7 @@ FileInStream::~FileInStream(void) {
   }
 }
 
-bool FileInStream::fillBuffer(size_t maxSize, bool wait)
+bool FileInStream::fillBuffer(size_t maxSize)
 {
   size_t n = fread((U8 *)end, 1, maxSize, file);
   if (n == 0) {
index 268f5375f8704e72bfd4ef25d7f387a55b6512b9..619397f0e90e8faf058c371ffacbc69994d84b17 100644 (file)
@@ -34,7 +34,7 @@ namespace rdr {
     ~FileInStream(void);
 
   private:
-    virtual bool fillBuffer(size_t maxSize, bool wait);
+    virtual bool fillBuffer(size_t maxSize);
 
   private:
     FILE *file;
index 322432c0f91b5514179c9d8a3ff1ef8294760333..66bbf174ae91af72e88cb3c81170555e04f0778a 100644 (file)
@@ -73,7 +73,7 @@ decodeError:
 
 
 bool HexInStream::fillBuffer(size_t maxSize, bool wait) {
-  if (!in_stream.check(2, wait))
+  if (!in_stream.hasData(2))
     return false;
 
   size_t length = min(in_stream.avail()/2, maxSize);
index 5d873011a61433992d532523f5dc8a51ec23e5f0..60ea4997ee3985fac4d6698bbb59b7049994921f 100644 (file)
@@ -1,4 +1,5 @@
 /* Copyright (C) 2002-2005 RealVNC Ltd.  All Rights Reserved.
+ * Copyright 2014-2020 Pierre Ossman for Cendio AB
  * 
  * This is free software; you can redistribute it and/or modify
  * it under the terms of the GNU General Public License as published by
 #include <rdr/Exception.h>
 #include <string.h> // for memcpy
 
+// Check that callers are using InStream properly,
+// useful when writing new protocol handling
+#undef RFB_INSTREAM_CHECK
+
 namespace rdr {
 
   class InStream {
@@ -39,29 +44,79 @@ namespace rdr {
     // avail() returns the number of bytes that are currenctly directly
     // available from the stream.
 
-    inline size_t avail()
-    {
+    inline size_t avail() {
+#ifdef RFB_INSTREAM_CHECK
+      checkedBytes = end - ptr;
+#endif
+
       return end - ptr;
     }
 
-    // check() ensures there is buffer data for at least needed bytes. Returns
-    // true once the data is available. If wait is false, then instead of
-    // blocking to wait for the bytes, false is returned if the bytes are not
-    // immediately available.
+    // hasData() ensures there is at least "length" bytes of buffer data,
+    // possibly trying to fetch more data if there isn't enough right away
+
+    inline bool hasData(size_t length) {
+#ifdef RFB_INSTREAM_CHECK
+      checkedBytes = 0;
+#endif
+
+      if (length > (size_t)(end - ptr)) {
+        if (restorePoint != NULL) {
+          bool ret;
+          size_t restoreDiff;
+
+          restoreDiff = ptr - restorePoint;
+          ptr = restorePoint;
+
+          ret = overrun(length + restoreDiff);
 
-    inline size_t check(size_t needed, bool wait=true)
-    {
-      if (needed > avail())
-        return overrun(needed, wait);
+          restorePoint = ptr;
+          ptr += restoreDiff;
+
+          if (!ret)
+            return false;
+        } else {
+          if (!overrun(length))
+            return false;
+        }
+      }
+
+#ifdef RFB_INSTREAM_CHECK
+      checkedBytes = length;
+#endif
 
       return true;
     }
 
-    // checkNoWait() tries to make sure that the given number of bytes can
-    // be read without blocking.  It returns true if this is the case, false
-    // otherwise.  The length must be "small" (less than the buffer size).
+    inline bool hasDataOrRestore(size_t length) {
+      if (hasData(length))
+        return true;
+      gotoRestorePoint();
+      return false;
+    }
 
-    inline bool checkNoWait(size_t length) { return check(length, false); }
+    inline void setRestorePoint() {
+#ifdef RFB_INSTREAM_CHECK
+      if (restorePoint != NULL)
+        throw Exception("Nested use of input stream restore point");
+#endif
+      restorePoint = ptr;
+    }
+    inline void clearRestorePoint() {
+#ifdef RFB_INSTREAM_CHECK
+      if (restorePoint == NULL)
+        throw Exception("Incorrect clearing of input stream restore point");
+#endif
+      restorePoint = NULL;
+    }
+    inline void gotoRestorePoint() {
+#ifdef RFB_INSTREAM_CHECK
+      if (restorePoint == NULL)
+        throw Exception("Incorrect activation of input stream restore point");
+#endif
+      ptr = restorePoint;
+      clearRestorePoint();
+    }
 
     // readU/SN() methods read unsigned and signed N-bit integers.
 
@@ -76,24 +131,19 @@ namespace rdr {
     inline S16 readS16() { return (S16)readU16(); }
     inline S32 readS32() { return (S32)readU32(); }
 
+    // skip() ignores a number of bytes on the stream
+
     inline void skip(size_t bytes) {
-      while (bytes > 0) {
-        size_t n = check(1, bytes);
-        ptr += n;
-        bytes -= n;
-      }
+      check(bytes);
+      ptr += bytes;
     }
 
     // readBytes() reads an exact number of bytes.
 
     void readBytes(void* data, size_t length) {
-      while (length > 0) {
-        size_t n = check(1, length);
-        memcpy(data, ptr, n);
-        ptr += n;
-        data = (U8*)data + n;
-        length -= n;
-      }
+      check(length);
+      memcpy(data, ptr, length);
+      ptr += length;
     }
 
     // readOpaqueN() reads a quantity without byte-swapping.
@@ -113,24 +163,45 @@ namespace rdr {
     // to the buffer. This is useful for a stream which is a wrapper around an
     // some other stream API.
 
-    inline const U8* getptr(size_t length) { check(length); return ptr; }
+    inline const U8* getptr(size_t length) { check(length);
+#ifdef RFB_INSTREAM_CHECK
+                                             checkedBytes += length;
+#endif
+                                             return ptr; }
     inline void setptr(size_t length) { if (length > avail())
                                           throw Exception("Input stream overflow");
                                         skip(length); }
 
   private:
 
+    const U8* restorePoint;
+#ifdef RFB_INSTREAM_CHECK
+    size_t checkedBytes;
+#endif
+
+    inline void check(size_t bytes) {
+#ifdef RFB_INSTREAM_CHECK
+      if (bytes > checkedBytes)
+        throw Exception("Input stream used without underrun check");
+      checkedBytes -= bytes;
+#endif
+      if (bytes > (size_t)(end - ptr))
+        throw Exception("InStream buffer underrun");
+    }
+
     // overrun() is implemented by a derived class to cope with buffer overrun.
-    // It ensures there are at least needed bytes of buffer data. Returns true
-    // once the data is available. If wait is false, then instead of blocking
-    // to wait for the bytes, false is returned if the bytes are not
-    // immediately available.
+    // It tries to ensure there are at least needed bytes of buffer data.
+    // Returns true if it managed to satisfy the request, or false otherwise.
 
-    virtual bool overrun(size_t needed, bool wait=true) = 0;
+    virtual bool overrun(size_t needed) = 0;
 
   protected:
 
-    InStream() {}
+    InStream() : restorePoint(NULL)
+#ifdef RFB_INSTREAM_CHECK
+      ,checkedBytes(0)
+#endif
+     {}
     const U8* ptr;
     const U8* end;
   };
index 83740dd9ffd8457cb51a1447dc8a75aa8a267f80..a519659428c0f695ab86d6021a9058f33d5972a5 100644 (file)
@@ -41,6 +41,12 @@ namespace rdr {
     {
       ptr = start;
       end = start + len;
+
+#ifdef RFB_INSTREAM_CHECK
+      // MemInStream cannot add more data, so callers are assumed to already
+      // new the total size
+      avail();
+#endif
     }
 
     virtual ~MemInStream() {
@@ -53,7 +59,7 @@ namespace rdr {
 
   private:
 
-    bool overrun(size_t needed, bool wait) { throw EndOfStream(); }
+    bool overrun(size_t needed) { throw EndOfStream(); }
     const U8* start;
     bool deleteWhenDone;
   };
index f432520f983af5733d33c4356258193167fb2006..61d5100b424ff78df0fccd77799f8c41b52bdedf 100644 (file)
@@ -49,14 +49,6 @@ namespace rdr {
       return end - ptr;
     }
 
-    // check() ensures there is buffer space for at least needed bytes.
-
-    inline void check(size_t needed)
-    {
-      if (needed > avail())
-        overrun(needed);
-    }
-
     // writeU/SN() methods write unsigned and signed N-bit integers.
 
     inline void writeU8( U8  u) { check(1); *ptr++ = u; }
@@ -136,6 +128,12 @@ namespace rdr {
 
   private:
 
+    inline void check(size_t length)
+    {
+      if (length > avail())
+        overrun(length);
+    }
+
     // overrun() is implemented by a derived class to cope with buffer overrun.
     // It ensures there are at least needed bytes of buffer space.
 
index 6333be3f063df352195ec704e651a0c9af96f3a9..e2da095732acf55a2987ab62199db378478b0921 100644 (file)
@@ -79,7 +79,7 @@ RandomStream::~RandomStream() {
 #endif
 }
 
-bool RandomStream::fillBuffer(size_t maxSize, bool wait) {
+bool RandomStream::fillBuffer(size_t maxSize) {
 #ifdef RFB_HAVE_WINCRYPT
   if (provider) {
     if (!CryptGenRandom(provider, maxSize, (U8*)end))
index 08ae0ff6f16be50e70bb6ba6e6850a641bce072f..58986433b464c074d3698e4bc26feba2b766e739 100644 (file)
@@ -40,7 +40,7 @@ namespace rdr {
     virtual ~RandomStream();
 
   private:
-    virtual bool fillBuffer(size_t maxSize, bool wait);
+    virtual bool fillBuffer(size_t maxSize);
 
   private:
     static unsigned int seed;
index ba20e752ffb608a57f2c58055541ad95340e4a29..2339956d66e390072f0c933d17f7b633fb98939b 100644 (file)
@@ -39,7 +39,7 @@ ssize_t TLSInStream::pull(gnutls_transport_ptr_t str, void* data, size_t size)
   InStream *in = self->in;
 
   try {
-    if (!in->check(1, false)) {
+    if (!in->hasData(1)) {
       gnutls_transport_set_errno(self->session, EAGAIN);
       return -1;
     }
@@ -74,23 +74,22 @@ TLSInStream::~TLSInStream()
   gnutls_transport_set_pull_function(session, NULL);
 }
 
-bool TLSInStream::fillBuffer(size_t maxSize, bool wait)
+bool TLSInStream::fillBuffer(size_t maxSize)
 {
-  size_t n = readTLS((U8*) end, maxSize, wait);
-  if (!wait && n == 0)
+  size_t n = readTLS((U8*) end, maxSize);
+  if (n == 0)
     return false;
   end += n;
 
   return true;
 }
 
-size_t TLSInStream::readTLS(U8* buf, size_t len, bool wait)
+size_t TLSInStream::readTLS(U8* buf, size_t len)
 {
   int n;
 
   if (gnutls_record_check_pending(session) == 0) {
-    n = in->check(1, wait);
-    if (n == 0)
+    if (!in->hasData(1))
       return 0;
   }
 
index 9779c68e4779343f3c67f6795b46d63e55ef0541..df5ebb4883af92c4a2997e4e0ddbb76a66087b02 100644 (file)
@@ -37,8 +37,8 @@ namespace rdr {
     virtual ~TLSInStream();
 
   private:
-    virtual bool fillBuffer(size_t maxSize, bool wait);
-    size_t readTLS(U8* buf, size_t len, bool wait);
+    virtual bool fillBuffer(size_t maxSize);
+    size_t readTLS(U8* buf, size_t len);
     static ssize_t pull(gnutls_transport_ptr_t str, void* data, size_t size);
 
     gnutls_session_t session;
index 26977228e2a304a227d5d4abea828c57a478a064..0cacc21f3f18bf4fd90dee1d1684118da3720458 100644 (file)
@@ -45,7 +45,7 @@ void ZlibInStream::setUnderlying(InStream* is, size_t bytesIn_)
 void ZlibInStream::flushUnderlying()
 {
   while (bytesIn > 0) {
-    if (!check(1))
+    if (!hasData(1))
       throw Exception("ZlibInStream: failed to flush remaining stream data");
     skip(avail());
   }
@@ -85,7 +85,7 @@ void ZlibInStream::deinit()
   zs = NULL;
 }
 
-bool ZlibInStream::fillBuffer(size_t maxSize, bool wait)
+bool ZlibInStream::fillBuffer(size_t maxSize)
 {
   if (!underlying)
     throw Exception("ZlibInStream overrun: no underlying stream");
@@ -93,8 +93,8 @@ bool ZlibInStream::fillBuffer(size_t maxSize, bool wait)
   zs->next_out = (U8*)end;
   zs->avail_out = maxSize;
 
-  size_t n = underlying->check(1, wait);
-  if (n == 0) return false;
+  if (!underlying->hasData(1))
+    return false;
   size_t length = underlying->avail();
   if (length > bytesIn)
     length = bytesIn;
index 1597b54a546f4ee228e4d49093311ca1e431feab..302c99d8887810686e7c7eb79f26079712ab8fe1 100644 (file)
@@ -44,7 +44,7 @@ namespace rdr {
     void init();
     void deinit();
 
-    virtual bool fillBuffer(size_t maxSize, bool wait);
+    virtual bool fillBuffer(size_t maxSize);
 
   private:
     InStream* underlying;
index cb1b84bdeb3cfd186d3034016efbb180fcfdaa49..9c957a540e93f690cbfc227187a7daf1ba9c845c 100644 (file)
@@ -120,16 +120,19 @@ void CConnection::initialiseProtocol()
   state_ = RFBSTATE_PROTOCOL_VERSION;
 }
 
-void CConnection::processMsg()
+bool CConnection::processMsg()
 {
   switch (state_) {
 
-  case RFBSTATE_PROTOCOL_VERSION: processVersionMsg();       break;
-  case RFBSTATE_SECURITY_TYPES:   processSecurityTypesMsg(); break;
-  case RFBSTATE_SECURITY:         processSecurityMsg();      break;
-  case RFBSTATE_SECURITY_RESULT:  processSecurityResultMsg(); break;
-  case RFBSTATE_INITIALISATION:   processInitMsg();          break;
-  case RFBSTATE_NORMAL:           reader_->readMsg();        break;
+  case RFBSTATE_PROTOCOL_VERSION: return processVersionMsg();        break;
+  case RFBSTATE_SECURITY_TYPES:   return processSecurityTypesMsg();  break;
+  case RFBSTATE_SECURITY:         return processSecurityMsg();       break;
+  case RFBSTATE_SECURITY_RESULT:  return processSecurityResultMsg(); break;
+  case RFBSTATE_SECURITY_REASON:  return processSecurityReasonMsg(); break;
+  case RFBSTATE_INITIALISATION:   return processInitMsg();           break;
+  case RFBSTATE_NORMAL:           return reader_->readMsg();         break;
+  case RFBSTATE_CLOSING:
+    throw Exception("CConnection::processMsg: called while closing");
   case RFBSTATE_UNINITIALISED:
     throw Exception("CConnection::processMsg: not initialised yet?");
   default:
@@ -137,7 +140,7 @@ void CConnection::processMsg()
   }
 }
 
-void CConnection::processVersionMsg()
+bool CConnection::processVersionMsg()
 {
   char verStr[27]; // FIXME: gcc has some bug in format-overflow
   int majorVersion;
@@ -145,8 +148,8 @@ void CConnection::processVersionMsg()
 
   vlog.debug("reading protocol version");
 
-  if (!is->checkNoWait(12))
-    return;
+  if (!is->hasData(12))
+    return false;
 
   is->readBytes(verStr, 12);
   verStr[12] = '\0';
@@ -184,10 +187,12 @@ void CConnection::processVersionMsg()
 
   vlog.info("Using RFB protocol version %d.%d",
             server.majorVersion, server.minorVersion);
+
+  return true;
 }
 
 
-void CConnection::processSecurityTypesMsg()
+bool CConnection::processSecurityTypesMsg()
 {
   vlog.debug("processing security types message");
 
@@ -200,10 +205,13 @@ void CConnection::processSecurityTypesMsg()
 
     // legacy 3.3 server may only offer "vnc authentication" or "none"
 
+    if (!is->hasData(4))
+      return false;
+
     secType = is->readU32();
     if (secType == secTypeInvalid) {
-      throwConnFailedException();
-
+      state_ = RFBSTATE_SECURITY_REASON;
+      return true;
     } else if (secType == secTypeNone || secType == secTypeVncAuth) {
       std::list<rdr::U8>::iterator i;
       for (i = secTypes.begin(); i != secTypes.end(); i++)
@@ -223,9 +231,21 @@ void CConnection::processSecurityTypesMsg()
 
     // >=3.7 server will offer us a list
 
+    if (!is->hasData(1))
+      return false;
+
+    is->setRestorePoint();
+
     int nServerSecTypes = is->readU8();
-    if (nServerSecTypes == 0)
-      throwConnFailedException();
+
+    if (!is->hasDataOrRestore(nServerSecTypes))
+      return false;
+    is->clearRestorePoint();
+
+    if (nServerSecTypes == 0) {
+      state_ = RFBSTATE_SECURITY_REASON;
+      return true;
+    }
 
     std::list<rdr::U8>::iterator j;
 
@@ -263,32 +283,38 @@ void CConnection::processSecurityTypesMsg()
 
   state_ = RFBSTATE_SECURITY;
   csecurity = security.GetCSecurity(this, secType);
-  processSecurityMsg();
+
+  return true;
 }
 
-void CConnection::processSecurityMsg()
+bool CConnection::processSecurityMsg()
 {
   vlog.debug("processing security message");
-  if (csecurity->processMsg()) {
-    state_ = RFBSTATE_SECURITY_RESULT;
-    processSecurityResultMsg();
-  }
+  if (!csecurity->processMsg())
+    return false;
+
+  state_ = RFBSTATE_SECURITY_RESULT;
+
+  return true;
 }
 
-void CConnection::processSecurityResultMsg()
+bool CConnection::processSecurityResultMsg()
 {
   vlog.debug("processing security result message");
   int result;
+
   if (server.beforeVersion(3,8) && csecurity->getType() == secTypeNone) {
     result = secResultOK;
   } else {
-    if (!is->checkNoWait(1)) return;
+    if (!is->hasData(4))
+      return false;
     result = is->readU32();
   }
+
   switch (result) {
   case secResultOK:
     securityCompleted();
-    return;
+    return true;
   case secResultFailed:
     vlog.debug("auth failed");
     break;
@@ -298,30 +324,42 @@ void CConnection::processSecurityResultMsg()
   default:
     throw Exception("Unknown security result from server");
   }
-  state_ = RFBSTATE_INVALID;
-  if (server.beforeVersion(3,8))
+
+  if (server.beforeVersion(3,8)) {
+    state_ = RFBSTATE_INVALID;
     throw AuthFailureException();
+  }
+
+  state_ = RFBSTATE_SECURITY_REASON;
+  return true;
+}
+
+bool CConnection::processSecurityReasonMsg()
+{
+  vlog.debug("processing security reason message");
+
+  if (!is->hasData(4))
+    return false;
+
+  is->setRestorePoint();
+
   rdr::U32 len = is->readU32();
+  if (!is->hasDataOrRestore(len))
+    return false;
+  is->clearRestorePoint();
+
   CharArray reason(len + 1);
   is->readBytes(reason.buf, len);
   reason.buf[len] = '\0';
+
+  state_ = RFBSTATE_INVALID;
   throw AuthFailureException(reason.buf);
 }
 
-void CConnection::processInitMsg()
+bool CConnection::processInitMsg()
 {
   vlog.debug("reading server initialisation");
-  reader_->readServerInit();
-}
-
-void CConnection::throwConnFailedException()
-{
-  state_ = RFBSTATE_INVALID;
-  rdr::U32 len = is->readU32();
-  CharArray reason(len + 1);
-  is->readBytes(reason.buf, len);
-  reason.buf[len] = '\0';
-  throw ConnFailedException(reason.buf);
+  return reader_->readServerInit();
 }
 
 void CConnection::securityCompleted()
@@ -429,11 +467,13 @@ void CConnection::serverInit(int width, int height,
   }
 }
 
-void CConnection::readAndDecodeRect(const Rect& r, int encoding,
+bool CConnection::readAndDecodeRect(const Rect& r, int encoding,
                                     ModifiablePixelBuffer* pb)
 {
-  decoder.decodeRect(r, encoding, pb);
+  if (!decoder.decodeRect(r, encoding, pb))
+    return false;
   decoder.flush();
+  return true;
 }
 
 void CConnection::framebufferUpdateStart()
@@ -474,9 +514,9 @@ void CConnection::framebufferUpdateEnd()
   }
 }
 
-void CConnection::dataRect(const Rect& r, int encoding)
+bool CConnection::dataRect(const Rect& r, int encoding)
 {
-  decoder.decodeRect(r, encoding, framebuffer);
+  return decoder.decodeRect(r, encoding, framebuffer);
 }
 
 void CConnection::serverCutText(const char* str)
index 3857be4d7c8ab0b175fd0e0c604036541ab3290e..3c2b04e8c493fd4f5f17f735d57be91e303bb8b7 100644 (file)
@@ -84,7 +84,7 @@ namespace rfb {
     //   In this case, processMsg should always process the available RFB
     //   message before returning.
     // NB: In either case, you must have called initialiseProtocol() first.
-    void processMsg();
+    bool processMsg();
 
     // close() gracefully shuts down the connection to the server and
     // should be called before terminating the underlying network
@@ -107,12 +107,12 @@ namespace rfb {
                             const PixelFormat& pf,
                             const char* name);
 
-    virtual void readAndDecodeRect(const Rect& r, int encoding,
+    virtual bool readAndDecodeRect(const Rect& r, int encoding,
                                    ModifiablePixelBuffer* pb);
 
     virtual void framebufferUpdateStart();
     virtual void framebufferUpdateEnd();
-    virtual void dataRect(const Rect& r, int encoding);
+    virtual bool dataRect(const Rect& r, int encoding);
 
     virtual void serverCutText(const char* str);
 
@@ -216,6 +216,7 @@ namespace rfb {
       RFBSTATE_SECURITY_TYPES,
       RFBSTATE_SECURITY,
       RFBSTATE_SECURITY_RESULT,
+      RFBSTATE_SECURITY_REASON,
       RFBSTATE_INITIALISATION,
       RFBSTATE_NORMAL,
       RFBSTATE_CLOSING,
@@ -249,13 +250,13 @@ namespace rfb {
     virtual void fence(rdr::U32 flags, unsigned len, const char data[]);
 
   private:
-    void processVersionMsg();
-    void processSecurityTypesMsg();
-    void processSecurityMsg();
-    void processSecurityResultMsg();
-    void processInitMsg();
+    bool processVersionMsg();
+    bool processSecurityTypesMsg();
+    bool processSecurityMsg();
+    bool processSecurityResultMsg();
+    bool processSecurityReasonMsg();
+    bool processInitMsg();
     void throwAuthFailureException();
-    void throwConnFailedException();
     void securityCompleted();
 
     void requestNewUpdate();
index 84dd115ce11336eebc7644fbde59345444f0542d..5b14806ad1a869a7dafbec4072b7b15e2720c030 100644 (file)
@@ -61,12 +61,12 @@ namespace rfb {
                             const PixelFormat& pf,
                             const char* name) = 0;
 
-    virtual void readAndDecodeRect(const Rect& r, int encoding,
+    virtual bool readAndDecodeRect(const Rect& r, int encoding,
                                    ModifiablePixelBuffer* pb) = 0;
 
     virtual void framebufferUpdateStart();
     virtual void framebufferUpdateEnd();
-    virtual void dataRect(const Rect& r, int encoding) = 0;
+    virtual bool dataRect(const Rect& r, int encoding) = 0;
 
     virtual void setColourMapEntries(int firstColour, int nColours,
                                     rdr::U16* rgbs) = 0;
index a015ec9906216aecf97f5d8c08538c4426eb18c3..40fb5912a15755a5b0eb9c07c0352c1f906af89d 100644 (file)
@@ -39,7 +39,7 @@ using namespace rfb;
 
 CMsgReader::CMsgReader(CMsgHandler* handler_, rdr::InStream* is_)
   : imageBufIdealSize(0), handler(handler_), is(is_),
-    nUpdateRectsLeft(0)
+    state(MSGSTATE_IDLE), cursorEncoding(-1)
 {
 }
 
@@ -47,149 +47,246 @@ CMsgReader::~CMsgReader()
 {
 }
 
-void CMsgReader::readServerInit()
+bool CMsgReader::readServerInit()
 {
-  int width = is->readU16();
-  int height = is->readU16();
+  int width, height;
+  rdr::U32 len;
+
+  if (!is->hasData(2 + 2 + 16 + 4))
+    return false;
+
+  is->setRestorePoint();
+
+  width = is->readU16();
+  height = is->readU16();
+
   PixelFormat pf;
   pf.read(is);
-  rdr::U32 len = is->readU32();
+
+  len = is->readU32();
+  if (!is->hasDataOrRestore(len))
+    return false;
+  is->clearRestorePoint();
   CharArray name(len + 1);
   is->readBytes(name.buf, len);
   name.buf[len] = '\0';
   handler->serverInit(width, height, pf, name.buf);
+
+  return true;
 }
 
-void CMsgReader::readMsg()
+bool CMsgReader::readMsg()
 {
-  if (nUpdateRectsLeft == 0) {
-    int type = is->readU8();
+  if (state == MSGSTATE_IDLE) {
+    if (!is->hasData(1))
+      return false;
+
+    currentMsgType = is->readU8();
+    state = MSGSTATE_MESSAGE;
+  }
+
+  if (currentMsgType != msgTypeFramebufferUpdate) {
+    bool ret;
 
-    switch (type) {
+    switch (currentMsgType) {
     case msgTypeSetColourMapEntries:
-      readSetColourMapEntries();
+      ret = readSetColourMapEntries();
       break;
     case msgTypeBell:
-      readBell();
+      ret = readBell();
       break;
     case msgTypeServerCutText:
-      readServerCutText();
+      ret = readServerCutText();
       break;
     case msgTypeFramebufferUpdate:
-      readFramebufferUpdate();
+      ret = readFramebufferUpdate();
       break;
     case msgTypeServerFence:
-      readFence();
+      ret = readFence();
       break;
     case msgTypeEndOfContinuousUpdates:
-      readEndOfContinuousUpdates();
+      ret = readEndOfContinuousUpdates();
       break;
     default:
-      vlog.error("unknown message type %d", type);
-      throw Exception("unknown message type");
+      throw Exception("Unknown message type %d", currentMsgType);
     }
+
+    if (ret)
+      state = MSGSTATE_IDLE;
+
+    return ret;
   } else {
-    int x = is->readU16();
-    int y = is->readU16();
-    int w = is->readU16();
-    int h = is->readU16();
-    int encoding = is->readS32();
+    if (state == MSGSTATE_MESSAGE) {
+      if (!readFramebufferUpdate())
+        return false;
+
+      // Empty update?
+      if (nUpdateRectsLeft == 0) {
+        state = MSGSTATE_IDLE;
+        handler->framebufferUpdateEnd();
+        return true;
+      }
+
+      state = MSGSTATE_RECT_HEADER;
+    }
+
+    if (state == MSGSTATE_RECT_HEADER) {
+      if (!is->hasData(12))
+        return false;
+
+      int x = is->readU16();
+      int y = is->readU16();
+      int w = is->readU16();
+      int h = is->readU16();
+
+      dataRect.setXYWH(x, y, w, h);
+
+      rectEncoding = is->readS32();
+
+      state = MSGSTATE_RECT_DATA;
+    }
+
+    bool ret;
 
-    switch (encoding) {
+    switch (rectEncoding) {
     case pseudoEncodingLastRect:
       nUpdateRectsLeft = 1;     // this rectangle is the last one
+      ret = true;
       break;
     case pseudoEncodingXCursor:
-      readSetXCursor(w, h, Point(x,y));
+      ret = readSetXCursor(dataRect.width(), dataRect.height(), dataRect.tl);
       break;
     case pseudoEncodingCursor:
-      readSetCursor(w, h, Point(x,y));
+      ret = readSetCursor(dataRect.width(), dataRect.height(), dataRect.tl);
       break;
     case pseudoEncodingCursorWithAlpha:
-      readSetCursorWithAlpha(w, h, Point(x,y));
+      ret = readSetCursorWithAlpha(dataRect.width(), dataRect.height(), dataRect.tl);
       break;
     case pseudoEncodingVMwareCursor:
-      readSetVMwareCursor(w, h, Point(x,y));
+      ret = readSetVMwareCursor(dataRect.width(), dataRect.height(), dataRect.tl);
       break;
     case pseudoEncodingDesktopName:
-      readSetDesktopName(x, y, w, h);
+      ret = readSetDesktopName(dataRect.tl.x, dataRect.tl.y,
+                               dataRect.width(), dataRect.height());
       break;
     case pseudoEncodingDesktopSize:
-      handler->setDesktopSize(w, h);
+      handler->setDesktopSize(dataRect.width(), dataRect.height());
+      ret = true;
       break;
     case pseudoEncodingExtendedDesktopSize:
-      readExtendedDesktopSize(x, y, w, h);
+      ret = readExtendedDesktopSize(dataRect.tl.x, dataRect.tl.y,
+                                    dataRect.width(), dataRect.height());
       break;
     case pseudoEncodingLEDState:
-      readLEDState();
+      ret = readLEDState();
       break;
     case pseudoEncodingVMwareLEDState:
-      readVMwareLEDState();
+      ret = readVMwareLEDState();
       break;
     case pseudoEncodingQEMUKeyEvent:
       handler->supportsQEMUKeyEvent();
+      ret = true;
       break;
     default:
-      readRect(Rect(x, y, x+w, y+h), encoding);
+      ret = readRect(dataRect, rectEncoding);
       break;
     };
 
-    nUpdateRectsLeft--;
-    if (nUpdateRectsLeft == 0)
-      handler->framebufferUpdateEnd();
+    if (ret) {
+      state = MSGSTATE_RECT_HEADER;
+      nUpdateRectsLeft--;
+      if (nUpdateRectsLeft == 0) {
+        state = MSGSTATE_IDLE;
+        handler->framebufferUpdateEnd();
+      }
+    }
+
+    return ret;
   }
 }
 
-void CMsgReader::readSetColourMapEntries()
+bool CMsgReader::readSetColourMapEntries()
 {
+  if (!is->hasData(5))
+    return false;
+
+  is->setRestorePoint();
+
   is->skip(1);
   int firstColour = is->readU16();
   int nColours = is->readU16();
+
+  if (!is->hasDataOrRestore(nColours * 3 * 2))
+    return false;
+  is->clearRestorePoint();
+
   rdr::U16Array rgbs(nColours * 3);
   for (int i = 0; i < nColours * 3; i++)
     rgbs.buf[i] = is->readU16();
   handler->setColourMapEntries(firstColour, nColours, rgbs.buf);
+
+  return true;
 }
 
-void CMsgReader::readBell()
+bool CMsgReader::readBell()
 {
   handler->bell();
+  return true;
 }
 
-void CMsgReader::readServerCutText()
+bool CMsgReader::readServerCutText()
 {
+  if (!is->hasData(7))
+    return false;
+
+  is->setRestorePoint();
+
   is->skip(3);
   rdr::U32 len = is->readU32();
 
   if (len & 0x80000000) {
     rdr::S32 slen = len;
     slen = -slen;
-    readExtendedClipboard(slen);
-    return;
+    if (readExtendedClipboard(slen)) {
+      is->clearRestorePoint();
+      return true;
+    } else {
+      is->gotoRestorePoint();
+      return false;
+    }
   }
 
+  if (!is->hasDataOrRestore(len))
+    return false;
+  is->clearRestorePoint();
+
   if (len > (size_t)maxCutText) {
     is->skip(len);
     vlog.error("cut text too long (%d bytes) - ignoring",len);
-    return;
+    return true;
   }
   CharArray ca(len);
   is->readBytes(ca.buf, len);
   CharArray filtered(convertLF(ca.buf, len));
   handler->serverCutText(filtered.buf);
+
+  return true;
 }
 
-void CMsgReader::readExtendedClipboard(rdr::S32 len)
+bool CMsgReader::readExtendedClipboard(rdr::S32 len)
 {
   rdr::U32 flags;
   rdr::U32 action;
 
+  if (!is->hasData(len))
+    return false;
+
   if (len < 4)
     throw Exception("Invalid extended clipboard message");
   if (len > maxCutText) {
     vlog.error("Extended clipboard message too long (%d bytes) - ignoring", len);
     is->skip(len);
-    return;
+    return true;
   }
 
   flags = is->readU32();
@@ -231,7 +328,14 @@ void CMsgReader::readExtendedClipboard(rdr::S32 len)
       if (!(flags & 1 << i))
         continue;
 
+      if (!zis.hasData(4))
+        throw Exception("Extended clipboard decode error");
+
       lengths[num] = zis.readU32();
+
+      if (!zis.hasData(lengths[num]))
+        throw Exception("Extended clipboard decode error");
+
       if (lengths[num] > (size_t)maxCutText) {
         vlog.error("Extended clipboard data too long (%d bytes) - ignoring",
                    (unsigned)lengths[num]);
@@ -271,43 +375,63 @@ void CMsgReader::readExtendedClipboard(rdr::S32 len)
       throw Exception("Invalid extended clipboard action");
     }
   }
+
+  return true;
 }
 
-void CMsgReader::readFence()
+bool CMsgReader::readFence()
 {
   rdr::U32 flags;
   rdr::U8 len;
   char data[64];
 
+  if (!is->hasData(8))
+    return false;
+
+  is->setRestorePoint();
+
   is->skip(3);
 
   flags = is->readU32();
 
   len = is->readU8();
+
+  if (!is->hasDataOrRestore(len))
+    return false;
+  is->clearRestorePoint();
+
   if (len > sizeof(data)) {
     vlog.error("Ignoring fence with too large payload");
     is->skip(len);
-    return;
+    return true;
   }
 
   is->readBytes(data, len);
 
   handler->fence(flags, len, data);
+
+  return true;
 }
 
-void CMsgReader::readEndOfContinuousUpdates()
+bool CMsgReader::readEndOfContinuousUpdates()
 {
   handler->endOfContinuousUpdates();
+  return true;
 }
 
-void CMsgReader::readFramebufferUpdate()
+bool CMsgReader::readFramebufferUpdate()
 {
+  if (!is->hasData(3))
+    return false;
+
   is->skip(1);
   nUpdateRectsLeft = is->readU16();
   handler->framebufferUpdateStart();
+
+  return true;
 }
 
-void CMsgReader::readRect(const Rect& r, int encoding)
+bool CMsgReader::readRect(const Rect& r, int encoding)
 {
   if ((r.br.x > handler->server.width()) ||
       (r.br.y > handler->server.height())) {
@@ -320,10 +444,10 @@ void CMsgReader::readRect(const Rect& r, int encoding)
   if (r.is_empty())
     vlog.error("zero size rect");
 
-  handler->dataRect(r, encoding);
+  return handler->dataRect(r, encoding);
 }
 
-void CMsgReader::readSetXCursor(int width, int height, const Point& hotspot)
+bool CMsgReader::readSetXCursor(int width, int height, const Point& hotspot)
 {
   if (width > maxCursorSize || height > maxCursorSize)
     throw Exception("Too big cursor");
@@ -341,6 +465,9 @@ void CMsgReader::readSetXCursor(int width, int height, const Point& hotspot)
     int x, y;
     rdr::U8* out;
 
+    if (!is->hasData(3 + 3 + data_len + mask_len))
+      return false;
+
     pr = is->readU8();
     pg = is->readU8();
     pb = is->readU8();
@@ -380,9 +507,11 @@ void CMsgReader::readSetXCursor(int width, int height, const Point& hotspot)
   }
 
   handler->setCursor(width, height, hotspot, rgba.buf);
+
+  return true;
 }
 
-void CMsgReader::readSetCursor(int width, int height, const Point& hotspot)
+bool CMsgReader::readSetCursor(int width, int height, const Point& hotspot)
 {
   if (width > maxCursorSize || height > maxCursorSize)
     throw Exception("Too big cursor");
@@ -397,6 +526,9 @@ void CMsgReader::readSetCursor(int width, int height, const Point& hotspot)
   rdr::U8* in;
   rdr::U8* out;
 
+  if (!is->hasData(data_len + mask_len))
+    return false;
+
   is->readBytes(data.buf, data_len);
   is->readBytes(mask.buf, mask_len);
 
@@ -421,29 +553,44 @@ void CMsgReader::readSetCursor(int width, int height, const Point& hotspot)
   }
 
   handler->setCursor(width, height, hotspot, rgba.buf);
+
+  return true;
 }
 
-void CMsgReader::readSetCursorWithAlpha(int width, int height, const Point& hotspot)
+bool CMsgReader::readSetCursorWithAlpha(int width, int height, const Point& hotspot)
 {
   if (width > maxCursorSize || height > maxCursorSize)
     throw Exception("Too big cursor");
 
-  int encoding;
-
   const PixelFormat rgbaPF(32, 32, false, true, 255, 255, 255, 16, 8, 0);
   ManagedPixelBuffer pb(rgbaPF, width, height);
   PixelFormat origPF;
 
+  bool ret;
+
   rdr::U8* buf;
   int stride;
 
-  encoding = is->readS32();
+  // We can't use restore points as the decoder likely wants to as well, so
+  // we need to keep track of the read encoding
+
+  if (cursorEncoding == -1) {
+    if (!is->hasData(4))
+      return false;
+
+    cursorEncoding = is->readS32();
+  }
 
   origPF = handler->server.pf();
   handler->server.setPF(rgbaPF);
-  handler->readAndDecodeRect(pb.getRect(), encoding, &pb);
+  ret = handler->readAndDecodeRect(pb.getRect(), cursorEncoding, &pb);
   handler->server.setPF(origPF);
 
+  if (!ret)
+    return false;
+
+  cursorEncoding = -1;
+
   // On-wire data has pre-multiplied alpha, but we store it
   // non-pre-multiplied
   buf = pb.getBufferRW(pb.getRect(), &stride);
@@ -467,18 +614,25 @@ void CMsgReader::readSetCursorWithAlpha(int width, int height, const Point& hots
 
   handler->setCursor(width, height, hotspot,
                      pb.getBuffer(pb.getRect(), &stride));
+
+  return true;
 }
 
-void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot)
+bool CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot)
 {
   if (width > maxCursorSize || height > maxCursorSize)
     throw Exception("Too big cursor");
 
   rdr::U8 type;
 
+  if (!is->hasData(2))
+    return false;
+
   type = is->readU8();
   is->skip(1);
 
+  is->setRestorePoint();
+
   if (type == 0) {
     int len = width * height * (handler->server.pf().bpp/8);
     rdr::U8Array andMask(len);
@@ -491,6 +645,10 @@ void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot
     rdr::U8* out;
     int Bpp;
 
+    if (!is->hasDataOrRestore(len + len))
+      return false;
+    is->clearRestorePoint();
+
     is->readBytes(andMask.buf, len);
     is->readBytes(xorMask.buf, len);
 
@@ -548,6 +706,10 @@ void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot
   } else if (type == 1) {
     rdr::U8Array data(width*height*4);
 
+    if (!is->hasDataOrRestore(width*height*4))
+      return false;
+    is->clearRestorePoint();
+
     // FIXME: Is alpha premultiplied?
     is->readBytes(data.buf, width*height*4);
 
@@ -555,11 +717,25 @@ void CMsgReader::readSetVMwareCursor(int width, int height, const Point& hotspot
   } else {
     throw Exception("Unknown cursor type");
   }
+
+  return true;
 }
 
-void CMsgReader::readSetDesktopName(int x, int y, int w, int h)
+bool CMsgReader::readSetDesktopName(int x, int y, int w, int h)
 {
-  rdr::U32 len = is->readU32();
+  rdr::U32 len;
+
+  if (!is->hasData(4))
+    return false;
+
+  is->setRestorePoint();
+
+  len = is->readU32();
+
+  if (!is->hasDataOrRestore(len))
+    return false;
+  is->clearRestorePoint();
+
   CharArray name(len + 1);
   is->readBytes(name.buf, len);
   name.buf[len] = '\0';
@@ -569,18 +745,29 @@ void CMsgReader::readSetDesktopName(int x, int y, int w, int h)
   } else {
     handler->setName(name.buf);
   }
+
+  return true;
 }
 
-void CMsgReader::readExtendedDesktopSize(int x, int y, int w, int h)
+bool CMsgReader::readExtendedDesktopSize(int x, int y, int w, int h)
 {
   unsigned int screens, i;
   rdr::U32 id, flags;
   int sx, sy, sw, sh;
   ScreenSet layout;
 
+  if (!is->hasData(4))
+    return false;
+
+  is->setRestorePoint();
+
   screens = is->readU8();
   is->skip(3);
 
+  if (!is->hasDataOrRestore(16 * screens))
+    return false;
+  is->clearRestorePoint();
+
   for (i = 0;i < screens;i++) {
     id = is->readU32();
     sx = is->readU16();
@@ -593,25 +780,37 @@ void CMsgReader::readExtendedDesktopSize(int x, int y, int w, int h)
   }
 
   handler->setExtendedDesktopSize(x, y, w, h, layout);
+
+  return true;
 }
 
-void CMsgReader::readLEDState()
+bool CMsgReader::readLEDState()
 {
   rdr::U8 state;
 
+  if (!is->hasData(1))
+    return false;
+
   state = is->readU8();
 
   handler->setLEDState(state);
+
+  return true;
 }
 
-void CMsgReader::readVMwareLEDState()
+bool CMsgReader::readVMwareLEDState()
 {
   rdr::U32 state;
 
+  if (!is->hasData(4))
+    return false;
+
   state = is->readU32();
 
   // As luck has it, this extension uses the same bit definitions,
   // so no conversion required
 
   handler->setLEDState(state);
+
+  return true;
 }
index 050990a989ab64ca4c2c2f47dbab2f83725b1930..ab55aed8b36933cc1cc2ed4f371007cb6ef3092c 100644 (file)
@@ -40,39 +40,55 @@ namespace rfb {
     CMsgReader(CMsgHandler* handler, rdr::InStream* is);
     virtual ~CMsgReader();
 
-    void readServerInit();
+    bool readServerInit();
 
     // readMsg() reads a message, calling the handler as appropriate.
-    void readMsg();
+    bool readMsg();
 
     rdr::InStream* getInStream() { return is; }
 
     int imageBufIdealSize;
 
   protected:
-    void readSetColourMapEntries();
-    void readBell();
-    void readServerCutText();
-    void readExtendedClipboard(rdr::S32 len);
-    void readFence();
-    void readEndOfContinuousUpdates();
-
-    void readFramebufferUpdate();
-
-    void readRect(const Rect& r, int encoding);
-
-    void readSetXCursor(int width, int height, const Point& hotspot);
-    void readSetCursor(int width, int height, const Point& hotspot);
-    void readSetCursorWithAlpha(int width, int height, const Point& hotspot);
-    void readSetVMwareCursor(int width, int height, const Point& hotspot);
-    void readSetDesktopName(int x, int y, int w, int h);
-    void readExtendedDesktopSize(int x, int y, int w, int h);
-    void readLEDState();
-    void readVMwareLEDState();
-
+    bool readSetColourMapEntries();
+    bool readBell();
+    bool readServerCutText();
+    bool readExtendedClipboard(rdr::S32 len);
+    bool readFence();
+    bool readEndOfContinuousUpdates();
+
+    bool readFramebufferUpdate();
+
+    bool readRect(const Rect& r, int encoding);
+
+    bool readSetXCursor(int width, int height, const Point& hotspot);
+    bool readSetCursor(int width, int height, const Point& hotspot);
+    bool readSetCursorWithAlpha(int width, int height, const Point& hotspot);
+    bool readSetVMwareCursor(int width, int height, const Point& hotspot);
+    bool readSetDesktopName(int x, int y, int w, int h);
+    bool readExtendedDesktopSize(int x, int y, int w, int h);
+    bool readLEDState();
+    bool readVMwareLEDState();
+
+  private:
     CMsgHandler* handler;
     rdr::InStream* is;
+
+    enum stateEnum {
+      MSGSTATE_IDLE,
+      MSGSTATE_MESSAGE,
+      MSGSTATE_RECT_HEADER,
+      MSGSTATE_RECT_DATA,
+    };
+
+    stateEnum state;
+
+    rdr::U8 currentMsgType;
     int nUpdateRectsLeft;
+    Rect dataRect;
+    int rectEncoding;
+
+    int cursorEncoding;
 
     static const int maxCursorSize = 256;
   };
index 374ec7f3da31009a8044f03668d75d934bd989b2..4fcaa7a9f5e5e4ec0b38777a9848b50c986d6621 100644 (file)
@@ -154,7 +154,7 @@ bool CSecurityTLS::processMsg()
   client = cc;
 
   if (!session) {
-    if (!is->checkNoWait(1))
+    if (!is->hasData(1))
       return false;
 
     if (is->readU8() == 0)
@@ -180,8 +180,10 @@ bool CSecurityTLS::processMsg()
   int err;
   err = gnutls_handshake(session);
   if (err != GNUTLS_E_SUCCESS) {
-    if (!gnutls_error_is_fatal(err))
+    if (!gnutls_error_is_fatal(err)) {
+      vlog.debug("Deferring completion of TLS handshake: %s", gnutls_strerror(err));
       return false;
+    }
 
     vlog.error("TLS Handshake failed: %s\n", gnutls_strerror (err));
     shutdown(false);
index 22201dd26a7240ab671fc2ec9ff5e16f83ee1b7c..98dad494b43dcbc64a764d0fbe0c43e503c98a48 100644 (file)
@@ -51,7 +51,6 @@ CSecurityVeNCrypt::CSecurityVeNCrypt(CConnection* cc, SecurityClient* sec)
   chosenType = secTypeVeNCrypt;
   nAvailableTypes = 0;
   availableTypes = NULL;
-  iAvailableType = 0;
 }
 
 CSecurityVeNCrypt::~CSecurityVeNCrypt()
@@ -64,16 +63,20 @@ bool CSecurityVeNCrypt::processMsg()
 {
   InStream* is = cc->getInStream();
   OutStream* os = cc->getOutStream();
-       
+
   /* get major, minor versions, send what we can support (or 0.0 for can't support it) */
   if (!haveRecvdMajorVersion) {
+    if (!is->hasData(1))
+      return false;
+
     majorVersion = is->readU8();
     haveRecvdMajorVersion = true;
-
-    return false;
   }
 
   if (!haveRecvdMinorVersion) {
+    if (!is->hasData(1))
+      return false;
+
     minorVersion = is->readU8();
     haveRecvdMinorVersion = true;
   }
@@ -100,47 +103,48 @@ bool CSecurityVeNCrypt::processMsg()
      }
 
      haveSentVersion = true;
-     return false;
   }
 
   /* Check that the server is OK */
   if (!haveAgreedVersion) {
+    if (!is->hasData(1))
+      return false;
+
     if (is->readU8())
       throw AuthFailureException("The server reported it could not support the "
                                 "VeNCrypt version");
 
     haveAgreedVersion = true;
-    return false;
   }
   
   /* get a number of types */
   if (!haveNumberOfTypes) {
+    if (!is->hasData(1))
+      return false;
+
     nAvailableTypes = is->readU8();
-    iAvailableType = 0;
 
     if (!nAvailableTypes)
       throw AuthFailureException("The server reported no VeNCrypt sub-types");
 
     availableTypes = new rdr::U32[nAvailableTypes];
     haveNumberOfTypes = true;
-    return false;
   }
 
   if (nAvailableTypes) {
     /* read in the types possible */
     if (!haveListOfTypes) {
-      if (is->checkNoWait(4)) {
-       availableTypes[iAvailableType++] = is->readU32();
-       haveListOfTypes = (iAvailableType >= nAvailableTypes);
-       vlog.debug("Server offers security type %s (%d)",
-                  secTypeName(availableTypes[iAvailableType - 1]),
-                  availableTypes[iAvailableType - 1]);
-
-       if (!haveListOfTypes)
-         return false;
-
-      } else
-       return false;
+      if (!is->hasData(4 * nAvailableTypes))
+        return false;
+
+      for (int i = 0;i < nAvailableTypes;i++) {
+        availableTypes[i] = is->readU32();
+        vlog.debug("Server offers security type %s (%d)",
+                   secTypeName(availableTypes[i]),
+                   availableTypes[i]);
+      }
+
+      haveListOfTypes = true;
     }
 
     /* make a choice and send it to the server, meanwhile set up the stack */
index d015e8f2cdb952735078dd761257db920140b4d6..1e2a7e686ae689fec09186a5d16821eb0afa11ec 100644 (file)
@@ -55,7 +55,6 @@ namespace rfb {
     rdr::U32 chosenType;
     rdr::U8 nAvailableTypes;
     rdr::U32 *availableTypes;
-    rdr::U8 iAvailableType;
   };
 }
 #endif
index 6a87498c64bb368a38a666d2a48fd6ccb157cb32..78a3a061519056f63eb036377d61260972cd5939 100644 (file)
@@ -45,6 +45,9 @@ bool CSecurityVncAuth::processMsg()
   rdr::InStream* is = cc->getInStream();
   rdr::OutStream* os = cc->getOutStream();
 
+  if (!is->hasData(vncAuthChallengeSize))
+    return false;
+
   // Read the challenge & obtain the user's password
   rdr::U8 challenge[vncAuthChallengeSize];
   is->readBytes(challenge, vncAuthChallengeSize);
index ecf503238e0aae536a7ac6088451c84b4b72f23d..bca8da481d3cd01b5104ba0539f3923b973a59d9 100644 (file)
@@ -31,10 +31,13 @@ CopyRectDecoder::~CopyRectDecoder()
 {
 }
 
-void CopyRectDecoder::readRect(const Rect& r, rdr::InStream* is,
+bool CopyRectDecoder::readRect(const Rect& r, rdr::InStream* is,
                                const ServerParams& server, rdr::OutStream* os)
 {
+  if (!is->hasData(4))
+    return false;
   os->copyBytes(is, 4);
+  return true;
 }
 
 
index 546266e10e19e4aa2c2a4ceb01c73afc63be6d28..5100eb2f9964affa97524fec39d9e2fa3a81ecf8 100644 (file)
@@ -26,7 +26,7 @@ namespace rfb {
   public:
     CopyRectDecoder();
     virtual ~CopyRectDecoder();
-    virtual void readRect(const Rect& r, rdr::InStream* is,
+    virtual bool readRect(const Rect& r, rdr::InStream* is,
                           const ServerParams& server, rdr::OutStream* os);
     virtual void getAffectedRegion(const Rect& rect, const void* buffer,
                                    size_t buflen, const ServerParams& server,
index 80c105103de143db32a5de28094b812e0b42347f..c003ab4003e79c49d468e6249fbdb9b9c6e7e5a7 100644 (file)
@@ -103,7 +103,7 @@ DecodeManager::~DecodeManager()
     delete decoders[i];
 }
 
-void DecodeManager::decodeRect(const Rect& r, int encoding,
+bool DecodeManager::decodeRect(const Rect& r, int encoding,
                                ModifiablePixelBuffer* pb)
 {
   Decoder *decoder;
@@ -133,19 +133,21 @@ void DecodeManager::decodeRect(const Rect& r, int encoding,
   if (threads.empty()) {
     bufferStream = freeBuffers.front();
     bufferStream->clear();
-    decoder->readRect(r, conn->getInStream(), conn->server, bufferStream);
+    if (!decoder->readRect(r, conn->getInStream(), conn->server, bufferStream))
+      return false;
     try {
       decoder->decodeRect(r, bufferStream->data(), bufferStream->length(),
                           conn->server, pb);
     } catch (rdr::Exception& e) {
       throw Exception("Error decoding rect: %s", e.str());
     }
-    return;
+    return true;
   }
 
   // Wait for an available memory buffer
   queueMutex->lock();
 
+  // FIXME: Should we return and let other things run here?
   while (freeBuffers.empty())
     producerCond->wait();
 
@@ -160,7 +162,8 @@ void DecodeManager::decodeRect(const Rect& r, int encoding,
 
   // Read the rect
   bufferStream->clear();
-  decoder->readRect(r, conn->getInStream(), conn->server, bufferStream);
+  if (!decoder->readRect(r, conn->getInStream(), conn->server, bufferStream))
+    return false;
 
   // Then try to put it on the queue
   entry = new QueueEntry;
@@ -190,6 +193,8 @@ void DecodeManager::decodeRect(const Rect& r, int encoding,
   consumerCond->signal();
 
   queueMutex->unlock();
+
+  return true;
 }
 
 void DecodeManager::flush()
index 058d824057bf496aa736ebeb46629b57a1349934..289686b587ffe78389dc4dc9ec71f7e3a6357aec 100644 (file)
@@ -47,7 +47,7 @@ namespace rfb {
     DecodeManager(CConnection *conn);
     ~DecodeManager();
 
-    void decodeRect(const Rect& r, int encoding,
+    bool decodeRect(const Rect& r, int encoding,
                     ModifiablePixelBuffer* pb);
 
     void flush();
index e074f3ec85f2b075577d0f0f931e892b320fc2a9..cb206a0d6e03d52ebdb650e6b6bfd90bf2b72bfc 100644 (file)
@@ -52,7 +52,7 @@ namespace rfb {
     // InStream to the OutStream, possibly changing it along the way to
     // make it easier to decode. This function will always be called in
     // a serial manner on the main thread.
-    virtual void readRect(const Rect& r, rdr::InStream* is,
+    virtual bool readRect(const Rect& r, rdr::InStream* is,
                           const ServerParams& server, rdr::OutStream* os)=0;
 
     // These functions will be called from any of the worker threads.
index 742dfb28752db87137332d10f8bff095641984eb..34392ea8966ef5ff434a2a562175e34aa0f0a610 100644 (file)
@@ -44,12 +44,14 @@ HextileDecoder::~HextileDecoder()
 {
 }
 
-void HextileDecoder::readRect(const Rect& r, rdr::InStream* is,
+bool HextileDecoder::readRect(const Rect& r, rdr::InStream* is,
                               const ServerParams& server, rdr::OutStream* os)
 {
   Rect t;
   size_t bytesPerPixel;
 
+  is->setRestorePoint();
+
   bytesPerPixel = server.pf().bpp/8;
 
   for (t.tl.y = r.tl.y; t.tl.y < r.br.y; t.tl.y += 16) {
@@ -61,33 +63,57 @@ void HextileDecoder::readRect(const Rect& r, rdr::InStream* is,
 
       t.br.x = __rfbmin(r.br.x, t.tl.x + 16);
 
+      if (!is->hasDataOrRestore(1))
+        return false;
+
       tileType = is->readU8();
       os->writeU8(tileType);
 
       if (tileType & hextileRaw) {
+        if (!is->hasDataOrRestore(t.area() * bytesPerPixel))
+          return false;
         os->copyBytes(is, t.area() * bytesPerPixel);
         continue;
       }
 
-      if (tileType & hextileBgSpecified)
+
+      if (tileType & hextileBgSpecified) {
+        if (!is->hasDataOrRestore(bytesPerPixel))
+          return false;
         os->copyBytes(is, bytesPerPixel);
+      }
 
-      if (tileType & hextileFgSpecified)
+      if (tileType & hextileFgSpecified) {
+        if (!is->hasDataOrRestore(bytesPerPixel))
+          return false;
         os->copyBytes(is, bytesPerPixel);
+      }
 
       if (tileType & hextileAnySubrects) {
         rdr::U8 nSubrects;
 
+        if (!is->hasDataOrRestore(1))
+          return false;
+
         nSubrects = is->readU8();
         os->writeU8(nSubrects);
 
-        if (tileType & hextileSubrectsColoured)
+        if (tileType & hextileSubrectsColoured) {
+          if (!is->hasDataOrRestore(nSubrects * (bytesPerPixel + 2)))
+            return false;
           os->copyBytes(is, nSubrects * (bytesPerPixel + 2));
-        else
+        } else {
+          if (!is->hasDataOrRestore(nSubrects * 2))
+            return false;
           os->copyBytes(is, nSubrects * 2);
+        }
       }
     }
   }
+
+  is->clearRestorePoint();
+
+  return true;
 }
 
 void HextileDecoder::decodeRect(const Rect& r, const void* buffer,
index b8515bfc9530fc3faffa39b1c21f100b7ba60ba3..2c42be54b19f65af8c9b6648a37e4fc07c862046 100644 (file)
@@ -26,7 +26,7 @@ namespace rfb {
   public:
     HextileDecoder();
     virtual ~HextileDecoder();
-    virtual void readRect(const Rect& r, rdr::InStream* is,
+    virtual bool readRect(const Rect& r, rdr::InStream* is,
                           const ServerParams& server, rdr::OutStream* os);
     virtual void decodeRect(const Rect& r, const void* buffer,
                             size_t buflen, const ServerParams& server,
index 70a7ddb2c845bbd4f233606b0515b983139644e8..af821cb98f8c3321f8b99d7ebd92edba6dff39bc 100644 (file)
@@ -44,15 +44,30 @@ RREDecoder::~RREDecoder()
 {
 }
 
-void RREDecoder::readRect(const Rect& r, rdr::InStream* is,
+bool RREDecoder::readRect(const Rect& r, rdr::InStream* is,
                           const ServerParams& server, rdr::OutStream* os)
 {
   rdr::U32 numRects;
+  size_t len;
+
+  if (!is->hasData(4))
+    return false;
+
+  is->setRestorePoint();
 
   numRects = is->readU32();
   os->writeU32(numRects);
 
-  os->copyBytes(is, server.pf().bpp/8 + numRects * (server.pf().bpp/8 + 8));
+  len = server.pf().bpp/8 + numRects * (server.pf().bpp/8 + 8);
+
+  if (!is->hasDataOrRestore(len))
+    return false;
+
+  is->clearRestorePoint();
+
+  os->copyBytes(is, len);
+
+  return true;
 }
 
 void RREDecoder::decodeRect(const Rect& r, const void* buffer,
index f47eddade5b9937380584dc87b83d9120b71a531..b8ec18f6e30d77deea45e8c47265c94a797d480f 100644 (file)
@@ -26,7 +26,7 @@ namespace rfb {
   public:
     RREDecoder();
     virtual ~RREDecoder();
-    virtual void readRect(const Rect& r, rdr::InStream* is,
+    virtual bool readRect(const Rect& r, rdr::InStream* is,
                           const ServerParams& server, rdr::OutStream* os);
     virtual void decodeRect(const Rect& r, const void* buffer,
                             size_t buflen, const ServerParams& server,
index 61235047bf1f621fc1d13eec42bc344500ecb1ec..a7648f973b07677d2adf97113eef752555dbbf5f 100644 (file)
@@ -33,10 +33,13 @@ RawDecoder::~RawDecoder()
 {
 }
 
-void RawDecoder::readRect(const Rect& r, rdr::InStream* is,
+bool RawDecoder::readRect(const Rect& r, rdr::InStream* is,
                           const ServerParams& server, rdr::OutStream* os)
 {
+  if (!is->hasData(r.area() * (server.pf().bpp/8)))
+    return false;
   os->copyBytes(is, r.area() * (server.pf().bpp/8));
+  return true;
 }
 
 void RawDecoder::decodeRect(const Rect& r, const void* buffer,
index 4ab8071770dba27a19e70a650fb8e42248b32267..2661ea57d4965cd922fce753b13e6b8d60f52ba1 100644 (file)
@@ -25,7 +25,7 @@ namespace rfb {
   public:
     RawDecoder();
     virtual ~RawDecoder();
-    virtual void readRect(const Rect& r, rdr::InStream* is,
+    virtual bool readRect(const Rect& r, rdr::InStream* is,
                           const ServerParams& server, rdr::OutStream* os);
     virtual void decodeRect(const Rect& r, const void* buffer,
                             size_t buflen, const ServerParams& server,
index e06fc6bba23eb07378f0034ae8263c58bb18c473..1c9ca3e705cbb992622cbe2cfb3a7052716cb719 100644 (file)
@@ -86,18 +86,20 @@ void SConnection::initialiseProtocol()
   state_ = RFBSTATE_PROTOCOL_VERSION;
 }
 
-void SConnection::processMsg()
+bool SConnection::processMsg()
 {
   switch (state_) {
-  case RFBSTATE_PROTOCOL_VERSION: processVersionMsg();      break;
-  case RFBSTATE_SECURITY_TYPE:    processSecurityTypeMsg(); break;
-  case RFBSTATE_SECURITY:         processSecurityMsg();     break;
-  case RFBSTATE_SECURITY_FAILURE: processSecurityFailure(); break;
-  case RFBSTATE_INITIALISATION:   processInitMsg();         break;
-  case RFBSTATE_NORMAL:           reader_->readMsg();       break;
+  case RFBSTATE_PROTOCOL_VERSION: return processVersionMsg();      break;
+  case RFBSTATE_SECURITY_TYPE:    return processSecurityTypeMsg(); break;
+  case RFBSTATE_SECURITY:         return processSecurityMsg();     break;
+  case RFBSTATE_SECURITY_FAILURE: return processSecurityFailure(); break;
+  case RFBSTATE_INITIALISATION:   return processInitMsg();         break;
+  case RFBSTATE_NORMAL:           return reader_->readMsg();       break;
   case RFBSTATE_QUERYING:
     throw Exception("SConnection::processMsg: bogus data from client while "
                     "querying");
+  case RFBSTATE_CLOSING:
+    throw Exception("SConnection::processMsg: called while closing");
   case RFBSTATE_UNINITIALISED:
     throw Exception("SConnection::processMsg: not initialised yet?");
   default:
@@ -105,7 +107,7 @@ void SConnection::processMsg()
   }
 }
 
-void SConnection::processVersionMsg()
+bool SConnection::processVersionMsg()
 {
   char verStr[13];
   int majorVersion;
@@ -113,8 +115,8 @@ void SConnection::processVersionMsg()
 
   vlog.debug("reading protocol version");
 
-  if (!is->checkNoWait(12))
-    return;
+  if (!is->hasData(12))
+    return false;
 
   is->readBytes(verStr, 12);
   verStr[12] = '\0';
@@ -172,8 +174,7 @@ void SConnection::processVersionMsg()
     if (*i == secTypeNone) os->flush();
     state_ = RFBSTATE_SECURITY;
     ssecurity = security.GetSSecurity(this, *i);
-    processSecurityMsg();
-    return;
+    return true;
   }
 
   // list supported security types for >=3.7 clients
@@ -186,15 +187,23 @@ void SConnection::processVersionMsg()
     os->writeU8(*i);
   os->flush();
   state_ = RFBSTATE_SECURITY_TYPE;
+
+  return true;
 }
 
 
-void SConnection::processSecurityTypeMsg()
+bool SConnection::processSecurityTypeMsg()
 {
   vlog.debug("processing security type message");
+
+  if (!is->hasData(1))
+    return false;
+
   int secType = is->readU8();
 
   processSecurityType(secType);
+
+  return true;
 }
 
 void SConnection::processSecurityType(int secType)
@@ -218,16 +227,14 @@ void SConnection::processSecurityType(int secType)
   } catch (rdr::Exception& e) {
     throwConnFailedException("%s", e.str());
   }
-
-  processSecurityMsg();
 }
 
-void SConnection::processSecurityMsg()
+bool SConnection::processSecurityMsg()
 {
   vlog.debug("processing security message");
   try {
     if (!ssecurity->processMsg())
-      return;
+      return false;
   } catch (AuthFailureException& e) {
     vlog.error("AuthFailureException: %s", e.str());
     state_ = RFBSTATE_SECURITY_FAILURE;
@@ -235,28 +242,41 @@ void SConnection::processSecurityMsg()
     // to make it difficult to brute force a password
     authFailureMsg.replaceBuf(strDup(e.str()));
     authFailureTimer.start(100);
+    return true;
   }
 
   state_ = RFBSTATE_QUERYING;
   setAccessRights(ssecurity->getAccessRights());
   queryConnection(ssecurity->getUserName());
+
+  // If the connection got approved right away then we can continue
+  if (state_ == RFBSTATE_INITIALISATION)
+    return true;
+
+  // Otherwise we need to wait for the result
+  // (or give up if if was rejected)
+  return false;
 }
 
-void SConnection::processSecurityFailure()
+bool SConnection::processSecurityFailure()
 {
   // Silently drop any data if we are currently delaying an
   // authentication failure response as otherwise we would close
   // the connection on unexpected data, and an attacker could use
   // that to detect our delayed state.
 
-  while (is->checkNoWait(1))
-    is->skip(1);
+  if (!is->hasData(1))
+    return false;
+
+  is->skip(is->avail());
+
+  return true;
 }
 
-void SConnection::processInitMsg()
+bool SConnection::processInitMsg()
 {
   vlog.debug("reading client initialisation");
-  reader_->readClientInit();
+  return reader_->readClientInit();
 }
 
 bool SConnection::handleAuthFailureTimeout(Timer* t)
index e7bbf2c33d718535f5218e488dcc91559e8f6aad..b333086a188b35dc40c5159769684f1e887b219e 100644 (file)
@@ -60,7 +60,7 @@ namespace rfb {
 
     // processMsg() should be called whenever there is data to read on the
     // InStream.  You must have called initialiseProtocol() first.
-    void processMsg();
+    bool processMsg();
 
     // approveConnection() is called to either accept or reject the connection.
     // If accept is false, the reason string gives the reason for the
@@ -235,12 +235,12 @@ namespace rfb {
 
     bool readyForSetColourMapEntries;
 
-    void processVersionMsg();
-    void processSecurityTypeMsg();
+    bool processVersionMsg();
+    bool processSecurityTypeMsg();
     void processSecurityType(int secType);
-    void processSecurityMsg();
-    void processSecurityFailure();
-    void processInitMsg();
+    bool processSecurityMsg();
+    bool processSecurityFailure();
+    bool processInitMsg();
 
     bool handleAuthFailureTimeout(Timer* t);
 
index dc7ddea6b8d4558a595adab7d508e86d1b7b3842..944f9315a00d84938f49a5d5e2af25f4038b89d7 100644 (file)
@@ -38,7 +38,7 @@ static LogWriter vlog("SMsgReader");
 static IntParameter maxCutText("MaxCutText", "Maximum permitted length of an incoming clipboard update", 256*1024);
 
 SMsgReader::SMsgReader(SMsgHandler* handler_, rdr::InStream* is_)
-  : handler(handler_), is(is_)
+  : handler(handler_), is(is_), state(MSGSTATE_IDLE)
 {
 }
 
@@ -46,71 +46,105 @@ SMsgReader::~SMsgReader()
 {
 }
 
-void SMsgReader::readClientInit()
+bool SMsgReader::readClientInit()
 {
+  if (!is->hasData(1))
+    return false;
   bool shared = is->readU8();
   handler->clientInit(shared);
+  return true;
 }
 
-void SMsgReader::readMsg()
+bool SMsgReader::readMsg()
 {
-  int msgType = is->readU8();
-  switch (msgType) {
+  bool ret;
+
+  if (state == MSGSTATE_IDLE) {
+    if (!is->hasData(1))
+      return false;
+
+    currentMsgType = is->readU8();
+    state = MSGSTATE_MESSAGE;
+  }
+
+  switch (currentMsgType) {
   case msgTypeSetPixelFormat:
-    readSetPixelFormat();
+    ret = readSetPixelFormat();
     break;
   case msgTypeSetEncodings:
-    readSetEncodings();
+    ret = readSetEncodings();
     break;
   case msgTypeSetDesktopSize:
-    readSetDesktopSize();
+    ret = readSetDesktopSize();
     break;
   case msgTypeFramebufferUpdateRequest:
-    readFramebufferUpdateRequest();
+    ret = readFramebufferUpdateRequest();
     break;
   case msgTypeEnableContinuousUpdates:
-    readEnableContinuousUpdates();
+    ret = readEnableContinuousUpdates();
     break;
   case msgTypeClientFence:
-    readFence();
+    ret = readFence();
     break;
   case msgTypeKeyEvent:
-    readKeyEvent();
+    ret = readKeyEvent();
     break;
   case msgTypePointerEvent:
-    readPointerEvent();
+    ret = readPointerEvent();
     break;
   case msgTypeClientCutText:
-    readClientCutText();
+    ret = readClientCutText();
     break;
   case msgTypeQEMUClientMessage:
-    readQEMUMessage();
+    ret = readQEMUMessage();
     break;
   default:
-    vlog.error("unknown message type %d", msgType);
+    vlog.error("unknown message type %d", currentMsgType);
     throw Exception("unknown message type");
   }
+
+  if (ret)
+    state = MSGSTATE_IDLE;
+
+  return ret;
 }
 
-void SMsgReader::readSetPixelFormat()
+bool SMsgReader::readSetPixelFormat()
 {
+  if (!is->hasData(3 + 16))
+    return false;
   is->skip(3);
   PixelFormat pf;
   pf.read(is);
   handler->setPixelFormat(pf);
+  return true;
 }
 
-void SMsgReader::readSetEncodings()
+bool SMsgReader::readSetEncodings()
 {
+  if (!is->hasData(3))
+    return false;
+
+  is->setRestorePoint();
+
   is->skip(1);
+
   int nEncodings = is->readU16();
+
+  if (!is->hasDataOrRestore(nEncodings * 4))
+    return false;
+  is->clearRestorePoint();
+
   rdr::S32Array encodings(nEncodings);
   for (int i = 0; i < nEncodings; i++)
     encodings.buf[i] = is->readU32();
+
   handler->setEncodings(nEncodings, encodings.buf);
+
+  return true;
 }
 
-void SMsgReader::readSetDesktopSize()
+bool SMsgReader::readSetDesktopSize()
 {
   int width, height;
   int screens, i;
@@ -118,6 +152,11 @@ void SMsgReader::readSetDesktopSize()
   int sx, sy, sw, sh;
   ScreenSet layout;
 
+  if (!is->hasData(7))
+    return true;
+
+  is->setRestorePoint();
+
   is->skip(1);
 
   width = is->readU16();
@@ -126,6 +165,10 @@ void SMsgReader::readSetDesktopSize()
   screens = is->readU8();
   is->skip(1);
 
+  if (!is->hasDataOrRestore(screens * 24))
+    return false;
+  is->clearRestorePoint();
+
   for (i = 0;i < screens;i++) {
     id = is->readU32();
     sx = is->readU16();
@@ -138,23 +181,31 @@ void SMsgReader::readSetDesktopSize()
   }
 
   handler->setDesktopSize(width, height, layout);
+
+  return true;
 }
 
-void SMsgReader::readFramebufferUpdateRequest()
+bool SMsgReader::readFramebufferUpdateRequest()
 {
+  if (!is->hasData(17))
+    return false;
   bool inc = is->readU8();
   int x = is->readU16();
   int y = is->readU16();
   int w = is->readU16();
   int h = is->readU16();
   handler->framebufferUpdateRequest(Rect(x, y, x+w, y+h), inc);
+  return true;
 }
 
-void SMsgReader::readEnableContinuousUpdates()
+bool SMsgReader::readEnableContinuousUpdates()
 {
   bool enable;
   int x, y, w, h;
 
+  if (!is->hasData(17))
+    return false;
+
   enable = is->readU8();
 
   x = is->readU16();
@@ -163,81 +214,121 @@ void SMsgReader::readEnableContinuousUpdates()
   h = is->readU16();
 
   handler->enableContinuousUpdates(enable, x, y, w, h);
+
+  return true;
 }
 
-void SMsgReader::readFence()
+bool SMsgReader::readFence()
 {
   rdr::U32 flags;
   rdr::U8 len;
   char data[64];
 
+  if (!is->hasData(8))
+    return false;
+
+  is->setRestorePoint();
+
   is->skip(3);
 
   flags = is->readU32();
 
   len = is->readU8();
+
+  if (!is->hasDataOrRestore(len))
+    return false;
+  is->clearRestorePoint();
+
   if (len > sizeof(data)) {
     vlog.error("Ignoring fence with too large payload");
     is->skip(len);
-    return;
+    return true;
   }
 
   is->readBytes(data, len);
   
   handler->fence(flags, len, data);
+
+  return true;
 }
 
-void SMsgReader::readKeyEvent()
+bool SMsgReader::readKeyEvent()
 {
+  if (!is->hasData(7))
+    return false;
   bool down = is->readU8();
   is->skip(2);
   rdr::U32 key = is->readU32();
   handler->keyEvent(key, 0, down);
+  return true;
 }
 
-void SMsgReader::readPointerEvent()
+bool SMsgReader::readPointerEvent()
 {
+  if (!is->hasData(5))
+    return false;
   int mask = is->readU8();
   int x = is->readU16();
   int y = is->readU16();
   handler->pointerEvent(Point(x, y), mask);
+  return true;
 }
 
 
-void SMsgReader::readClientCutText()
+bool SMsgReader::readClientCutText()
 {
+  if (!is->hasData(7))
+    return false;
+
+  is->setRestorePoint();
+
   is->skip(3);
   rdr::U32 len = is->readU32();
 
   if (len & 0x80000000) {
     rdr::S32 slen = len;
     slen = -slen;
-    readExtendedClipboard(slen);
-    return;
+    if (readExtendedClipboard(slen)) {
+      is->clearRestorePoint();
+      return true;
+    } else {
+      is->gotoRestorePoint();
+      return false;
+    }
   }
 
+  if (!is->hasDataOrRestore(len))
+    return false;
+  is->clearRestorePoint();
+
   if (len > (size_t)maxCutText) {
     is->skip(len);
     vlog.error("Cut text too long (%d bytes) - ignoring", len);
-    return;
+    return true;
   }
+
   CharArray ca(len);
   is->readBytes(ca.buf, len);
   CharArray filtered(convertLF(ca.buf, len));
   handler->clientCutText(filtered.buf);
+
+  return true;
 }
 
-void SMsgReader::readExtendedClipboard(rdr::S32 len)
+bool SMsgReader::readExtendedClipboard(rdr::S32 len)
 {
   rdr::U32 flags;
   rdr::U32 action;
 
+  if (!is->hasData(len))
+    return false;
+
   if (len < 4)
     throw Exception("Invalid extended clipboard message");
   if (len > maxCutText) {
     vlog.error("Extended clipboard message too long (%d bytes) - ignoring", len);
     is->skip(len);
-    return;
+    return true;
   }
 
   flags = is->readU32();
@@ -279,7 +370,14 @@ void SMsgReader::readExtendedClipboard(rdr::S32 len)
       if (!(flags & 1 << i))
         continue;
 
+      if (!zis.hasData(4))
+        throw Exception("Extended clipboard decode error");
+
       lengths[num] = zis.readU32();
+
+      if (!zis.hasData(lengths[num]))
+        throw Exception("Extended clipboard decode error");
+
       if (lengths[num] > (size_t)maxCutText) {
         vlog.error("Extended clipboard data too long (%d bytes) - ignoring",
                    (unsigned)lengths[num]);
@@ -319,28 +417,50 @@ void SMsgReader::readExtendedClipboard(rdr::S32 len)
       throw Exception("Invalid extended clipboard action");
     }
   }
+
+  return true;
 }
 
-void SMsgReader::readQEMUMessage()
+bool SMsgReader::readQEMUMessage()
 {
-  int subType = is->readU8();
+  int subType;
+  bool ret;
+
+  if (!is->hasData(1))
+    return false;
+
+  is->setRestorePoint();
+
+  subType = is->readU8();
+
   switch (subType) {
   case qemuExtendedKeyEvent:
-    readQEMUKeyEvent();
+    ret = readQEMUKeyEvent();
     break;
   default:
     throw Exception("unknown QEMU submessage type %d", subType);
   }
+
+  if (!ret) {
+    is->gotoRestorePoint();
+    return false;
+  } else {
+    is->clearRestorePoint();
+    return true;
+  }
 }
 
-void SMsgReader::readQEMUKeyEvent()
+bool SMsgReader::readQEMUKeyEvent()
 {
+  if (!is->hasData(10))
+    return false;
   bool down = is->readU16();
   rdr::U32 keysym = is->readU32();
   rdr::U32 keycode = is->readU32();
   if (!keycode) {
     vlog.error("Key event without keycode - ignoring");
-    return;
+    return true;
   }
   handler->keyEvent(keysym, keycode, down);
+  return true;
 }
index 4991fd38ca966f0d2075f94d6c25a6b2c733d649..acc872ed7cfe45c638286655486ec773061289cf 100644 (file)
@@ -34,33 +34,43 @@ namespace rfb {
     SMsgReader(SMsgHandler* handler, rdr::InStream* is);
     virtual ~SMsgReader();
 
-    void readClientInit();
+    bool readClientInit();
 
     // readMsg() reads a message, calling the handler as appropriate.
-    void readMsg();
+    bool readMsg();
 
     rdr::InStream* getInStream() { return is; }
 
   protected:
-    void readSetPixelFormat();
-    void readSetEncodings();
-    void readSetDesktopSize();
+    bool readSetPixelFormat();
+    bool readSetEncodings();
+    bool readSetDesktopSize();
 
-    void readFramebufferUpdateRequest();
-    void readEnableContinuousUpdates();
+    bool readFramebufferUpdateRequest();
+    bool readEnableContinuousUpdates();
 
-    void readFence();
+    bool readFence();
 
-    void readKeyEvent();
-    void readPointerEvent();
-    void readClientCutText();
-    void readExtendedClipboard(rdr::S32 len);
+    bool readKeyEvent();
+    bool readPointerEvent();
+    bool readClientCutText();
+    bool readExtendedClipboard(rdr::S32 len);
 
-    void readQEMUMessage();
-    void readQEMUKeyEvent();
+    bool readQEMUMessage();
+    bool readQEMUKeyEvent();
 
+  private:
     SMsgHandler* handler;
     rdr::InStream* is;
+
+    enum stateEnum {
+      MSGSTATE_IDLE,
+      MSGSTATE_MESSAGE,
+    };
+
+    stateEnum state;
+
+    rdr::U8 currentMsgType;
   };
 }
 #endif
index f577c0d6fa1fc2657ccd16b9e06d23134e15dd10..6ae19557129e30bd77bed8d952af830083f34d3d 100644 (file)
@@ -84,7 +84,7 @@ bool SSecurityPlain::processMsg()
     throw AuthFailureException("No password validator configured");
 
   if (state == 0) {
-    if (!is->checkNoWait(8))
+    if (!is->hasData(8))
       return false;
 
     ulen = is->readU32();
@@ -99,7 +99,7 @@ bool SSecurityPlain::processMsg()
   }
 
   if (state == 1) {
-    if (!is->checkNoWait(ulen + plen))
+    if (!is->hasData(ulen + plen))
       return false;
     state = 2;
     pw = new char[plen + 1];
index d522ef6fc48ee94c63e82381e02d09ea53d4e776..135742c09feeb6422b255faa99b54085260521dd 100644 (file)
@@ -78,19 +78,21 @@ bool SSecurityVeNCrypt::processMsg()
     os->writeU8(2);
     haveSentVersion = true;
     os->flush();
-
-    return false;
   }
 
   /* Receive back highest version that client can support (up to and including ours) */
   if (!haveRecvdMajorVersion) {
+    if (!is->hasData(1))
+      return false;
+
     majorVersion = is->readU8();
     haveRecvdMajorVersion = true;
-
-    return false;
   }
 
   if (!haveRecvdMinorVersion) {
+    if (!is->hasData(1))
+      return false;
+
     minorVersion = is->readU8();
     haveRecvdMinorVersion = true;
 
@@ -140,14 +142,15 @@ bool SSecurityVeNCrypt::processMsg()
 
       os->flush(); 
       haveSentTypes = true;
-      return false;
     } else
       throw AuthFailureException("There are no VeNCrypt sub-types to send to the client");
   }
 
   /* get type back from client (must be one of the ones we sent) */
   if (!haveChosenType) {
-    is->check(4);
+    if (!is->hasData(4))
+      return false;
+
     chosenType = is->readU32();
 
     for (i = 0; i < numTypes; i++) {
index 882f0b086154f25c1ef169c6cf7162d0636033cc..c2a348b93ab6ada4634540d7a943f5537dc7f784 100644 (file)
@@ -49,7 +49,7 @@ VncAuthPasswdParameter SSecurityVncAuth::vncAuthPasswd
  "access the server", &SSecurityVncAuth::vncAuthPasswdFile);
 
 SSecurityVncAuth::SSecurityVncAuth(SConnection* sc)
-  : SSecurity(sc), sentChallenge(false), responsePos(0),
+  : SSecurity(sc), sentChallenge(false),
     pg(&vncAuthPasswd), accessRights(0)
 {
 }
@@ -78,6 +78,8 @@ bool SSecurityVncAuth::processMsg()
 
   if (!sentChallenge) {
     rdr::RandomStream rs;
+    if (!rs.hasData(vncAuthChallengeSize))
+      throw Exception("Could not generate random data for VNC auth challenge");
     rs.readBytes(challenge, vncAuthChallengeSize);
     os->writeBytes(challenge, vncAuthChallengeSize);
     os->flush();
@@ -85,10 +87,10 @@ bool SSecurityVncAuth::processMsg()
     return false;
   }
 
-  while (responsePos < vncAuthChallengeSize && is->checkNoWait(1))
-    response[responsePos++] = is->readU8();
+  if (!is->hasData(vncAuthChallengeSize))
+    return false;
 
-  if (responsePos < vncAuthChallengeSize) return false;
+  is->readBytes(response, vncAuthChallengeSize);
 
   PlainPasswd passwd, passwdReadOnly;
   pg->getVncAuthPasswd(&passwd, &passwdReadOnly);
index fe00b031ba67b0cde2c1906545e905d54b61c58d..94d5aaf2f25e620c44694ab174603524be582b44 100644 (file)
@@ -64,7 +64,6 @@ namespace rfb {
     rdr::U8 challenge[vncAuthChallengeSize];
     rdr::U8 response[vncAuthChallengeSize];
     bool sentChallenge;
-    int responsePos;
     VncAuthPasswdGetter* pg;
     SConnection::AccessRights accessRights;
   };
index b1097a3ee45c2223f7190313d4a7d8a348bda398..8f49848c91d28d7757d3ae752804beea882b2fdf 100644 (file)
@@ -42,11 +42,6 @@ rfb::IntParameter rfb::Server::maxIdleTime
 ("MaxIdleTime",
  "Terminate after s seconds of user inactivity", 
  0, 0);
-rfb::IntParameter rfb::Server::clientWaitTimeMillis
-("ClientWaitTimeMillis",
- "The number of milliseconds to wait for a client which is no longer "
- "responding",
- 20000, 0);
 rfb::IntParameter rfb::Server::compareFB
 ("CompareFB",
  "Perform pixel comparison on framebuffer to reduce unnecessary updates "
index f915c7a7a948c3f12e2839aaa2d6670e6952e70a..20a740a8589f3d420d7ddb16f53560d9106eb5b3 100644 (file)
@@ -36,7 +36,6 @@ namespace rfb {
     static IntParameter maxDisconnectionTime;
     static IntParameter maxConnectionTime;
     static IntParameter maxIdleTime;
-    static IntParameter clientWaitTimeMillis;
     static IntParameter compareFB;
     static IntParameter frameRate;
     static BoolParameter protocol3_3;
index ebc98b06ec5219ce04bd96400d3b311b47c176a2..fe03e453a65c65b3b1659d5e92e9959ea08c12b4 100644 (file)
@@ -54,11 +54,16 @@ TightDecoder::~TightDecoder()
 {
 }
 
-void TightDecoder::readRect(const Rect& r, rdr::InStream* is,
+bool TightDecoder::readRect(const Rect& r, rdr::InStream* is,
                             const ServerParams& server, rdr::OutStream* os)
 {
   rdr::U8 comp_ctl;
 
+  if (!is->hasData(1))
+    return false;
+
+  is->setRestorePoint();
+
   comp_ctl = is->readU8();
   os->writeU8(comp_ctl);
 
@@ -66,21 +71,38 @@ void TightDecoder::readRect(const Rect& r, rdr::InStream* is,
 
   // "Fill" compression type.
   if (comp_ctl == tightFill) {
-    if (server.pf().is888())
+    if (server.pf().is888()) {
+      if (!is->hasDataOrRestore(3))
+        return false;
       os->copyBytes(is, 3);
-    else
+    } else {
+      if (!is->hasDataOrRestore(server.pf().bpp/8))
+        return false;
       os->copyBytes(is, server.pf().bpp/8);
-    return;
+    }
+    is->clearRestorePoint();
+    return true;
   }
 
   // "JPEG" compression type.
   if (comp_ctl == tightJpeg) {
     rdr::U32 len;
 
+    // FIXME: Might be less than 3 bytes
+    if (!is->hasDataOrRestore(3))
+      return false;
+
     len = readCompact(is);
     os->writeOpaque32(len);
+
+    if (!is->hasDataOrRestore(len))
+      return false;
+
     os->copyBytes(is, len);
-    return;
+
+    is->clearRestorePoint();
+
+    return true;
   }
 
   // Quit on unsupported compression type.
@@ -98,18 +120,29 @@ void TightDecoder::readRect(const Rect& r, rdr::InStream* is,
   if ((comp_ctl & tightExplicitFilter) != 0) {
     rdr::U8 filterId;
 
+    if (!is->hasDataOrRestore(1))
+      return false;
+
     filterId = is->readU8();
     os->writeU8(filterId);
 
     switch (filterId) {
     case tightFilterPalette:
+      if (!is->hasDataOrRestore(1))
+        return false;
+
       palSize = is->readU8() + 1;
       os->writeU8(palSize - 1);
 
-      if (server.pf().is888())
+      if (server.pf().is888()) {
+        if (!is->hasDataOrRestore(palSize * 3))
+          return false;
         os->copyBytes(is, palSize * 3);
-      else
+      } else {
+        if (!is->hasDataOrRestore(palSize * server.pf().bpp/8))
+          return false;
         os->copyBytes(is, palSize * server.pf().bpp/8);
+      }
       break;
     case tightFilterGradient:
       if (server.pf().bpp == 8)
@@ -137,15 +170,29 @@ void TightDecoder::readRect(const Rect& r, rdr::InStream* is,
 
   dataSize = r.height() * rowSize;
 
-  if (dataSize < TIGHT_MIN_TO_COMPRESS)
+  if (dataSize < TIGHT_MIN_TO_COMPRESS) {
+    if (!is->hasDataOrRestore(dataSize))
+      return false;
     os->copyBytes(is, dataSize);
-  else {
+  else {
     rdr::U32 len;
 
+    // FIXME: Might be less than 3 bytes
+    if (!is->hasDataOrRestore(3))
+      return false;
+
     len = readCompact(is);
     os->writeOpaque32(len);
+
+    if (!is->hasDataOrRestore(len))
+      return false;
+
     os->copyBytes(is, len);
   }
+
+  is->clearRestorePoint();
+
+  return true;
 }
 
 bool TightDecoder::doRectsConflict(const Rect& rectA,
@@ -339,6 +386,8 @@ void TightDecoder::decodeRect(const Rect& r, const void* buffer,
     // Allocate buffer and decompress the data
     netbuf = new rdr::U8[dataSize];
 
+    if (!zis[streamId].hasData(dataSize))
+      throw Exception("Tight decode error");
     zis[streamId].readBytes(netbuf, dataSize);
 
     zis[streamId].flushUnderlying();
index 28b6c30fd48628975357cf9bf6fe4d58aa05df0a..763c82d6642e32dad152f5b302743026b3b4a1e6 100644 (file)
@@ -31,7 +31,7 @@ namespace rfb {
   public:
     TightDecoder();
     virtual ~TightDecoder();
-    virtual void readRect(const Rect& r, rdr::InStream* is,
+    virtual bool readRect(const Rect& r, rdr::InStream* is,
                           const ServerParams& server, rdr::OutStream* os);
     virtual bool doRectsConflict(const Rect& rectA,
                                  const void* bufferA,
index 00f640b364b83fa456ec87fd427ed6be8fb9b914..c4ec733b71576838518cf3d4d6ccbc8b15b29047 100644 (file)
@@ -57,9 +57,6 @@ VNCSConnectionST::VNCSConnectionST(VNCServerST* server_, network::Socket *s,
   setStreams(&sock->inStream(), &sock->outStream());
   peerEndpoint.buf = sock->getPeerEndpoint();
 
-  // Configure the socket
-  setSocketTimeouts();
-
   // Kick off the idle timer
   if (rfb::Server::idleTimeout) {
     // minimum of 15 seconds while authenticating
@@ -152,26 +149,23 @@ void VNCSConnectionST::processMessages()
 {
   if (state() == RFBSTATE_CLOSING) return;
   try {
-    // - Now set appropriate socket timeouts and process data
-    setSocketTimeouts();
-
     inProcessMessages = true;
 
     // Get the underlying transport to build large packets if we send
     // multiple small responses.
     getOutStream()->cork(true);
 
-    while (getInStream()->checkNoWait(1)) {
-      if (pendingSyncFence) {
+    while (true) {
+      if (pendingSyncFence)
         syncFence = true;
-        pendingSyncFence = false;
-      }
 
-      processMsg();
+      if (!processMsg())
+        break;
 
       if (syncFence) {
         writer()->writeFence(fenceFlags, fenceDataLen, fenceData);
         syncFence = false;
+        pendingSyncFence = false;
       }
     }
 
@@ -195,7 +189,6 @@ void VNCSConnectionST::flushSocket()
 {
   if (state() == RFBSTATE_CLOSING) return;
   try {
-    setSocketTimeouts();
     sock->outStream().flush();
     // Flushing the socket might release an update that was previously
     // delayed because of congestion.
@@ -1150,12 +1143,3 @@ void VNCSConnectionST::setLEDState(unsigned int ledstate)
   if (client.supportsLEDState())
     writer()->writeLEDState();
 }
-
-void VNCSConnectionST::setSocketTimeouts()
-{
-  int timeoutms = rfb::Server::clientWaitTimeMillis;
-  if (timeoutms == 0)
-    timeoutms = -1;
-  sock->inStream().setTimeout(timeoutms);
-  sock->outStream().setTimeout(timeoutms);
-}
index 46a2b28b12209d4ee94ddbf1bf000dbe2f87f443..06fdf541974c5f4c7cd9328ad5c7391204f20798 100644 (file)
@@ -155,7 +155,6 @@ namespace rfb {
     void setCursor();
     void setDesktopName(const char *name);
     void setLEDState(unsigned int state);
-    void setSocketTimeouts();
 
   private:
     network::Socket* sock;
index 9d1ff6b68888addde18f66a8224e22ceefe54bd3..4fba0c226e7477e2eb96469739579c0340b6d927 100644 (file)
@@ -21,6 +21,7 @@
 #include <rdr/MemInStream.h>
 #include <rdr/OutStream.h>
 
+#include <rfb/Exception.h>
 #include <rfb/ServerParams.h>
 #include <rfb/PixelBuffer.h>
 #include <rfb/ZRLEDecoder.h>
@@ -29,7 +30,6 @@ using namespace rfb;
 
 static inline rdr::U32 readOpaque24A(rdr::InStream* is)
 {
-  is->check(3);
   rdr::U32 r=0;
   ((rdr::U8*)&r)[0] = is->readU8();
   ((rdr::U8*)&r)[1] = is->readU8();
@@ -39,7 +39,6 @@ static inline rdr::U32 readOpaque24A(rdr::InStream* is)
 }
 static inline rdr::U32 readOpaque24B(rdr::InStream* is)
 {
-  is->check(3);
   rdr::U32 r=0;
   ((rdr::U8*)&r)[1] = is->readU8();
   ((rdr::U8*)&r)[2] = is->readU8();
@@ -47,6 +46,12 @@ static inline rdr::U32 readOpaque24B(rdr::InStream* is)
   return r;
 }
 
+static inline void zlibHasData(rdr::ZlibInStream* zis, size_t length)
+{
+  if (!zis->hasData(length))
+    throw Exception("ZRLE decode error");
+}
+
 #define BPP 8
 #include <rfb/zrleDecode.h>
 #undef BPP
@@ -71,14 +76,27 @@ ZRLEDecoder::~ZRLEDecoder()
 {
 }
 
-void ZRLEDecoder::readRect(const Rect& r, rdr::InStream* is,
+bool ZRLEDecoder::readRect(const Rect& r, rdr::InStream* is,
                            const ServerParams& server, rdr::OutStream* os)
 {
   rdr::U32 len;
 
+  if (!is->hasData(4))
+    return false;
+
+  is->setRestorePoint();
+
   len = is->readU32();
   os->writeU32(len);
+
+  if (!is->hasDataOrRestore(len))
+    return false;
+
+  is->clearRestorePoint();
+
   os->copyBytes(is, len);
+
+  return true;
 }
 
 void ZRLEDecoder::decodeRect(const Rect& r, const void* buffer,
index a530586e14cb2d44939f85af2d17f33c81ab0c4b..115f8fb8ca2cea0fc8fb782b66ce751f115252f4 100644 (file)
@@ -27,7 +27,7 @@ namespace rfb {
   public:
     ZRLEDecoder();
     virtual ~ZRLEDecoder();
-    virtual void readRect(const Rect& r, rdr::InStream* is,
+    virtual bool readRect(const Rect& r, rdr::InStream* is,
                           const ServerParams& server, rdr::OutStream* os);
     virtual void decodeRect(const Rect& r, const void* buffer,
                             size_t buflen, const ServerParams& server,
index f4325385c812ccca5bca9998fa76a49fffc127f8..998e51edc819eb698442f189c7f86e6a9733bd4d 100644 (file)
 // This file is #included after having set the following macro:
 // BPP                - 8, 16 or 32
 
-#include <stdio.h>
-#include <rdr/InStream.h>
-#include <rdr/ZlibInStream.h>
-#include <rfb/Exception.h>
-
 namespace rfb {
 
 // CONCAT2E concatenates its arguments, expanding them if they are macros
@@ -63,11 +58,17 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is,
 
       t.br.x = __rfbmin(r.br.x, t.tl.x + 64);
 
+      zlibHasData(zis, 1);
       int mode = zis->readU8();
       bool rle = mode & 128;
       int palSize = mode & 127;
       PIXEL_T palette[128];
 
+#ifdef CPIXEL
+      zlibHasData(zis, 3 * palSize);
+#else
+      zlibHasData(zis, BPP/8 * palSize);
+#endif
       for (int i = 0; i < palSize; i++) {
         palette[i] = READ_PIXEL(zis);
       }
@@ -84,10 +85,12 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is,
           // raw
 
 #ifdef CPIXEL
+          zlibHasData(zis, 3 * t.area());
           for (PIXEL_T* ptr = buf; ptr < buf+t.area(); ptr++) {
             *ptr = READ_PIXEL(zis);
           }
 #else
+          zlibHasData(zis, BPP/8 * t.area());
           zis->readBytes(buf, t.area() * (BPP / 8));
 #endif
 
@@ -106,6 +109,7 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is,
 
             while (ptr < eol) {
               if (nbits == 0) {
+                zlibHasData(zis, 1);
                 byte = zis->readU8();
                 nbits = 8;
               }
@@ -125,10 +129,16 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is,
           PIXEL_T* ptr = buf;
           PIXEL_T* end = ptr + t.area();
           while (ptr < end) {
+#ifdef CPIXEL
+            zlibHasData(zis, 3);
+#else
+            zlibHasData(zis, BPP/8);
+#endif
             PIXEL_T pix = READ_PIXEL(zis);
             int len = 1;
             int b;
             do {
+              zlibHasData(zis, 1);
               b = zis->readU8();
               len += b;
             } while (b == 255);
@@ -147,11 +157,13 @@ void ZRLE_DECODE (const Rect& r, rdr::InStream* is,
           PIXEL_T* ptr = buf;
           PIXEL_T* end = ptr + t.area();
           while (ptr < end) {
+            zlibHasData(zis, 1);
             int index = zis->readU8();
             int len = 1;
             if (index & 128) {
               int b;
               do {
+                zlibHasData(zis, 1);
                 b = zis->readU8();
                 len += b;
               } while (b == 255);
index e13070701662e82132fc99e31138e7de0f187770..a6c65a221dfa7c6a2eda22901f3b304c839e5f92 100644 (file)
@@ -102,6 +102,8 @@ void DummyOutStream::flush()
 void DummyOutStream::overrun(size_t needed)
 {
   flush();
+  if (avail() < needed)
+    throw rdr::Exception("Insufficient dummy output buffer");
 }
 
 CConn::CConn(const char *filename)
index 6bcb6f7447439795b1f06a39d4ad7ab71ed15c2e..41c309c17480f7ed1521978212663b421979d345 100644 (file)
@@ -95,7 +95,7 @@ public:
   virtual void setCursor(int, int, const rfb::Point&, const rdr::U8*);
   virtual void framebufferUpdateStart();
   virtual void framebufferUpdateEnd();
-  virtual void dataRect(const rfb::Rect&, int);
+  virtual bool dataRect(const rfb::Rect&, int);
   virtual void setColourMapEntries(int, int, rdr::U16*);
   virtual void bell();
   virtual void serverCutText(const char*);
@@ -159,6 +159,8 @@ void DummyOutStream::flush()
 void DummyOutStream::overrun(size_t needed)
 {
   flush();
+  if (avail() < needed)
+    throw rdr::Exception("Insufficient dummy output buffer");
 }
 
 CConn::CConn(const char *filename)
@@ -241,12 +243,15 @@ void CConn::framebufferUpdateEnd()
   encodeTime += getCpuCounter();
 }
 
-void CConn::dataRect(const rfb::Rect &r, int encoding)
+bool CConn::dataRect(const rfb::Rect &r, int encoding)
 {
-  CConnection::dataRect(r, encoding);
+  if (!CConnection::dataRect(r, encoding))
+    return false;
 
   if (encoding != rfb::encodingCopyRect) // FIXME
     updates.add_changed(rfb::Region(r));
+
+  return true;
 }
 
 void CConn::setColourMapEntries(int, int, rdr::U16*)
index 9f3a475017a4c334eb20310168c1d6c67fc465f5..2964df3337aac5aaaac4e80c0b84848290bddfbe 100755 (executable)
@@ -103,7 +103,6 @@ my %config;
 # override these where present.
 $default_opts{desktop} = $desktopName;
 $default_opts{auth} = $xauthorityFile;
-$default_opts{rfbwait} = 30000;
 $default_opts{rfbauth} = "$vncUserDir/passwd";
 $default_opts{rfbport} = $vncPort;
 $default_opts{fp} = $fontPath if ($fontPath);
index a9782ada77634760f24a94973c916f4f25e9f0bb..1531de60b162a5abad84eb0d1d0ca6c74cc2cc8c 100644 (file)
@@ -352,7 +352,6 @@ int main(int argc, char** argv)
         if (FD_ISSET((*i)->getFd(), &rfds)) {
           Socket* sock = (*i)->accept();
           if (sock) {
-            sock->outStream().setBlocking(false);
             server.addSocket(sock);
           } else {
             vlog.status("Client connection rejected");
index b54fcb48ad78f534c1770d443593d46c50d78360..094abbe9bc62531a414b2d87fd7ddcc60c13ab71 100644 (file)
@@ -298,13 +298,6 @@ Terminate when a client has been connected for \fIN\fP seconds.  Default is
 Terminate after \fIN\fP seconds of user inactivity.  Default is 0.
 .
 .TP
-.B \-ClientWaitTimeMillis \fItime\fP
-Time in milliseconds to wait for a viewer which is blocking the server. This is
-necessary because the server is single-threaded and sometimes blocks until the
-viewer has finished sending or receiving a message - note that this does not
-mean an update will be aborted after this time.  Default is 20000 (20 seconds).
-.
-.TP
 .B \-AcceptCutText
 .TQ
 .B \-SendCutText
index 8215c936c8133208206ed7aa63e8f58781e78b5e..6f707299db9840a11db3be676f1a6a3fdb336f38 100644 (file)
@@ -311,7 +311,6 @@ bool XserverDesktop::handleListenerEvent(int fd,
     return false;
 
   Socket* sock = (*i)->accept();
-  sock->outStream().setBlocking(false);
   vlog.debug("new client, sock %d", sock->getFd());
   sockserv->addSocket(sock);
   vncSetNotifyFd(sock->getFd(), screenIndex, true, false);
@@ -393,7 +392,6 @@ void XserverDesktop::blockHandler(int* timeout)
 void XserverDesktop::addClient(Socket* sock, bool reverse)
 {
   vlog.debug("new client, sock %d reverse %d",sock->getFd(),reverse);
-  sock->outStream().setBlocking(false);
   server->addSocket(sock, reverse);
   vncSetNotifyFd(sock->getFd(), screenIndex, true, false);
 }
index 83621c08dbae7be1893ffc37f67181353c1bec25..2d0089d7c1e6d485d75de8c1cec9f27b640f5821 100644 (file)
@@ -98,13 +98,6 @@ connections from viewers, instead of listening on a TCP port.
 Specifies the mode of the Unix domain socket.  The default is 0600.
 .
 .TP
-.B \-rfbwait \fItime\fP, \-ClientWaitTimeMillis \fItime\fP
-Time in milliseconds to wait for a viewer which is blocking the server. This is
-necessary because the server is single-threaded and sometimes blocks until the
-viewer has finished sending or receiving a message - note that this does not
-mean an update will be aborted after this time.  Default is 20000 (20 seconds).
-.
-.TP
 .B \-rfbauth \fIpasswd-file\fP, \-PasswordFile \fIpasswd-file\fP
 Password file for VNC authentication.  There is no default, you should
 specify the password file explicitly.  Password file should be created with
index 6ab306b142dcf73b40eb701002359169031b73b6..43f830881d935e89a1fe553818ceab07fa9b0bce 100644 (file)
@@ -73,8 +73,6 @@ struct CaseInsensitiveCompare {
 typedef std::set<std::string, CaseInsensitiveCompare> ParamSet;
 static ParamSet allowOverrideSet;
 
-rfb::AliasParameter rfbwait("rfbwait", "Alias for ClientWaitTimeMillis",
-                            &rfb::Server::clientWaitTimeMillis);
 rfb::IntParameter rfbport("rfbport", "TCP port to listen for RFB protocol",0);
 rfb::StringParameter rfbunixpath("rfbunixpath", "Unix socket to listen for RFB protocol", "");
 rfb::IntParameter rfbunixmode("rfbunixmode", "Unix socket access mode", 0600);
index e7362c8ee718d42e111cd2364fa1267ebb2d5707..68dd031b4f33f271ce16346cb8dac2f1c8807479 100644 (file)
@@ -116,9 +116,6 @@ CConn::CConn(const char* vncServerName, network::Socket* socket=NULL)
 
   Fl::add_fd(sock->getFd(), FL_READ | FL_EXCEPT, socketEvent, this);
 
-  // See callback below
-  sock->inStream().setBlockCallback(this);
-
   setServerName(serverHost);
   setStreams(&sock->inStream(), &sock->outStream());
 
@@ -228,22 +225,11 @@ unsigned CConn::getPosition()
   return sock->inStream().pos();
 }
 
-// The RFB core is not properly asynchronous, so it calls this callback
-// whenever it needs to block to wait for more data. Since FLTK is
-// monitoring the socket, we just make sure FLTK gets to run.
-
-void CConn::blockCallback()
-{
-  run_mainloop();
-
-  if (should_exit())
-    throw rdr::Exception("Termination requested");
-}
-
 void CConn::socketEvent(FL_SOCKET fd, void *data)
 {
   CConn *cc;
   static bool recursing = false;
+  int when;
 
   assert(data);
   cc = (CConn*)data;
@@ -255,10 +241,14 @@ void CConn::socketEvent(FL_SOCKET fd, void *data)
   recursing = true;
 
   try {
+    // We might have been called to flush unwritten socket data
+    cc->sock->outStream().flush();
+
+    cc->sock->outStream().cork(true);
+
     // processMsg() only processes one message, so we need to loop
     // until the buffers are empty or things will stall.
-    do {
-      cc->processMsg();
+    while (cc->processMsg()) {
 
       // Make sure that the FLTK handling and the timers gets some CPU
       // time in case of back to back messages
@@ -268,7 +258,10 @@ void CConn::socketEvent(FL_SOCKET fd, void *data)
        // Also check if we need to stop reading and terminate
        if (should_exit())
          break;
-    } while (cc->getInStream()->checkNoWait(1));
+    }
+
+    cc->sock->outStream().cork(false);
+    cc->sock->outStream().flush();
   } catch (rdr::EndOfStream& e) {
     vlog.info("%s", e.str());
     exit_vncviewer();
@@ -280,6 +273,12 @@ void CConn::socketEvent(FL_SOCKET fd, void *data)
       exit_vncviewer(e.str());
   }
 
+  when = FL_READ | FL_EXCEPT;
+  if (cc->sock->outStream().hasBufferedData())
+    when |= FL_WRITE;
+
+  Fl::add_fd(fd, when, socketEvent, data);
+
   recursing = false;
 }
 
@@ -402,14 +401,19 @@ void CConn::bell()
   fl_beep();
 }
 
-void CConn::dataRect(const Rect& r, int encoding)
+bool CConn::dataRect(const Rect& r, int encoding)
 {
+  bool ret;
+
   if (encoding != encodingCopyRect)
     lastServerEncoding = encoding;
 
-  CConnection::dataRect(r, encoding);
+  ret = CConnection::dataRect(r, encoding);
+
+  if (ret)
+    pixelCount += r.area();
 
-  pixelCount += r.area();
+  return ret;
 }
 
 void CConn::setCursor(int width, int height, const Point& hotspot,
index 25dff875abd62d8a0f86544eb29cb3a524d9f7c3..ad3fb797bb49f95a1b3766203145220d30c9a702 100644 (file)
@@ -29,8 +29,7 @@ namespace network { class Socket; }
 
 class DesktopWindow;
 
-class CConn : public rfb::CConnection,
-              public rdr::FdInStreamBlockCallback
+class CConn : public rfb::CConnection
 {
 public:
   CConn(const char* vncServerName, network::Socket* sock);
@@ -42,9 +41,6 @@ public:
   unsigned getPixelCount();
   unsigned getPosition();
 
-  // FdInStreamBlockCallback methods
-  void blockCallback();
-
   // Callback when socket is ready (or broken)
   static void socketEvent(FL_SOCKET fd, void *data);
 
@@ -63,7 +59,7 @@ public:
 
   void framebufferUpdateStart();
   void framebufferUpdateEnd();
-  void dataRect(const rfb::Rect& r, int encoding);
+  bool dataRect(const rfb::Rect& r, int encoding);
 
   void setCursor(int width, int height, const rfb::Point& hotspot,
                  const rdr::U8* data);
index 0092d94dd02a8099e8b09e7843dd221b6674487c..393e2191e15be4e557e07174fb7eacdcb52f8f64 100644 (file)
@@ -170,6 +170,13 @@ int SocketManager::checkTimeouts() {
     j_next = j; j_next++;
     if (j->second.sock->isShutdown())
       shutdownSocks.push_back(j->second.sock);
+    else {
+      long eventMask = FD_READ | FD_CLOSE;
+      if (j->second.sock->outStream().hasBufferedData())
+        eventMask |= FD_WRITE;
+      if (WSAEventSelect(j->second.sock->getFd(), j->first, eventMask) == SOCKET_ERROR)
+        throw rdr::SystemException("unable to adjust WSAEventSelect:%u", WSAGetLastError());
+    }
   }
 
   std::list<network::Socket*>::iterator k;
@@ -213,6 +220,13 @@ void SocketManager::processEvent(HANDLE event) {
     try {
       // Process data from an active connection
 
+      WSANETWORKEVENTS events;
+      long eventMask;
+
+      // Fetch why this event notification triggered
+      if (WSAEnumNetworkEvents(ci.sock->getFd(), event, &events) == SOCKET_ERROR)
+        throw rdr::SystemException("unable to get WSAEnumNetworkEvents:%u", WSAGetLastError());
+
       // Cancel event notification for this socket
       if (WSAEventSelect(ci.sock->getFd(), event, 0) == SOCKET_ERROR)
         throw rdr::SystemException("unable to disable WSAEventSelect:%u", WSAGetLastError());
@@ -220,16 +234,29 @@ void SocketManager::processEvent(HANDLE event) {
       // Reset the event object
       WSAResetEvent(event);
 
+
       // Call the socket server to process the event
-      ci.server->processSocketReadEvent(ci.sock);
-      if (ci.sock->isShutdown()) {
-        remSocket(ci.sock);
-        return;
+      if (events.lNetworkEvents & FD_WRITE) {
+        ci.server->processSocketWriteEvent(ci.sock);
+        if (ci.sock->isShutdown()) {
+          remSocket(ci.sock);
+          return;
+        }
+      }
+      if (events.lNetworkEvents & (FD_READ | FD_CLOSE)) {
+        ci.server->processSocketReadEvent(ci.sock);
+        if (ci.sock->isShutdown()) {
+          remSocket(ci.sock);
+          return;
+        }
       }
 
       // Re-instate the required socket event
       // If the read event is still valid, the event object gets set here
-      if (WSAEventSelect(ci.sock->getFd(), event, FD_READ | FD_CLOSE) == SOCKET_ERROR)
+      eventMask = FD_READ | FD_CLOSE;
+      if (ci.sock->outStream().hasBufferedData())
+        eventMask |= FD_WRITE;
+      if (WSAEventSelect(ci.sock->getFd(), event, eventMask) == SOCKET_ERROR)
         throw rdr::SystemException("unable to re-enable WSAEventSelect:%u", WSAGetLastError());
     } catch (rdr::Exception& e) {
       vlog.error("%s", e.str());