]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Implement preliminary code for fann autolearn
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 5 Apr 2016 16:26:43 +0000 (17:26 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 5 Apr 2016 16:26:43 +0000 (17:26 +0100)
src/plugins/lua/fann_scores.lua

index 66a5e18796576d7aeffe08b878f16ed9e78c40bf..f430d915085bbe07f09f037c7f9f602db8be42f3 100644 (file)
@@ -25,71 +25,192 @@ local ucl = require "ucl"
 
 -- Module vars
 local fann
-local symbols
-local nsymbols = 0
+local fann_train
+local fann_file
+local ntrains = 0
+local max_trains = 1000
+local epoch = 0
+local max_epoch = 100
+local fann_mtime = 0
 local opts = rspamd_config:get_all_opt("fann_scores")
 
-local function fann_scores_filter(task)
-  local fann_input = {}
+local function load_fann()
+  local err,st = rspamd_util.stat(fann_file)
+
+  if err then
+    return false
+  end
+
+  fann = rspamd_fann.load(fann_file)
 
-  for sym,idx in pairs(symbols) do
-    if task:has_symbol(sym) then
-      fann_input[idx + 1] = 1
+  if fann then
+    local n = rspamd_config:get_symbols_count()
+
+    if n ~= fann:get_inputs() then
+      rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
+      ' is found in the cache', fann:get_inputs(), n)
+      fann = nil
     else
-      fann_input[idx + 1] = 0
+      rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fann_file)
+      return true
+    end
+  end
+
+  return false
+end
+
+local function check_fann()
+  local n = rspamd_config:get_symbols_count()
+
+  if fann then
+    local n = rspamd_config:get_symbols_count()
+
+    if n ~= fann:get_inputs() then
+      rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
+      ' is found in the cache', fann:get_inputs(), n)
+      fann = nil
     end
   end
 
-  local out = fann:test(nsymbols, fann_input)
-  local result = rspamd_util.tanh(2 * (out[1] - 0.5))
-  local symscore = string.format('%.3f', out[1])
-  rspamd_logger.infox(task, 'fann score: %s', symscore)
+  local err,st = rspamd_util.stat(fann_file)
 
-  task:insert_result(fann_symbol, result, symscore)
+  if not err then
+    local mtime = st['mtime']
+
+    if mtime > fann_mtime then
+      fann_mtime = mtime
+      fann = nil
+    end
+  end
 end
 
-if not rspamd_fann.is_enabled() then
-  rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
-    'module is eventually disabled')
-else
-  if not opts['fann_file'] or not opts['symbols_file'] then
-    rspamd_logger.errx(rspamd_config, 'fann_scores module requires ' ..
-      '`fann_file` and `symbols_file` to be specified')
-  else
-    fann = rspamd_fann.load(opts['fann_file'])
+local function fann_scores_filter(task)
+  check_fann()
 
-    if not fann then
-      rspamd_logger.errx(rspamd_config, 'cannot load fann from %s',
-        opts['fann_file'])
-      return
+  if fann then
+    local fann_input = {}
+
+    for sym,idx in pairs(symbols) do
+      if task:has_symbol(sym) then
+        fann_input[idx + 1] = 1
+      else
+        fann_input[idx + 1] = 0
+      end
     end
-    -- Parse symbols
-    local parser = ucl.parser()
-    local res, err = parser:parse_file(opts['symbols_file'])
+
+    local out = fann:test(nsymbols, fann_input)
+    local result = rspamd_util.tanh(2 * (out[1] - 0.5))
+    local symscore = string.format('%.3f', out[1])
+    rspamd_logger.infox(task, 'fann score: %s', symscore)
+
+    task:insert_result(fann_symbol, result, symscore)
+  else
+    if load_fann() then
+      fann_scores_filter(task)
+    end
+  end
+end
+
+local function create_train_fann(n)
+  fann_train = rspamd_fann.create(3, n, n / 2, 1)
+  ntrains = 0
+  epoch = 0
+end
+
+local function fann_train(score, required_score,results, cf, opts)
+  local n = cf:get_symbols_count()
+
+  if not fann_train then
+    create_train_fann(n)
+  end
+
+  if fann_train:get_inputs() ~= n then
+    rspamd_logger.infox(cf, 'fann has incorrect number of inputs: %s, %s symbols' ..
+      ' is found in the cache', fann_train:get_inputs(), n)
+    create_train_fann(n)
+  end
+
+  if ntrains > max_trains then
+    -- Store fann on disk
+    res = fann_train:save(fann_file)
+
     if not res then
-      rspamd_logger.errx(rspamd_config, 'cannot load symbols from %s: %s',
-        opts['symbols_file'], err)
-      return
+      rspamd_logger.errx(cf, 'cannot save fann in %s', fann_file)
+    else
+      ntrains = 0
+      epoch = epoch + 1
     end
+  end
+
+  if epoch > max_epoch then
+    -- Re-create fann
+    rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fann_file,
+      max_epoch)
+    create_train_fann(n)
+  end
+
+  local learn_spam, learn_ham = false, false
+  if opts['spam_score'] then
+    learn_spam = score >= opts['spam_score']
+  else
+    learn_spam = score >= required_score
+  end
+  if opts['ham_score'] then
+    learn_ham = score <= opts['ham_score']
+  else
+    learn_ham = score < 0
+  end
+
+  if learn_spam or learn_ham then
+    local learn_data = {}
+    local matched_symbols = {}
 
-    symbols = parser:get_object()
+    for _,sym in ipairs(results) do
+      matched_symbols[sym[1] + 1] = 1
+    end
 
-    -- Check sanity
-    for _,s in pairs(symbols) do nsymbols = nsymbols + 1 end
-    if fann:get_inputs() ~= nsymbols then
-      rspamd_logger.errx(rspamd_config, 'fann number of inputs: %s is not equal' ..
-          ' to symbols count: %s',
-        fann:get_inputs(), nsymbols)
-      return
+    for i=1,(n + 1) do
+      if matched_symbols[i] then
+        learn_data[i] = 1
+      else
+        learn_data[i] = 0
+      end
     end
 
-    if fann:get_outputs() ~= 1 then
-      rspamd_logger.errx(rspamd_config, 'fann nuber of outputs is invalid: %s',
-        fann:get_outputs())
-      return
+    if learn_spam then
+      fann_train:train(learn_data, 1.0)
+    else
+      fann_train:train(learn_data, 0.0)
     end
 
+    trains = trains + 1
+  end
+end
+
+if not rspamd_fann.is_enabled() then
+  rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
+    'module is eventually disabled')
+else
+  if not opts['fann_file'] then
+    rspamd_logger.errx(rspamd_config, 'fann_scores module requires ' ..
+      '`fann_file` to be specified')
+  else
+    fann_file = opts['fann_file']
     rspamd_config:set_metric_symbol(fann_symbol, 3.0, 'Experimental FANN adjustment')
     rspamd_config:register_post_filter(fann_scores_filter)
+
+    if opts['train'] then
+      rspamd_config:add_on_load(function(cfg)
+        if opts['train']['max_train'] then
+          max_trains = opts['train']['max_train']
+        end
+        if opts['train']['max_epoch'] then
+          max_trains = opts['train']['max_epoch']
+        end
+        cfg:register_worker_script("log_helper", function(score, req_score, results, cf)
+          fann_train(score, req_score, results, cf, opts['train'])
+        end)
+      end)
+    end
   end
 end