aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2013-12-10 16:01:14 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2013-12-10 16:01:14 +0000
commit546d3478737b90f774d566b49b819b3d96bc121b (patch)
treef5e1685cd459f5ec73fb87fbee907b72bd46f55b /src
parent3e35fe3a949f1d79a262e9df7ffbcb5232f99365 (diff)
downloadrspamd-546d3478737b90f774d566b49b819b3d96bc121b.tar.gz
rspamd-546d3478737b90f774d566b49b819b3d96bc121b.zip
Rework rspamd DNS resolver.
Diffstat (limited to 'src')
-rw-r--r--src/dns.c217
-rw-r--r--src/dns.h3
-rw-r--r--src/main.h10
-rw-r--r--src/util.h11
-rw-r--r--src/worker.c14
-rw-r--r--src/worker_util.c2
6 files changed, 144 insertions, 113 deletions
diff --git a/src/dns.c b/src/dns.c
index a6e87aa32..996bc086e 100644
--- a/src/dns.c
+++ b/src/dns.c
@@ -52,14 +52,11 @@ static const unsigned initial_bias = 72;
static const gint dns_port = 53;
+static void dns_read_cb (gint fd, short what, void *arg);
+static void dns_retransmit_handler (gint fd, short what, void *arg);
+
+#define DNS_RANDOM g_random_int
-#ifdef HAVE_ARC4RANDOM
-#define DNS_RANDOM arc4random
-#elif defined HAVE_RANDOM
-#define DNS_RANDOM random
-#else
-#define DNS_RANDOM rand
-#endif
#define UDP_PACKET_SIZE 4096
@@ -209,8 +206,6 @@ punycode_label_toascii(const gunichar *in, gsize in_len, gchar *out,
#define DNS_K_TEA_CYCLES 32
#define DNS_K_TEA_MAGIC 0x9E3779B9U
-static void dns_retransmit_handler (gint fd, short what, void *arg);
-
static void
dns_k_tea_init(struct dns_k_tea *tea, guint32 key[], guint cycles)
@@ -434,6 +429,30 @@ dns_k_shuffle16 (guint16 n, guint s)
return ((0xff00 & (a << 8)) | (0x00ff & (b << 0)));
} /* dns_k_shuffle16() */
+static gint
+make_dns_socket (struct rspamd_dns_server *serv)
+{
+ gint sock;
+
+ sock = socket (serv->addr.sa.sa_family, SOCK_DGRAM, 0);
+ if (sock == -1) {
+ msg_warn ("socket failed: %d, '%s'", errno, strerror (errno));
+ return -1;
+ }
+
+ if (make_socket_nonblocking (sock) == -1) {
+ close (sock);
+ return -1;
+ }
+ if (fcntl (sock, F_SETFD, FD_CLOEXEC) == -1) {
+ msg_warn ("fcntl failed: %d, '%s'", errno, strerror (errno));
+ close (sock);
+ return -1;
+ }
+
+ return sock;
+}
+
struct dns_request_key {
guint16 id;
guint16 port;
@@ -672,7 +691,7 @@ make_a_req (struct rspamd_dns_request *req, const gchar *name)
*p = htons (DNS_C_IN);
req->pos += sizeof (guint16) * 2;
req->type = DNS_REQUEST_A;
- req->requested_name = name;
+ req->requested_name = memory_pool_strdup (req->pool, name);
}
#ifdef HAVE_INET_PTON
@@ -689,7 +708,7 @@ make_aaa_req (struct rspamd_dns_request *req, const gchar *name)
*p = htons (DNS_C_IN);
req->pos += sizeof (guint16) * 2;
req->type = DNS_REQUEST_AAA;
- req->requested_name = name;
+ req->requested_name = memory_pool_strdup (req->pool, name);
}
#endif
@@ -706,7 +725,7 @@ make_txt_req (struct rspamd_dns_request *req, const gchar *name)
*p = htons (DNS_C_IN);
req->pos += sizeof (guint16) * 2;
req->type = DNS_REQUEST_TXT;
- req->requested_name = name;
+ req->requested_name = memory_pool_strdup (req->pool, name);
}
static void
@@ -722,7 +741,7 @@ make_mx_req (struct rspamd_dns_request *req, const gchar *name)
*p = htons (DNS_C_IN);
req->pos += sizeof (guint16) * 2;
req->type = DNS_REQUEST_MX;
- req->requested_name = name;
+ req->requested_name = memory_pool_strdup (req->pool, name);
}
static void
@@ -744,7 +763,7 @@ make_srv_req (struct rspamd_dns_request *req, const gchar *service, const gchar
*p = htons (DNS_C_IN);
req->pos += sizeof (guint16) * 2;
req->type = DNS_REQUEST_SRV;
- req->requested_name = name;
+ req->requested_name = memory_pool_strdup (req->pool, name);
}
static void
@@ -760,11 +779,11 @@ make_spf_req (struct rspamd_dns_request *req, const gchar *name)
*p = htons (DNS_C_IN);
req->pos += sizeof (guint16) * 2;
req->type = DNS_REQUEST_SPF;
- req->requested_name = name;
+ req->requested_name = memory_pool_strdup (req->pool, name);
}
static guint16
-rspamd_bind_to_random_port (int sock)
+rspamd_bind_to_random_port (int sock, int af)
{
union sa_union su;
socklen_t slen = sizeof (su);
@@ -772,21 +791,21 @@ rspamd_bind_to_random_port (int sock)
const int max_retries = 10;
int retries = 0;
- if (getsockname (sock, &su.sa, &slen) != -1) {
+ memset (&su, 0, sizeof (su));
+ su.sa.sa_family = af;
- while (retries < max_retries) {
- ret = g_random_int_range (1024, G_MAXUINT16 - 1);
- if (su.sa.sa_family == AF_INET) {
- su.s4.sin_port = htons (ret);
- }
- else if (su.sa.sa_family == AF_INET6) {
- su.s6.sin6_port = htons (ret);
- }
- if (bind (sock, &su.sa, slen) != -1) {
- return ret;
- }
- retries ++;
+ while (retries < max_retries) {
+ ret = g_random_int_range (1024, G_MAXUINT16 - 1);
+ if (af == AF_INET) {
+ su.s4.sin_port = htons (ret);
}
+ else if (af == AF_INET6) {
+ su.s6.sin6_port = htons (ret);
+ }
+ if (bind (sock, &su.sa, slen) != -1) {
+ return ret;
+ }
+ retries ++;
}
return 0;
@@ -797,15 +816,14 @@ send_dns_request (struct rspamd_dns_request *req)
{
gint r;
- req->port = rspamd_bind_to_random_port (req->sock);
- req->key = ((guint32)req->port) << 16 + req->id;
- r = send (req->sock, req->packet, req->pos, 0);
+ r = sendto (req->sock, req->packet, req->pos, 0, &req->server->addr.sa,
+ req->server->addr.sa.sa_family == AF_INET ? sizeof (struct sockaddr_in) :
+ sizeof (struct sockaddr_in6));
if (r == -1) {
if (errno == EAGAIN) {
event_set (&req->io_event, req->sock, EV_WRITE, dns_retransmit_handler, req);
event_base_set (req->resolver->ev_base, &req->io_event);
event_add (&req->io_event, &req->tv);
- register_async_event (req->session, (event_finalizer_t)event_del, &req->io_event, g_quark_from_static_string ("dns resolver"));
return 0;
}
else {
@@ -818,9 +836,13 @@ send_dns_request (struct rspamd_dns_request *req)
event_set (&req->io_event, req->sock, EV_WRITE, dns_retransmit_handler, req);
event_base_set (req->resolver->ev_base, &req->io_event);
event_add (&req->io_event, &req->tv);
- register_async_event (req->session, (event_finalizer_t)event_del, &req->io_event, g_quark_from_static_string ("dns resolver"));
return 0;
}
+ else {
+ event_set (&req->io_event, req->sock, EV_READ, dns_read_cb, req);
+ event_base_set (req->resolver->ev_base, &req->io_event);
+ event_add (&req->io_event, NULL);
+ }
return 1;
}
@@ -830,7 +852,11 @@ dns_fin_cb (gpointer arg)
{
struct rspamd_dns_request *req = arg;
+ if (req->sock != -1) {
+ close (req->sock);
+ }
event_del (&req->timer_event);
+ event_del (&req->io_event);
g_hash_table_remove (req->resolver->requests, &req->key);
}
@@ -1219,11 +1245,10 @@ dns_parse_rr (guint8 *in, union rspamd_reply_element *elt, guint8 **pos, struct
static gboolean
dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver,
- guint16 port, struct rspamd_dns_request **req_out,
+ guint16 port, struct rspamd_dns_request *req,
struct rspamd_dns_reply **_rep)
{
struct dns_header *header = (struct dns_header *)in;
- struct rspamd_dns_request *req;
struct rspamd_dns_reply *rep;
union rspamd_reply_element *elt;
guint8 *pos;
@@ -1240,16 +1265,19 @@ dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver,
/* Now try to find corresponding request */
id = header->qid;
key = ((guint32)port) << 16 + id;
- if ((req = g_hash_table_lookup (resolver->requests, &key)) == NULL) {
+ if (g_hash_table_lookup (resolver->requests, &key) == NULL) {
/* No such requests found */
+ msg_info ("corrupted DNS packet received, must have %d:%d, but %d:%d was expected",
+ (gint)id, (gint)port, (gint)req->id, (gint)req->port);
return FALSE;
}
- *req_out = req;
/*
* Now we have request and query data is now at the end of header, so compare
* request QR section and reply QR section
*/
- if ((pos = dns_request_reply_cmp (req, in + sizeof (struct dns_header), r - sizeof (struct dns_header))) == NULL) {
+ if ((pos = dns_request_reply_cmp (req, in + sizeof (struct dns_header),
+ r - sizeof (struct dns_header))) == NULL) {
+ msg_info ("query and answer differs, skip DNS packet");
return FALSE;
}
/*
@@ -1311,8 +1339,7 @@ dns_check_throttling (struct rspamd_dns_resolver *resolver)
static void
dns_read_cb (gint fd, short what, void *arg)
{
- struct rspamd_dns_resolver *resolver = arg;
- struct rspamd_dns_request *req = NULL;
+ struct rspamd_dns_request *req = arg;
gint r;
struct rspamd_dns_reply *rep;
guint8 in[UDP_PACKET_SIZE];
@@ -1323,24 +1350,20 @@ dns_read_cb (gint fd, short what, void *arg)
/* This function is called each time when we have data on one of server's sockets */
/* First read packet from socket */
- r = recvfrom (fd, in, sizeof (in), 0, &su.sa, &slen);
- if (r > (gint)(sizeof (struct dns_header) + sizeof (struct dns_query))) {
- if (su.sa.sa_family == AF_INET) {
- port = ntohs (su.s4.sin_port);
- }
- else if (su.sa.sa_family == AF_INET6) {
- port = ntohs (su.s6.sin6_port);
- }
- if (dns_parse_reply (in, r, resolver, port, &req, &rep)) {
- /* Decrease errors count */
- if (rep->request->resolver->errors > 0) {
- rep->request->resolver->errors --;
+ if (what == EV_READ) {
+ r = recvfrom (fd, in, sizeof (in), 0, &su.sa, &slen);
+ if (r > (gint)(sizeof (struct dns_header) + sizeof (struct dns_query))) {
+ if (dns_parse_reply (in, r, req->resolver, req->port, req, &rep)) {
+ /* Decrease errors count */
+ if (rep->request->resolver->errors > 0) {
+ rep->request->resolver->errors --;
+ }
+ upstream_ok (&rep->request->server->up, rep->request->time);
+ rep->request->func (rep, rep->request->arg);
}
- upstream_ok (&rep->request->server->up, rep->request->time);
- rep->request->func (rep, rep->request->arg);
- remove_normal_event (req->session, dns_fin_cb, req);
}
}
+ remove_normal_event (req->session, dns_fin_cb, req);
}
static void
@@ -1353,7 +1376,8 @@ dns_timer_cb (gint fd, short what, void *arg)
/* Retransmit dns request */
req->retransmits ++;
if (req->retransmits >= req->resolver->max_retransmits) {
- msg_err ("maximum number of retransmits expired for resolving %s of type %s", req->requested_name, dns_strtype (req->type));
+ msg_err ("maximum number of retransmits expired for resolving %s of type %s",
+ req->requested_name, dns_strtype (req->type));
rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
rep->request = req;
rep->code = DNS_RC_SERVFAIL;
@@ -1387,11 +1411,7 @@ dns_timer_cb (gint fd, short what, void *arg)
return;
}
- if (req->server->sock == -1) {
- req->server->sock = make_universal_socket (req->server->name,
- dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
- }
- req->sock = req->server->sock;
+ req->sock = make_dns_socket (req->server);
if (req->sock == -1) {
rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
@@ -1549,11 +1569,7 @@ make_dns_request (struct rspamd_dns_resolver *resolver,
return FALSE;
}
- if (req->server->sock == -1) {
- req->server->sock = make_universal_socket (req->server->name,
- dns_port, SOCK_DGRAM, TRUE, FALSE, FALSE);
- }
- req->sock = req->server->sock;
+ req->sock = make_dns_socket (req->server);
if (req->sock == -1) {
return FALSE;
@@ -1565,21 +1581,27 @@ make_dns_request (struct rspamd_dns_resolver *resolver,
event_base_set (req->resolver->ev_base, &req->timer_event);
/* Now send request to server */
+ req->id = dns_k_permutor_step (resolver->permutor);
+ req->port = rspamd_bind_to_random_port (req->sock, req->server->addr.sa.sa_family);
+ req->key = ((guint32)req->port) << 16 + req->id;
+ /* Add request to hash table */
+ header = (struct dns_header *)req->packet;
+ while (g_hash_table_lookup (resolver->requests, &req->key)) {
+ /* Check for unique id */
+ req->id = dns_k_permutor_step (resolver->permutor);
+ req->key ^= req->id;
+ }
+ header->qid = req->id;
r = send_dns_request (req);
if (r == 1) {
/* Add timer event */
evtimer_add (&req->timer_event, &req->tv);
- /* Add request to hash table */
- while (g_hash_table_lookup (resolver->requests, &req->key)) {
- /* Check for unique id */
- header = (struct dns_header *)req->packet;
- header->qid = dns_k_permutor_step (resolver->permutor);
- req->id = header->qid;
- }
+
g_hash_table_insert (resolver->requests, &req->key, req);
- register_async_event (session, (event_finalizer_t)dns_fin_cb, req, g_quark_from_static_string ("dns resolver"));
+ register_async_event (session, (event_finalizer_t)dns_fin_cb, req,
+ g_quark_from_static_string ("dns resolver"));
}
else if (r == -1) {
return FALSE;
@@ -1594,7 +1616,7 @@ static gboolean
parse_resolv_conf (struct rspamd_dns_resolver *resolver)
{
FILE *r;
- gchar buf[BUFSIZ], *p, addr_holder[16];
+ gchar buf[BUFSIZ], *p;
struct rspamd_dns_server *new;
r = fopen (RESOLV_CONF, "r");
@@ -1617,9 +1639,16 @@ parse_resolv_conf (struct rspamd_dns_resolver *resolver)
continue;
}
else {
- if (inet_pton (AF_INET6, p, addr_holder) == 1 ||
- inet_pton (AF_INET, p, addr_holder) == 1) {
- new = &resolver->servers[resolver->servers_num];
+ new = &resolver->servers[resolver->servers_num];
+ if (inet_pton (AF_INET6, p, &new->addr.s6.sin6_addr) == 1) {
+ new->addr.sa.sa_family = AF_INET6;
+ new->addr.s6.sin6_port = htons (dns_port);
+ new->name = memory_pool_strdup (resolver->static_pool, p);
+ resolver->servers_num ++;
+ }
+ else if (inet_pton (AF_INET, p, &new->addr.s4.sin_addr) == 1) {
+ new->addr.sa.sa_family = AF_INET;
+ new->addr.s4.sin_port = htons (dns_port);
new->name = memory_pool_strdup (resolver->static_pool, p);
resolver->servers_num ++;
}
@@ -1655,8 +1684,8 @@ dns_resolver_init (struct event_base *ev_base, struct config_file *cfg)
{
GList *cur;
struct rspamd_dns_resolver *new;
- gchar *begin, *p, *err, addr_holder[16];
- gint priority, i;
+ gchar *begin, *p, *err;
+ gint priority;
struct rspamd_dns_server *serv;
new = memory_pool_alloc0 (cfg->cfg_pool, sizeof (struct rspamd_dns_resolver));
@@ -1711,8 +1740,16 @@ dns_resolver_init (struct event_base *ev_base, struct config_file *cfg)
priority = 0;
}
serv = &new->servers[new->servers_num];
- if (inet_pton (AF_INET6, p, addr_holder) == 1 ||
- inet_pton (AF_INET, p, addr_holder) == 1) {
+ if (inet_pton (AF_INET6, p, &serv->addr.s6.sin6_addr) == 1) {
+ serv->addr.sa.sa_family = AF_INET6;
+ serv->addr.s6.sin6_port = htons (dns_port);
+ serv->name = memory_pool_strdup (new->static_pool, begin);
+ serv->up.priority = priority;
+ new->servers_num ++;
+ }
+ else if (inet_pton (AF_INET, p, &serv->addr.s4.sin_addr) == 1) {
+ serv->addr.sa.sa_family = AF_INET;
+ serv->addr.s4.sin_port = htons (dns_port);
serv->name = memory_pool_strdup (new->static_pool, begin);
serv->up.priority = priority;
new->servers_num ++;
@@ -1734,20 +1771,6 @@ dns_resolver_init (struct event_base *ev_base, struct config_file *cfg)
}
}
- /* Now init all servers */
- for (i = 0; i < new->servers_num; i ++) {
- serv = &new->servers[i];
- serv->sock = make_universal_socket (serv->name, dns_port,
- SOCK_DGRAM, TRUE, FALSE, FALSE);
- if (serv->sock == -1) {
- msg_warn ("cannot create socket to server %s", serv->name);
- }
- else {
- event_set (&serv->ev, serv->sock, EV_READ | EV_PERSIST, dns_read_cb, new);
- event_base_set (new->ev_base, &serv->ev);
- event_add (&serv->ev, NULL);
- }
- }
return new;
}
diff --git a/src/dns.h b/src/dns.h
index 4889d271c..6a6cda358 100644
--- a/src/dns.h
+++ b/src/dns.h
@@ -5,6 +5,7 @@
#include "mem_pool.h"
#include "events.h"
#include "upstream.h"
+#include "util.h"
#define MAX_SERVERS 16
@@ -23,7 +24,7 @@ typedef void (*dns_callback_type) (struct rspamd_dns_reply *reply, gpointer arg)
struct rspamd_dns_server {
struct upstream up; /**< upstream structure */
gchar *name; /**< name of DNS server */
- gint sock; /**< persistent socket */
+ union sa_union addr; /**< address storage */
struct event ev;
};
diff --git a/src/main.h b/src/main.h
index a104eb116..e3a13ee65 100644
--- a/src/main.h
+++ b/src/main.h
@@ -122,16 +122,6 @@ struct process_exception {
};
/**
- * Union that would be used for storing sockaddrs
- */
-union sa_union {
- struct sockaddr_storage ss;
- struct sockaddr sa;
- struct sockaddr_in s4;
- struct sockaddr_in6 s6;
-};
-
-/**
* Control session object
*/
struct controller_command;
diff --git a/src/util.h b/src/util.h
index 4edf5cb95..edcfa6122 100644
--- a/src/util.h
+++ b/src/util.h
@@ -14,6 +14,17 @@ struct workq;
struct statfile;
struct classifier_config;
+/**
+ * Union that is used for storing sockaddrs
+ */
+union sa_union {
+ struct sockaddr_storage ss;
+ struct sockaddr sa;
+ struct sockaddr_in s4;
+ struct sockaddr_in6 s6;
+ struct sockaddr_un su;
+};
+
/*
* Create socket and bind or connect it to specified address and port
*/
diff --git a/src/worker.c b/src/worker.c
index bb43afba8..95355bdd3 100644
--- a/src/worker.c
+++ b/src/worker.c
@@ -502,8 +502,9 @@ accept_socket (gint fd, short what, void *arg)
struct rspamd_worker_ctx *ctx;
union sa_union su;
struct worker_task *new_task;
+ char ip_str[INET6_ADDRSTRLEN + 1];
- socklen_t addrlen = sizeof (su.ss);
+ socklen_t addrlen = sizeof (su);
gint nfd;
ctx = worker->ctx;
@@ -514,7 +515,7 @@ accept_socket (gint fd, short what, void *arg)
}
if ((nfd =
- accept_from_socket (fd, (struct sockaddr *) &su.ss, &addrlen)) == -1) {
+ accept_from_socket (fd, &su.sa, &addrlen)) == -1) {
msg_warn ("accept failed: %s", strerror (errno));
return;
}
@@ -525,16 +526,21 @@ accept_socket (gint fd, short what, void *arg)
new_task = construct_task (worker);
- if (su.ss.ss_family == AF_UNIX) {
+ if (su.sa.sa_family == AF_UNIX) {
msg_info ("accepted connection from unix socket");
new_task->client_addr.s_addr = INADDR_NONE;
}
- else if (su.ss.ss_family == AF_INET) {
+ else if (su.sa.sa_family == AF_INET) {
msg_info ("accepted connection from %s port %d",
inet_ntoa (su.s4.sin_addr), ntohs (su.s4.sin_port));
memcpy (&new_task->client_addr, &su.s4.sin_addr,
sizeof (struct in_addr));
}
+ else if (su.sa.sa_family == AF_INET6) {
+ msg_info ("accepted connection from %s port %d",
+ inet_ntop (su.sa.sa_family, &su.s6.sin6_addr, ip_str, sizeof (ip_str)),
+ ntohs (su.s6.sin6_port));
+ }
/* Copy some variables */
new_task->sock = nfd;
diff --git a/src/worker_util.c b/src/worker_util.c
index e8d8f7423..599d098c9 100644
--- a/src/worker_util.c
+++ b/src/worker_util.c
@@ -152,7 +152,6 @@ free_task (struct worker_task *task, gboolean is_soft)
if (task->received) {
g_list_free (task->received);
}
- memory_pool_delete (task->task_pool);
if (task->dispatcher) {
if (is_soft) {
/* Plan dispatcher shutdown */
@@ -165,6 +164,7 @@ free_task (struct worker_task *task, gboolean is_soft)
if (task->sock != -1) {
close (task->sock);
}
+ memory_pool_delete (task->task_pool);
g_slice_free1 (sizeof (struct worker_task), task);
}
}