]> source.dussan.org Git - rspamd.git/commitdiff
Rework ip_score plugin.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 22 Jul 2015 14:27:50 +0000 (15:27 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 22 Jul 2015 14:27:50 +0000 (15:27 +0100)
- Add new normalization
- Store data as pairs score/total
- Use new mempool variables API

src/plugins/lua/ip_score.lua

index b1d116fc4781cd8f9cb541969a97017d3e3b3e94..1e1fd92a13c5f05f8e5e401d0398a60056df5bd9 100644 (file)
@@ -43,7 +43,7 @@ local options = {
     ['reject'] = 1.0,
     ['add header'] = 0.25,
     ['rewrite subject'] = 0.25,
-    ['no action'] = -1.0
+    ['no action'] = 1.0
   },
   scores = { -- how each component is evaluated
     ['asn'] = 0.5,
@@ -84,13 +84,26 @@ local function asn_check(task)
   
   if ip and ip:is_valid() then
     local req_name = rspamd_logger.slog("%1.%2",
-      table.concat(ip:inversed_str_octets(), '.'), asn_provider)
+      table.concat(ip:inversed_str_octets(), '.'), options['asn_provider'])
     
     task:get_resolver():resolve_txt(task:get_session(), task:get_mempool(),
         req_name, asn_dns_cb)
   end
 end
 
+local function ip_score_hash_key(asn, country, ipnet, ip)
+  -- We use the most common attribute as hashing key
+  if country then
+    return country
+  elseif asn then
+    return asn
+  elseif ipnet then
+    return ipnet
+  else
+    return ip:to_string()
+  end
+end
+
 local function ip_score_get_task_vars(task)
   local pool = task:get_mempool()
   local asn, country, ipnet
@@ -109,6 +122,17 @@ end
 
 -- Set score based on metric's action
 local ip_score_set = function(task)
+  local function new_score_set(score, old_score, old_total)
+    local new_total
+    if old_total == -1 then
+      new_total = 1
+    else
+      new_total = old_total + 1
+    end
+    
+    return old_score + score, new_total
+  end
+
   local score_set_cb = function(task, err, data)
     if err then
       rspamd_logger.infox('got error while IP score changing: %1', err)
@@ -120,7 +144,7 @@ local ip_score_set = function(task)
   if not ip or not ip:is_valid() then
     return
   end
-
+  
   -- Check whitelist
   if whitelist then
     if whitelist:get_key(ip) then
@@ -136,41 +160,34 @@ local ip_score_set = function(task)
         ipnet_score,total_ipnet,
         ip_score, total_ip = pool:get_variable('ip_score', 
         'double,double,double,double,double,double,double,double')
-  
 
-  rspamd_logger.infox('%1', action)
-  local score = 0
-  if scores[action] then
-    score = scores[action]
+  local score_mult = 0
+  if options['actions'][action] then
+    score_mult = options['actions'][action]
   end
-  
+  local score = task:get_metric_score(options['metric'])[1]
+  if action == 'no action' and score > 0 then
+    score_mult = 0
+  end
+
+  score = score_mult * rspamd_util.tanh (2.718 * score)
+
   if score ~= 0 then
-    local hkey = ip:to_string()
+    local hkey = ip_score_hash_key(asn, country, ipnet, ip)
     local upstream = upstreams:get_upstream_by_hash(hkey)
     local addr = upstream:get_addr()
+    
+    asn_score,total_asn = new_score_set(score, asn_score, total_asn)
+    country_score,total_country = new_score_set(score, country_score, total_country)
+    ipnet_score,total_ipnet = new_score_set(score, ipnet_score, total_ipnet)
+    ip_score,total_ip = new_score_set(score, ip_score, total_ip)
+    
     rspamd_redis.make_request(task, addr, score_set_cb, 
-      'HINCRBY', {score_hash, hkey, score})
-    if country then
-      hkey = country_prefix .. country
-      local upstream = upstreams:get_upstream_by_hash(hkey)
-      local addr = upstream:get_addr()
-      rspamd_redis.make_request(task, addr, score_set_cb, 
-        'HINCRBY', {score_hash, hkey, score})
-    end
-    if asn then
-      hkey = asn_prefix .. asn
-      local upstream = upstreams:get_upstream_by_hash(hkey)
-      local addr = upstream:get_addr()
-      rspamd_redis.make_request(task, addr, score_set_cb, 
-        'HINCRBY', {score_hash, hkey, score})
-    end
-    if ipnet then
-      hkey = ipnet_prefix .. ipnet
-      local upstream = upstreams:get_upstream_by_hash(hkey)
-      local addr = upstream:get_addr()
-      rspamd_redis.make_request(task, addr, score_set_cb, 
-        'HINCRBY', {score_hash, hkey, score})
-    end
+      'HMSET', {options['hash'], 
+      options['asn_prefix'] .. asn, string.format('%f|%d', asn_score, total_asn),
+      options['country_prefix'] .. country, string.format('%f|%d', country_score, total_country),
+      options['ipnet_prefix'] .. ipnet, string.format('%f|%d', ipnet_score, total_ipnet),
+      ip:to_string(), string.format('%f|%d', ip_score, total_ip)})
   end
 end
 
@@ -251,31 +268,32 @@ local ip_score_check = function(task)
       
       if total_score ~= 0 then
         task:insert_result(options['symbol'], total_score, description)
+      end
     end
   end
   
   local function create_get_command(ip, asn, country, ipnet)
     local cmd = 'HMGET'
     
-    local args = {score_hash}
+    local args = {options['hash']}
     
     if asn then
-      table.insert(args, asn_prefix .. asn)
+      table.insert(args, options['asn_prefix'] .. asn)
     else
       -- fake arg
-      table.insert(args, asn_prefix)
+      table.insert(args, options['asn_prefix'])
     end
     if country then
-      table.insert(args, country_prefix .. country)
+      table.insert(args, options['country_prefix'] .. country)
     else
       -- fake arg
-      table.insert(args, country_prefix)
+      table.insert(args, options['country_prefix'])
     end
     if ipnet then
-      table.insert(args, ipnet_prefix .. ipnet)
+      table.insert(args, options['ipnet_prefix'] .. ipnet)
     else
       -- fake arg
-      table.insert(args, ipnet_prefix)
+      table.insert(args, options['ipnet_prefix'])
     end
     
     table.insert(args, ip:to_string())
@@ -293,7 +311,8 @@ local ip_score_check = function(task)
     end
 
     local cmd, args = create_get_command(ip, asn, country, ipnet)
-    local upstream = upstreams:get_upstream_by_hash(ip:to_string())
+    local upstream = upstreams:get_upstream_by_hash(
+      ip_score_hash_key(asn, country, ipnet, ip))
     local addr = upstream:get_addr()
     rspamd_redis.make_request(task, addr, ip_score_redis_cb, cmd, args)
   end
@@ -321,11 +340,11 @@ end
 
 
 configure_ip_score_module()
-if upstreams and normalize_score > 0 then
+if upstreams then
   -- Register ip_score module
-  if asn_provider then
+  if options['asn_provider'] then
     rspamd_config:register_pre_filter(asn_check)
   end
-  rspamd_config:register_symbol(symbol, 1.0, ip_score_check)
+  rspamd_config:register_symbol(options['symbol'], 1.0, ip_score_check)
   rspamd_config:register_post_filter(ip_score_set)
 end