You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

bayes.c 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. /*-
  2. * Copyright 2016 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*
  17. * Bayesian classifier
  18. */
  19. #include "classifiers.h"
  20. #include "rspamd.h"
  21. #include "stat_internal.h"
  22. #include "math.h"
  23. #define msg_err_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \
  24. "bayes", task->task_pool->tag.uid, \
  25. RSPAMD_LOG_FUNC, \
  26. __VA_ARGS__)
  27. #define msg_warn_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_WARNING, \
  28. "bayes", task->task_pool->tag.uid, \
  29. RSPAMD_LOG_FUNC, \
  30. __VA_ARGS__)
  31. #define msg_info_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_INFO, \
  32. "bayes", task->task_pool->tag.uid, \
  33. RSPAMD_LOG_FUNC, \
  34. __VA_ARGS__)
  35. INIT_LOG_MODULE_PUBLIC(bayes)
  36. static inline GQuark
  37. bayes_error_quark(void)
  38. {
  39. return g_quark_from_static_string("bayes-error");
  40. }
  41. /**
  42. * Returns probability of chisquare > value with specified number of freedom
  43. * degrees
  44. * @param value value to test
  45. * @param freedom_deg number of degrees of freedom
  46. * @return
  47. */
  48. static double
  49. inv_chi_square(struct rspamd_task *task, double value, int freedom_deg)
  50. {
  51. double prob, sum, m;
  52. int i;
  53. errno = 0;
  54. m = -value;
  55. prob = exp(value);
  56. if (errno == ERANGE) {
  57. /*
  58. * e^x where x is large *NEGATIVE* number is OK, so we have a very strong
  59. * confidence that inv-chi-square is close to zero
  60. */
  61. msg_debug_bayes("exp overflow");
  62. if (value < 0) {
  63. return 0;
  64. }
  65. else {
  66. return 1.0;
  67. }
  68. }
  69. sum = prob;
  70. msg_debug_bayes("m: %f, probability: %g", m, prob);
  71. /*
  72. * m is our confidence in class
  73. * prob is e ^ x (small value since x is normally less than zero
  74. * So we integrate over degrees of freedom and produce the total result
  75. * from 1.0 (no confidence) to 0.0 (full confidence)
  76. */
  77. for (i = 1; i < freedom_deg; i++) {
  78. prob *= m / (double) i;
  79. sum += prob;
  80. msg_debug_bayes("i=%d, probability: %g, sum: %g", i, prob, sum);
  81. }
  82. return MIN(1.0, sum);
  83. }
  84. struct bayes_task_closure {
  85. double ham_prob;
  86. double spam_prob;
  87. double meta_skip_prob;
  88. uint64_t processed_tokens;
  89. uint64_t total_hits;
  90. uint64_t text_tokens;
  91. struct rspamd_task *task;
  92. };
  93. /*
  94. * Mathematically we use pow(complexity, complexity), where complexity is the
  95. * window index
  96. */
  97. static const double feature_weight[] = {0, 3125, 256, 27, 1, 0, 0, 0};
  98. #define PROB_COMBINE(prob, cnt, weight, assumed) (((weight) * (assumed) + (cnt) * (prob)) / ((weight) + (cnt)))
  99. /*
  100. * In this callback we calculate local probabilities for tokens
  101. */
  102. static void
  103. bayes_classify_token(struct rspamd_classifier *ctx,
  104. rspamd_token_t *tok, struct bayes_task_closure *cl)
  105. {
  106. unsigned int i;
  107. int id;
  108. unsigned int spam_count = 0, ham_count = 0, total_count = 0;
  109. struct rspamd_statfile *st;
  110. struct rspamd_task *task;
  111. const char *token_type = "txt";
  112. double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
  113. ham_prob, fw, w, val;
  114. task = cl->task;
  115. #if 0
  116. if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_LUA_META) {
  117. /* Ignore lua metatokens for now */
  118. return;
  119. }
  120. #endif
  121. if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) {
  122. val = rspamd_random_double_fast();
  123. if (val <= cl->meta_skip_prob) {
  124. if (tok->t1 && tok->t2) {
  125. msg_debug_bayes(
  126. "token(meta) %uL <%*s:%*s> probabilistically skipped",
  127. tok->data,
  128. (int) tok->t1->original.len, tok->t1->original.begin,
  129. (int) tok->t2->original.len, tok->t2->original.begin);
  130. }
  131. return;
  132. }
  133. }
  134. for (i = 0; i < ctx->statfiles_ids->len; i++) {
  135. id = g_array_index(ctx->statfiles_ids, int, i);
  136. st = g_ptr_array_index(ctx->ctx->statfiles, id);
  137. g_assert(st != NULL);
  138. val = tok->values[id];
  139. if (val > 0) {
  140. if (st->stcf->is_spam) {
  141. spam_count += val;
  142. }
  143. else {
  144. ham_count += val;
  145. }
  146. total_count += val;
  147. cl->total_hits += val;
  148. }
  149. }
  150. /* Probability for this token */
  151. if (total_count >= ctx->cfg->min_token_hits) {
  152. spam_freq = ((double) spam_count / MAX(1., (double) ctx->spam_learns));
  153. ham_freq = ((double) ham_count / MAX(1., (double) ctx->ham_learns));
  154. spam_prob = spam_freq / (spam_freq + ham_freq);
  155. ham_prob = ham_freq / (spam_freq + ham_freq);
  156. if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) {
  157. fw = 1.0;
  158. }
  159. else {
  160. fw = feature_weight[tok->window_idx %
  161. G_N_ELEMENTS(feature_weight)];
  162. }
  163. w = (fw * total_count) / (1.0 + fw * total_count);
  164. bayes_spam_prob = PROB_COMBINE(spam_prob, total_count, w, 0.5);
  165. if ((bayes_spam_prob > 0.5 && bayes_spam_prob < 0.5 + ctx->cfg->min_prob_strength) ||
  166. (bayes_spam_prob < 0.5 && bayes_spam_prob > 0.5 - ctx->cfg->min_prob_strength)) {
  167. msg_debug_bayes(
  168. "token %uL <%*s:%*s> skipped, probability not in range: %f",
  169. tok->data,
  170. (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
  171. (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
  172. bayes_spam_prob);
  173. return;
  174. }
  175. bayes_ham_prob = PROB_COMBINE(ham_prob, total_count, w, 0.5);
  176. cl->spam_prob += log(bayes_spam_prob);
  177. cl->ham_prob += log(bayes_ham_prob);
  178. cl->processed_tokens++;
  179. if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
  180. cl->text_tokens++;
  181. }
  182. else {
  183. token_type = "meta";
  184. }
  185. if (tok->t1 && tok->t2) {
  186. msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %f, cf: %f, "
  187. "total_count: %ud, "
  188. "spam_count: %ud, ham_count: %ud,"
  189. "spam_prob: %.3f, ham_prob: %.3f, "
  190. "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
  191. "current spam probability: %.3f, current ham probability: %.3f",
  192. token_type,
  193. tok->data,
  194. (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
  195. (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
  196. fw, w, total_count, spam_count, ham_count,
  197. spam_prob, ham_prob,
  198. bayes_spam_prob, bayes_ham_prob,
  199. cl->spam_prob, cl->ham_prob);
  200. }
  201. else {
  202. msg_debug_bayes("token(%s) %uL <?:?>: weight: %f, cf: %f, "
  203. "total_count: %ud, "
  204. "spam_count: %ud, ham_count: %ud,"
  205. "spam_prob: %.3f, ham_prob: %.3f, "
  206. "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
  207. "current spam probability: %.3f, current ham probability: %.3f",
  208. token_type,
  209. tok->data,
  210. fw, w, total_count, spam_count, ham_count,
  211. spam_prob, ham_prob,
  212. bayes_spam_prob, bayes_ham_prob,
  213. cl->spam_prob, cl->ham_prob);
  214. }
  215. }
  216. }
  217. gboolean
  218. bayes_init(struct rspamd_config *cfg,
  219. struct ev_loop *ev_base,
  220. struct rspamd_classifier *cl)
  221. {
  222. cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_INTEGER;
  223. return TRUE;
  224. }
  225. void bayes_fin(struct rspamd_classifier *cl)
  226. {
  227. }
  228. gboolean
  229. bayes_classify(struct rspamd_classifier *ctx,
  230. GPtrArray *tokens,
  231. struct rspamd_task *task)
  232. {
  233. double final_prob, h, s, *pprob;
  234. char sumbuf[32];
  235. struct rspamd_statfile *st = NULL;
  236. struct bayes_task_closure cl;
  237. rspamd_token_t *tok;
  238. unsigned int i, text_tokens = 0;
  239. int id;
  240. g_assert(ctx != NULL);
  241. g_assert(tokens != NULL);
  242. memset(&cl, 0, sizeof(cl));
  243. cl.task = task;
  244. /* Check min learns */
  245. if (ctx->cfg->min_learns > 0) {
  246. if (ctx->ham_learns < ctx->cfg->min_learns) {
  247. msg_info_task("not classified as ham. The ham class needs more "
  248. "training samples. Currently: %ul; minimum %ud required",
  249. ctx->ham_learns, ctx->cfg->min_learns);
  250. return TRUE;
  251. }
  252. if (ctx->spam_learns < ctx->cfg->min_learns) {
  253. msg_info_task("not classified as spam. The spam class needs more "
  254. "training samples. Currently: %ul; minimum %ud required",
  255. ctx->spam_learns, ctx->cfg->min_learns);
  256. return TRUE;
  257. }
  258. }
  259. for (i = 0; i < tokens->len; i++) {
  260. tok = g_ptr_array_index(tokens, i);
  261. if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
  262. text_tokens++;
  263. }
  264. }
  265. if (text_tokens == 0) {
  266. msg_info_task("skipped classification as there are no text tokens. "
  267. "Total tokens: %ud",
  268. tokens->len);
  269. return TRUE;
  270. }
  271. /*
  272. * Skip some metatokens if we don't have enough text tokens
  273. */
  274. if (text_tokens > tokens->len - text_tokens) {
  275. cl.meta_skip_prob = 0.0;
  276. }
  277. else {
  278. cl.meta_skip_prob = 1.0 - text_tokens / tokens->len;
  279. }
  280. for (i = 0; i < tokens->len; i++) {
  281. tok = g_ptr_array_index(tokens, i);
  282. bayes_classify_token(ctx, tok, &cl);
  283. }
  284. if (cl.processed_tokens == 0) {
  285. msg_info_bayes("no tokens found in bayes database "
  286. "(%ud total tokens, %ud text tokens), ignore stats",
  287. tokens->len, text_tokens);
  288. return TRUE;
  289. }
  290. if (ctx->cfg->min_tokens > 0 &&
  291. cl.text_tokens < (int) (ctx->cfg->min_tokens * 0.1)) {
  292. msg_info_bayes("ignore bayes probability since we have "
  293. "found too few text tokens: %uL (of %ud checked), "
  294. "at least %d required",
  295. cl.text_tokens,
  296. text_tokens,
  297. (int) (ctx->cfg->min_tokens * 0.1));
  298. return TRUE;
  299. }
  300. if (cl.spam_prob > -300 && cl.ham_prob > -300) {
  301. /* Fisher value is low enough to apply inv_chi_square */
  302. h = 1 - inv_chi_square(task, cl.spam_prob, cl.processed_tokens);
  303. s = 1 - inv_chi_square(task, cl.ham_prob, cl.processed_tokens);
  304. }
  305. else {
  306. /* Use naive method */
  307. if (cl.spam_prob < cl.ham_prob) {
  308. h = (1.0 - exp(cl.spam_prob - cl.ham_prob)) /
  309. (1.0 + exp(cl.spam_prob - cl.ham_prob));
  310. s = 1.0 - h;
  311. }
  312. else {
  313. s = (1.0 - exp(cl.ham_prob - cl.spam_prob)) /
  314. (1.0 + exp(cl.ham_prob - cl.spam_prob));
  315. h = 1.0 - s;
  316. }
  317. }
  318. if (isfinite(s) && isfinite(h)) {
  319. final_prob = (s + 1.0 - h) / 2.;
  320. msg_debug_bayes(
  321. "got ham probability %.2f -> %.2f and spam probability %.2f -> %.2f,"
  322. " %L tokens processed of %ud total tokens;"
  323. " %uL text tokens found of %ud text tokens)",
  324. cl.ham_prob,
  325. h,
  326. cl.spam_prob,
  327. s,
  328. cl.processed_tokens,
  329. tokens->len,
  330. cl.text_tokens,
  331. text_tokens);
  332. }
  333. else {
  334. /*
  335. * We have some overflow, hence we need to check which class
  336. * is NaN
  337. */
  338. if (isfinite(h)) {
  339. final_prob = 1.0;
  340. msg_debug_bayes("spam class is full: no"
  341. " ham samples");
  342. }
  343. else if (isfinite(s)) {
  344. final_prob = 0.0;
  345. msg_debug_bayes("ham class is full: no"
  346. " spam samples");
  347. }
  348. else {
  349. final_prob = 0.5;
  350. msg_warn_bayes("spam and ham classes are both full");
  351. }
  352. }
  353. pprob = rspamd_mempool_alloc(task->task_pool, sizeof(*pprob));
  354. *pprob = final_prob;
  355. rspamd_mempool_set_variable(task->task_pool, "bayes_prob", pprob, NULL);
  356. if (cl.processed_tokens > 0 && fabs(final_prob - 0.5) > 0.05) {
  357. /* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */
  358. for (i = 0; i < ctx->statfiles_ids->len; i++) {
  359. id = g_array_index(ctx->statfiles_ids, int, i);
  360. st = g_ptr_array_index(ctx->ctx->statfiles, id);
  361. if (final_prob > 0.5 && st->stcf->is_spam) {
  362. break;
  363. }
  364. else if (final_prob < 0.5 && !st->stcf->is_spam) {
  365. break;
  366. }
  367. }
  368. /* Correctly scale HAM */
  369. if (final_prob < 0.5) {
  370. final_prob = 1.0 - final_prob;
  371. }
  372. /*
  373. * Bayes p is from 0.5 to 1.0, but confidence is from 0 to 1, so
  374. * we need to rescale it to display correctly
  375. */
  376. rspamd_snprintf(sumbuf, sizeof(sumbuf), "%.2f%%",
  377. (final_prob - 0.5) * 200.);
  378. final_prob = rspamd_normalize_probability(final_prob, 0.5);
  379. g_assert(st != NULL);
  380. if (final_prob > 1 || final_prob < 0) {
  381. msg_err_bayes("internal error: probability %f is outside of the "
  382. "allowed range [0..1]",
  383. final_prob);
  384. if (final_prob > 1) {
  385. final_prob = 1.0;
  386. }
  387. else {
  388. final_prob = 0.0;
  389. }
  390. }
  391. rspamd_task_insert_result(task,
  392. st->stcf->symbol,
  393. final_prob,
  394. sumbuf);
  395. }
  396. return TRUE;
  397. }
  398. gboolean
  399. bayes_learn_spam(struct rspamd_classifier *ctx,
  400. GPtrArray *tokens,
  401. struct rspamd_task *task,
  402. gboolean is_spam,
  403. gboolean unlearn,
  404. GError **err)
  405. {
  406. unsigned int i, j, total_cnt, spam_cnt, ham_cnt;
  407. int id;
  408. struct rspamd_statfile *st;
  409. rspamd_token_t *tok;
  410. gboolean incrementing;
  411. g_assert(ctx != NULL);
  412. g_assert(tokens != NULL);
  413. incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
  414. for (i = 0; i < tokens->len; i++) {
  415. total_cnt = 0;
  416. spam_cnt = 0;
  417. ham_cnt = 0;
  418. tok = g_ptr_array_index(tokens, i);
  419. for (j = 0; j < ctx->statfiles_ids->len; j++) {
  420. id = g_array_index(ctx->statfiles_ids, int, j);
  421. st = g_ptr_array_index(ctx->ctx->statfiles, id);
  422. g_assert(st != NULL);
  423. if (!!st->stcf->is_spam == !!is_spam) {
  424. if (incrementing) {
  425. tok->values[id] = 1;
  426. }
  427. else {
  428. tok->values[id]++;
  429. }
  430. total_cnt += tok->values[id];
  431. if (st->stcf->is_spam) {
  432. spam_cnt += tok->values[id];
  433. }
  434. else {
  435. ham_cnt += tok->values[id];
  436. }
  437. }
  438. else {
  439. if (tok->values[id] > 0 && unlearn) {
  440. /* Unlearning */
  441. if (incrementing) {
  442. tok->values[id] = -1;
  443. }
  444. else {
  445. tok->values[id]--;
  446. }
  447. if (st->stcf->is_spam) {
  448. spam_cnt += tok->values[id];
  449. }
  450. else {
  451. ham_cnt += tok->values[id];
  452. }
  453. total_cnt += tok->values[id];
  454. }
  455. else if (incrementing) {
  456. tok->values[id] = 0;
  457. }
  458. }
  459. }
  460. if (tok->t1 && tok->t2) {
  461. msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, "
  462. "spam_count: %d, ham_count: %d",
  463. tok->data,
  464. (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
  465. (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
  466. tok->window_idx, total_cnt, spam_cnt, ham_cnt);
  467. }
  468. else {
  469. msg_debug_bayes("token %uL <?:?>: window: %d, total_count: %d, "
  470. "spam_count: %d, ham_count: %d",
  471. tok->data,
  472. tok->window_idx, total_cnt, spam_cnt, ham_cnt);
  473. }
  474. }
  475. return TRUE;
  476. }