aboutsummaryrefslogtreecommitdiffstats
path: root/src/libstat
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2021-09-01 14:26:32 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2021-09-01 14:26:32 +0100
commit718238fd33017f346d1e84fe757481f9f147eb90 (patch)
tree34140ab35d6b9709d3c8ff45c8c1a7501ee44dd9 /src/libstat
parent6b80e5120a9edeebee4e266fc17c81e2a5ddaf40 (diff)
downloadrspamd-718238fd33017f346d1e84fe757481f9f147eb90.tar.gz
rspamd-718238fd33017f346d1e84fe757481f9f147eb90.zip
[Rework] Rework learn and add classify condition
Diffstat (limited to 'src/libstat')
-rw-r--r--src/libstat/stat_process.c180
1 files changed, 100 insertions, 80 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 8ac4e499e..4e856b563 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -190,9 +190,75 @@ rspamd_stat_process_tokenize (struct rspamd_stat_ctx *st_ctx,
b32_hout, g_free);
}
+static gboolean
+rspamd_stat_classifier_is_skipped (struct rspamd_task *task,
+ struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam)
+{
+ GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions;
+ lua_State *L = task->cfg->lua_state;
+ gboolean ret = FALSE;
+
+ while (cur) {
+ gint cb_ref = GPOINTER_TO_INT (cur->data);
+ gint old_top = lua_gettop (L);
+
+ lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref);
+ /* Push task and two booleans: is_spam and is_unlearn */
+ struct rspamd_task **ptask = lua_newuserdata (L, sizeof (*ptask));
+ *ptask = task;
+ rspamd_lua_setclass (L, "rspamd{task}", -1);
+
+ if (is_learn) {
+ lua_pushboolean(L, is_spam);
+ lua_pushboolean(L,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
+ }
+
+ if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) {
+ msg_err_task ("call to %s failed: %s",
+ "condition callback",
+ lua_tostring (L, -1));
+ }
+ else {
+ if (lua_isboolean (L, 1)) {
+ if (!lua_toboolean (L, 1)) {
+ ret = TRUE;
+ }
+ }
+
+ if (lua_isstring (L, 2)) {
+ if (ret) {
+ msg_notice_task ("%s condition for classifier %s returned: %s; skip classifier",
+ is_learn ? "learn" : "classify", cl->cfg->name,
+ lua_tostring(L, 2));
+ }
+ else {
+ msg_info_task ("%s condition for classifier %s returned: %s",
+ is_learn ? "learn" : "classify", cl->cfg->name,
+ lua_tostring(L, 2));
+ }
+ }
+ else if (ret) {
+ msg_notice_task("%s condition for classifier %s returned false; skip classifier",
+ is_learn ? "learn" : "classify", cl->cfg->name);
+ }
+
+ if (ret) {
+ lua_settop (L, old_top);
+ break;
+ }
+ }
+
+ lua_settop (L, old_top);
+ cur = g_list_next (cur);
+ }
+
+ return ret;
+}
+
static void
rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
- struct rspamd_task *task, gboolean learn)
+ struct rspamd_task *task, gboolean is_learn, gboolean is_spam)
{
guint i;
struct rspamd_statfile *st;
@@ -207,12 +273,39 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
rspamd_mempool_add_destructor (task->task_pool,
rspamd_ptr_array_free_hard, task->stat_runtimes);
+ /* Temporary set all stat_runtimes to some max size to distinguish from NULL */
+ for (i = 0; i < st_ctx->statfiles->len; i ++) {
+ g_ptr_array_index (task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE);
+ }
+
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ struct rspamd_classifier *cl = g_ptr_array_index (st_ctx->classifiers, i);
+ gboolean skip_classifier = FALSE;
+
+ if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
+ skip_classifier = TRUE;
+ }
+ else {
+ if (rspamd_stat_classifier_is_skipped (task, cl, is_learn , is_spam)) {
+ skip_classifier = TRUE;
+ }
+ }
+
+ if (skip_classifier) {
+ /* Set NULL for all statfiles indexed by id */
+ for (int j = 0; j < cl->statfiles_ids->len; j++) {
+ int id = g_array_index (cl->statfiles_ids, gint, j);
+ g_ptr_array_index (task->stat_runtimes, id) = NULL;
+ }
+ }
+ }
+
for (i = 0; i < st_ctx->statfiles->len; i ++) {
st = g_ptr_array_index (st_ctx->statfiles, i);
g_assert (st != NULL);
- if (st->classifier->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
- g_ptr_array_index (task->stat_runtimes, i) = NULL;
+ if (g_ptr_array_index (task->stat_runtimes, i) == NULL) {
+ /* The whole classifier is skipped */
continue;
}
@@ -224,7 +317,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
continue;
}
- bk_run = st->backend->runtime (task, st->stcf, learn, st->bkcf);
+ bk_run = st->backend->runtime (task, st->stcf, is_learn, st->bkcf);
if (bk_run == NULL) {
msg_err_task ("cannot init backend %s for statfile %s",
@@ -249,11 +342,6 @@ rspamd_stat_backends_process (struct rspamd_stat_ctx *st_ctx,
for (i = 0; i < st_ctx->statfiles->len; i++) {
st = g_ptr_array_index (st_ctx->statfiles, i);
cl = st->classifier;
-
- if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
- continue;
- }
-
bk_run = g_ptr_array_index (task->stat_runtimes, i);
if (bk_run != NULL) {
@@ -302,10 +390,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
st = g_ptr_array_index (st_ctx->statfiles, i);
cl = st->classifier;
- if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
- continue;
- }
-
bk_run = g_ptr_array_index (task->stat_runtimes, i);
g_assert (st != NULL);
@@ -332,10 +416,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
/* Do not process classifiers on backend failures */
for (j = 0; j < cl->statfiles_ids->len; j++) {
- if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
- continue;
- }
-
id = g_array_index (cl->statfiles_ids, gint, j);
bk_run = g_ptr_array_index (task->stat_runtimes, id);
st = g_ptr_array_index (st_ctx->statfiles, id);
@@ -406,7 +486,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage,
if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) {
/* Preprocess tokens */
- rspamd_stat_preprocess (st_ctx, task, FALSE);
+ rspamd_stat_preprocess (st_ctx, task, FALSE, FALSE);
}
else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) {
/* Process backends */
@@ -490,13 +570,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
{
struct rspamd_classifier *cl, *sel = NULL;
guint i;
- gboolean learned = FALSE, too_small = FALSE, too_large = FALSE,
- conditionally_skipped = FALSE;
- lua_State *L;
- struct rspamd_task **ptask;
- GList *cur;
- gint cb_ref;
- gchar *cond_str = NULL;
+ gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
*err == NULL) {
@@ -544,52 +618,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
continue;
}
- /* Check all conditions for this classifier */
- cur = cl->cfg->learn_conditions;
- L = task->cfg->lua_state;
-
- while (cur) {
- cb_ref = GPOINTER_TO_INT (cur->data);
-
- gint old_top = lua_gettop (L);
- lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref);
- /* Push task and two booleans: is_spam and is_unlearn */
- ptask = lua_newuserdata (L, sizeof (*ptask));
- *ptask = task;
- rspamd_lua_setclass (L, "rspamd{task}", -1);
- lua_pushboolean (L, spam);
- lua_pushboolean (L,
- task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
-
- if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) {
- msg_err_task ("call to %s failed: %s",
- "condition callback",
- lua_tostring (L, -1));
- }
- else {
- if (lua_isboolean (L, 1)) {
- if (!lua_toboolean (L, 1)) {
- conditionally_skipped = TRUE;
- /* Also check for error string if needed */
- if (lua_isstring (L, 2)) {
- cond_str = rspamd_mempool_strdup (task->task_pool,
- lua_tostring (L, 2));
- }
-
- lua_settop (L, old_top);
- break;
- }
- }
- }
-
- lua_settop (L, old_top);
- cur = g_list_next (cur);
- }
-
- if (conditionally_skipped) {
- break;
- }
-
if (cl->subrs->learn_spam_func (cl, task->tokens, task, spam,
task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
learned = TRUE;
@@ -627,14 +655,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
task->tokens->len,
sel->cfg->min_tokens);
}
- else if (conditionally_skipped) {
- g_set_error (err, rspamd_stat_quark (), 204,
- "<%s> is skipped for %s classifier: "
- "%s",
- MESSAGE_FIELD (task, message_id),
- sel->cfg->name,
- cond_str ? cond_str : "unknown reason");
- }
}
return learned;
@@ -828,7 +848,7 @@ rspamd_stat_learn (struct rspamd_task *task,
if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
/* Process classifiers */
- rspamd_stat_preprocess (st_ctx, task, TRUE);
+ rspamd_stat_preprocess (st_ctx, task, TRUE, spam);
if (!rspamd_stat_cache_check (st_ctx, task, classifier, spam, err)) {
return RSPAMD_STAT_PROCESS_ERROR;