From 718238fd33017f346d1e84fe757481f9f147eb90 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Wed, 1 Sep 2021 14:26:32 +0100 Subject: [PATCH] [Rework] Rework learn and add classify condition --- src/libserver/cfg_file.h | 1 + src/libserver/cfg_rcl.c | 30 ++++++- src/libstat/stat_process.c | 180 ++++++++++++++++++++----------------- src/lua/lua_common.c | 16 ++-- src/lua/lua_common.h | 2 +- 5 files changed, 140 insertions(+), 89 deletions(-) diff --git a/src/libserver/cfg_file.h b/src/libserver/cfg_file.h index 4d865e273..745f0fb22 100644 --- a/src/libserver/cfg_file.h +++ b/src/libserver/cfg_file.h @@ -192,6 +192,7 @@ struct rspamd_classifier_config { const gchar *backend; /**< name of statfile's backend */ ucl_object_t *opts; /**< other options */ GList *learn_conditions; /**< list of learn condition callbacks */ + GList *classify_conditions; /**< list of classify condition callbacks */ gchar *name; /**< unique name of classifier */ guint32 min_tokens; /**< minimal number of tokens to process classifier */ guint32 max_tokens; /**< maximum number of tokens */ diff --git a/src/libserver/cfg_rcl.c b/src/libserver/cfg_rcl.c index 717b16bea..e3c69c343 100644 --- a/src/libserver/cfg_rcl.c +++ b/src/libserver/cfg_rcl.c @@ -1299,7 +1299,7 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool, ccf->tokenizer = tkcf; /* Handle lua conditions */ - val = ucl_object_lookup_any (obj, "condition", "learn_condition", NULL); + val = ucl_object_lookup_any (obj, "learn_condition", NULL); if (val) { LL_FOREACH (val, cur) { @@ -1310,7 +1310,7 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool, lua_script = ucl_object_tolstring(cur, &slen); ref_idx = rspamd_lua_function_ref_from_str(L, - lua_script, slen, err); + lua_script, slen, "learn_condition", err); if (ref_idx == LUA_NOREF) { return FALSE; @@ -1325,6 +1325,32 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool, } } + val = ucl_object_lookup_any (obj, "classify_condition", NULL); + + if (val) { + LL_FOREACH (val, cur) { + if (ucl_object_type(cur) == UCL_STRING) { + const gchar *lua_script; + gsize slen; + gint ref_idx; + + lua_script = ucl_object_tolstring(cur, &slen); + ref_idx = rspamd_lua_function_ref_from_str(L, + lua_script, slen, "classify_condition", err); + + if (ref_idx == LUA_NOREF) { + return FALSE; + } + + rspamd_lua_add_ref_dtor (L, cfg->cfg_pool, ref_idx); + ccf->classify_conditions = rspamd_mempool_glist_append( + cfg->cfg_pool, + ccf->classify_conditions, + GINT_TO_POINTER (ref_idx)); + } + } + } + ccf->opts = (ucl_object_t *)obj; cfg->classifiers = g_list_prepend (cfg->classifiers, ccf); diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 8ac4e499e..4e856b563 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -190,9 +190,75 @@ rspamd_stat_process_tokenize (struct rspamd_stat_ctx *st_ctx, b32_hout, g_free); } +static gboolean +rspamd_stat_classifier_is_skipped (struct rspamd_task *task, + struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam) +{ + GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions; + lua_State *L = task->cfg->lua_state; + gboolean ret = FALSE; + + while (cur) { + gint cb_ref = GPOINTER_TO_INT (cur->data); + gint old_top = lua_gettop (L); + + lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref); + /* Push task and two booleans: is_spam and is_unlearn */ + struct rspamd_task **ptask = lua_newuserdata (L, sizeof (*ptask)); + *ptask = task; + rspamd_lua_setclass (L, "rspamd{task}", -1); + + if (is_learn) { + lua_pushboolean(L, is_spam); + lua_pushboolean(L, + task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false); + } + + if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) { + msg_err_task ("call to %s failed: %s", + "condition callback", + lua_tostring (L, -1)); + } + else { + if (lua_isboolean (L, 1)) { + if (!lua_toboolean (L, 1)) { + ret = TRUE; + } + } + + if (lua_isstring (L, 2)) { + if (ret) { + msg_notice_task ("%s condition for classifier %s returned: %s; skip classifier", + is_learn ? "learn" : "classify", cl->cfg->name, + lua_tostring(L, 2)); + } + else { + msg_info_task ("%s condition for classifier %s returned: %s", + is_learn ? "learn" : "classify", cl->cfg->name, + lua_tostring(L, 2)); + } + } + else if (ret) { + msg_notice_task("%s condition for classifier %s returned false; skip classifier", + is_learn ? "learn" : "classify", cl->cfg->name); + } + + if (ret) { + lua_settop (L, old_top); + break; + } + } + + lua_settop (L, old_top); + cur = g_list_next (cur); + } + + return ret; +} + static void rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, - struct rspamd_task *task, gboolean learn) + struct rspamd_task *task, gboolean is_learn, gboolean is_spam) { guint i; struct rspamd_statfile *st; @@ -207,12 +273,39 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, rspamd_mempool_add_destructor (task->task_pool, rspamd_ptr_array_free_hard, task->stat_runtimes); + /* Temporary set all stat_runtimes to some max size to distinguish from NULL */ + for (i = 0; i < st_ctx->statfiles->len; i ++) { + g_ptr_array_index (task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE); + } + + for (i = 0; i < st_ctx->classifiers->len; i++) { + struct rspamd_classifier *cl = g_ptr_array_index (st_ctx->classifiers, i); + gboolean skip_classifier = FALSE; + + if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { + skip_classifier = TRUE; + } + else { + if (rspamd_stat_classifier_is_skipped (task, cl, is_learn , is_spam)) { + skip_classifier = TRUE; + } + } + + if (skip_classifier) { + /* Set NULL for all statfiles indexed by id */ + for (int j = 0; j < cl->statfiles_ids->len; j++) { + int id = g_array_index (cl->statfiles_ids, gint, j); + g_ptr_array_index (task->stat_runtimes, id) = NULL; + } + } + } + for (i = 0; i < st_ctx->statfiles->len; i ++) { st = g_ptr_array_index (st_ctx->statfiles, i); g_assert (st != NULL); - if (st->classifier->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { - g_ptr_array_index (task->stat_runtimes, i) = NULL; + if (g_ptr_array_index (task->stat_runtimes, i) == NULL) { + /* The whole classifier is skipped */ continue; } @@ -224,7 +317,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, continue; } - bk_run = st->backend->runtime (task, st->stcf, learn, st->bkcf); + bk_run = st->backend->runtime (task, st->stcf, is_learn, st->bkcf); if (bk_run == NULL) { msg_err_task ("cannot init backend %s for statfile %s", @@ -249,11 +342,6 @@ rspamd_stat_backends_process (struct rspamd_stat_ctx *st_ctx, for (i = 0; i < st_ctx->statfiles->len; i++) { st = g_ptr_array_index (st_ctx->statfiles, i); cl = st->classifier; - - if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { - continue; - } - bk_run = g_ptr_array_index (task->stat_runtimes, i); if (bk_run != NULL) { @@ -302,10 +390,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx, st = g_ptr_array_index (st_ctx->statfiles, i); cl = st->classifier; - if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { - continue; - } - bk_run = g_ptr_array_index (task->stat_runtimes, i); g_assert (st != NULL); @@ -332,10 +416,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx, /* Do not process classifiers on backend failures */ for (j = 0; j < cl->statfiles_ids->len; j++) { - if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { - continue; - } - id = g_array_index (cl->statfiles_ids, gint, j); bk_run = g_ptr_array_index (task->stat_runtimes, id); st = g_ptr_array_index (st_ctx->statfiles, id); @@ -406,7 +486,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage, if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) { /* Preprocess tokens */ - rspamd_stat_preprocess (st_ctx, task, FALSE); + rspamd_stat_preprocess (st_ctx, task, FALSE, FALSE); } else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) { /* Process backends */ @@ -490,13 +570,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx, { struct rspamd_classifier *cl, *sel = NULL; guint i; - gboolean learned = FALSE, too_small = FALSE, too_large = FALSE, - conditionally_skipped = FALSE; - lua_State *L; - struct rspamd_task **ptask; - GList *cur; - gint cb_ref; - gchar *cond_str = NULL; + gboolean learned = FALSE, too_small = FALSE, too_large = FALSE; if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL && *err == NULL) { @@ -544,52 +618,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx, continue; } - /* Check all conditions for this classifier */ - cur = cl->cfg->learn_conditions; - L = task->cfg->lua_state; - - while (cur) { - cb_ref = GPOINTER_TO_INT (cur->data); - - gint old_top = lua_gettop (L); - lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref); - /* Push task and two booleans: is_spam and is_unlearn */ - ptask = lua_newuserdata (L, sizeof (*ptask)); - *ptask = task; - rspamd_lua_setclass (L, "rspamd{task}", -1); - lua_pushboolean (L, spam); - lua_pushboolean (L, - task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false); - - if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) { - msg_err_task ("call to %s failed: %s", - "condition callback", - lua_tostring (L, -1)); - } - else { - if (lua_isboolean (L, 1)) { - if (!lua_toboolean (L, 1)) { - conditionally_skipped = TRUE; - /* Also check for error string if needed */ - if (lua_isstring (L, 2)) { - cond_str = rspamd_mempool_strdup (task->task_pool, - lua_tostring (L, 2)); - } - - lua_settop (L, old_top); - break; - } - } - } - - lua_settop (L, old_top); - cur = g_list_next (cur); - } - - if (conditionally_skipped) { - break; - } - if (cl->subrs->learn_spam_func (cl, task->tokens, task, spam, task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { learned = TRUE; @@ -627,14 +655,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx, task->tokens->len, sel->cfg->min_tokens); } - else if (conditionally_skipped) { - g_set_error (err, rspamd_stat_quark (), 204, - "<%s> is skipped for %s classifier: " - "%s", - MESSAGE_FIELD (task, message_id), - sel->cfg->name, - cond_str ? cond_str : "unknown reason"); - } } return learned; @@ -828,7 +848,7 @@ rspamd_stat_learn (struct rspamd_task *task, if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) { /* Process classifiers */ - rspamd_stat_preprocess (st_ctx, task, TRUE); + rspamd_stat_preprocess (st_ctx, task, TRUE, spam); if (!rspamd_stat_cache_check (st_ctx, task, classifier, spam, err)) { return RSPAMD_STAT_PROCESS_ERROR; diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c index 5d874d507..ee29f9b9d 100644 --- a/src/lua/lua_common.c +++ b/src/lua/lua_common.c @@ -2294,7 +2294,7 @@ rspamd_lua_require_function (lua_State *L, const gchar *modname, gint rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen, - GError **err) + const gchar *modname, GError **err) { gint err_idx, ref_idx; @@ -2302,11 +2302,12 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen, err_idx = lua_gettop (L); /* Load file */ - if (luaL_loadbuffer (L, str, slen, "lua_embedded_str") != 0) { + if (luaL_loadbuffer (L, str, slen, modname) != 0) { g_set_error (err, lua_error_quark(), EINVAL, - "cannot load lua script: %s", + "%s: cannot load lua script: %s", + modname, lua_tostring (L, -1)); lua_settop (L, err_idx - 1); /* Error function */ @@ -2318,7 +2319,8 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen, g_set_error (err, lua_error_quark(), EINVAL, - "cannot init lua script: %s", + "%s: cannot init lua script: %s", + modname, lua_tostring (L, -1)); lua_settop (L, err_idx - 1); @@ -2329,8 +2331,10 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen, g_set_error (err, lua_error_quark(), EINVAL, - "cannot init lua script: " - "must return function"); + "%s: cannot init lua script: " + "must return function not %s", + modname, + lua_typename (L, lua_type (L, -1))); lua_settop (L, err_idx - 1); return LUA_NOREF; diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index b929ab864..10816d450 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -572,7 +572,7 @@ void rspamd_lua_add_ref_dtor (lua_State *L, rspamd_mempool_t *pool, * @return */ gint rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen, - GError **err); + const gchar *modname, GError **err); /** * Tries to load some module using `require` and get some method from it -- 2.39.5