summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/libstat/classifiers/bayes.c277
-rw-r--r--src/libstat/classifiers/classifiers.h11
2 files changed, 100 insertions, 188 deletions
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c
index 6e068b79d..54d696721 100644
--- a/src/libstat/classifiers/bayes.c
+++ b/src/libstat/classifiers/bayes.c
@@ -39,63 +39,6 @@ bayes_error_quark (void)
return g_quark_from_static_string ("bayes-error");
}
-struct bayes_statfile_data {
- guint64 hits;
- guint64 total_hits;
- double value;
- struct rspamd_statfile_config *st;
-};
-
-struct bayes_callback_data {
- struct classifier_ctx *ctx;
- gboolean in_class;
- time_t now;
- struct bayes_statfile_data *statfiles;
- guint32 statfiles_num;
- guint64 total_spam;
- guint64 total_ham;
- guint64 processed_tokens;
- gsize max_tokens;
- double spam_probability;
- double ham_probability;
-};
-
-static gboolean
-bayes_learn_callback (gpointer key, gpointer value, gpointer data)
-{
- rspamd_token_t *node = key;
- struct bayes_callback_data *cd = data;
- gint c;
- guint64 v;
-
- c = (cd->in_class) ? 1 : -1;
-
- /* Consider that not found blocks have value 1 */
- /* XXX implement getting */
- if (v == 0 && c > 0) {
- /* XXX: add token to the backend */
- cd->processed_tokens++;
- }
- else if (v != 0) {
- if (G_LIKELY (c > 0)) {
- v++;
- }
- else if (c < 0) {
- if (v != 0) {
- v--;
- }
- }
- /* XXX: Implement setting */
- cd->processed_tokens++;
- }
-
- if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) {
- /* Stop learning on max tokens */
- return TRUE;
- }
- return FALSE;
-}
-
/**
* Returns probability of chisquare > value with specified number of freedom
* degrees
@@ -142,45 +85,35 @@ inv_chi_square (gdouble value, gint freedom_deg)
static gboolean
bayes_classify_callback (gpointer key, gpointer value, gpointer data)
{
-
- rspamd_token_t *node = key;
- struct bayes_callback_data *cd = data;
+ rspamd_token_t *node = value;
+ struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
guint i;
- struct bayes_statfile_data *cur;
+ struct rspamd_token_result *res;
guint64 spam_count = 0, ham_count = 0, total_count = 0;
double spam_prob, spam_freq, ham_freq, bayes_spam_prob;
- for (i = 0; i < cd->statfiles_num; i++) {
- cur = &cd->statfiles[i];
- /*
- * XXX: Implement getting
- */
- if (cur->value > 0) {
- cur->total_hits += cur->value;
- if (cur->st->is_spam) {
- spam_count += cur->value;
+ for (i = rt->start_pos; i < rt->end_pos; i++) {
+ res = &g_array_index (node->results, struct rspamd_token_result, i);
+
+ if (res->value > 0) {
+ if (res->st_runtime->st->is_spam) {
+ spam_count += res->value;
}
else {
- ham_count += cur->value;
+ ham_count += res->value;
}
- total_count += cur->value;
+ total_count += res->value;
}
}
/* Probability for this token */
if (total_count > 0) {
- spam_freq = ((double)spam_count / MAX (1., (double)cd->total_spam));
- ham_freq = ((double)ham_count / MAX (1., (double)cd->total_ham));
+ spam_freq = ((double)spam_count / MAX (1., (double)rt->total_spam));
+ ham_freq = ((double)ham_count / MAX (1., (double)rt->total_ham));
spam_prob = spam_freq / (spam_freq + ham_freq);
bayes_spam_prob = (0.5 + spam_prob * total_count) / (1. + total_count);
- cd->spam_probability += log (bayes_spam_prob);
- cd->ham_probability += log (1. - bayes_spam_prob);
- cd->processed_tokens++;
- }
-
- if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) {
- /* Stop classifying on max tokens */
- return TRUE;
+ rt->spam_prob += log (bayes_spam_prob);
+ rt->ham_prob += log (1. - bayes_spam_prob);
}
return FALSE;
@@ -202,91 +135,53 @@ bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier_config *cfg)
gboolean
bayes_classify (struct classifier_ctx * ctx,
GTree *input,
+ struct rspamd_classifier_runtime *rt,
struct rspamd_task *task)
{
- struct bayes_callback_data data;
- gchar *value;
- gint nodes, i = 0, selected_st = -1, cnt;
- gint minnodes;
- guint64 maxhits = 0, rev;
double final_prob, h, s;
- struct rspamd_statfile_config *st;
+ guint maxhits = 0;
+ struct rspamd_statfile_runtime *st, *selected_st = NULL;
GList *cur;
char *sumbuf;
g_assert (ctx != NULL);
+ g_assert (input != NULL);
+ g_assert (rt != NULL);
+ g_assert (rt->end_pos > rt->start_pos);
- if (ctx->cfg->opts &&
- (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
- minnodes = strtol (value, NULL, 10);
- nodes = g_tree_nnodes (input);
+ g_tree_foreach (input, bayes_classify_callback, &rt);
- if (nodes < minnodes) {
- return FALSE;
- }
- }
-
- cur = ctx->cfg->statfiles;
-#if 0
- cur = rspamd_lua_call_cls_pre_callbacks (ctx->cfg, task, FALSE, FALSE, L);
- if (cur) {
- rspamd_mempool_add_destructor (task->task_pool,
- (rspamd_mempool_destruct_t)g_list_free, cur);
- }
- else {
- cur = ctx->cfg->statfiles;
- }
-#endif
-
-
- data.statfiles_num = g_list_length (cur);
- data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num);
- data.now = time (NULL);
- data.ctx = ctx;
-
- data.processed_tokens = 0;
- data.spam_probability = 0;
- data.ham_probability = 0;
- data.total_ham = 0;
- data.total_spam = 0;
- if (ctx->cfg->opts &&
- (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
- minnodes = rspamd_config_parse_limit (value, -1);
- data.max_tokens = minnodes;
- }
- else {
- data.max_tokens = 0;
- }
-
- cnt = i;
-
- g_tree_foreach (input, bayes_classify_callback, &data);
-
- if (data.processed_tokens == 0 || data.spam_probability == 0) {
+ if (rt->spam_prob == 0) {
final_prob = 0;
}
else {
- h = 1 - inv_chi_square (-2. * data.spam_probability,
- 2 * data.processed_tokens);
- s = 1 - inv_chi_square (-2. * data.ham_probability,
- 2 * data.processed_tokens);
+ h = 1 - inv_chi_square (-2. * rt->spam_prob,
+ 2 * rt->processed_tokens);
+ s = 1 - inv_chi_square (-2. * rt->ham_prob,
+ 2 * rt->processed_tokens);
final_prob = (s + 1 - h) / 2.;
}
- if (data.processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
+ if (rt->processed_tokens > 0 && fabs (final_prob - 0.5) > 0.05) {
sumbuf = rspamd_mempool_alloc (task->task_pool, 32);
- for (i = 0; i < cnt; i++) {
- if ((final_prob > 0.5 && !data.statfiles[i].st->is_spam) ||
- (final_prob < 0.5 && data.statfiles[i].st->is_spam)) {
- continue;
- }
- if (data.statfiles[i].total_hits > maxhits) {
- maxhits = data.statfiles[i].total_hits;
- selected_st = i;
+ cur = g_list_first (rt->st_runtime);
+
+ while (cur) {
+ st = (struct rspamd_statfile_runtime *)cur->data;
+
+ if ((final_prob < 0.5 && !st->st->is_spam) ||
+ (final_prob > 0.5 && st->st->is_spam)) {
+ if (st->total_hits > maxhits) {
+ maxhits = st->total_hits;
+ selected_st = st;
+ }
}
+
+ cur = g_list_next (cur);
}
- if (selected_st == -1) {
+
+ if (selected_st == NULL) {
msg_err (
"unexpected classifier error: cannot select desired statfile");
}
@@ -298,65 +193,75 @@ bayes_classify (struct classifier_ctx * ctx,
rspamd_snprintf (sumbuf, 32, "%.2f%%", final_prob * 100.);
cur = g_list_prepend (NULL, sumbuf);
rspamd_task_insert_result (task,
- data.statfiles[selected_st].st->symbol,
+ selected_st->st->symbol,
final_prob,
cur);
}
}
- g_free (data.statfiles);
-
return TRUE;
}
+static gboolean
+bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data)
+{
+ rspamd_token_t *node = value;
+ struct rspamd_token_result *res;
+ struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
+ guint i;
+
+
+ for (i = rt->start_pos; i < rt->end_pos; i++) {
+ res = &g_array_index (node->results, struct rspamd_token_result, i);
+
+ if (res->st_runtime->st->is_spam) {
+ res->value ++;
+ }
+ }
+
+ return FALSE;
+}
+
+static gboolean
+bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data)
+{
+ rspamd_token_t *node = value;
+ struct rspamd_token_result *res;
+ struct rspamd_classifier_runtime *rt = (struct rspamd_classifier_runtime *)data;
+ guint i;
+
+
+ for (i = rt->start_pos; i < rt->end_pos; i++) {
+ res = &g_array_index (node->results, struct rspamd_token_result, i);
+
+ if (!res->st_runtime->st->is_spam) {
+ res->value ++;
+ }
+ }
+
+ return FALSE;
+}
+
gboolean
bayes_learn_spam (struct classifier_ctx * ctx,
GTree *input,
+ struct rspamd_classifier_runtime *rt,
struct rspamd_task *task,
gboolean is_spam,
GError **err)
{
- struct bayes_callback_data data;
- gchar *value;
- gint nodes;
- gint minnodes;
- struct rspamd_statfile_config *st;
- GList *cur;
- gboolean skip_labels;
-
g_assert (ctx != NULL);
+ g_assert (input != NULL);
+ g_assert (rt != NULL);
+ g_assert (rt->end_pos > rt->start_pos);
- if (ctx->cfg->opts &&
- (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
- minnodes = strtol (value, NULL, 10);
- nodes = g_tree_nnodes (input);
-
- if (nodes < minnodes) {
- g_set_error (err,
- bayes_error_quark (), /* error domain */
- 1, /* error code */
- "message contains too few tokens: %d, while min is %d",
- nodes, (int)minnodes);
- return FALSE;
- }
- }
-
- data.now = time (NULL);
- data.ctx = ctx;
- data.in_class = TRUE;
-
- data.processed_tokens = 0;
- if (ctx->cfg->opts &&
- (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
- minnodes = rspamd_config_parse_limit (value, -1);
- data.max_tokens = minnodes;
+ if (is_spam) {
+ g_tree_foreach (input, bayes_learn_spam_callback, rt);
}
else {
- data.max_tokens = 0;
+ g_tree_foreach (input, bayes_learn_ham_callback, rt);
}
- g_tree_foreach (input, bayes_learn_callback, &data);
-
return TRUE;
}
diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h
index e2bf57f81..9a30039df 100644
--- a/src/libstat/classifiers/classifiers.h
+++ b/src/libstat/classifiers/classifiers.h
@@ -18,14 +18,19 @@ struct classifier_ctx {
struct rspamd_classifier_config *cfg;
};
+struct token_node_s;
+struct rspamd_classifier_runtime;
+
struct rspamd_stat_classifier {
char *name;
struct classifier_ctx * (*init_func)(rspamd_mempool_t *pool,
struct rspamd_classifier_config *cf);
gboolean (*classify_func)(struct classifier_ctx * ctx,
- GTree *input, struct rspamd_task *task);
+ GTree *input, struct rspamd_classifier_runtime *rt,
+ struct rspamd_task *task);
gboolean (*learn_spam_func)(struct classifier_ctx * ctx,
- GTree *input, struct rspamd_task *task, gboolean is_spam,
+ GTree *input, struct rspamd_classifier_runtime *rt,
+ struct rspamd_task *task, gboolean is_spam,
GError **err);
};
@@ -34,9 +39,11 @@ struct classifier_ctx * bayes_init (rspamd_mempool_t *pool,
struct rspamd_classifier_config *cf);
gboolean bayes_classify (struct classifier_ctx * ctx,
GTree *input,
+ struct rspamd_classifier_runtime *rt,
struct rspamd_task *task);
gboolean bayes_learn_spam (struct classifier_ctx * ctx,
GTree *input,
+ struct rspamd_classifier_runtime *rt,
struct rspamd_task *task,
gboolean is_spam,
GError **err);