aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-24 15:03:29 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-24 15:03:29 +0100
commite1fadcc80b5f6a3d566224b0ed1a74d7a9dbc9ed (patch)
tree718bee26b22d6582e63a7bffc8d50104b866c131
parent701a711049ee01373bc3862cc441fc3065c8dbc2 (diff)
downloadrspamd-e1fadcc80b5f6a3d566224b0ed1a74d7a9dbc9ed.tar.gz
rspamd-e1fadcc80b5f6a3d566224b0ed1a74d7a9dbc9ed.zip
[Feature] Improve autolearning
-rw-r--r--conf/statistic.conf2
-rw-r--r--lualib/lua_bayes_learn.lua67
-rw-r--r--src/libserver/mempool_vars_internal.h2
-rw-r--r--src/libstat/backends/redis_backend.c26
-rw-r--r--src/libstat/stat_internal.h1
-rw-r--r--src/libstat/stat_process.c76
6 files changed, 172 insertions, 2 deletions
diff --git a/conf/statistic.conf b/conf/statistic.conf
index 1e78c73cd..396564a23 100644
--- a/conf/statistic.conf
+++ b/conf/statistic.conf
@@ -41,7 +41,7 @@ classifier "bayes" {
symbol = "BAYES_SPAM";
spam = true;
}
- learn_condition = "return require("lua_bayes_learn").can_learn"
+ learn_condition = 'return require("lua_bayes_learn").can_learn';
.include(try=true; priority=1) "$LOCAL_CONFDIR/local.d/classifier-bayes.conf"
.include(try=true; priority=10) "$LOCAL_CONFDIR/override.d/classifier-bayes.conf"
diff --git a/lualib/lua_bayes_learn.lua b/lualib/lua_bayes_learn.lua
index 7df52a2ef..5a46265e7 100644
--- a/lualib/lua_bayes_learn.lua
+++ b/lualib/lua_bayes_learn.lua
@@ -16,6 +16,10 @@ limitations under the License.
-- This file contains functions to simplify bayes classifier auto-learning
+local lua_util = require "lua_util"
+
+local N = "lua_bayes"
+
local exports = {}
exports.can_learn = function(task, is_spam, is_unlearn)
@@ -46,4 +50,67 @@ exports.can_learn = function(task, is_spam, is_unlearn)
return true
end
+exports.autolearn = function(task, conf)
+ -- We have autolearn config so let's figure out what is requested
+ local verdict,score = lua_util.get_task_verdict(task)
+ local learn_spam,learn_ham = false, false
+
+ if verdict == 'passthrough' then
+ -- No need to autolearn
+ lua_util.debugm(N, task, 'no need to autolearn - verdict: %s',
+ verdict)
+ return
+ end
+
+ if conf.spam_threshold and conf.ham_threshold then
+ if verdict == 'spam' then
+ if conf.spam_threshold and score >= conf.spam_threshold then
+ lua_util.debugm(N, task, 'can autolearn spam: score %s >= %s',
+ score, conf.spam_threshold)
+ learn_spam = true
+ end
+ elseif verdict == 'ham' then
+ if conf.ham_threshold and score <= conf.ham_threshold then
+ lua_util.debugm(N, task, 'can autolearn ham: score %s <= %s',
+ score, conf.ham_threshold)
+ learn_ham = true
+ end
+ end
+ end
+
+ if conf.check_balance then
+ -- Check balance of learns
+ local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0
+ local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0
+
+ local min_balance = 0.9
+ if conf.min_balance then min_balance = conf.min_balance end
+
+ if spam_learns > 0 or ham_learns > 0 then
+ local max_ratio = 1.0 / min_balance
+ local spam_learns_ratio = spam_learns / (ham_learns + 1)
+ if spam_learns_ratio > max_ratio and learn_spam then
+ lua_util.debugm(N, task,
+ 'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
+ spam_learns_ratio, min_balance, spam_learns, ham_learns)
+ learn_spam = false
+ end
+
+ local ham_learns_ratio = ham_learns / (spam_learns + 1)
+ if ham_learns_ratio > max_ratio and learn_ham then
+ lua_util.debugm(N, task,
+ 'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
+ ham_learns_ratio, min_balance, spam_learns, ham_learns)
+ learn_ham = false
+ end
+ end
+ end
+
+ if learn_spam then
+ return 'spam'
+ elseif learn_ham then
+ return 'ham'
+ end
+end
+
return exports \ No newline at end of file
diff --git a/src/libserver/mempool_vars_internal.h b/src/libserver/mempool_vars_internal.h
index c062d44d4..576635a9b 100644
--- a/src/libserver/mempool_vars_internal.h
+++ b/src/libserver/mempool_vars_internal.h
@@ -38,5 +38,7 @@
#define RSPAMD_MEMPOOL_ARC_SIGN_SELECTOR "arc_selector"
#define RSPAMD_MEMPOOL_STAT_SIGNATURE "stat_signature"
#define RSPAMD_MEMPOOL_FUZZY_RESULT "fuzzy_hashes"
+#define RSPAMD_MEMPOOL_SPAM_LEARNS "spam_learns"
+#define RSPAMD_MEMPOOL_HAM_LEARNS "ham_learns"
#endif
diff --git a/src/libstat/backends/redis_backend.c b/src/libstat/backends/redis_backend.c
index 7263b3c16..9ac6fb445 100644
--- a/src/libstat/backends/redis_backend.c
+++ b/src/libstat/backends/redis_backend.c
@@ -1230,6 +1230,32 @@ rspamd_redis_connected (redisAsyncContext *c, gpointer r, gpointer priv)
rt->redis_object_expanded, rt->learned);
rspamd_upstream_ok (rt->selected);
+ /* Save learn count in mempool variable */
+ gint64 *learns_cnt;
+ const gchar *var_name;
+
+ if (rt->stcf->is_spam) {
+ var_name = RSPAMD_MEMPOOL_SPAM_LEARNS;
+ }
+ else {
+ var_name = RSPAMD_MEMPOOL_HAM_LEARNS;
+ }
+
+ learns_cnt = rspamd_mempool_get_variable (task->task_pool,
+ var_name);
+
+ if (learns_cnt) {
+ (*learns_cnt) += rt->learned;
+ }
+ else {
+ learns_cnt = rspamd_mempool_alloc (task->task_pool,
+ sizeof (*learns_cnt));
+ *learns_cnt = rt->learned;
+ rspamd_mempool_set_variable (task->task_pool,
+ var_name,
+ learns_cnt, NULL);
+ }
+
if (rt->learned >= rt->stcf->clcf->min_learns && rt->learned > 0) {
rspamd_fstring_t *query = rspamd_redis_tokens_to_query (
task,
diff --git a/src/libstat/stat_internal.h b/src/libstat/stat_internal.h
index 967a3c4d6..5e2578177 100644
--- a/src/libstat/stat_internal.h
+++ b/src/libstat/stat_internal.h
@@ -43,6 +43,7 @@ struct rspamd_classifier {
gpointer cachecf;
gulong spam_learns;
gulong ham_learns;
+ gint autolearn_cbref;
struct rspamd_classifier_config *cfg;
struct rspamd_stat_classifier *subrs;
gpointer specific;
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 034e1a5be..d720a77ab 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -906,6 +906,19 @@ rspamd_stat_has_classifier_symbols (struct rspamd_task *task,
return FALSE;
}
+struct cl_cbref_dtor_data {
+ lua_State *L;
+ gint ref_idx;
+};
+
+static void
+rspamd_stat_cbref_dtor (void *d)
+{
+ struct cl_cbref_dtor_data *data = (struct cl_cbref_dtor_data *)d;
+
+ luaL_unref (data->L, LUA_REGISTRYINDEX, data->ref_idx);
+}
+
gboolean
rspamd_stat_check_autolearn (struct rspamd_task *task)
{
@@ -925,6 +938,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
st_ctx = rspamd_stat_get_ctx ();
g_assert (st_ctx != NULL);
+ L = task->cfg->lua_state;
+
for (i = 0; i < st_ctx->classifiers->len; i ++) {
cl = g_ptr_array_index (st_ctx->classifiers, i);
ret = FALSE;
@@ -933,6 +948,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
obj = ucl_object_lookup (cl->cfg->opts, "autolearn");
if (ucl_object_type (obj) == UCL_BOOLEAN) {
+ /* Legacy true/false */
if (ucl_object_toboolean (obj)) {
/*
* Default learning algorithm:
@@ -956,6 +972,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
}
}
else if (ucl_object_type (obj) == UCL_ARRAY && obj->len == 2) {
+ /* Legacy thresholds */
/*
* We have an array of 2 elements, treat it as a
* ham_score, spam_score
@@ -994,8 +1011,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
}
}
else if (ucl_object_type (obj) == UCL_STRING) {
+ /* Legacy sript */
lua_script = ucl_object_tostring (obj);
- L = task->cfg->lua_state;
if (luaL_dostring (L, lua_script) != 0) {
msg_err_task ("cannot execute lua script for autolearn "
@@ -1018,6 +1035,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
else {
lua_ret = lua_tostring (L, -1);
+ /* We can have immediate results */
if (lua_ret) {
if (strcmp (lua_ret, "ham") == 0) {
task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
@@ -1041,6 +1059,62 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
}
}
}
+ else if (ucl_object_type (obj) == UCL_OBJECT) {
+ /* Try to find autolearn callback */
+ if (cl->autolearn_cbref == 0) {
+ /* We don't have preprocessed cb id, so try to get it */
+ if (!rspamd_lua_require_function (L, "lua_bayes_learn",
+ "autolearn")) {
+ msg_err_task ("cannot get autolearn library from "
+ "`lua_bayes_learn`");
+ }
+ else {
+ struct cl_cbref_dtor_data *dtor_data;
+
+ dtor_data = (struct cl_cbref_dtor_data *)
+ rspamd_mempool_alloc (task->cfg->cfg_pool,
+ sizeof (*dtor_data));
+ cl->autolearn_cbref = luaL_ref (L, LUA_REGISTRYINDEX);
+ dtor_data->L = L;
+ dtor_data->ref_idx = cl->autolearn_cbref;
+ rspamd_mempool_add_destructor (task->cfg->cfg_pool,
+ rspamd_stat_cbref_dtor, dtor_data);
+ }
+ }
+
+ if (cl->autolearn_cbref != -1) {
+ lua_pushcfunction (L, &rspamd_lua_traceback);
+ err_idx = lua_gettop (L);
+ lua_rawgeti (L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
+
+ ptask = lua_newuserdata (L, sizeof (struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass (L, "rspamd{task}", -1);
+ /* Push the whole object as well */
+ ucl_object_push_lua (L, obj, true);
+
+ if (lua_pcall (L, 2, 1, err_idx) != 0) {
+ msg_err_task ("call to autolearn script failed: "
+ "%s", lua_tostring (L, -1));
+ }
+ else {
+ lua_ret = lua_tostring (L, -1);
+
+ if (lua_ret) {
+ if (strcmp (lua_ret, "ham") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ ret = TRUE;
+ }
+ else if (strcmp (lua_ret, "spam") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+ ret = TRUE;
+ }
+ }
+ }
+
+ lua_settop (L, err_idx - 1);
+ }
+ }
if (ret) {
/* Do not autolearn if we have this symbol already */