You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_bayes_learn.lua 4.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. --[[
  2. Copyright (c) 2019, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. -- This file contains functions to simplify bayes classifier auto-learning
  14. local lua_util = require "lua_util"
  15. local lua_verdict = require "lua_verdict"
  16. local N = "lua_bayes"
  17. local exports = {}
  18. exports.can_learn = function(task, is_spam, is_unlearn)
  19. local learn_type = task:get_request_header('Learn-Type')
  20. if not (learn_type and tostring(learn_type) == 'bulk') then
  21. local prob = task:get_mempool():get_variable('bayes_prob', 'double')
  22. if prob then
  23. local in_class = false
  24. local cl
  25. if is_spam then
  26. cl = 'spam'
  27. in_class = prob >= 0.95
  28. else
  29. cl = 'ham'
  30. in_class = prob <= 0.05
  31. end
  32. if in_class then
  33. return false,string.format(
  34. 'already in class %s; probability %.2f%%',
  35. cl, math.abs((prob - 0.5) * 200.0))
  36. end
  37. end
  38. end
  39. return true
  40. end
  41. exports.autolearn = function(task, conf)
  42. local function log_can_autolearn(verdict, score, threshold)
  43. local from = task:get_from('smtp')
  44. local mime_rcpts = 'undef'
  45. local mr = task:get_recipients('mime')
  46. if mr then
  47. for _,r in ipairs(mr) do
  48. if mime_rcpts == 'undef' then
  49. mime_rcpts = r.addr
  50. else
  51. mime_rcpts = mime_rcpts .. ',' .. r.addr
  52. end
  53. end
  54. end
  55. lua_util.debugm(N, task, 'id: %s, from: <%s>: can autolearn %s: score %s %s %s, mime_rcpts: <%s>',
  56. task:get_header('Message-Id') or '<undef>',
  57. from and from[1].addr or 'undef',
  58. verdict,
  59. string.format("%.2f", score),
  60. verdict == 'ham' and '<=' or verdict == 'spam' and '>=' or '/',
  61. threshold,
  62. mime_rcpts)
  63. end
  64. -- We have autolearn config so let's figure out what is requested
  65. local verdict,score = lua_verdict.get_specific_verdict("bayes", task)
  66. local learn_spam,learn_ham = false, false
  67. if verdict == 'passthrough' then
  68. -- No need to autolearn
  69. lua_util.debugm(N, task, 'no need to autolearn - verdict: %s',
  70. verdict)
  71. return
  72. end
  73. if conf.spam_threshold and conf.ham_threshold then
  74. if verdict == 'spam' then
  75. if conf.spam_threshold and score >= conf.spam_threshold then
  76. log_can_autolearn(verdict, score, conf.spam_threshold)
  77. learn_spam = true
  78. end
  79. elseif verdict == 'ham' then
  80. if conf.ham_threshold and score <= conf.ham_threshold then
  81. log_can_autolearn(verdict, score, conf.ham_threshold)
  82. learn_ham = true
  83. end
  84. end
  85. elseif conf.learn_verdict then
  86. if verdict == 'spam' or verdict == 'junk' then
  87. learn_spam = true
  88. elseif verdict == 'ham' then
  89. learn_ham = true
  90. end
  91. end
  92. if conf.check_balance then
  93. -- Check balance of learns
  94. local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0
  95. local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0
  96. local min_balance = 0.9
  97. if conf.min_balance then min_balance = conf.min_balance end
  98. if spam_learns > 0 or ham_learns > 0 then
  99. local max_ratio = 1.0 / min_balance
  100. local spam_learns_ratio = spam_learns / (ham_learns + 1)
  101. if spam_learns_ratio > max_ratio and learn_spam then
  102. lua_util.debugm(N, task,
  103. 'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
  104. spam_learns_ratio, min_balance, spam_learns, ham_learns)
  105. learn_spam = false
  106. end
  107. local ham_learns_ratio = ham_learns / (spam_learns + 1)
  108. if ham_learns_ratio > max_ratio and learn_ham then
  109. lua_util.debugm(N, task,
  110. 'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
  111. ham_learns_ratio, min_balance, spam_learns, ham_learns)
  112. learn_ham = false
  113. end
  114. end
  115. end
  116. if learn_spam then
  117. return 'spam'
  118. elseif learn_ham then
  119. return 'ham'
  120. end
  121. end
  122. return exports