aboutsummaryrefslogtreecommitdiffstats
path: root/src/libstat/classifiers/bayes.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/libstat/classifiers/bayes.c')
-rw-r--r--src/libstat/classifiers/bayes.c652
1 files changed, 613 insertions, 39 deletions
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c
index 93b5149da..dbae98cc2 100644
--- a/src/libstat/classifiers/bayes.c
+++ b/src/libstat/classifiers/bayes.c
@@ -1,11 +1,11 @@
-/*-
- * Copyright 2016 Vsevolod Stakhov
+/*
+ * Copyright 2025 Vsevolod Stakhov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -53,10 +53,26 @@ static double
inv_chi_square(struct rspamd_task *task, double value, int freedom_deg)
{
double prob, sum, m;
+ double log_prob, log_m;
int i;
errno = 0;
m = -value;
+
+ /* Handle extreme negative values that would cause exp() underflow */
+ if (value < -700) {
+ /* Very strong confidence, return 0 */
+ msg_debug_bayes("extreme negative value: %f, returning 0", value);
+ return 0.0;
+ }
+
+ /* Handle extreme positive values that would cause overflow */
+ if (value > 700) {
+ /* No confidence, return 1 */
+ msg_debug_bayes("extreme positive value: %f, returning 1", value);
+ return 1.0;
+ }
+
prob = exp(value);
if (errno == ERANGE) {
@@ -75,6 +91,8 @@ inv_chi_square(struct rspamd_task *task, double value, int freedom_deg)
}
sum = prob;
+ log_prob = value; /* log of current prob term */
+ log_m = log(fabs(m)); /* log of |m| for numerical stability */
msg_debug_bayes("m: %f, probability: %g", m, prob);
@@ -83,24 +101,60 @@ inv_chi_square(struct rspamd_task *task, double value, int freedom_deg)
* prob is e ^ x (small value since x is normally less than zero
* So we integrate over degrees of freedom and produce the total result
* from 1.0 (no confidence) to 0.0 (full confidence)
+ * Use logarithmic arithmetic to prevent overflow
*/
for (i = 1; i < freedom_deg; i++) {
- prob *= m / (double) i;
+ /* Calculate next term using logarithms to prevent overflow */
+ log_prob += log_m - log((double) i);
+
+ /* Check if the log probability is too negative (term becomes negligible) */
+ if (log_prob < -700) {
+ msg_debug_bayes("term %d became negligible, stopping series", i);
+ break;
+ }
+
+ /* Check if the log probability is too positive (would cause overflow) */
+ if (log_prob > 700) {
+ msg_debug_bayes("series diverging at term %d, returning 1.0", i);
+ return 1.0;
+ }
+
+ prob = exp(log_prob);
sum += prob;
- msg_debug_bayes("i=%d, probability: %g, sum: %g", i, prob, sum);
+ msg_debug_bayes("i=%d, log_prob: %g, probability: %g, sum: %g", i, log_prob, prob, sum);
+
+ /* Early termination if sum is getting too large */
+ if (sum > 1e10) {
+ msg_debug_bayes("sum too large (%g), returning 1.0", sum);
+ return 1.0;
+ }
}
return MIN(1.0, sum);
}
struct bayes_task_closure {
- double ham_prob;
- double spam_prob;
+ double ham_prob; /* Kept for binary compatibility */
+ double spam_prob; /* Kept for binary compatibility */
+ double meta_skip_prob;
+ uint64_t processed_tokens;
+ uint64_t total_hits;
+ uint64_t text_tokens;
+ struct rspamd_task *task;
+};
+
+/* Multi-class classification closure */
+struct bayes_multiclass_closure {
+ double *class_log_probs; /* Array of log probabilities for each class */
+ uint64_t *class_learns; /* Learning counts for each class */
+ char **class_names; /* Array of class names */
+ unsigned int num_classes; /* Number of classes */
double meta_skip_prob;
uint64_t processed_tokens;
uint64_t total_hits;
uint64_t text_tokens;
struct rspamd_task *task;
+ struct rspamd_classifier_config *cfg;
};
/*
@@ -122,7 +176,6 @@ bayes_classify_token(struct rspamd_classifier *ctx,
unsigned int spam_count = 0, ham_count = 0, total_count = 0;
struct rspamd_statfile *st;
struct rspamd_task *task;
- const char *token_type = "txt";
double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
ham_prob, fw, w, val;
@@ -211,41 +264,379 @@ bayes_classify_token(struct rspamd_classifier *ctx,
if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
cl->text_tokens++;
}
+ }
+}
+
+/*
+ * Multinomial token classification for multi-class Bayes
+ */
+static void
+bayes_classify_token_multiclass(struct rspamd_classifier *ctx,
+ rspamd_token_t *tok,
+ struct bayes_multiclass_closure *cl)
+{
+ unsigned int i, j;
+ int id;
+ struct rspamd_statfile *st;
+ struct rspamd_task *task;
+ double val, fw, w;
+ guint64 *class_counts;
+ guint64 total_count = 0;
+
+ task = cl->task;
+
+ /* Skip meta tokens probabilistically if configured */
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) {
+ val = rspamd_random_double_fast();
+ if (val <= cl->meta_skip_prob) {
+ return;
+ }
+ }
+
+ /* Allocate array for class counts */
+ class_counts = g_alloca(cl->num_classes * sizeof(guint64));
+ memset(class_counts, 0, cl->num_classes * sizeof(guint64));
+
+ /* Collect counts for each class */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index(ctx->statfiles_ids, int, i);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ g_assert(st != NULL);
+ val = tok->values[id];
+
+ if (val > 0) {
+ /* Direct O(1) class index lookup instead of O(N) string comparison */
+ if (st->stcf->class_name && st->stcf->class_index < cl->num_classes) {
+ unsigned int class_idx = st->stcf->class_index;
+ class_counts[class_idx] += val;
+ total_count += val;
+ cl->total_hits += val;
+ }
+ else {
+ msg_debug_bayes("invalid class_index %ud >= %ud for statfile %s",
+ st->stcf->class_index, cl->num_classes, st->stcf->symbol);
+ }
+ }
+ }
+
+ /* Calculate multinomial probability for this token */
+ if (total_count >= ctx->cfg->min_token_hits) {
+ /* Feature weight calculation */
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) {
+ fw = 1.0;
+ }
else {
- token_type = "meta";
+ fw = feature_weight[tok->window_idx % G_N_ELEMENTS(feature_weight)];
}
- if (tok->t1 && tok->t2) {
- msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %f, cf: %f, "
- "total_count: %ud, "
- "spam_count: %ud, ham_count: %ud,"
- "spam_prob: %.3f, ham_prob: %.3f, "
- "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
- "current spam probability: %.3f, current ham probability: %.3f",
- token_type,
- tok->data,
- (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
- (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
- fw, w, total_count, spam_count, ham_count,
- spam_prob, ham_prob,
- bayes_spam_prob, bayes_ham_prob,
- cl->spam_prob, cl->ham_prob);
+ w = (fw * total_count) / (1.0 + fw * total_count);
+
+ /* Apply multinomial model for each class */
+ for (j = 0; j < cl->num_classes; j++) {
+ /* Skip classes with insufficient learns */
+ if (ctx->cfg->min_learns > 0 && cl->class_learns[j] < ctx->cfg->min_learns) {
+ continue;
+ }
+
+ double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]);
+ double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes);
+
+ /* Ensure probability is properly bounded [0, 1] */
+ class_prob = MAX(0.0, MIN(1.0, class_prob));
+
+ /* Skip probabilities too close to uniform (1/num_classes) */
+ double uniform_prior = 1.0 / cl->num_classes;
+ if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) {
+ continue;
+ }
+
+ cl->class_log_probs[j] += log(class_prob);
+ }
+
+ cl->processed_tokens++;
+ if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
+ cl->text_tokens++;
+ }
+
+ /* Per-token debug logging removed to reduce verbosity */
+ }
+}
+
+/*
+ * Multinomial Bayes classification with Fisher confidence
+ */
+static gboolean
+bayes_classify_multiclass(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task)
+{
+ struct bayes_multiclass_closure cl;
+ rspamd_token_t *tok;
+ unsigned int i, j, text_tokens = 0;
+ int id;
+ struct rspamd_statfile *st;
+ rspamd_multiclass_result_t *result;
+ double *normalized_probs;
+ double max_log_prob = -INFINITY;
+ unsigned int winning_class_idx = 0;
+ double confidence;
+
+ g_assert(ctx != NULL);
+ g_assert(tokens != NULL);
+
+ /* Initialize multi-class closure */
+ memset(&cl, 0, sizeof(cl));
+ cl.task = task;
+ cl.cfg = ctx->cfg;
+
+ /* Get class information from classifier config */
+ if (!ctx->cfg->class_names) {
+ msg_debug_bayes("no class_names array in classifier config");
+ return TRUE; /* Fall back to binary mode */
+ }
+ if (ctx->cfg->class_names->len < 2) {
+ msg_debug_bayes("insufficient classes: %ud < 2", (unsigned int) ctx->cfg->class_names->len);
+ return TRUE; /* Fall back to binary mode */
+ }
+ if (!ctx->cfg->class_names->pdata) {
+ msg_debug_bayes("class_names->pdata is NULL");
+ return TRUE; /* Fall back to binary mode */
+ }
+
+ cl.num_classes = ctx->cfg->class_names->len;
+ cl.class_names = (char **) ctx->cfg->class_names->pdata;
+
+ /* Debug: verify class names are accessible */
+ msg_debug_bayes("multiclass setup: ctx->cfg->class_names=%p, len=%ud, pdata=%p",
+ ctx->cfg->class_names, (unsigned int) ctx->cfg->class_names->len, ctx->cfg->class_names->pdata);
+ msg_debug_bayes("multiclass setup: cl.num_classes=%ud, cl.class_names=%p",
+ cl.num_classes, cl.class_names);
+ cl.class_log_probs = g_alloca(cl.num_classes * sizeof(double));
+ cl.class_learns = g_alloca(cl.num_classes * sizeof(uint64_t));
+
+ /* Initialize probabilities and get learning counts */
+ for (i = 0; i < cl.num_classes; i++) {
+ cl.class_log_probs[i] = 0.0;
+ cl.class_learns[i] = 0;
+ }
+
+ /* Collect learning counts for each class */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index(ctx->statfiles_ids, int, i);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ g_assert(st != NULL);
+
+ for (j = 0; j < cl.num_classes; j++) {
+ if (st->stcf->class_name &&
+ strcmp(st->stcf->class_name, cl.class_names[j]) == 0) {
+ cl.class_learns[j] += st->backend->total_learns(task,
+ g_ptr_array_index(task->stat_runtimes, id), ctx->ctx);
+ break;
+ }
+ }
+ }
+
+ /* Check minimum learns requirement - count viable classes */
+ unsigned int viable_classes = 0;
+ if (ctx->cfg->min_learns > 0) {
+ for (i = 0; i < cl.num_classes; i++) {
+ if (cl.class_learns[i] >= ctx->cfg->min_learns) {
+ viable_classes++;
+ }
+ else {
+ msg_info_task("class %s excluded from classification: %uL learns < %ud minimum",
+ cl.class_names[i], cl.class_learns[i], ctx->cfg->min_learns);
+ }
+ }
+
+ if (viable_classes == 0) {
+ msg_info_task("no classes have sufficient training samples for classification");
+ return TRUE;
+ }
+
+ msg_info_bayes("multiclass classification: %ud/%ud classes have sufficient learns",
+ viable_classes, cl.num_classes);
+ }
+
+ /* Count text tokens */
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
+ text_tokens++;
+ }
+ }
+
+ if (text_tokens == 0) {
+ msg_info_task("skipped classification as there are no text tokens. "
+ "Total tokens: %ud",
+ tokens->len);
+ return TRUE;
+ }
+
+ /* Set meta token skip probability */
+ if (text_tokens > tokens->len - text_tokens) {
+ cl.meta_skip_prob = 0.0;
+ }
+ else {
+ cl.meta_skip_prob = 1.0 - (double) text_tokens / tokens->len;
+ }
+
+ /* Process all tokens */
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ bayes_classify_token_multiclass(ctx, tok, &cl);
+ }
+
+ if (cl.processed_tokens == 0) {
+ /* Debug: check why no tokens were processed */
+ msg_debug_bayes("examining token values for debugging:");
+ for (i = 0; i < MIN(tokens->len, 10); i++) { /* Check first 10 tokens */
+ tok = g_ptr_array_index(tokens, i);
+ for (j = 0; j < ctx->statfiles_ids->len; j++) {
+ id = g_array_index(ctx->statfiles_ids, int, j);
+ if (tok->values[id] > 0) {
+ struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ msg_debug_bayes("token %ud: values[%d] = %.2f (class=%s, symbol=%s)",
+ i, id, tok->values[id],
+ st->stcf->class_name ? st->stcf->class_name : "unknown",
+ st->stcf->symbol);
+ }
+ }
+ }
+
+ msg_info_bayes("no tokens found in bayes database "
+ "(%ud total tokens, %ud text tokens), ignore stats",
+ tokens->len, text_tokens);
+ return TRUE;
+ }
+
+ if (ctx->cfg->min_tokens > 0 &&
+ cl.text_tokens < (int) (ctx->cfg->min_tokens * 0.1)) {
+ msg_info_bayes("ignore bayes probability since we have "
+ "found too few text tokens: %uL (of %ud checked), "
+ "at least %d required",
+ cl.text_tokens, text_tokens,
+ (int) (ctx->cfg->min_tokens * 0.1));
+ return TRUE;
+ }
+
+ /* Normalize probabilities using softmax */
+ normalized_probs = g_alloca(cl.num_classes * sizeof(double));
+
+ /* Find maximum for numerical stability - only consider classes with sufficient training */
+ for (i = 0; i < cl.num_classes; i++) {
+ msg_debug_bayes("class %s, log_prob: %.2f", cl.class_names[i], cl.class_log_probs[i]);
+ /* Only consider classes that have sufficient training data */
+ if (ctx->cfg->min_learns > 0 && cl.class_learns[i] < ctx->cfg->min_learns) {
+ msg_debug_bayes("skipping class %s in winner selection: %uL learns < %ud minimum",
+ cl.class_names[i], cl.class_learns[i], ctx->cfg->min_learns);
+ continue;
+ }
+ if (cl.class_log_probs[i] > max_log_prob) {
+ max_log_prob = cl.class_log_probs[i];
+ winning_class_idx = i;
+ }
+ }
+
+ /* Apply softmax normalization */
+ double sum_exp = 0.0;
+ for (i = 0; i < cl.num_classes; i++) {
+ normalized_probs[i] = exp(cl.class_log_probs[i] - max_log_prob);
+ sum_exp += normalized_probs[i];
+ }
+
+ if (sum_exp > 0) {
+ for (i = 0; i < cl.num_classes; i++) {
+ normalized_probs[i] /= sum_exp;
+ }
+ }
+ else {
+ /* Fallback to uniform distribution */
+ for (i = 0; i < cl.num_classes; i++) {
+ normalized_probs[i] = 1.0 / cl.num_classes;
+ }
+ }
+
+ /* Calculate confidence using Fisher method for the winning class */
+ if (max_log_prob > -300) {
+ if (max_log_prob > 0) {
+ /* Positive log prob means very strong evidence - high confidence */
+ confidence = 0.95; /* High confidence for positive log probabilities */
+ msg_debug_bayes("positive log_prob (%g), setting high confidence", max_log_prob);
}
else {
- msg_debug_bayes("token(%s) %uL <?:?>: weight: %f, cf: %f, "
- "total_count: %ud, "
- "spam_count: %ud, ham_count: %ud,"
- "spam_prob: %.3f, ham_prob: %.3f, "
- "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
- "current spam probability: %.3f, current ham probability: %.3f",
- token_type,
- tok->data,
- fw, w, total_count, spam_count, ham_count,
- spam_prob, ham_prob,
- bayes_spam_prob, bayes_ham_prob,
- cl->spam_prob, cl->ham_prob);
+ /* Negative log prob - use Fisher method as intended */
+ double fisher_result = inv_chi_square(task, max_log_prob, cl.processed_tokens);
+ confidence = 1.0 - fisher_result;
+
+ msg_debug_bayes("fisher_result: %g, max_log_prob: %g, condition check: fisher_result > 0.999 = %s, max_log_prob > -50 = %s",
+ fisher_result, max_log_prob,
+ fisher_result > 0.999 ? "true" : "false",
+ max_log_prob > -50 ? "true" : "false");
+
+ /* Handle case where Fisher method indicates extreme confidence */
+ if (fisher_result > 0.999 && max_log_prob > -100) {
+ /* Large magnitude negative log prob means strong evidence */
+ confidence = 0.90;
+ msg_debug_bayes("extreme negative log_prob (%g), setting high confidence", max_log_prob);
+ }
}
}
+ else {
+ confidence = normalized_probs[winning_class_idx];
+ }
+
+ /* Create and store multiclass result */
+ result = g_new0(rspamd_multiclass_result_t, 1);
+ result->class_names = g_new(char *, cl.num_classes);
+ result->probabilities = g_new(double, cl.num_classes);
+ result->num_classes = cl.num_classes;
+ result->winning_class = cl.class_names[winning_class_idx]; /* Reference, not copy */
+ result->confidence = confidence;
+
+ for (i = 0; i < cl.num_classes; i++) {
+ result->class_names[i] = g_strdup(cl.class_names[i]);
+ result->probabilities[i] = normalized_probs[i];
+ }
+
+ rspamd_task_set_multiclass_result(task, result);
+
+ msg_info_bayes("MULTICLASS_RESULT: winning_class='%s', confidence=%.3f, normalized_prob=%.3f, tokens=%uL",
+ cl.class_names[winning_class_idx], confidence,
+ normalized_probs[winning_class_idx], cl.processed_tokens);
+
+ /* Insert symbol for winning class if confidence is significant */
+ if (confidence > 0.05) {
+ char sumbuf[32];
+ double final_prob = rspamd_normalize_probability(confidence, 0.5);
+
+ rspamd_snprintf(sumbuf, sizeof(sumbuf), "%.2f%%", confidence * 100.0);
+
+ /* Find the statfile for the winning class to get the symbol */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index(ctx->statfiles_ids, int, i);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+
+ if (st->stcf->class_name &&
+ strcmp(st->stcf->class_name, cl.class_names[winning_class_idx]) == 0) {
+ msg_info_bayes("SYMBOL_INSERT: symbol='%s', final_prob=%.3f, confidence_display='%s'",
+ st->stcf->symbol, final_prob, sumbuf);
+ rspamd_task_insert_result(task, st->stcf->symbol, final_prob, sumbuf);
+ break;
+ }
+ }
+
+ msg_debug_bayes("multiclass classification: winning class '%s' with "
+ "probability %.3f, confidence %.3f, %uL tokens processed",
+ cl.class_names[winning_class_idx],
+ normalized_probs[winning_class_idx],
+ confidence, cl.processed_tokens);
+ }
+ else {
+ msg_info_bayes("SYMBOL_SKIPPED: confidence=%.3f <= 0.05, no symbol inserted", confidence);
+ }
+
+ return TRUE;
}
@@ -279,6 +670,37 @@ bayes_classify(struct rspamd_classifier *ctx,
g_assert(ctx != NULL);
g_assert(tokens != NULL);
+ /* Check if this is a multi-class classifier */
+ msg_debug_bayes("classification check: class_names=%p, len=%uz",
+ ctx->cfg->class_names,
+ ctx->cfg->class_names ? ctx->cfg->class_names->len : 0);
+
+ if (ctx->cfg->class_names && ctx->cfg->class_names->len >= 2) {
+ /* Verify that at least one statfile has class_name set (indicating new multi-class config) */
+ gboolean has_class_names = FALSE;
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ int id = g_array_index(ctx->statfiles_ids, int, i);
+ struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ msg_debug_bayes("checking statfile %s: class_name=%s, is_spam_converted=%s",
+ st->stcf->symbol,
+ st->stcf->class_name ? st->stcf->class_name : "NULL",
+ st->stcf->is_spam_converted ? "true" : "false");
+ if (st->stcf->class_name) {
+ has_class_names = TRUE;
+ }
+ }
+
+ msg_debug_bayes("has_class_names=%s", has_class_names ? "true" : "false");
+
+ if (has_class_names) {
+ msg_debug_bayes("using multiclass classification with %ud classes",
+ (unsigned int) ctx->cfg->class_names->len);
+ return bayes_classify_multiclass(ctx, tokens, task);
+ }
+ }
+
+ /* Fall back to binary classification */
+ msg_debug_bayes("using binary classification");
memset(&cl, 0, sizeof(cl));
cl.task = task;
@@ -286,14 +708,14 @@ bayes_classify(struct rspamd_classifier *ctx,
if (ctx->cfg->min_learns > 0) {
if (ctx->ham_learns < ctx->cfg->min_learns) {
msg_info_task("not classified as ham. The ham class needs more "
- "training samples. Currently: %ul; minimum %ud required",
+ "training samples. Currently: %uL; minimum %ud required",
ctx->ham_learns, ctx->cfg->min_learns);
return TRUE;
}
if (ctx->spam_learns < ctx->cfg->min_learns) {
msg_info_task("not classified as spam. The spam class needs more "
- "training samples. Currently: %ul; minimum %ud required",
+ "training samples. Currently: %uL; minimum %ud required",
ctx->spam_learns, ctx->cfg->min_learns);
return TRUE;
@@ -374,7 +796,7 @@ bayes_classify(struct rspamd_classifier *ctx,
final_prob = (s + 1.0 - h) / 2.;
msg_debug_bayes(
"got ham probability %.2f -> %.2f and spam probability %.2f -> %.2f,"
- " %L tokens processed of %ud total tokens;"
+ " %uL tokens processed of %ud total tokens;"
" %uL text tokens found of %ud text tokens)",
cl.ham_prob,
h,
@@ -549,3 +971,155 @@ bayes_learn_spam(struct rspamd_classifier *ctx,
return TRUE;
}
+
+gboolean
+bayes_learn_class(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ const char *class_name,
+ gboolean unlearn,
+ GError **err)
+{
+ unsigned int i, j, total_cnt;
+ int id;
+ struct rspamd_statfile *st;
+ rspamd_token_t *tok;
+ gboolean incrementing;
+ unsigned int *class_counts = NULL;
+ struct rspamd_statfile **class_statfiles = NULL;
+ unsigned int num_classes = 0;
+
+ g_assert(ctx != NULL);
+ g_assert(tokens != NULL);
+ g_assert(class_name != NULL);
+
+ msg_info_bayes("LEARN_CLASS: class='%s', unlearn=%s, tokens=%ud",
+ class_name, unlearn ? "true" : "false", tokens->len);
+
+ incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
+
+ /* Count classes and prepare arrays for multi-class learning */
+ if (ctx->cfg->class_names && ctx->cfg->class_names->len > 0) {
+ num_classes = ctx->cfg->class_names->len;
+ class_counts = g_alloca(num_classes * sizeof(unsigned int));
+ class_statfiles = g_alloca(num_classes * sizeof(struct rspamd_statfile *));
+ memset(class_counts, 0, num_classes * sizeof(unsigned int));
+ memset(class_statfiles, 0, num_classes * sizeof(struct rspamd_statfile *));
+ }
+
+ for (i = 0; i < tokens->len; i++) {
+ total_cnt = 0;
+ tok = g_ptr_array_index(tokens, i);
+
+ /* Reset class counts for this token */
+ if (num_classes > 0) {
+ memset(class_counts, 0, num_classes * sizeof(unsigned int));
+ }
+
+ for (j = 0; j < ctx->statfiles_ids->len; j++) {
+ id = g_array_index(ctx->statfiles_ids, int, j);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ g_assert(st != NULL);
+
+ /* Determine if this statfile matches our target class */
+ gboolean is_target_class = FALSE;
+ if (st->stcf->class_name) {
+ /* Multi-class: exact class name match */
+ is_target_class = (strcmp(st->stcf->class_name, class_name) == 0);
+ }
+ else {
+ /* Legacy binary: map class_name to spam/ham */
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ is_target_class = st->stcf->is_spam;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ is_target_class = !st->stcf->is_spam;
+ }
+ }
+
+ if (is_target_class) {
+ /* Learning: increment the target class */
+ if (incrementing) {
+ tok->values[id] = 1;
+ }
+ else {
+ tok->values[id]++;
+ }
+ total_cnt += tok->values[id];
+
+ /* Track class counts for debugging */
+ if (num_classes > 0) {
+ for (unsigned int k = 0; k < num_classes; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+ if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) {
+ class_counts[k] += tok->values[id];
+ class_statfiles[k] = st;
+ break;
+ }
+ }
+ }
+ }
+ else {
+ /* Unlearning: decrement other classes if unlearn flag is set */
+ if (tok->values[id] > 0 && unlearn) {
+ if (incrementing) {
+ tok->values[id] = -1;
+ }
+ else {
+ tok->values[id]--;
+ }
+ total_cnt += tok->values[id];
+
+ /* Track class counts for debugging */
+ if (num_classes > 0) {
+ for (unsigned int k = 0; k < num_classes; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+ if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) {
+ class_counts[k] += tok->values[id];
+ class_statfiles[k] = st;
+ break;
+ }
+ }
+ }
+ }
+ else if (incrementing) {
+ tok->values[id] = 0;
+ }
+ }
+ }
+
+ /* Debug logging */
+ if (tok->t1 && tok->t2) {
+ if (num_classes > 0) {
+ GString *debug_str = g_string_new("");
+ for (unsigned int k = 0; k < num_classes; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+ g_string_append_printf(debug_str, "%s:%d ", check_class, class_counts[k]);
+ }
+ msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, "
+ "class_counts: %s",
+ tok->data,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ tok->window_idx, total_cnt, debug_str->str);
+ g_string_free(debug_str, TRUE);
+ }
+ else {
+ msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, "
+ "class: %s",
+ tok->data,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ tok->window_idx, total_cnt, class_name);
+ }
+ }
+ else {
+ msg_debug_bayes("token %uL <?:?>: window: %d, total_count: %d, "
+ "class: %s",
+ tok->data,
+ tok->window_idx, total_cnt, class_name);
+ }
+ }
+
+ return TRUE;
+}