1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
|
--[[
Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]--
-- This file contains functions to simplify bayes classifier auto-learning
local lua_util = require "lua_util"
local lua_verdict = require "lua_verdict"
local logger = require "rspamd_logger"
local N = "lua_bayes"
local exports = {}
exports.can_learn = function(task, is_spam, is_unlearn)
local learn_type = task:get_request_header('Learn-Type')
if not (learn_type and tostring(learn_type) == 'bulk') then
local prob = task:get_mempool():get_variable('bayes_prob', 'double')
if prob then
local in_class = false
local cl
if is_spam then
cl = 'spam'
in_class = prob >= 0.95
else
cl = 'ham'
in_class = prob <= 0.05
end
if in_class then
return false, string.format(
'already in class %s; probability %.2f%%',
cl, math.abs((prob - 0.5) * 200.0))
end
end
end
return true
end
exports.autolearn = function(task, conf)
local function log_can_autolearn(verdict, score, threshold)
local from = task:get_from('smtp')
local mime_rcpts = 'undef'
local mr = task:get_recipients('mime')
if mr then
local r_addrs = {}
for _, r in ipairs(mr) do
r_addrs[#r_addrs + 1] = r.addr
end
if #r_addrs > 0 then
mime_rcpts = table.concat(r_addrs, ',')
end
end
logger.info(task, 'id: %s, from: <%s>: can autolearn %s: score %s %s %s, mime_rcpts: <%s>',
task:get_header('Message-Id') or '<undef>',
from and from[1].addr or 'undef',
verdict,
string.format("%.2f", score),
verdict == 'ham' and '<=' or verdict == 'spam' and '>=' or '/',
threshold,
mime_rcpts)
end
if not task:get_queue_id() then
-- We should skip messages that come from `rspamc` or webui as they are usually
-- not intended for autolearn at all
lua_util.debugm(N, task, 'no need to autolearn - queue id is missing')
return
end
-- We have autolearn config so let's figure out what is requested
local verdict, score = lua_verdict.get_specific_verdict("bayes", task)
local learn_spam, learn_ham = false, false
if verdict == 'passthrough' then
-- No need to autolearn
lua_util.debugm(N, task, 'no need to autolearn - verdict: %s',
verdict)
return
end
if conf.spam_threshold and conf.ham_threshold then
if verdict == 'spam' then
if conf.spam_threshold and score >= conf.spam_threshold then
log_can_autolearn(verdict, score, conf.spam_threshold)
learn_spam = true
end
elseif verdict == 'junk' then
if conf.junk_threshold and score >= conf.junk_threshold then
log_can_autolearn(verdict, score, conf.junk_threshold)
learn_spam = true
end
elseif verdict == 'ham' then
if conf.ham_threshold and score <= conf.ham_threshold then
log_can_autolearn(verdict, score, conf.ham_threshold)
learn_ham = true
end
end
elseif conf.learn_verdict then
if verdict == 'spam' or verdict == 'junk' then
learn_spam = true
elseif verdict == 'ham' then
learn_ham = true
end
end
if conf.check_balance then
-- Check balance of learns
local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0
local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0
local min_balance = 0.9
if conf.min_balance then
min_balance = conf.min_balance
end
if spam_learns > 0 or ham_learns > 0 then
local max_ratio = 1.0 / min_balance
local spam_learns_ratio = spam_learns / (ham_learns + 1)
if spam_learns_ratio > max_ratio and learn_spam then
lua_util.debugm(N, task,
'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
spam_learns_ratio, min_balance, spam_learns, ham_learns)
learn_spam = false
end
local ham_learns_ratio = ham_learns / (spam_learns + 1)
if ham_learns_ratio > max_ratio and learn_ham then
lua_util.debugm(N, task,
'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
ham_learns_ratio, min_balance, spam_learns, ham_learns)
learn_ham = false
end
end
end
if learn_spam then
return 'spam'
elseif learn_ham then
return 'ham'
end
end
return exports
|