diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2021-09-01 14:26:32 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2021-09-01 14:26:32 +0100 |
commit | 718238fd33017f346d1e84fe757481f9f147eb90 (patch) | |
tree | 34140ab35d6b9709d3c8ff45c8c1a7501ee44dd9 /src/libstat/stat_process.c | |
parent | 6b80e5120a9edeebee4e266fc17c81e2a5ddaf40 (diff) | |
download | rspamd-718238fd33017f346d1e84fe757481f9f147eb90.tar.gz rspamd-718238fd33017f346d1e84fe757481f9f147eb90.zip |
[Rework] Rework learn and add classify condition
Diffstat (limited to 'src/libstat/stat_process.c')
-rw-r--r-- | src/libstat/stat_process.c | 180 |
1 files changed, 100 insertions, 80 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 8ac4e499e..4e856b563 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -190,9 +190,75 @@ rspamd_stat_process_tokenize (struct rspamd_stat_ctx *st_ctx, b32_hout, g_free); } +static gboolean +rspamd_stat_classifier_is_skipped (struct rspamd_task *task, + struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam) +{ + GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions; + lua_State *L = task->cfg->lua_state; + gboolean ret = FALSE; + + while (cur) { + gint cb_ref = GPOINTER_TO_INT (cur->data); + gint old_top = lua_gettop (L); + + lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref); + /* Push task and two booleans: is_spam and is_unlearn */ + struct rspamd_task **ptask = lua_newuserdata (L, sizeof (*ptask)); + *ptask = task; + rspamd_lua_setclass (L, "rspamd{task}", -1); + + if (is_learn) { + lua_pushboolean(L, is_spam); + lua_pushboolean(L, + task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false); + } + + if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) { + msg_err_task ("call to %s failed: %s", + "condition callback", + lua_tostring (L, -1)); + } + else { + if (lua_isboolean (L, 1)) { + if (!lua_toboolean (L, 1)) { + ret = TRUE; + } + } + + if (lua_isstring (L, 2)) { + if (ret) { + msg_notice_task ("%s condition for classifier %s returned: %s; skip classifier", + is_learn ? "learn" : "classify", cl->cfg->name, + lua_tostring(L, 2)); + } + else { + msg_info_task ("%s condition for classifier %s returned: %s", + is_learn ? "learn" : "classify", cl->cfg->name, + lua_tostring(L, 2)); + } + } + else if (ret) { + msg_notice_task("%s condition for classifier %s returned false; skip classifier", + is_learn ? "learn" : "classify", cl->cfg->name); + } + + if (ret) { + lua_settop (L, old_top); + break; + } + } + + lua_settop (L, old_top); + cur = g_list_next (cur); + } + + return ret; +} + static void rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, - struct rspamd_task *task, gboolean learn) + struct rspamd_task *task, gboolean is_learn, gboolean is_spam) { guint i; struct rspamd_statfile *st; @@ -207,12 +273,39 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, rspamd_mempool_add_destructor (task->task_pool, rspamd_ptr_array_free_hard, task->stat_runtimes); + /* Temporary set all stat_runtimes to some max size to distinguish from NULL */ + for (i = 0; i < st_ctx->statfiles->len; i ++) { + g_ptr_array_index (task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE); + } + + for (i = 0; i < st_ctx->classifiers->len; i++) { + struct rspamd_classifier *cl = g_ptr_array_index (st_ctx->classifiers, i); + gboolean skip_classifier = FALSE; + + if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { + skip_classifier = TRUE; + } + else { + if (rspamd_stat_classifier_is_skipped (task, cl, is_learn , is_spam)) { + skip_classifier = TRUE; + } + } + + if (skip_classifier) { + /* Set NULL for all statfiles indexed by id */ + for (int j = 0; j < cl->statfiles_ids->len; j++) { + int id = g_array_index (cl->statfiles_ids, gint, j); + g_ptr_array_index (task->stat_runtimes, id) = NULL; + } + } + } + for (i = 0; i < st_ctx->statfiles->len; i ++) { st = g_ptr_array_index (st_ctx->statfiles, i); g_assert (st != NULL); - if (st->classifier->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { - g_ptr_array_index (task->stat_runtimes, i) = NULL; + if (g_ptr_array_index (task->stat_runtimes, i) == NULL) { + /* The whole classifier is skipped */ continue; } @@ -224,7 +317,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, continue; } - bk_run = st->backend->runtime (task, st->stcf, learn, st->bkcf); + bk_run = st->backend->runtime (task, st->stcf, is_learn, st->bkcf); if (bk_run == NULL) { msg_err_task ("cannot init backend %s for statfile %s", @@ -249,11 +342,6 @@ rspamd_stat_backends_process (struct rspamd_stat_ctx *st_ctx, for (i = 0; i < st_ctx->statfiles->len; i++) { st = g_ptr_array_index (st_ctx->statfiles, i); cl = st->classifier; - - if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { - continue; - } - bk_run = g_ptr_array_index (task->stat_runtimes, i); if (bk_run != NULL) { @@ -302,10 +390,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx, st = g_ptr_array_index (st_ctx->statfiles, i); cl = st->classifier; - if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { - continue; - } - bk_run = g_ptr_array_index (task->stat_runtimes, i); g_assert (st != NULL); @@ -332,10 +416,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx, /* Do not process classifiers on backend failures */ for (j = 0; j < cl->statfiles_ids->len; j++) { - if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { - continue; - } - id = g_array_index (cl->statfiles_ids, gint, j); bk_run = g_ptr_array_index (task->stat_runtimes, id); st = g_ptr_array_index (st_ctx->statfiles, id); @@ -406,7 +486,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage, if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) { /* Preprocess tokens */ - rspamd_stat_preprocess (st_ctx, task, FALSE); + rspamd_stat_preprocess (st_ctx, task, FALSE, FALSE); } else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) { /* Process backends */ @@ -490,13 +570,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx, { struct rspamd_classifier *cl, *sel = NULL; guint i; - gboolean learned = FALSE, too_small = FALSE, too_large = FALSE, - conditionally_skipped = FALSE; - lua_State *L; - struct rspamd_task **ptask; - GList *cur; - gint cb_ref; - gchar *cond_str = NULL; + gboolean learned = FALSE, too_small = FALSE, too_large = FALSE; if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL && *err == NULL) { @@ -544,52 +618,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx, continue; } - /* Check all conditions for this classifier */ - cur = cl->cfg->learn_conditions; - L = task->cfg->lua_state; - - while (cur) { - cb_ref = GPOINTER_TO_INT (cur->data); - - gint old_top = lua_gettop (L); - lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref); - /* Push task and two booleans: is_spam and is_unlearn */ - ptask = lua_newuserdata (L, sizeof (*ptask)); - *ptask = task; - rspamd_lua_setclass (L, "rspamd{task}", -1); - lua_pushboolean (L, spam); - lua_pushboolean (L, - task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false); - - if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) { - msg_err_task ("call to %s failed: %s", - "condition callback", - lua_tostring (L, -1)); - } - else { - if (lua_isboolean (L, 1)) { - if (!lua_toboolean (L, 1)) { - conditionally_skipped = TRUE; - /* Also check for error string if needed */ - if (lua_isstring (L, 2)) { - cond_str = rspamd_mempool_strdup (task->task_pool, - lua_tostring (L, 2)); - } - - lua_settop (L, old_top); - break; - } - } - } - - lua_settop (L, old_top); - cur = g_list_next (cur); - } - - if (conditionally_skipped) { - break; - } - if (cl->subrs->learn_spam_func (cl, task->tokens, task, spam, task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { learned = TRUE; @@ -627,14 +655,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx, task->tokens->len, sel->cfg->min_tokens); } - else if (conditionally_skipped) { - g_set_error (err, rspamd_stat_quark (), 204, - "<%s> is skipped for %s classifier: " - "%s", - MESSAGE_FIELD (task, message_id), - sel->cfg->name, - cond_str ? cond_str : "unknown reason"); - } } return learned; @@ -828,7 +848,7 @@ rspamd_stat_learn (struct rspamd_task *task, if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) { /* Process classifiers */ - rspamd_stat_preprocess (st_ctx, task, TRUE); + rspamd_stat_preprocess (st_ctx, task, TRUE, spam); if (!rspamd_stat_cache_check (st_ctx, task, classifier, spam, err)) { return RSPAMD_STAT_PROCESS_ERROR; |