]> source.dussan.org Git - rspamd.git/commitdiff
Use sockets pool for DNS requests.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 18 Dec 2013 11:11:36 +0000 (11:11 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 18 Dec 2013 11:11:36 +0000 (11:11 +0000)
Inspired by: Vadim Goncharov

src/cfg_file.h
src/cfg_utils.c
src/dns.c
src/dns.h

index ed45d7d95adc78190590b1bb9e4ebf815dd627d6..41c37bd34439d4b9d5d8ae94607b97d5378cede4 100644 (file)
@@ -36,9 +36,6 @@
 #define DEFAULT_SCORE 10.0
 #define DEFAULT_REJECT_SCORE 999.0
 
-#define yyerror parse_err
-#define yywarn parse_warn
-
 struct expression;
 struct tokenizer;
 struct classifier;
@@ -379,6 +376,7 @@ struct config_file {
        guint32 dns_retransmits;                                                /**< maximum retransmits count                                                  */
        guint32 dns_throttling_errors;                                  /**< maximum errors for starting resolver throttling    */
        guint32 dns_throttling_time;                                    /**< time in seconds for DNS throttling                                 */
+       guint32 dns_io_per_server;                                              /**< number of sockets per DNS server                                   */
        GList *nameservers;                                                             /**< list of nameservers or NULL to parse resolv.conf   */
 };
 
index 181e44f51ff88a3a0c33daadbc008169943b3af8..d6c2b3fc99cfb17f92fc1929494ea0fa2f2c9f53 100644 (file)
@@ -233,6 +233,8 @@ init_defaults (struct config_file *cfg)
        /* After 20 errors do throttling for 10 seconds */
        cfg->dns_throttling_errors = 20;
        cfg->dns_throttling_time = 10000;
+       /* 16 sockets per DNS server */
+       cfg->dns_io_per_server = 16;
 
        cfg->statfile_sync_interval = 60000;
        cfg->statfile_sync_timeout = 20000;
index 9328c57a7c19023396d2c15626693be461de9c20..77063d9c1c06795f50af731da0abbf5531e173c0 100644 (file)
--- a/src/dns.c
+++ b/src/dns.c
@@ -33,6 +33,7 @@
 #include "config.h"
 #include "dns.h"
 #include "main.h"
+#include "utlist.h"
 #ifdef HAVE_OPENSSL
 #include <openssl/rand.h>
 #endif
@@ -767,6 +768,7 @@ static gint
 send_dns_request (struct rspamd_dns_request *req)
 {
        gint r;
+       struct rspamd_dns_server *serv = req->io->srv;
 
        r = send (req->sock, req->packet, req->pos, 0);
        if (r == -1) {
@@ -774,12 +776,13 @@ send_dns_request (struct rspamd_dns_request *req)
                        event_set (&req->io_event, req->sock, EV_WRITE, dns_retransmit_handler, req);
                        event_base_set (req->resolver->ev_base, &req->io_event);
                        event_add (&req->io_event, &req->tv);
-                       register_async_event (req->session, (event_finalizer_t)event_del, &req->io_event, g_quark_from_static_string ("dns resolver"));
+                       register_async_event (req->session, (event_finalizer_t)event_del, &req->io_event,
+                                       g_quark_from_static_string ("dns resolver"));
                        return 0;
                } 
                else {
-                       msg_err ("send failed: %s for server %s", strerror (errno), req->server->name);
-                       upstream_fail (&req->server->up, req->time);
+                       msg_err ("send failed: %s for server %s", strerror (errno), serv->name);
+                       upstream_fail (&serv->up, req->time);
                        return -1;
                }
        }
@@ -787,7 +790,8 @@ send_dns_request (struct rspamd_dns_request *req)
                event_set (&req->io_event, req->sock, EV_WRITE, dns_retransmit_handler, req);
                event_base_set (req->resolver->ev_base, &req->io_event);
                event_add (&req->io_event, &req->tv);
-               register_async_event (req->session, (event_finalizer_t)event_del, &req->io_event, g_quark_from_static_string ("dns resolver"));
+               register_async_event (req->session, (event_finalizer_t)event_del, &req->io_event,
+                               g_quark_from_static_string ("dns resolver"));
                return 0;
        }
        
@@ -800,7 +804,7 @@ dns_fin_cb (gpointer arg)
        struct rspamd_dns_request *req = arg;
        
        event_del (&req->timer_event);
-       g_hash_table_remove (req->resolver->requests, &req->id);
+       g_hash_table_remove (req->io->requests, &req->id);
 }
 
 static guint8 *
@@ -1187,12 +1191,13 @@ dns_parse_rr (guint8 *in, union rspamd_reply_element *elt, guint8 **pos, struct
 }
 
 static gboolean
-dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver,
+dns_parse_reply (gint sock, guint8 *in, gint r, struct rspamd_dns_resolver *resolver,
                struct rspamd_dns_request **req_out, struct rspamd_dns_reply **_rep)
 {
        struct dns_header *header = (struct dns_header *)in;
        struct rspamd_dns_request      *req;
        struct rspamd_dns_reply        *rep;
+       struct rspamd_dns_io_channel   *ioc;
        union rspamd_reply_element     *elt;
        guint8                         *pos;
        guint16                         id;
@@ -1204,9 +1209,15 @@ dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver,
                return FALSE;
        }
 
+       /* Find io channel */
+       if ((ioc = g_hash_table_lookup (resolver->io_channels, GINT_TO_POINTER (sock))) == NULL) {
+               msg_err ("io channel is not found for this resolver: %d", sock);
+               return FALSE;
+       }
+
        /* Now try to find corresponding request */
        id = header->qid;
-       if ((req = g_hash_table_lookup (resolver->requests, &id)) == NULL) {
+       if ((req = g_hash_table_lookup (ioc->requests, &id)) == NULL) {
                /* No such requests found */
                return FALSE;
        }
@@ -1215,7 +1226,8 @@ dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver,
         * Now we have request and query data is now at the end of header, so compare
         * request QR section and reply QR section
         */
-       if ((pos = dns_request_reply_cmp (req, in + sizeof (struct dns_header), r - sizeof (struct dns_header))) == NULL) {
+       if ((pos = dns_request_reply_cmp (req, in + sizeof (struct dns_header),
+                       r - sizeof (struct dns_header))) == NULL) {
                return FALSE;
        }
        /*
@@ -1288,12 +1300,12 @@ dns_read_cb (gint fd, short what, void *arg)
        /* First read packet from socket */
        r = read (fd, in, sizeof (in));
        if (r > (gint)(sizeof (struct dns_header) + sizeof (struct dns_query))) {
-               if (dns_parse_reply (in, r, resolver, &req, &rep)) {
+               if (dns_parse_reply (fd, in, r, resolver, &req, &rep)) {
                        /* Decrease errors count */
                        if (rep->request->resolver->errors > 0) {
                                rep->request->resolver->errors --;
                        }
-                       upstream_ok (&rep->request->server->up, rep->request->time);
+                       upstream_ok (&rep->request->io->srv->up, rep->request->time);
                        rep->request->func (rep, rep->request->arg);
                        remove_normal_event (req->session, dns_fin_cb, req);
                }
@@ -1305,73 +1317,66 @@ dns_timer_cb (gint fd, short what, void *arg)
 {
        struct rspamd_dns_request *req = arg;
        struct rspamd_dns_reply *rep;
+       struct rspamd_dns_server *serv;
        gint                            r;
        
        /* Retransmit dns request */
        req->retransmits ++;
+       serv = req->io->srv;
        if (req->retransmits >= req->resolver->max_retransmits) {
                msg_err ("maximum number of retransmits expired for resolving %s of type %s", req->requested_name, dns_strtype (req->type));
-               rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
-               rep->request = req;
-               rep->code = DNS_RC_SERVFAIL;
-               upstream_fail (&rep->request->server->up, rep->request->time);
                dns_check_throttling (req->resolver);
                req->resolver->errors ++;
-
-               req->func (rep, req->arg);
-               remove_normal_event (req->session, dns_fin_cb, req);
-
-               return;
+               goto err;
        }
        /* Select other server */
        if (req->resolver->is_master_slave) {
-               req->server = (struct rspamd_dns_server *)get_upstream_master_slave (req->resolver->servers,
+               serv = (struct rspamd_dns_server *)get_upstream_master_slave (req->resolver->servers,
                                        req->resolver->servers_num, sizeof (struct rspamd_dns_server),
                                        req->time, DEFAULT_UPSTREAM_ERROR_TIME, DEFAULT_UPSTREAM_DEAD_TIME, DEFAULT_UPSTREAM_MAXERRORS);
        }
        else {
-               req->server = (struct rspamd_dns_server *)get_upstream_round_robin (req->resolver->servers,
+               serv = (struct rspamd_dns_server *)get_upstream_round_robin (req->resolver->servers,
                        req->resolver->servers_num, sizeof (struct rspamd_dns_server),
                        req->time, DEFAULT_UPSTREAM_ERROR_TIME, DEFAULT_UPSTREAM_DEAD_TIME, DEFAULT_UPSTREAM_MAXERRORS);
        }
-       if (req->server == NULL) {
-               rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
-               rep->request = req;
-               rep->code = DNS_RC_SERVFAIL;
-
-               req->func (rep, req->arg);
-               remove_normal_event (req->session, dns_fin_cb, req);
-               return;
+       if (serv == NULL) {
+               goto err;
        }
+
+       req->io = serv->cur_io_channel;
+       if (req->io == NULL) {
+               msg_err ("cannot find suitable io channel for the server %s", serv->name);
+               goto err;
+       }
+       serv->cur_io_channel = serv->cur_io_channel->next;
        
-       if (req->server->sock == -1) {
-               req->server->sock =  make_universal_socket (req->server->name, dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
+       if (req->io->sock == -1) {
+               req->io->sock =  make_universal_socket (serv->name, dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
        }
-       req->sock = req->server->sock;
+       req->sock = req->io->sock;
 
        if (req->sock == -1) {
-               rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
-               rep->request = req;
-               rep->code = DNS_RC_SERVFAIL;
-               upstream_fail (&rep->request->server->up, rep->request->time);
-
-               req->func (rep, req->arg);
-               remove_normal_event (req->session, dns_fin_cb, req);
-
-               return;
+               goto err;
        }
        /* Add other retransmit event */
        r = send_dns_request (req);
        if (r == -1) {
-               rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
-               rep->request = req;
-               rep->code = DNS_RC_SERVFAIL;
-               upstream_fail (&rep->request->server->up, rep->request->time);
-               req->func (rep, req->arg);
-               remove_normal_event (req->session, dns_fin_cb, req);
-               return;
+               goto err;
        }
        evtimer_add (&req->timer_event, &req->tv);
+
+       return;
+err:
+       rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
+       rep->request = req;
+       rep->code = DNS_RC_SERVFAIL;
+       if (serv) {
+               upstream_fail (&serv->up, rep->request->time);
+       }
+       req->func (rep, req->arg);
+       remove_normal_event (req->session, dns_fin_cb, req);
+       return;
 }
 
 static void
@@ -1379,9 +1384,11 @@ dns_retransmit_handler (gint fd, short what, void *arg)
 {
        struct rspamd_dns_request *req = arg;
        struct rspamd_dns_reply *rep;
+       struct rspamd_dns_server *serv;
        gint r;
 
        remove_normal_event (req->session, (event_finalizer_t)event_del, &req->io_event);
+       serv = req->io->srv;
 
        if (what == EV_WRITE) {
                /* Retransmit dns request */
@@ -1392,7 +1399,7 @@ dns_retransmit_handler (gint fd, short what, void *arg)
                        rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
                        rep->request = req;
                        rep->code = DNS_RC_SERVFAIL;
-                       upstream_fail (&rep->request->server->up, rep->request->time);
+                       upstream_fail (&serv->up, rep->request->time);
                        req->resolver->errors ++;
                        dns_check_throttling (req->resolver);
 
@@ -1405,7 +1412,7 @@ dns_retransmit_handler (gint fd, short what, void *arg)
                        rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
                        rep->request = req;
                        rep->code = DNS_RC_SERVFAIL;
-                       upstream_fail (&rep->request->server->up, rep->request->time);
+                       upstream_fail (&serv->up, rep->request->time);
                        req->func (rep, req->arg);
 
                }
@@ -1417,8 +1424,9 @@ dns_retransmit_handler (gint fd, short what, void *arg)
                        evtimer_add (&req->timer_event, &req->tv);
 
                        /* Add request to hash table */
-                       g_hash_table_insert (req->resolver->requests, &req->id, req);
-                       register_async_event (req->session, (event_finalizer_t)dns_fin_cb, req, g_quark_from_static_string ("dns resolver"));
+                       g_hash_table_insert (req->io->requests, &req->id, req);
+                       register_async_event (req->session, (event_finalizer_t)dns_fin_cb,
+                                       req, g_quark_from_static_string ("dns resolver"));
                }
        }
 }
@@ -1430,9 +1438,11 @@ make_dns_request (struct rspamd_dns_resolver *resolver,
 {
        va_list args;
        struct rspamd_dns_request *req;
+       struct rspamd_dns_server *serv;
        struct in_addr *addr;
        const gchar *name, *service, *proto;
        gint r;
+       const gint max_id_cycles = 32;
        struct dns_header *header;
 
        /* If no DNS servers defined silently return FALSE */
@@ -1495,24 +1505,29 @@ make_dns_request (struct rspamd_dns_resolver *resolver,
        req->retransmits = 0;
        req->time = time (NULL);
        if (resolver->is_master_slave) {
-               req->server = (struct rspamd_dns_server *)get_upstream_master_slave (resolver->servers,
+               serv = (struct rspamd_dns_server *)get_upstream_master_slave (resolver->servers,
                                resolver->servers_num, sizeof (struct rspamd_dns_server),
                                req->time, DEFAULT_UPSTREAM_ERROR_TIME, DEFAULT_UPSTREAM_DEAD_TIME, DEFAULT_UPSTREAM_MAXERRORS);
        }
        else {
-               req->server = (struct rspamd_dns_server *)get_upstream_round_robin (resolver->servers,
+               serv = (struct rspamd_dns_server *)get_upstream_round_robin (resolver->servers,
                                resolver->servers_num, sizeof (struct rspamd_dns_server),
                                req->time, DEFAULT_UPSTREAM_ERROR_TIME, DEFAULT_UPSTREAM_DEAD_TIME, DEFAULT_UPSTREAM_MAXERRORS);
        }
-       if (req->server == NULL) {
+       if (serv == NULL) {
                msg_err ("cannot find suitable server for request");
                return FALSE;
        }
        
-       if (req->server->sock == -1) {
-               req->server->sock =  make_universal_socket (req->server->name, dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
+       /* Now select IO channel */
+
+       req->io = serv->cur_io_channel;
+       if (req->io == NULL) {
+               msg_err ("cannot find suitable io channel for the server %s", serv->name);
+               return FALSE;
        }
-       req->sock = req->server->sock;
+       serv->cur_io_channel = serv->cur_io_channel->next;
+       req->sock = req->io->sock;
 
        if (req->sock == -1) {
                return FALSE;
@@ -1531,14 +1546,20 @@ make_dns_request (struct rspamd_dns_resolver *resolver,
                evtimer_add (&req->timer_event, &req->tv);
 
                /* Add request to hash table */
-               while (g_hash_table_lookup (resolver->requests, &req->id)) {
+               r = 0;
+               while (g_hash_table_lookup (req->io->requests, &req->id)) {
                        /* Check for unique id */
                        header = (struct dns_header *)req->packet;
                        header->qid = dns_k_permutor_step (resolver->permutor);
                        req->id = header->qid;
+                       if (++r > max_id_cycles) {
+                               msg_err ("cannot generate new id for server %s", serv->name);
+                               return FALSE;
+                       }
                }
-               g_hash_table_insert (resolver->requests, &req->id, req);
-               register_async_event (session, (event_finalizer_t)dns_fin_cb, req, g_quark_from_static_string ("dns resolver"));
+               g_hash_table_insert (req->io->requests, &req->id, req);
+               register_async_event (session, (event_finalizer_t)dns_fin_cb, req,
+                               g_quark_from_static_string ("dns resolver"));
        }
        else if (r == -1) {
                return FALSE;
@@ -1615,14 +1636,15 @@ dns_resolver_init (struct event_base *ev_base, struct config_file *cfg)
        GList                          *cur;
        struct rspamd_dns_resolver     *new;
        gchar                          *begin, *p, *err, addr_holder[16];
-       gint                            priority, i;
+       gint                            priority, i, j;
        struct rspamd_dns_server       *serv;
+       struct rspamd_dns_io_channel   *ioc;
        
        new = memory_pool_alloc0 (cfg->cfg_pool, sizeof (struct rspamd_dns_resolver));
        new->ev_base = ev_base;
-       new->requests = g_hash_table_new (dns_id_hash, dns_id_equal);
        new->permutor = memory_pool_alloc (cfg->cfg_pool, sizeof (struct dns_k_permutor));
        dns_k_permutor_init (new->permutor, 0, G_MAXUINT16);
+       new->io_channels = g_hash_table_new (g_direct_hash, g_direct_equal);
        new->static_pool = cfg->cfg_pool;
        new->request_timeout = cfg->dns_timeout;
        new->max_retransmits = cfg->dns_retransmits;
@@ -1693,17 +1715,28 @@ dns_resolver_init (struct event_base *ev_base, struct config_file *cfg)
                }
 
        }
-       /* Now init all servers */
+       /* Now init io channels to all servers */
        for (i = 0; i < new->servers_num; i ++) {
                serv = &new->servers[i];
-               serv->sock = make_universal_socket (serv->name, dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
-               if (serv->sock == -1) {
-                       msg_warn ("cannot create socket to server %s", serv->name);
-               }
-               else {
-                       event_set (&serv->ev, serv->sock, EV_READ | EV_PERSIST, dns_read_cb, new);
-                       event_base_set (new->ev_base, &serv->ev);
-                       event_add (&serv->ev, NULL);
+               for (j = 0; j < (gint)cfg->dns_io_per_server; j ++) {
+                       ioc = memory_pool_alloc (new->static_pool, sizeof (struct rspamd_dns_io_channel));
+                       ioc->sock = make_universal_socket (serv->name, dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
+                       if (ioc->sock == -1) {
+                               msg_warn ("cannot create socket to server %s", serv->name);
+                       }
+                       else {
+                               ioc->requests = g_hash_table_new (dns_id_hash, dns_id_equal);
+                               memory_pool_add_destructor (new->static_pool, (pool_destruct_func)g_hash_table_unref,
+                                               ioc->requests);
+                               ioc->srv = serv;
+                               ioc->resolver = new;
+                               event_set (&ioc->ev, ioc->sock, EV_READ | EV_PERSIST, dns_read_cb, new);
+                               event_base_set (new->ev_base, &ioc->ev);
+                               event_add (&ioc->ev, NULL);
+                               CDL_PREPEND (serv->io_channels, ioc);
+                               serv->cur_io_channel = ioc;
+                               g_hash_table_insert (new->io_channels, GINT_TO_POINTER (ioc->sock), ioc);
+                       }
                }
        }
 
index b04aa285f33c3a8727a648e54c89c060dd25c229..0b701391b77c9faa18b3bfc031023cfbed776836 100644 (file)
--- a/src/dns.h
+++ b/src/dns.h
@@ -17,14 +17,27 @@ struct rspamd_dns_reply;
 struct config_file;
 
 typedef void (*dns_callback_type) (struct rspamd_dns_reply *reply, gpointer arg);
+
 /**
- * Implements DNS server
+ * Represents DNS server
  */
 struct rspamd_dns_server {
        struct upstream up;                                     /**< upstream structure                                         */
        gchar *name;                                                    /**< name of DNS server                                         */
+       struct rspamd_dns_io_channel *io_channels;
+       struct rspamd_dns_io_channel *cur_io_channel;
+};
+
+/**
+ * IO channel for a specific DNS server
+ */
+struct rspamd_dns_io_channel {
+       struct rspamd_dns_server *srv;
+       struct rspamd_dns_resolver *resolver;
        gint sock;                                                      /**< persistent socket                                          */
        struct event ev;
+       GHashTable *requests;                           /**< requests in flight                                         */
+       struct rspamd_dns_io_channel *prev, *next;
 };
 
 #define DNS_K_TEA_KEY_SIZE     16
@@ -44,11 +57,11 @@ struct dns_k_permutor {
 struct rspamd_dns_resolver {
        struct rspamd_dns_server servers[MAX_SERVERS];
        gint servers_num;                                       /**< number of DNS servers registered           */
-       GHashTable *requests;                           /**< requests in flight                                         */
        struct dns_k_permutor *permutor;        /**< permutor for randomizing request id        */
        guint request_timeout;
        guint max_retransmits;
        guint max_errors;
+       GHashTable *io_channels;                        /**< hash of io chains indexed by socket        */
        memory_pool_t *static_pool;                     /**< permament pool (cfg_pool)                          */
        gboolean throttling;                            /**< dns servers are busy                                       */
        gboolean is_master_slave;                       /**< if this is true, then select upstreams as master/slave */
@@ -74,7 +87,7 @@ enum rspamd_request_type {
 struct rspamd_dns_request {
        memory_pool_t *pool;                            /**< pool associated with request                       */
        struct rspamd_dns_resolver *resolver;
-       struct rspamd_dns_server *server;
+       struct rspamd_dns_io_channel *io;
        dns_callback_type func;
        gpointer arg;
        struct event timer_event;
@@ -234,12 +247,12 @@ struct dns_query {
 
 /* Rspamd DNS API */
 
-/*
+/**
  * Init DNS resolver, params are obtained from a config file or system file /etc/resolv.conf
  */
 struct rspamd_dns_resolver *dns_resolver_init (struct event_base *ev_base, struct config_file *cfg);
 
-/*
+/**
  * Make a DNS request
  * @param resolver resolver object
  * @param session async session to register event
@@ -254,12 +267,12 @@ gboolean make_dns_request (struct rspamd_dns_resolver *resolver,
                struct rspamd_async_session *session, memory_pool_t *pool, dns_callback_type cb, 
                gpointer ud, enum rspamd_request_type type, ...);
 
-/*
+/**
  * Get textual presentation of DNS error code
  */
 const gchar *dns_strerror (enum dns_rcode rcode);
 
-/*
+/**
  * Get textual presentation of DNS request type
  */
 const gchar *dns_strtype (enum rspamd_request_type type);