]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Add verdict library in lua
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 4 Nov 2019 17:53:58 +0000 (17:53 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 4 Nov 2019 17:53:58 +0000 (17:53 +0000)
lualib/lua_bayes_learn.lua
lualib/lua_util.lua
lualib/lua_verdict.lua [new file with mode: 0644]
src/plugins/lua/clustering.lua
src/plugins/lua/neural.lua
src/plugins/lua/ratelimit.lua
src/plugins/lua/reputation.lua

index 066e86a4df04a14576c4c1b540eb1b3c12262d69..ae8d901f8b6139e1385280fc71ec97a4b32dbb6f 100644 (file)
@@ -17,7 +17,7 @@ limitations under the License.
 -- This file contains functions to simplify bayes classifier auto-learning
 
 local lua_util = require "lua_util"
-
+local lua_verdict = require "lua_verdict"
 local N = "lua_bayes"
 
 local exports = {}
@@ -76,7 +76,7 @@ exports.autolearn = function(task, conf)
   end
 
   -- We have autolearn config so let's figure out what is requested
-  local verdict,score = lua_util.get_task_verdict(task)
+  local verdict,score = lua_verdict.get_specific_verdict("bayes", task)
   local learn_spam,learn_ham = false, false
 
   if verdict == 'passthrough' then
@@ -98,6 +98,12 @@ exports.autolearn = function(task, conf)
         learn_ham = true
       end
     end
+  elseif conf.learn_verdict then
+    if verdict == 'spam' or verdict == 'junk' then
+      learn_spam = true
+    elseif verdict == 'ham' then
+      learn_ham = true
+    end
   end
 
   if conf.check_balance then
index 842f079b2d024957fb4b7f9a601ee647097cf893..bda8b0c02c5fd4c2ff4c00db98e95cede8acaa4a 100644 (file)
@@ -1026,35 +1026,9 @@ end
 -- * `uncertain`: all other cases
 --]]
 exports.get_task_verdict = function(task)
-  local result = task:get_metric_result()
+  local lua_verdict = require "lua_verdict"
 
-  if result then
-
-    if result.passthrough then
-      return 'passthrough',nil
-    end
-
-    local score = result.score
-
-    local action = result.action
-
-    if action == 'reject' and result.npositive > 1 then
-      return 'spam',score
-    elseif action == 'no action' then
-      if score < 0 or result.nnegative > 3 then
-        return 'ham',score
-      end
-    else
-      -- All colors of junk
-      if action == 'add header' or action == 'rewrite subject' then
-        if result.npositive > 2 then
-          return 'junk',score
-        end
-      end
-    end
-
-    return 'uncertain',score
-  end
+  return lua_verdict.get_default_verdict(task)
 end
 
 ---[[[
diff --git a/lualib/lua_verdict.lua b/lualib/lua_verdict.lua
new file mode 100644 (file)
index 0000000..d6a1634
--- /dev/null
@@ -0,0 +1,189 @@
+--[[
+Copyright (c) 2019, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local exports = {}
+
+---[[[
+-- @function lua_verdict.get_default_verdict(task)
+-- Returns verdict for a task + score if certain, must be called from idempotent filters only
+-- Returns string:
+-- * `spam`: if message have over reject threshold and has more than one positive rule
+-- * `junk`: if a message has between score between [add_header/rewrite subject] to reject thresholds and has more than two positive rules
+-- * `passthrough`: if a message has been passed through some short-circuit rule
+-- * `ham`: if a message has overall score below junk level **and** more than three negative rule, or negative total score
+-- * `uncertain`: all other cases
+--]]
+local function default_verdict_function(task)
+  local result = task:get_metric_result()
+
+  if result then
+
+    if result.passthrough then
+      return 'passthrough',nil
+    end
+
+    local score = result.score
+
+    local action = result.action
+
+    if action == 'reject' and result.npositive > 1 then
+      return 'spam',score
+    elseif action == 'no action' then
+      if score < 0 or result.nnegative > 3 then
+        return 'ham',score
+      end
+    else
+      -- All colors of junk
+      if action == 'add header' or action == 'rewrite subject' then
+        if result.npositive > 2 then
+          return 'junk',score
+        end
+      end
+    end
+
+    return 'uncertain',score
+  end
+end
+
+local default_possible_verdicts = {
+  passthrough = {
+    can_learn = false,
+    description = 'message has passthrough result',
+  },
+  spam = {
+    can_learn = 'spam',
+    description = 'message is likely spam',
+  },
+  junk = {
+    can_learn = 'spam',
+    description = 'message is likely possible spam',
+  },
+  ham = {
+    can_learn = 'ham',
+    description = 'message is likely ham',
+  },
+  uncertain = {
+    can_learn = false,
+    description = 'not certainity in verdict'
+  }
+}
+
+-- Verdict functions specific for modules
+local specific_verdicts = {
+  default = {
+    callback = default_verdict_function,
+    possible_verdicts = default_possible_verdicts
+  }
+}
+
+local default_verdict = specific_verdicts.default
+
+exports.get_default_verdict = default_verdict.callback
+exports.set_verdict_function = function(func, what)
+  assert(type(func) == 'function')
+  if not what then
+    -- Default verdict
+    local existing = specific_verdicts.default.callback
+    specific_verdicts.default.callback = func
+    exports.get_default_verdict = func
+
+    return existing
+  else
+    local existing = specific_verdicts[what]
+
+    if not existing then
+      specific_verdicts[what] = {
+        callback = func,
+        possible_verdicts = default_possible_verdicts
+      }
+    else
+      existing = existing.callback
+    end
+
+    specific_verdicts[what].callback = func
+    return existing
+  end
+end
+
+exports.set_verdict_table = function(verdict_tbl, what)
+  assert(type(verdict_tbl) == 'table' and
+    type(verdict_tbl.callback) == 'function' and
+    type(verdict_tbl.possible_verdicts) == 'table')
+
+  if not what then
+    -- Default verdict
+    local existing = specific_verdicts.default
+    specific_verdicts.default = verdict_tbl
+    exports.get_default_verdict = specific_verdicts.default.callback
+
+    return existing
+  else
+    local existing = specific_verdicts[what]
+    specific_verdicts[what] = verdict_tbl
+    return existing
+  end
+end
+
+exports.get_specific_verdict = function(what, task)
+  if specific_verdicts[what] then
+    return specific_verdicts[what].callback(task)
+  end
+
+  return exports.get_default_verdict(task)
+end
+
+exports.get_possible_verdicts = function(what)
+  local lua_util = require "lua_util"
+  if what then
+    if specific_verdicts[what] then
+      return lua_util.keys(specific_verdicts[what].possible_verdicts)
+    end
+  else
+    return lua_util.keys(specific_verdicts.default.possible_verdicts)
+  end
+
+  return nil
+end
+
+exports.can_learn = function(verdict, what)
+  if what then
+    if specific_verdicts[what] and specific_verdicts[what].possible_verdicts[verdict] then
+      return specific_verdicts[what].possible_verdicts[verdict].can_learn
+    end
+  else
+    if specific_verdicts.default.possible_verdicts[verdict] then
+      return specific_verdicts.default.possible_verdicts[verdict].can_learn
+    end
+  end
+
+  return nil -- To distinguish from `false` that could happen in can_learn
+end
+
+exports.describe = function(verdict, what)
+  if what then
+    if specific_verdicts[what] and specific_verdicts[what].possible_verdicts[verdict] then
+      return specific_verdicts[what].possible_verdicts[verdict].description
+    end
+  else
+    if specific_verdicts.default.possible_verdicts[verdict] then
+      return specific_verdicts.default.possible_verdicts[verdict].description
+    end
+  end
+
+  return nil
+end
+
+return exports
\ No newline at end of file
index 0b514476c8421fa4e47cca382b8c45dab3a497f7..45e5d4bc14ff6ff4c11020402e0d19072e4be1f4 100644 (file)
@@ -24,6 +24,7 @@ local N = 'clustering'
 
 local rspamd_logger = require "rspamd_logger"
 local lua_util = require "lua_util"
+local lua_verdict = require "lua_verdict"
 local lua_redis = require "lua_redis"
 local lua_selectors = require "lua_selectors"
 local ts = require("tableshape").types
@@ -190,7 +191,7 @@ local function clusterting_idempotent_cb(task, rule)
   if task:has_flag('skip') then return end
   if not rule.allow_local and lua_util.is_rspamc_or_controller(task) then return end
 
-  local verdict = lua_util.get_task_verdict(task)
+  local verdict = lua_verdict.get_specific_verdict(N, task)
   local score
 
   if verdict == 'ham' then
index cb3a5130003a125d3e56126814f0739359176a3e..d96ca896ade116243289ab946fcf80cbd6e7a5d0 100644 (file)
@@ -28,6 +28,7 @@ local fun = require "fun"
 local lua_settings = require "lua_settings"
 local meta_functions = require "lua_meta"
 local ts = require("tableshape").types
+local lua_verdict = require "lua_verdict"
 local N = "neural"
 
 -- Module vars
@@ -1164,7 +1165,7 @@ local function ann_push_vector(task)
     return
   end
 
-  local verdict,score = lua_util.get_task_verdict(task)
+  local verdict,score = lua_verdict.get_specific_verdict(N, task)
 
   if verdict == 'passthrough' then
     lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
index bd77b1ad6e8909143702df1b81e1a89d5681c127..731bd7ac36417547501f6e4749931efb654e3d81 100644 (file)
@@ -26,6 +26,7 @@ local lua_redis = require "lua_redis"
 local fun = require "fun"
 local lua_maps = require "lua_maps"
 local lua_util = require "lua_util"
+local lua_verdict = require "lua_verdict"
 local rspamd_hash = require "rspamd_cryptobox_hash"
 local lua_selectors = require "lua_selectors"
 local ts = require("tableshape").types
@@ -666,7 +667,7 @@ local function ratelimit_update_cb(task)
       return
     end
 
-    local verdict = lua_util.get_task_verdict(task)
+    local verdict = lua_verdict.get_specific_verdict(N, task)
     local _,nrcpt = task:has_recipients('smtp')
     if not nrcpt or nrcpt <= 0 then
       nrcpt = 1
index 8b2b8c5a4aab6c76ed4e84a49a5cd0f8c837bdcb..f88973d71fbcf8252c4119a47480030a2964c34c 100644 (file)
@@ -109,9 +109,10 @@ end
 
 -- Extracts task score and subtracts score of the rule itself
 local function extract_task_score(task, rule)
-  local _,score = lua_util.get_task_verdict(task)
+  local lua_verdict = require "lua_verdict"
+  local verdict,score = lua_verdict.get_specific_verdict(N, task)
 
-  if not score then return nil end
+  if not score or verdict == 'passthrough' then return nil end
 
   return sub_symbol_score(task, rule, score)
 end