]> source.dussan.org Git - rspamd.git/commitdiff
Add learning implementation.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 6 Jan 2016 14:24:07 +0000 (14:24 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 6 Jan 2016 14:24:07 +0000 (14:24 +0000)
src/libserver/task.c
src/libserver/task.h
src/libstat/classifiers/bayes.c
src/libstat/classifiers/classifiers.h
src/libstat/stat_api.h
src/libstat/stat_process.c

index 290101023427b070b3e494cb606094271a62b0c2..579cc3461f4c244aafc09a662e3f41482736fcb4 100644 (file)
@@ -610,11 +610,7 @@ rspamd_learn_task_spam (struct rspamd_task *task,
        const gchar *classifier,
        GError **err)
 {
-       return rspamd_stat_learn (task,
-                       is_spam,
-                       task->cfg->lua_state,
-                       classifier,
-                       err);
+       return FALSE;
 }
 
 static gboolean
@@ -999,7 +995,8 @@ rspamd_task_write_log (struct rspamd_task *task)
 
        g_assert (task != NULL);
 
-       if (task->cfg->log_format == NULL || task->flags & RSPAMD_TASK_FLAG_NO_LOG) {
+       if (task->cfg->log_format == NULL ||
+                       (task->flags & RSPAMD_TASK_FLAG_NO_LOG)) {
                return;
        }
 
index ed18d99d0795348b532f5310fb6448c0ade4eec0..901067ba4c1d158a6286e63cd149653d40b1aec5 100644 (file)
@@ -65,8 +65,11 @@ enum rspamd_task_stage {
        RSPAMD_TASK_STAGE_CLASSIFIERS_POST = (1 << 7),
        RSPAMD_TASK_STAGE_COMPOSITES = (1 << 8),
        RSPAMD_TASK_STAGE_POST_FILTERS = (1 << 9),
-       RSPAMD_TASK_STAGE_DONE = (1 << 10),
-       RSPAMD_TASK_STAGE_REPLIED = (1 << 11)
+       RSPAMD_TASK_STAGE_LEARN_PRE = (1 << 10),
+       RSPAMD_TASK_STAGE_LEARN = (1 << 11),
+       RSPAMD_TASK_STAGE_LEARN_POST = (1 << 12),
+       RSPAMD_TASK_STAGE_DONE = (1 << 13),
+       RSPAMD_TASK_STAGE_REPLIED = (1 << 14)
 };
 
 #define RSPAMD_TASK_PROCESS_ALL (RSPAMD_TASK_STAGE_CONNECT | \
@@ -79,10 +82,16 @@ enum rspamd_task_stage {
                RSPAMD_TASK_STAGE_CLASSIFIERS_POST | \
                RSPAMD_TASK_STAGE_COMPOSITES | \
                RSPAMD_TASK_STAGE_POST_FILTERS | \
+               RSPAMD_TASK_STAGE_LEARN_PRE | \
+               RSPAMD_TASK_STAGE_LEARN | \
+               RSPAMD_TASK_STAGE_LEARN_POST | \
                RSPAMD_TASK_STAGE_DONE)
 #define RSPAMD_TASK_PROCESS_LEARN (RSPAMD_TASK_STAGE_CONNECT | \
                RSPAMD_TASK_STAGE_ENVELOPE | \
                RSPAMD_TASK_STAGE_READ_MESSAGE | \
+               RSPAMD_TASK_STAGE_CLASSIFIERS_PRE | \
+               RSPAMD_TASK_STAGE_CLASSIFIERS | \
+               RSPAMD_TASK_STAGE_CLASSIFIERS_POST | \
                RSPAMD_TASK_STAGE_DONE)
 
 #define RSPAMD_TASK_FLAG_MIME (1 << 0)
@@ -99,11 +108,14 @@ enum rspamd_task_stage {
 #define RSPAMD_TASK_FLAG_GTUBE (1 << 11)
 #define RSPAMD_TASK_FLAG_FILE (1 << 12)
 #define RSPAMD_TASK_FLAG_NO_STAT (1 << 13)
+#define RSPAMD_TASK_FLAG_UNLEARN (1 << 14)
+#define RSPAMD_TASK_FLAG_ALREADY_LEARNED (1 << 15)
 
 #define RSPAMD_TASK_IS_SKIPPED(task) (((task)->flags & RSPAMD_TASK_FLAG_SKIP))
 #define RSPAMD_TASK_IS_JSON(task) (((task)->flags & RSPAMD_TASK_FLAG_JSON))
 #define RSPAMD_TASK_IS_SPAMC(task) (((task)->flags & RSPAMD_TASK_FLAG_SPAMC))
 #define RSPAMD_TASK_IS_PROCESSED(task) (((task)->processed_stages & RSPAMD_TASK_STAGE_DONE))
+#define RSPAMD_TASK_IS_CLASSIFIED(task) (((task)->processed_stages & RSPAMD_TASK_STAGE_CLASSIFIERS))
 
 typedef gint (*protocol_reply_func)(struct rspamd_task *task);
 
index 0915933f1485b34fc112b68d6b199e77157567c6..b08c703808e94eab6eaf7a00602ee7f44ed5d88f 100644 (file)
@@ -303,6 +303,7 @@ bayes_learn_spam (struct rspamd_classifier * ctx,
                GPtrArray *tokens,
                struct rspamd_task *task,
                gboolean is_spam,
+               gboolean unlearn,
                GError **err)
 {
        guint i, j;
@@ -325,7 +326,7 @@ bayes_learn_spam (struct rspamd_classifier * ctx,
                                if (st->stcf->is_spam) {
                                        tok->values[id]++;
                                }
-                               else if (tok->values[id] > 0) {
+                               else if (tok->values[id] > 0 && unlearn) {
                                        /* Unlearning */
                                        tok->values[id]--;
                                }
@@ -334,7 +335,7 @@ bayes_learn_spam (struct rspamd_classifier * ctx,
                                if (!st->stcf->is_spam) {
                                        tok->values[id]++;
                                }
-                               else if (tok->values[id] > 0) {
+                               else if (tok->values[id] > 0 && unlearn) {
                                        /* Unlearning */
                                        tok->values[id]--;
                                }
index 86395c96d0a06f92bfac95f08f4c07b8c11b6c28..6bafa8507bc00e48ad098941528326597c1b2564 100644 (file)
@@ -23,7 +23,9 @@ struct rspamd_stat_classifier {
                        struct rspamd_task *task);
        gboolean (*learn_spam_func)(struct rspamd_classifier * ctx,
                        GPtrArray *input,
-                       struct rspamd_task *task, gboolean is_spam,
+                       struct rspamd_task *task,
+                       gboolean is_spam,
+                       gboolean unlearn,
                        GError **err);
 };
 
@@ -37,6 +39,7 @@ gboolean bayes_learn_spam (struct rspamd_classifier *ctx,
                GPtrArray *tokens,
                struct rspamd_task *task,
                gboolean is_spam,
+               gboolean unlearn,
                GError **err);
 
 #endif
index a4a28a4bc9ff9727b272c8696fa97d4bad5d7fd6..1cdd2f0292485f98d4bea8845b493d11cc426798 100644 (file)
@@ -77,6 +77,7 @@ rspamd_stat_result_t rspamd_stat_classify (struct rspamd_task *task,
  */
 rspamd_stat_result_t rspamd_stat_learn (struct rspamd_task *task,
                gboolean spam, lua_State *L, const gchar *classifier,
+               guint stage,
                GError **err);
 
 /**
index 8a426972713bd4b7a498247acbd4b22ef5dd42b8..b2010391cea35baa58a1a39884955c53ab2d425c 100644 (file)
@@ -364,284 +364,282 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage,
        return ret;
 }
 
-#if 0
 static gboolean
-rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
+rspamd_stat_cache_check (struct rspamd_stat_ctx *st_ctx,
+               struct rspamd_task *task,
+                const gchar *classifier,
+                gboolean spam,
+                GError **err)
 {
-       rspamd_token_t *t = (rspamd_token_t *)v;
-       struct preprocess_cb_data *cbdata = (struct preprocess_cb_data *)d;
-       struct rspamd_statfile_runtime *st_runtime;
-       struct rspamd_classifier_runtime *cl_runtime;
-       struct rspamd_token_result *res;
-       struct rspamd_task *task;
-       GList *cur, *curst;
-       gint i = 0;
-
-       task = cbdata->task;
-       cur = g_list_first (cbdata->classifier_runtimes);
+       rspamd_learn_t learn_res = RSPAMD_LEARN_OK;
+       struct rspamd_classifier *cl;
+       guint i;
 
-       while (cur) {
-               cl_runtime = (struct rspamd_classifier_runtime *)cur->data;
-
-               if (cl_runtime->clcf->min_tokens > 0 &&
-                               (guint32)g_tree_nnodes (cbdata->tok->tokens) < cl_runtime->clcf->min_tokens) {
-                       /* Skip this classifier */
-                       msg_debug_task ("<%s> contains less tokens than required for %s classifier: "
-                                       "%ud < %ud", cbdata->task->message_id, cl_runtime->clcf->name,
-                                       g_tree_nnodes (cbdata->tok->tokens),
-                                       cl_runtime->clcf->min_tokens);
-                       cur = g_list_next (cur);
+       /* Check whether we have learned that file */
+       for (i = 0; i < st_ctx->classifiers->len; i ++) {
+               cl = g_ptr_array_index (st_ctx->classifiers, i);
+
+               /* Skip other classifiers if they are not needed */
+               if (classifier != NULL && (cl->cfg->name == NULL ||
+                               g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
                        continue;
                }
 
-               curst = cl_runtime->st_runtime;
-
-               while (curst) {
-                       res = &g_array_index (t->results, struct rspamd_token_result, i);
-                       st_runtime = (struct rspamd_statfile_runtime *)curst->data;
-
-                       if (cl_runtime->backend->learn_token (cbdata->task, t, res,
-                                       cl_runtime->backend->ctx)) {
-                               cl_runtime->processed_tokens ++;
-
-                               if (cl_runtime->clcf->max_tokens > 0 &&
-                                               cl_runtime->processed_tokens > cl_runtime->clcf->max_tokens) {
-                                       msg_debug_task ("message contains more tokens than allowed for %s classifier: "
-                                                       "%uL > %ud", cl_runtime->clcf->name,
-                                                       cl_runtime->processed_tokens,
-                                                       cl_runtime->clcf->max_tokens);
+               if (cl->cache && cl->cachecf) {
+                       learn_res = cl->cache->process (task, spam,
+                                       cl->cachecf);
+               }
 
-                                       return TRUE;
-                               }
-                       }
+               if (learn_res == RSPAMD_LEARN_INGORE) {
+                       /* Do not learn twice */
+                       g_set_error (err, rspamd_stat_quark (), 404, "<%s> has been already "
+                                       "learned as %s, ignore it", task->message_id,
+                                       spam ? "spam" : "ham");
+                       task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
 
-                       i ++;
-                       curst = g_list_next (curst);
+                       return FALSE;
+               }
+               else if (learn_res == RSPAMD_LEARN_UNLEARN) {
+                       task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+                       break;
                }
-
-               cur = g_list_next (cur);
        }
 
-
-       return FALSE;
+       return TRUE;
 }
 
-rspamd_stat_result_t
-rspamd_stat_learn (struct rspamd_task *task,
-               gboolean spam,
-               lua_State *L,
-               const gchar *classifier,
-               GError **err)
+static gboolean
+rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
+               struct rspamd_task *task,
+                const gchar *classifier,
+                gboolean spam,
+                GError **err)
 {
-       struct rspamd_stat_ctx *st_ctx;
-       struct rspamd_classifier_runtime *cl_run;
-       struct rspamd_statfile_runtime *st_run;
-       struct classifier_ctx *cl_ctx;
-       struct preprocess_cb_data cbdata;
-       GList *cl_runtimes;
-       GList *cur, *curst;
-       gboolean unlearn = FALSE;
-       rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_ERROR;
-       gulong nrev;
-       rspamd_learn_t learn_res = RSPAMD_LEARN_OK;
+       struct rspamd_classifier *cl;
        guint i;
-       gboolean learned = FALSE;
-
-       st_ctx = rspamd_stat_get_ctx ();
-       g_assert (st_ctx != NULL);
-
-       cur = g_list_first (task->cfg->classifiers);
+       gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
 
        /* Check whether we have learned that file */
-       for (i = 0; i < st_ctx->caches_count; i ++) {
-               learn_res = st_ctx->caches[i].process (task, spam,
-                               st_ctx->caches[i].ctx);
+       for (i = 0; i < st_ctx->classifiers->len; i ++) {
+               cl = g_ptr_array_index (st_ctx->classifiers, i);
 
-               if (learn_res == RSPAMD_LEARN_INGORE) {
-                       /* Do not learn twice */
-                       g_set_error (err, rspamd_stat_quark (), 404, "<%s> has been already "
-                                       "learned as %s, ignore it", task->message_id,
-                                       spam ? "spam" : "ham");
-                       return RSPAMD_STAT_PROCESS_ERROR;
+               /* Skip other classifiers if they are not needed */
+               if (classifier != NULL && (cl->cfg->name == NULL ||
+                               g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
+                       continue;
                }
-               else if (learn_res == RSPAMD_LEARN_UNLEARN) {
-                       unlearn = TRUE;
+
+               /* Now check max and min tokens */
+               if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
+                       msg_info_task (
+                               "<%s> contains less tokens than required for %s classifier: "
+                                               "%ud < %ud",
+                               task->message_id,
+                               cl->cfg->name,
+                               task->tokens->len,
+                               cl->cfg->min_tokens);
+                       too_small = TRUE;
+                       continue;
+               }
+               else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
+                       msg_info_task (
+                               "<%s> contains more tokens than allowed for %s classifier: "
+                                               "%ud > %ud",
+                               task->message_id,
+                               cl->cfg->name,
+                               task->tokens->len,
+                               cl->cfg->max_tokens);
+                       too_large = TRUE;
+                       continue;
+               }
+
+               if (cl->subrs->learn_spam_func (cl, task->tokens, task, spam,
+                               task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+                       learned = TRUE;
                }
        }
 
-       /* Initialize classifiers and statfiles runtime */
-       if ((cl_runtimes = rspamd_stat_preprocess (st_ctx,
-                       task,
-                       L,
-                       unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP,
-                       spam,
-                       classifier,
-                       err)) == NULL) {
-               return RSPAMD_STAT_PROCESS_ERROR;
+       if (!learned && err && *err == NULL) {
+               if (too_large) {
+                       g_set_error (err, rspamd_stat_quark (), 400,
+                                       "<%s> contains more tokens than allowed for %s classifier: "
+                                       "%d > %d",
+                                       task->message_id,
+                                       cl->cfg->name,
+                                       task->tokens->len,
+                                       cl->cfg->max_tokens);
+               }
+               else if (too_small) {
+                       g_set_error (err, rspamd_stat_quark (), 400,
+                                       "<%s> contains less tokens than required for %s classifier: "
+                                       "%d < %d",
+                                       task->message_id,
+                                       cl->cfg->name,
+                                       task->tokens->len,
+                                       cl->cfg->max_tokens);
+               }
        }
 
-       cur = cl_runtimes;
+       return learned;
+}
 
-       while (cur) {
-               cl_run = (struct rspamd_classifier_runtime *)cur->data;
+static gboolean
+rspamd_stat_backends_learn (struct rspamd_stat_ctx *st_ctx,
+               struct rspamd_task *task,
+                const gchar *classifier,
+                gboolean spam,
+                GError **err)
+{
+       struct rspamd_classifier *cl;
+       struct rspamd_statfile *st;
+       gpointer bk_run;
+       guint i, j;
+       gint id;
+       gboolean res = TRUE;
 
-               curst = cl_run->st_runtime;
+       for (i = 0; i < st_ctx->classifiers->len; i ++) {
+               cl = g_ptr_array_index (st_ctx->classifiers, i);
 
-               /* Needed to finalize pre-process stage */
-               while (curst) {
-                       st_run = curst->data;
-                       cl_run->backend->finalize_process (task,
-                                       st_run->backend_runtime,
-                                       cl_run->backend->ctx);
-                       curst = g_list_next (curst);
+               /* Skip other classifiers if they are not needed */
+               if (classifier != NULL && (cl->cfg->name == NULL ||
+                               g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
+                       continue;
                }
 
-               if (cl_run->skipped) {
-                       msg_info_task (
-                                       "<%s> contains less tokens than required for %s classifier: "
-                                                       "%ud < %ud",
-                                       task->message_id,
-                                       cl_run->clcf->name,
-                                       g_tree_nnodes (cl_run->tok->tokens),
-                                       cl_run->clcf->min_tokens);
-               }
+               for (j = 0; j < cl->statfiles_ids->len; j ++) {
+                       id = g_array_index (cl->statfiles_ids, gint, j);
+                       st = g_ptr_array_index (st_ctx->statfiles, id);
+                       bk_run = g_ptr_array_index (task->stat_runtimes, id);
 
-               if (cl_run->cl && !cl_run->skipped) {
-                       cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf);
-
-                       if (cl_ctx != NULL) {
-                               if (cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens,
-                                               cl_run, task, spam, err)) {
-                                       msg_debug_task ("learned %s classifier %s", spam ? "spam" : "ham",
-                                                       cl_run->clcf->name);
-                                       ret = RSPAMD_STAT_PROCESS_OK;
-                                       learned = TRUE;
-
-                                       cbdata.classifier_runtimes = cur;
-                                       cbdata.task = task;
-                                       cbdata.tok = cl_run->tok;
-                                       cbdata.unlearn = unlearn;
-                                       cbdata.spam = spam;
-                                       g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token,
-                                                       &cbdata);
-
-                                       curst = g_list_first (cl_run->st_runtime);
-
-                                       while (curst) {
-                                               st_run = (struct rspamd_statfile_runtime *)curst->data;
-
-                                               if (unlearn && spam != st_run->st->is_spam) {
-                                                       nrev = cl_run->backend->dec_learns (task,
-                                                                       st_run->backend_runtime,
-                                                                       cl_run->backend->ctx);
-                                                       msg_debug_task ("unlearned %s, new revision: %ul",
-                                                                       st_run->st->symbol, nrev);
-                                               }
-                                               else {
-                                                       nrev = cl_run->backend->inc_learns (task,
-                                                               st_run->backend_runtime,
-                                                               cl_run->backend->ctx);
-                                                       msg_debug_task ("learned %s, new revision: %ul",
-                                                               st_run->st->symbol, nrev);
-                                               }
-
-                                               cl_run->backend->finalize_learn (task,
-                                                               st_run->backend_runtime,
-                                                               cl_run->backend->ctx);
-
-                                               curst = g_list_next (curst);
-                                       }
+                       g_assert (st != NULL);
+
+                       if (bk_run == NULL) {
+                               /* XXX: must be error */
+                               continue;
+                       }
+
+                       if (!task->flags & RSPAMD_TASK_FLAG_UNLEARN) {
+                               if (spam != st->stcf->is_spam) {
+                                       /* If we are not unlearning, then do not touch another class */
+                                       continue;
                                }
-                               else {
-                                       return RSPAMD_STAT_PROCESS_ERROR;
+                       }
+
+                       if (!st->backend->learn_tokens (task, task->tokens, id, bk_run)) {
+                               if (err && *err == NULL) {
+                                       g_set_error (err, rspamd_stat_quark (), 500, "Cannot push "
+                                                       "learned results to the backend");
                                }
 
+                               res = FALSE;
                        }
                }
-
-               cur = g_list_next (cur);
-       }
-
-       if (!learned) {
-               g_set_error (err, rspamd_stat_quark (), 500, "message cannot be learned as "
-                               "it has too few tokens for any classifier defined");
-       }
-       else {
-               g_atomic_int_inc (&task->worker->srv->stat->messages_learned);
        }
 
-       return ret;
+       return res;
 }
 
-rspamd_stat_result_t rspamd_stat_statistics (struct rspamd_task *task,
-               struct rspamd_config *cfg,
-               guint64 *total_learns,
-               ucl_object_t **target)
+static gboolean
+rspamd_stat_backends_post_learn (struct rspamd_stat_ctx *st_ctx,
+               struct rspamd_task *task,
+                const gchar *classifier,
+                gboolean spam)
 {
-       struct rspamd_classifier_config *clcf;
-       struct rspamd_statfile_config *stcf;
-       struct rspamd_stat_backend *bk;
-       gpointer backend_runtime;
-       GList *cur, *st_list = NULL, *curst;
-       ucl_object_t *res = NULL, *elt;
-       guint64 learns = 0;
-
-       if (cfg != NULL && cfg->classifiers != NULL) {
-               res = ucl_object_typed_new (UCL_ARRAY);
+       struct rspamd_classifier *cl;
+       struct rspamd_statfile *st;
+       gpointer bk_run;
+       guint i, j;
+       gint id;
+       gboolean res = TRUE;
 
-               cur = g_list_first (cfg->classifiers);
+       for (i = 0; i < st_ctx->classifiers->len; i ++) {
+               cl = g_ptr_array_index (st_ctx->classifiers, i);
 
-               while (cur) {
-                       clcf = (struct rspamd_classifier_config *)cur->data;
+               /* Skip other classifiers if they are not needed */
+               if (classifier != NULL && (cl->cfg->name == NULL ||
+                               g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
+                       continue;
+               }
 
-                       st_list = clcf->statfiles;
-                       curst = st_list;
+               for (j = 0; j < cl->statfiles_ids->len; j ++) {
+                       id = g_array_index (cl->statfiles_ids, gint, j);
+                       st = g_ptr_array_index (st_ctx->statfiles, id);
+                       bk_run = g_ptr_array_index (task->stat_runtimes, id);
 
-                       while (curst != NULL) {
-                               stcf = (struct rspamd_statfile_config *)curst->data;
+                       g_assert (st != NULL);
 
-                               bk = rspamd_stat_get_backend (clcf->backend);
+                       if (bk_run == NULL) {
+                               /* XXX: must be error */
+                               continue;
+                       }
 
-                               if (bk == NULL) {
-                                       msg_warn ("backend of type %s is not defined", clcf->backend);
-                                       curst = g_list_next (curst);
+                       if (!task->flags & RSPAMD_TASK_FLAG_UNLEARN) {
+                               if (spam != st->stcf->is_spam) {
+                                       /* If we are not unlearning, then do not touch another class */
                                        continue;
                                }
 
-                               backend_runtime = bk->runtime (task, stcf, FALSE, bk->ctx);
-
-                               learns += bk->total_learns (task, backend_runtime, bk->ctx);
-                               elt = bk->get_stat (backend_runtime, bk->ctx);
-
-                               if (elt != NULL) {
-                                       ucl_array_append (res, elt);
+                               st->backend->inc_learns (task, bk_run, st_ctx);
+                       }
+                       else {
+                               if (spam == st->stcf->is_spam) {
+                                       st->backend->inc_learns (task, bk_run, st_ctx);
+                               }
+                               else {
+                                       st->backend->dec_learns (task, bk_run, st_ctx);
                                }
-
-                               curst = g_list_next (curst);
                        }
 
-                       /* Next classifier */
-                       cur = g_list_next (cur);
+                       st->backend->finalize_learn (task, bk_run, st_ctx);
                }
+       }
+
+       return res;
+}
+
+rspamd_stat_result_t
+rspamd_stat_learn (struct rspamd_task *task,
+               gboolean spam, lua_State *L, const gchar *classifier, guint stage,
+               GError **err)
+{
+       struct rspamd_stat_ctx *st_ctx;
+
+       /*
+        * We assume now that a task has been already classified before
+        * coming to learn
+        */
+       g_assert (RSPAMD_TASK_IS_CLASSIFIED (task));
+
+       rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
+
+       st_ctx = rspamd_stat_get_ctx ();
+       g_assert (st_ctx != NULL);
 
-               if (total_learns != NULL) {
-                       *total_learns = learns;
+       if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
+               /* Process classifiers */
+               if (!rspamd_stat_cache_check (st_ctx, task, classifier, spam, err)) {
+                       return RSPAMD_STAT_PROCESS_ERROR;
                }
        }
+       else if (stage == RSPAMD_TASK_STAGE_LEARN) {
+               /* Process classifiers */
+               if (!rspamd_stat_classifiers_learn (st_ctx, task, classifier,
+                               spam, err)) {
+                       return RSPAMD_STAT_PROCESS_ERROR;
+               }
 
-       if (target) {
-               *target = res;
+               /* Process backends */
+               if (!rspamd_stat_backends_learn (st_ctx, task, classifier, spam, err)) {
+                       return RSPAMD_STAT_PROCESS_ERROR;
+               }
+       }
+       else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) {
+               if (!rspamd_stat_backends_post_learn (st_ctx, task, classifier, spam)) {
+                       return RSPAMD_STAT_PROCESS_ERROR;
+               }
        }
 
-       return RSPAMD_STAT_PROCESS_OK;
-}
-#else
-/* TODO: finish learning */
-rspamd_stat_result_t rspamd_stat_learn (struct rspamd_task *task,
-               gboolean spam, lua_State *L, const gchar *classifier,
-               GError **err)
-{
-       return RSPAMD_STAT_PROCESS_ERROR;
+       return ret;
 }
 
 /**
@@ -657,4 +655,3 @@ rspamd_stat_result_t rspamd_stat_statistics (struct rspamd_task *task,
 {
        return RSPAMD_STAT_PROCESS_ERROR;
 }
-#endif