]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Improve autolearning
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 24 Jul 2019 14:03:29 +0000 (15:03 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 24 Jul 2019 14:03:29 +0000 (15:03 +0100)
conf/statistic.conf
lualib/lua_bayes_learn.lua
src/libserver/mempool_vars_internal.h
src/libstat/backends/redis_backend.c
src/libstat/stat_internal.h
src/libstat/stat_process.c

index 1e78c73cd17775115f1d742cb946d85660f77497..396564a2349c98f636c16cdfd7dfe18b8dab244d 100644 (file)
@@ -41,7 +41,7 @@ classifier "bayes" {
     symbol = "BAYES_SPAM";
     spam = true;
   }
-  learn_condition = "return require("lua_bayes_learn").can_learn"
+  learn_condition = 'return require("lua_bayes_learn").can_learn';
 
   .include(try=true; priority=1) "$LOCAL_CONFDIR/local.d/classifier-bayes.conf"
   .include(try=true; priority=10) "$LOCAL_CONFDIR/override.d/classifier-bayes.conf"
index 7df52a2ef90063c0fa5cf2b59a7a72aac58fb32e..5a46265e7a5b60ea82a3c0c0098adf70032a8a17 100644 (file)
@@ -16,6 +16,10 @@ limitations under the License.
 
 -- This file contains functions to simplify bayes classifier auto-learning
 
+local lua_util = require "lua_util"
+
+local N = "lua_bayes"
+
 local exports = {}
 
 exports.can_learn = function(task, is_spam, is_unlearn)
@@ -46,4 +50,67 @@ exports.can_learn = function(task, is_spam, is_unlearn)
   return true
 end
 
+exports.autolearn = function(task, conf)
+  -- We have autolearn config so let's figure out what is requested
+  local verdict,score = lua_util.get_task_verdict(task)
+  local learn_spam,learn_ham = false, false
+
+  if verdict == 'passthrough' then
+    -- No need to autolearn
+    lua_util.debugm(N, task, 'no need to autolearn - verdict: %s',
+        verdict)
+    return
+  end
+
+  if conf.spam_threshold and conf.ham_threshold then
+    if verdict == 'spam' then
+      if conf.spam_threshold and score >= conf.spam_threshold then
+        lua_util.debugm(N, task, 'can autolearn spam: score %s >= %s',
+            score, conf.spam_threshold)
+        learn_spam = true
+      end
+    elseif verdict == 'ham' then
+      if conf.ham_threshold and score <= conf.ham_threshold then
+        lua_util.debugm(N, task, 'can autolearn ham: score %s <= %s',
+            score, conf.ham_threshold)
+        learn_ham = true
+      end
+    end
+  end
+
+  if conf.check_balance then
+    -- Check balance of learns
+    local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0
+    local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0
+
+    local min_balance = 0.9
+    if conf.min_balance then min_balance = conf.min_balance end
+
+    if spam_learns > 0 or ham_learns > 0 then
+      local max_ratio = 1.0 / min_balance
+      local spam_learns_ratio = spam_learns / (ham_learns + 1)
+      if  spam_learns_ratio > max_ratio and learn_spam then
+        lua_util.debugm(N, task,
+            'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
+            spam_learns_ratio, min_balance, spam_learns, ham_learns)
+        learn_spam = false
+      end
+
+      local ham_learns_ratio = ham_learns / (spam_learns + 1)
+      if  ham_learns_ratio > max_ratio and learn_ham then
+        lua_util.debugm(N, task,
+            'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
+            ham_learns_ratio, min_balance, spam_learns, ham_learns)
+        learn_ham = false
+      end
+    end
+  end
+
+  if learn_spam then
+    return 'spam'
+  elseif learn_ham then
+    return 'ham'
+  end
+end
+
 return exports
\ No newline at end of file
index c062d44d475ce1c3bc5ef6232c24c3800e215e32..576635a9b44c885f3814c653793dc4d80cd5ccc0 100644 (file)
@@ -38,5 +38,7 @@
 #define RSPAMD_MEMPOOL_ARC_SIGN_SELECTOR "arc_selector"
 #define RSPAMD_MEMPOOL_STAT_SIGNATURE "stat_signature"
 #define RSPAMD_MEMPOOL_FUZZY_RESULT "fuzzy_hashes"
+#define RSPAMD_MEMPOOL_SPAM_LEARNS "spam_learns"
+#define RSPAMD_MEMPOOL_HAM_LEARNS "ham_learns"
 
 #endif
index 7263b3c1687485895ce97c5646c97740d4de24cb..9ac6fb4456152b083e3c9a096149b9d357b99f28 100644 (file)
@@ -1230,6 +1230,32 @@ rspamd_redis_connected (redisAsyncContext *c, gpointer r, gpointer priv)
                                        rt->redis_object_expanded, rt->learned);
                        rspamd_upstream_ok (rt->selected);
 
+                       /* Save learn count in mempool variable */
+                       gint64 *learns_cnt;
+                       const gchar *var_name;
+
+                       if (rt->stcf->is_spam) {
+                               var_name = RSPAMD_MEMPOOL_SPAM_LEARNS;
+                       }
+                       else {
+                               var_name = RSPAMD_MEMPOOL_HAM_LEARNS;
+                       }
+
+                       learns_cnt = rspamd_mempool_get_variable (task->task_pool,
+                                       var_name);
+
+                       if (learns_cnt) {
+                               (*learns_cnt) += rt->learned;
+                       }
+                       else {
+                               learns_cnt = rspamd_mempool_alloc (task->task_pool,
+                                               sizeof (*learns_cnt));
+                               *learns_cnt = rt->learned;
+                               rspamd_mempool_set_variable (task->task_pool,
+                                               var_name,
+                                               learns_cnt, NULL);
+                       }
+
                        if (rt->learned >= rt->stcf->clcf->min_learns && rt->learned > 0) {
                                rspamd_fstring_t *query = rspamd_redis_tokens_to_query (
                                                task,
index 967a3c4d66da0de80000a86b0e56b4f6d4565730..5e2578177786146a7692a47d43df22aa5735a115 100644 (file)
@@ -43,6 +43,7 @@ struct rspamd_classifier {
        gpointer cachecf;
        gulong spam_learns;
        gulong ham_learns;
+       gint autolearn_cbref;
        struct rspamd_classifier_config *cfg;
        struct rspamd_stat_classifier *subrs;
        gpointer specific;
index 034e1a5be228f0cfc52b5470805caafcf1b6c120..d720a77ab3bb9565aeb7032c2eaaf0897a80e487 100644 (file)
@@ -906,6 +906,19 @@ rspamd_stat_has_classifier_symbols (struct rspamd_task *task,
        return FALSE;
 }
 
+struct cl_cbref_dtor_data {
+       lua_State *L;
+       gint ref_idx;
+};
+
+static void
+rspamd_stat_cbref_dtor (void *d)
+{
+       struct cl_cbref_dtor_data *data = (struct cl_cbref_dtor_data *)d;
+
+       luaL_unref (data->L, LUA_REGISTRYINDEX, data->ref_idx);
+}
+
 gboolean
 rspamd_stat_check_autolearn (struct rspamd_task *task)
 {
@@ -925,6 +938,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
        st_ctx = rspamd_stat_get_ctx ();
        g_assert (st_ctx != NULL);
 
+       L = task->cfg->lua_state;
+
        for (i = 0; i < st_ctx->classifiers->len; i ++) {
                cl = g_ptr_array_index (st_ctx->classifiers, i);
                ret = FALSE;
@@ -933,6 +948,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
                        obj = ucl_object_lookup (cl->cfg->opts, "autolearn");
 
                        if (ucl_object_type (obj) == UCL_BOOLEAN) {
+                               /* Legacy true/false */
                                if (ucl_object_toboolean (obj)) {
                                        /*
                                         * Default learning algorithm:
@@ -956,6 +972,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
                                }
                        }
                        else if (ucl_object_type (obj) == UCL_ARRAY && obj->len == 2) {
+                               /* Legacy thresholds */
                                /*
                                 * We have an array of 2 elements, treat it as a
                                 * ham_score, spam_score
@@ -994,8 +1011,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
                                }
                        }
                        else if (ucl_object_type (obj) == UCL_STRING) {
+                               /* Legacy sript */
                                lua_script = ucl_object_tostring (obj);
-                               L = task->cfg->lua_state;
 
                                if (luaL_dostring (L, lua_script) != 0) {
                                        msg_err_task ("cannot execute lua script for autolearn "
@@ -1018,6 +1035,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
                                                else {
                                                        lua_ret = lua_tostring (L, -1);
 
+                                                       /* We can have immediate results */
                                                        if (lua_ret) {
                                                                if (strcmp (lua_ret, "ham") == 0) {
                                                                        task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
@@ -1041,6 +1059,62 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
                                        }
                                }
                        }
+                       else if (ucl_object_type (obj) == UCL_OBJECT) {
+                               /* Try to find autolearn callback */
+                               if (cl->autolearn_cbref == 0) {
+                                       /* We don't have preprocessed cb id, so try to get it */
+                                       if (!rspamd_lua_require_function (L, "lua_bayes_learn",
+                                                       "autolearn")) {
+                                               msg_err_task ("cannot get autolearn library from "
+                                                                         "`lua_bayes_learn`");
+                                       }
+                                       else {
+                                               struct cl_cbref_dtor_data *dtor_data;
+
+                                               dtor_data = (struct cl_cbref_dtor_data *)
+                                                               rspamd_mempool_alloc (task->cfg->cfg_pool,
+                                                                       sizeof (*dtor_data));
+                                               cl->autolearn_cbref = luaL_ref (L, LUA_REGISTRYINDEX);
+                                               dtor_data->L = L;
+                                               dtor_data->ref_idx = cl->autolearn_cbref;
+                                               rspamd_mempool_add_destructor (task->cfg->cfg_pool,
+                                                               rspamd_stat_cbref_dtor, dtor_data);
+                                       }
+                               }
+
+                               if (cl->autolearn_cbref != -1) {
+                                       lua_pushcfunction (L, &rspamd_lua_traceback);
+                                       err_idx = lua_gettop (L);
+                                       lua_rawgeti (L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
+
+                                       ptask = lua_newuserdata (L, sizeof (struct rspamd_task *));
+                                       *ptask = task;
+                                       rspamd_lua_setclass (L, "rspamd{task}", -1);
+                                       /* Push the whole object as well */
+                                       ucl_object_push_lua (L, obj, true);
+
+                                       if (lua_pcall (L, 2, 1, err_idx) != 0) {
+                                               msg_err_task ("call to autolearn script failed: "
+                                                                         "%s", lua_tostring (L, -1));
+                                       }
+                                       else {
+                                               lua_ret = lua_tostring (L, -1);
+
+                                               if (lua_ret) {
+                                                       if (strcmp (lua_ret, "ham") == 0) {
+                                                               task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+                                                               ret = TRUE;
+                                                       }
+                                                       else if (strcmp (lua_ret, "spam") == 0) {
+                                                               task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+                                                               ret = TRUE;
+                                                       }
+                                               }
+                                       }
+
+                                       lua_settop (L, err_idx - 1);
+                               }
+                       }
 
                        if (ret) {
                                /* Do not autolearn if we have this symbol already */