]> source.dussan.org Git - tigervnc.git/commitdiff
IPv6 support for TcpFilter. 150/head
authorTim Waugh <twaugh@redhat.com>
Fri, 13 Mar 2015 16:07:29 +0000 (16:07 +0000)
committerTim Waugh <twaugh@redhat.com>
Fri, 13 Mar 2015 18:05:02 +0000 (18:05 +0000)
common/network/TcpSocket.cxx
common/network/TcpSocket.h

index 188651edc69155e0f5bd5e2ae07fb3bab9c56f89..6fca3018a72c28776c6c85181b15ebf220b41ca4 100644 (file)
@@ -31,7 +31,6 @@
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <arpa/inet.h>
-#include <netinet/in.h>
 #include <netinet/tcp.h>
 #include <netdb.h>
 #include <errno.h>
 using namespace network;
 using namespace rdr;
 
-typedef struct vnc_sockaddr {
-       union {
-               sockaddr        sa;
-               sockaddr_in     sin;
-#ifdef HAVE_GETADDRINFO
-               sockaddr_in6    sin6;
-#endif
-       } u;
-} vnc_sockaddr_t;
-
 static rfb::LogWriter vlog("TcpSocket");
 
 static rfb::BoolParameter UseIPv4("UseIPv4", "Use IPv4 for incoming and outgoing connections.", true);
@@ -708,29 +697,69 @@ TcpFilter::~TcpFilter() {
 
 
 static bool
-patternMatchIP(const TcpFilter::Pattern& pattern, const char* value) {
-  unsigned long address = inet_addr((char *)value);
-  if (address == INADDR_NONE) return false;
-  return ((pattern.address & pattern.mask) == (address & pattern.mask));
+patternMatchIP(const TcpFilter::Pattern& pattern, vnc_sockaddr_t *sa) {
+  switch (pattern.address.u.sa.sa_family) {
+    unsigned long address;
+
+  case AF_INET:
+    if (sa->u.sa.sa_family != AF_INET)
+      return false;
+
+    address = sa->u.sin.sin_addr.s_addr;
+    if (address == htonl (INADDR_NONE)) return false;
+    return ((pattern.address.u.sin.sin_addr.s_addr &
+             pattern.mask.u.sin.sin_addr.s_addr) ==
+            (address & pattern.mask.u.sin.sin_addr.s_addr));
+
+  case AF_INET6:
+    if (sa->u.sa.sa_family != AF_INET6)
+      return false;
+
+    for (unsigned int n = 0; n < 16; n++) {
+      unsigned int bits = (n + 1) * 8;
+      unsigned int mask;
+      if (pattern.prefixlen > bits)
+        mask = 0xff;
+      else {
+        unsigned int lastbits = 0xff;
+        lastbits <<= bits - pattern.prefixlen;
+        mask = lastbits & 0xff;
+      }
+
+      if ((pattern.address.u.sin6.sin6_addr.s6_addr[n] & mask) !=
+          (sa->u.sin6.sin6_addr.s6_addr[n] & mask))
+        return false;
+
+      if (mask < 0xff)
+        break;
+    }
+
+    return true;
+
+  case AF_UNSPEC:
+    // Any address matches
+    return true;
+
+  default:
+    break;
+  }
+
+  return false;
 }
 
 bool
 TcpFilter::verifyConnection(Socket* s) {
   rfb::CharArray name;
-
-#ifdef HAVE_GETADDRINFO
   vnc_sockaddr_t sa;
   socklen_t sa_size = sizeof(sa);
-  if (getpeername(s->getFd(), &sa.u.sa, &sa_size) != 0 ||
-      sa.u.sa.sa_family != AF_INET)
-    /* Matching only works for IPv4 */
+
+  if (getpeername(s->getFd(), &sa.u.sa, &sa_size) != 0)
     return false;
-#endif /* HAVE_GETADDRINFO */
 
   name.buf = s->getPeerAddress();
   std::list<TcpFilter::Pattern>::iterator i;
   for (i=filter.begin(); i!=filter.end(); i++) {
-    if (patternMatchIP(*i, name.buf)) {
+    if (patternMatchIP(*i, &sa)) {
       switch ((*i).action) {
       case Accept:
         vlog.debug("ACCEPT %s", name.buf);
@@ -754,31 +783,105 @@ TcpFilter::verifyConnection(Socket* s) {
 TcpFilter::Pattern TcpFilter::parsePattern(const char* p) {
   TcpFilter::Pattern pattern;
 
-  bool expandMask = false;
-  rfb::CharArray addr, mask;
+  rfb::CharArray addr, pref;
+  bool prefix_specified;
+  int family;
 
-  if (rfb::strSplit(&p[1], '/', &addr.buf, &mask.buf)) {
-    if (rfb::strContains(mask.buf, '.')) {
-      pattern.mask = inet_addr(mask.buf);
+  prefix_specified = rfb::strSplit(&p[1], '/', &addr.buf, &pref.buf);
+  if (addr.buf[0] == '\0') {
+    // Match any address
+    memset (&pattern.address, 0, sizeof (pattern.address));
+    pattern.address.u.sa.sa_family = AF_UNSPEC;
+    pattern.prefixlen = 0;
+  } else {
+#ifdef HAVE_GETADDRINFO
+    struct addrinfo hints;
+    struct addrinfo *ai;
+    char *p = addr.buf;
+    int result;
+    memset (&hints, 0, sizeof (hints));
+    hints.ai_family = AF_UNSPEC;
+    hints.ai_flags = AI_NUMERICHOST;
+
+    // Take out brackets, if present
+    if (*p == '[') {
+      size_t len;
+      p++;
+      len = strlen (p);
+      if (len > 0 && p[len - 1] == ']')
+        p[len - 1] = '\0';
+    }
+
+    if ((result = getaddrinfo (p, NULL, &hints, &ai)) != 0) {
+      throw Exception("unable to resolve host by name: %s",
+                      gai_strerror(result));
+    }
+
+    memcpy (&pattern.address.u.sa, ai->ai_addr, ai->ai_addrlen);
+    freeaddrinfo (ai);
+#else
+    pattern.address.u.sa.sa_family = AF_INET;
+    pattern.address.u.sin.sin_addr.s_addr = inet_addr(addr.buf);
+#endif /* HAVE_GETADDRINFO */
+
+    family = pattern.address.u.sa.sa_family;
+
+    if (prefix_specified) {
+      if (family == AF_INET &&
+          rfb::strContains(pref.buf, '.')) {
+        throw Exception("mask no longer supported for filter, "
+                        "use prefix instead");
+      }
+
+      pattern.prefixlen = (unsigned int) atoi(pref.buf);
     } else {
-      pattern.mask = atoi(mask.buf);
-      expandMask = true;
+      switch (family) {
+      case AF_INET:
+        pattern.prefixlen = 32;
+        break;
+      case AF_INET6:
+        pattern.prefixlen = 128;
+        break;
+      default:
+        throw Exception("unknown address family");
+      }
     }
-  } else {
-    pattern.mask = 32;
-    expandMask = true;
-  }
-  if (expandMask) {
-    unsigned long expanded = 0;
-    // *** check endianness!
-    for (int i=0; i<(int)pattern.mask; i++)
-      expanded |= 1<<(31-i);
-    pattern.mask = htonl(expanded);
   }
 
-  pattern.address = inet_addr(addr.buf) & pattern.mask;
-  if ((pattern.address == INADDR_NONE) ||
-      (pattern.address == 0)) pattern.mask = 0;
+  if (pattern.prefixlen > (family == AF_INET ? 32: 128))
+    throw Exception("invalid prefix length for filter address: %u",
+                    pattern.prefixlen);
+
+  // Compute mask from address and prefix length
+  memset (&pattern.mask, 0, sizeof (pattern.mask));
+  switch (family) {
+    unsigned long mask;
+  case AF_INET:
+    mask = 0;
+    for (unsigned int i=0; i<pattern.prefixlen; i++)
+      mask |= 1<<(31-i);
+    pattern.mask.u.sin.sin_addr.s_addr = htonl(mask);
+    break;
+
+  case AF_INET6:
+    for (unsigned int n = 0; n < 16; n++) {
+      unsigned int bits = (n + 1) * 8;
+      if (pattern.prefixlen > bits)
+        pattern.mask.u.sin6.sin6_addr.s6_addr[n] = 0xff;
+      else {
+        unsigned int lastbits = 0xff;
+        lastbits <<= bits - pattern.prefixlen;
+        pattern.mask.u.sin6.sin6_addr.s6_addr[n] = lastbits & 0xff;
+        break;
+      }
+    }
+    break;
+  case AF_UNSPEC:
+    // No mask to compute
+    break;
+  default:
+    ; /* not reached */
+  }
 
   switch(p[0]) {
   case '+': pattern.action = TcpFilter::Accept; break;
@@ -790,21 +893,45 @@ TcpFilter::Pattern TcpFilter::parsePattern(const char* p) {
 }
 
 char* TcpFilter::patternToStr(const TcpFilter::Pattern& p) {
+  rfb::CharArray addr;
+#ifdef HAVE_GETADDRINFO
+  char buffer[INET6_ADDRSTRLEN + 2];
+
+  if (p.address.u.sa.sa_family == AF_INET) {
+    getnameinfo(&p.address.u.sa, sizeof(p.address.u.sin),
+                buffer, sizeof (buffer), NULL, 0, NI_NUMERICHOST);
+    addr.buf = rfb::strDup(buffer);
+  } else if (p.address.u.sa.sa_family == AF_INET6) {
+    buffer[0] = '[';
+    getnameinfo(&p.address.u.sa, sizeof(p.address.u.sin6),
+                buffer + 1, sizeof (buffer) - 2, NULL, 0, NI_NUMERICHOST);
+    strcat(buffer, "]");
+    addr.buf = rfb::strDup(buffer);
+  } else if (p.address.u.sa.sa_family == AF_UNSPEC)
+    addr.buf = rfb::strDup("");
+#else
   in_addr tmp;
-  rfb::CharArray addr, mask;
-  tmp.s_addr = p.address;
+  tmp.s_addr = p.address.u.sin.sin_addr.s_addr;
   addr.buf = rfb::strDup(inet_ntoa(tmp));
-  tmp.s_addr = p.mask;
-  mask.buf = rfb::strDup(inet_ntoa(tmp));
-  char* result = new char[strlen(addr.buf)+1+strlen(mask.buf)+1+1];
+#endif /* HAVE_GETADDRINFO */
+
+  char action;
   switch (p.action) {
-  case Accept: result[0] = '+'; break;
-  case Reject: result[0] = '-'; break;
-  case Query: result[0] = '?'; break;
+  case Accept: action = '+'; break;
+  case Reject: action = '-'; break;
+  default:
+  case Query: action = '?'; break;
   };
-  result[1] = 0;
-  strcat(result, addr.buf);
-  strcat(result, "/");
-  strcat(result, mask.buf);
+  size_t resultlen = (1                   // action
+                      + strlen (addr.buf) // address
+                      + 1                 // slash
+                      + 3                 // prefix length, max 128
+                      + 1);               // terminating nul
+  char* result = new char[resultlen];
+  if (addr.buf[0] == '\0')
+    snprintf(result, resultlen, "%c", action);
+  else
+    snprintf(result, resultlen, "%c%s/%u", action, addr.buf, p.prefixlen);
+
   return result;
 }
index e62738900c7f7eb3680baaad36bd3788808e0b05..29c09c69ba99754097fc67bcf9ab9ec248c84052 100644 (file)
 #ifndef __NETWORK_TCP_SOCKET_H__
 #define __NETWORK_TCP_SOCKET_H__
 
+#ifdef HAVE_CONFIG_H
+#include <config.h> /* for HAVE_GETADDRINFO */
+#endif
+
 #include <network/Socket.h>
-#include <sys/socket.h>
+#include <sys/socket.h> /* for socklen_t */
+#include <netinet/in.h> /* for struct sockaddr_in */
 
 #include <list>
 
@@ -85,6 +90,16 @@ namespace network {
                           const char *addr,
                           int port);
 
+  typedef struct vnc_sockaddr {
+    union {
+      sockaddr     sa;
+      sockaddr_in  sin;
+#ifdef HAVE_GETADDRINFO
+      sockaddr_in6 sin6;
+#endif
+    } u;
+  } vnc_sockaddr_t;
+
   class TcpFilter : public ConnectionFilter {
   public:
     TcpFilter(const char* filter);
@@ -95,8 +110,10 @@ namespace network {
     typedef enum {Accept, Reject, Query} Action;
     struct Pattern {
       Action action;
-      unsigned long address;
-      unsigned long mask;
+      vnc_sockaddr_t address;
+      unsigned int prefixlen;
+
+      vnc_sockaddr_t mask; // computed from address and prefix
     };
     static Pattern parsePattern(const char* s);
     static char* patternToStr(const Pattern& p);