]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Neural: Further fixes
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 7 Jul 2019 18:45:08 +0000 (19:45 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 7 Jul 2019 18:45:08 +0000 (19:45 +0100)
src/plugins/lua/neural.lua

index d2c0191e7a8a40159a41dedc266606340510c93b..b0f307803fb48ff83bcc0a8d9b619aa549eadce5 100644 (file)
@@ -280,7 +280,6 @@ local function new_ann_profile(task, rule, set, version)
   end
 
   lua_redis.redis_make_request(task,
-      rspamd_config,
       rule.redis,
       nil,
       true, -- is write
@@ -347,21 +346,28 @@ local function create_ann(n, nlayers)
 end
 
 
-local function ann_train_callback(rule, task, score, required_score, set)
+local function ann_push_task_result(rule, task, verdict, score, set)
   local train_opts = rule.train
 
+
   local learn_spam, learn_ham
 
   if train_opts.autotrain then
+    if verdict == 'passthrough' or verdict == 'uncertain' then
+      lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
+          verdict, score)
+    end
+
     if train_opts['spam_score'] then
       learn_spam = score >= train_opts['spam_score']
     else
-      learn_spam = score >= required_score
+      learn_spam = verdict == 'spam' or verdict == 'junk'
     end
+
     if train_opts['ham_score'] then
       learn_ham = score <= train_opts['ham_score']
     else
-      learn_ham = score < 0
+      learn_ham = verdict == 'ham'
     end
   else
     -- Train by request header
@@ -408,7 +414,7 @@ local function ann_train_callback(rule, task, score, required_score, set)
             true, -- is write
             learn_vec_cb, --callback
             'LPUSH', -- command
-            { set.ann.redis_prefix .. '_' .. learn_type, str} -- arguments
+            { set.ann.redis_key .. '_' .. learn_type, str} -- arguments
         )
       else
         if err then
@@ -948,6 +954,7 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback)
         rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
             err)
       elseif type(data) == 'table' then
+        lua_util.debugm(N, cfg, 'process element %s:%s', rule.prefix, set.name)
         process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
       end
     end
@@ -1000,12 +1007,12 @@ end
 local function ann_push_vector(task)
   if task:has_flag('skip') then return end
   if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
-  local scores = task:get_metric_score()
+  local verdict,score = lua_util.get_task_verdict(task)
   for _,rule in pairs(settings.rules) do
     local sid = task:get_settings_id() or -1
 
     if rule.settings[sid] then
-      ann_train_callback(rule, task, scores[1], scores[2], rule.settings[sid])
+      ann_push_task_result(rule, task, verdict, score, rule.settings[sid])
     end
 
   end
@@ -1124,6 +1131,10 @@ local id = rspamd_config:register_symbol({
   callback = ann_scores_filter
 })
 
+settings = lua_util.override_defaults(settings, module_config)
+settings.rules = {} -- Reset unless validated further in the cycle
+
+-- Check all rules
 for k,r in pairs(rules) do
   local rule_elt = lua_util.override_defaults(default_options, r)
   rule_elt['redis'] = redis_params