diff options
Diffstat (limited to 'common/network/TcpSocket.cxx')
-rw-r--r-- | common/network/TcpSocket.cxx | 144 |
1 files changed, 115 insertions, 29 deletions
diff --git a/common/network/TcpSocket.cxx b/common/network/TcpSocket.cxx index c5b86543..bf3a224c 100644 --- a/common/network/TcpSocket.cxx +++ b/common/network/TcpSocket.cxx @@ -35,19 +35,21 @@ #include <errno.h> #endif +#include <assert.h> +#include <ctype.h> #include <stdlib.h> +#include <string.h> #include <unistd.h> -#include <rdr/Exception.h> +#include <core/Configuration.h> +#include <core/Exception.h> +#include <core/LogWriter.h> +#include <core/string.h> #include <network/TcpSocket.h> -#include <rfb/LogWriter.h> -#include <rfb/Configuration.h> -#include <rfb/util.h> - #ifdef WIN32 -#include <os/winerrno.h> +#include <core/winerrno.h> #endif #ifndef INADDR_NONE @@ -68,12 +70,11 @@ #endif using namespace network; -using namespace rdr; -static rfb::LogWriter vlog("TcpSocket"); +static core::LogWriter vlog("TcpSocket"); -static rfb::BoolParameter UseIPv4("UseIPv4", "Use IPv4 for incoming and outgoing connections.", true); -static rfb::BoolParameter UseIPv6("UseIPv6", "Use IPv6 for incoming and outgoing connections.", true); +static core::BoolParameter UseIPv4("UseIPv4", "Use IPv4 for incoming and outgoing connections.", true); +static core::BoolParameter UseIPv6("UseIPv6", "Use IPv6 for incoming and outgoing connections.", true); /* Tunnelling support. */ int network::findFreeTcpPort (void) @@ -85,20 +86,105 @@ int network::findFreeTcpPort (void) addr.sin_addr.s_addr = INADDR_ANY; if ((sock = socket (AF_INET, SOCK_STREAM, 0)) < 0) - throw socket_error("Unable to create socket", errorNumber); + throw core::socket_error("Unable to create socket", errorNumber); addr.sin_port = 0; if (bind (sock, (struct sockaddr *)&addr, sizeof (addr)) < 0) - throw socket_error("Unable to find free port", errorNumber); + throw core::socket_error("Unable to find free port", errorNumber); socklen_t n = sizeof(addr); if (getsockname (sock, (struct sockaddr *)&addr, &n) < 0) - throw socket_error("Unable to get port number", errorNumber); + throw core::socket_error("Unable to get port number", errorNumber); closesocket (sock); return ntohs(addr.sin_port); } +static bool isAllSpace(const char *string) { + if (string == nullptr) + return false; + while(*string != '\0') { + if (! isspace(*string)) + return false; + string++; + } + return true; +} + +void network::getHostAndPort(const char* hi, std::string* host, + int* port, int basePort) +{ + const char* hostStart; + const char* hostEnd; + const char* portStart; + + if (hi == nullptr) + throw std::invalid_argument("NULL host specified"); + + // Trim leading whitespace + while(isspace(*hi)) + hi++; + + assert(host); + assert(port); + + if (hi[0] == '[') { + hostStart = &hi[1]; + hostEnd = strchr(hostStart, ']'); + if (hostEnd == nullptr) + throw std::invalid_argument("Unmatched [ in host"); + + portStart = hostEnd + 1; + if (isAllSpace(portStart)) + portStart = nullptr; + } else { + hostStart = &hi[0]; + hostEnd = strrchr(hostStart, ':'); + + if (hostEnd == nullptr) { + hostEnd = hostStart + strlen(hostStart); + portStart = nullptr; + } else { + if ((hostEnd > hostStart) && (hostEnd[-1] == ':')) + hostEnd--; + portStart = strchr(hostStart, ':'); + if (portStart != hostEnd) { + // We found more : in the host. This is probably an IPv6 address + hostEnd = hostStart + strlen(hostStart); + portStart = nullptr; + } + } + } + + // Back up past trailing space + while(isspace(*(hostEnd - 1)) && hostEnd > hostStart) + hostEnd--; + + if (hostStart == hostEnd) + *host = "localhost"; + else + *host = std::string(hostStart, hostEnd - hostStart); + + if (portStart == nullptr) + *port = basePort; + else { + char* end; + + if (portStart[0] != ':') + throw std::invalid_argument("Invalid port specified"); + + if (portStart[1] != ':') + *port = strtol(portStart + 1, &end, 10); + else + *port = strtol(portStart + 2, &end, 10); + if (*end != '\0' && ! isAllSpace(end)) + throw std::invalid_argument("Invalid port specified"); + + if ((portStart[1] != ':') && (*port < 100)) + *port += basePort; + } +} + int network::getSockPort(int sock) { vnc_sockaddr_t sa; @@ -137,7 +223,7 @@ TcpSocket::TcpSocket(const char *host, int port) hints.ai_next = nullptr; if ((result = getaddrinfo(host, nullptr, &hints, &ai)) != 0) { - throw getaddrinfo_error("Unable to resolve host by name", result); + throw core::getaddrinfo_error("Unable to resolve host by name", result); } sock = -1; @@ -178,7 +264,7 @@ TcpSocket::TcpSocket(const char *host, int port) if (sock == -1) { err = errorNumber; freeaddrinfo(ai); - throw socket_error("Unable to create socket", err); + throw core::socket_error("Unable to create socket", err); } /* Attempt to connect to the remote host */ @@ -205,7 +291,7 @@ TcpSocket::TcpSocket(const char *host, int port) if (err == 0) throw std::runtime_error("No useful address for host"); else - throw socket_error("Unable to connect to socket", err); + throw core::socket_error("Unable to connect to socket", err); } // Take proper ownership of the socket @@ -302,7 +388,7 @@ TcpListener::TcpListener(const struct sockaddr *listenaddr, int sock; if ((sock = socket (listenaddr->sa_family, SOCK_STREAM, 0)) < 0) - throw socket_error("Unable to create listening socket", errorNumber); + throw core::socket_error("Unable to create listening socket", errorNumber); memcpy (&sa, listenaddr, listenaddrlen); #ifdef IPV6_V6ONLY @@ -310,7 +396,7 @@ TcpListener::TcpListener(const struct sockaddr *listenaddr, if (setsockopt (sock, IPPROTO_IPV6, IPV6_V6ONLY, (char*)&one, sizeof(one))) { int e = errorNumber; closesocket(sock); - throw socket_error("Unable to set IPV6_V6ONLY", e); + throw core::socket_error("Unable to set IPV6_V6ONLY", e); } } #endif /* defined(IPV6_V6ONLY) */ @@ -328,14 +414,14 @@ TcpListener::TcpListener(const struct sockaddr *listenaddr, (char *)&one, sizeof(one)) < 0) { int e = errorNumber; closesocket(sock); - throw socket_error("Unable to create listening socket", e); + throw core::socket_error("Unable to create listening socket", e); } #endif if (bind(sock, &sa.u.sa, listenaddrlen) == -1) { int e = errorNumber; closesocket(sock); - throw socket_error("Failed to bind socket", e); + throw core::socket_error("Failed to bind socket", e); } listen(sock); @@ -446,7 +532,7 @@ void network::createTcpListeners(std::list<SocketListener*> *listeners, snprintf (service, sizeof (service) - 1, "%d", port); service[sizeof (service) - 1] = '\0'; if ((result = getaddrinfo(addr, service, &hints, &ai)) != 0) - throw getaddrinfo_error("Unable to resolve listening address", result); + throw core::getaddrinfo_error("Unable to resolve listening address", result); try { createTcpListeners(listeners, ai); @@ -485,7 +571,7 @@ void network::createTcpListeners(std::list<SocketListener*> *listeners, try { new_listeners.push_back(new TcpListener(current->ai_addr, current->ai_addrlen)); - } catch (socket_error& e) { + } catch (core::socket_error& e) { // Ignore this if it is due to lack of address family support on // the interface or on the system if (e.err != EADDRNOTAVAIL && e.err != EAFNOSUPPORT) { @@ -506,7 +592,7 @@ void network::createTcpListeners(std::list<SocketListener*> *listeners, TcpFilter::TcpFilter(const char* spec) { std::vector<std::string> patterns; - patterns = rfb::split(spec, ','); + patterns = core::split(spec, ','); for (size_t i = 0; i < patterns.size(); i++) { if (!patterns[i].empty()) @@ -608,11 +694,11 @@ TcpFilter::Pattern TcpFilter::parsePattern(const char* p) { initSockets(); - parts = rfb::split(&p[1], '/'); + parts = core::split(&p[1], '/'); if (parts.size() > 2) throw std::invalid_argument("Invalid filter specified"); - if (parts[0].empty()) { + if (parts.empty() || parts[0].empty()) { // Match any address memset (&pattern.address, 0, sizeof (pattern.address)); pattern.address.u.sa.sa_family = AF_UNSPEC; @@ -633,7 +719,7 @@ TcpFilter::Pattern TcpFilter::parsePattern(const char* p) { } if ((result = getaddrinfo (parts[0].c_str(), nullptr, &hints, &ai)) != 0) { - throw getaddrinfo_error("Unable to resolve host by name", result); + throw core::getaddrinfo_error("Unable to resolve host by name", result); } memcpy (&pattern.address.u.sa, ai->ai_addr, ai->ai_addrlen); @@ -666,9 +752,9 @@ TcpFilter::Pattern TcpFilter::parsePattern(const char* p) { family = pattern.address.u.sa.sa_family; if (pattern.prefixlen > (family == AF_INET ? 32: 128)) - throw std::invalid_argument(rfb::format("Invalid prefix length for " - "filter address: %u", - pattern.prefixlen)); + throw std::invalid_argument( + core::format("Invalid prefix length for filter address: %u", + pattern.prefixlen)); // Compute mask from address and prefix length memset (&pattern.mask, 0, sizeof (pattern.mask)); |