}
struct bayes_task_closure {
- struct rspamd_classifier_runtime *rt;
+ double ham_prob;
+ double spam_prob;
+ guint64 processed_tokens;
+ guint64 total_hits;
struct rspamd_task *task;
};
/*
* In this callback we calculate local probabilities for tokens
*/
-static gboolean
-bayes_classify_callback (gpointer key, gpointer value, gpointer data)
+static void
+bayes_classify_token (struct rspamd_classifier *ctx,
+ rspamd_token_t *tok, struct bayes_task_closure *cl)
{
- rspamd_token_t *node = value;
- struct bayes_task_closure *cl = data;
- struct rspamd_classifier_runtime *rt;
guint i;
- struct rspamd_token_result *res;
+ gint id;
guint64 spam_count = 0, ham_count = 0, total_count = 0;
+ struct rspamd_statfile *st;
struct rspamd_task *task;
double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
- ham_prob, fw, w, norm_sum, norm_sub;
+ ham_prob, fw, w, norm_sum, norm_sub, val;
- rt = cl->rt;
task = cl->task;
- for (i = rt->start_pos; i < rt->end_pos; i++) {
- res = &g_array_index (node->results, struct rspamd_token_result, i);
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index (ctx->statfiles_ids, gint, i);
+ st = g_ptr_array_index (ctx->ctx->statfiles, id);
+ g_assert (st != NULL);
+ val = tok->values[id];
- if (res->value > 0) {
- if (res->st_runtime->st->is_spam) {
- spam_count += res->value;
+ if (val > 0) {
+ if (st->stcf->is_spam) {
+ spam_count += val;
}
else {
- ham_count += res->value;
+ ham_count += val;
}
- total_count += res->value;
- res->st_runtime->total_hits += res->value;
+
+ total_count += val;
+ cl->total_hits += val;
}
}
/* Probability for this token */
if (total_count > 0) {
- spam_freq = ((double)spam_count / MAX (1., (double)rt->total_spam));
- ham_freq = ((double)ham_count / MAX (1., (double)rt->total_ham));
+ spam_freq = ((double)spam_count / MAX (1., (double) ctx->spam_learns));
+ ham_freq = ((double)ham_count / MAX (1., (double)ctx->ham_learns));
spam_prob = spam_freq / (spam_freq + ham_freq);
ham_prob = ham_freq / (spam_freq + ham_freq);
- fw = feature_weight[node->window_idx % G_N_ELEMENTS (feature_weight)];
+ fw = feature_weight[tok->window_idx % G_N_ELEMENTS (feature_weight)];
norm_sum = (spam_freq + ham_freq) * (spam_freq + ham_freq);
norm_sub = (spam_freq - ham_freq) * (spam_freq - ham_freq);
w = (norm_sub) / (norm_sum) *
w = (norm_sub) / (norm_sum) *
(fw * total_count) / (4.0 * (1.0 + fw * total_count));
bayes_ham_prob = PROB_COMBINE (ham_prob, total_count, w, 0.5);
- rt->spam_prob += log (bayes_spam_prob);
- rt->ham_prob += log (bayes_ham_prob);
- res->cl_runtime->processed_tokens ++;
+ cl->spam_prob += log (bayes_spam_prob);
+ cl->ham_prob += log (bayes_ham_prob);
+ cl->processed_tokens ++;
msg_debug_bayes ("token: weight: %f, total_count: %L, "
"spam_count: %L, ham_count: %L,"
fw, total_count, spam_count, ham_count,
spam_prob, ham_prob,
bayes_spam_prob, bayes_ham_prob,
- rt->spam_prob, rt->ham_prob);
+ cl->spam_prob, cl->ham_prob);
}
-
- return FALSE;
}
/*
gboolean
bayes_classify (struct rspamd_classifier * ctx,
- GTree *input,
- struct rspamd_classifier_runtime *rt,
- struct rspamd_task *task)
+ GPtrArray *tokens,
+ struct rspamd_task *task)
{
double final_prob, h, s;
- guint maxhits = 0;
- struct rspamd_statfile_runtime *st, *selected_st = NULL;
- GList *cur;
char *sumbuf;
+ struct rspamd_statfile *st = NULL;
struct bayes_task_closure cl;
+ rspamd_token_t *tok;
+ guint i;
+ gint id;
+ GList *cur;
g_assert (ctx != NULL);
- g_assert (input != NULL);
- g_assert (rt != NULL);
- g_assert (rt->end_pos > rt->start_pos);
-
- if (rt->stage == RSPAMD_STAT_STAGE_PRE) {
- cl.rt = rt;
- cl.task = task;
- g_tree_foreach (input, bayes_classify_callback, &cl);
+ g_assert (tokens != NULL);
+
+ memset (&cl, 0, sizeof (cl));
+ cl.task = task;
+
+ for (i = 0; i < tokens->len; i ++) {
+ tok = g_ptr_array_index (tokens, i);
+
+ bayes_classify_token (ctx, tok, &cl);
+ }
+
+ h = 1 - inv_chi_square (task, cl.spam_prob, cl.processed_tokens);
+ s = 1 - inv_chi_square (task, cl.ham_prob, cl.processed_tokens);
+
+ if (isfinite (s) && isfinite (h)) {
+ final_prob = (s + 1.0 - h) / 2.;
+ msg_debug_bayes (
+ "<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f,"
+ " %L tokens processed of %ud total tokens",
+ task->message_id,
+ cl.ham_prob,
+ h,
+ cl.spam_prob,
+ s,
+ cl.processed_tokens,
+ tokens->len);
}
else {
- h = 1 - inv_chi_square (task, rt->spam_prob, rt->processed_tokens);
- s = 1 - inv_chi_square (task, rt->ham_prob, rt->processed_tokens);
-
- if (isfinite (s) && isfinite (h)) {
- final_prob = (s + 1.0 - h) / 2.;
- msg_debug_bayes ("<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f,"
- " %L tokens processed of %ud total tokens",
- task->message_id, rt->ham_prob, h, rt->spam_prob, s,
- rt->processed_tokens, g_tree_nnodes (input));
+ /*
+ * We have some overflow, hence we need to check which class
+ * is NaN
+ */
+ if (isfinite (h)) {
+ final_prob = 1.0;
+ msg_debug_bayes ("<%s> spam class is overflowed, as we have no"
+ " ham samples", task->message_id);
+ }
+ else if (isfinite (s)) {
+ final_prob = 0.0;
+ msg_debug_bayes ("<%s> ham class is overflowed, as we have no"
+ " spam samples", task->message_id);
}
else {
- /*
- * We have some overflow, hence we need to check which class
- * is NaN
- */
- if (isfinite (h)) {
- final_prob = 1.0;
- msg_debug_bayes ("<%s> spam class is overflowed, as we have no"
- " ham samples", task->message_id);
- }
- else if (isfinite (s)){
- final_prob = 0.0;
- msg_debug_bayes ("<%s> ham class is overflowed, as we have no"
- " spam samples", task->message_id);
- }
- else {
- final_prob = 0.5;
- msg_warn_bayes ("<%s> spam and ham classes are both overflowed",
- task->message_id);
- }
+ final_prob = 0.5;
+ msg_warn_bayes ("<%s> spam and ham classes are both overflowed",
+ task->message_id);
}
+ }
- if (rt->processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
-
- sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
- cur = g_list_first (rt->st_runtime);
+ if (cl.processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
- while (cur) {
- st = (struct rspamd_statfile_runtime *)cur->data;
+ sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
- if ((final_prob < 0.5 && !st->st->is_spam) ||
- (final_prob > 0.5 && st->st->is_spam)) {
- if (st->total_hits > maxhits) {
- maxhits = st->total_hits;
- selected_st = st;
- }
- }
+ /* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index (ctx->statfiles_ids, gint, i);
+ st = g_ptr_array_index (ctx->ctx->statfiles, id);
- cur = g_list_next (cur);
+ if (final_prob > 0.5 && st->stcf->is_spam) {
+ break;
}
-
- if (selected_st == NULL) {
- msg_err_bayes (
- "unexpected classifier error: cannot select desired statfile, "
- "prob: %.4f", final_prob);
+ else if (final_prob < 0.5 && !st->stcf->is_spam) {
+ break;
}
- else {
- /* Correctly scale HAM */
- if (final_prob < 0.5) {
- final_prob = 1.0 - final_prob;
- }
-
- rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
- final_prob = bayes_normalize_prob (final_prob);
+ }
- cur = g_list_prepend (NULL, sumbuf);
- rspamd_task_insert_result (task,
- selected_st->st->symbol,
- final_prob,
- cur);
- }
+ /* Correctly scale HAM */
+ if (final_prob < 0.5) {
+ final_prob = 1.0 - final_prob;
}
+
+ rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
+ final_prob = bayes_normalize_prob (final_prob);
+ g_assert (st != NULL);
+ cur = g_list_prepend (NULL, sumbuf);
+ rspamd_task_insert_result (task,
+ st->stcf->symbol,
+ final_prob,
+ cur);
}
return TRUE;
}
-static gboolean
-bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data)
+gboolean
+bayes_learn_spam (struct rspamd_classifier * ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ GError **err)
{
- rspamd_token_t *node = value;
- struct rspamd_token_result *res;
- struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
- guint i;
-
+ guint i, j;
+ gint id;
+ struct rspamd_statfile *st;
+ rspamd_token_t *tok;
- for (i = rt->start_pos; i < rt->end_pos; i++) {
- res = &g_array_index (node->results, struct rspamd_token_result, i);
-
- if (res->st_runtime) {
- if (res->st_runtime->st->is_spam) {
- res->value ++;
- }
- else if (res->value > 0) {
- /* Unlearning */
- res->value --;
- }
- }
- }
-
- return FALSE;
-}
-
-static gboolean
-bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data)
-{
- rspamd_token_t *node = value;
- struct rspamd_token_result *res;
- struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
- guint i;
+ g_assert (ctx != NULL);
+ g_assert (tokens != NULL);
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index (tokens, i);
- for (i = rt->start_pos; i < rt->end_pos; i++) {
- res = &g_array_index (node->results, struct rspamd_token_result, i);
+ for (j = 0; j < ctx->statfiles_ids->len; j++) {
+ id = g_array_index (ctx->statfiles_ids, gint, j);
+ st = g_ptr_array_index (ctx->ctx->statfiles, id);
+ g_assert (st != NULL);
- if (res->st_runtime) {
- if (!res->st_runtime->st->is_spam) {
- res->value ++;
+ if (is_spam) {
+ if (st->stcf->is_spam) {
+ tok->values[id]++;
+ }
+ else if (tok->values[id] > 0) {
+ /* Unlearning */
+ tok->values[id]--;
+ }
}
- else if (res->value > 0) {
- res->value --;
+ else {
+ if (!st->stcf->is_spam) {
+ tok->values[id]++;
+ }
+ else if (tok->values[id] > 0) {
+ /* Unlearning */
+ tok->values[id]--;
+ }
}
}
}
- return FALSE;
-}
-
-gboolean
-bayes_learn_spam (struct rspamd_classifier * ctx,
- GTree *input,
- struct rspamd_classifier_runtime *rt,
- struct rspamd_task *task,
- gboolean is_spam,
- GError **err)
-{
- g_assert (ctx != NULL);
- g_assert (input != NULL);
- g_assert (rt != NULL);
- g_assert (rt->end_pos > rt->start_pos);
-
- if (is_spam) {
- g_tree_foreach (input, bayes_learn_spam_callback, rt);
- }
- else {
- g_tree_foreach (input, bayes_learn_ham_callback, rt);
- }
-
-
return TRUE;
}