123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- --[[
- Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ]]
-
- -- This file contains functions to support Bayes statistics in Redis
-
- local exports = {}
- local lua_redis = require "lua_redis"
- local logger = require "rspamd_logger"
- local lua_util = require "lua_util"
- local ucl = require "ucl"
-
- local N = "bayes"
-
- local function gen_classify_functor(redis_params, classify_script_id)
- return function(task, expanded_key, id, is_spam, stat_tokens, callback)
-
- local function classify_redis_cb(err, data)
- lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data)
- if err then
- callback(task, false, err)
- else
- callback(task, true, data[1], data[2], data[3], data[4])
- end
- end
-
- lua_redis.exec_redis_script(classify_script_id,
- { task = task, is_write = false, key = expanded_key },
- classify_redis_cb, { expanded_key, stat_tokens })
- end
- end
-
- local function gen_learn_functor(redis_params, learn_script_id)
- return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens)
- local function learn_redis_cb(err, data)
- lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data)
- if err then
- callback(task, false, err)
- else
- callback(task, true)
- end
- end
-
- if maybe_text_tokens then
- lua_redis.exec_redis_script(learn_script_id,
- { task = task, is_write = true, key = expanded_key },
- learn_redis_cb,
- { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
- else
- lua_redis.exec_redis_script(learn_script_id,
- { task = task, is_write = true, key = expanded_key },
- learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
- end
-
- end
- end
-
- local function load_redis_params(classifier_ucl, statfile_ucl)
- local redis_params
-
- -- Try load from statfile options
- if statfile_ucl.redis then
- redis_params = lua_redis.try_load_redis_servers(statfile_ucl.redis, rspamd_config, true)
- end
-
- if not redis_params then
- if statfile_ucl then
- redis_params = lua_redis.try_load_redis_servers(statfile_ucl, rspamd_config, true)
- end
- end
-
- -- Try load from classifier config
- if not redis_params and classifier_ucl.backend then
- redis_params = lua_redis.try_load_redis_servers(classifier_ucl.backend, rspamd_config, true)
- end
-
- if not redis_params and classifier_ucl.redis then
- redis_params = lua_redis.try_load_redis_servers(classifier_ucl.redis, rspamd_config, true)
- end
-
- if not redis_params then
- redis_params = lua_redis.try_load_redis_servers(classifier_ucl, rspamd_config, true)
- end
-
- -- Try load global options
- if not redis_params then
- redis_params = lua_redis.try_load_redis_servers(rspamd_config:get_all_opt('redis'), rspamd_config, true)
- end
-
- if not redis_params then
- logger.err(rspamd_config, "cannot load Redis parameters for the classifier")
- return nil
- end
-
- return redis_params
- end
-
- ---
- --- Init bayes classifier
- --- @param classifier_ucl ucl of the classifier config
- --- @param statfile_ucl ucl of the statfile config
- --- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
- exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)
-
- local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
-
- if not redis_params then
- return nil
- end
-
- local classify_script_id = lua_redis.load_redis_script_from_file("bayes_classify.lua", redis_params)
- local learn_script_id = lua_redis.load_redis_script_from_file("bayes_learn.lua", redis_params)
- local stat_script_id = lua_redis.load_redis_script_from_file("bayes_stat.lua", redis_params)
- local max_users = classifier_ucl.max_users or 1000
-
- local current_data = {
- users = 0,
- revision = 0,
- }
- local final_data = {
- users = 0,
- revision = 0, -- number of learns
- }
- local cursor = 0
-
- if ev_base then
- rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
-
- local function stat_redis_cb(err, data)
- lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)
-
- if err then
- logger.warn(cfg, 'cannot get bayes statistics for %s: %s', symbol, err)
- else
- local new_cursor = data[1]
- current_data.users = current_data.users + data[2]
- current_data.revision = current_data.revision + data[3]
- if new_cursor == 0 then
- -- Done iteration
- final_data = lua_util.shallowcopy(current_data)
- current_data = {
- users = 0,
- revision = 0,
- }
- lua_util.debugm(N, cfg, 'final data: %s', final_data)
- stat_periodic_cb(cfg, final_data)
- end
-
- cursor = new_cursor
- end
- end
-
- lua_redis.exec_redis_script(stat_script_id,
- { ev_base = ev_base, cfg = cfg, is_write = false },
- stat_redis_cb, { tostring(cursor),
- symbol,
- is_spam and "learns_spam" or "learns_ham",
- tostring(max_users) })
- return statfile_ucl.monitor_timeout or classifier_ucl.monitor_timeout or 30.0
- end)
- end
-
- return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id)
- end
-
- local function gen_cache_check_functor(redis_params, check_script_id, conf)
- local packed_conf = ucl.to_format(conf, 'msgpack')
- return function(task, cache_id, callback)
-
- local function classify_redis_cb(err, data)
- lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data))
- if err then
- callback(task, false, err)
- else
- if type(data) == 'number' then
- callback(task, true, data)
- else
- callback(task, false, 'not found')
- end
- end
- end
-
- lua_util.debugm(N, task, 'checking cache: %s', cache_id)
- lua_redis.exec_redis_script(check_script_id,
- { task = task, is_write = false, key = cache_id },
- classify_redis_cb, { cache_id, packed_conf })
- end
- end
-
- local function gen_cache_learn_functor(redis_params, learn_script_id, conf)
- local packed_conf = ucl.to_format(conf, 'msgpack')
- return function(task, cache_id, is_spam)
- local function learn_redis_cb(err, data)
- lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
- end
-
- lua_util.debugm(N, task, 'try to learn cache: %s', cache_id)
- lua_redis.exec_redis_script(learn_script_id,
- { task = task, is_write = true, key = cache_id },
- learn_redis_cb,
- { cache_id, is_spam and "1" or "0", packed_conf })
-
- end
- end
-
- exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl)
- local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
-
- if not redis_params then
- return nil
- end
-
- local default_conf = {
- cache_prefix = "learned_ids",
- cache_max_elt = 10000, -- Maximum number of elements in the cache key
- cache_max_keys = 5, -- Maximum number of keys in the cache
- cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value)
- }
-
- local conf = lua_util.override_defaults(default_conf, classifier_ucl)
- -- Clean all not known configurations
- for k, _ in pairs(conf) do
- if default_conf[k] == nil then
- conf[k] = nil
- end
- end
-
- local check_script_id = lua_redis.load_redis_script_from_file("bayes_cache_check.lua", redis_params)
- local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params)
-
- return gen_cache_check_functor(redis_params, check_script_id, conf), gen_cache_learn_functor(redis_params,
- learn_script_id, conf)
- end
-
- return exports
|