]> source.dussan.org Git - rspamd.git/commitdiff
Add substages for classification
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 29 Dec 2015 15:16:17 +0000 (15:16 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 29 Dec 2015 15:16:17 +0000 (15:16 +0000)
src/libserver/task.c
src/libserver/task.h
src/libstat/stat_api.h
src/libstat/stat_process.c

index 758b9080701751af7fe3d3a2645a1ea8179f4de2..4f3a9d72c9bd3522b16a6bf6b0233ae9b1805a75 100644 (file)
@@ -440,7 +440,9 @@ rspamd_task_process (struct rspamd_task *task, guint stages)
                break;
 
        case RSPAMD_TASK_STAGE_CLASSIFIERS:
-               if (rspamd_stat_classify (task, task->cfg->lua_state, &stat_error) ==
+       case RSPAMD_TASK_STAGE_CLASSIFIERS_PRE:
+       case RSPAMD_TASK_STAGE_CLASSIFIERS_POST:
+               if (rspamd_stat_classify (task, task->cfg->lua_state, st, &stat_error) ==
                                RSPAMD_STAT_PROCESS_ERROR) {
                        msg_err_task ("classify error: %e", stat_error);
                        g_error_free (stat_error);
index 2047d701c2eee086415f1592383c8cd6381bf317..359b2f41f284010b0355e88989f867a549e1698b 100644 (file)
@@ -60,11 +60,13 @@ enum rspamd_task_stage {
        RSPAMD_TASK_STAGE_READ_MESSAGE = (1 << 2),
        RSPAMD_TASK_STAGE_PRE_FILTERS = (1 << 3),
        RSPAMD_TASK_STAGE_FILTERS = (1 << 4),
-       RSPAMD_TASK_STAGE_CLASSIFIERS = (1 << 5),
-       RSPAMD_TASK_STAGE_COMPOSITES = (1 << 6),
-       RSPAMD_TASK_STAGE_POST_FILTERS = (1 << 7),
-       RSPAMD_TASK_STAGE_DONE = (1 << 8),
-       RSPAMD_TASK_STAGE_REPLIED = (1 << 9)
+       RSPAMD_TASK_STAGE_CLASSIFIERS_PRE = (1 << 5),
+       RSPAMD_TASK_STAGE_CLASSIFIERS = (1 << 6),
+       RSPAMD_TASK_STAGE_CLASSIFIERS_POST = (1 << 7),
+       RSPAMD_TASK_STAGE_COMPOSITES = (1 << 8),
+       RSPAMD_TASK_STAGE_POST_FILTERS = (1 << 9),
+       RSPAMD_TASK_STAGE_DONE = (1 << 10),
+       RSPAMD_TASK_STAGE_REPLIED = (1 << 11)
 };
 
 #define RSPAMD_TASK_PROCESS_ALL (RSPAMD_TASK_STAGE_CONNECT | \
@@ -72,7 +74,9 @@ enum rspamd_task_stage {
                RSPAMD_TASK_STAGE_READ_MESSAGE | \
                RSPAMD_TASK_STAGE_PRE_FILTERS | \
                RSPAMD_TASK_STAGE_FILTERS | \
+               RSPAMD_TASK_STAGE_CLASSIFIERS_PRE | \
                RSPAMD_TASK_STAGE_CLASSIFIERS | \
+               RSPAMD_TASK_STAGE_CLASSIFIERS_POST | \
                RSPAMD_TASK_STAGE_COMPOSITES | \
                RSPAMD_TASK_STAGE_POST_FILTERS | \
                RSPAMD_TASK_STAGE_DONE)
@@ -154,6 +158,7 @@ struct rspamd_task {
 
        GList *messages;                                                                /**< list of messages that would be reported            */
        struct rspamd_re_runtime *re_rt;                                /**< regexp runtime                                                                     */
+       GList *cl_runtimes;                                                             /**< classifiers runtime                                                        */
        struct rspamd_config *cfg;                                              /**< pointer to config object                                           */
        GError *err;
        rspamd_mempool_t *task_pool;                                    /**< memory pool for task                                                       */
index ba5dc4a409b40cee8d39edfac46f1a7c0e22feca..a4a28a4bc9ff9727b272c8696fa97d4bad5d7fd6 100644 (file)
@@ -63,7 +63,7 @@ void rspamd_stat_close (void);
  * @return TRUE if task has been classified
  */
 rspamd_stat_result_t rspamd_stat_classify (struct rspamd_task *task,
-               lua_State *L, GError **err);
+               lua_State *L, guint stage, GError **err);
 
 
 /**
index 57482974d1cb0c457abba48739afeb5df4967d4b..89df10a07d2e06ef0be94b2261dba4942de4b319 100644 (file)
@@ -540,7 +540,8 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
 }
 
 rspamd_stat_result_t
-rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
+rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage,
+               GError **err)
 {
        struct rspamd_stat_ctx *st_ctx;
        struct rspamd_statfile_runtime *st_run;
@@ -552,63 +553,83 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
 
        st_ctx = rspamd_stat_get_ctx ();
        g_assert (st_ctx != NULL);
+       cl_runtimes = task->cl_runtimes;
+
+       if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) {
+               /* Initialize classifiers and statfiles runtime */
+               if (task->cl_runtimes == NULL) {
+                       if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, L,
+                                       RSPAMD_CLASSIFY_OP, FALSE, NULL, err)) == NULL) {
+                               return RSPAMD_STAT_PROCESS_OK;
+                       }
 
-       /* Initialize classifiers and statfiles runtime */
-       if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, L,
-                       RSPAMD_CLASSIFY_OP, FALSE, NULL, err)) == NULL) {
-               return RSPAMD_STAT_PROCESS_OK;
-       }
+                       task->cl_runtimes = cl_runtimes;
+               }
 
-       cur = cl_runtimes;
+               cur = cl_runtimes;
 
-       while (cur) {
-               cl_run = (struct rspamd_classifier_runtime *)cur->data;
-               cl_run->stage = RSPAMD_STAT_STAGE_PRE;
+               while (cur) {
+                       cl_run = (struct rspamd_classifier_runtime *) cur->data;
+                       cl_run->stage = RSPAMD_STAT_STAGE_PRE;
 
-               if (cl_run->cl) {
-                       cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf);
+                       if (cl_run->cl) {
+                               cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf);
 
-                       if (cl_ctx != NULL) {
-                               cl_run->cl->classify_func (cl_ctx, cl_run->tok->tokens,
-                                               cl_run, task);
+                               if (cl_ctx != NULL) {
+                                       cl_run->cl->classify_func (cl_ctx, cl_run->tok->tokens,
+                                                       cl_run, task);
+                               }
                        }
-               }
 
-               cur = g_list_next (cur);
+                       cur = g_list_next (cur);
+               }
        }
+       else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) {
+               cur = cl_runtimes;
+               while (cur) {
+                       cl_run = (struct rspamd_classifier_runtime *) cur->data;
+                       cl_run->stage = RSPAMD_STAT_STAGE_POST;
 
-       /* XXX: backend runtime post-processing */
-       /* Post-processing */
-       cur = cl_runtimes;
-       while (cur) {
-               cl_run = (struct rspamd_classifier_runtime *)cur->data;
-               cl_run->stage = RSPAMD_STAT_STAGE_POST;
+                       if (cl_run->skipped) {
+                               cur = g_list_next (cur);
+                               continue;
+                       }
+
+                       if (cl_run->cl) {
+                               if (cl_ctx != NULL) {
+                                       if (cl_run->cl->classify_func (cl_ctx, cl_run->tok->tokens,
+                                                       cl_run, task)) {
+                                               ret = RSPAMD_STAT_PROCESS_OK;
+                                       }
+                               }
+                       }
 
-               if (cl_run->skipped) {
                        cur = g_list_next (cur);
-                       continue;
                }
+       }
+       else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_POST) {
+               cur = cl_runtimes;
+               while (cur) {
+                       cl_run = (struct rspamd_classifier_runtime *) cur->data;
+                       cl_run->stage = RSPAMD_STAT_STAGE_POST;
 
-               if (cl_run->cl) {
-                       if (cl_ctx != NULL) {
-                               if (cl_run->cl->classify_func (cl_ctx, cl_run->tok->tokens,
-                                               cl_run, task)) {
-                                       ret = RSPAMD_STAT_PROCESS_OK;
-                               }
+                       if (cl_run->skipped) {
+                               cur = g_list_next (cur);
+                               continue;
                        }
-               }
 
-               curst = cl_run->st_runtime;
+                       curst = cl_run->st_runtime;
 
-               while (curst) {
-                       st_run = curst->data;
-                       cl_run->backend->finalize_process (task,
-                                       st_run->backend_runtime,
-                                       cl_run->backend->ctx);
-                       curst = g_list_next (curst);
-               }
+                       while (curst) {
+                               st_run = curst->data;
+                               cl_run->backend->finalize_process (task,
+                                               st_run->backend_runtime,
+                                               cl_run->backend->ctx);
+                               curst = g_list_next (curst);
+                       }
 
-               cur = g_list_next (cur);
+                       cur = g_list_next (cur);
+               }
        }
 
        return ret;