aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/libstat/backends/backends.h3
-rw-r--r--src/libstat/backends/mmaped_file.c40
-rw-r--r--src/libstat/classifiers/bayes.c7
-rw-r--r--src/libstat/stat_config.c1
-rw-r--r--src/libstat/stat_process.c36
5 files changed, 72 insertions, 15 deletions
diff --git a/src/libstat/backends/backends.h b/src/libstat/backends/backends.h
index c7c4210fb..f8a2af72c 100644
--- a/src/libstat/backends/backends.h
+++ b/src/libstat/backends/backends.h
@@ -49,6 +49,7 @@ struct rspamd_stat_backend {
struct rspamd_token_result *res, gpointer ctx);
gulong (*total_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx);
gulong (*inc_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx);
+ gulong (*dec_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx);
ucl_object_t* (*get_stat)(struct rspamd_statfile_runtime *runtime, gpointer ctx);
gpointer ctx;
};
@@ -66,6 +67,8 @@ gulong rspamd_mmaped_file_total_learns (struct rspamd_statfile_runtime *runtime,
gpointer ctx);
gulong rspamd_mmaped_file_inc_learns (struct rspamd_statfile_runtime *runtime,
gpointer ctx);
+gulong rspamd_mmaped_file_dec_learns (struct rspamd_statfile_runtime *runtime,
+ gpointer ctx);
ucl_object_t * rspamd_mmaped_file_get_stat (struct rspamd_statfile_runtime *runtime,
gpointer ctx);
diff --git a/src/libstat/backends/mmaped_file.c b/src/libstat/backends/mmaped_file.c
index 0fb386f61..02ea17c28 100644
--- a/src/libstat/backends/mmaped_file.c
+++ b/src/libstat/backends/mmaped_file.c
@@ -291,6 +291,23 @@ rspamd_mmaped_file_inc_revision (rspamd_mmaped_file_t *file)
}
gboolean
+rspamd_mmaped_file_dec_revision (rspamd_mmaped_file_t *file)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return FALSE;
+ }
+
+ header = (struct stat_file_header *)file->map;
+
+ header->revision--;
+
+ return TRUE;
+}
+
+
+gboolean
rspamd_mmaped_file_get_revision (rspamd_mmaped_file_t *file, guint64 *rev, time_t *time)
{
struct stat_file_header *header;
@@ -939,11 +956,7 @@ rspamd_mmaped_file_learn_token (rspamd_token_t *tok,
memcpy (&h2, tok->data + sizeof (h1), sizeof (h2));
rspamd_mmaped_file_set_block (ctx, mf, h1, h2, res->value);
- if (res->value > 0.0) {
- return TRUE;
- }
-
- return FALSE;
+ return TRUE;
}
gulong
@@ -977,6 +990,23 @@ rspamd_mmaped_file_inc_learns (struct rspamd_statfile_runtime *runtime,
return rev;
}
+gulong
+rspamd_mmaped_file_dec_learns (struct rspamd_statfile_runtime *runtime,
+ gpointer ctx)
+{
+ rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *)runtime;
+ guint64 rev = 0;
+ time_t t;
+
+ if (mf != NULL) {
+ rspamd_mmaped_file_dec_revision (mf);
+ rspamd_mmaped_file_get_revision (mf, &rev, &t);
+ }
+
+ return rev;
+}
+
+
ucl_object_t *
rspamd_mmaped_file_get_stat (struct rspamd_statfile_runtime *runtime,
gpointer ctx)
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c
index be6c6f545..7932ceb9e 100644
--- a/src/libstat/classifiers/bayes.c
+++ b/src/libstat/classifiers/bayes.c
@@ -221,6 +221,10 @@ bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data)
if (res->st_runtime->st->is_spam) {
res->value ++;
}
+ else if (res->value > 0) {
+ /* Unlearning */
+ res->value --;
+ }
}
return FALSE;
@@ -241,6 +245,9 @@ bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data)
if (!res->st_runtime->st->is_spam) {
res->value ++;
}
+ else if (res->value > 0) {
+ res->value --;
+ }
}
return FALSE;
diff --git a/src/libstat/stat_config.c b/src/libstat/stat_config.c
index b8ad6ec30..17b5c54f5 100644
--- a/src/libstat/stat_config.c
+++ b/src/libstat/stat_config.c
@@ -53,6 +53,7 @@ static struct rspamd_stat_backend stat_backends[] = {
.learn_token = rspamd_mmaped_file_learn_token,
.total_learns = rspamd_mmaped_file_total_learns,
.inc_learns = rspamd_mmaped_file_inc_learns,
+ .dec_learns = rspamd_mmaped_file_dec_learns,
.get_stat = rspamd_mmaped_file_get_stat
}
};
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 8b2a19429..1ce439c51 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -29,12 +29,17 @@
#include "lua/lua_common.h"
#include <utlist.h>
+#define RSPAMD_CLASSIFY_OP 0
+#define RSPAMD_LEARN_OP 1
+#define RSPAMD_UNLEARN_OP 2
+
struct preprocess_cb_data {
struct rspamd_task *task;
GList *classifier_runtimes;
struct rspamd_tokenizer_runtime *tok;
guint results_count;
gboolean unlearn;
+ gboolean spam;
};
static struct rspamd_tokenizer_runtime *
@@ -135,7 +140,7 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d)
static GList*
rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
struct rspamd_task *task, struct rspamd_tokenizer_runtime *tklist,
- lua_State *L, gboolean learn, gboolean spam, GError **err)
+ lua_State *L, gint op, gboolean spam, GError **err)
{
struct rspamd_classifier_config *clcf;
struct rspamd_statfile_config *stcf;
@@ -186,7 +191,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
stcf = (struct rspamd_statfile_config *)curst->data;
/* On learning skip statfiles that do not belong to class */
- if (learn && (spam != stcf->is_spam)) {
+ if (op == RSPAMD_LEARN_OP && (spam != stcf->is_spam)) {
curst = g_list_next (curst);
continue;
}
@@ -199,7 +204,8 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
continue;
}
- backend_runtime = bk->runtime (stcf, learn, bk->ctx);
+ backend_runtime = bk->runtime (stcf, op != RSPAMD_CLASSIFY_OP,
+ bk->ctx);
st_runtime = rspamd_mempool_alloc0 (task->task_pool,
sizeof (*st_runtime));
@@ -354,7 +360,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
/* Initialize classifiers and statfiles runtime */
if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L,
- FALSE, FALSE, err)) == NULL) {
+ RSPAMD_CLASSIFY_OP, FALSE, err)) == NULL) {
return RSPAMD_STAT_PROCESS_ERROR;
}
@@ -407,11 +413,12 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
continue;
}
- res = &g_array_index (t->results, struct rspamd_token_result, i);
- curst = res->cl_runtime->st_runtime;
+
+ curst = cl_runtime->st_runtime;
while (curst) {
+ res = &g_array_index (t->results, struct rspamd_token_result, i);
st_runtime = (struct rspamd_statfile_runtime *)curst->data;
if (st_runtime->backend->learn_token (t, res,
@@ -432,6 +439,7 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
i ++;
curst = g_list_next (curst);
}
+
cur = g_list_next (cur);
}
@@ -507,7 +515,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
/* Initialize classifiers and statfiles runtime */
if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L,
- TRUE, spam, err)) == NULL) {
+ unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP, spam, err)) == NULL) {
return RSPAMD_STAT_PROCESS_ERROR;
}
@@ -530,6 +538,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
cbdata.task = task;
cbdata.tok = cl_run->tok;
cbdata.unlearn = unlearn;
+ cbdata.spam = spam;
g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token,
&cbdata);
@@ -538,11 +547,18 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
while (curst) {
st_run = (struct rspamd_statfile_runtime *)curst->data;
- nrev = st_run->backend->inc_learns (st_run->backend_runtime,
+ if (unlearn && spam != st_run->st->is_spam) {
+ nrev = st_run->backend->dec_learns (st_run->backend_runtime,
+ st_run->backend->ctx);
+ msg_debug ("unlearned %s, new revision: %ul",
+ st_run->st->symbol, nrev);
+ }
+ else {
+ nrev = st_run->backend->inc_learns (st_run->backend_runtime,
st_run->backend->ctx);
-
- msg_debug ("learned %s, new revision: %ul",
+ msg_debug ("learned %s, new revision: %ul",
st_run->st->symbol, nrev);
+ }
curst = g_list_next (curst);
}