guint32 dns_timeout; /**< timeout in milliseconds for waiting for dns reply */
guint32 dns_retransmits; /**< maximum retransmits count */
+ guint32 dns_throttling_errors; /**< maximum errors for starting resolver throttling */
+ guint32 dns_throttling_time; /**< time in seconds for DNS throttling */
GList *nameservers; /**< list of nameservers or NULL to parse resolv.conf */
};
cfg->dns_timeout = 1000;
cfg->dns_retransmits = 5;
-
+ /* After 20 errors do throttling for 10 seconds */
+ cfg->dns_throttling_errors = 20;
+ cfg->dns_throttling_time = 10000;
cfg->max_statfile_size = DEFAULT_STATFILE_SIZE;
cfg->modules_opts = g_hash_table_new (g_str_hash, g_str_equal);
G_STRUCT_OFFSET (struct config_file, dns_retransmits),
NULL
},
+ {
+ "dns_throttling_errors",
+ xml_handle_uint32,
+ G_STRUCT_OFFSET (struct config_file, dns_throttling_errors),
+ NULL
+ },
+ {
+ "dns_throttling_time",
+ xml_handle_seconds,
+ G_STRUCT_OFFSET (struct config_file, dns_throttling_time),
+ NULL
+ },
NULL_ATTR
},
NULL_DEF_ATTR
}
else {
msg_err ("send failed: %s for server %s", strerror (errno), req->server->name);
- upstream_fail (&req->server->up, time (NULL));
+ upstream_fail (&req->server->up, req->time);
return -1;
}
}
return rep;
}
+static void
+dns_throttling_cb (gint fd, short what, void *arg)
+{
+ struct rspamd_dns_resolver *resolver = arg;
+
+ resolver->throttling = FALSE;
+ resolver->errors = 0;
+ msg_info ("stop DNS throttling after %d seconds", (int)resolver->throttling_time.tv_sec);
+}
+
+static void
+dns_check_throttling (struct rspamd_dns_resolver *resolver)
+{
+ if (resolver->errors > resolver->max_errors) {
+ msg_info ("starting DNS throttling after %ud errors", resolver->errors);
+ /* Init throttling timeout */
+ resolver->throttling = TRUE;
+ evtimer_set (&resolver->throttling_event, dns_throttling_cb, resolver);
+ event_add (&resolver->throttling_event, &resolver->throttling_time);
+ }
+}
+
static void
dns_read_cb (gint fd, short what, void *arg)
{
r = read (fd, in, sizeof (in));
if (r > sizeof (struct dns_header) + sizeof (struct dns_query)) {
if ((rep = dns_parse_reply (in, r, resolver, &req)) != NULL) {
+ /* Decrease errors count */
+ if (rep->request->resolver->errors > 0) {
+ rep->request->resolver->errors --;
+ }
upstream_ok (&rep->request->server->up, rep->request->time);
rep->request->func (rep, rep->request->arg);
}
upstream_fail (&rep->request->server->up, rep->request->time);
remove_normal_event (req->session, dns_fin_cb, req);
req->func (rep, req->arg);
+ dns_check_throttling (req->resolver);
+ req->resolver->errors ++;
return;
}
/* Select other server */
req->server = (struct rspamd_dns_server *)get_upstream_round_robin (req->resolver->servers,
req->resolver->servers_num, sizeof (struct rspamd_dns_server),
- time (NULL), DEFAULT_UPSTREAM_ERROR_TIME, DEFAULT_UPSTREAM_DEAD_TIME, DEFAULT_UPSTREAM_MAXERRORS);
+ req->time, DEFAULT_UPSTREAM_ERROR_TIME, DEFAULT_UPSTREAM_DEAD_TIME, DEFAULT_UPSTREAM_MAXERRORS);
if (req->server == NULL) {
rep = memory_pool_alloc0 (req->pool, sizeof (struct rspamd_dns_reply));
rep->request = req;
rep->code = DNS_RC_SERVFAIL;
upstream_fail (&rep->request->server->up, rep->request->time);
req->func (rep, req->arg);
-
+ req->resolver->errors ++;
+ dns_check_throttling (req->resolver);
return;
}
r = send_dns_request (req);
const gchar *name, *service, *proto;
gint r;
+ /* Check throttling */
+ if (resolver->throttling) {
+ return FALSE;
+ }
+
req = memory_pool_alloc (pool, sizeof (struct rspamd_dns_request));
req->pool = pool;
req->session = session;
new->static_pool = cfg->cfg_pool;
new->request_timeout = cfg->dns_timeout;
new->max_retransmits = cfg->dns_retransmits;
+ new->max_errors = cfg->dns_throttling_errors;
+ new->throttling_time.tv_sec = cfg->dns_throttling_time / 1000;
+ new->throttling_time.tv_usec = (cfg->dns_throttling_time - new->throttling_time.tv_sec * 1000) * 1000;
if (cfg->nameservers == NULL) {
/* Parse resolv.conf */
struct dns_k_permutor *permutor; /**< permutor for randomizing request id */
guint request_timeout;
guint max_retransmits;
+ guint max_errors;
memory_pool_t *static_pool; /**< permament pool (cfg_pool) */
+ gboolean throttling; /**< dns servers are busy */
+ guint errors; /**< resolver errors */
+ struct timeval throttling_time; /**< throttling time */
+ struct event throttling_event; /**< throttling event */
};
struct dns_header;
--
local symbol = 'PHISHED_URL'
local domains = nil
+local strict_domains = {}
function phishing_cb (task)
local urls = task:get_urls();
if urls then
for _,url in ipairs(urls) do
if url:is_phished() then
+ local purl = url:get_phished()
+ if table.maxn(strict_domains) > 0 then
+ local _,_,tld = string.find(purl:get_host(), '([a-zA-Z0-9%-]+\.[a-zA-Z0-9%-]+)$')
+ local found = false
+ if tld then
+ for _,rule in ipairs(strict_domains) do
+ if rule['map']:get_key(tld) then
+ task:insert_result(rule['symbol'], 1, purl:get_host())
+ found = true
+ end
+ end
+ if found then
+ return
+ end
+ end
+ end
if domains then
- local _,_,tld = string.find(url:get_phished():get_host(), '([a-zA-Z0-9%-]+\.[a-zA-Z0-9%-]+)$')
+ local _,_,tld = string.find(purl:get_host(), '([a-zA-Z0-9%-]+\.[a-zA-Z0-9%-]+)$')
if tld then
if domains:get_key(tld) then
- if url:is_phished() then
- task:insert_result(symbol, 1, url:get_host())
- end
+ task:insert_result(symbol, 1, purl:get_host())
end
end
else
- task:insert_result(symbol, 1, url:get_phished():get_host())
+ task:insert_result(symbol, 1, purl:get_host())
end
end
end
if rspamd_config:get_api_version() >= 1 then
rspamd_config:register_module_option('phishing', 'symbol', 'string')
rspamd_config:register_module_option('phishing', 'domains', 'map')
+ rspamd_config:register_module_option('phishing', 'strict_domains', 'string')
end
end
-- Register symbol's callback
rspamd_config:register_symbol(symbol, 1.0, 'phishing_cb')
end
- if opts['domains'] then
+ if opts['domains'] and type(opt['domains']) == 'string' then
domains = rspamd_config:add_hash_map (opts['domains'])
end
+ if opts['strict_domains'] then
+ local sd = {}
+ if type(opts['strict_domains']) == 'table' then
+ sd = opts['strict_domains']
+ else
+ sd[1] = opts['strict_domains']
+ end
+ for _,d in ipairs(sd) do
+ local s, _ = string.find(d, ':')
+ if s then
+ local sym = string.sub(d, s + 1, -1)
+ local map = string.sub(d, 1, s - 1)
+ if type(rspamd_config.get_api_version) ~= 'nil' then
+ rspamd_config:register_virtual_symbol(sym, 1)
+ end
+ local rmap = rspamd_config:add_hash_map (map)
+ if rmap then
+ local rule = {symbol = sym, map = rmap}
+ table.insert(strict_domains, rule)
+ else
+ rspamd_logger.info('cannot add map: ' .. map .. ' for symbol: ' .. sym)
+ end
+ else
+ rspamd_logger.info('strict_domains option must be in format <map>:<symbol>')
+ end
+ end
+ end
-- If no symbol defined, do not register this module
end