aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/libserver/task.h1
-rw-r--r--src/libstat/classifiers/bayes.c82
-rw-r--r--src/libstat/stat_internal.h6
-rw-r--r--src/libstat/stat_process.c17
4 files changed, 67 insertions, 39 deletions
diff --git a/src/libserver/task.h b/src/libserver/task.h
index ded241b31..135e8bf92 100644
--- a/src/libserver/task.h
+++ b/src/libserver/task.h
@@ -143,6 +143,7 @@ struct rspamd_task {
struct event_base *ev_base; /**< Event base */
GThreadPool *classify_pool; /**< A pool of classify threads */
+ gpointer classify_data; /**< Opaque classifiers data */
struct {
enum rspamd_metric_action action; /**< Action of pre filters */
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c
index 7932ceb9e..823f5eff9 100644
--- a/src/libstat/classifiers/bayes.c
+++ b/src/libstat/classifiers/bayes.c
@@ -151,55 +151,59 @@ bayes_classify (struct classifier_ctx * ctx,
g_assert (rt != NULL);
g_assert (rt->end_pos > rt->start_pos);
- g_tree_foreach (input, bayes_classify_callback, rt);
-
- if (rt->spam_prob == 0) {
- final_prob = 0;
+ if (rt->stage == RSPAMD_STAT_STAGE_PRE) {
+ g_tree_foreach (input, bayes_classify_callback, rt);
}
else {
- h = 1 - inv_chi_square (-2. * rt->spam_prob,
- 2 * rt->processed_tokens);
- s = 1 - inv_chi_square (-2. * rt->ham_prob,
- 2 * rt->processed_tokens);
- final_prob = (s + 1 - h) / 2.;
- msg_debug ("<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f",
- task->message_id, rt->ham_prob, h, rt->spam_prob, s);
- }
- if (rt->processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
+ if (rt->spam_prob == 0) {
+ final_prob = 0;
+ }
+ else {
+ h = 1 - inv_chi_square (-2. * rt->spam_prob,
+ 2 * rt->processed_tokens);
+ s = 1 - inv_chi_square (-2. * rt->ham_prob,
+ 2 * rt->processed_tokens);
+ final_prob = (s + 1 - h) / 2.;
+ msg_debug ("<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f",
+ task->message_id, rt->ham_prob, h, rt->spam_prob, s);
+ }
+
+ if (rt->processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
- sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
- cur = g_list_first (rt->st_runtime);
+ sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
+ cur = g_list_first (rt->st_runtime);
- while (cur) {
- st = (struct rspamd_statfile_runtime *)cur->data;
+ while (cur) {
+ st = (struct rspamd_statfile_runtime *)cur->data;
- if ((final_prob < 0.5 && !st->st->is_spam) ||
- (final_prob > 0.5 && st->st->is_spam)) {
- if (st->total_hits > maxhits) {
- maxhits = st->total_hits;
- selected_st = st;
+ if ((final_prob < 0.5 && !st->st->is_spam) ||
+ (final_prob > 0.5 && st->st->is_spam)) {
+ if (st->total_hits > maxhits) {
+ maxhits = st->total_hits;
+ selected_st = st;
+ }
}
- }
- cur = g_list_next (cur);
- }
+ cur = g_list_next (cur);
+ }
- if (selected_st == NULL) {
- msg_err (
- "unexpected classifier error: cannot select desired statfile");
- }
- else {
- /* Calculate ham probability correctly */
- if (final_prob < 0.5) {
- final_prob = 1. - final_prob;
+ if (selected_st == NULL) {
+ msg_err (
+ "unexpected classifier error: cannot select desired statfile");
+ }
+ else {
+ /* Calculate ham probability correctly */
+ if (final_prob < 0.5) {
+ final_prob = 1. - final_prob;
+ }
+ rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
+ cur = g_list_prepend (NULL, sumbuf);
+ rspamd_task_insert_result (task,
+ selected_st->st->symbol,
+ final_prob,
+ cur);
}
- rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
- cur = g_list_prepend (NULL, sumbuf);
- rspamd_task_insert_result (task,
- selected_st->st->symbol,
- final_prob,
- cur);
}
}
diff --git a/src/libstat/stat_internal.h b/src/libstat/stat_internal.h
index 29bd937fb..051404814 100644
--- a/src/libstat/stat_internal.h
+++ b/src/libstat/stat_internal.h
@@ -30,6 +30,11 @@
#include "backends/backends.h"
#include "learn_cache/learn_cache.h"
+enum stat_process_stage {
+ RSPAMD_STAT_STAGE_PRE = 0,
+ RSPAMD_STAT_STAGE_POST
+};
+
struct rspamd_tokenizer_runtime {
GTree *tokens;
const gchar *name;
@@ -51,6 +56,7 @@ struct rspamd_classifier_runtime {
struct rspamd_tokenizer_runtime *tok;
double ham_prob;
double spam_prob;
+ enum stat_process_stage stage;
guint64 total_spam;
guint64 total_ham;
guint64 processed_tokens;
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 311eaa0ea..4cb0f42bb 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -400,11 +400,28 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
while (cur) {
cl_run = (struct rspamd_classifier_runtime *)cur->data;
+ cl_run->stage = RSPAMD_STAT_STAGE_PRE;
if (cl_run->cl) {
cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf);
if (cl_ctx != NULL) {
+ cl_run->cl->classify_func (cl_ctx, cl_run->tok->tokens,
+ cl_run, task);
+ }
+ }
+
+ cur = g_list_next (cur);
+ }
+
+ /* XXX: backend runtime post-processing */
+ /* Post-processing */
+ while (cur) {
+ cl_run = (struct rspamd_classifier_runtime *)cur->data;
+ cl_run->stage = RSPAMD_STAT_STAGE_POST;
+
+ if (cl_run->cl) {
+ if (cl_ctx != NULL) {
if (cl_run->cl->classify_func (cl_ctx, cl_run->tok->tokens,
cl_run, task)) {
ret = RSPAMD_STAT_PROCESS_OK;