]> source.dussan.org Git - rspamd.git/commitdiff
* Add throttling detection mechanic for dns resolver
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Wed, 23 Mar 2011 14:08:58 +0000 (17:08 +0300)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Wed, 23 Mar 2011 14:08:58 +0000 (17:08 +0300)
* Improve phishing module adding ability to define 'strict' phishing domains

src/cfg_file.h
src/cfg_utils.c
src/cfg_xml.c
src/dns.c
src/dns.h
src/plugins/lua/phishing.lua

index 3a9a7d12e6493917146cc9ec891999bfd9a2f6cc..6fb4065e5d48f91b01e8faf15b63a90fd6090927 100644 (file)
@@ -320,6 +320,8 @@ struct config_file {
 
        guint32 dns_timeout;                                                    /**< timeout in milliseconds for waiting for dns reply  */
        guint32 dns_retransmits;                                                /**< maximum retransmits count                                                  */
+       guint32 dns_throttling_errors;                                  /**< maximum errors for starting resolver throttling    */
+       guint32 dns_throttling_time;                                    /**< time in seconds for DNS throttling                                 */
        GList *nameservers;                                                             /**< list of nameservers or NULL to parse resolv.conf   */
 };
 
index f4cf48529df822fc09b5a8857bf37287238d4174..9acb442b5d92d30563df1c01f0135d9b1143861d 100644 (file)
@@ -170,7 +170,9 @@ init_defaults (struct config_file *cfg)
 
        cfg->dns_timeout = 1000;
        cfg->dns_retransmits = 5;
-
+       /* After 20 errors do throttling for 10 seconds */
+       cfg->dns_throttling_errors = 20;
+       cfg->dns_throttling_time = 10000;
 
        cfg->max_statfile_size = DEFAULT_STATFILE_SIZE;
        cfg->modules_opts = g_hash_table_new (g_str_hash, g_str_equal);
index 2bc83e10943ee55cf3d506fe69631dbab36d68f7..43b2e0441977169860a9ca25421b7ab375e1e585 100644 (file)
@@ -184,6 +184,18 @@ static struct xml_parser_rule grammar[] = {
                                G_STRUCT_OFFSET (struct config_file, dns_retransmits),
                                NULL
                        },
+                       {
+                               "dns_throttling_errors",
+                               xml_handle_uint32,
+                               G_STRUCT_OFFSET (struct config_file, dns_throttling_errors),
+                               NULL
+                       },
+                       {
+                               "dns_throttling_time",
+                               xml_handle_seconds,
+                               G_STRUCT_OFFSET (struct config_file, dns_throttling_time),
+                               NULL
+                       },
                        NULL_ATTR
                },
                NULL_DEF_ATTR
index d29ea5696d3d154767f2e7ae421161c840ab196e..798f4250d7faadec8d78c58c76826dff960f1c7b 100644 (file)
--- a/src/dns.c
+++ b/src/dns.c
@@ -546,7 +546,7 @@ send_dns_request (struct rspamd_dns_request *req)
                } 
                else {
                        msg_err ("send failed: %s for server %s", strerror (errno), req->server->name);
-                       upstream_fail (&req->server->up, time (NULL));
+                       upstream_fail (&req->server->up, req->time);
                        return -1;
                }
        }
@@ -963,6 +963,28 @@ dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver, struc
        return rep;
 }
 
+static void
+dns_throttling_cb (gint fd, short what, void *arg)
+{
+       struct rspamd_dns_resolver *resolver = arg;
+
+       resolver->throttling = FALSE;
+       resolver->errors = 0;
+       msg_info ("stop DNS throttling after %d seconds", (int)resolver->throttling_time.tv_sec);
+}
+
+static void
+dns_check_throttling (struct rspamd_dns_resolver *resolver)
+{
+       if (resolver->errors > resolver->max_errors) {
+               msg_info ("starting DNS throttling after %ud errors", resolver->errors);
+               /* Init throttling timeout */
+               resolver->throttling = TRUE;
+               evtimer_set (&resolver->throttling_event, dns_throttling_cb, resolver);
+               event_add (&resolver->throttling_event, &resolver->throttling_time);
+       }
+}
+
 static void
 dns_read_cb (gint fd, short what, void *arg)
 {
@@ -978,6 +1000,10 @@ dns_read_cb (gint fd, short what, void *arg)
        r = read (fd, in, sizeof (in));
        if (r > sizeof (struct dns_header) + sizeof (struct dns_query)) {
                if ((rep = dns_parse_reply (in, r, resolver, &req)) != NULL) {
+                       /* Decrease errors count */
+                       if (rep->request->resolver->errors > 0) {
+                               rep->request->resolver->errors --;
+                       }
                        upstream_ok (&rep->request->server->up, rep->request->time);
                        rep->request->func (rep, rep->request->arg);
                }
@@ -1001,12 +1027,14 @@ dns_timer_cb (gint fd, short what, void *arg)
                upstream_fail (&rep->request->server->up, rep->request->time);
                remove_normal_event (req->session, dns_fin_cb, req);
                req->func (rep, req->arg);
+               dns_check_throttling (req->resolver);
+               req->resolver->errors ++;
                return;
        }
        /* Select other server */
        req->server = (struct rspamd_dns_server *)get_upstream_round_robin (req->resolver->servers, 
                        req->resolver->servers_num, sizeof (struct rspamd_dns_server),
-                       time (NULL), DEFAULT_UPSTREAM_ERROR_TIME, DEFAULT_UPSTREAM_DEAD_TIME, DEFAULT_UPSTREAM_MAXERRORS);
+                       req->time, DEFAULT_UPSTREAM_ERROR_TIME, DEFAULT_UPSTREAM_DEAD_TIME, DEFAULT_UPSTREAM_MAXERRORS);
        if (req->server == NULL) {
                rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
                rep->request = req;
@@ -1065,7 +1093,8 @@ dns_retransmit_handler (gint fd, short what, void *arg)
                        rep->code = DNS_RC_SERVFAIL;
                        upstream_fail (&rep->request->server->up, rep->request->time);
                        req->func (rep, req->arg);
-
+                       req->resolver->errors ++;
+                       dns_check_throttling (req->resolver);
                        return;
                }
                r = send_dns_request (req);
@@ -1101,6 +1130,11 @@ make_dns_request (struct rspamd_dns_resolver *resolver,
        const gchar *name, *service, *proto;
        gint r;
 
+       /* Check throttling */
+       if (resolver->throttling) {
+               return FALSE;
+       }
+
        req = memory_pool_alloc (pool, sizeof (struct rspamd_dns_request));
        req->pool = pool;
        req->session = session;
@@ -1248,6 +1282,9 @@ dns_resolver_init (struct config_file *cfg)
        new->static_pool = cfg->cfg_pool;
        new->request_timeout = cfg->dns_timeout;
        new->max_retransmits = cfg->dns_retransmits;
+       new->max_errors = cfg->dns_throttling_errors;
+       new->throttling_time.tv_sec = cfg->dns_throttling_time / 1000;
+       new->throttling_time.tv_usec = (cfg->dns_throttling_time - new->throttling_time.tv_sec * 1000) * 1000;
 
        if (cfg->nameservers == NULL) {
                /* Parse resolv.conf */
index 9896a6fa445711dc01057407c4bf73b385327e72..67cb50e9fbe49bb01ac71361f9cdfaef8e74fe48 100644 (file)
--- a/src/dns.h
+++ b/src/dns.h
@@ -49,7 +49,12 @@ struct rspamd_dns_resolver {
        struct dns_k_permutor *permutor;        /**< permutor for randomizing request id        */
        guint request_timeout;
        guint max_retransmits;
+       guint max_errors;
        memory_pool_t *static_pool;                     /**< permament pool (cfg_pool)                          */
+       gboolean throttling;                            /**< dns servers are busy                                       */
+       guint errors;                                           /**< resolver errors                                            */
+       struct timeval throttling_time;         /**< throttling time                                            */
+       struct event throttling_event;          /**< throttling event                                           */
 };
 
 struct dns_header;
index 11568875057f1dbcd43add0f3c4f8deb78b9b42a..1e648768ace9427043b1261e3591ff1058d72d65 100644 (file)
@@ -3,6 +3,7 @@
 --
 local symbol = 'PHISHED_URL'
 local domains = nil
+local strict_domains = {}
 
 function phishing_cb (task)
        local urls = task:get_urls();
@@ -10,17 +11,31 @@ function phishing_cb (task)
        if urls then
                for _,url in ipairs(urls) do
                        if url:is_phished() then
+                               local purl = url:get_phished()
+                               if table.maxn(strict_domains) > 0 then
+                                       local _,_,tld = string.find(purl:get_host(), '([a-zA-Z0-9%-]+\.[a-zA-Z0-9%-]+)$')
+                                       local found = false
+                                       if tld then
+                                               for _,rule in ipairs(strict_domains) do
+                                                       if rule['map']:get_key(tld) then
+                                                               task:insert_result(rule['symbol'], 1, purl:get_host())
+                                                               found = true
+                                                       end
+                                               end
+                                               if found then
+                                                       return
+                                               end
+                                       end
+                               end
                                if domains then
-                                       local _,_,tld = string.find(url:get_phished():get_host(), '([a-zA-Z0-9%-]+\.[a-zA-Z0-9%-]+)$')
+                                       local _,_,tld = string.find(purl:get_host(), '([a-zA-Z0-9%-]+\.[a-zA-Z0-9%-]+)$')
                                        if tld then
                                                if domains:get_key(tld) then
-                                                       if url:is_phished() then
-                                                               task:insert_result(symbol, 1, url:get_host())
-                                                       end
+                                                       task:insert_result(symbol, 1, purl:get_host())
                                                end
                                        end
                                else            
-                                       task:insert_result(symbol, 1, url:get_phished():get_host())
+                                       task:insert_result(symbol, 1, purl:get_host())
                                end
                        end
                end
@@ -32,6 +47,7 @@ if type(rspamd_config.get_api_version) ~= 'nil' then
        if rspamd_config:get_api_version() >= 1 then
                rspamd_config:register_module_option('phishing', 'symbol', 'string')
                rspamd_config:register_module_option('phishing', 'domains', 'map')
+               rspamd_config:register_module_option('phishing', 'strict_domains', 'string')
        end
 end
 
@@ -43,8 +59,35 @@ if opts then
         -- Register symbol's callback
         rspamd_config:register_symbol(symbol, 1.0, 'phishing_cb')
     end
-       if opts['domains'] then
+       if opts['domains'] and type(opt['domains']) == 'string' then
                domains = rspamd_config:add_hash_map (opts['domains'])
        end
+       if opts['strict_domains'] then
+               local sd = {}
+               if type(opts['strict_domains']) == 'table' then
+                       sd = opts['strict_domains']
+               else
+                       sd[1] = opts['strict_domains']
+               end
+               for _,d in ipairs(sd) do
+                       local s, _ = string.find(d, ':')
+                       if s then
+                               local sym = string.sub(d, s + 1, -1)
+                               local map = string.sub(d, 1, s - 1)
+                               if type(rspamd_config.get_api_version) ~= 'nil' then
+                                       rspamd_config:register_virtual_symbol(sym, 1)
+                               end
+                               local rmap = rspamd_config:add_hash_map (map)
+                               if rmap then
+                                       local rule = {symbol = sym, map = rmap}
+                                       table.insert(strict_domains, rule)
+                               else
+                                       rspamd_logger.info('cannot add map: ' .. map .. ' for symbol: ' .. sym)
+                               end
+                       else
+                               rspamd_logger.info('strict_domains option must be in format <map>:<symbol>')
+                       end
+               end
+       end
     -- If no symbol defined, do not register this module
 end