From e1fadcc80b5f6a3d566224b0ed1a74d7a9dbc9ed Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Wed, 24 Jul 2019 15:03:29 +0100 Subject: [PATCH] [Feature] Improve autolearning --- conf/statistic.conf | 2 +- lualib/lua_bayes_learn.lua | 67 +++++++++++++++++++++++ src/libserver/mempool_vars_internal.h | 2 + src/libstat/backends/redis_backend.c | 26 +++++++++ src/libstat/stat_internal.h | 1 + src/libstat/stat_process.c | 76 ++++++++++++++++++++++++++- 6 files changed, 172 insertions(+), 2 deletions(-) diff --git a/conf/statistic.conf b/conf/statistic.conf index 1e78c73cd..396564a23 100644 --- a/conf/statistic.conf +++ b/conf/statistic.conf @@ -41,7 +41,7 @@ classifier "bayes" { symbol = "BAYES_SPAM"; spam = true; } - learn_condition = "return require("lua_bayes_learn").can_learn" + learn_condition = 'return require("lua_bayes_learn").can_learn'; .include(try=true; priority=1) "$LOCAL_CONFDIR/local.d/classifier-bayes.conf" .include(try=true; priority=10) "$LOCAL_CONFDIR/override.d/classifier-bayes.conf" diff --git a/lualib/lua_bayes_learn.lua b/lualib/lua_bayes_learn.lua index 7df52a2ef..5a46265e7 100644 --- a/lualib/lua_bayes_learn.lua +++ b/lualib/lua_bayes_learn.lua @@ -16,6 +16,10 @@ limitations under the License. -- This file contains functions to simplify bayes classifier auto-learning +local lua_util = require "lua_util" + +local N = "lua_bayes" + local exports = {} exports.can_learn = function(task, is_spam, is_unlearn) @@ -46,4 +50,67 @@ exports.can_learn = function(task, is_spam, is_unlearn) return true end +exports.autolearn = function(task, conf) + -- We have autolearn config so let's figure out what is requested + local verdict,score = lua_util.get_task_verdict(task) + local learn_spam,learn_ham = false, false + + if verdict == 'passthrough' then + -- No need to autolearn + lua_util.debugm(N, task, 'no need to autolearn - verdict: %s', + verdict) + return + end + + if conf.spam_threshold and conf.ham_threshold then + if verdict == 'spam' then + if conf.spam_threshold and score >= conf.spam_threshold then + lua_util.debugm(N, task, 'can autolearn spam: score %s >= %s', + score, conf.spam_threshold) + learn_spam = true + end + elseif verdict == 'ham' then + if conf.ham_threshold and score <= conf.ham_threshold then + lua_util.debugm(N, task, 'can autolearn ham: score %s <= %s', + score, conf.ham_threshold) + learn_ham = true + end + end + end + + if conf.check_balance then + -- Check balance of learns + local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0 + local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0 + + local min_balance = 0.9 + if conf.min_balance then min_balance = conf.min_balance end + + if spam_learns > 0 or ham_learns > 0 then + local max_ratio = 1.0 / min_balance + local spam_learns_ratio = spam_learns / (ham_learns + 1) + if spam_learns_ratio > max_ratio and learn_spam then + lua_util.debugm(N, task, + 'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns', + spam_learns_ratio, min_balance, spam_learns, ham_learns) + learn_spam = false + end + + local ham_learns_ratio = ham_learns / (spam_learns + 1) + if ham_learns_ratio > max_ratio and learn_ham then + lua_util.debugm(N, task, + 'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns', + ham_learns_ratio, min_balance, spam_learns, ham_learns) + learn_ham = false + end + end + end + + if learn_spam then + return 'spam' + elseif learn_ham then + return 'ham' + end +end + return exports \ No newline at end of file diff --git a/src/libserver/mempool_vars_internal.h b/src/libserver/mempool_vars_internal.h index c062d44d4..576635a9b 100644 --- a/src/libserver/mempool_vars_internal.h +++ b/src/libserver/mempool_vars_internal.h @@ -38,5 +38,7 @@ #define RSPAMD_MEMPOOL_ARC_SIGN_SELECTOR "arc_selector" #define RSPAMD_MEMPOOL_STAT_SIGNATURE "stat_signature" #define RSPAMD_MEMPOOL_FUZZY_RESULT "fuzzy_hashes" +#define RSPAMD_MEMPOOL_SPAM_LEARNS "spam_learns" +#define RSPAMD_MEMPOOL_HAM_LEARNS "ham_learns" #endif diff --git a/src/libstat/backends/redis_backend.c b/src/libstat/backends/redis_backend.c index 7263b3c16..9ac6fb445 100644 --- a/src/libstat/backends/redis_backend.c +++ b/src/libstat/backends/redis_backend.c @@ -1230,6 +1230,32 @@ rspamd_redis_connected (redisAsyncContext *c, gpointer r, gpointer priv) rt->redis_object_expanded, rt->learned); rspamd_upstream_ok (rt->selected); + /* Save learn count in mempool variable */ + gint64 *learns_cnt; + const gchar *var_name; + + if (rt->stcf->is_spam) { + var_name = RSPAMD_MEMPOOL_SPAM_LEARNS; + } + else { + var_name = RSPAMD_MEMPOOL_HAM_LEARNS; + } + + learns_cnt = rspamd_mempool_get_variable (task->task_pool, + var_name); + + if (learns_cnt) { + (*learns_cnt) += rt->learned; + } + else { + learns_cnt = rspamd_mempool_alloc (task->task_pool, + sizeof (*learns_cnt)); + *learns_cnt = rt->learned; + rspamd_mempool_set_variable (task->task_pool, + var_name, + learns_cnt, NULL); + } + if (rt->learned >= rt->stcf->clcf->min_learns && rt->learned > 0) { rspamd_fstring_t *query = rspamd_redis_tokens_to_query ( task, diff --git a/src/libstat/stat_internal.h b/src/libstat/stat_internal.h index 967a3c4d6..5e2578177 100644 --- a/src/libstat/stat_internal.h +++ b/src/libstat/stat_internal.h @@ -43,6 +43,7 @@ struct rspamd_classifier { gpointer cachecf; gulong spam_learns; gulong ham_learns; + gint autolearn_cbref; struct rspamd_classifier_config *cfg; struct rspamd_stat_classifier *subrs; gpointer specific; diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 034e1a5be..d720a77ab 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -906,6 +906,19 @@ rspamd_stat_has_classifier_symbols (struct rspamd_task *task, return FALSE; } +struct cl_cbref_dtor_data { + lua_State *L; + gint ref_idx; +}; + +static void +rspamd_stat_cbref_dtor (void *d) +{ + struct cl_cbref_dtor_data *data = (struct cl_cbref_dtor_data *)d; + + luaL_unref (data->L, LUA_REGISTRYINDEX, data->ref_idx); +} + gboolean rspamd_stat_check_autolearn (struct rspamd_task *task) { @@ -925,6 +938,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) st_ctx = rspamd_stat_get_ctx (); g_assert (st_ctx != NULL); + L = task->cfg->lua_state; + for (i = 0; i < st_ctx->classifiers->len; i ++) { cl = g_ptr_array_index (st_ctx->classifiers, i); ret = FALSE; @@ -933,6 +948,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) obj = ucl_object_lookup (cl->cfg->opts, "autolearn"); if (ucl_object_type (obj) == UCL_BOOLEAN) { + /* Legacy true/false */ if (ucl_object_toboolean (obj)) { /* * Default learning algorithm: @@ -956,6 +972,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) } } else if (ucl_object_type (obj) == UCL_ARRAY && obj->len == 2) { + /* Legacy thresholds */ /* * We have an array of 2 elements, treat it as a * ham_score, spam_score @@ -994,8 +1011,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) } } else if (ucl_object_type (obj) == UCL_STRING) { + /* Legacy sript */ lua_script = ucl_object_tostring (obj); - L = task->cfg->lua_state; if (luaL_dostring (L, lua_script) != 0) { msg_err_task ("cannot execute lua script for autolearn " @@ -1018,6 +1035,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) else { lua_ret = lua_tostring (L, -1); + /* We can have immediate results */ if (lua_ret) { if (strcmp (lua_ret, "ham") == 0) { task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; @@ -1041,6 +1059,62 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) } } } + else if (ucl_object_type (obj) == UCL_OBJECT) { + /* Try to find autolearn callback */ + if (cl->autolearn_cbref == 0) { + /* We don't have preprocessed cb id, so try to get it */ + if (!rspamd_lua_require_function (L, "lua_bayes_learn", + "autolearn")) { + msg_err_task ("cannot get autolearn library from " + "`lua_bayes_learn`"); + } + else { + struct cl_cbref_dtor_data *dtor_data; + + dtor_data = (struct cl_cbref_dtor_data *) + rspamd_mempool_alloc (task->cfg->cfg_pool, + sizeof (*dtor_data)); + cl->autolearn_cbref = luaL_ref (L, LUA_REGISTRYINDEX); + dtor_data->L = L; + dtor_data->ref_idx = cl->autolearn_cbref; + rspamd_mempool_add_destructor (task->cfg->cfg_pool, + rspamd_stat_cbref_dtor, dtor_data); + } + } + + if (cl->autolearn_cbref != -1) { + lua_pushcfunction (L, &rspamd_lua_traceback); + err_idx = lua_gettop (L); + lua_rawgeti (L, LUA_REGISTRYINDEX, cl->autolearn_cbref); + + ptask = lua_newuserdata (L, sizeof (struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass (L, "rspamd{task}", -1); + /* Push the whole object as well */ + ucl_object_push_lua (L, obj, true); + + if (lua_pcall (L, 2, 1, err_idx) != 0) { + msg_err_task ("call to autolearn script failed: " + "%s", lua_tostring (L, -1)); + } + else { + lua_ret = lua_tostring (L, -1); + + if (lua_ret) { + if (strcmp (lua_ret, "ham") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + ret = TRUE; + } + else if (strcmp (lua_ret, "spam") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + ret = TRUE; + } + } + } + + lua_settop (L, err_idx - 1); + } + } if (ret) { /* Do not autolearn if we have this symbol already */ -- 2.39.5