diff options
Diffstat (limited to 'common/network')
-rw-r--r-- | common/network/CMakeLists.txt | 2 | ||||
-rw-r--r-- | common/network/Socket.cxx | 31 | ||||
-rw-r--r-- | common/network/Socket.h | 11 | ||||
-rw-r--r-- | common/network/TcpSocket.cxx | 144 | ||||
-rw-r--r-- | common/network/TcpSocket.h | 3 | ||||
-rw-r--r-- | common/network/UnixSocket.cxx | 23 |
6 files changed, 158 insertions, 56 deletions
diff --git a/common/network/CMakeLists.txt b/common/network/CMakeLists.txt index f08eaa31..42472b8d 100644 --- a/common/network/CMakeLists.txt +++ b/common/network/CMakeLists.txt @@ -7,7 +7,7 @@ if(NOT WIN32) endif() target_include_directories(network PUBLIC ${CMAKE_SOURCE_DIR}/common) -target_link_libraries(network os rdr rfb) +target_link_libraries(network core rdr) if(WIN32) target_link_libraries(network ws2_32) diff --git a/common/network/Socket.cxx b/common/network/Socket.cxx index 49abbc84..7fc39d1e 100644 --- a/common/network/Socket.cxx +++ b/common/network/Socket.cxx @@ -32,22 +32,23 @@ #define errorNumber errno #define closesocket close #include <sys/socket.h> -#endif - #include <unistd.h> #include <signal.h> #include <fcntl.h> #include <errno.h> +#endif -#include <rdr/Exception.h> +#include <core/Exception.h> +#include <core/LogWriter.h> -#include <network/Socket.h> +#include <rdr/FdInStream.h> +#include <rdr/FdOutStream.h> -#include <rfb/LogWriter.h> +#include <network/Socket.h> using namespace network; -static rfb::LogWriter vlog("Socket"); +static core::LogWriter vlog("Socket"); // -=- Socket initialisation static bool socketsInitialised = false; @@ -59,7 +60,7 @@ void network::initSockets() { WSADATA initResult; if (WSAStartup(requiredVersion, &initResult) != 0) - throw rdr::socket_error("Unable to initialise Winsock2", errorNumber); + throw core::socket_error("Unable to initialise Winsock2", errorNumber); #else signal(SIGPIPE, SIG_IGN); #endif @@ -99,6 +100,11 @@ Socket::~Socket() delete outstream; } +int Socket::getFd() +{ + return outstream->getFd(); +} + // if shutdown() is overridden then the override MUST call on to here void Socket::shutdown() { @@ -114,7 +120,7 @@ void Socket::shutdown() } isShutdown_ = true; - ::shutdown(getFd(), SHUT_RDWR); + ::shutdown(getFd(), SHUT_WR); } bool Socket::isShutdown() const @@ -122,6 +128,11 @@ bool Socket::isShutdown() const return isShutdown_; } +void Socket::cork(bool enable) +{ + outstream->cork(enable); +} + // Was there a "?" in the ConnectionFilter used to accept this Socket? void Socket::setRequiresQuery() { @@ -178,7 +189,7 @@ Socket* SocketListener::accept() { // Accept an incoming connection if ((new_sock = ::accept(fd, nullptr, nullptr)) < 0) - throw rdr::socket_error("Unable to accept new connection", errorNumber); + throw core::socket_error("Unable to accept new connection", errorNumber); // Create the socket object & check connection is allowed Socket* s = createSocket(new_sock); @@ -196,7 +207,7 @@ void SocketListener::listen(int sock) if (::listen(sock, 5) < 0) { int e = errorNumber; closesocket(sock); - throw rdr::socket_error("Unable to set socket to listening mode", e); + throw core::socket_error("Unable to set socket to listening mode", e); } fd = sock; diff --git a/common/network/Socket.h b/common/network/Socket.h index 34b8db8e..f1688c72 100644 --- a/common/network/Socket.h +++ b/common/network/Socket.h @@ -24,8 +24,11 @@ #include <list> #include <limits.h> -#include <rdr/FdInStream.h> -#include <rdr/FdOutStream.h> + +namespace rdr { + class FdInStream; + class FdOutStream; +} namespace network { @@ -40,12 +43,12 @@ namespace network { rdr::FdInStream &inStream() {return *instream;} rdr::FdOutStream &outStream() {return *outstream;} - int getFd() {return outstream->getFd();} + int getFd(); void shutdown(); bool isShutdown() const; - void cork(bool enable) { outstream->cork(enable); } + void cork(bool enable); // information about the remote end of the socket virtual const char* getPeerAddress() = 0; // a string e.g. "192.168.0.1" diff --git a/common/network/TcpSocket.cxx b/common/network/TcpSocket.cxx index c5b86543..bf3a224c 100644 --- a/common/network/TcpSocket.cxx +++ b/common/network/TcpSocket.cxx @@ -35,19 +35,21 @@ #include <errno.h> #endif +#include <assert.h> +#include <ctype.h> #include <stdlib.h> +#include <string.h> #include <unistd.h> -#include <rdr/Exception.h> +#include <core/Configuration.h> +#include <core/Exception.h> +#include <core/LogWriter.h> +#include <core/string.h> #include <network/TcpSocket.h> -#include <rfb/LogWriter.h> -#include <rfb/Configuration.h> -#include <rfb/util.h> - #ifdef WIN32 -#include <os/winerrno.h> +#include <core/winerrno.h> #endif #ifndef INADDR_NONE @@ -68,12 +70,11 @@ #endif using namespace network; -using namespace rdr; -static rfb::LogWriter vlog("TcpSocket"); +static core::LogWriter vlog("TcpSocket"); -static rfb::BoolParameter UseIPv4("UseIPv4", "Use IPv4 for incoming and outgoing connections.", true); -static rfb::BoolParameter UseIPv6("UseIPv6", "Use IPv6 for incoming and outgoing connections.", true); +static core::BoolParameter UseIPv4("UseIPv4", "Use IPv4 for incoming and outgoing connections.", true); +static core::BoolParameter UseIPv6("UseIPv6", "Use IPv6 for incoming and outgoing connections.", true); /* Tunnelling support. */ int network::findFreeTcpPort (void) @@ -85,20 +86,105 @@ int network::findFreeTcpPort (void) addr.sin_addr.s_addr = INADDR_ANY; if ((sock = socket (AF_INET, SOCK_STREAM, 0)) < 0) - throw socket_error("Unable to create socket", errorNumber); + throw core::socket_error("Unable to create socket", errorNumber); addr.sin_port = 0; if (bind (sock, (struct sockaddr *)&addr, sizeof (addr)) < 0) - throw socket_error("Unable to find free port", errorNumber); + throw core::socket_error("Unable to find free port", errorNumber); socklen_t n = sizeof(addr); if (getsockname (sock, (struct sockaddr *)&addr, &n) < 0) - throw socket_error("Unable to get port number", errorNumber); + throw core::socket_error("Unable to get port number", errorNumber); closesocket (sock); return ntohs(addr.sin_port); } +static bool isAllSpace(const char *string) { + if (string == nullptr) + return false; + while(*string != '\0') { + if (! isspace(*string)) + return false; + string++; + } + return true; +} + +void network::getHostAndPort(const char* hi, std::string* host, + int* port, int basePort) +{ + const char* hostStart; + const char* hostEnd; + const char* portStart; + + if (hi == nullptr) + throw std::invalid_argument("NULL host specified"); + + // Trim leading whitespace + while(isspace(*hi)) + hi++; + + assert(host); + assert(port); + + if (hi[0] == '[') { + hostStart = &hi[1]; + hostEnd = strchr(hostStart, ']'); + if (hostEnd == nullptr) + throw std::invalid_argument("Unmatched [ in host"); + + portStart = hostEnd + 1; + if (isAllSpace(portStart)) + portStart = nullptr; + } else { + hostStart = &hi[0]; + hostEnd = strrchr(hostStart, ':'); + + if (hostEnd == nullptr) { + hostEnd = hostStart + strlen(hostStart); + portStart = nullptr; + } else { + if ((hostEnd > hostStart) && (hostEnd[-1] == ':')) + hostEnd--; + portStart = strchr(hostStart, ':'); + if (portStart != hostEnd) { + // We found more : in the host. This is probably an IPv6 address + hostEnd = hostStart + strlen(hostStart); + portStart = nullptr; + } + } + } + + // Back up past trailing space + while(isspace(*(hostEnd - 1)) && hostEnd > hostStart) + hostEnd--; + + if (hostStart == hostEnd) + *host = "localhost"; + else + *host = std::string(hostStart, hostEnd - hostStart); + + if (portStart == nullptr) + *port = basePort; + else { + char* end; + + if (portStart[0] != ':') + throw std::invalid_argument("Invalid port specified"); + + if (portStart[1] != ':') + *port = strtol(portStart + 1, &end, 10); + else + *port = strtol(portStart + 2, &end, 10); + if (*end != '\0' && ! isAllSpace(end)) + throw std::invalid_argument("Invalid port specified"); + + if ((portStart[1] != ':') && (*port < 100)) + *port += basePort; + } +} + int network::getSockPort(int sock) { vnc_sockaddr_t sa; @@ -137,7 +223,7 @@ TcpSocket::TcpSocket(const char *host, int port) hints.ai_next = nullptr; if ((result = getaddrinfo(host, nullptr, &hints, &ai)) != 0) { - throw getaddrinfo_error("Unable to resolve host by name", result); + throw core::getaddrinfo_error("Unable to resolve host by name", result); } sock = -1; @@ -178,7 +264,7 @@ TcpSocket::TcpSocket(const char *host, int port) if (sock == -1) { err = errorNumber; freeaddrinfo(ai); - throw socket_error("Unable to create socket", err); + throw core::socket_error("Unable to create socket", err); } /* Attempt to connect to the remote host */ @@ -205,7 +291,7 @@ TcpSocket::TcpSocket(const char *host, int port) if (err == 0) throw std::runtime_error("No useful address for host"); else - throw socket_error("Unable to connect to socket", err); + throw core::socket_error("Unable to connect to socket", err); } // Take proper ownership of the socket @@ -302,7 +388,7 @@ TcpListener::TcpListener(const struct sockaddr *listenaddr, int sock; if ((sock = socket (listenaddr->sa_family, SOCK_STREAM, 0)) < 0) - throw socket_error("Unable to create listening socket", errorNumber); + throw core::socket_error("Unable to create listening socket", errorNumber); memcpy (&sa, listenaddr, listenaddrlen); #ifdef IPV6_V6ONLY @@ -310,7 +396,7 @@ TcpListener::TcpListener(const struct sockaddr *listenaddr, if (setsockopt (sock, IPPROTO_IPV6, IPV6_V6ONLY, (char*)&one, sizeof(one))) { int e = errorNumber; closesocket(sock); - throw socket_error("Unable to set IPV6_V6ONLY", e); + throw core::socket_error("Unable to set IPV6_V6ONLY", e); } } #endif /* defined(IPV6_V6ONLY) */ @@ -328,14 +414,14 @@ TcpListener::TcpListener(const struct sockaddr *listenaddr, (char *)&one, sizeof(one)) < 0) { int e = errorNumber; closesocket(sock); - throw socket_error("Unable to create listening socket", e); + throw core::socket_error("Unable to create listening socket", e); } #endif if (bind(sock, &sa.u.sa, listenaddrlen) == -1) { int e = errorNumber; closesocket(sock); - throw socket_error("Failed to bind socket", e); + throw core::socket_error("Failed to bind socket", e); } listen(sock); @@ -446,7 +532,7 @@ void network::createTcpListeners(std::list<SocketListener*> *listeners, snprintf (service, sizeof (service) - 1, "%d", port); service[sizeof (service) - 1] = '\0'; if ((result = getaddrinfo(addr, service, &hints, &ai)) != 0) - throw getaddrinfo_error("Unable to resolve listening address", result); + throw core::getaddrinfo_error("Unable to resolve listening address", result); try { createTcpListeners(listeners, ai); @@ -485,7 +571,7 @@ void network::createTcpListeners(std::list<SocketListener*> *listeners, try { new_listeners.push_back(new TcpListener(current->ai_addr, current->ai_addrlen)); - } catch (socket_error& e) { + } catch (core::socket_error& e) { // Ignore this if it is due to lack of address family support on // the interface or on the system if (e.err != EADDRNOTAVAIL && e.err != EAFNOSUPPORT) { @@ -506,7 +592,7 @@ void network::createTcpListeners(std::list<SocketListener*> *listeners, TcpFilter::TcpFilter(const char* spec) { std::vector<std::string> patterns; - patterns = rfb::split(spec, ','); + patterns = core::split(spec, ','); for (size_t i = 0; i < patterns.size(); i++) { if (!patterns[i].empty()) @@ -608,11 +694,11 @@ TcpFilter::Pattern TcpFilter::parsePattern(const char* p) { initSockets(); - parts = rfb::split(&p[1], '/'); + parts = core::split(&p[1], '/'); if (parts.size() > 2) throw std::invalid_argument("Invalid filter specified"); - if (parts[0].empty()) { + if (parts.empty() || parts[0].empty()) { // Match any address memset (&pattern.address, 0, sizeof (pattern.address)); pattern.address.u.sa.sa_family = AF_UNSPEC; @@ -633,7 +719,7 @@ TcpFilter::Pattern TcpFilter::parsePattern(const char* p) { } if ((result = getaddrinfo (parts[0].c_str(), nullptr, &hints, &ai)) != 0) { - throw getaddrinfo_error("Unable to resolve host by name", result); + throw core::getaddrinfo_error("Unable to resolve host by name", result); } memcpy (&pattern.address.u.sa, ai->ai_addr, ai->ai_addrlen); @@ -666,9 +752,9 @@ TcpFilter::Pattern TcpFilter::parsePattern(const char* p) { family = pattern.address.u.sa.sa_family; if (pattern.prefixlen > (family == AF_INET ? 32: 128)) - throw std::invalid_argument(rfb::format("Invalid prefix length for " - "filter address: %u", - pattern.prefixlen)); + throw std::invalid_argument( + core::format("Invalid prefix length for filter address: %u", + pattern.prefixlen)); // Compute mask from address and prefix length memset (&pattern.mask, 0, sizeof (pattern.mask)); diff --git a/common/network/TcpSocket.h b/common/network/TcpSocket.h index b029bff2..17854e10 100644 --- a/common/network/TcpSocket.h +++ b/common/network/TcpSocket.h @@ -49,6 +49,9 @@ namespace network { /* Tunnelling support. */ int findFreeTcpPort (void); + void getHostAndPort(const char* hi, std::string* host, + int* port, int basePort=5900); + int getSockPort(int sock); class TcpSocket : public Socket { diff --git a/common/network/UnixSocket.cxx b/common/network/UnixSocket.cxx index 48561245..9691cb23 100644 --- a/common/network/UnixSocket.cxx +++ b/common/network/UnixSocket.cxx @@ -28,17 +28,16 @@ #include <errno.h> #include <stdlib.h> #include <stddef.h> +#include <string.h> -#include <rdr/Exception.h> +#include <core/Exception.h> +#include <core/LogWriter.h> #include <network/UnixSocket.h> -#include <rfb/LogWriter.h> - using namespace network; -using namespace rdr; -static rfb::LogWriter vlog("UnixSocket"); +static core::LogWriter vlog("UnixSocket"); // -=- UnixSocket @@ -53,12 +52,12 @@ UnixSocket::UnixSocket(const char *path) socklen_t salen; if (strlen(path) >= sizeof(addr.sun_path)) - throw socket_error("Socket path is too long", ENAMETOOLONG); + throw core::socket_error("Socket path is too long", ENAMETOOLONG); // - Create a socket sock = socket(AF_UNIX, SOCK_STREAM, 0); if (sock == -1) - throw socket_error("Unable to create socket", errno); + throw core::socket_error("Unable to create socket", errno); // - Attempt to connect memset(&addr, 0, sizeof(addr)); @@ -72,7 +71,7 @@ UnixSocket::UnixSocket(const char *path) } if (result == -1) - throw socket_error("Unable to connect to socket", err); + throw core::socket_error("Unable to connect to socket", err); setFd(sock); } @@ -119,11 +118,11 @@ UnixListener::UnixListener(const char *path, int mode) int err, result; if (strlen(path) >= sizeof(addr.sun_path)) - throw socket_error("Socket path is too long", ENAMETOOLONG); + throw core::socket_error("Socket path is too long", ENAMETOOLONG); // - Create a socket if ((fd = socket(AF_UNIX, SOCK_STREAM, 0)) < 0) - throw socket_error("Unable to create listening socket", errno); + throw core::socket_error("Unable to create listening socket", errno); // - Delete existing socket (ignore result) unlink(path); @@ -138,14 +137,14 @@ UnixListener::UnixListener(const char *path, int mode) umask(saved_umask); if (result < 0) { close(fd); - throw socket_error("Unable to bind listening socket", err); + throw core::socket_error("Unable to bind listening socket", err); } // - Set socket mode if (chmod(path, mode) < 0) { err = errno; close(fd); - throw socket_error("Unable to set socket mode", err); + throw core::socket_error("Unable to set socket mode", err); } listen(fd); |