]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Fann scores now uses metadata from a message
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 12 Sep 2016 15:15:44 +0000 (16:15 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 12 Sep 2016 15:15:44 +0000 (16:15 +0100)
By introducing of extra data, it is now possible to train ANN with
metadata of messages improving quality of filtering.

src/plugins/lua/fann_scores.lua

index 3e4ec2fc5a3290db0e0fcb20491f7ef8a475a87b..b53fed5553c632a3a59a7dcff546570ec41318ab 100644 (file)
@@ -25,6 +25,7 @@ local fann_symbol_ham = 'FANN_HAM'
 require "fun" ()
 local ucl = require "ucl"
 
+local module_log_id = 0x100
 -- Module vars
 -- ANNs indexed by settings id
 local data = {
@@ -34,28 +35,225 @@ local data = {
     epoch = 0,
   }
 }
+
 local fann_file
 local max_trains = 1000
 local max_epoch = 100
 local use_settings = false
-local opts = rspamd_config:get_all_opt("fann_scores")
-if not (opts and type(opts) == 'table') then
-  rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
-  return
+
+
+-- Metafunctions
+local function fann_size_function(task)
+  local sizes = {
+    100,
+    200,
+    500,
+    1000,
+    2000,
+    4000,
+    10000,
+    20000,
+    30000,
+    100000,
+    200000,
+    400000,
+    800000,
+    1000000,
+    2000000,
+    8000000,
+  }
+
+  local size = task:get_size()
+  for i = 1,#sizes do
+    if sizes[i] >= size then
+      return {i / #sizes}
+    end
+  end
+
+  return {0}
+end
+
+local function fann_images_function(task)
+  local images = task:get_images()
+  local ntotal = 0
+  local njpg = 0
+  local npng = 0
+  local nlarge = 0
+  local nsmall = 0
+
+  if images then
+    for _,img in ipairs(images) do
+      if img:get_type() == 'png' then
+        npng = npng + 1
+      elseif img:get_type() == 'jpeg' then
+        njpg = njpg + 1
+      end
+
+      local w = img:get_width()
+      local h = img:get_height()
+
+      if w > 0 and h > 0 then
+        if w + h > 256 then
+          nlarge = nlarge + 1
+        else
+          nsmall = nsmall + 1
+        end
+      end
+
+      ntotal = ntotal + 1
+    end
+  end
+
+  return {ntotal,njpg,npng,nlarge,nsmall}
+end
+
+local function fann_nparts_function(task)
+  local nattachments = 0
+  local ntextparts = 0
+
+  local tp = task:get_text_parts()
+  if tp then
+    ntextparts = #tp
+  end
+
+  local parts = task:get_parts()
+
+  if parts then
+    for _,p in ipairs(parts) do
+      if p:get_filename() then
+        nattachments = nattachments + 1
+      end
+    end
+  end
+
+  return {ntextparts, nattachments}
+end
+
+local function fann_encoding_function(task)
+  local nutf = 0
+  local nother = 0
+
+  local tp = task:get_text_parts()
+  if tp then
+    for _,p in ipairs(tp) do
+      if p:is_utf() then
+        nutf = nutf + 1
+      else
+        nother = nother + 1
+      end
+    end
+  end
+
+  return {nutf, nother}
+end
+
+local function fann_recipients_function(task)
+  local nmime = 0
+  local nsmtp = 0
+
+  if task:has_recipients('mime') then
+    nmime = #(task:get_recipients('mime'))
+  end
+  if task:has_recipients('smtp') then
+    nsmtp = #(task:get_recipients('smtp'))
+  end
+
+  return {nmime,nsmtp}
+end
+
+local function fann_received_function(task)
+  return {#(task:get_received_headers())}
 end
 
-local function symbols_to_fann_vector(syms)
+local function fann_urls_function(task)
+  if task:has_urls() then
+    return {#(task:get_urls())}
+  end
+
+  return {0}
+end
+
+local function fann_attachments_function(task)
+end
+
+local metafunctions = {
+  {
+    cb = fann_size_function,
+    ninputs = 1,
+  },
+  {
+    cb = fann_images_function,
+    ninputs = 5,
+    -- 1 - number of images,
+    -- 2 - number of png images,
+    -- 3 - number of jpeg images
+    -- 4 - number of large images (> 128 x 128)
+    -- 5 - number of small images (< 128 x 128)
+  },
+  {
+    cb = fann_nparts_function,
+    ninputs = 2,
+    -- 1 - number of text parts
+    -- 2 - number of attachments
+  },
+  {
+    cb = fann_encoding_function,
+    ninputs = 2,
+    -- 1 - number of utf parts
+    -- 2 - number of non-utf parts
+  },
+  {
+    cb = fann_recipients_function,
+    ninputs = 2,
+    -- 1 - number of mime rcpt
+    -- 2 - number of smtp rcpt
+  },
+  {
+    cb = fann_received_function,
+    ninputs = 1,
+  },
+  {
+    cb = fann_urls_function,
+    ninputs = 1,
+  },
+}
+
+local function gen_metatokens(task)
+  local metatokens = {}
+  for _,mt in ipairs(metafunctions) do
+    local ct = mt.cb(task)
+
+    for _,tok in ipairs(ct) do
+      table.insert(metatokens, tok)
+    end
+  end
+
+  rspamd_logger.errx(task, "tokens: %s", metatokens)
+
+  return metatokens
+end
+
+local function count_metatokens()
+  local total = 0
+  for _,mt in ipairs(metafunctions) do
+    total = total + mt.ninputs
+  end
+
+  return total
+end
+
+local function symbols_to_fann_vector(syms, scores)
   local learn_data = {}
   local matched_symbols = {}
   local n = rspamd_config:get_symbols_count()
 
-  each(function(s)
-    matched_symbols[s + 1] = 1
-  end, syms)
+  each(function(s, score)
+     matched_symbols[s + 1] = score
+  end, zip(syms, scores))
 
   for i=1,n do
     if matched_symbols[i] then
-      learn_data[i] = 1
+      learn_data[i] = math.abs(matched_symbols[i])
     else
       learn_data[i] = 0
     end
@@ -85,7 +283,7 @@ local function load_fann(id)
   rspamd_util.unlock_file(fd) -- closes fd
 
   if data[id].fann then
-    local n = rspamd_config:get_symbols_count()
+    local n = rspamd_config:get_symbols_count() + count_metatokens()
 
     if n ~= data[id].fann:get_inputs() then
       rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
@@ -115,7 +313,7 @@ end
 
 local function check_fann(id)
   if data[id].fann then
-    local n = rspamd_config:get_symbols_count()
+    local n = rspamd_config:get_symbols_count() + count_metatokens
 
     if n ~= data[id].fann:get_inputs() then
       rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
@@ -151,8 +349,13 @@ local function fann_scores_filter(task)
   check_fann(id)
 
   if data[id].fann then
-    local symbols = task:get_symbols_numeric()
-    local fann_data = symbols_to_fann_vector(symbols)
+    local symbols,scores = task:get_symbols_numeric()
+    local fann_data = symbols_to_fann_vector(symbols, scores)
+    local mt = gen_metatokens(task)
+
+    for _,tok in ipairs(mt) do
+      table.insert(fann_data, tok)
+    end
 
     local out = data[id].fann:test(fann_data)
     local result = rspamd_util.tanh(2 * (out[1] - 0.5))
@@ -177,8 +380,8 @@ local function create_train_fann(n, id)
   data[id].epoch = 0
 end
 
-local function fann_train_callback(score, required_score,results, cf, id, opts)
-  local n = cf:get_symbols_count()
+local function fann_train_callback(score, required_score, results, cf, id, opts, extra)
+  local n = cf:get_symbols_count() + count_metatokens()
   local fname = gen_fann_file(id)
 
   if not data[id].fann_train then
@@ -240,8 +443,11 @@ local function fann_train_callback(score, required_score,results, cf, id, opts)
 
   if learn_spam or learn_ham then
     local learn_data = symbols_to_fann_vector(
-      map(function(r) return r[1] end, results)
+      map(function(r) return r[1] end, results),
+      map(function(r) return r[2] end, results)
     )
+    -- Add filtered meta tokens
+    each(function(e) table.insert(learn_data, e) end, extra)
 
     if learn_spam then
       data[id].fann_train:train(learn_data, {1.0})
@@ -253,6 +459,14 @@ local function fann_train_callback(score, required_score,results, cf, id, opts)
   end
 end
 
+-- Initialization part
+
+local opts = rspamd_config:get_all_opt("fann_scores")
+if not (opts and type(opts) == 'table') then
+  rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
+  return
+end
+
 if not rspamd_fann.is_enabled() then
   rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
     'module is eventually disabled')
@@ -294,15 +508,26 @@ else
           max_epoch = opts['train']['max_epoch']
         end
         cfg:register_worker_script("log_helper",
-          function(score, req_score, results, cf, id)
+          function(score, req_score, results, cf, id, extra)
+            -- map (snd x) (filter (fst x == module_id) extra)
+            local extra_fann = map(function(e) return e[2] end,
+              filter(function(e) return e[1] == module_log_id end, extra))
             if use_settings then
               fann_train_callback(score, req_score, results, cf,
-                tostring(id), opts['train'])
+                tostring(id), opts['train'], extra_fann)
             else
-              fann_train_callback(score, req_score, results, cf, '0', opts['train'])
+              fann_train_callback(score, req_score, results, cf, '0',
+                opts['train'], extra_fann)
             end
         end)
       end)
+      rspamd_plugins["fann_score"] = {
+        log_callback = function(task)
+          return totable(map(
+            function(tok) return {module_log_id, tok} end,
+            gen_metatokens(task)))
+        end
+      }
     end
   end
 end