summaryrefslogtreecommitdiffstats
path: root/src/libstat/stat_process.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/libstat/stat_process.c')
-rw-r--r--src/libstat/stat_process.c36
1 files changed, 26 insertions, 10 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 8b2a19429..1ce439c51 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -29,12 +29,17 @@
#include "lua/lua_common.h"
#include <utlist.h>
+#define RSPAMD_CLASSIFY_OP 0
+#define RSPAMD_LEARN_OP 1
+#define RSPAMD_UNLEARN_OP 2
+
struct preprocess_cb_data {
struct rspamd_task *task;
GList *classifier_runtimes;
struct rspamd_tokenizer_runtime *tok;
guint results_count;
gboolean unlearn;
+ gboolean spam;
};
static struct rspamd_tokenizer_runtime *
@@ -135,7 +140,7 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d)
static GList*
rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
struct rspamd_task *task, struct rspamd_tokenizer_runtime *tklist,
- lua_State *L, gboolean learn, gboolean spam, GError **err)
+ lua_State *L, gint op, gboolean spam, GError **err)
{
struct rspamd_classifier_config *clcf;
struct rspamd_statfile_config *stcf;
@@ -186,7 +191,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
stcf = (struct rspamd_statfile_config *)curst->data;
/* On learning skip statfiles that do not belong to class */
- if (learn && (spam != stcf->is_spam)) {
+ if (op == RSPAMD_LEARN_OP && (spam != stcf->is_spam)) {
curst = g_list_next (curst);
continue;
}
@@ -199,7 +204,8 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
continue;
}
- backend_runtime = bk->runtime (stcf, learn, bk->ctx);
+ backend_runtime = bk->runtime (stcf, op != RSPAMD_CLASSIFY_OP,
+ bk->ctx);
st_runtime = rspamd_mempool_alloc0 (task->task_pool,
sizeof (*st_runtime));
@@ -354,7 +360,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
/* Initialize classifiers and statfiles runtime */
if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L,
- FALSE, FALSE, err)) == NULL) {
+ RSPAMD_CLASSIFY_OP, FALSE, err)) == NULL) {
return RSPAMD_STAT_PROCESS_ERROR;
}
@@ -407,11 +413,12 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
continue;
}
- res = &g_array_index (t->results, struct rspamd_token_result, i);
- curst = res->cl_runtime->st_runtime;
+
+ curst = cl_runtime->st_runtime;
while (curst) {
+ res = &g_array_index (t->results, struct rspamd_token_result, i);
st_runtime = (struct rspamd_statfile_runtime *)curst->data;
if (st_runtime->backend->learn_token (t, res,
@@ -432,6 +439,7 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
i ++;
curst = g_list_next (curst);
}
+
cur = g_list_next (cur);
}
@@ -507,7 +515,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
/* Initialize classifiers and statfiles runtime */
if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L,
- TRUE, spam, err)) == NULL) {
+ unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP, spam, err)) == NULL) {
return RSPAMD_STAT_PROCESS_ERROR;
}
@@ -530,6 +538,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
cbdata.task = task;
cbdata.tok = cl_run->tok;
cbdata.unlearn = unlearn;
+ cbdata.spam = spam;
g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token,
&cbdata);
@@ -538,11 +547,18 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
while (curst) {
st_run = (struct rspamd_statfile_runtime *)curst->data;
- nrev = st_run->backend->inc_learns (st_run->backend_runtime,
+ if (unlearn && spam != st_run->st->is_spam) {
+ nrev = st_run->backend->dec_learns (st_run->backend_runtime,
+ st_run->backend->ctx);
+ msg_debug ("unlearned %s, new revision: %ul",
+ st_run->st->symbol, nrev);
+ }
+ else {
+ nrev = st_run->backend->inc_learns (st_run->backend_runtime,
st_run->backend->ctx);
-
- msg_debug ("learned %s, new revision: %ul",
+ msg_debug ("learned %s, new revision: %ul",
st_run->st->symbol, nrev);
+ }
curst = g_list_next (curst);
}