From: Vsevolod Stakhov Date: Tue, 10 Dec 2013 16:01:14 +0000 (+0000) Subject: Rework rspamd DNS resolver. X-Git-Tag: 0.6.3~3 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=546d3478737b90f774d566b49b819b3d96bc121b;p=rspamd.git Rework rspamd DNS resolver. --- diff --git a/src/dns.c b/src/dns.c index a6e87aa32..996bc086e 100644 --- a/src/dns.c +++ b/src/dns.c @@ -52,14 +52,11 @@ static const unsigned initial_bias = 72; static const gint dns_port = 53; +static void dns_read_cb (gint fd, short what, void *arg); +static void dns_retransmit_handler (gint fd, short what, void *arg); + +#define DNS_RANDOM g_random_int -#ifdef HAVE_ARC4RANDOM -#define DNS_RANDOM arc4random -#elif defined HAVE_RANDOM -#define DNS_RANDOM random -#else -#define DNS_RANDOM rand -#endif #define UDP_PACKET_SIZE 4096 @@ -209,8 +206,6 @@ punycode_label_toascii(const gunichar *in, gsize in_len, gchar *out, #define DNS_K_TEA_CYCLES 32 #define DNS_K_TEA_MAGIC 0x9E3779B9U -static void dns_retransmit_handler (gint fd, short what, void *arg); - static void dns_k_tea_init(struct dns_k_tea *tea, guint32 key[], guint cycles) @@ -434,6 +429,30 @@ dns_k_shuffle16 (guint16 n, guint s) return ((0xff00 & (a << 8)) | (0x00ff & (b << 0))); } /* dns_k_shuffle16() */ +static gint +make_dns_socket (struct rspamd_dns_server *serv) +{ + gint sock; + + sock = socket (serv->addr.sa.sa_family, SOCK_DGRAM, 0); + if (sock == -1) { + msg_warn ("socket failed: %d, '%s'", errno, strerror (errno)); + return -1; + } + + if (make_socket_nonblocking (sock) == -1) { + close (sock); + return -1; + } + if (fcntl (sock, F_SETFD, FD_CLOEXEC) == -1) { + msg_warn ("fcntl failed: %d, '%s'", errno, strerror (errno)); + close (sock); + return -1; + } + + return sock; +} + struct dns_request_key { guint16 id; guint16 port; @@ -672,7 +691,7 @@ make_a_req (struct rspamd_dns_request *req, const gchar *name) *p = htons (DNS_C_IN); req->pos += sizeof (guint16) * 2; req->type = DNS_REQUEST_A; - req->requested_name = name; + req->requested_name = memory_pool_strdup (req->pool, name); } #ifdef HAVE_INET_PTON @@ -689,7 +708,7 @@ make_aaa_req (struct rspamd_dns_request *req, const gchar *name) *p = htons (DNS_C_IN); req->pos += sizeof (guint16) * 2; req->type = DNS_REQUEST_AAA; - req->requested_name = name; + req->requested_name = memory_pool_strdup (req->pool, name); } #endif @@ -706,7 +725,7 @@ make_txt_req (struct rspamd_dns_request *req, const gchar *name) *p = htons (DNS_C_IN); req->pos += sizeof (guint16) * 2; req->type = DNS_REQUEST_TXT; - req->requested_name = name; + req->requested_name = memory_pool_strdup (req->pool, name); } static void @@ -722,7 +741,7 @@ make_mx_req (struct rspamd_dns_request *req, const gchar *name) *p = htons (DNS_C_IN); req->pos += sizeof (guint16) * 2; req->type = DNS_REQUEST_MX; - req->requested_name = name; + req->requested_name = memory_pool_strdup (req->pool, name); } static void @@ -744,7 +763,7 @@ make_srv_req (struct rspamd_dns_request *req, const gchar *service, const gchar *p = htons (DNS_C_IN); req->pos += sizeof (guint16) * 2; req->type = DNS_REQUEST_SRV; - req->requested_name = name; + req->requested_name = memory_pool_strdup (req->pool, name); } static void @@ -760,11 +779,11 @@ make_spf_req (struct rspamd_dns_request *req, const gchar *name) *p = htons (DNS_C_IN); req->pos += sizeof (guint16) * 2; req->type = DNS_REQUEST_SPF; - req->requested_name = name; + req->requested_name = memory_pool_strdup (req->pool, name); } static guint16 -rspamd_bind_to_random_port (int sock) +rspamd_bind_to_random_port (int sock, int af) { union sa_union su; socklen_t slen = sizeof (su); @@ -772,21 +791,21 @@ rspamd_bind_to_random_port (int sock) const int max_retries = 10; int retries = 0; - if (getsockname (sock, &su.sa, &slen) != -1) { + memset (&su, 0, sizeof (su)); + su.sa.sa_family = af; - while (retries < max_retries) { - ret = g_random_int_range (1024, G_MAXUINT16 - 1); - if (su.sa.sa_family == AF_INET) { - su.s4.sin_port = htons (ret); - } - else if (su.sa.sa_family == AF_INET6) { - su.s6.sin6_port = htons (ret); - } - if (bind (sock, &su.sa, slen) != -1) { - return ret; - } - retries ++; + while (retries < max_retries) { + ret = g_random_int_range (1024, G_MAXUINT16 - 1); + if (af == AF_INET) { + su.s4.sin_port = htons (ret); } + else if (af == AF_INET6) { + su.s6.sin6_port = htons (ret); + } + if (bind (sock, &su.sa, slen) != -1) { + return ret; + } + retries ++; } return 0; @@ -797,15 +816,14 @@ send_dns_request (struct rspamd_dns_request *req) { gint r; - req->port = rspamd_bind_to_random_port (req->sock); - req->key = ((guint32)req->port) << 16 + req->id; - r = send (req->sock, req->packet, req->pos, 0); + r = sendto (req->sock, req->packet, req->pos, 0, &req->server->addr.sa, + req->server->addr.sa.sa_family == AF_INET ? sizeof (struct sockaddr_in) : + sizeof (struct sockaddr_in6)); if (r == -1) { if (errno == EAGAIN) { event_set (&req->io_event, req->sock, EV_WRITE, dns_retransmit_handler, req); event_base_set (req->resolver->ev_base, &req->io_event); event_add (&req->io_event, &req->tv); - register_async_event (req->session, (event_finalizer_t)event_del, &req->io_event, g_quark_from_static_string ("dns resolver")); return 0; } else { @@ -818,9 +836,13 @@ send_dns_request (struct rspamd_dns_request *req) event_set (&req->io_event, req->sock, EV_WRITE, dns_retransmit_handler, req); event_base_set (req->resolver->ev_base, &req->io_event); event_add (&req->io_event, &req->tv); - register_async_event (req->session, (event_finalizer_t)event_del, &req->io_event, g_quark_from_static_string ("dns resolver")); return 0; } + else { + event_set (&req->io_event, req->sock, EV_READ, dns_read_cb, req); + event_base_set (req->resolver->ev_base, &req->io_event); + event_add (&req->io_event, NULL); + } return 1; } @@ -830,7 +852,11 @@ dns_fin_cb (gpointer arg) { struct rspamd_dns_request *req = arg; + if (req->sock != -1) { + close (req->sock); + } event_del (&req->timer_event); + event_del (&req->io_event); g_hash_table_remove (req->resolver->requests, &req->key); } @@ -1219,11 +1245,10 @@ dns_parse_rr (guint8 *in, union rspamd_reply_element *elt, guint8 **pos, struct static gboolean dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver, - guint16 port, struct rspamd_dns_request **req_out, + guint16 port, struct rspamd_dns_request *req, struct rspamd_dns_reply **_rep) { struct dns_header *header = (struct dns_header *)in; - struct rspamd_dns_request *req; struct rspamd_dns_reply *rep; union rspamd_reply_element *elt; guint8 *pos; @@ -1240,16 +1265,19 @@ dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver, /* Now try to find corresponding request */ id = header->qid; key = ((guint32)port) << 16 + id; - if ((req = g_hash_table_lookup (resolver->requests, &key)) == NULL) { + if (g_hash_table_lookup (resolver->requests, &key) == NULL) { /* No such requests found */ + msg_info ("corrupted DNS packet received, must have %d:%d, but %d:%d was expected", + (gint)id, (gint)port, (gint)req->id, (gint)req->port); return FALSE; } - *req_out = req; /* * Now we have request and query data is now at the end of header, so compare * request QR section and reply QR section */ - if ((pos = dns_request_reply_cmp (req, in + sizeof (struct dns_header), r - sizeof (struct dns_header))) == NULL) { + if ((pos = dns_request_reply_cmp (req, in + sizeof (struct dns_header), + r - sizeof (struct dns_header))) == NULL) { + msg_info ("query and answer differs, skip DNS packet"); return FALSE; } /* @@ -1311,8 +1339,7 @@ dns_check_throttling (struct rspamd_dns_resolver *resolver) static void dns_read_cb (gint fd, short what, void *arg) { - struct rspamd_dns_resolver *resolver = arg; - struct rspamd_dns_request *req = NULL; + struct rspamd_dns_request *req = arg; gint r; struct rspamd_dns_reply *rep; guint8 in[UDP_PACKET_SIZE]; @@ -1323,24 +1350,20 @@ dns_read_cb (gint fd, short what, void *arg) /* This function is called each time when we have data on one of server's sockets */ /* First read packet from socket */ - r = recvfrom (fd, in, sizeof (in), 0, &su.sa, &slen); - if (r > (gint)(sizeof (struct dns_header) + sizeof (struct dns_query))) { - if (su.sa.sa_family == AF_INET) { - port = ntohs (su.s4.sin_port); - } - else if (su.sa.sa_family == AF_INET6) { - port = ntohs (su.s6.sin6_port); - } - if (dns_parse_reply (in, r, resolver, port, &req, &rep)) { - /* Decrease errors count */ - if (rep->request->resolver->errors > 0) { - rep->request->resolver->errors --; + if (what == EV_READ) { + r = recvfrom (fd, in, sizeof (in), 0, &su.sa, &slen); + if (r > (gint)(sizeof (struct dns_header) + sizeof (struct dns_query))) { + if (dns_parse_reply (in, r, req->resolver, req->port, req, &rep)) { + /* Decrease errors count */ + if (rep->request->resolver->errors > 0) { + rep->request->resolver->errors --; + } + upstream_ok (&rep->request->server->up, rep->request->time); + rep->request->func (rep, rep->request->arg); } - upstream_ok (&rep->request->server->up, rep->request->time); - rep->request->func (rep, rep->request->arg); - remove_normal_event (req->session, dns_fin_cb, req); } } + remove_normal_event (req->session, dns_fin_cb, req); } static void @@ -1353,7 +1376,8 @@ dns_timer_cb (gint fd, short what, void *arg) /* Retransmit dns request */ req->retransmits ++; if (req->retransmits >= req->resolver->max_retransmits) { - msg_err ("maximum number of retransmits expired for resolving %s of type %s", req->requested_name, dns_strtype (req->type)); + msg_err ("maximum number of retransmits expired for resolving %s of type %s", + req->requested_name, dns_strtype (req->type)); rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply)); rep->request = req; rep->code = DNS_RC_SERVFAIL; @@ -1387,11 +1411,7 @@ dns_timer_cb (gint fd, short what, void *arg) return; } - if (req->server->sock == -1) { - req->server->sock = make_universal_socket (req->server->name, - dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE); - } - req->sock = req->server->sock; + req->sock = make_dns_socket (req->server); if (req->sock == -1) { rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply)); @@ -1549,11 +1569,7 @@ make_dns_request (struct rspamd_dns_resolver *resolver, return FALSE; } - if (req->server->sock == -1) { - req->server->sock = make_universal_socket (req->server->name, - dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE); - } - req->sock = req->server->sock; + req->sock = make_dns_socket (req->server); if (req->sock == -1) { return FALSE; @@ -1565,21 +1581,27 @@ make_dns_request (struct rspamd_dns_resolver *resolver, event_base_set (req->resolver->ev_base, &req->timer_event); /* Now send request to server */ + req->id = dns_k_permutor_step (resolver->permutor); + req->port = rspamd_bind_to_random_port (req->sock, req->server->addr.sa.sa_family); + req->key = ((guint32)req->port) << 16 + req->id; + /* Add request to hash table */ + header = (struct dns_header *)req->packet; + while (g_hash_table_lookup (resolver->requests, &req->key)) { + /* Check for unique id */ + req->id = dns_k_permutor_step (resolver->permutor); + req->key ^= req->id; + } + header->qid = req->id; r = send_dns_request (req); if (r == 1) { /* Add timer event */ evtimer_add (&req->timer_event, &req->tv); - /* Add request to hash table */ - while (g_hash_table_lookup (resolver->requests, &req->key)) { - /* Check for unique id */ - header = (struct dns_header *)req->packet; - header->qid = dns_k_permutor_step (resolver->permutor); - req->id = header->qid; - } + g_hash_table_insert (resolver->requests, &req->key, req); - register_async_event (session, (event_finalizer_t)dns_fin_cb, req, g_quark_from_static_string ("dns resolver")); + register_async_event (session, (event_finalizer_t)dns_fin_cb, req, + g_quark_from_static_string ("dns resolver")); } else if (r == -1) { return FALSE; @@ -1594,7 +1616,7 @@ static gboolean parse_resolv_conf (struct rspamd_dns_resolver *resolver) { FILE *r; - gchar buf[BUFSIZ], *p, addr_holder[16]; + gchar buf[BUFSIZ], *p; struct rspamd_dns_server *new; r = fopen (RESOLV_CONF, "r"); @@ -1617,9 +1639,16 @@ parse_resolv_conf (struct rspamd_dns_resolver *resolver) continue; } else { - if (inet_pton (AF_INET6, p, addr_holder) == 1 || - inet_pton (AF_INET, p, addr_holder) == 1) { - new = &resolver->servers[resolver->servers_num]; + new = &resolver->servers[resolver->servers_num]; + if (inet_pton (AF_INET6, p, &new->addr.s6.sin6_addr) == 1) { + new->addr.sa.sa_family = AF_INET6; + new->addr.s6.sin6_port = htons (dns_port); + new->name = memory_pool_strdup (resolver->static_pool, p); + resolver->servers_num ++; + } + else if (inet_pton (AF_INET, p, &new->addr.s4.sin_addr) == 1) { + new->addr.sa.sa_family = AF_INET; + new->addr.s4.sin_port = htons (dns_port); new->name = memory_pool_strdup (resolver->static_pool, p); resolver->servers_num ++; } @@ -1655,8 +1684,8 @@ dns_resolver_init (struct event_base *ev_base, struct config_file *cfg) { GList *cur; struct rspamd_dns_resolver *new; - gchar *begin, *p, *err, addr_holder[16]; - gint priority, i; + gchar *begin, *p, *err; + gint priority; struct rspamd_dns_server *serv; new = memory_pool_alloc0 (cfg->cfg_pool, sizeof (struct rspamd_dns_resolver)); @@ -1711,8 +1740,16 @@ dns_resolver_init (struct event_base *ev_base, struct config_file *cfg) priority = 0; } serv = &new->servers[new->servers_num]; - if (inet_pton (AF_INET6, p, addr_holder) == 1 || - inet_pton (AF_INET, p, addr_holder) == 1) { + if (inet_pton (AF_INET6, p, &serv->addr.s6.sin6_addr) == 1) { + serv->addr.sa.sa_family = AF_INET6; + serv->addr.s6.sin6_port = htons (dns_port); + serv->name = memory_pool_strdup (new->static_pool, begin); + serv->up.priority = priority; + new->servers_num ++; + } + else if (inet_pton (AF_INET, p, &serv->addr.s4.sin_addr) == 1) { + serv->addr.sa.sa_family = AF_INET; + serv->addr.s4.sin_port = htons (dns_port); serv->name = memory_pool_strdup (new->static_pool, begin); serv->up.priority = priority; new->servers_num ++; @@ -1734,20 +1771,6 @@ dns_resolver_init (struct event_base *ev_base, struct config_file *cfg) } } - /* Now init all servers */ - for (i = 0; i < new->servers_num; i ++) { - serv = &new->servers[i]; - serv->sock = make_universal_socket (serv->name, dns_port, - SOCK_DGRAM, TRUE, FALSE, FALSE); - if (serv->sock == -1) { - msg_warn ("cannot create socket to server %s", serv->name); - } - else { - event_set (&serv->ev, serv->sock, EV_READ | EV_PERSIST, dns_read_cb, new); - event_base_set (new->ev_base, &serv->ev); - event_add (&serv->ev, NULL); - } - } return new; } diff --git a/src/dns.h b/src/dns.h index 4889d271c..6a6cda358 100644 --- a/src/dns.h +++ b/src/dns.h @@ -5,6 +5,7 @@ #include "mem_pool.h" #include "events.h" #include "upstream.h" +#include "util.h" #define MAX_SERVERS 16 @@ -23,7 +24,7 @@ typedef void (*dns_callback_type) (struct rspamd_dns_reply *reply, gpointer arg) struct rspamd_dns_server { struct upstream up; /**< upstream structure */ gchar *name; /**< name of DNS server */ - gint sock; /**< persistent socket */ + union sa_union addr; /**< address storage */ struct event ev; }; diff --git a/src/main.h b/src/main.h index a104eb116..e3a13ee65 100644 --- a/src/main.h +++ b/src/main.h @@ -121,16 +121,6 @@ struct process_exception { gsize len; }; -/** - * Union that would be used for storing sockaddrs - */ -union sa_union { - struct sockaddr_storage ss; - struct sockaddr sa; - struct sockaddr_in s4; - struct sockaddr_in6 s6; -}; - /** * Control session object */ diff --git a/src/util.h b/src/util.h index 4edf5cb95..edcfa6122 100644 --- a/src/util.h +++ b/src/util.h @@ -14,6 +14,17 @@ struct workq; struct statfile; struct classifier_config; +/** + * Union that is used for storing sockaddrs + */ +union sa_union { + struct sockaddr_storage ss; + struct sockaddr sa; + struct sockaddr_in s4; + struct sockaddr_in6 s6; + struct sockaddr_un su; +}; + /* * Create socket and bind or connect it to specified address and port */ diff --git a/src/worker.c b/src/worker.c index bb43afba8..95355bdd3 100644 --- a/src/worker.c +++ b/src/worker.c @@ -502,8 +502,9 @@ accept_socket (gint fd, short what, void *arg) struct rspamd_worker_ctx *ctx; union sa_union su; struct worker_task *new_task; + char ip_str[INET6_ADDRSTRLEN + 1]; - socklen_t addrlen = sizeof (su.ss); + socklen_t addrlen = sizeof (su); gint nfd; ctx = worker->ctx; @@ -514,7 +515,7 @@ accept_socket (gint fd, short what, void *arg) } if ((nfd = - accept_from_socket (fd, (struct sockaddr *) &su.ss, &addrlen)) == -1) { + accept_from_socket (fd, &su.sa, &addrlen)) == -1) { msg_warn ("accept failed: %s", strerror (errno)); return; } @@ -525,16 +526,21 @@ accept_socket (gint fd, short what, void *arg) new_task = construct_task (worker); - if (su.ss.ss_family == AF_UNIX) { + if (su.sa.sa_family == AF_UNIX) { msg_info ("accepted connection from unix socket"); new_task->client_addr.s_addr = INADDR_NONE; } - else if (su.ss.ss_family == AF_INET) { + else if (su.sa.sa_family == AF_INET) { msg_info ("accepted connection from %s port %d", inet_ntoa (su.s4.sin_addr), ntohs (su.s4.sin_port)); memcpy (&new_task->client_addr, &su.s4.sin_addr, sizeof (struct in_addr)); } + else if (su.sa.sa_family == AF_INET6) { + msg_info ("accepted connection from %s port %d", + inet_ntop (su.sa.sa_family, &su.s6.sin6_addr, ip_str, sizeof (ip_str)), + ntohs (su.s6.sin6_port)); + } /* Copy some variables */ new_task->sock = nfd; diff --git a/src/worker_util.c b/src/worker_util.c index e8d8f7423..599d098c9 100644 --- a/src/worker_util.c +++ b/src/worker_util.c @@ -152,7 +152,6 @@ free_task (struct worker_task *task, gboolean is_soft) if (task->received) { g_list_free (task->received); } - memory_pool_delete (task->task_pool); if (task->dispatcher) { if (is_soft) { /* Plan dispatcher shutdown */ @@ -165,6 +164,7 @@ free_task (struct worker_task *task, gboolean is_soft) if (task->sock != -1) { close (task->sock); } + memory_pool_delete (task->task_pool); g_slice_free1 (sizeof (struct worker_task), task); } }