return ret;
}
-#if 0
static gboolean
-rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
+rspamd_stat_cache_check (struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam,
+ GError **err)
{
- rspamd_token_t *t = (rspamd_token_t *)v;
- struct preprocess_cb_data *cbdata = (struct preprocess_cb_data *)d;
- struct rspamd_statfile_runtime *st_runtime;
- struct rspamd_classifier_runtime *cl_runtime;
- struct rspamd_token_result *res;
- struct rspamd_task *task;
- GList *cur, *curst;
- gint i = 0;
-
- task = cbdata->task;
- cur = g_list_first (cbdata->classifier_runtimes);
+ rspamd_learn_t learn_res = RSPAMD_LEARN_OK;
+ struct rspamd_classifier *cl;
+ guint i;
- while (cur) {
- cl_runtime = (struct rspamd_classifier_runtime *)cur->data;
-
- if (cl_runtime->clcf->min_tokens > 0 &&
- (guint32)g_tree_nnodes (cbdata->tok->tokens) < cl_runtime->clcf->min_tokens) {
- /* Skip this classifier */
- msg_debug_task ("<%s> contains less tokens than required for %s classifier: "
- "%ud < %ud", cbdata->task->message_id, cl_runtime->clcf->name,
- g_tree_nnodes (cbdata->tok->tokens),
- cl_runtime->clcf->min_tokens);
- cur = g_list_next (cur);
+ /* Check whether we have learned that file */
+ for (i = 0; i < st_ctx->classifiers->len; i ++) {
+ cl = g_ptr_array_index (st_ctx->classifiers, i);
+
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
continue;
}
- curst = cl_runtime->st_runtime;
-
- while (curst) {
- res = &g_array_index (t->results, struct rspamd_token_result, i);
- st_runtime = (struct rspamd_statfile_runtime *)curst->data;
-
- if (cl_runtime->backend->learn_token (cbdata->task, t, res,
- cl_runtime->backend->ctx)) {
- cl_runtime->processed_tokens ++;
-
- if (cl_runtime->clcf->max_tokens > 0 &&
- cl_runtime->processed_tokens > cl_runtime->clcf->max_tokens) {
- msg_debug_task ("message contains more tokens than allowed for %s classifier: "
- "%uL > %ud", cl_runtime->clcf->name,
- cl_runtime->processed_tokens,
- cl_runtime->clcf->max_tokens);
+ if (cl->cache && cl->cachecf) {
+ learn_res = cl->cache->process (task, spam,
+ cl->cachecf);
+ }
- return TRUE;
- }
- }
+ if (learn_res == RSPAMD_LEARN_INGORE) {
+ /* Do not learn twice */
+ g_set_error (err, rspamd_stat_quark (), 404, "<%s> has been already "
+ "learned as %s, ignore it", task->message_id,
+ spam ? "spam" : "ham");
+ task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
- i ++;
- curst = g_list_next (curst);
+ return FALSE;
+ }
+ else if (learn_res == RSPAMD_LEARN_UNLEARN) {
+ task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+ break;
}
-
- cur = g_list_next (cur);
}
-
- return FALSE;
+ return TRUE;
}
-rspamd_stat_result_t
-rspamd_stat_learn (struct rspamd_task *task,
- gboolean spam,
- lua_State *L,
- const gchar *classifier,
- GError **err)
+static gboolean
+rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam,
+ GError **err)
{
- struct rspamd_stat_ctx *st_ctx;
- struct rspamd_classifier_runtime *cl_run;
- struct rspamd_statfile_runtime *st_run;
- struct classifier_ctx *cl_ctx;
- struct preprocess_cb_data cbdata;
- GList *cl_runtimes;
- GList *cur, *curst;
- gboolean unlearn = FALSE;
- rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_ERROR;
- gulong nrev;
- rspamd_learn_t learn_res = RSPAMD_LEARN_OK;
+ struct rspamd_classifier *cl;
guint i;
- gboolean learned = FALSE;
-
- st_ctx = rspamd_stat_get_ctx ();
- g_assert (st_ctx != NULL);
-
- cur = g_list_first (task->cfg->classifiers);
+ gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
/* Check whether we have learned that file */
- for (i = 0; i < st_ctx->caches_count; i ++) {
- learn_res = st_ctx->caches[i].process (task, spam,
- st_ctx->caches[i].ctx);
+ for (i = 0; i < st_ctx->classifiers->len; i ++) {
+ cl = g_ptr_array_index (st_ctx->classifiers, i);
- if (learn_res == RSPAMD_LEARN_INGORE) {
- /* Do not learn twice */
- g_set_error (err, rspamd_stat_quark (), 404, "<%s> has been already "
- "learned as %s, ignore it", task->message_id,
- spam ? "spam" : "ham");
- return RSPAMD_STAT_PROCESS_ERROR;
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
+ continue;
}
- else if (learn_res == RSPAMD_LEARN_UNLEARN) {
- unlearn = TRUE;
+
+ /* Now check max and min tokens */
+ if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
+ msg_info_task (
+ "<%s> contains less tokens than required for %s classifier: "
+ "%ud < %ud",
+ task->message_id,
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->min_tokens);
+ too_small = TRUE;
+ continue;
+ }
+ else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
+ msg_info_task (
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%ud > %ud",
+ task->message_id,
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->max_tokens);
+ too_large = TRUE;
+ continue;
+ }
+
+ if (cl->subrs->learn_spam_func (cl, task->tokens, task, spam,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
}
}
- /* Initialize classifiers and statfiles runtime */
- if ((cl_runtimes = rspamd_stat_preprocess (st_ctx,
- task,
- L,
- unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP,
- spam,
- classifier,
- err)) == NULL) {
- return RSPAMD_STAT_PROCESS_ERROR;
+ if (!learned && err && *err == NULL) {
+ if (too_large) {
+ g_set_error (err, rspamd_stat_quark (), 400,
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%d > %d",
+ task->message_id,
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->max_tokens);
+ }
+ else if (too_small) {
+ g_set_error (err, rspamd_stat_quark (), 400,
+ "<%s> contains less tokens than required for %s classifier: "
+ "%d < %d",
+ task->message_id,
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->max_tokens);
+ }
}
- cur = cl_runtimes;
+ return learned;
+}
- while (cur) {
- cl_run = (struct rspamd_classifier_runtime *)cur->data;
+static gboolean
+rspamd_stat_backends_learn (struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam,
+ GError **err)
+{
+ struct rspamd_classifier *cl;
+ struct rspamd_statfile *st;
+ gpointer bk_run;
+ guint i, j;
+ gint id;
+ gboolean res = TRUE;
- curst = cl_run->st_runtime;
+ for (i = 0; i < st_ctx->classifiers->len; i ++) {
+ cl = g_ptr_array_index (st_ctx->classifiers, i);
- /* Needed to finalize pre-process stage */
- while (curst) {
- st_run = curst->data;
- cl_run->backend->finalize_process (task,
- st_run->backend_runtime,
- cl_run->backend->ctx);
- curst = g_list_next (curst);
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
+ continue;
}
- if (cl_run->skipped) {
- msg_info_task (
- "<%s> contains less tokens than required for %s classifier: "
- "%ud < %ud",
- task->message_id,
- cl_run->clcf->name,
- g_tree_nnodes (cl_run->tok->tokens),
- cl_run->clcf->min_tokens);
- }
+ for (j = 0; j < cl->statfiles_ids->len; j ++) {
+ id = g_array_index (cl->statfiles_ids, gint, j);
+ st = g_ptr_array_index (st_ctx->statfiles, id);
+ bk_run = g_ptr_array_index (task->stat_runtimes, id);
- if (cl_run->cl && !cl_run->skipped) {
- cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf);
-
- if (cl_ctx != NULL) {
- if (cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens,
- cl_run, task, spam, err)) {
- msg_debug_task ("learned %s classifier %s", spam ? "spam" : "ham",
- cl_run->clcf->name);
- ret = RSPAMD_STAT_PROCESS_OK;
- learned = TRUE;
-
- cbdata.classifier_runtimes = cur;
- cbdata.task = task;
- cbdata.tok = cl_run->tok;
- cbdata.unlearn = unlearn;
- cbdata.spam = spam;
- g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token,
- &cbdata);
-
- curst = g_list_first (cl_run->st_runtime);
-
- while (curst) {
- st_run = (struct rspamd_statfile_runtime *)curst->data;
-
- if (unlearn && spam != st_run->st->is_spam) {
- nrev = cl_run->backend->dec_learns (task,
- st_run->backend_runtime,
- cl_run->backend->ctx);
- msg_debug_task ("unlearned %s, new revision: %ul",
- st_run->st->symbol, nrev);
- }
- else {
- nrev = cl_run->backend->inc_learns (task,
- st_run->backend_runtime,
- cl_run->backend->ctx);
- msg_debug_task ("learned %s, new revision: %ul",
- st_run->st->symbol, nrev);
- }
-
- cl_run->backend->finalize_learn (task,
- st_run->backend_runtime,
- cl_run->backend->ctx);
-
- curst = g_list_next (curst);
- }
+ g_assert (st != NULL);
+
+ if (bk_run == NULL) {
+ /* XXX: must be error */
+ continue;
+ }
+
+ if (!task->flags & RSPAMD_TASK_FLAG_UNLEARN) {
+ if (spam != st->stcf->is_spam) {
+ /* If we are not unlearning, then do not touch another class */
+ continue;
}
- else {
- return RSPAMD_STAT_PROCESS_ERROR;
+ }
+
+ if (!st->backend->learn_tokens (task, task->tokens, id, bk_run)) {
+ if (err && *err == NULL) {
+ g_set_error (err, rspamd_stat_quark (), 500, "Cannot push "
+ "learned results to the backend");
}
+ res = FALSE;
}
}
-
- cur = g_list_next (cur);
- }
-
- if (!learned) {
- g_set_error (err, rspamd_stat_quark (), 500, "message cannot be learned as "
- "it has too few tokens for any classifier defined");
- }
- else {
- g_atomic_int_inc (&task->worker->srv->stat->messages_learned);
}
- return ret;
+ return res;
}
-rspamd_stat_result_t rspamd_stat_statistics (struct rspamd_task *task,
- struct rspamd_config *cfg,
- guint64 *total_learns,
- ucl_object_t **target)
+static gboolean
+rspamd_stat_backends_post_learn (struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam)
{
- struct rspamd_classifier_config *clcf;
- struct rspamd_statfile_config *stcf;
- struct rspamd_stat_backend *bk;
- gpointer backend_runtime;
- GList *cur, *st_list = NULL, *curst;
- ucl_object_t *res = NULL, *elt;
- guint64 learns = 0;
-
- if (cfg != NULL && cfg->classifiers != NULL) {
- res = ucl_object_typed_new (UCL_ARRAY);
+ struct rspamd_classifier *cl;
+ struct rspamd_statfile *st;
+ gpointer bk_run;
+ guint i, j;
+ gint id;
+ gboolean res = TRUE;
- cur = g_list_first (cfg->classifiers);
+ for (i = 0; i < st_ctx->classifiers->len; i ++) {
+ cl = g_ptr_array_index (st_ctx->classifiers, i);
- while (cur) {
- clcf = (struct rspamd_classifier_config *)cur->data;
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp (classifier, cl->cfg->name) != 0)) {
+ continue;
+ }
- st_list = clcf->statfiles;
- curst = st_list;
+ for (j = 0; j < cl->statfiles_ids->len; j ++) {
+ id = g_array_index (cl->statfiles_ids, gint, j);
+ st = g_ptr_array_index (st_ctx->statfiles, id);
+ bk_run = g_ptr_array_index (task->stat_runtimes, id);
- while (curst != NULL) {
- stcf = (struct rspamd_statfile_config *)curst->data;
+ g_assert (st != NULL);
- bk = rspamd_stat_get_backend (clcf->backend);
+ if (bk_run == NULL) {
+ /* XXX: must be error */
+ continue;
+ }
- if (bk == NULL) {
- msg_warn ("backend of type %s is not defined", clcf->backend);
- curst = g_list_next (curst);
+ if (!task->flags & RSPAMD_TASK_FLAG_UNLEARN) {
+ if (spam != st->stcf->is_spam) {
+ /* If we are not unlearning, then do not touch another class */
continue;
}
- backend_runtime = bk->runtime (task, stcf, FALSE, bk->ctx);
-
- learns += bk->total_learns (task, backend_runtime, bk->ctx);
- elt = bk->get_stat (backend_runtime, bk->ctx);
-
- if (elt != NULL) {
- ucl_array_append (res, elt);
+ st->backend->inc_learns (task, bk_run, st_ctx);
+ }
+ else {
+ if (spam == st->stcf->is_spam) {
+ st->backend->inc_learns (task, bk_run, st_ctx);
+ }
+ else {
+ st->backend->dec_learns (task, bk_run, st_ctx);
}
-
- curst = g_list_next (curst);
}
- /* Next classifier */
- cur = g_list_next (cur);
+ st->backend->finalize_learn (task, bk_run, st_ctx);
}
+ }
+
+ return res;
+}
+
+rspamd_stat_result_t
+rspamd_stat_learn (struct rspamd_task *task,
+ gboolean spam, lua_State *L, const gchar *classifier, guint stage,
+ GError **err)
+{
+ struct rspamd_stat_ctx *st_ctx;
+
+ /*
+ * We assume now that a task has been already classified before
+ * coming to learn
+ */
+ g_assert (RSPAMD_TASK_IS_CLASSIFIED (task));
+
+ rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
+
+ st_ctx = rspamd_stat_get_ctx ();
+ g_assert (st_ctx != NULL);
- if (total_learns != NULL) {
- *total_learns = learns;
+ if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
+ /* Process classifiers */
+ if (!rspamd_stat_cache_check (st_ctx, task, classifier, spam, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
}
}
+ else if (stage == RSPAMD_TASK_STAGE_LEARN) {
+ /* Process classifiers */
+ if (!rspamd_stat_classifiers_learn (st_ctx, task, classifier,
+ spam, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
- if (target) {
- *target = res;
+ /* Process backends */
+ if (!rspamd_stat_backends_learn (st_ctx, task, classifier, spam, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+ else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) {
+ if (!rspamd_stat_backends_post_learn (st_ctx, task, classifier, spam)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
}
- return RSPAMD_STAT_PROCESS_OK;
-}
-#else
-/* TODO: finish learning */
-rspamd_stat_result_t rspamd_stat_learn (struct rspamd_task *task,
- gboolean spam, lua_State *L, const gchar *classifier,
- GError **err)
-{
- return RSPAMD_STAT_PROCESS_ERROR;
+ return ret;
}
/**
{
return RSPAMD_STAT_PROCESS_ERROR;
}
-#endif