]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Rework learn and add classify condition
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 1 Sep 2021 13:26:32 +0000 (14:26 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 1 Sep 2021 13:26:32 +0000 (14:26 +0100)
src/libserver/cfg_file.h
src/libserver/cfg_rcl.c
src/libstat/stat_process.c
src/lua/lua_common.c
src/lua/lua_common.h

index 4d865e2733eaa4c275377da5e08df4755a69deb4..745f0fb2208838c9ab16262f792c0276c6b42e9e 100644 (file)
@@ -192,6 +192,7 @@ struct rspamd_classifier_config {
        const gchar *backend;                           /**< name of statfile's backend                                                 */
        ucl_object_t *opts;                             /**< other options                                      */
        GList *learn_conditions;                        /**< list of learn condition callbacks                                  */
+       GList *classify_conditions;                     /**< list of classify condition callbacks                                       */
        gchar *name;                                    /**< unique name of classifier                                                  */
        guint32 min_tokens;                             /**< minimal number of tokens to process classifier     */
        guint32 max_tokens;                             /**< maximum number of tokens                                                   */
index 717b16bea99628153d242cb67664b69e6d149838..e3c69c3431f94c70c7862e122f9a35436c6af985 100644 (file)
@@ -1299,7 +1299,7 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,
        ccf->tokenizer = tkcf;
 
        /* Handle lua conditions */
-       val = ucl_object_lookup_any (obj, "condition", "learn_condition", NULL);
+       val = ucl_object_lookup_any (obj, "learn_condition", NULL);
 
        if (val) {
                LL_FOREACH (val, cur) {
@@ -1310,7 +1310,7 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,
 
                                lua_script = ucl_object_tolstring(cur, &slen);
                                ref_idx = rspamd_lua_function_ref_from_str(L,
-                                               lua_script, slen, err);
+                                               lua_script, slen, "learn_condition", err);
 
                                if (ref_idx == LUA_NOREF) {
                                        return FALSE;
@@ -1325,6 +1325,32 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,
                }
        }
 
+       val = ucl_object_lookup_any (obj, "classify_condition", NULL);
+
+       if (val) {
+               LL_FOREACH (val, cur) {
+                       if (ucl_object_type(cur) == UCL_STRING) {
+                               const gchar *lua_script;
+                               gsize slen;
+                               gint ref_idx;
+
+                               lua_script = ucl_object_tolstring(cur, &slen);
+                               ref_idx = rspamd_lua_function_ref_from_str(L,
+                                               lua_script, slen, "classify_condition", err);
+
+                               if (ref_idx == LUA_NOREF) {
+                                       return FALSE;
+                               }
+
+                               rspamd_lua_add_ref_dtor (L, cfg->cfg_pool, ref_idx);
+                               ccf->classify_conditions = rspamd_mempool_glist_append(
+                                               cfg->cfg_pool,
+                                               ccf->classify_conditions,
+                                               GINT_TO_POINTER (ref_idx));
+                       }
+               }
+       }
+
        ccf->opts = (ucl_object_t *)obj;
        cfg->classifiers = g_list_prepend (cfg->classifiers, ccf);
 
index 8ac4e499ee8e5fa29a14d480db4e55b7fe70bf24..4e856b563194f27b2f00d3eb2b591908098bb97b 100644 (file)
@@ -190,9 +190,75 @@ rspamd_stat_process_tokenize (struct rspamd_stat_ctx *st_ctx,
                        b32_hout, g_free);
 }
 
+static gboolean
+rspamd_stat_classifier_is_skipped (struct rspamd_task *task,
+               struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam)
+{
+       GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions;
+       lua_State *L = task->cfg->lua_state;
+       gboolean ret = FALSE;
+
+       while (cur) {
+               gint cb_ref = GPOINTER_TO_INT (cur->data);
+               gint old_top = lua_gettop (L);
+
+               lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref);
+               /* Push task and two booleans: is_spam and is_unlearn */
+               struct rspamd_task **ptask = lua_newuserdata (L, sizeof (*ptask));
+               *ptask = task;
+               rspamd_lua_setclass (L, "rspamd{task}", -1);
+
+               if (is_learn) {
+                       lua_pushboolean(L, is_spam);
+                       lua_pushboolean(L,
+                                       task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
+               }
+
+               if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) {
+                       msg_err_task ("call to %s failed: %s",
+                                       "condition callback",
+                                       lua_tostring (L, -1));
+               }
+               else {
+                       if (lua_isboolean (L, 1)) {
+                               if (!lua_toboolean (L, 1)) {
+                                       ret = TRUE;
+                               }
+                       }
+
+                       if (lua_isstring (L, 2)) {
+                               if (ret) {
+                                       msg_notice_task ("%s condition for classifier %s returned: %s; skip classifier",
+                                                       is_learn ? "learn" : "classify", cl->cfg->name,
+                                                       lua_tostring(L, 2));
+                               }
+                               else {
+                                       msg_info_task ("%s condition for classifier %s returned: %s",
+                                                       is_learn ? "learn" : "classify", cl->cfg->name,
+                                                       lua_tostring(L, 2));
+                               }
+                       }
+                       else if (ret) {
+                               msg_notice_task("%s condition for classifier %s returned false; skip classifier",
+                                               is_learn ? "learn" : "classify", cl->cfg->name);
+                       }
+
+                       if (ret) {
+                               lua_settop (L, old_top);
+                               break;
+                       }
+               }
+
+               lua_settop (L, old_top);
+               cur = g_list_next (cur);
+       }
+
+       return ret;
+}
+
 static void
 rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
-               struct rspamd_task *task, gboolean learn)
+               struct rspamd_task *task, gboolean is_learn, gboolean is_spam)
 {
        guint i;
        struct rspamd_statfile *st;
@@ -207,12 +273,39 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
        rspamd_mempool_add_destructor (task->task_pool,
                        rspamd_ptr_array_free_hard, task->stat_runtimes);
 
+       /* Temporary set all stat_runtimes to some max size to distinguish from NULL */
+       for (i = 0; i < st_ctx->statfiles->len; i ++) {
+               g_ptr_array_index (task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE);
+       }
+
+       for (i = 0; i < st_ctx->classifiers->len; i++) {
+               struct rspamd_classifier *cl = g_ptr_array_index (st_ctx->classifiers, i);
+               gboolean skip_classifier = FALSE;
+
+               if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
+                       skip_classifier = TRUE;
+               }
+               else {
+                       if (rspamd_stat_classifier_is_skipped (task, cl, is_learn , is_spam)) {
+                               skip_classifier = TRUE;
+                       }
+               }
+
+               if (skip_classifier) {
+                       /* Set NULL for all statfiles indexed by id */
+                       for (int j = 0; j < cl->statfiles_ids->len; j++) {
+                               int id = g_array_index (cl->statfiles_ids, gint, j);
+                               g_ptr_array_index (task->stat_runtimes, id) = NULL;
+                       }
+               }
+       }
+
        for (i = 0; i < st_ctx->statfiles->len; i ++) {
                st = g_ptr_array_index (st_ctx->statfiles, i);
                g_assert (st != NULL);
 
-               if (st->classifier->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
-                       g_ptr_array_index (task->stat_runtimes, i) = NULL;
+               if (g_ptr_array_index (task->stat_runtimes, i) == NULL) {
+                       /* The whole classifier is skipped */
                        continue;
                }
 
@@ -224,7 +317,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
                        continue;
                }
 
-               bk_run = st->backend->runtime (task, st->stcf, learn, st->bkcf);
+               bk_run = st->backend->runtime (task, st->stcf, is_learn, st->bkcf);
 
                if (bk_run == NULL) {
                        msg_err_task ("cannot init backend %s for statfile %s",
@@ -249,11 +342,6 @@ rspamd_stat_backends_process (struct rspamd_stat_ctx *st_ctx,
        for (i = 0; i < st_ctx->statfiles->len; i++) {
                st = g_ptr_array_index (st_ctx->statfiles, i);
                cl = st->classifier;
-
-               if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
-                       continue;
-               }
-
                bk_run = g_ptr_array_index (task->stat_runtimes, i);
 
                if (bk_run != NULL) {
@@ -302,10 +390,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
                st = g_ptr_array_index (st_ctx->statfiles, i);
                cl = st->classifier;
 
-               if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
-                       continue;
-               }
-
                bk_run = g_ptr_array_index (task->stat_runtimes, i);
                g_assert (st != NULL);
 
@@ -332,10 +416,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
 
                /* Do not process classifiers on backend failures */
                for (j = 0; j < cl->statfiles_ids->len; j++) {
-                       if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
-                               continue;
-                       }
-
                        id = g_array_index (cl->statfiles_ids, gint, j);
                        bk_run =  g_ptr_array_index (task->stat_runtimes, id);
                        st = g_ptr_array_index (st_ctx->statfiles, id);
@@ -406,7 +486,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage,
 
        if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) {
                /* Preprocess tokens */
-               rspamd_stat_preprocess (st_ctx, task, FALSE);
+               rspamd_stat_preprocess (st_ctx, task, FALSE, FALSE);
        }
        else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) {
                /* Process backends */
@@ -490,13 +570,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
 {
        struct rspamd_classifier *cl, *sel = NULL;
        guint i;
-       gboolean learned = FALSE, too_small = FALSE, too_large = FALSE,
-                       conditionally_skipped = FALSE;
-       lua_State *L;
-       struct rspamd_task **ptask;
-       GList *cur;
-       gint cb_ref;
-       gchar *cond_str = NULL;
+       gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
 
        if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
                        *err == NULL) {
@@ -544,52 +618,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
                        continue;
                }
 
-               /* Check all conditions for this classifier */
-               cur = cl->cfg->learn_conditions;
-               L = task->cfg->lua_state;
-
-               while (cur) {
-                       cb_ref = GPOINTER_TO_INT (cur->data);
-
-                       gint old_top = lua_gettop (L);
-                       lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref);
-                       /* Push task and two booleans: is_spam and is_unlearn */
-                       ptask = lua_newuserdata (L, sizeof (*ptask));
-                       *ptask = task;
-                       rspamd_lua_setclass (L, "rspamd{task}", -1);
-                       lua_pushboolean (L, spam);
-                       lua_pushboolean (L,
-                                       task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
-
-                       if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) {
-                               msg_err_task ("call to %s failed: %s",
-                                               "condition callback",
-                                               lua_tostring (L, -1));
-                       }
-                       else {
-                               if (lua_isboolean (L, 1)) {
-                                       if (!lua_toboolean (L, 1)) {
-                                               conditionally_skipped = TRUE;
-                                               /* Also check for error string if needed */
-                                               if (lua_isstring (L, 2)) {
-                                                       cond_str = rspamd_mempool_strdup (task->task_pool,
-                                                                       lua_tostring (L, 2));
-                                               }
-
-                                               lua_settop (L, old_top);
-                                               break;
-                                       }
-                               }
-                       }
-
-                       lua_settop (L, old_top);
-                       cur = g_list_next (cur);
-               }
-
-               if (conditionally_skipped) {
-                       break;
-               }
-
                if (cl->subrs->learn_spam_func (cl, task->tokens, task, spam,
                                task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
                        learned = TRUE;
@@ -627,14 +655,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
                                        task->tokens->len,
                                        sel->cfg->min_tokens);
                }
-               else if (conditionally_skipped) {
-                       g_set_error (err, rspamd_stat_quark (), 204,
-                                       "<%s> is skipped for %s classifier: "
-                                       "%s",
-                                       MESSAGE_FIELD (task, message_id),
-                                       sel->cfg->name,
-                                       cond_str ? cond_str : "unknown reason");
-               }
        }
 
        return learned;
@@ -828,7 +848,7 @@ rspamd_stat_learn (struct rspamd_task *task,
 
        if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
                /* Process classifiers */
-               rspamd_stat_preprocess (st_ctx, task, TRUE);
+               rspamd_stat_preprocess (st_ctx, task, TRUE, spam);
 
                if (!rspamd_stat_cache_check (st_ctx, task, classifier, spam, err)) {
                        return RSPAMD_STAT_PROCESS_ERROR;
index 5d874d507ea7c769d16f9be6e21c3462c93d80a8..ee29f9b9d1fabf160f6927722218463f6a37e3a7 100644 (file)
@@ -2294,7 +2294,7 @@ rspamd_lua_require_function (lua_State *L, const gchar *modname,
 
 gint
 rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
-                                                                 GError **err)
+                                                                 const gchar *modname, GError **err)
 {
        gint err_idx, ref_idx;
 
@@ -2302,11 +2302,12 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
        err_idx = lua_gettop (L);
 
        /* Load file */
-       if (luaL_loadbuffer (L, str, slen, "lua_embedded_str") != 0) {
+       if (luaL_loadbuffer (L, str, slen, modname) != 0) {
                g_set_error (err,
                                lua_error_quark(),
                                EINVAL,
-                               "cannot load lua script: %s",
+                               "%s: cannot load lua script: %s",
+                               modname,
                                lua_tostring (L, -1));
                lua_settop (L, err_idx - 1); /* Error function */
 
@@ -2318,7 +2319,8 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
                g_set_error (err,
                                lua_error_quark(),
                                EINVAL,
-                               "cannot init lua script: %s",
+                               "%s: cannot init lua script: %s",
+                               modname,
                                lua_tostring (L, -1));
                lua_settop (L, err_idx - 1);
 
@@ -2329,8 +2331,10 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
                g_set_error (err,
                                lua_error_quark(),
                                EINVAL,
-                               "cannot init lua script: "
-                               "must return function");
+                               "%s: cannot init lua script: "
+                               "must return function not %s",
+                               modname,
+                               lua_typename (L, lua_type (L, -1)));
                lua_settop (L, err_idx - 1);
 
                return LUA_NOREF;
index b929ab86421fb371799751002827d228f1261399..10816d450d196e5c9dc8e2e6b2e86fd2ea3dceb4 100644 (file)
@@ -572,7 +572,7 @@ void rspamd_lua_add_ref_dtor (lua_State *L, rspamd_mempool_t *pool,
  * @return
  */
 gint rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
-                                                                          GError **err);
+                                                                          const gchar *modname, GError **err);
 
 /**
 * Tries to load some module using `require` and get some method from it