You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_bayes_learn.lua 3.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. --[[
  2. Copyright (c) 2019, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. -- This file contains functions to simplify bayes classifier auto-learning
  14. local lua_util = require "lua_util"
  15. local N = "lua_bayes"
  16. local exports = {}
  17. exports.can_learn = function(task, is_spam, is_unlearn)
  18. local learn_type = task:get_request_header('Learn-Type')
  19. if not (learn_type and tostring(learn_type) == 'bulk') then
  20. local prob = task:get_mempool():get_variable('bayes_prob', 'double')
  21. if prob then
  22. local in_class = false
  23. local cl
  24. if is_spam then
  25. cl = 'spam'
  26. in_class = prob >= 0.95
  27. else
  28. cl = 'ham'
  29. in_class = prob <= 0.05
  30. end
  31. if in_class then
  32. return false,string.format(
  33. 'already in class %s; probability %.2f%%',
  34. cl, math.abs((prob - 0.5) * 200.0))
  35. end
  36. end
  37. end
  38. return true
  39. end
  40. exports.autolearn = function(task, conf)
  41. -- We have autolearn config so let's figure out what is requested
  42. local verdict,score = lua_util.get_task_verdict(task)
  43. local learn_spam,learn_ham = false, false
  44. if verdict == 'passthrough' then
  45. -- No need to autolearn
  46. lua_util.debugm(N, task, 'no need to autolearn - verdict: %s',
  47. verdict)
  48. return
  49. end
  50. if conf.spam_threshold and conf.ham_threshold then
  51. if verdict == 'spam' then
  52. if conf.spam_threshold and score >= conf.spam_threshold then
  53. lua_util.debugm(N, task, 'can autolearn spam: score %s >= %s',
  54. score, conf.spam_threshold)
  55. learn_spam = true
  56. end
  57. elseif verdict == 'ham' then
  58. if conf.ham_threshold and score <= conf.ham_threshold then
  59. lua_util.debugm(N, task, 'can autolearn ham: score %s <= %s',
  60. score, conf.ham_threshold)
  61. learn_ham = true
  62. end
  63. end
  64. end
  65. if conf.check_balance then
  66. -- Check balance of learns
  67. local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0
  68. local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0
  69. local min_balance = 0.9
  70. if conf.min_balance then min_balance = conf.min_balance end
  71. if spam_learns > 0 or ham_learns > 0 then
  72. local max_ratio = 1.0 / min_balance
  73. local spam_learns_ratio = spam_learns / (ham_learns + 1)
  74. if spam_learns_ratio > max_ratio and learn_spam then
  75. lua_util.debugm(N, task,
  76. 'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
  77. spam_learns_ratio, min_balance, spam_learns, ham_learns)
  78. learn_spam = false
  79. end
  80. local ham_learns_ratio = ham_learns / (spam_learns + 1)
  81. if ham_learns_ratio > max_ratio and learn_ham then
  82. lua_util.debugm(N, task,
  83. 'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
  84. ham_learns_ratio, min_balance, spam_learns, ham_learns)
  85. learn_ham = false
  86. end
  87. end
  88. end
  89. if learn_spam then
  90. return 'spam'
  91. elseif learn_ham then
  92. return 'ham'
  93. end
  94. end
  95. return exports