You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_bayes_learn.lua 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. -- This file contains functions to simplify bayes classifier auto-learning
  14. local lua_util = require "lua_util"
  15. local lua_verdict = require "lua_verdict"
  16. local N = "lua_bayes"
  17. local exports = {}
  18. exports.can_learn = function(task, is_spam, is_unlearn)
  19. local learn_type = task:get_request_header('Learn-Type')
  20. if not (learn_type and tostring(learn_type) == 'bulk') then
  21. local prob = task:get_mempool():get_variable('bayes_prob', 'double')
  22. if prob then
  23. local in_class = false
  24. local cl
  25. if is_spam then
  26. cl = 'spam'
  27. in_class = prob >= 0.95
  28. else
  29. cl = 'ham'
  30. in_class = prob <= 0.05
  31. end
  32. if in_class then
  33. return false, string.format(
  34. 'already in class %s; probability %.2f%%',
  35. cl, math.abs((prob - 0.5) * 200.0))
  36. end
  37. end
  38. end
  39. return true
  40. end
  41. exports.autolearn = function(task, conf)
  42. local function log_can_autolearn(verdict, score, threshold)
  43. local from = task:get_from('smtp')
  44. local mime_rcpts = 'undef'
  45. local mr = task:get_recipients('mime')
  46. if mr then
  47. for _, r in ipairs(mr) do
  48. if mime_rcpts == 'undef' then
  49. mime_rcpts = r.addr
  50. else
  51. mime_rcpts = mime_rcpts .. ',' .. r.addr
  52. end
  53. end
  54. end
  55. lua_util.debugm(N, task, 'id: %s, from: <%s>: can autolearn %s: score %s %s %s, mime_rcpts: <%s>',
  56. task:get_header('Message-Id') or '<undef>',
  57. from and from[1].addr or 'undef',
  58. verdict,
  59. string.format("%.2f", score),
  60. verdict == 'ham' and '<=' or verdict == 'spam' and '>=' or '/',
  61. threshold,
  62. mime_rcpts)
  63. end
  64. -- We have autolearn config so let's figure out what is requested
  65. local verdict, score = lua_verdict.get_specific_verdict("bayes", task)
  66. local learn_spam, learn_ham = false, false
  67. if verdict == 'passthrough' then
  68. -- No need to autolearn
  69. lua_util.debugm(N, task, 'no need to autolearn - verdict: %s',
  70. verdict)
  71. return
  72. end
  73. if conf.spam_threshold and conf.ham_threshold then
  74. if verdict == 'spam' then
  75. if conf.spam_threshold and score >= conf.spam_threshold then
  76. log_can_autolearn(verdict, score, conf.spam_threshold)
  77. learn_spam = true
  78. end
  79. elseif verdict == 'junk' then
  80. if conf.junk_threshold and score >= conf.junk_threshold then
  81. log_can_autolearn(verdict, score, conf.junk_threshold)
  82. learn_spam = true
  83. end
  84. elseif verdict == 'ham' then
  85. if conf.ham_threshold and score <= conf.ham_threshold then
  86. log_can_autolearn(verdict, score, conf.ham_threshold)
  87. learn_ham = true
  88. end
  89. end
  90. elseif conf.learn_verdict then
  91. if verdict == 'spam' or verdict == 'junk' then
  92. learn_spam = true
  93. elseif verdict == 'ham' then
  94. learn_ham = true
  95. end
  96. end
  97. if conf.check_balance then
  98. -- Check balance of learns
  99. local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0
  100. local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0
  101. local min_balance = 0.9
  102. if conf.min_balance then
  103. min_balance = conf.min_balance
  104. end
  105. if spam_learns > 0 or ham_learns > 0 then
  106. local max_ratio = 1.0 / min_balance
  107. local spam_learns_ratio = spam_learns / (ham_learns + 1)
  108. if spam_learns_ratio > max_ratio and learn_spam then
  109. lua_util.debugm(N, task,
  110. 'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
  111. spam_learns_ratio, min_balance, spam_learns, ham_learns)
  112. learn_spam = false
  113. end
  114. local ham_learns_ratio = ham_learns / (spam_learns + 1)
  115. if ham_learns_ratio > max_ratio and learn_ham then
  116. lua_util.debugm(N, task,
  117. 'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
  118. ham_learns_ratio, min_balance, spam_learns, ham_learns)
  119. learn_ham = false
  120. end
  121. end
  122. end
  123. if learn_spam then
  124. return 'spam'
  125. elseif learn_ham then
  126. return 'ham'
  127. end
  128. end
  129. return exports