]> source.dussan.org Git - rspamd.git/commitdiff
Rework rspamd DNS resolver.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 10 Dec 2013 16:01:14 +0000 (16:01 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 10 Dec 2013 16:01:14 +0000 (16:01 +0000)
src/dns.c
src/dns.h
src/main.h
src/util.h
src/worker.c
src/worker_util.c

index a6e87aa3262d40edfc74c8a687071cbda941d056..996bc086e27495eb26fcd52e18477066a35605a3 100644 (file)
--- a/src/dns.c
+++ b/src/dns.c
@@ -52,14 +52,11 @@ static const unsigned initial_bias = 72;
 
 static const gint dns_port = 53;
 
+static void dns_read_cb (gint fd, short what, void *arg);
+static void dns_retransmit_handler (gint fd, short what, void *arg);
+
+#define DNS_RANDOM g_random_int
 
-#ifdef HAVE_ARC4RANDOM
-#define DNS_RANDOM arc4random
-#elif defined HAVE_RANDOM
-#define DNS_RANDOM random
-#else
-#define DNS_RANDOM rand
-#endif
 
 #define UDP_PACKET_SIZE 4096
 
@@ -209,8 +206,6 @@ punycode_label_toascii(const gunichar *in, gsize in_len, gchar *out,
 #define DNS_K_TEA_CYCLES       32
 #define DNS_K_TEA_MAGIC                0x9E3779B9U
 
-static void dns_retransmit_handler (gint fd, short what, void *arg);
-
 
 static void 
 dns_k_tea_init(struct dns_k_tea *tea, guint32 key[], guint cycles) 
@@ -434,6 +429,30 @@ dns_k_shuffle16 (guint16 n, guint s)
        return ((0xff00 & (a << 8)) | (0x00ff & (b << 0)));
 } /* dns_k_shuffle16() */
 
+static gint
+make_dns_socket (struct rspamd_dns_server *serv)
+{
+       gint sock;
+
+       sock = socket (serv->addr.sa.sa_family, SOCK_DGRAM, 0);
+       if (sock == -1) {
+               msg_warn ("socket failed: %d, '%s'", errno, strerror (errno));
+               return -1;
+       }
+
+       if (make_socket_nonblocking (sock) == -1) {
+               close (sock);
+               return -1;
+       }
+       if (fcntl (sock, F_SETFD, FD_CLOEXEC) == -1) {
+               msg_warn ("fcntl failed: %d, '%s'", errno, strerror (errno));
+               close (sock);
+               return -1;
+       }
+
+       return sock;
+}
+
 struct dns_request_key {
        guint16 id;
        guint16 port;
@@ -672,7 +691,7 @@ make_a_req (struct rspamd_dns_request *req, const gchar *name)
        *p = htons (DNS_C_IN);
        req->pos += sizeof (guint16) * 2;
        req->type = DNS_REQUEST_A;
-       req->requested_name = name;
+       req->requested_name = memory_pool_strdup (req->pool, name);
 }
 
 #ifdef HAVE_INET_PTON
@@ -689,7 +708,7 @@ make_aaa_req (struct rspamd_dns_request *req, const gchar *name)
        *p = htons (DNS_C_IN);
        req->pos += sizeof (guint16) * 2;
        req->type = DNS_REQUEST_AAA;
-       req->requested_name = name;
+       req->requested_name = memory_pool_strdup (req->pool, name);
 }
 #endif
 
@@ -706,7 +725,7 @@ make_txt_req (struct rspamd_dns_request *req, const gchar *name)
        *p = htons (DNS_C_IN);
        req->pos += sizeof (guint16) * 2;
        req->type = DNS_REQUEST_TXT;
-       req->requested_name = name;
+       req->requested_name = memory_pool_strdup (req->pool, name);
 }
 
 static void
@@ -722,7 +741,7 @@ make_mx_req (struct rspamd_dns_request *req, const gchar *name)
        *p = htons (DNS_C_IN);
        req->pos += sizeof (guint16) * 2;
        req->type = DNS_REQUEST_MX;
-       req->requested_name = name;
+       req->requested_name = memory_pool_strdup (req->pool, name);
 }
 
 static void
@@ -744,7 +763,7 @@ make_srv_req (struct rspamd_dns_request *req, const gchar *service, const gchar
        *p = htons (DNS_C_IN);
        req->pos += sizeof (guint16) * 2;
        req->type = DNS_REQUEST_SRV;
-       req->requested_name = name;
+       req->requested_name = memory_pool_strdup (req->pool, name);
 }
 
 static void
@@ -760,11 +779,11 @@ make_spf_req (struct rspamd_dns_request *req, const gchar *name)
        *p = htons (DNS_C_IN);
        req->pos += sizeof (guint16) * 2;
        req->type = DNS_REQUEST_SPF;
-       req->requested_name = name;
+       req->requested_name = memory_pool_strdup (req->pool, name);
 }
 
 static guint16
-rspamd_bind_to_random_port (int sock)
+rspamd_bind_to_random_port (int sock, int af)
 {
        union sa_union su;
        socklen_t slen = sizeof (su);
@@ -772,21 +791,21 @@ rspamd_bind_to_random_port (int sock)
        const int max_retries = 10;
        int retries = 0;
 
-       if (getsockname (sock, &su.sa, &slen) != -1) {
+       memset (&su, 0, sizeof (su));
+       su.sa.sa_family = af;
 
-               while (retries < max_retries) {
-                       ret = g_random_int_range (1024, G_MAXUINT16 - 1);
-                       if (su.sa.sa_family == AF_INET) {
-                               su.s4.sin_port = htons (ret);
-                       }
-                       else if (su.sa.sa_family == AF_INET6) {
-                               su.s6.sin6_port = htons (ret);
-                       }
-                       if (bind (sock, &su.sa, slen) != -1) {
-                               return ret;
-                       }
-                       retries ++;
+       while (retries < max_retries) {
+               ret = g_random_int_range (1024, G_MAXUINT16 - 1);
+               if (af == AF_INET) {
+                       su.s4.sin_port = htons (ret);
                }
+               else if (af == AF_INET6) {
+                       su.s6.sin6_port = htons (ret);
+               }
+               if (bind (sock, &su.sa, slen) != -1) {
+                       return ret;
+               }
+               retries ++;
        }
 
        return 0;
@@ -797,15 +816,14 @@ send_dns_request (struct rspamd_dns_request *req)
 {
        gint r;
 
-       req->port = rspamd_bind_to_random_port (req->sock);
-       req->key = ((guint32)req->port) << 16 + req->id;
-       r = send (req->sock, req->packet, req->pos, 0);
+       r = sendto (req->sock, req->packet, req->pos, 0, &req->server->addr.sa,
+                       req->server->addr.sa.sa_family == AF_INET ? sizeof (struct sockaddr_in) :
+                                       sizeof (struct sockaddr_in6));
        if (r == -1) {
                if (errno == EAGAIN) {
                        event_set (&req->io_event, req->sock, EV_WRITE, dns_retransmit_handler, req);
                        event_base_set (req->resolver->ev_base, &req->io_event);
                        event_add (&req->io_event, &req->tv);
-                       register_async_event (req->session, (event_finalizer_t)event_del, &req->io_event, g_quark_from_static_string ("dns resolver"));
                        return 0;
                } 
                else {
@@ -818,9 +836,13 @@ send_dns_request (struct rspamd_dns_request *req)
                event_set (&req->io_event, req->sock, EV_WRITE, dns_retransmit_handler, req);
                event_base_set (req->resolver->ev_base, &req->io_event);
                event_add (&req->io_event, &req->tv);
-               register_async_event (req->session, (event_finalizer_t)event_del, &req->io_event, g_quark_from_static_string ("dns resolver"));
                return 0;
        }
+       else {
+               event_set (&req->io_event, req->sock, EV_READ, dns_read_cb, req);
+               event_base_set (req->resolver->ev_base, &req->io_event);
+               event_add (&req->io_event, NULL);
+       }
        
        return 1;
 }
@@ -830,7 +852,11 @@ dns_fin_cb (gpointer arg)
 {
        struct rspamd_dns_request *req = arg;
        
+       if (req->sock != -1) {
+               close (req->sock);
+       }
        event_del (&req->timer_event);
+       event_del (&req->io_event);
        g_hash_table_remove (req->resolver->requests, &req->key);
 }
 
@@ -1219,11 +1245,10 @@ dns_parse_rr (guint8 *in, union rspamd_reply_element *elt, guint8 **pos, struct
 
 static gboolean
 dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver,
-               guint16 port, struct rspamd_dns_request **req_out,
+               guint16 port, struct rspamd_dns_request *req,
                struct rspamd_dns_reply **_rep)
 {
        struct dns_header *header = (struct dns_header *)in;
-       struct rspamd_dns_request      *req;
        struct rspamd_dns_reply        *rep;
        union rspamd_reply_element     *elt;
        guint8                         *pos;
@@ -1240,16 +1265,19 @@ dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver,
        /* Now try to find corresponding request */
        id = header->qid;
        key = ((guint32)port) << 16 + id;
-       if ((req = g_hash_table_lookup (resolver->requests, &key)) == NULL) {
+       if (g_hash_table_lookup (resolver->requests, &key) == NULL) {
                /* No such requests found */
+               msg_info ("corrupted DNS packet received, must have %d:%d, but %d:%d was expected",
+                               (gint)id, (gint)port, (gint)req->id, (gint)req->port);
                return FALSE;
        }
-       *req_out = req;
        /* 
         * Now we have request and query data is now at the end of header, so compare
         * request QR section and reply QR section
         */
-       if ((pos = dns_request_reply_cmp (req, in + sizeof (struct dns_header), r - sizeof (struct dns_header))) == NULL) {
+       if ((pos = dns_request_reply_cmp (req, in + sizeof (struct dns_header),
+                       r - sizeof (struct dns_header))) == NULL) {
+               msg_info ("query and answer differs, skip DNS packet");
                return FALSE;
        }
        /*
@@ -1311,8 +1339,7 @@ dns_check_throttling (struct rspamd_dns_resolver *resolver)
 static void
 dns_read_cb (gint fd, short what, void *arg)
 {
-       struct rspamd_dns_resolver     *resolver = arg;
-       struct rspamd_dns_request      *req = NULL;
+       struct rspamd_dns_request      *req = arg;
        gint                            r;
        struct rspamd_dns_reply        *rep;
        guint8                          in[UDP_PACKET_SIZE];
@@ -1323,24 +1350,20 @@ dns_read_cb (gint fd, short what, void *arg)
        /* This function is called each time when we have data on one of server's sockets */
        
        /* First read packet from socket */
-       r = recvfrom (fd, in, sizeof (in), 0, &su.sa, &slen);
-       if (r > (gint)(sizeof (struct dns_header) + sizeof (struct dns_query))) {
-               if (su.sa.sa_family == AF_INET) {
-                       port = ntohs (su.s4.sin_port);
-               }
-               else if (su.sa.sa_family == AF_INET6) {
-                       port = ntohs (su.s6.sin6_port);
-               }
-               if (dns_parse_reply (in, r, resolver, port, &req, &rep)) {
-                       /* Decrease errors count */
-                       if (rep->request->resolver->errors > 0) {
-                               rep->request->resolver->errors --;
+       if (what == EV_READ) {
+               r = recvfrom (fd, in, sizeof (in), 0, &su.sa, &slen);
+               if (r > (gint)(sizeof (struct dns_header) + sizeof (struct dns_query))) {
+                       if (dns_parse_reply (in, r, req->resolver, req->port, req, &rep)) {
+                               /* Decrease errors count */
+                               if (rep->request->resolver->errors > 0) {
+                                       rep->request->resolver->errors --;
+                               }
+                               upstream_ok (&rep->request->server->up, rep->request->time);
+                               rep->request->func (rep, rep->request->arg);
                        }
-                       upstream_ok (&rep->request->server->up, rep->request->time);
-                       rep->request->func (rep, rep->request->arg);
-                       remove_normal_event (req->session, dns_fin_cb, req);
                }
        }
+       remove_normal_event (req->session, dns_fin_cb, req);
 }
 
 static void
@@ -1353,7 +1376,8 @@ dns_timer_cb (gint fd, short what, void *arg)
        /* Retransmit dns request */
        req->retransmits ++;
        if (req->retransmits >= req->resolver->max_retransmits) {
-               msg_err ("maximum number of retransmits expired for resolving %s of type %s", req->requested_name, dns_strtype (req->type));
+               msg_err ("maximum number of retransmits expired for resolving %s of type %s",
+                               req->requested_name, dns_strtype (req->type));
                rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
                rep->request = req;
                rep->code = DNS_RC_SERVFAIL;
@@ -1387,11 +1411,7 @@ dns_timer_cb (gint fd, short what, void *arg)
                return;
        }
        
-       if (req->server->sock == -1) {
-               req->server->sock =  make_universal_socket (req->server->name,
-                               dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
-       }
-       req->sock = req->server->sock;
+       req->sock = make_dns_socket (req->server);
 
        if (req->sock == -1) {
                rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
@@ -1549,11 +1569,7 @@ make_dns_request (struct rspamd_dns_resolver *resolver,
                return FALSE;
        }
        
-       if (req->server->sock == -1) {
-               req->server->sock =  make_universal_socket (req->server->name,
-                               dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
-       }
-       req->sock = req->server->sock;
+       req->sock = make_dns_socket (req->server);
 
        if (req->sock == -1) {
                return FALSE;
@@ -1565,21 +1581,27 @@ make_dns_request (struct rspamd_dns_resolver *resolver,
        event_base_set (req->resolver->ev_base, &req->timer_event);
        
        /* Now send request to server */
+       req->id = dns_k_permutor_step (resolver->permutor);
+       req->port = rspamd_bind_to_random_port (req->sock, req->server->addr.sa.sa_family);
+       req->key = ((guint32)req->port) << 16 + req->id;
+       /* Add request to hash table */
+       header = (struct dns_header *)req->packet;
+       while (g_hash_table_lookup (resolver->requests, &req->key)) {
+               /* Check for unique id */
+               req->id = dns_k_permutor_step (resolver->permutor);
+               req->key ^= req->id;
+       }
+       header->qid = req->id;
        r = send_dns_request (req);
 
        if (r == 1) {
                /* Add timer event */
                evtimer_add (&req->timer_event, &req->tv);
 
-               /* Add request to hash table */
-               while (g_hash_table_lookup (resolver->requests, &req->key)) {
-                       /* Check for unique id */
-                       header = (struct dns_header *)req->packet;
-                       header->qid = dns_k_permutor_step (resolver->permutor);
-                       req->id = header->qid;
-               }
+
                g_hash_table_insert (resolver->requests, &req->key, req);
-               register_async_event (session, (event_finalizer_t)dns_fin_cb, req, g_quark_from_static_string ("dns resolver"));
+               register_async_event (session, (event_finalizer_t)dns_fin_cb, req,
+                               g_quark_from_static_string ("dns resolver"));
        }
        else if (r == -1) {
                return FALSE;
@@ -1594,7 +1616,7 @@ static gboolean
 parse_resolv_conf (struct rspamd_dns_resolver *resolver)
 {
        FILE                           *r;
-       gchar                           buf[BUFSIZ], *p, addr_holder[16];
+       gchar                           buf[BUFSIZ], *p;
        struct rspamd_dns_server       *new;
 
        r = fopen (RESOLV_CONF, "r");
@@ -1617,9 +1639,16 @@ parse_resolv_conf (struct rspamd_dns_resolver *resolver)
                                        continue;
                                }
                                else {
-                                       if (inet_pton (AF_INET6, p, addr_holder) == 1 ||
-                                                       inet_pton (AF_INET, p, addr_holder) == 1) {
-                                               new = &resolver->servers[resolver->servers_num];
+                                       new = &resolver->servers[resolver->servers_num];
+                                       if (inet_pton (AF_INET6, p, &new->addr.s6.sin6_addr) == 1) {
+                                               new->addr.sa.sa_family = AF_INET6;
+                                               new->addr.s6.sin6_port = htons (dns_port);
+                                               new->name = memory_pool_strdup (resolver->static_pool, p);
+                                               resolver->servers_num ++;
+                                       }
+                                       else if (inet_pton (AF_INET, p, &new->addr.s4.sin_addr) == 1) {
+                                               new->addr.sa.sa_family = AF_INET;
+                                               new->addr.s4.sin_port = htons (dns_port);
                                                new->name = memory_pool_strdup (resolver->static_pool, p);
                                                resolver->servers_num ++;
                                        }
@@ -1655,8 +1684,8 @@ dns_resolver_init (struct event_base *ev_base, struct config_file *cfg)
 {
        GList                          *cur;
        struct rspamd_dns_resolver     *new;
-       gchar                          *begin, *p, *err, addr_holder[16];
-       gint                            priority, i;
+       gchar                          *begin, *p, *err;
+       gint                            priority;
        struct rspamd_dns_server       *serv;
        
        new = memory_pool_alloc0 (cfg->cfg_pool, sizeof (struct rspamd_dns_resolver));
@@ -1711,8 +1740,16 @@ dns_resolver_init (struct event_base *ev_base, struct config_file *cfg)
                                priority = 0;
                        }
                        serv = &new->servers[new->servers_num];
-                       if (inet_pton (AF_INET6, p, addr_holder) == 1 ||
-                               inet_pton (AF_INET, p, addr_holder) == 1) {
+                       if (inet_pton (AF_INET6, p, &serv->addr.s6.sin6_addr) == 1) {
+                               serv->addr.sa.sa_family = AF_INET6;
+                               serv->addr.s6.sin6_port = htons (dns_port);
+                               serv->name = memory_pool_strdup (new->static_pool, begin);
+                               serv->up.priority = priority;
+                               new->servers_num ++;
+                       }
+                       else if (inet_pton (AF_INET, p, &serv->addr.s4.sin_addr) == 1) {
+                               serv->addr.sa.sa_family = AF_INET;
+                               serv->addr.s4.sin_port = htons (dns_port);
                                serv->name = memory_pool_strdup (new->static_pool, begin);
                                serv->up.priority = priority;
                                new->servers_num ++;
@@ -1734,20 +1771,6 @@ dns_resolver_init (struct event_base *ev_base, struct config_file *cfg)
                }
 
        }
-       /* Now init all servers */
-       for (i = 0; i < new->servers_num; i ++) {
-               serv = &new->servers[i];
-               serv->sock = make_universal_socket (serv->name, dns_port,
-                               SOCK_DGRAM, TRUE, FALSE, FALSE);
-               if (serv->sock == -1) {
-                       msg_warn ("cannot create socket to server %s", serv->name);
-               }
-               else {
-                       event_set (&serv->ev, serv->sock, EV_READ | EV_PERSIST, dns_read_cb, new);
-                       event_base_set (new->ev_base, &serv->ev);
-                       event_add (&serv->ev, NULL);
-               }
-       }
 
        return new;
 }
index 4889d271cd0a59307a7a350023e6b27d9c0fc4e9..6a6cda358a0221eb7928228a62faf29adef3bf7b 100644 (file)
--- a/src/dns.h
+++ b/src/dns.h
@@ -5,6 +5,7 @@
 #include "mem_pool.h"
 #include "events.h"
 #include "upstream.h"
+#include "util.h"
 
 #define MAX_SERVERS 16
 
@@ -23,7 +24,7 @@ typedef void (*dns_callback_type) (struct rspamd_dns_reply *reply, gpointer arg)
 struct rspamd_dns_server {
        struct upstream up;                                     /**< upstream structure                                         */
        gchar *name;                                                    /**< name of DNS server                                         */
-       gint sock;                                                      /**< persistent socket                                          */
+       union sa_union addr;                            /**< address storage                                            */
        struct event ev;
 };
 
index a104eb116f779ca215440ba18f8cf91c464a426b..e3a13ee650d31813b6a049a5cf973601235a11f0 100644 (file)
@@ -121,16 +121,6 @@ struct process_exception {
        gsize len;
 };
 
-/**
- * Union that would be used for storing sockaddrs
- */
-union sa_union {
-       struct sockaddr_storage ss;
-       struct sockaddr sa;
-       struct sockaddr_in s4;
-       struct sockaddr_in6 s6;
-};
-
 /**
  * Control session object
  */
index 4edf5cb9560a8aea196669cb23b47608cd87ba1a..edcfa612230bc6f1a404071e4e7c9eeac25f6053 100644 (file)
@@ -14,6 +14,17 @@ struct workq;
 struct statfile;
 struct classifier_config;
 
+/**
+ * Union that is used for storing sockaddrs
+ */
+union sa_union {
+       struct sockaddr_storage ss;
+       struct sockaddr sa;
+       struct sockaddr_in s4;
+       struct sockaddr_in6 s6;
+       struct sockaddr_un su;
+};
+
 /*
  * Create socket and bind or connect it to specified address and port
  */
index bb43afba8d5b193515e68af4e1d57ae2c9679b94..95355bdd3fde572516db207576cb139fd9ae8002 100644 (file)
@@ -502,8 +502,9 @@ accept_socket (gint fd, short what, void *arg)
        struct rspamd_worker_ctx       *ctx;
        union sa_union                  su;
        struct worker_task             *new_task;
+       char                            ip_str[INET6_ADDRSTRLEN + 1];
 
-       socklen_t                       addrlen = sizeof (su.ss);
+       socklen_t                       addrlen = sizeof (su);
        gint                            nfd;
 
        ctx = worker->ctx;
@@ -514,7 +515,7 @@ accept_socket (gint fd, short what, void *arg)
        }
 
        if ((nfd =
-                       accept_from_socket (fd, (struct sockaddr *) &su.ss, &addrlen)) == -1) {
+                       accept_from_socket (fd, &su.sa, &addrlen)) == -1) {
                msg_warn ("accept failed: %s", strerror (errno));
                return;
        }
@@ -525,16 +526,21 @@ accept_socket (gint fd, short what, void *arg)
 
        new_task = construct_task (worker);
 
-       if (su.ss.ss_family == AF_UNIX) {
+       if (su.sa.sa_family == AF_UNIX) {
                msg_info ("accepted connection from unix socket");
                new_task->client_addr.s_addr = INADDR_NONE;
        }
-       else if (su.ss.ss_family == AF_INET) {
+       else if (su.sa.sa_family == AF_INET) {
                msg_info ("accepted connection from %s port %d",
                                inet_ntoa (su.s4.sin_addr), ntohs (su.s4.sin_port));
                memcpy (&new_task->client_addr, &su.s4.sin_addr,
                                sizeof (struct in_addr));
        }
+       else if (su.sa.sa_family == AF_INET6) {
+               msg_info ("accepted connection from %s port %d",
+                               inet_ntop (su.sa.sa_family, &su.s6.sin6_addr, ip_str, sizeof (ip_str)),
+                               ntohs (su.s6.sin6_port));
+       }
 
        /* Copy some variables */
        new_task->sock = nfd;
index e8d8f7423a29b784979f6f956414a90d136aa0ff..599d098c99f3f79ddebc63ca07a6a2b8ab48969d 100644 (file)
@@ -152,7 +152,6 @@ free_task (struct worker_task *task, gboolean is_soft)
                if (task->received) {
                        g_list_free (task->received);
                }
-               memory_pool_delete (task->task_pool);
                if (task->dispatcher) {
                        if (is_soft) {
                                /* Plan dispatcher shutdown */
@@ -165,6 +164,7 @@ free_task (struct worker_task *task, gboolean is_soft)
                if (task->sock != -1) {
                        close (task->sock);
                }
+               memory_pool_delete (task->task_pool);
                g_slice_free1 (sizeof (struct worker_task), task);
        }
 }