aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-01-06 14:24:07 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-01-06 14:24:07 +0000
commitdf9ada40a53a804d2d90d9dfddc149a68c141a15 (patch)
treeecf56c7cf18b8a95b2fc6a38933ea7dd142c432b /src
parent1622570f58b5f5b184f97cd75a52a98cc0b1721a (diff)
downloadrspamd-df9ada40a53a804d2d90d9dfddc149a68c141a15.tar.gz
rspamd-df9ada40a53a804d2d90d9dfddc149a68c141a15.zip
Add learning implementation.
Diffstat (limited to 'src')
-rw-r--r--src/libserver/task.c9
-rw-r--r--src/libserver/task.h16
-rw-r--r--src/libstat/classifiers/bayes.c5
-rw-r--r--src/libstat/classifiers/classifiers.h5
-rw-r--r--src/libstat/stat_api.h1
-rw-r--r--src/libstat/stat_process.c437
6 files changed, 242 insertions, 231 deletions
diff --git a/src/libserver/task.c b/src/libserver/task.c
index 290101023..579cc3461 100644
--- a/src/libserver/task.c
+++ b/src/libserver/task.c
@@ -610,11 +610,7 @@ rspamd_learn_task_spam (struct rspamd_task *task,
const gchar *classifier,
GError **err)
{
- return rspamd_stat_learn (task,
- is_spam,
- task->cfg->lua_state,
- classifier,
- err);
+ return FALSE;
}
static gboolean
@@ -999,7 +995,8 @@ rspamd_task_write_log (struct rspamd_task *task)
g_assert (task != NULL);
- if (task->cfg->log_format == NULL || task->flags & RSPAMD_TASK_FLAG_NO_LOG) {
+ if (task->cfg->log_format == NULL ||
+ (task->flags & RSPAMD_TASK_FLAG_NO_LOG)) {
return;
}
diff --git a/src/libserver/task.h b/src/libserver/task.h
index ed18d99d0..901067ba4 100644
--- a/src/libserver/task.h
+++ b/src/libserver/task.h
@@ -65,8 +65,11 @@ enum rspamd_task_stage {
RSPAMD_TASK_STAGE_CLASSIFIERS_POST = (1 << 7),
RSPAMD_TASK_STAGE_COMPOSITES = (1 << 8),
RSPAMD_TASK_STAGE_POST_FILTERS = (1 << 9),
- RSPAMD_TASK_STAGE_DONE = (1 << 10),
- RSPAMD_TASK_STAGE_REPLIED = (1 << 11)
+ RSPAMD_TASK_STAGE_LEARN_PRE = (1 << 10),
+ RSPAMD_TASK_STAGE_LEARN = (1 << 11),
+ RSPAMD_TASK_STAGE_LEARN_POST = (1 << 12),
+ RSPAMD_TASK_STAGE_DONE = (1 << 13),
+ RSPAMD_TASK_STAGE_REPLIED = (1 << 14)
};
#define RSPAMD_TASK_PROCESS_ALL (RSPAMD_TASK_STAGE_CONNECT | \
@@ -79,10 +82,16 @@ enum rspamd_task_stage {
RSPAMD_TASK_STAGE_CLASSIFIERS_POST | \
RSPAMD_TASK_STAGE_COMPOSITES | \
RSPAMD_TASK_STAGE_POST_FILTERS | \
+ RSPAMD_TASK_STAGE_LEARN_PRE | \
+ RSPAMD_TASK_STAGE_LEARN | \
+ RSPAMD_TASK_STAGE_LEARN_POST | \
RSPAMD_TASK_STAGE_DONE)
#define RSPAMD_TASK_PROCESS_LEARN (RSPAMD_TASK_STAGE_CONNECT | \
RSPAMD_TASK_STAGE_ENVELOPE | \
RSPAMD_TASK_STAGE_READ_MESSAGE | \
+ RSPAMD_TASK_STAGE_CLASSIFIERS_PRE | \
+ RSPAMD_TASK_STAGE_CLASSIFIERS | \
+ RSPAMD_TASK_STAGE_CLASSIFIERS_POST | \
RSPAMD_TASK_STAGE_DONE)
#define RSPAMD_TASK_FLAG_MIME (1 << 0)
@@ -99,11 +108,14 @@ enum rspamd_task_stage {
#define RSPAMD_TASK_FLAG_GTUBE (1 << 11)
#define RSPAMD_TASK_FLAG_FILE (1 << 12)
#define RSPAMD_TASK_FLAG_NO_STAT (1 << 13)
+#define RSPAMD_TASK_FLAG_UNLEARN (1 << 14)
+#define RSPAMD_TASK_FLAG_ALREADY_LEARNED (1 << 15)
#define RSPAMD_TASK_IS_SKIPPED(task) (((task)->flags & RSPAMD_TASK_FLAG_SKIP))
#define RSPAMD_TASK_IS_JSON(task) (((task)->flags & RSPAMD_TASK_FLAG_JSON))
#define RSPAMD_TASK_IS_SPAMC(task) (((task)->flags & RSPAMD_TASK_FLAG_SPAMC))
#define RSPAMD_TASK_IS_PROCESSED(task) (((task)->processed_stages & RSPAMD_TASK_STAGE_DONE))
+#define RSPAMD_TASK_IS_CLASSIFIED(task) (((task)->processed_stages & RSPAMD_TASK_STAGE_CLASSIFIERS))
typedef gint (*protocol_reply_func)(struct rspamd_task *task);
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c
index 0915933f1..b08c70380 100644
--- a/src/libstat/classifiers/bayes.c
+++ b/src/libstat/classifiers/bayes.c
@@ -303,6 +303,7 @@ bayes_learn_spam (struct rspamd_classifier * ctx,
GPtrArray *tokens,
struct rspamd_task *task,
gboolean is_spam,
+ gboolean unlearn,
GError **err)
{
guint i, j;
@@ -325,7 +326,7 @@ bayes_learn_spam (struct rspamd_classifier * ctx,
if (st->stcf->is_spam) {
tok->values[id]++;
}
- else if (tok->values[id] > 0) {
+ else if (tok->values[id] > 0 && unlearn) {
/* Unlearning */
tok->values[id]--;
}
@@ -334,7 +335,7 @@ bayes_learn_spam (struct rspamd_classifier * ctx,
if (!st->stcf->is_spam) {
tok->values[id]++;
}
- else if (tok->values[id] > 0) {
+ else if (tok->values[id] > 0 && unlearn) {
/* Unlearning */
tok->values[id]--;
}
diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h
index 86395c96d..6bafa8507 100644
--- a/src/libstat/classifiers/classifiers.h
+++ b/src/libstat/classifiers/classifiers.h
@@ -23,7 +23,9 @@ struct rspamd_stat_classifier {
struct rspamd_task *task);
gboolean (*learn_spam_func)(struct rspamd_classifier * ctx,
GPtrArray *input,
- struct rspamd_task *task, gboolean is_spam,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ gboolean unlearn,
GError **err);
};
@@ -37,6 +39,7 @@ gboolean bayes_learn_spam (struct rspamd_classifier *ctx,
GPtrArray *tokens,
struct rspamd_task *task,
gboolean is_spam,
+ gboolean unlearn,
GError **err);
#endif
diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h
index a4a28a4bc..1cdd2f029 100644
--- a/src/libstat/stat_api.h
+++ b/src/libstat/stat_api.h
@@ -77,6 +77,7 @@ rspamd_stat_result_t rspamd_stat_classify (struct rspamd_task *task,
*/
rspamd_stat_result_t rspamd_stat_learn (struct rspamd_task *task,
gboolean spam, lua_State *L, const gchar *classifier,
+ guint stage,
GError **err);
/**
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 8a4269727..b2010391c 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -364,284 +364,282 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage,
return ret;
}
-#if 0
static gboolean
-rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
+rspamd_stat_cache_check (struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam,
+ GError **err)
{
- rspamd_token_t *t = (rspamd_token_t *)v;
- struct preprocess_cb_data *cbdata = (struct preprocess_cb_data *)d;
- struct rspamd_statfile_runtime *st_runtime;
- struct rspamd_classifier_runtime *cl_runtime;
- struct rspamd_token_result *res;
- struct rspamd_task *task;
- GList *cur, *curst;
- gint i = 0;
-
- task = cbdata->task;
- cur = g_list_first (cbdata->classifier_runtimes);
+ rspamd_learn_t learn_res = RSPAMD_LEARN_OK;
+ struct rspamd_classifier *cl;
+ guint i;
- while (cur) {
- cl_runtime = (struct rspamd_classifier_runtime *)cur->data;
-
- if (cl_runtime->clcf->min_tokens > 0 &&
- (guint32)g_tree_nnodes (cbdata->tok->tokens) < cl_runtime->clcf->min_tokens) {
- /* Skip this classifier */
- msg_debug_task ("<%s> contains less tokens than required for %s classifier: "
- "%ud < %ud", cbdata->task->message_id, cl_runtime->clcf->name,
- g_tree_nnodes (cbdata->tok->tokens),
- cl_runtime->clcf->min_tokens);
- cur = g_list_next (cur);
+ /* Check whether we have learned that file */
+ for (i = 0; i < st_ctx->classifiers->len; i ++) {
+ cl = g_ptr_array_index (st_ctx->classifiers, i);
+
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
continue;
}
- curst = cl_runtime->st_runtime;
-
- while (curst) {
- res = &g_array_index (t->results, struct rspamd_token_result, i);
- st_runtime = (struct rspamd_statfile_runtime *)curst->data;
-
- if (cl_runtime->backend->learn_token (cbdata->task, t, res,
- cl_runtime->backend->ctx)) {
- cl_runtime->processed_tokens ++;
-
- if (cl_runtime->clcf->max_tokens > 0 &&
- cl_runtime->processed_tokens > cl_runtime->clcf->max_tokens) {
- msg_debug_task ("message contains more tokens than allowed for %s classifier: "
- "%uL > %ud", cl_runtime->clcf->name,
- cl_runtime->processed_tokens,
- cl_runtime->clcf->max_tokens);
+ if (cl->cache && cl->cachecf) {
+ learn_res = cl->cache->process (task, spam,
+ cl->cachecf);
+ }
- return TRUE;
- }
- }
+ if (learn_res == RSPAMD_LEARN_INGORE) {
+ /* Do not learn twice */
+ g_set_error (err, rspamd_stat_quark (), 404, "<%s> has been already "
+ "learned as %s, ignore it", task->message_id,
+ spam ? "spam" : "ham");
+ task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
- i ++;
- curst = g_list_next (curst);
+ return FALSE;
+ }
+ else if (learn_res == RSPAMD_LEARN_UNLEARN) {
+ task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+ break;
}
-
- cur = g_list_next (cur);
}
-
- return FALSE;
+ return TRUE;
}
-rspamd_stat_result_t
-rspamd_stat_learn (struct rspamd_task *task,
- gboolean spam,
- lua_State *L,
- const gchar *classifier,
- GError **err)
+static gboolean
+rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam,
+ GError **err)
{
- struct rspamd_stat_ctx *st_ctx;
- struct rspamd_classifier_runtime *cl_run;
- struct rspamd_statfile_runtime *st_run;
- struct classifier_ctx *cl_ctx;
- struct preprocess_cb_data cbdata;
- GList *cl_runtimes;
- GList *cur, *curst;
- gboolean unlearn = FALSE;
- rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_ERROR;
- gulong nrev;
- rspamd_learn_t learn_res = RSPAMD_LEARN_OK;
+ struct rspamd_classifier *cl;
guint i;
- gboolean learned = FALSE;
-
- st_ctx = rspamd_stat_get_ctx ();
- g_assert (st_ctx != NULL);
-
- cur = g_list_first (task->cfg->classifiers);
+ gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
/* Check whether we have learned that file */
- for (i = 0; i < st_ctx->caches_count; i ++) {
- learn_res = st_ctx->caches[i].process (task, spam,
- st_ctx->caches[i].ctx);
+ for (i = 0; i < st_ctx->classifiers->len; i ++) {
+ cl = g_ptr_array_index (st_ctx->classifiers, i);
- if (learn_res == RSPAMD_LEARN_INGORE) {
- /* Do not learn twice */
- g_set_error (err, rspamd_stat_quark (), 404, "<%s> has been already "
- "learned as %s, ignore it", task->message_id,
- spam ? "spam" : "ham");
- return RSPAMD_STAT_PROCESS_ERROR;
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
+ continue;
}
- else if (learn_res == RSPAMD_LEARN_UNLEARN) {
- unlearn = TRUE;
+
+ /* Now check max and min tokens */
+ if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
+ msg_info_task (
+ "<%s> contains less tokens than required for %s classifier: "
+ "%ud < %ud",
+ task->message_id,
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->min_tokens);
+ too_small = TRUE;
+ continue;
+ }
+ else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
+ msg_info_task (
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%ud > %ud",
+ task->message_id,
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->max_tokens);
+ too_large = TRUE;
+ continue;
+ }
+
+ if (cl->subrs->learn_spam_func (cl, task->tokens, task, spam,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
}
}
- /* Initialize classifiers and statfiles runtime */
- if ((cl_runtimes = rspamd_stat_preprocess (st_ctx,
- task,
- L,
- unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP,
- spam,
- classifier,
- err)) == NULL) {
- return RSPAMD_STAT_PROCESS_ERROR;
+ if (!learned && err && *err == NULL) {
+ if (too_large) {
+ g_set_error (err, rspamd_stat_quark (), 400,
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%d > %d",
+ task->message_id,
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->max_tokens);
+ }
+ else if (too_small) {
+ g_set_error (err, rspamd_stat_quark (), 400,
+ "<%s> contains less tokens than required for %s classifier: "
+ "%d < %d",
+ task->message_id,
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->max_tokens);
+ }
}
- cur = cl_runtimes;
+ return learned;
+}
- while (cur) {
- cl_run = (struct rspamd_classifier_runtime *)cur->data;
+static gboolean
+rspamd_stat_backends_learn (struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam,
+ GError **err)
+{
+ struct rspamd_classifier *cl;
+ struct rspamd_statfile *st;
+ gpointer bk_run;
+ guint i, j;
+ gint id;
+ gboolean res = TRUE;
- curst = cl_run->st_runtime;
+ for (i = 0; i < st_ctx->classifiers->len; i ++) {
+ cl = g_ptr_array_index (st_ctx->classifiers, i);
- /* Needed to finalize pre-process stage */
- while (curst) {
- st_run = curst->data;
- cl_run->backend->finalize_process (task,
- st_run->backend_runtime,
- cl_run->backend->ctx);
- curst = g_list_next (curst);
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
+ continue;
}
- if (cl_run->skipped) {
- msg_info_task (
- "<%s> contains less tokens than required for %s classifier: "
- "%ud < %ud",
- task->message_id,
- cl_run->clcf->name,
- g_tree_nnodes (cl_run->tok->tokens),
- cl_run->clcf->min_tokens);
- }
+ for (j = 0; j < cl->statfiles_ids->len; j ++) {
+ id = g_array_index (cl->statfiles_ids, gint, j);
+ st = g_ptr_array_index (st_ctx->statfiles, id);
+ bk_run = g_ptr_array_index (task->stat_runtimes, id);
- if (cl_run->cl && !cl_run->skipped) {
- cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf);
-
- if (cl_ctx != NULL) {
- if (cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens,
- cl_run, task, spam, err)) {
- msg_debug_task ("learned %s classifier %s", spam ? "spam" : "ham",
- cl_run->clcf->name);
- ret = RSPAMD_STAT_PROCESS_OK;
- learned = TRUE;
-
- cbdata.classifier_runtimes = cur;
- cbdata.task = task;
- cbdata.tok = cl_run->tok;
- cbdata.unlearn = unlearn;
- cbdata.spam = spam;
- g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token,
- &cbdata);
-
- curst = g_list_first (cl_run->st_runtime);
-
- while (curst) {
- st_run = (struct rspamd_statfile_runtime *)curst->data;
-
- if (unlearn && spam != st_run->st->is_spam) {
- nrev = cl_run->backend->dec_learns (task,
- st_run->backend_runtime,
- cl_run->backend->ctx);
- msg_debug_task ("unlearned %s, new revision: %ul",
- st_run->st->symbol, nrev);
- }
- else {
- nrev = cl_run->backend->inc_learns (task,
- st_run->backend_runtime,
- cl_run->backend->ctx);
- msg_debug_task ("learned %s, new revision: %ul",
- st_run->st->symbol, nrev);
- }
-
- cl_run->backend->finalize_learn (task,
- st_run->backend_runtime,
- cl_run->backend->ctx);
-
- curst = g_list_next (curst);
- }
+ g_assert (st != NULL);
+
+ if (bk_run == NULL) {
+ /* XXX: must be error */
+ continue;
+ }
+
+ if (!task->flags & RSPAMD_TASK_FLAG_UNLEARN) {
+ if (spam != st->stcf->is_spam) {
+ /* If we are not unlearning, then do not touch another class */
+ continue;
}
- else {
- return RSPAMD_STAT_PROCESS_ERROR;
+ }
+
+ if (!st->backend->learn_tokens (task, task->tokens, id, bk_run)) {
+ if (err && *err == NULL) {
+ g_set_error (err, rspamd_stat_quark (), 500, "Cannot push "
+ "learned results to the backend");
}
+ res = FALSE;
}
}
-
- cur = g_list_next (cur);
- }
-
- if (!learned) {
- g_set_error (err, rspamd_stat_quark (), 500, "message cannot be learned as "
- "it has too few tokens for any classifier defined");
- }
- else {
- g_atomic_int_inc (&task->worker->srv->stat->messages_learned);
}
- return ret;
+ return res;
}
-rspamd_stat_result_t rspamd_stat_statistics (struct rspamd_task *task,
- struct rspamd_config *cfg,
- guint64 *total_learns,
- ucl_object_t **target)
+static gboolean
+rspamd_stat_backends_post_learn (struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam)
{
- struct rspamd_classifier_config *clcf;
- struct rspamd_statfile_config *stcf;
- struct rspamd_stat_backend *bk;
- gpointer backend_runtime;
- GList *cur, *st_list = NULL, *curst;
- ucl_object_t *res = NULL, *elt;
- guint64 learns = 0;
-
- if (cfg != NULL && cfg->classifiers != NULL) {
- res = ucl_object_typed_new (UCL_ARRAY);
+ struct rspamd_classifier *cl;
+ struct rspamd_statfile *st;
+ gpointer bk_run;
+ guint i, j;
+ gint id;
+ gboolean res = TRUE;
- cur = g_list_first (cfg->classifiers);
+ for (i = 0; i < st_ctx->classifiers->len; i ++) {
+ cl = g_ptr_array_index (st_ctx->classifiers, i);
- while (cur) {
- clcf = (struct rspamd_classifier_config *)cur->data;
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
+ continue;
+ }
- st_list = clcf->statfiles;
- curst = st_list;
+ for (j = 0; j < cl->statfiles_ids->len; j ++) {
+ id = g_array_index (cl->statfiles_ids, gint, j);
+ st = g_ptr_array_index (st_ctx->statfiles, id);
+ bk_run = g_ptr_array_index (task->stat_runtimes, id);
- while (curst != NULL) {
- stcf = (struct rspamd_statfile_config *)curst->data;
+ g_assert (st != NULL);
- bk = rspamd_stat_get_backend (clcf->backend);
+ if (bk_run == NULL) {
+ /* XXX: must be error */
+ continue;
+ }
- if (bk == NULL) {
- msg_warn ("backend of type %s is not defined", clcf->backend);
- curst = g_list_next (curst);
+ if (!task->flags & RSPAMD_TASK_FLAG_UNLEARN) {
+ if (spam != st->stcf->is_spam) {
+ /* If we are not unlearning, then do not touch another class */
continue;
}
- backend_runtime = bk->runtime (task, stcf, FALSE, bk->ctx);
-
- learns += bk->total_learns (task, backend_runtime, bk->ctx);
- elt = bk->get_stat (backend_runtime, bk->ctx);
-
- if (elt != NULL) {
- ucl_array_append (res, elt);
+ st->backend->inc_learns (task, bk_run, st_ctx);
+ }
+ else {
+ if (spam == st->stcf->is_spam) {
+ st->backend->inc_learns (task, bk_run, st_ctx);
+ }
+ else {
+ st->backend->dec_learns (task, bk_run, st_ctx);
}
-
- curst = g_list_next (curst);
}
- /* Next classifier */
- cur = g_list_next (cur);
+ st->backend->finalize_learn (task, bk_run, st_ctx);
}
+ }
+
+ return res;
+}
+
+rspamd_stat_result_t
+rspamd_stat_learn (struct rspamd_task *task,
+ gboolean spam, lua_State *L, const gchar *classifier, guint stage,
+ GError **err)
+{
+ struct rspamd_stat_ctx *st_ctx;
+
+ /*
+ * We assume now that a task has been already classified before
+ * coming to learn
+ */
+ g_assert (RSPAMD_TASK_IS_CLASSIFIED (task));
+
+ rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
+
+ st_ctx = rspamd_stat_get_ctx ();
+ g_assert (st_ctx != NULL);
- if (total_learns != NULL) {
- *total_learns = learns;
+ if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
+ /* Process classifiers */
+ if (!rspamd_stat_cache_check (st_ctx, task, classifier, spam, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
}
}
+ else if (stage == RSPAMD_TASK_STAGE_LEARN) {
+ /* Process classifiers */
+ if (!rspamd_stat_classifiers_learn (st_ctx, task, classifier,
+ spam, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
- if (target) {
- *target = res;
+ /* Process backends */
+ if (!rspamd_stat_backends_learn (st_ctx, task, classifier, spam, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+ else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) {
+ if (!rspamd_stat_backends_post_learn (st_ctx, task, classifier, spam)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
}
- return RSPAMD_STAT_PROCESS_OK;
-}
-#else
-/* TODO: finish learning */
-rspamd_stat_result_t rspamd_stat_learn (struct rspamd_task *task,
- gboolean spam, lua_State *L, const gchar *classifier,
- GError **err)
-{
- return RSPAMD_STAT_PROCESS_ERROR;
+ return ret;
}
/**
@@ -657,4 +655,3 @@ rspamd_stat_result_t rspamd_stat_statistics (struct rspamd_task *task,
{
return RSPAMD_STAT_PROCESS_ERROR;
}
-#endif