From 84bba58ac10018f5ad541331ad3b84ea1b1119b6 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Tue, 29 Dec 2015 15:16:17 +0000 Subject: [PATCH] Add substages for classification --- src/libserver/task.c | 4 +- src/libserver/task.h | 15 ++++-- src/libstat/stat_api.h | 2 +- src/libstat/stat_process.c | 103 ++++++++++++++++++++++--------------- 4 files changed, 76 insertions(+), 48 deletions(-) diff --git a/src/libserver/task.c b/src/libserver/task.c index 758b90807..4f3a9d72c 100644 --- a/src/libserver/task.c +++ b/src/libserver/task.c @@ -440,7 +440,9 @@ rspamd_task_process (struct rspamd_task *task, guint stages) break; case RSPAMD_TASK_STAGE_CLASSIFIERS: - if (rspamd_stat_classify (task, task->cfg->lua_state, &stat_error) == + case RSPAMD_TASK_STAGE_CLASSIFIERS_PRE: + case RSPAMD_TASK_STAGE_CLASSIFIERS_POST: + if (rspamd_stat_classify (task, task->cfg->lua_state, st, &stat_error) == RSPAMD_STAT_PROCESS_ERROR) { msg_err_task ("classify error: %e", stat_error); g_error_free (stat_error); diff --git a/src/libserver/task.h b/src/libserver/task.h index 2047d701c..359b2f41f 100644 --- a/src/libserver/task.h +++ b/src/libserver/task.h @@ -60,11 +60,13 @@ enum rspamd_task_stage { RSPAMD_TASK_STAGE_READ_MESSAGE = (1 << 2), RSPAMD_TASK_STAGE_PRE_FILTERS = (1 << 3), RSPAMD_TASK_STAGE_FILTERS = (1 << 4), - RSPAMD_TASK_STAGE_CLASSIFIERS = (1 << 5), - RSPAMD_TASK_STAGE_COMPOSITES = (1 << 6), - RSPAMD_TASK_STAGE_POST_FILTERS = (1 << 7), - RSPAMD_TASK_STAGE_DONE = (1 << 8), - RSPAMD_TASK_STAGE_REPLIED = (1 << 9) + RSPAMD_TASK_STAGE_CLASSIFIERS_PRE = (1 << 5), + RSPAMD_TASK_STAGE_CLASSIFIERS = (1 << 6), + RSPAMD_TASK_STAGE_CLASSIFIERS_POST = (1 << 7), + RSPAMD_TASK_STAGE_COMPOSITES = (1 << 8), + RSPAMD_TASK_STAGE_POST_FILTERS = (1 << 9), + RSPAMD_TASK_STAGE_DONE = (1 << 10), + RSPAMD_TASK_STAGE_REPLIED = (1 << 11) }; #define RSPAMD_TASK_PROCESS_ALL (RSPAMD_TASK_STAGE_CONNECT | \ @@ -72,7 +74,9 @@ enum rspamd_task_stage { RSPAMD_TASK_STAGE_READ_MESSAGE | \ RSPAMD_TASK_STAGE_PRE_FILTERS | \ RSPAMD_TASK_STAGE_FILTERS | \ + RSPAMD_TASK_STAGE_CLASSIFIERS_PRE | \ RSPAMD_TASK_STAGE_CLASSIFIERS | \ + RSPAMD_TASK_STAGE_CLASSIFIERS_POST | \ RSPAMD_TASK_STAGE_COMPOSITES | \ RSPAMD_TASK_STAGE_POST_FILTERS | \ RSPAMD_TASK_STAGE_DONE) @@ -154,6 +158,7 @@ struct rspamd_task { GList *messages; /**< list of messages that would be reported */ struct rspamd_re_runtime *re_rt; /**< regexp runtime */ + GList *cl_runtimes; /**< classifiers runtime */ struct rspamd_config *cfg; /**< pointer to config object */ GError *err; rspamd_mempool_t *task_pool; /**< memory pool for task */ diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h index ba5dc4a40..a4a28a4bc 100644 --- a/src/libstat/stat_api.h +++ b/src/libstat/stat_api.h @@ -63,7 +63,7 @@ void rspamd_stat_close (void); * @return TRUE if task has been classified */ rspamd_stat_result_t rspamd_stat_classify (struct rspamd_task *task, - lua_State *L, GError **err); + lua_State *L, guint stage, GError **err); /** diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 57482974d..89df10a07 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -540,7 +540,8 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, } rspamd_stat_result_t -rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err) +rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage, + GError **err) { struct rspamd_stat_ctx *st_ctx; struct rspamd_statfile_runtime *st_run; @@ -552,63 +553,83 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err) st_ctx = rspamd_stat_get_ctx (); g_assert (st_ctx != NULL); + cl_runtimes = task->cl_runtimes; + + if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) { + /* Initialize classifiers and statfiles runtime */ + if (task->cl_runtimes == NULL) { + if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, L, + RSPAMD_CLASSIFY_OP, FALSE, NULL, err)) == NULL) { + return RSPAMD_STAT_PROCESS_OK; + } - /* Initialize classifiers and statfiles runtime */ - if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, L, - RSPAMD_CLASSIFY_OP, FALSE, NULL, err)) == NULL) { - return RSPAMD_STAT_PROCESS_OK; - } + task->cl_runtimes = cl_runtimes; + } - cur = cl_runtimes; + cur = cl_runtimes; - while (cur) { - cl_run = (struct rspamd_classifier_runtime *)cur->data; - cl_run->stage = RSPAMD_STAT_STAGE_PRE; + while (cur) { + cl_run = (struct rspamd_classifier_runtime *) cur->data; + cl_run->stage = RSPAMD_STAT_STAGE_PRE; - if (cl_run->cl) { - cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf); + if (cl_run->cl) { + cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf); - if (cl_ctx != NULL) { - cl_run->cl->classify_func (cl_ctx, cl_run->tok->tokens, - cl_run, task); + if (cl_ctx != NULL) { + cl_run->cl->classify_func (cl_ctx, cl_run->tok->tokens, + cl_run, task); + } } - } - cur = g_list_next (cur); + cur = g_list_next (cur); + } } + else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) { + cur = cl_runtimes; + while (cur) { + cl_run = (struct rspamd_classifier_runtime *) cur->data; + cl_run->stage = RSPAMD_STAT_STAGE_POST; - /* XXX: backend runtime post-processing */ - /* Post-processing */ - cur = cl_runtimes; - while (cur) { - cl_run = (struct rspamd_classifier_runtime *)cur->data; - cl_run->stage = RSPAMD_STAT_STAGE_POST; + if (cl_run->skipped) { + cur = g_list_next (cur); + continue; + } + + if (cl_run->cl) { + if (cl_ctx != NULL) { + if (cl_run->cl->classify_func (cl_ctx, cl_run->tok->tokens, + cl_run, task)) { + ret = RSPAMD_STAT_PROCESS_OK; + } + } + } - if (cl_run->skipped) { cur = g_list_next (cur); - continue; } + } + else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_POST) { + cur = cl_runtimes; + while (cur) { + cl_run = (struct rspamd_classifier_runtime *) cur->data; + cl_run->stage = RSPAMD_STAT_STAGE_POST; - if (cl_run->cl) { - if (cl_ctx != NULL) { - if (cl_run->cl->classify_func (cl_ctx, cl_run->tok->tokens, - cl_run, task)) { - ret = RSPAMD_STAT_PROCESS_OK; - } + if (cl_run->skipped) { + cur = g_list_next (cur); + continue; } - } - curst = cl_run->st_runtime; + curst = cl_run->st_runtime; - while (curst) { - st_run = curst->data; - cl_run->backend->finalize_process (task, - st_run->backend_runtime, - cl_run->backend->ctx); - curst = g_list_next (curst); - } + while (curst) { + st_run = curst->data; + cl_run->backend->finalize_process (task, + st_run->backend_runtime, + cl_run->backend->ctx); + curst = g_list_next (curst); + } - cur = g_list_next (cur); + cur = g_list_next (cur); + } } return ret; -- 2.39.5