From: Vsevolod Stakhov Date: Wed, 6 Jan 2016 14:24:07 +0000 (+0000) Subject: Add learning implementation. X-Git-Tag: 1.1.0~125 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=df9ada40a53a804d2d90d9dfddc149a68c141a15;p=rspamd.git Add learning implementation. --- diff --git a/src/libserver/task.c b/src/libserver/task.c index 290101023..579cc3461 100644 --- a/src/libserver/task.c +++ b/src/libserver/task.c @@ -610,11 +610,7 @@ rspamd_learn_task_spam (struct rspamd_task *task, const gchar *classifier, GError **err) { - return rspamd_stat_learn (task, - is_spam, - task->cfg->lua_state, - classifier, - err); + return FALSE; } static gboolean @@ -999,7 +995,8 @@ rspamd_task_write_log (struct rspamd_task *task) g_assert (task != NULL); - if (task->cfg->log_format == NULL || task->flags & RSPAMD_TASK_FLAG_NO_LOG) { + if (task->cfg->log_format == NULL || + (task->flags & RSPAMD_TASK_FLAG_NO_LOG)) { return; } diff --git a/src/libserver/task.h b/src/libserver/task.h index ed18d99d0..901067ba4 100644 --- a/src/libserver/task.h +++ b/src/libserver/task.h @@ -65,8 +65,11 @@ enum rspamd_task_stage { RSPAMD_TASK_STAGE_CLASSIFIERS_POST = (1 << 7), RSPAMD_TASK_STAGE_COMPOSITES = (1 << 8), RSPAMD_TASK_STAGE_POST_FILTERS = (1 << 9), - RSPAMD_TASK_STAGE_DONE = (1 << 10), - RSPAMD_TASK_STAGE_REPLIED = (1 << 11) + RSPAMD_TASK_STAGE_LEARN_PRE = (1 << 10), + RSPAMD_TASK_STAGE_LEARN = (1 << 11), + RSPAMD_TASK_STAGE_LEARN_POST = (1 << 12), + RSPAMD_TASK_STAGE_DONE = (1 << 13), + RSPAMD_TASK_STAGE_REPLIED = (1 << 14) }; #define RSPAMD_TASK_PROCESS_ALL (RSPAMD_TASK_STAGE_CONNECT | \ @@ -79,10 +82,16 @@ enum rspamd_task_stage { RSPAMD_TASK_STAGE_CLASSIFIERS_POST | \ RSPAMD_TASK_STAGE_COMPOSITES | \ RSPAMD_TASK_STAGE_POST_FILTERS | \ + RSPAMD_TASK_STAGE_LEARN_PRE | \ + RSPAMD_TASK_STAGE_LEARN | \ + RSPAMD_TASK_STAGE_LEARN_POST | \ RSPAMD_TASK_STAGE_DONE) #define RSPAMD_TASK_PROCESS_LEARN (RSPAMD_TASK_STAGE_CONNECT | \ RSPAMD_TASK_STAGE_ENVELOPE | \ RSPAMD_TASK_STAGE_READ_MESSAGE | \ + RSPAMD_TASK_STAGE_CLASSIFIERS_PRE | \ + RSPAMD_TASK_STAGE_CLASSIFIERS | \ + RSPAMD_TASK_STAGE_CLASSIFIERS_POST | \ RSPAMD_TASK_STAGE_DONE) #define RSPAMD_TASK_FLAG_MIME (1 << 0) @@ -99,11 +108,14 @@ enum rspamd_task_stage { #define RSPAMD_TASK_FLAG_GTUBE (1 << 11) #define RSPAMD_TASK_FLAG_FILE (1 << 12) #define RSPAMD_TASK_FLAG_NO_STAT (1 << 13) +#define RSPAMD_TASK_FLAG_UNLEARN (1 << 14) +#define RSPAMD_TASK_FLAG_ALREADY_LEARNED (1 << 15) #define RSPAMD_TASK_IS_SKIPPED(task) (((task)->flags & RSPAMD_TASK_FLAG_SKIP)) #define RSPAMD_TASK_IS_JSON(task) (((task)->flags & RSPAMD_TASK_FLAG_JSON)) #define RSPAMD_TASK_IS_SPAMC(task) (((task)->flags & RSPAMD_TASK_FLAG_SPAMC)) #define RSPAMD_TASK_IS_PROCESSED(task) (((task)->processed_stages & RSPAMD_TASK_STAGE_DONE)) +#define RSPAMD_TASK_IS_CLASSIFIED(task) (((task)->processed_stages & RSPAMD_TASK_STAGE_CLASSIFIERS)) typedef gint (*protocol_reply_func)(struct rspamd_task *task); diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c index 0915933f1..b08c70380 100644 --- a/src/libstat/classifiers/bayes.c +++ b/src/libstat/classifiers/bayes.c @@ -303,6 +303,7 @@ bayes_learn_spam (struct rspamd_classifier * ctx, GPtrArray *tokens, struct rspamd_task *task, gboolean is_spam, + gboolean unlearn, GError **err) { guint i, j; @@ -325,7 +326,7 @@ bayes_learn_spam (struct rspamd_classifier * ctx, if (st->stcf->is_spam) { tok->values[id]++; } - else if (tok->values[id] > 0) { + else if (tok->values[id] > 0 && unlearn) { /* Unlearning */ tok->values[id]--; } @@ -334,7 +335,7 @@ bayes_learn_spam (struct rspamd_classifier * ctx, if (!st->stcf->is_spam) { tok->values[id]++; } - else if (tok->values[id] > 0) { + else if (tok->values[id] > 0 && unlearn) { /* Unlearning */ tok->values[id]--; } diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h index 86395c96d..6bafa8507 100644 --- a/src/libstat/classifiers/classifiers.h +++ b/src/libstat/classifiers/classifiers.h @@ -23,7 +23,9 @@ struct rspamd_stat_classifier { struct rspamd_task *task); gboolean (*learn_spam_func)(struct rspamd_classifier * ctx, GPtrArray *input, - struct rspamd_task *task, gboolean is_spam, + struct rspamd_task *task, + gboolean is_spam, + gboolean unlearn, GError **err); }; @@ -37,6 +39,7 @@ gboolean bayes_learn_spam (struct rspamd_classifier *ctx, GPtrArray *tokens, struct rspamd_task *task, gboolean is_spam, + gboolean unlearn, GError **err); #endif diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h index a4a28a4bc..1cdd2f029 100644 --- a/src/libstat/stat_api.h +++ b/src/libstat/stat_api.h @@ -77,6 +77,7 @@ rspamd_stat_result_t rspamd_stat_classify (struct rspamd_task *task, */ rspamd_stat_result_t rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, const gchar *classifier, + guint stage, GError **err); /** diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 8a4269727..b2010391c 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -364,284 +364,282 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage, return ret; } -#if 0 static gboolean -rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d) +rspamd_stat_cache_check (struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam, + GError **err) { - rspamd_token_t *t = (rspamd_token_t *)v; - struct preprocess_cb_data *cbdata = (struct preprocess_cb_data *)d; - struct rspamd_statfile_runtime *st_runtime; - struct rspamd_classifier_runtime *cl_runtime; - struct rspamd_token_result *res; - struct rspamd_task *task; - GList *cur, *curst; - gint i = 0; - - task = cbdata->task; - cur = g_list_first (cbdata->classifier_runtimes); + rspamd_learn_t learn_res = RSPAMD_LEARN_OK; + struct rspamd_classifier *cl; + guint i; - while (cur) { - cl_runtime = (struct rspamd_classifier_runtime *)cur->data; - - if (cl_runtime->clcf->min_tokens > 0 && - (guint32)g_tree_nnodes (cbdata->tok->tokens) < cl_runtime->clcf->min_tokens) { - /* Skip this classifier */ - msg_debug_task ("<%s> contains less tokens than required for %s classifier: " - "%ud < %ud", cbdata->task->message_id, cl_runtime->clcf->name, - g_tree_nnodes (cbdata->tok->tokens), - cl_runtime->clcf->min_tokens); - cur = g_list_next (cur); + /* Check whether we have learned that file */ + for (i = 0; i < st_ctx->classifiers->len; i ++) { + cl = g_ptr_array_index (st_ctx->classifiers, i); + + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) { continue; } - curst = cl_runtime->st_runtime; - - while (curst) { - res = &g_array_index (t->results, struct rspamd_token_result, i); - st_runtime = (struct rspamd_statfile_runtime *)curst->data; - - if (cl_runtime->backend->learn_token (cbdata->task, t, res, - cl_runtime->backend->ctx)) { - cl_runtime->processed_tokens ++; - - if (cl_runtime->clcf->max_tokens > 0 && - cl_runtime->processed_tokens > cl_runtime->clcf->max_tokens) { - msg_debug_task ("message contains more tokens than allowed for %s classifier: " - "%uL > %ud", cl_runtime->clcf->name, - cl_runtime->processed_tokens, - cl_runtime->clcf->max_tokens); + if (cl->cache && cl->cachecf) { + learn_res = cl->cache->process (task, spam, + cl->cachecf); + } - return TRUE; - } - } + if (learn_res == RSPAMD_LEARN_INGORE) { + /* Do not learn twice */ + g_set_error (err, rspamd_stat_quark (), 404, "<%s> has been already " + "learned as %s, ignore it", task->message_id, + spam ? "spam" : "ham"); + task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; - i ++; - curst = g_list_next (curst); + return FALSE; + } + else if (learn_res == RSPAMD_LEARN_UNLEARN) { + task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + break; } - - cur = g_list_next (cur); } - - return FALSE; + return TRUE; } -rspamd_stat_result_t -rspamd_stat_learn (struct rspamd_task *task, - gboolean spam, - lua_State *L, - const gchar *classifier, - GError **err) +static gboolean +rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam, + GError **err) { - struct rspamd_stat_ctx *st_ctx; - struct rspamd_classifier_runtime *cl_run; - struct rspamd_statfile_runtime *st_run; - struct classifier_ctx *cl_ctx; - struct preprocess_cb_data cbdata; - GList *cl_runtimes; - GList *cur, *curst; - gboolean unlearn = FALSE; - rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_ERROR; - gulong nrev; - rspamd_learn_t learn_res = RSPAMD_LEARN_OK; + struct rspamd_classifier *cl; guint i; - gboolean learned = FALSE; - - st_ctx = rspamd_stat_get_ctx (); - g_assert (st_ctx != NULL); - - cur = g_list_first (task->cfg->classifiers); + gboolean learned = FALSE, too_small = FALSE, too_large = FALSE; /* Check whether we have learned that file */ - for (i = 0; i < st_ctx->caches_count; i ++) { - learn_res = st_ctx->caches[i].process (task, spam, - st_ctx->caches[i].ctx); + for (i = 0; i < st_ctx->classifiers->len; i ++) { + cl = g_ptr_array_index (st_ctx->classifiers, i); - if (learn_res == RSPAMD_LEARN_INGORE) { - /* Do not learn twice */ - g_set_error (err, rspamd_stat_quark (), 404, "<%s> has been already " - "learned as %s, ignore it", task->message_id, - spam ? "spam" : "ham"); - return RSPAMD_STAT_PROCESS_ERROR; + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) { + continue; } - else if (learn_res == RSPAMD_LEARN_UNLEARN) { - unlearn = TRUE; + + /* Now check max and min tokens */ + if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) { + msg_info_task ( + "<%s> contains less tokens than required for %s classifier: " + "%ud < %ud", + task->message_id, + cl->cfg->name, + task->tokens->len, + cl->cfg->min_tokens); + too_small = TRUE; + continue; + } + else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) { + msg_info_task ( + "<%s> contains more tokens than allowed for %s classifier: " + "%ud > %ud", + task->message_id, + cl->cfg->name, + task->tokens->len, + cl->cfg->max_tokens); + too_large = TRUE; + continue; + } + + if (cl->subrs->learn_spam_func (cl, task->tokens, task, spam, + task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { + learned = TRUE; } } - /* Initialize classifiers and statfiles runtime */ - if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, - task, - L, - unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP, - spam, - classifier, - err)) == NULL) { - return RSPAMD_STAT_PROCESS_ERROR; + if (!learned && err && *err == NULL) { + if (too_large) { + g_set_error (err, rspamd_stat_quark (), 400, + "<%s> contains more tokens than allowed for %s classifier: " + "%d > %d", + task->message_id, + cl->cfg->name, + task->tokens->len, + cl->cfg->max_tokens); + } + else if (too_small) { + g_set_error (err, rspamd_stat_quark (), 400, + "<%s> contains less tokens than required for %s classifier: " + "%d < %d", + task->message_id, + cl->cfg->name, + task->tokens->len, + cl->cfg->max_tokens); + } } - cur = cl_runtimes; + return learned; +} - while (cur) { - cl_run = (struct rspamd_classifier_runtime *)cur->data; +static gboolean +rspamd_stat_backends_learn (struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam, + GError **err) +{ + struct rspamd_classifier *cl; + struct rspamd_statfile *st; + gpointer bk_run; + guint i, j; + gint id; + gboolean res = TRUE; - curst = cl_run->st_runtime; + for (i = 0; i < st_ctx->classifiers->len; i ++) { + cl = g_ptr_array_index (st_ctx->classifiers, i); - /* Needed to finalize pre-process stage */ - while (curst) { - st_run = curst->data; - cl_run->backend->finalize_process (task, - st_run->backend_runtime, - cl_run->backend->ctx); - curst = g_list_next (curst); + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) { + continue; } - if (cl_run->skipped) { - msg_info_task ( - "<%s> contains less tokens than required for %s classifier: " - "%ud < %ud", - task->message_id, - cl_run->clcf->name, - g_tree_nnodes (cl_run->tok->tokens), - cl_run->clcf->min_tokens); - } + for (j = 0; j < cl->statfiles_ids->len; j ++) { + id = g_array_index (cl->statfiles_ids, gint, j); + st = g_ptr_array_index (st_ctx->statfiles, id); + bk_run = g_ptr_array_index (task->stat_runtimes, id); - if (cl_run->cl && !cl_run->skipped) { - cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf); - - if (cl_ctx != NULL) { - if (cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens, - cl_run, task, spam, err)) { - msg_debug_task ("learned %s classifier %s", spam ? "spam" : "ham", - cl_run->clcf->name); - ret = RSPAMD_STAT_PROCESS_OK; - learned = TRUE; - - cbdata.classifier_runtimes = cur; - cbdata.task = task; - cbdata.tok = cl_run->tok; - cbdata.unlearn = unlearn; - cbdata.spam = spam; - g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token, - &cbdata); - - curst = g_list_first (cl_run->st_runtime); - - while (curst) { - st_run = (struct rspamd_statfile_runtime *)curst->data; - - if (unlearn && spam != st_run->st->is_spam) { - nrev = cl_run->backend->dec_learns (task, - st_run->backend_runtime, - cl_run->backend->ctx); - msg_debug_task ("unlearned %s, new revision: %ul", - st_run->st->symbol, nrev); - } - else { - nrev = cl_run->backend->inc_learns (task, - st_run->backend_runtime, - cl_run->backend->ctx); - msg_debug_task ("learned %s, new revision: %ul", - st_run->st->symbol, nrev); - } - - cl_run->backend->finalize_learn (task, - st_run->backend_runtime, - cl_run->backend->ctx); - - curst = g_list_next (curst); - } + g_assert (st != NULL); + + if (bk_run == NULL) { + /* XXX: must be error */ + continue; + } + + if (!task->flags & RSPAMD_TASK_FLAG_UNLEARN) { + if (spam != st->stcf->is_spam) { + /* If we are not unlearning, then do not touch another class */ + continue; } - else { - return RSPAMD_STAT_PROCESS_ERROR; + } + + if (!st->backend->learn_tokens (task, task->tokens, id, bk_run)) { + if (err && *err == NULL) { + g_set_error (err, rspamd_stat_quark (), 500, "Cannot push " + "learned results to the backend"); } + res = FALSE; } } - - cur = g_list_next (cur); - } - - if (!learned) { - g_set_error (err, rspamd_stat_quark (), 500, "message cannot be learned as " - "it has too few tokens for any classifier defined"); - } - else { - g_atomic_int_inc (&task->worker->srv->stat->messages_learned); } - return ret; + return res; } -rspamd_stat_result_t rspamd_stat_statistics (struct rspamd_task *task, - struct rspamd_config *cfg, - guint64 *total_learns, - ucl_object_t **target) +static gboolean +rspamd_stat_backends_post_learn (struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam) { - struct rspamd_classifier_config *clcf; - struct rspamd_statfile_config *stcf; - struct rspamd_stat_backend *bk; - gpointer backend_runtime; - GList *cur, *st_list = NULL, *curst; - ucl_object_t *res = NULL, *elt; - guint64 learns = 0; - - if (cfg != NULL && cfg->classifiers != NULL) { - res = ucl_object_typed_new (UCL_ARRAY); + struct rspamd_classifier *cl; + struct rspamd_statfile *st; + gpointer bk_run; + guint i, j; + gint id; + gboolean res = TRUE; - cur = g_list_first (cfg->classifiers); + for (i = 0; i < st_ctx->classifiers->len; i ++) { + cl = g_ptr_array_index (st_ctx->classifiers, i); - while (cur) { - clcf = (struct rspamd_classifier_config *)cur->data; + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) { + continue; + } - st_list = clcf->statfiles; - curst = st_list; + for (j = 0; j < cl->statfiles_ids->len; j ++) { + id = g_array_index (cl->statfiles_ids, gint, j); + st = g_ptr_array_index (st_ctx->statfiles, id); + bk_run = g_ptr_array_index (task->stat_runtimes, id); - while (curst != NULL) { - stcf = (struct rspamd_statfile_config *)curst->data; + g_assert (st != NULL); - bk = rspamd_stat_get_backend (clcf->backend); + if (bk_run == NULL) { + /* XXX: must be error */ + continue; + } - if (bk == NULL) { - msg_warn ("backend of type %s is not defined", clcf->backend); - curst = g_list_next (curst); + if (!task->flags & RSPAMD_TASK_FLAG_UNLEARN) { + if (spam != st->stcf->is_spam) { + /* If we are not unlearning, then do not touch another class */ continue; } - backend_runtime = bk->runtime (task, stcf, FALSE, bk->ctx); - - learns += bk->total_learns (task, backend_runtime, bk->ctx); - elt = bk->get_stat (backend_runtime, bk->ctx); - - if (elt != NULL) { - ucl_array_append (res, elt); + st->backend->inc_learns (task, bk_run, st_ctx); + } + else { + if (spam == st->stcf->is_spam) { + st->backend->inc_learns (task, bk_run, st_ctx); + } + else { + st->backend->dec_learns (task, bk_run, st_ctx); } - - curst = g_list_next (curst); } - /* Next classifier */ - cur = g_list_next (cur); + st->backend->finalize_learn (task, bk_run, st_ctx); } + } + + return res; +} + +rspamd_stat_result_t +rspamd_stat_learn (struct rspamd_task *task, + gboolean spam, lua_State *L, const gchar *classifier, guint stage, + GError **err) +{ + struct rspamd_stat_ctx *st_ctx; + + /* + * We assume now that a task has been already classified before + * coming to learn + */ + g_assert (RSPAMD_TASK_IS_CLASSIFIED (task)); + + rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK; + + st_ctx = rspamd_stat_get_ctx (); + g_assert (st_ctx != NULL); - if (total_learns != NULL) { - *total_learns = learns; + if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) { + /* Process classifiers */ + if (!rspamd_stat_cache_check (st_ctx, task, classifier, spam, err)) { + return RSPAMD_STAT_PROCESS_ERROR; } } + else if (stage == RSPAMD_TASK_STAGE_LEARN) { + /* Process classifiers */ + if (!rspamd_stat_classifiers_learn (st_ctx, task, classifier, + spam, err)) { + return RSPAMD_STAT_PROCESS_ERROR; + } - if (target) { - *target = res; + /* Process backends */ + if (!rspamd_stat_backends_learn (st_ctx, task, classifier, spam, err)) { + return RSPAMD_STAT_PROCESS_ERROR; + } + } + else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) { + if (!rspamd_stat_backends_post_learn (st_ctx, task, classifier, spam)) { + return RSPAMD_STAT_PROCESS_ERROR; + } } - return RSPAMD_STAT_PROCESS_OK; -} -#else -/* TODO: finish learning */ -rspamd_stat_result_t rspamd_stat_learn (struct rspamd_task *task, - gboolean spam, lua_State *L, const gchar *classifier, - GError **err) -{ - return RSPAMD_STAT_PROCESS_ERROR; + return ret; } /** @@ -657,4 +655,3 @@ rspamd_stat_result_t rspamd_stat_statistics (struct rspamd_task *task, { return RSPAMD_STAT_PROCESS_ERROR; } -#endif