You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_bayes_learn.lua 4.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. --[[
  2. Copyright (c) 2019, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. -- This file contains functions to simplify bayes classifier auto-learning
  14. local lua_util = require "lua_util"
  15. local N = "lua_bayes"
  16. local exports = {}
  17. exports.can_learn = function(task, is_spam, is_unlearn)
  18. local learn_type = task:get_request_header('Learn-Type')
  19. if not (learn_type and tostring(learn_type) == 'bulk') then
  20. local prob = task:get_mempool():get_variable('bayes_prob', 'double')
  21. if prob then
  22. local in_class = false
  23. local cl
  24. if is_spam then
  25. cl = 'spam'
  26. in_class = prob >= 0.95
  27. else
  28. cl = 'ham'
  29. in_class = prob <= 0.05
  30. end
  31. if in_class then
  32. return false,string.format(
  33. 'already in class %s; probability %.2f%%',
  34. cl, math.abs((prob - 0.5) * 200.0))
  35. end
  36. end
  37. end
  38. return true
  39. end
  40. exports.autolearn = function(task, conf)
  41. local function log_can_autolearn(verdict, score, threshold)
  42. local from = task:get_from('smtp')
  43. local mime_rcpts = 'undef'
  44. local mr = task:get_recipients('mime')
  45. if mr then
  46. for _,r in ipairs(mr) do
  47. if mime_rcpts == 'undef' then
  48. mime_rcpts = r.addr
  49. else
  50. mime_rcpts = mime_rcpts .. ',' .. r.addr
  51. end
  52. end
  53. end
  54. lua_util.debugm(N, task, 'id: %s, from: <%s>: can autolearn %s: score %s %s %s, mime_rcpts: <%s>',
  55. task:get_header('Message-Id') or '<undef>',
  56. from and from[1].addr or 'undef',
  57. verdict,
  58. string.format("%.2f", score),
  59. verdict == 'ham' and '<=' or verdict == 'spam' and '>=' or '/',
  60. threshold,
  61. mime_rcpts)
  62. end
  63. -- We have autolearn config so let's figure out what is requested
  64. local verdict,score = lua_util.get_task_verdict(task)
  65. local learn_spam,learn_ham = false, false
  66. if verdict == 'passthrough' then
  67. -- No need to autolearn
  68. lua_util.debugm(N, task, 'no need to autolearn - verdict: %s',
  69. verdict)
  70. return
  71. end
  72. if conf.spam_threshold and conf.ham_threshold then
  73. if verdict == 'spam' then
  74. if conf.spam_threshold and score >= conf.spam_threshold then
  75. log_can_autolearn(verdict, score, conf.spam_threshold)
  76. learn_spam = true
  77. end
  78. elseif verdict == 'ham' then
  79. if conf.ham_threshold and score <= conf.ham_threshold then
  80. log_can_autolearn(verdict, score, conf.ham_threshold)
  81. learn_ham = true
  82. end
  83. end
  84. end
  85. if conf.check_balance then
  86. -- Check balance of learns
  87. local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0
  88. local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0
  89. local min_balance = 0.9
  90. if conf.min_balance then min_balance = conf.min_balance end
  91. if spam_learns > 0 or ham_learns > 0 then
  92. local max_ratio = 1.0 / min_balance
  93. local spam_learns_ratio = spam_learns / (ham_learns + 1)
  94. if spam_learns_ratio > max_ratio and learn_spam then
  95. lua_util.debugm(N, task,
  96. 'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
  97. spam_learns_ratio, min_balance, spam_learns, ham_learns)
  98. learn_spam = false
  99. end
  100. local ham_learns_ratio = ham_learns / (spam_learns + 1)
  101. if ham_learns_ratio > max_ratio and learn_ham then
  102. lua_util.debugm(N, task,
  103. 'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
  104. ham_learns_ratio, min_balance, spam_learns, ham_learns)
  105. learn_ham = false
  106. end
  107. end
  108. end
  109. if learn_spam then
  110. return 'spam'
  111. elseif learn_ham then
  112. return 'ham'
  113. end
  114. end
  115. return exports