aboutsummaryrefslogtreecommitdiffstats
path: root/src/libstat/stat_process.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2015-01-26 22:20:29 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2015-01-26 22:20:29 +0000
commitf117de8a3c06894d5c10d39abf72611c31bf31a3 (patch)
tree2ca37497102f6f23fafaa88b01e55a927c08d38c /src/libstat/stat_process.c
parent5525b96d6c9d3f4c36167f4e5abf48c485ac5a07 (diff)
downloadrspamd-f117de8a3c06894d5c10d39abf72611c31bf31a3.tar.gz
rspamd-f117de8a3c06894d5c10d39abf72611c31bf31a3.zip
Fixing learning.
Diffstat (limited to 'src/libstat/stat_process.c')
-rw-r--r--src/libstat/stat_process.c87
1 files changed, 79 insertions, 8 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index e00c0aaa8..17c8059a9 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -135,7 +135,7 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d)
static GList*
rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
struct rspamd_task *task, struct rspamd_tokenizer_runtime *tklist,
- lua_State *L, GError **err)
+ lua_State *L, gboolean learn, GError **err)
{
struct rspamd_classifier_config *clcf;
struct rspamd_statfile_config *stcf;
@@ -193,7 +193,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
continue;
}
- backend_runtime = bk->runtime (stcf, bk->ctx);
+ backend_runtime = bk->runtime (stcf, learn, bk->ctx);
st_runtime = rspamd_mempool_alloc0 (task->task_pool,
sizeof (*st_runtime));
@@ -243,8 +243,6 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
cbdata.results_count = result_size;
cbdata.classifier_runtimes = cl_runtimes;
cbdata.task = task;
-
- /* Allocate token results */
cbdata.tok = cl_runtime->tok;
g_tree_foreach (cl_runtime->tok->tokens, preprocess_init_stat_token,
&cbdata);
@@ -346,7 +344,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
}
/* Initialize classifiers and statfiles runtime */
- if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, err))
+ if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, FALSE, err))
== NULL) {
return FALSE;
}
@@ -371,6 +369,65 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
return ret;
}
+static gboolean
+rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
+{
+ rspamd_token_t *t = (rspamd_token_t *)v;
+ struct preprocess_cb_data *cbdata = (struct preprocess_cb_data *)d;
+ struct rspamd_statfile_runtime *st_runtime;
+ struct rspamd_classifier_runtime *cl_runtime;
+ struct rspamd_token_result *res;
+ GList *cur, *curst;
+ gint i = 0;
+
+ cur = g_list_first (cbdata->classifier_runtimes);
+
+ while (cur) {
+ cl_runtime = (struct rspamd_classifier_runtime *)cur->data;
+
+ if (cl_runtime->clcf->min_tokens > 0 &&
+ (guint32)g_tree_nnodes (cbdata->tok->tokens) < cl_runtime->clcf->min_tokens) {
+ /* Skip this classifier */
+ msg_debug ("<%s> contains less tokens than required for %s classifier: "
+ "%ud < %ud", cbdata->task->message_id, cl_runtime->clcf->name,
+ g_tree_nnodes (cbdata->tok->tokens),
+ cl_runtime->clcf->min_tokens);
+ cur = g_list_next (cur);
+ continue;
+ }
+
+ res = &g_array_index (t->results, struct rspamd_token_result, i);
+
+ curst = res->cl_runtime->st_runtime;
+
+ while (curst) {
+ st_runtime = (struct rspamd_statfile_runtime *)curst->data;
+
+ if (st_runtime->backend->learn_token (t, res,
+ st_runtime->backend->ctx)) {
+ cl_runtime->processed_tokens ++;
+
+ if (cl_runtime->clcf->max_tokens > 0 &&
+ cl_runtime->processed_tokens > cl_runtime->clcf->max_tokens) {
+ msg_debug ("<%s> contains more tokens than allowed for %s classifier: "
+ "%ud > %ud", cbdata->task, cl_runtime->clcf->name,
+ cl_runtime->processed_tokens,
+ cl_runtime->clcf->max_tokens);
+
+ return TRUE;
+ }
+ }
+
+ i ++;
+ curst = g_list_next (curst);
+ }
+ cur = g_list_next (cur);
+ }
+
+
+ return FALSE;
+}
+
gboolean
rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
GError **err)
@@ -381,6 +438,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
struct rspamd_tokenizer_runtime *tklist = NULL, *tok;
struct rspamd_classifier_runtime *cl_run;
struct classifier_ctx *cl_ctx;
+ struct preprocess_cb_data cbdata;
GList *cl_runtimes;
GList *cur;
gboolean ret = FALSE;
@@ -416,7 +474,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
}
/* Initialize classifiers and statfiles runtime */
- if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, err))
+ if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, TRUE, err))
== NULL) {
return FALSE;
}
@@ -430,8 +488,21 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf);
if (cl_ctx != NULL) {
- ret |= cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens,
- cl_run, task, spam, err);
+ if (cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens,
+ cl_run, task, spam, err)) {
+ ret = TRUE;
+
+ cbdata.classifier_runtimes = cur;
+ cbdata.task = task;
+ cbdata.tok = cl_run->tok;
+ g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token,
+ &cbdata);
+
+ }
+ else {
+ return FALSE;
+ }
+
}
}