123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551 |
- /*-
- * Copyright 2016 Vsevolod Stakhov
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- /*
- * Bayesian classifier
- */
- #include "classifiers.h"
- #include "rspamd.h"
- #include "stat_internal.h"
- #include "math.h"
-
- #define msg_err_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \
- "bayes", task->task_pool->tag.uid, \
- RSPAMD_LOG_FUNC, \
- __VA_ARGS__)
- #define msg_warn_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_WARNING, \
- "bayes", task->task_pool->tag.uid, \
- RSPAMD_LOG_FUNC, \
- __VA_ARGS__)
- #define msg_info_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_INFO, \
- "bayes", task->task_pool->tag.uid, \
- RSPAMD_LOG_FUNC, \
- __VA_ARGS__)
-
- INIT_LOG_MODULE_PUBLIC(bayes)
-
- static inline GQuark
- bayes_error_quark(void)
- {
- return g_quark_from_static_string("bayes-error");
- }
-
- /**
- * Returns probability of chisquare > value with specified number of freedom
- * degrees
- * @param value value to test
- * @param freedom_deg number of degrees of freedom
- * @return
- */
- static double
- inv_chi_square(struct rspamd_task *task, double value, int freedom_deg)
- {
- double prob, sum, m;
- int i;
-
- errno = 0;
- m = -value;
- prob = exp(value);
-
- if (errno == ERANGE) {
- /*
- * e^x where x is large *NEGATIVE* number is OK, so we have a very strong
- * confidence that inv-chi-square is close to zero
- */
- msg_debug_bayes("exp overflow");
-
- if (value < 0) {
- return 0;
- }
- else {
- return 1.0;
- }
- }
-
- sum = prob;
-
- msg_debug_bayes("m: %f, probability: %g", m, prob);
-
- /*
- * m is our confidence in class
- * prob is e ^ x (small value since x is normally less than zero
- * So we integrate over degrees of freedom and produce the total result
- * from 1.0 (no confidence) to 0.0 (full confidence)
- */
- for (i = 1; i < freedom_deg; i++) {
- prob *= m / (double) i;
- sum += prob;
- msg_debug_bayes("i=%d, probability: %g, sum: %g", i, prob, sum);
- }
-
- return MIN(1.0, sum);
- }
-
- struct bayes_task_closure {
- double ham_prob;
- double spam_prob;
- double meta_skip_prob;
- uint64_t processed_tokens;
- uint64_t total_hits;
- uint64_t text_tokens;
- struct rspamd_task *task;
- };
-
- /*
- * Mathematically we use pow(complexity, complexity), where complexity is the
- * window index
- */
- static const double feature_weight[] = {0, 3125, 256, 27, 1, 0, 0, 0};
-
- #define PROB_COMBINE(prob, cnt, weight, assumed) (((weight) * (assumed) + (cnt) * (prob)) / ((weight) + (cnt)))
- /*
- * In this callback we calculate local probabilities for tokens
- */
- static void
- bayes_classify_token(struct rspamd_classifier *ctx,
- rspamd_token_t *tok, struct bayes_task_closure *cl)
- {
- unsigned int i;
- int id;
- unsigned int spam_count = 0, ham_count = 0, total_count = 0;
- struct rspamd_statfile *st;
- struct rspamd_task *task;
- const char *token_type = "txt";
- double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
- ham_prob, fw, w, val;
-
- task = cl->task;
-
- #if 0
- if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_LUA_META) {
- /* Ignore lua metatokens for now */
- return;
- }
- #endif
-
- if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) {
- val = rspamd_random_double_fast();
-
- if (val <= cl->meta_skip_prob) {
- if (tok->t1 && tok->t2) {
- msg_debug_bayes(
- "token(meta) %uL <%*s:%*s> probabilistically skipped",
- tok->data,
- (int) tok->t1->original.len, tok->t1->original.begin,
- (int) tok->t2->original.len, tok->t2->original.begin);
- }
-
- return;
- }
- }
-
- for (i = 0; i < ctx->statfiles_ids->len; i++) {
- id = g_array_index(ctx->statfiles_ids, int, i);
- st = g_ptr_array_index(ctx->ctx->statfiles, id);
- g_assert(st != NULL);
- val = tok->values[id];
-
- if (val > 0) {
- if (st->stcf->is_spam) {
- spam_count += val;
- }
- else {
- ham_count += val;
- }
-
- total_count += val;
- cl->total_hits += val;
- }
- }
-
- /* Probability for this token */
- if (total_count >= ctx->cfg->min_token_hits) {
- spam_freq = ((double) spam_count / MAX(1., (double) ctx->spam_learns));
- ham_freq = ((double) ham_count / MAX(1., (double) ctx->ham_learns));
- spam_prob = spam_freq / (spam_freq + ham_freq);
- ham_prob = ham_freq / (spam_freq + ham_freq);
-
- if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) {
- fw = 1.0;
- }
- else {
- fw = feature_weight[tok->window_idx %
- G_N_ELEMENTS(feature_weight)];
- }
-
-
- w = (fw * total_count) / (1.0 + fw * total_count);
-
- bayes_spam_prob = PROB_COMBINE(spam_prob, total_count, w, 0.5);
-
- if ((bayes_spam_prob > 0.5 && bayes_spam_prob < 0.5 + ctx->cfg->min_prob_strength) ||
- (bayes_spam_prob < 0.5 && bayes_spam_prob > 0.5 - ctx->cfg->min_prob_strength)) {
- msg_debug_bayes(
- "token %uL <%*s:%*s> skipped, probability not in range: %f",
- tok->data,
- (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
- (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
- bayes_spam_prob);
-
- return;
- }
-
- bayes_ham_prob = PROB_COMBINE(ham_prob, total_count, w, 0.5);
-
- cl->spam_prob += log(bayes_spam_prob);
- cl->ham_prob += log(bayes_ham_prob);
- cl->processed_tokens++;
-
- if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
- cl->text_tokens++;
- }
- else {
- token_type = "meta";
- }
-
- if (tok->t1 && tok->t2) {
- msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %f, cf: %f, "
- "total_count: %ud, "
- "spam_count: %ud, ham_count: %ud,"
- "spam_prob: %.3f, ham_prob: %.3f, "
- "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
- "current spam probability: %.3f, current ham probability: %.3f",
- token_type,
- tok->data,
- (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
- (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
- fw, w, total_count, spam_count, ham_count,
- spam_prob, ham_prob,
- bayes_spam_prob, bayes_ham_prob,
- cl->spam_prob, cl->ham_prob);
- }
- else {
- msg_debug_bayes("token(%s) %uL <?:?>: weight: %f, cf: %f, "
- "total_count: %ud, "
- "spam_count: %ud, ham_count: %ud,"
- "spam_prob: %.3f, ham_prob: %.3f, "
- "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
- "current spam probability: %.3f, current ham probability: %.3f",
- token_type,
- tok->data,
- fw, w, total_count, spam_count, ham_count,
- spam_prob, ham_prob,
- bayes_spam_prob, bayes_ham_prob,
- cl->spam_prob, cl->ham_prob);
- }
- }
- }
-
-
- gboolean
- bayes_init(struct rspamd_config *cfg,
- struct ev_loop *ev_base,
- struct rspamd_classifier *cl)
- {
- cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_INTEGER;
-
- return TRUE;
- }
-
- void bayes_fin(struct rspamd_classifier *cl)
- {
- }
-
- gboolean
- bayes_classify(struct rspamd_classifier *ctx,
- GPtrArray *tokens,
- struct rspamd_task *task)
- {
- double final_prob, h, s, *pprob;
- char sumbuf[32];
- struct rspamd_statfile *st = NULL;
- struct bayes_task_closure cl;
- rspamd_token_t *tok;
- unsigned int i, text_tokens = 0;
- int id;
-
- g_assert(ctx != NULL);
- g_assert(tokens != NULL);
-
- memset(&cl, 0, sizeof(cl));
- cl.task = task;
-
- /* Check min learns */
- if (ctx->cfg->min_learns > 0) {
- if (ctx->ham_learns < ctx->cfg->min_learns) {
- msg_info_task("not classified as ham. The ham class needs more "
- "training samples. Currently: %ul; minimum %ud required",
- ctx->ham_learns, ctx->cfg->min_learns);
-
- return TRUE;
- }
- if (ctx->spam_learns < ctx->cfg->min_learns) {
- msg_info_task("not classified as spam. The spam class needs more "
- "training samples. Currently: %ul; minimum %ud required",
- ctx->spam_learns, ctx->cfg->min_learns);
-
- return TRUE;
- }
- }
-
- for (i = 0; i < tokens->len; i++) {
- tok = g_ptr_array_index(tokens, i);
- if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
- text_tokens++;
- }
- }
-
- if (text_tokens == 0) {
- msg_info_task("skipped classification as there are no text tokens. "
- "Total tokens: %ud",
- tokens->len);
-
- return TRUE;
- }
-
- /*
- * Skip some metatokens if we don't have enough text tokens
- */
- if (text_tokens > tokens->len - text_tokens) {
- cl.meta_skip_prob = 0.0;
- }
- else {
- cl.meta_skip_prob = 1.0 - text_tokens / tokens->len;
- }
-
- for (i = 0; i < tokens->len; i++) {
- tok = g_ptr_array_index(tokens, i);
-
- bayes_classify_token(ctx, tok, &cl);
- }
-
- if (cl.processed_tokens == 0) {
- msg_info_bayes("no tokens found in bayes database "
- "(%ud total tokens, %ud text tokens), ignore stats",
- tokens->len, text_tokens);
-
- return TRUE;
- }
-
- if (ctx->cfg->min_tokens > 0 &&
- cl.text_tokens < (int) (ctx->cfg->min_tokens * 0.1)) {
- msg_info_bayes("ignore bayes probability since we have "
- "found too few text tokens: %uL (of %ud checked), "
- "at least %d required",
- cl.text_tokens,
- text_tokens,
- (int) (ctx->cfg->min_tokens * 0.1));
-
- return TRUE;
- }
-
- if (cl.spam_prob > -300 && cl.ham_prob > -300) {
- /* Fisher value is low enough to apply inv_chi_square */
- h = 1 - inv_chi_square(task, cl.spam_prob, cl.processed_tokens);
- s = 1 - inv_chi_square(task, cl.ham_prob, cl.processed_tokens);
- }
- else {
- /* Use naive method */
- if (cl.spam_prob < cl.ham_prob) {
- h = (1.0 - exp(cl.spam_prob - cl.ham_prob)) /
- (1.0 + exp(cl.spam_prob - cl.ham_prob));
- s = 1.0 - h;
- }
- else {
- s = (1.0 - exp(cl.ham_prob - cl.spam_prob)) /
- (1.0 + exp(cl.ham_prob - cl.spam_prob));
- h = 1.0 - s;
- }
- }
-
- if (isfinite(s) && isfinite(h)) {
- final_prob = (s + 1.0 - h) / 2.;
- msg_debug_bayes(
- "got ham probability %.2f -> %.2f and spam probability %.2f -> %.2f,"
- " %L tokens processed of %ud total tokens;"
- " %uL text tokens found of %ud text tokens)",
- cl.ham_prob,
- h,
- cl.spam_prob,
- s,
- cl.processed_tokens,
- tokens->len,
- cl.text_tokens,
- text_tokens);
- }
- else {
- /*
- * We have some overflow, hence we need to check which class
- * is NaN
- */
- if (isfinite(h)) {
- final_prob = 1.0;
- msg_debug_bayes("spam class is full: no"
- " ham samples");
- }
- else if (isfinite(s)) {
- final_prob = 0.0;
- msg_debug_bayes("ham class is full: no"
- " spam samples");
- }
- else {
- final_prob = 0.5;
- msg_warn_bayes("spam and ham classes are both full");
- }
- }
-
- pprob = rspamd_mempool_alloc(task->task_pool, sizeof(*pprob));
- *pprob = final_prob;
- rspamd_mempool_set_variable(task->task_pool, "bayes_prob", pprob, NULL);
-
- if (cl.processed_tokens > 0 && fabs(final_prob - 0.5) > 0.05) {
- /* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */
- for (i = 0; i < ctx->statfiles_ids->len; i++) {
- id = g_array_index(ctx->statfiles_ids, int, i);
- st = g_ptr_array_index(ctx->ctx->statfiles, id);
-
- if (final_prob > 0.5 && st->stcf->is_spam) {
- break;
- }
- else if (final_prob < 0.5 && !st->stcf->is_spam) {
- break;
- }
- }
-
- /* Correctly scale HAM */
- if (final_prob < 0.5) {
- final_prob = 1.0 - final_prob;
- }
-
- /*
- * Bayes p is from 0.5 to 1.0, but confidence is from 0 to 1, so
- * we need to rescale it to display correctly
- */
- rspamd_snprintf(sumbuf, sizeof(sumbuf), "%.2f%%",
- (final_prob - 0.5) * 200.);
- final_prob = rspamd_normalize_probability(final_prob, 0.5);
- g_assert(st != NULL);
-
- if (final_prob > 1 || final_prob < 0) {
- msg_err_bayes("internal error: probability %f is outside of the "
- "allowed range [0..1]",
- final_prob);
-
- if (final_prob > 1) {
- final_prob = 1.0;
- }
- else {
- final_prob = 0.0;
- }
- }
-
- rspamd_task_insert_result(task,
- st->stcf->symbol,
- final_prob,
- sumbuf);
- }
-
- return TRUE;
- }
-
- gboolean
- bayes_learn_spam(struct rspamd_classifier *ctx,
- GPtrArray *tokens,
- struct rspamd_task *task,
- gboolean is_spam,
- gboolean unlearn,
- GError **err)
- {
- unsigned int i, j, total_cnt, spam_cnt, ham_cnt;
- int id;
- struct rspamd_statfile *st;
- rspamd_token_t *tok;
- gboolean incrementing;
-
- g_assert(ctx != NULL);
- g_assert(tokens != NULL);
-
- incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
-
- for (i = 0; i < tokens->len; i++) {
- total_cnt = 0;
- spam_cnt = 0;
- ham_cnt = 0;
- tok = g_ptr_array_index(tokens, i);
-
- for (j = 0; j < ctx->statfiles_ids->len; j++) {
- id = g_array_index(ctx->statfiles_ids, int, j);
- st = g_ptr_array_index(ctx->ctx->statfiles, id);
- g_assert(st != NULL);
-
- if (!!st->stcf->is_spam == !!is_spam) {
- if (incrementing) {
- tok->values[id] = 1;
- }
- else {
- tok->values[id]++;
- }
-
- total_cnt += tok->values[id];
-
- if (st->stcf->is_spam) {
- spam_cnt += tok->values[id];
- }
- else {
- ham_cnt += tok->values[id];
- }
- }
- else {
- if (tok->values[id] > 0 && unlearn) {
- /* Unlearning */
- if (incrementing) {
- tok->values[id] = -1;
- }
- else {
- tok->values[id]--;
- }
-
- if (st->stcf->is_spam) {
- spam_cnt += tok->values[id];
- }
- else {
- ham_cnt += tok->values[id];
- }
- total_cnt += tok->values[id];
- }
- else if (incrementing) {
- tok->values[id] = 0;
- }
- }
- }
-
- if (tok->t1 && tok->t2) {
- msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, "
- "spam_count: %d, ham_count: %d",
- tok->data,
- (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
- (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
- tok->window_idx, total_cnt, spam_cnt, ham_cnt);
- }
- else {
- msg_debug_bayes("token %uL <?:?>: window: %d, total_count: %d, "
- "spam_count: %d, ham_count: %d",
- tok->data,
- tok->window_idx, total_cnt, spam_cnt, ham_cnt);
- }
- }
-
- return TRUE;
- }
|