aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/lua_bayes_learn.lua
blob: 89470edba9b205e29a6112862187aaa39d55c72f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
--[[
Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]--

-- This file contains functions to simplify bayes classifier auto-learning

local lua_util = require "lua_util"
local lua_verdict = require "lua_verdict"
local logger = require "rspamd_logger"
local N = "lua_bayes"

local exports = {}

exports.can_learn = function(task, is_spam, is_unlearn)
  local learn_type = task:get_request_header('Learn-Type')

  if not (learn_type and tostring(learn_type) == 'bulk') then
    local prob = task:get_mempool():get_variable('bayes_prob', 'double')

    if prob then
      local in_class = false
      local cl
      if is_spam then
        cl = 'spam'
        in_class = prob >= 0.95
      else
        cl = 'ham'
        in_class = prob <= 0.05
      end

      if in_class then
        return false, string.format(
            'already in class %s; probability %.2f%%',
            cl, math.abs((prob - 0.5) * 200.0))
      end
    end
  end

  return true
end

exports.autolearn = function(task, conf)
  local function log_can_autolearn(verdict, score, threshold)
    local from = task:get_from('smtp')
    local mime_rcpts = 'undef'
    local mr = task:get_recipients('mime')
    if mr then
      local r_addrs = {}
      for _, r in ipairs(mr) do
        r_addrs[#r_addrs + 1] = r.addr
      end
      if #r_addrs > 0 then
        mime_rcpts = table.concat(r_addrs, ',')
      end
    end

    logger.info(task, 'id: %s, from: <%s>: can autolearn %s: score %s %s %s, mime_rcpts: <%s>',
        task:get_header('Message-Id') or '<undef>',
        from and from[1].addr or 'undef',
        verdict,
        string.format("%.2f", score),
        verdict == 'ham' and '<=' or verdict == 'spam' and '>=' or '/',
        threshold,
        mime_rcpts)
  end

  if not task:get_queue_id() then
    -- We should skip messages that come from `rspamc` or webui as they are usually
    -- not intended for autolearn at all
    lua_util.debugm(N, task, 'no need to autolearn - queue id is missing')
    return
  end

  -- We have autolearn config so let's figure out what is requested
  local verdict, score = lua_verdict.get_specific_verdict("bayes", task)
  local learn_spam, learn_ham = false, false

  if verdict == 'passthrough' then
    -- No need to autolearn
    lua_util.debugm(N, task, 'no need to autolearn - verdict: %s',
        verdict)
    return
  end

  if conf.spam_threshold and conf.ham_threshold then
    if verdict == 'spam' then
      if conf.spam_threshold and score >= conf.spam_threshold then
        log_can_autolearn(verdict, score, conf.spam_threshold)
        learn_spam = true
      end
    elseif verdict == 'junk' then
      if conf.junk_threshold and score >= conf.junk_threshold then
        log_can_autolearn(verdict, score, conf.junk_threshold)
        learn_spam = true
      end
    elseif verdict == 'ham' then
      if conf.ham_threshold and score <= conf.ham_threshold then
        log_can_autolearn(verdict, score, conf.ham_threshold)
        learn_ham = true
      end
    end
  elseif conf.learn_verdict then
    if verdict == 'spam' or verdict == 'junk' then
      learn_spam = true
    elseif verdict == 'ham' then
      learn_ham = true
    end
  end

  if conf.check_balance then
    -- Check balance of learns
    local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0
    local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0

    local min_balance = 0.9
    if conf.min_balance then
      min_balance = conf.min_balance
    end

    if spam_learns > 0 or ham_learns > 0 then
      local max_ratio = 1.0 / min_balance
      local spam_learns_ratio = spam_learns / (ham_learns + 1)
      if spam_learns_ratio > max_ratio and learn_spam then
        lua_util.debugm(N, task,
            'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
            spam_learns_ratio, min_balance, spam_learns, ham_learns)
        learn_spam = false
      end

      local ham_learns_ratio = ham_learns / (spam_learns + 1)
      if ham_learns_ratio > max_ratio and learn_ham then
        lua_util.debugm(N, task,
            'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
            ham_learns_ratio, min_balance, spam_learns, ham_learns)
        learn_ham = false
      end
    end
  end

  if learn_spam then
    return 'spam'
  elseif learn_ham then
    return 'ham'
  end
end

return exports