Browse Source

Fix bayes classifier for the new architecture

tags/1.1.0
Vsevolod Stakhov 8 years ago
parent
commit
8be7159568

+ 136
- 163
src/libstat/classifiers/bayes.c View File

@@ -90,7 +90,10 @@ inv_chi_square (struct rspamd_task *task, gdouble value, gint freedom_deg)
}

struct bayes_task_closure {
struct rspamd_classifier_runtime *rt;
double ham_prob;
double spam_prob;
guint64 processed_tokens;
guint64 total_hits;
struct rspamd_task *task;
};

@@ -104,44 +107,46 @@ static const double feature_weight[] = { 0, 1, 4, 27, 256, 3125, 46656, 823543 }
/*
* In this callback we calculate local probabilities for tokens
*/
static gboolean
bayes_classify_callback (gpointer key, gpointer value, gpointer data)
static void
bayes_classify_token (struct rspamd_classifier *ctx,
rspamd_token_t *tok, struct bayes_task_closure *cl)
{
rspamd_token_t *node = value;
struct bayes_task_closure *cl = data;
struct rspamd_classifier_runtime *rt;
guint i;
struct rspamd_token_result *res;
gint id;
guint64 spam_count = 0, ham_count = 0, total_count = 0;
struct rspamd_statfile *st;
struct rspamd_task *task;
double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
ham_prob, fw, w, norm_sum, norm_sub;
ham_prob, fw, w, norm_sum, norm_sub, val;

rt = cl->rt;
task = cl->task;

for (i = rt->start_pos; i < rt->end_pos; i++) {
res = &g_array_index (node->results, struct rspamd_token_result, i);
for (i = 0; i < ctx->statfiles_ids->len; i++) {
id = g_array_index (ctx->statfiles_ids, gint, i);
st = g_ptr_array_index (ctx->ctx->statfiles, id);
g_assert (st != NULL);
val = tok->values[id];

if (res->value > 0) {
if (res->st_runtime->st->is_spam) {
spam_count += res->value;
if (val > 0) {
if (st->stcf->is_spam) {
spam_count += val;
}
else {
ham_count += res->value;
ham_count += val;
}
total_count += res->value;
res->st_runtime->total_hits += res->value;

total_count += val;
cl->total_hits += val;
}
}

/* Probability for this token */
if (total_count > 0) {
spam_freq = ((double)spam_count / MAX (1., (double)rt->total_spam));
ham_freq = ((double)ham_count / MAX (1., (double)rt->total_ham));
spam_freq = ((double)spam_count / MAX (1., (double) ctx->spam_learns));
ham_freq = ((double)ham_count / MAX (1., (double)ctx->ham_learns));
spam_prob = spam_freq / (spam_freq + ham_freq);
ham_prob = ham_freq / (spam_freq + ham_freq);
fw = feature_weight[node->window_idx % G_N_ELEMENTS (feature_weight)];
fw = feature_weight[tok->window_idx % G_N_ELEMENTS (feature_weight)];
norm_sum = (spam_freq + ham_freq) * (spam_freq + ham_freq);
norm_sub = (spam_freq - ham_freq) * (spam_freq - ham_freq);
w = (norm_sub) / (norm_sum) *
@@ -151,9 +156,9 @@ bayes_classify_callback (gpointer key, gpointer value, gpointer data)
w = (norm_sub) / (norm_sum) *
(fw * total_count) / (4.0 * (1.0 + fw * total_count));
bayes_ham_prob = PROB_COMBINE (ham_prob, total_count, w, 0.5);
rt->spam_prob += log (bayes_spam_prob);
rt->ham_prob += log (bayes_ham_prob);
res->cl_runtime->processed_tokens ++;
cl->spam_prob += log (bayes_spam_prob);
cl->ham_prob += log (bayes_ham_prob);
cl->processed_tokens ++;

msg_debug_bayes ("token: weight: %f, total_count: %L, "
"spam_count: %L, ham_count: %L,"
@@ -163,10 +168,8 @@ bayes_classify_callback (gpointer key, gpointer value, gpointer data)
fw, total_count, spam_count, ham_count,
spam_prob, ham_prob,
bayes_spam_prob, bayes_ham_prob,
rt->spam_prob, rt->ham_prob);
cl->spam_prob, cl->ham_prob);
}

return FALSE;
}

/*
@@ -198,176 +201,146 @@ bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier *cl)

gboolean
bayes_classify (struct rspamd_classifier * ctx,
GTree *input,
struct rspamd_classifier_runtime *rt,
struct rspamd_task *task)
GPtrArray *tokens,
struct rspamd_task *task)
{
double final_prob, h, s;
guint maxhits = 0;
struct rspamd_statfile_runtime *st, *selected_st = NULL;
GList *cur;
char *sumbuf;
struct rspamd_statfile *st = NULL;
struct bayes_task_closure cl;
rspamd_token_t *tok;
guint i;
gint id;
GList *cur;

g_assert (ctx != NULL);
g_assert (input != NULL);
g_assert (rt != NULL);
g_assert (rt->end_pos > rt->start_pos);

if (rt->stage == RSPAMD_STAT_STAGE_PRE) {
cl.rt = rt;
cl.task = task;
g_tree_foreach (input, bayes_classify_callback, &cl);
g_assert (tokens != NULL);

memset (&cl, 0, sizeof (cl));
cl.task = task;

for (i = 0; i < tokens->len; i ++) {
tok = g_ptr_array_index (tokens, i);

bayes_classify_token (ctx, tok, &cl);
}

h = 1 - inv_chi_square (task, cl.spam_prob, cl.processed_tokens);
s = 1 - inv_chi_square (task, cl.ham_prob, cl.processed_tokens);

if (isfinite (s) && isfinite (h)) {
final_prob = (s + 1.0 - h) / 2.;
msg_debug_bayes (
"<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f,"
" %L tokens processed of %ud total tokens",
task->message_id,
cl.ham_prob,
h,
cl.spam_prob,
s,
cl.processed_tokens,
tokens->len);
}
else {
h = 1 - inv_chi_square (task, rt->spam_prob, rt->processed_tokens);
s = 1 - inv_chi_square (task, rt->ham_prob, rt->processed_tokens);

if (isfinite (s) && isfinite (h)) {
final_prob = (s + 1.0 - h) / 2.;
msg_debug_bayes ("<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f,"
" %L tokens processed of %ud total tokens",
task->message_id, rt->ham_prob, h, rt->spam_prob, s,
rt->processed_tokens, g_tree_nnodes (input));
/*
* We have some overflow, hence we need to check which class
* is NaN
*/
if (isfinite (h)) {
final_prob = 1.0;
msg_debug_bayes ("<%s> spam class is overflowed, as we have no"
" ham samples", task->message_id);
}
else if (isfinite (s)) {
final_prob = 0.0;
msg_debug_bayes ("<%s> ham class is overflowed, as we have no"
" spam samples", task->message_id);
}
else {
/*
* We have some overflow, hence we need to check which class
* is NaN
*/
if (isfinite (h)) {
final_prob = 1.0;
msg_debug_bayes ("<%s> spam class is overflowed, as we have no"
" ham samples", task->message_id);
}
else if (isfinite (s)){
final_prob = 0.0;
msg_debug_bayes ("<%s> ham class is overflowed, as we have no"
" spam samples", task->message_id);
}
else {
final_prob = 0.5;
msg_warn_bayes ("<%s> spam and ham classes are both overflowed",
task->message_id);
}
final_prob = 0.5;
msg_warn_bayes ("<%s> spam and ham classes are both overflowed",
task->message_id);
}
}

if (rt->processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {

sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
cur = g_list_first (rt->st_runtime);
if (cl.processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {

while (cur) {
st = (struct rspamd_statfile_runtime *)cur->data;
sumbuf = rspamd_mempool_alloc (task->task_pool, 32);

if ((final_prob < 0.5 && !st->st->is_spam) ||
(final_prob > 0.5 && st->st->is_spam)) {
if (st->total_hits > maxhits) {
maxhits = st->total_hits;
selected_st = st;
}
}
/* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */
for (i = 0; i < ctx->statfiles_ids->len; i++) {
id = g_array_index (ctx->statfiles_ids, gint, i);
st = g_ptr_array_index (ctx->ctx->statfiles, id);

cur = g_list_next (cur);
if (final_prob > 0.5 && st->stcf->is_spam) {
break;
}

if (selected_st == NULL) {
msg_err_bayes (
"unexpected classifier error: cannot select desired statfile, "
"prob: %.4f", final_prob);
else if (final_prob < 0.5 && !st->stcf->is_spam) {
break;
}
else {
/* Correctly scale HAM */
if (final_prob < 0.5) {
final_prob = 1.0 - final_prob;
}

rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
final_prob = bayes_normalize_prob (final_prob);
}

cur = g_list_prepend (NULL, sumbuf);
rspamd_task_insert_result (task,
selected_st->st->symbol,
final_prob,
cur);
}
/* Correctly scale HAM */
if (final_prob < 0.5) {
final_prob = 1.0 - final_prob;
}

rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
final_prob = bayes_normalize_prob (final_prob);
g_assert (st != NULL);
cur = g_list_prepend (NULL, sumbuf);
rspamd_task_insert_result (task,
st->stcf->symbol,
final_prob,
cur);
}

return TRUE;
}

static gboolean
bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data)
gboolean
bayes_learn_spam (struct rspamd_classifier * ctx,
GPtrArray *tokens,
struct rspamd_task *task,
gboolean is_spam,
GError **err)
{
rspamd_token_t *node = value;
struct rspamd_token_result *res;
struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
guint i;

guint i, j;
gint id;
struct rspamd_statfile *st;
rspamd_token_t *tok;

for (i = rt->start_pos; i < rt->end_pos; i++) {
res = &g_array_index (node->results, struct rspamd_token_result, i);

if (res->st_runtime) {
if (res->st_runtime->st->is_spam) {
res->value ++;
}
else if (res->value > 0) {
/* Unlearning */
res->value --;
}
}
}

return FALSE;
}

static gboolean
bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data)
{
rspamd_token_t *node = value;
struct rspamd_token_result *res;
struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
guint i;
g_assert (ctx != NULL);
g_assert (tokens != NULL);

for (i = 0; i < tokens->len; i++) {
tok = g_ptr_array_index (tokens, i);

for (i = rt->start_pos; i < rt->end_pos; i++) {
res = &g_array_index (node->results, struct rspamd_token_result, i);
for (j = 0; j < ctx->statfiles_ids->len; j++) {
id = g_array_index (ctx->statfiles_ids, gint, j);
st = g_ptr_array_index (ctx->ctx->statfiles, id);
g_assert (st != NULL);

if (res->st_runtime) {
if (!res->st_runtime->st->is_spam) {
res->value ++;
if (is_spam) {
if (st->stcf->is_spam) {
tok->values[id]++;
}
else if (tok->values[id] > 0) {
/* Unlearning */
tok->values[id]--;
}
}
else if (res->value > 0) {
res->value --;
else {
if (!st->stcf->is_spam) {
tok->values[id]++;
}
else if (tok->values[id] > 0) {
/* Unlearning */
tok->values[id]--;
}
}
}
}

return FALSE;
}

gboolean
bayes_learn_spam (struct rspamd_classifier * ctx,
GTree *input,
struct rspamd_classifier_runtime *rt,
struct rspamd_task *task,
gboolean is_spam,
GError **err)
{
g_assert (ctx != NULL);
g_assert (input != NULL);
g_assert (rt != NULL);
g_assert (rt->end_pos > rt->start_pos);

if (is_spam) {
g_tree_foreach (input, bayes_learn_spam_callback, rt);
}
else {
g_tree_foreach (input, bayes_learn_ham_callback, rt);
}


return TRUE;
}

+ 15
- 18
src/libstat/classifiers/classifiers.h View File

@@ -12,34 +12,31 @@ struct rspamd_task;
struct rspamd_classifier;

struct token_node_s;
struct rspamd_classifier_runtime;

struct rspamd_stat_classifier {
char *name;
void (*init_func)(rspamd_mempool_t *pool,
struct rspamd_classifier *cl);
struct rspamd_classifier *cl);
gboolean (*classify_func)(struct rspamd_classifier * ctx,
GTree *input, struct rspamd_classifier_runtime *rt,
struct rspamd_task *task);
GPtrArray *tokens,
struct rspamd_task *task);
gboolean (*learn_spam_func)(struct rspamd_classifier * ctx,
GTree *input, struct rspamd_classifier_runtime *rt,
struct rspamd_task *task, gboolean is_spam,
GError **err);
GPtrArray *input,
struct rspamd_task *task, gboolean is_spam,
GError **err);
};

/* Bayes algorithm */
void bayes_init (rspamd_mempool_t *pool,
struct rspamd_classifier *);
gboolean bayes_classify (struct rspamd_classifier * ctx,
GTree *input,
struct rspamd_classifier_runtime *rt,
struct rspamd_task *task);
gboolean bayes_learn_spam (struct rspamd_classifier * ctx,
GTree *input,
struct rspamd_classifier_runtime *rt,
struct rspamd_task *task,
gboolean is_spam,
GError **err);
struct rspamd_classifier *);
gboolean bayes_classify (struct rspamd_classifier *ctx,
GPtrArray *tokens,
struct rspamd_task *task);
gboolean bayes_learn_spam (struct rspamd_classifier *ctx,
GPtrArray *tokens,
struct rspamd_task *task,
gboolean is_spam,
GError **err);

#endif
/*

+ 1
- 0
src/libstat/stat_config.c View File

@@ -133,6 +133,7 @@ rspamd_stat_init (struct rspamd_config *cfg)

cl = g_slice_alloc0 (sizeof (*cl));
cl->cfg = clf;
cl->ctx = stat_ctx;
cl->statfiles_ids = g_array_new (FALSE, FALSE, sizeof (gint));

/* Init classifier cache */

+ 3
- 23
src/libstat/stat_internal.h View File

@@ -30,11 +30,6 @@
#include "backends/backends.h"
#include "learn_cache/learn_cache.h"

enum stat_process_stage {
RSPAMD_STAT_STAGE_PRE = 0,
RSPAMD_STAT_STAGE_POST
};

struct rspamd_statfile_runtime {
struct rspamd_statfile_config *st;
gpointer backend_runtime;
@@ -42,29 +37,14 @@ struct rspamd_statfile_runtime {
guint64 total_hits;
};

struct rspamd_classifier_runtime {
struct rspamd_classifier_config *clcf;
struct classifier_ctx *clctx;
struct rspamd_stat_classifier *cl;
struct rspamd_stat_backend *backend;
struct rspamd_tokenizer_runtime *tok;
double ham_prob;
double spam_prob;
enum stat_process_stage stage;
guint64 total_spam;
guint64 total_ham;
guint64 processed_tokens;
GList *st_runtime;
guint start_pos;
guint end_pos;
gboolean skipped;
};

/* Common classifier structure */
struct rspamd_classifier {
struct rspamd_stat_ctx *ctx;
struct rspamd_stat_cache *cache;
gpointer cachecf;
GArray *statfiles_ids;
gulong spam_learns;
gulong ham_learns;
struct rspamd_classifier_config *cfg;
};


+ 2
- 0
src/libstat/stat_process.c View File

@@ -37,6 +37,7 @@

static const gint similarity_treshold = 80;

#if 0
struct preprocess_cb_data {
struct rspamd_task *task;
GList *classifier_runtimes;
@@ -910,3 +911,4 @@ rspamd_stat_result_t rspamd_stat_statistics (struct rspamd_task *task,

return RSPAMD_STAT_PROCESS_OK;
}
#endif

Loading…
Cancel
Save