You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_bayes_redis.lua 5.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]
  13. -- This file contains functions to support Bayes statistics in Redis
  14. local exports = {}
  15. local lua_redis = require "lua_redis"
  16. local logger = require "rspamd_logger"
  17. local lua_util = require "lua_util"
  18. local N = "bayes"
  19. local function gen_classify_functor(redis_params, classify_script_id)
  20. return function(task, expanded_key, id, is_spam, stat_tokens, callback)
  21. local function classify_redis_cb(err, data)
  22. lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data)
  23. if err then
  24. callback(task, false, err)
  25. else
  26. callback(task, true, data[1], data[2], data[3], data[4])
  27. end
  28. end
  29. lua_redis.exec_redis_script(classify_script_id,
  30. { task = task, is_write = false, key = expanded_key },
  31. classify_redis_cb, { expanded_key, stat_tokens })
  32. end
  33. end
  34. local function gen_learn_functor(redis_params, learn_script_id)
  35. return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback)
  36. local function learn_redis_cb(err, data)
  37. lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data)
  38. if err then
  39. callback(task, false, err)
  40. else
  41. callback(task, true)
  42. end
  43. end
  44. lua_redis.exec_redis_script(learn_script_id,
  45. { task = task, is_write = false, key = expanded_key },
  46. learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
  47. end
  48. end
  49. ---
  50. --- Init bayes classifier
  51. --- @param classifier_ucl ucl of the classifier config
  52. --- @param statfile_ucl ucl of the statfile config
  53. --- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
  54. exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)
  55. local redis_params
  56. -- Try load from statfile options
  57. if statfile_ucl.redis then
  58. redis_params = lua_redis.try_load_redis_servers(statfile_ucl.redis, rspamd_config, true)
  59. end
  60. if not redis_params then
  61. if statfile_ucl then
  62. redis_params = lua_redis.try_load_redis_servers(statfile_ucl, rspamd_config, true)
  63. end
  64. end
  65. -- Try load from classifier config
  66. if not redis_params and classifier_ucl.backend then
  67. redis_params = lua_redis.try_load_redis_servers(classifier_ucl.backend, rspamd_config, true)
  68. end
  69. if not redis_params and classifier_ucl.redis then
  70. redis_params = lua_redis.try_load_redis_servers(classifier_ucl.redis, rspamd_config, true)
  71. end
  72. if not redis_params then
  73. redis_params = lua_redis.try_load_redis_servers(classifier_ucl, rspamd_config, true)
  74. end
  75. -- Try load global options
  76. if not redis_params then
  77. redis_params = lua_redis.try_load_redis_servers(rspamd_config:get_all_opt('redis'), rspamd_config, true)
  78. end
  79. if not redis_params then
  80. logger.err(rspamd_config, "cannot load Redis parameters for the classifier")
  81. return nil
  82. end
  83. local classify_script_id = lua_redis.load_redis_script_from_file("bayes_classify.lua", redis_params)
  84. local learn_script_id = lua_redis.load_redis_script_from_file("bayes_learn.lua", redis_params)
  85. local stat_script_id = lua_redis.load_redis_script_from_file("bayes_stat.lua", redis_params)
  86. local max_users = classifier_ucl.max_users or 1000
  87. local current_data = {
  88. users = 0,
  89. revision = 0,
  90. }
  91. local final_data = {
  92. users = 0,
  93. revision = 0, -- number of learns
  94. }
  95. local cursor = 0
  96. rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
  97. local function stat_redis_cb(err, data)
  98. -- TODO: write this function
  99. lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)
  100. if err then
  101. logger.warn(cfg, 'cannot get bayes statistics for %s: %s', symbol, err)
  102. else
  103. local new_cursor = data[1]
  104. current_data.users = current_data.users + data[2]
  105. current_data.revision = current_data.revision + data[3]
  106. if new_cursor == 0 then
  107. -- Done iteration
  108. final_data = lua_util.shallowcopy(current_data)
  109. current_data = {
  110. users = 0,
  111. revision = 0,
  112. }
  113. lua_util.debugm(N, cfg, 'final data: %s', final_data)
  114. stat_periodic_cb(cfg, final_data)
  115. end
  116. cursor = new_cursor
  117. end
  118. end
  119. lua_redis.exec_redis_script(stat_script_id,
  120. { ev_base = ev_base, cfg = cfg, is_write = false },
  121. stat_redis_cb, { tostring(cursor),
  122. symbol,
  123. is_spam and "learns_spam" or "learns_ham",
  124. tostring(max_users) })
  125. return statfile_ucl.monitor_timeout or classifier_ucl.monitor_timeout or 30.0
  126. end)
  127. return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id)
  128. end
  129. return exports