aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-01-05 18:02:47 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-01-05 18:02:47 +0000
commit8be715956820c4b14d3a3d7f41e50e4e69b7e679 (patch)
tree934d45c3e84b5e4c98982dee360267bd74e693a2
parenta53bd05ff9eac93f444b7bf8dd1607e16301ef5b (diff)
downloadrspamd-8be715956820c4b14d3a3d7f41e50e4e69b7e679.tar.gz
rspamd-8be715956820c4b14d3a3d7f41e50e4e69b7e679.zip
Fix bayes classifier for the new architecture
-rw-r--r--src/libstat/classifiers/bayes.c299
-rw-r--r--src/libstat/classifiers/classifiers.h33
-rw-r--r--src/libstat/stat_config.c1
-rw-r--r--src/libstat/stat_internal.h26
-rw-r--r--src/libstat/stat_process.c2
5 files changed, 157 insertions, 204 deletions
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c
index a271a424a..0915933f1 100644
--- a/src/libstat/classifiers/bayes.c
+++ b/src/libstat/classifiers/bayes.c
@@ -90,7 +90,10 @@ inv_chi_square (struct rspamd_task *task, gdouble value, gint freedom_deg)
}
struct bayes_task_closure {
- struct rspamd_classifier_runtime *rt;
+ double ham_prob;
+ double spam_prob;
+ guint64 processed_tokens;
+ guint64 total_hits;
struct rspamd_task *task;
};
@@ -104,44 +107,46 @@ static const double feature_weight[] = { 0, 1, 4, 27, 256, 3125, 46656, 823543 }
/*
* In this callback we calculate local probabilities for tokens
*/
-static gboolean
-bayes_classify_callback (gpointer key, gpointer value, gpointer data)
+static void
+bayes_classify_token (struct rspamd_classifier *ctx,
+ rspamd_token_t *tok, struct bayes_task_closure *cl)
{
- rspamd_token_t *node = value;
- struct bayes_task_closure *cl = data;
- struct rspamd_classifier_runtime *rt;
guint i;
- struct rspamd_token_result *res;
+ gint id;
guint64 spam_count = 0, ham_count = 0, total_count = 0;
+ struct rspamd_statfile *st;
struct rspamd_task *task;
double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
- ham_prob, fw, w, norm_sum, norm_sub;
+ ham_prob, fw, w, norm_sum, norm_sub, val;
- rt = cl->rt;
task = cl->task;
- for (i = rt->start_pos; i < rt->end_pos; i++) {
- res = &g_array_index (node->results, struct rspamd_token_result, i);
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index (ctx->statfiles_ids, gint, i);
+ st = g_ptr_array_index (ctx->ctx->statfiles, id);
+ g_assert (st != NULL);
+ val = tok->values[id];
- if (res->value > 0) {
- if (res->st_runtime->st->is_spam) {
- spam_count += res->value;
+ if (val > 0) {
+ if (st->stcf->is_spam) {
+ spam_count += val;
}
else {
- ham_count += res->value;
+ ham_count += val;
}
- total_count += res->value;
- res->st_runtime->total_hits += res->value;
+
+ total_count += val;
+ cl->total_hits += val;
}
}
/* Probability for this token */
if (total_count > 0) {
- spam_freq = ((double)spam_count / MAX (1., (double)rt->total_spam));
- ham_freq = ((double)ham_count / MAX (1., (double)rt->total_ham));
+ spam_freq = ((double)spam_count / MAX (1., (double) ctx->spam_learns));
+ ham_freq = ((double)ham_count / MAX (1., (double)ctx->ham_learns));
spam_prob = spam_freq / (spam_freq + ham_freq);
ham_prob = ham_freq / (spam_freq + ham_freq);
- fw = feature_weight[node->window_idx % G_N_ELEMENTS (feature_weight)];
+ fw = feature_weight[tok->window_idx % G_N_ELEMENTS (feature_weight)];
norm_sum = (spam_freq + ham_freq) * (spam_freq + ham_freq);
norm_sub = (spam_freq - ham_freq) * (spam_freq - ham_freq);
w = (norm_sub) / (norm_sum) *
@@ -151,9 +156,9 @@ bayes_classify_callback (gpointer key, gpointer value, gpointer data)
w = (norm_sub) / (norm_sum) *
(fw * total_count) / (4.0 * (1.0 + fw * total_count));
bayes_ham_prob = PROB_COMBINE (ham_prob, total_count, w, 0.5);
- rt->spam_prob += log (bayes_spam_prob);
- rt->ham_prob += log (bayes_ham_prob);
- res->cl_runtime->processed_tokens ++;
+ cl->spam_prob += log (bayes_spam_prob);
+ cl->ham_prob += log (bayes_ham_prob);
+ cl->processed_tokens ++;
msg_debug_bayes ("token: weight: %f, total_count: %L, "
"spam_count: %L, ham_count: %L,"
@@ -163,10 +168,8 @@ bayes_classify_callback (gpointer key, gpointer value, gpointer data)
fw, total_count, spam_count, ham_count,
spam_prob, ham_prob,
bayes_spam_prob, bayes_ham_prob,
- rt->spam_prob, rt->ham_prob);
+ cl->spam_prob, cl->ham_prob);
}
-
- return FALSE;
}
/*
@@ -198,176 +201,146 @@ bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier *cl)
gboolean
bayes_classify (struct rspamd_classifier * ctx,
- GTree *input,
- struct rspamd_classifier_runtime *rt,
- struct rspamd_task *task)
+ GPtrArray *tokens,
+ struct rspamd_task *task)
{
double final_prob, h, s;
- guint maxhits = 0;
- struct rspamd_statfile_runtime *st, *selected_st = NULL;
- GList *cur;
char *sumbuf;
+ struct rspamd_statfile *st = NULL;
struct bayes_task_closure cl;
+ rspamd_token_t *tok;
+ guint i;
+ gint id;
+ GList *cur;
g_assert (ctx != NULL);
- g_assert (input != NULL);
- g_assert (rt != NULL);
- g_assert (rt->end_pos > rt->start_pos);
-
- if (rt->stage == RSPAMD_STAT_STAGE_PRE) {
- cl.rt = rt;
- cl.task = task;
- g_tree_foreach (input, bayes_classify_callback, &cl);
+ g_assert (tokens != NULL);
+
+ memset (&cl, 0, sizeof (cl));
+ cl.task = task;
+
+ for (i = 0; i < tokens->len; i ++) {
+ tok = g_ptr_array_index (tokens, i);
+
+ bayes_classify_token (ctx, tok, &cl);
+ }
+
+ h = 1 - inv_chi_square (task, cl.spam_prob, cl.processed_tokens);
+ s = 1 - inv_chi_square (task, cl.ham_prob, cl.processed_tokens);
+
+ if (isfinite (s) && isfinite (h)) {
+ final_prob = (s + 1.0 - h) / 2.;
+ msg_debug_bayes (
+ "<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f,"
+ " %L tokens processed of %ud total tokens",
+ task->message_id,
+ cl.ham_prob,
+ h,
+ cl.spam_prob,
+ s,
+ cl.processed_tokens,
+ tokens->len);
}
else {
- h = 1 - inv_chi_square (task, rt->spam_prob, rt->processed_tokens);
- s = 1 - inv_chi_square (task, rt->ham_prob, rt->processed_tokens);
-
- if (isfinite (s) && isfinite (h)) {
- final_prob = (s + 1.0 - h) / 2.;
- msg_debug_bayes ("<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f,"
- " %L tokens processed of %ud total tokens",
- task->message_id, rt->ham_prob, h, rt->spam_prob, s,
- rt->processed_tokens, g_tree_nnodes (input));
+ /*
+ * We have some overflow, hence we need to check which class
+ * is NaN
+ */
+ if (isfinite (h)) {
+ final_prob = 1.0;
+ msg_debug_bayes ("<%s> spam class is overflowed, as we have no"
+ " ham samples", task->message_id);
+ }
+ else if (isfinite (s)) {
+ final_prob = 0.0;
+ msg_debug_bayes ("<%s> ham class is overflowed, as we have no"
+ " spam samples", task->message_id);
}
else {
- /*
- * We have some overflow, hence we need to check which class
- * is NaN
- */
- if (isfinite (h)) {
- final_prob = 1.0;
- msg_debug_bayes ("<%s> spam class is overflowed, as we have no"
- " ham samples", task->message_id);
- }
- else if (isfinite (s)){
- final_prob = 0.0;
- msg_debug_bayes ("<%s> ham class is overflowed, as we have no"
- " spam samples", task->message_id);
- }
- else {
- final_prob = 0.5;
- msg_warn_bayes ("<%s> spam and ham classes are both overflowed",
- task->message_id);
- }
+ final_prob = 0.5;
+ msg_warn_bayes ("<%s> spam and ham classes are both overflowed",
+ task->message_id);
}
+ }
- if (rt->processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
-
- sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
- cur = g_list_first (rt->st_runtime);
+ if (cl.processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
- while (cur) {
- st = (struct rspamd_statfile_runtime *)cur->data;
+ sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
- if ((final_prob < 0.5 && !st->st->is_spam) ||
- (final_prob > 0.5 && st->st->is_spam)) {
- if (st->total_hits > maxhits) {
- maxhits = st->total_hits;
- selected_st = st;
- }
- }
+ /* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index (ctx->statfiles_ids, gint, i);
+ st = g_ptr_array_index (ctx->ctx->statfiles, id);
- cur = g_list_next (cur);
+ if (final_prob > 0.5 && st->stcf->is_spam) {
+ break;
}
-
- if (selected_st == NULL) {
- msg_err_bayes (
- "unexpected classifier error: cannot select desired statfile, "
- "prob: %.4f", final_prob);
+ else if (final_prob < 0.5 && !st->stcf->is_spam) {
+ break;
}
- else {
- /* Correctly scale HAM */
- if (final_prob < 0.5) {
- final_prob = 1.0 - final_prob;
- }
-
- rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
- final_prob = bayes_normalize_prob (final_prob);
+ }
- cur = g_list_prepend (NULL, sumbuf);
- rspamd_task_insert_result (task,
- selected_st->st->symbol,
- final_prob,
- cur);
- }
+ /* Correctly scale HAM */
+ if (final_prob < 0.5) {
+ final_prob = 1.0 - final_prob;
}
+
+ rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
+ final_prob = bayes_normalize_prob (final_prob);
+ g_assert (st != NULL);
+ cur = g_list_prepend (NULL, sumbuf);
+ rspamd_task_insert_result (task,
+ st->stcf->symbol,
+ final_prob,
+ cur);
}
return TRUE;
}
-static gboolean
-bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data)
+gboolean
+bayes_learn_spam (struct rspamd_classifier * ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ GError **err)
{
- rspamd_token_t *node = value;
- struct rspamd_token_result *res;
- struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
- guint i;
-
+ guint i, j;
+ gint id;
+ struct rspamd_statfile *st;
+ rspamd_token_t *tok;
- for (i = rt->start_pos; i < rt->end_pos; i++) {
- res = &g_array_index (node->results, struct rspamd_token_result, i);
-
- if (res->st_runtime) {
- if (res->st_runtime->st->is_spam) {
- res->value ++;
- }
- else if (res->value > 0) {
- /* Unlearning */
- res->value --;
- }
- }
- }
-
- return FALSE;
-}
-
-static gboolean
-bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data)
-{
- rspamd_token_t *node = value;
- struct rspamd_token_result *res;
- struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
- guint i;
+ g_assert (ctx != NULL);
+ g_assert (tokens != NULL);
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index (tokens, i);
- for (i = rt->start_pos; i < rt->end_pos; i++) {
- res = &g_array_index (node->results, struct rspamd_token_result, i);
+ for (j = 0; j < ctx->statfiles_ids->len; j++) {
+ id = g_array_index (ctx->statfiles_ids, gint, j);
+ st = g_ptr_array_index (ctx->ctx->statfiles, id);
+ g_assert (st != NULL);
- if (res->st_runtime) {
- if (!res->st_runtime->st->is_spam) {
- res->value ++;
+ if (is_spam) {
+ if (st->stcf->is_spam) {
+ tok->values[id]++;
+ }
+ else if (tok->values[id] > 0) {
+ /* Unlearning */
+ tok->values[id]--;
+ }
}
- else if (res->value > 0) {
- res->value --;
+ else {
+ if (!st->stcf->is_spam) {
+ tok->values[id]++;
+ }
+ else if (tok->values[id] > 0) {
+ /* Unlearning */
+ tok->values[id]--;
+ }
}
}
}
- return FALSE;
-}
-
-gboolean
-bayes_learn_spam (struct rspamd_classifier * ctx,
- GTree *input,
- struct rspamd_classifier_runtime *rt,
- struct rspamd_task *task,
- gboolean is_spam,
- GError **err)
-{
- g_assert (ctx != NULL);
- g_assert (input != NULL);
- g_assert (rt != NULL);
- g_assert (rt->end_pos > rt->start_pos);
-
- if (is_spam) {
- g_tree_foreach (input, bayes_learn_spam_callback, rt);
- }
- else {
- g_tree_foreach (input, bayes_learn_ham_callback, rt);
- }
-
-
return TRUE;
}
diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h
index 62abb0052..52b9a89f7 100644
--- a/src/libstat/classifiers/classifiers.h
+++ b/src/libstat/classifiers/classifiers.h
@@ -12,34 +12,31 @@ struct rspamd_task;
struct rspamd_classifier;
struct token_node_s;
-struct rspamd_classifier_runtime;
struct rspamd_stat_classifier {
char *name;
void (*init_func)(rspamd_mempool_t *pool,
- struct rspamd_classifier *cl);
+ struct rspamd_classifier *cl);
gboolean (*classify_func)(struct rspamd_classifier * ctx,
- GTree *input, struct rspamd_classifier_runtime *rt,
- struct rspamd_task *task);
+ GPtrArray *tokens,
+ struct rspamd_task *task);
gboolean (*learn_spam_func)(struct rspamd_classifier * ctx,
- GTree *input, struct rspamd_classifier_runtime *rt,
- struct rspamd_task *task, gboolean is_spam,
- GError **err);
+ GPtrArray *input,
+ struct rspamd_task *task, gboolean is_spam,
+ GError **err);
};
/* Bayes algorithm */
void bayes_init (rspamd_mempool_t *pool,
- struct rspamd_classifier *);
-gboolean bayes_classify (struct rspamd_classifier * ctx,
- GTree *input,
- struct rspamd_classifier_runtime *rt,
- struct rspamd_task *task);
-gboolean bayes_learn_spam (struct rspamd_classifier * ctx,
- GTree *input,
- struct rspamd_classifier_runtime *rt,
- struct rspamd_task *task,
- gboolean is_spam,
- GError **err);
+ struct rspamd_classifier *);
+gboolean bayes_classify (struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task);
+gboolean bayes_learn_spam (struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ GError **err);
#endif
/*
diff --git a/src/libstat/stat_config.c b/src/libstat/stat_config.c
index 647079367..1cf19d412 100644
--- a/src/libstat/stat_config.c
+++ b/src/libstat/stat_config.c
@@ -133,6 +133,7 @@ rspamd_stat_init (struct rspamd_config *cfg)
cl = g_slice_alloc0 (sizeof (*cl));
cl->cfg = clf;
+ cl->ctx = stat_ctx;
cl->statfiles_ids = g_array_new (FALSE, FALSE, sizeof (gint));
/* Init classifier cache */
diff --git a/src/libstat/stat_internal.h b/src/libstat/stat_internal.h
index 09fd87fb6..edced84de 100644
--- a/src/libstat/stat_internal.h
+++ b/src/libstat/stat_internal.h
@@ -30,11 +30,6 @@
#include "backends/backends.h"
#include "learn_cache/learn_cache.h"
-enum stat_process_stage {
- RSPAMD_STAT_STAGE_PRE = 0,
- RSPAMD_STAT_STAGE_POST
-};
-
struct rspamd_statfile_runtime {
struct rspamd_statfile_config *st;
gpointer backend_runtime;
@@ -42,29 +37,14 @@ struct rspamd_statfile_runtime {
guint64 total_hits;
};
-struct rspamd_classifier_runtime {
- struct rspamd_classifier_config *clcf;
- struct classifier_ctx *clctx;
- struct rspamd_stat_classifier *cl;
- struct rspamd_stat_backend *backend;
- struct rspamd_tokenizer_runtime *tok;
- double ham_prob;
- double spam_prob;
- enum stat_process_stage stage;
- guint64 total_spam;
- guint64 total_ham;
- guint64 processed_tokens;
- GList *st_runtime;
- guint start_pos;
- guint end_pos;
- gboolean skipped;
-};
-
/* Common classifier structure */
struct rspamd_classifier {
+ struct rspamd_stat_ctx *ctx;
struct rspamd_stat_cache *cache;
gpointer cachecf;
GArray *statfiles_ids;
+ gulong spam_learns;
+ gulong ham_learns;
struct rspamd_classifier_config *cfg;
};
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 8bdf394b1..1506f4d48 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -37,6 +37,7 @@
static const gint similarity_treshold = 80;
+#if 0
struct preprocess_cb_data {
struct rspamd_task *task;
GList *classifier_runtimes;
@@ -910,3 +911,4 @@ rspamd_stat_result_t rspamd_stat_statistics (struct rspamd_task *task,
return RSPAMD_STAT_PROCESS_OK;
}
+#endif