You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

stat_process.c 32KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263
  1. /*
  2. * Copyright 2024 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "config.h"
  17. #include "stat_api.h"
  18. #include "rspamd.h"
  19. #include "stat_internal.h"
  20. #include "libmime/message.h"
  21. #include "libmime/images.h"
  22. #include "libserver/html/html.h"
  23. #include "lua/lua_common.h"
  24. #include "lua/lua_classnames.h"
  25. #include "libserver/mempool_vars_internal.h"
  26. #include "utlist.h"
  27. #include <math.h>
  28. #define RSPAMD_CLASSIFY_OP 0
  29. #define RSPAMD_LEARN_OP 1
  30. #define RSPAMD_UNLEARN_OP 2
  31. static const double similarity_threshold = 80.0;
  32. static void
  33. rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx,
  34. struct rspamd_task *task)
  35. {
  36. GArray *ar;
  37. rspamd_stat_token_t elt;
  38. unsigned int i;
  39. lua_State *L = task->cfg->lua_state;
  40. ar = g_array_sized_new(FALSE, FALSE, sizeof(elt), 16);
  41. memset(&elt, 0, sizeof(elt));
  42. elt.flags = RSPAMD_STAT_TOKEN_FLAG_META;
  43. if (st_ctx->lua_stat_tokens_ref != -1) {
  44. int err_idx, ret;
  45. struct rspamd_task **ptask;
  46. lua_pushcfunction(L, &rspamd_lua_traceback);
  47. err_idx = lua_gettop(L);
  48. lua_rawgeti(L, LUA_REGISTRYINDEX, st_ctx->lua_stat_tokens_ref);
  49. ptask = lua_newuserdata(L, sizeof(*ptask));
  50. *ptask = task;
  51. rspamd_lua_setclass(L, rspamd_task_classname, -1);
  52. if ((ret = lua_pcall(L, 1, 1, err_idx)) != 0) {
  53. msg_err_task("call to stat_tokens lua "
  54. "script failed (%d): %s",
  55. ret, lua_tostring(L, -1));
  56. }
  57. else {
  58. if (lua_type(L, -1) != LUA_TTABLE) {
  59. msg_err_task("stat_tokens invocation must return "
  60. "table and not %s",
  61. lua_typename(L, lua_type(L, -1)));
  62. }
  63. else {
  64. unsigned int vlen;
  65. rspamd_ftok_t tok;
  66. vlen = rspamd_lua_table_size(L, -1);
  67. for (i = 0; i < vlen; i++) {
  68. lua_rawgeti(L, -1, i + 1);
  69. tok.begin = lua_tolstring(L, -1, &tok.len);
  70. if (tok.begin && tok.len > 0) {
  71. elt.original.begin =
  72. rspamd_mempool_ftokdup(task->task_pool, &tok);
  73. elt.original.len = tok.len;
  74. elt.stemmed.begin = elt.original.begin;
  75. elt.stemmed.len = elt.original.len;
  76. elt.normalized.begin = elt.original.begin;
  77. elt.normalized.len = elt.original.len;
  78. g_array_append_val(ar, elt);
  79. }
  80. lua_pop(L, 1);
  81. }
  82. }
  83. }
  84. lua_settop(L, 0);
  85. }
  86. if (ar->len > 0) {
  87. st_ctx->tokenizer->tokenize_func(st_ctx,
  88. task,
  89. ar,
  90. TRUE,
  91. "M",
  92. task->tokens);
  93. }
  94. rspamd_mempool_add_destructor(task->task_pool,
  95. rspamd_array_free_hard, ar);
  96. }
  97. /*
  98. * Tokenize task using the tokenizer specified
  99. */
  100. void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx,
  101. struct rspamd_task *task)
  102. {
  103. struct rspamd_mime_text_part *part;
  104. rspamd_cryptobox_hash_state_t hst;
  105. rspamd_token_t *st_tok;
  106. unsigned int i, reserved_len = 0;
  107. double *pdiff;
  108. unsigned char hout[rspamd_cryptobox_HASHBYTES];
  109. char *b32_hout;
  110. if (st_ctx == NULL) {
  111. st_ctx = rspamd_stat_get_ctx();
  112. }
  113. g_assert(st_ctx != NULL);
  114. PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part)
  115. {
  116. if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) {
  117. reserved_len += part->utf_words->len;
  118. }
  119. /* XXX: normal window size */
  120. reserved_len += 5;
  121. }
  122. task->tokens = g_ptr_array_sized_new(reserved_len);
  123. rspamd_mempool_add_destructor(task->task_pool,
  124. rspamd_ptr_array_free_hard, task->tokens);
  125. rspamd_mempool_notify_alloc(task->task_pool, reserved_len * sizeof(gpointer));
  126. pdiff = rspamd_mempool_get_variable(task->task_pool, "parts_distance");
  127. PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part)
  128. {
  129. if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) {
  130. st_ctx->tokenizer->tokenize_func(st_ctx, task,
  131. part->utf_words, IS_TEXT_PART_UTF(part),
  132. NULL, task->tokens);
  133. }
  134. if (pdiff != NULL && (1.0 - *pdiff) * 100.0 > similarity_threshold) {
  135. msg_debug_bayes("message has two common parts (%.2f), so skip the last one",
  136. *pdiff);
  137. break;
  138. }
  139. }
  140. if (task->meta_words != NULL) {
  141. st_ctx->tokenizer->tokenize_func(st_ctx,
  142. task,
  143. task->meta_words,
  144. TRUE,
  145. "SUBJECT",
  146. task->tokens);
  147. }
  148. rspamd_stat_tokenize_parts_metadata(st_ctx, task);
  149. /* Produce signature */
  150. rspamd_cryptobox_hash_init(&hst, NULL, 0);
  151. PTR_ARRAY_FOREACH(task->tokens, i, st_tok)
  152. {
  153. rspamd_cryptobox_hash_update(&hst, (unsigned char *) &st_tok->data,
  154. sizeof(st_tok->data));
  155. }
  156. rspamd_cryptobox_hash_final(&hst, hout);
  157. b32_hout = rspamd_encode_base32(hout, sizeof(hout), RSPAMD_BASE32_DEFAULT);
  158. /*
  159. * We need to strip it to 32 characters providing ~160 bits of
  160. * hash distribution
  161. */
  162. b32_hout[32] = '\0';
  163. rspamd_mempool_set_variable(task->task_pool, RSPAMD_MEMPOOL_STAT_SIGNATURE,
  164. b32_hout, g_free);
  165. }
  166. static gboolean
  167. rspamd_stat_classifier_is_skipped(struct rspamd_task *task,
  168. struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam)
  169. {
  170. GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions;
  171. lua_State *L = task->cfg->lua_state;
  172. gboolean ret = FALSE;
  173. while (cur) {
  174. int cb_ref = GPOINTER_TO_INT(cur->data);
  175. int old_top = lua_gettop(L);
  176. int nargs;
  177. lua_rawgeti(L, LUA_REGISTRYINDEX, cb_ref);
  178. /* Push task and two booleans: is_spam and is_unlearn */
  179. struct rspamd_task **ptask = lua_newuserdata(L, sizeof(*ptask));
  180. *ptask = task;
  181. rspamd_lua_setclass(L, rspamd_task_classname, -1);
  182. if (is_learn) {
  183. lua_pushboolean(L, is_spam);
  184. lua_pushboolean(L,
  185. task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
  186. nargs = 3;
  187. }
  188. else {
  189. nargs = 1;
  190. }
  191. if (lua_pcall(L, nargs, LUA_MULTRET, 0) != 0) {
  192. msg_err_task("call to %s failed: %s",
  193. "condition callback",
  194. lua_tostring(L, -1));
  195. }
  196. else {
  197. if (lua_isboolean(L, 1)) {
  198. if (!lua_toboolean(L, 1)) {
  199. ret = TRUE;
  200. }
  201. }
  202. if (lua_isstring(L, 2)) {
  203. if (ret) {
  204. msg_notice_task("%s condition for classifier %s returned: %s; skip classifier",
  205. is_learn ? "learn" : "classify", cl->cfg->name,
  206. lua_tostring(L, 2));
  207. }
  208. else {
  209. msg_info_task("%s condition for classifier %s returned: %s",
  210. is_learn ? "learn" : "classify", cl->cfg->name,
  211. lua_tostring(L, 2));
  212. }
  213. }
  214. else if (ret) {
  215. msg_notice_task("%s condition for classifier %s returned false; skip classifier",
  216. is_learn ? "learn" : "classify", cl->cfg->name);
  217. }
  218. if (ret) {
  219. lua_settop(L, old_top);
  220. break;
  221. }
  222. }
  223. lua_settop(L, old_top);
  224. cur = g_list_next(cur);
  225. }
  226. return ret;
  227. }
  228. static void
  229. rspamd_stat_preprocess(struct rspamd_stat_ctx *st_ctx,
  230. struct rspamd_task *task, gboolean is_learn, gboolean is_spam)
  231. {
  232. unsigned int i;
  233. struct rspamd_statfile *st;
  234. gpointer bk_run;
  235. if (task->tokens == NULL) {
  236. rspamd_stat_process_tokenize(st_ctx, task);
  237. }
  238. task->stat_runtimes = g_ptr_array_sized_new(st_ctx->statfiles->len);
  239. g_ptr_array_set_size(task->stat_runtimes, st_ctx->statfiles->len);
  240. rspamd_mempool_add_destructor(task->task_pool,
  241. rspamd_ptr_array_free_hard, task->stat_runtimes);
  242. /* Temporary set all stat_runtimes to some max size to distinguish from NULL */
  243. for (i = 0; i < st_ctx->statfiles->len; i++) {
  244. g_ptr_array_index(task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE);
  245. }
  246. for (i = 0; i < st_ctx->classifiers->len; i++) {
  247. struct rspamd_classifier *cl = g_ptr_array_index(st_ctx->classifiers, i);
  248. gboolean skip_classifier = FALSE;
  249. if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
  250. skip_classifier = TRUE;
  251. }
  252. else {
  253. if (rspamd_stat_classifier_is_skipped(task, cl, is_learn, is_spam)) {
  254. skip_classifier = TRUE;
  255. }
  256. }
  257. if (skip_classifier) {
  258. /* Set NULL for all statfiles indexed by id */
  259. for (int j = 0; j < cl->statfiles_ids->len; j++) {
  260. int id = g_array_index(cl->statfiles_ids, int, j);
  261. g_ptr_array_index(task->stat_runtimes, id) = NULL;
  262. }
  263. }
  264. }
  265. for (i = 0; i < st_ctx->statfiles->len; i++) {
  266. st = g_ptr_array_index(st_ctx->statfiles, i);
  267. g_assert(st != NULL);
  268. if (g_ptr_array_index(task->stat_runtimes, i) == NULL) {
  269. /* The whole classifier is skipped */
  270. continue;
  271. }
  272. if (is_learn && st->backend->read_only) {
  273. /* Read only backend, skip it */
  274. g_ptr_array_index(task->stat_runtimes, i) = NULL;
  275. continue;
  276. }
  277. if (!is_learn && !rspamd_symcache_is_symbol_enabled(task, task->cfg->cache,
  278. st->stcf->symbol)) {
  279. g_ptr_array_index(task->stat_runtimes, i) = NULL;
  280. msg_debug_bayes("symbol %s is disabled, skip classification",
  281. st->stcf->symbol);
  282. /* We need to disable the whole classifier for this! */
  283. struct rspamd_classifier *cl = st->classifier;
  284. for (int j = 0; j < st_ctx->statfiles->len; j++) {
  285. struct rspamd_statfile *nst = g_ptr_array_index(st_ctx->statfiles, j);
  286. if (st != nst && nst->classifier == cl) {
  287. g_ptr_array_index(task->stat_runtimes, j) = NULL;
  288. msg_debug_bayes("symbol %s is disabled, skip classification for %s as well",
  289. st->stcf->symbol, nst->stcf->symbol);
  290. }
  291. }
  292. continue;
  293. }
  294. bk_run = st->backend->runtime(task, st->stcf, is_learn, st->bkcf, i);
  295. if (bk_run == NULL) {
  296. msg_err_task("cannot init backend %s for statfile %s",
  297. st->backend->name, st->stcf->symbol);
  298. }
  299. g_ptr_array_index(task->stat_runtimes, i) = bk_run;
  300. }
  301. }
  302. static void
  303. rspamd_stat_backends_process(struct rspamd_stat_ctx *st_ctx,
  304. struct rspamd_task *task)
  305. {
  306. unsigned int i;
  307. struct rspamd_statfile *st;
  308. gpointer bk_run;
  309. g_assert(task->stat_runtimes != NULL);
  310. for (i = 0; i < st_ctx->statfiles->len; i++) {
  311. st = g_ptr_array_index(st_ctx->statfiles, i);
  312. bk_run = g_ptr_array_index(task->stat_runtimes, i);
  313. if (bk_run != NULL) {
  314. st->backend->process_tokens(task, task->tokens, i, bk_run);
  315. }
  316. }
  317. }
  318. static void
  319. rspamd_stat_classifiers_process(struct rspamd_stat_ctx *st_ctx,
  320. struct rspamd_task *task)
  321. {
  322. unsigned int i, j, id;
  323. struct rspamd_classifier *cl;
  324. struct rspamd_statfile *st;
  325. gpointer bk_run;
  326. gboolean skip;
  327. if (st_ctx->classifiers->len == 0) {
  328. return;
  329. }
  330. /*
  331. * Do not classify a message if some class is missing
  332. */
  333. if (!(task->flags & RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS)) {
  334. msg_info_task("skip statistics as SPAM class is missing");
  335. return;
  336. }
  337. if (!(task->flags & RSPAMD_TASK_FLAG_HAS_HAM_TOKENS)) {
  338. msg_info_task("skip statistics as HAM class is missing");
  339. return;
  340. }
  341. for (i = 0; i < st_ctx->classifiers->len; i++) {
  342. cl = g_ptr_array_index(st_ctx->classifiers, i);
  343. cl->spam_learns = 0;
  344. cl->ham_learns = 0;
  345. }
  346. g_assert(task->stat_runtimes != NULL);
  347. for (i = 0; i < st_ctx->statfiles->len; i++) {
  348. st = g_ptr_array_index(st_ctx->statfiles, i);
  349. cl = st->classifier;
  350. bk_run = g_ptr_array_index(task->stat_runtimes, i);
  351. g_assert(st != NULL);
  352. if (bk_run != NULL) {
  353. if (st->stcf->is_spam) {
  354. cl->spam_learns += st->backend->total_learns(task,
  355. bk_run,
  356. st_ctx);
  357. }
  358. else {
  359. cl->ham_learns += st->backend->total_learns(task,
  360. bk_run,
  361. st_ctx);
  362. }
  363. }
  364. }
  365. for (i = 0; i < st_ctx->classifiers->len; i++) {
  366. cl = g_ptr_array_index(st_ctx->classifiers, i);
  367. g_assert(cl != NULL);
  368. skip = FALSE;
  369. /* Do not process classifiers on backend failures */
  370. for (j = 0; j < cl->statfiles_ids->len; j++) {
  371. id = g_array_index(cl->statfiles_ids, int, j);
  372. bk_run = g_ptr_array_index(task->stat_runtimes, id);
  373. st = g_ptr_array_index(st_ctx->statfiles, id);
  374. if (bk_run != NULL) {
  375. if (!st->backend->finalize_process(task, bk_run, st_ctx)) {
  376. skip = TRUE;
  377. break;
  378. }
  379. }
  380. }
  381. /* Ensure that all symbols enabled */
  382. if (!skip && !(cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) {
  383. for (j = 0; j < cl->statfiles_ids->len; j++) {
  384. id = g_array_index(cl->statfiles_ids, int, j);
  385. bk_run = g_ptr_array_index(task->stat_runtimes, id);
  386. st = g_ptr_array_index(st_ctx->statfiles, id);
  387. if (bk_run == NULL) {
  388. skip = TRUE;
  389. msg_debug_bayes("disable classifier %s as statfile symbol %s is disabled",
  390. cl->cfg->name, st->stcf->symbol);
  391. break;
  392. }
  393. }
  394. }
  395. if (!skip) {
  396. if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
  397. msg_debug_bayes(
  398. "contains less tokens than required for %s classifier: "
  399. "%ud < %ud",
  400. cl->cfg->name,
  401. task->tokens->len,
  402. cl->cfg->min_tokens);
  403. continue;
  404. }
  405. else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
  406. msg_debug_bayes(
  407. "contains more tokens than allowed for %s classifier: "
  408. "%ud > %ud",
  409. cl->cfg->name,
  410. task->tokens->len,
  411. cl->cfg->max_tokens);
  412. continue;
  413. }
  414. cl->subrs->classify_func(cl, task->tokens, task);
  415. }
  416. }
  417. }
  418. rspamd_stat_result_t
  419. rspamd_stat_classify(struct rspamd_task *task, lua_State *L, unsigned int stage,
  420. GError **err)
  421. {
  422. struct rspamd_stat_ctx *st_ctx;
  423. rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
  424. st_ctx = rspamd_stat_get_ctx();
  425. g_assert(st_ctx != NULL);
  426. if (st_ctx->classifiers->len == 0) {
  427. task->processed_stages |= stage;
  428. return ret;
  429. }
  430. if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) {
  431. /* Preprocess tokens */
  432. rspamd_stat_preprocess(st_ctx, task, FALSE, FALSE);
  433. }
  434. else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) {
  435. /* Process backends */
  436. rspamd_stat_backends_process(st_ctx, task);
  437. }
  438. else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_POST) {
  439. /* Process classifiers */
  440. rspamd_stat_classifiers_process(st_ctx, task);
  441. }
  442. task->processed_stages |= stage;
  443. return ret;
  444. }
  445. static gboolean
  446. rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx,
  447. struct rspamd_task *task,
  448. const char *classifier,
  449. gboolean spam,
  450. GError **err)
  451. {
  452. rspamd_learn_t learn_res = RSPAMD_LEARN_OK;
  453. struct rspamd_classifier *cl, *sel = NULL;
  454. gpointer rt;
  455. unsigned int i;
  456. /* Check whether we have learned that file */
  457. for (i = 0; i < st_ctx->classifiers->len; i++) {
  458. cl = g_ptr_array_index(st_ctx->classifiers, i);
  459. /* Skip other classifiers if they are not needed */
  460. if (classifier != NULL && (cl->cfg->name == NULL ||
  461. g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
  462. continue;
  463. }
  464. sel = cl;
  465. if (sel->cache && sel->cachecf) {
  466. rt = cl->cache->runtime(task, sel->cachecf, FALSE);
  467. learn_res = cl->cache->check(task, spam, rt);
  468. }
  469. if (learn_res == RSPAMD_LEARN_IGNORE) {
  470. /* Do not learn twice */
  471. g_set_error(err, rspamd_stat_quark(), 404, "<%s> has been already "
  472. "learned as %s, ignore it",
  473. MESSAGE_FIELD(task, message_id),
  474. spam ? "spam" : "ham");
  475. task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
  476. return FALSE;
  477. }
  478. else if (learn_res == RSPAMD_LEARN_UNLEARN) {
  479. task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
  480. break;
  481. }
  482. }
  483. if (sel == NULL) {
  484. if (classifier) {
  485. g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
  486. "with name %s",
  487. classifier);
  488. }
  489. else {
  490. g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
  491. }
  492. return FALSE;
  493. }
  494. return TRUE;
  495. }
  496. static gboolean
  497. rspamd_stat_classifiers_learn(struct rspamd_stat_ctx *st_ctx,
  498. struct rspamd_task *task,
  499. const char *classifier,
  500. gboolean spam,
  501. GError **err)
  502. {
  503. struct rspamd_classifier *cl, *sel = NULL;
  504. unsigned int i;
  505. gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
  506. if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
  507. *err == NULL) {
  508. /* Do not learn twice */
  509. g_set_error(err, rspamd_stat_quark(), 208, "<%s> has been already "
  510. "learned as %s, ignore it",
  511. MESSAGE_FIELD(task, message_id),
  512. spam ? "spam" : "ham");
  513. return FALSE;
  514. }
  515. /* Check whether we have learned that file */
  516. for (i = 0; i < st_ctx->classifiers->len; i++) {
  517. cl = g_ptr_array_index(st_ctx->classifiers, i);
  518. /* Skip other classifiers if they are not needed */
  519. if (classifier != NULL && (cl->cfg->name == NULL ||
  520. g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
  521. continue;
  522. }
  523. sel = cl;
  524. /* Now check max and min tokens */
  525. if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
  526. msg_info_task(
  527. "<%s> contains less tokens than required for %s classifier: "
  528. "%ud < %ud",
  529. MESSAGE_FIELD(task, message_id),
  530. cl->cfg->name,
  531. task->tokens->len,
  532. cl->cfg->min_tokens);
  533. too_small = TRUE;
  534. continue;
  535. }
  536. else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
  537. msg_info_task(
  538. "<%s> contains more tokens than allowed for %s classifier: "
  539. "%ud > %ud",
  540. MESSAGE_FIELD(task, message_id),
  541. cl->cfg->name,
  542. task->tokens->len,
  543. cl->cfg->max_tokens);
  544. too_large = TRUE;
  545. continue;
  546. }
  547. if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam,
  548. task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
  549. learned = TRUE;
  550. }
  551. }
  552. if (sel == NULL) {
  553. if (classifier) {
  554. g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
  555. "with name %s",
  556. classifier);
  557. }
  558. else {
  559. g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
  560. }
  561. return FALSE;
  562. }
  563. if (!learned && err && *err == NULL) {
  564. if (too_large) {
  565. g_set_error(err, rspamd_stat_quark(), 204,
  566. "<%s> contains more tokens than allowed for %s classifier: "
  567. "%d > %d",
  568. MESSAGE_FIELD(task, message_id),
  569. sel->cfg->name,
  570. task->tokens->len,
  571. sel->cfg->max_tokens);
  572. }
  573. else if (too_small) {
  574. g_set_error(err, rspamd_stat_quark(), 204,
  575. "<%s> contains less tokens than required for %s classifier: "
  576. "%d < %d",
  577. MESSAGE_FIELD(task, message_id),
  578. sel->cfg->name,
  579. task->tokens->len,
  580. sel->cfg->min_tokens);
  581. }
  582. }
  583. return learned;
  584. }
  585. static gboolean
  586. rspamd_stat_backends_learn(struct rspamd_stat_ctx *st_ctx,
  587. struct rspamd_task *task,
  588. const char *classifier,
  589. gboolean spam,
  590. GError **err)
  591. {
  592. struct rspamd_classifier *cl, *sel = NULL;
  593. struct rspamd_statfile *st;
  594. gpointer bk_run;
  595. unsigned int i, j;
  596. int id;
  597. gboolean res = FALSE, backend_found = FALSE;
  598. for (i = 0; i < st_ctx->classifiers->len; i++) {
  599. cl = g_ptr_array_index(st_ctx->classifiers, i);
  600. /* Skip other classifiers if they are not needed */
  601. if (classifier != NULL && (cl->cfg->name == NULL ||
  602. g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
  603. continue;
  604. }
  605. if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
  606. res = TRUE;
  607. continue;
  608. }
  609. sel = cl;
  610. for (j = 0; j < cl->statfiles_ids->len; j++) {
  611. id = g_array_index(cl->statfiles_ids, int, j);
  612. st = g_ptr_array_index(st_ctx->statfiles, id);
  613. bk_run = g_ptr_array_index(task->stat_runtimes, id);
  614. g_assert(st != NULL);
  615. if (bk_run == NULL) {
  616. /* XXX: must be error */
  617. if (task->result->passthrough_result) {
  618. /* Passthrough email, cannot learn */
  619. g_set_error(err, rspamd_stat_quark(), 204,
  620. "Cannot learn statistics when passthrough "
  621. "result has been set; not classified");
  622. res = FALSE;
  623. goto end;
  624. }
  625. msg_debug_task("no runtime for backend %s; classifier %s; symbol %s",
  626. st->backend->name, cl->cfg->name, st->stcf->symbol);
  627. continue;
  628. }
  629. /* We set sel merely when we have runtime */
  630. backend_found = TRUE;
  631. if (!(task->flags & RSPAMD_TASK_FLAG_UNLEARN)) {
  632. if (!!spam != !!st->stcf->is_spam) {
  633. /* If we are not unlearning, then do not touch another class */
  634. continue;
  635. }
  636. }
  637. if (!st->backend->learn_tokens(task, task->tokens, id, bk_run)) {
  638. g_set_error(err, rspamd_stat_quark(), 500,
  639. "Cannot push "
  640. "learned results to the backend");
  641. res = FALSE;
  642. goto end;
  643. }
  644. else {
  645. if (!!spam == !!st->stcf->is_spam) {
  646. st->backend->inc_learns(task, bk_run, st_ctx);
  647. }
  648. else if (task->flags & RSPAMD_TASK_FLAG_UNLEARN) {
  649. st->backend->dec_learns(task, bk_run, st_ctx);
  650. }
  651. res = TRUE;
  652. }
  653. }
  654. }
  655. end:
  656. if (!res) {
  657. if (err && *err) {
  658. /* Error has been set already */
  659. return res;
  660. }
  661. if (sel == NULL) {
  662. if (classifier) {
  663. g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
  664. "with name %s",
  665. classifier);
  666. }
  667. else {
  668. g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
  669. }
  670. return FALSE;
  671. }
  672. else if (!backend_found) {
  673. g_set_error(err, rspamd_stat_quark(), 204, "all learn conditions "
  674. "denied learning %s in %s",
  675. spam ? "spam" : "ham",
  676. classifier ? classifier : "default classifier");
  677. }
  678. else {
  679. g_set_error(err, rspamd_stat_quark(), 404, "cannot find statfile "
  680. "backend to learn %s in %s",
  681. spam ? "spam" : "ham",
  682. classifier ? classifier : "default classifier");
  683. }
  684. }
  685. return res;
  686. }
  687. static gboolean
  688. rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx,
  689. struct rspamd_task *task,
  690. const char *classifier,
  691. gboolean spam,
  692. GError **err)
  693. {
  694. struct rspamd_classifier *cl;
  695. struct rspamd_statfile *st;
  696. gpointer bk_run, cache_run;
  697. unsigned int i, j;
  698. int id;
  699. gboolean res = TRUE;
  700. for (i = 0; i < st_ctx->classifiers->len; i++) {
  701. cl = g_ptr_array_index(st_ctx->classifiers, i);
  702. /* Skip other classifiers if they are not needed */
  703. if (classifier != NULL && (cl->cfg->name == NULL ||
  704. g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
  705. continue;
  706. }
  707. if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
  708. res = TRUE;
  709. continue;
  710. }
  711. for (j = 0; j < cl->statfiles_ids->len; j++) {
  712. id = g_array_index(cl->statfiles_ids, int, j);
  713. st = g_ptr_array_index(st_ctx->statfiles, id);
  714. bk_run = g_ptr_array_index(task->stat_runtimes, id);
  715. g_assert(st != NULL);
  716. if (bk_run == NULL) {
  717. /* XXX: must be error */
  718. continue;
  719. }
  720. if (!st->backend->finalize_learn(task, bk_run, st_ctx, err)) {
  721. return RSPAMD_STAT_PROCESS_ERROR;
  722. }
  723. }
  724. if (cl->cache) {
  725. cache_run = cl->cache->runtime(task, cl->cachecf, TRUE);
  726. cl->cache->learn(task, spam, cache_run);
  727. }
  728. }
  729. g_atomic_int_add(&task->worker->srv->stat->messages_learned, 1);
  730. return res;
  731. }
  732. rspamd_stat_result_t
  733. rspamd_stat_learn(struct rspamd_task *task,
  734. gboolean spam, lua_State *L, const char *classifier, unsigned int stage,
  735. GError **err)
  736. {
  737. struct rspamd_stat_ctx *st_ctx;
  738. rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
  739. /*
  740. * We assume now that a task has been already classified before
  741. * coming to learn
  742. */
  743. g_assert(RSPAMD_TASK_IS_CLASSIFIED(task));
  744. st_ctx = rspamd_stat_get_ctx();
  745. g_assert(st_ctx != NULL);
  746. if (st_ctx->classifiers->len == 0) {
  747. task->processed_stages |= stage;
  748. return ret;
  749. }
  750. if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
  751. /* Process classifiers */
  752. rspamd_stat_preprocess(st_ctx, task, TRUE, spam);
  753. if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) {
  754. return RSPAMD_STAT_PROCESS_ERROR;
  755. }
  756. }
  757. else if (stage == RSPAMD_TASK_STAGE_LEARN) {
  758. /* Process classifiers */
  759. if (!rspamd_stat_classifiers_learn(st_ctx, task, classifier,
  760. spam, err)) {
  761. if (err && *err == NULL) {
  762. g_set_error(err, rspamd_stat_quark(), 500,
  763. "Unknown statistics error, found when learning classifiers;"
  764. " classifier: %s",
  765. task->classifier);
  766. }
  767. return RSPAMD_STAT_PROCESS_ERROR;
  768. }
  769. /* Process backends */
  770. if (!rspamd_stat_backends_learn(st_ctx, task, classifier, spam, err)) {
  771. if (err && *err == NULL) {
  772. g_set_error(err, rspamd_stat_quark(), 500,
  773. "Unknown statistics error, found when storing data on backend;"
  774. " classifier: %s",
  775. task->classifier);
  776. }
  777. return RSPAMD_STAT_PROCESS_ERROR;
  778. }
  779. }
  780. else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) {
  781. if (!rspamd_stat_backends_post_learn(st_ctx, task, classifier, spam, err)) {
  782. return RSPAMD_STAT_PROCESS_ERROR;
  783. }
  784. }
  785. task->processed_stages |= stage;
  786. return ret;
  787. }
  788. static gboolean
  789. rspamd_stat_has_classifier_symbols(struct rspamd_task *task,
  790. struct rspamd_scan_result *mres,
  791. struct rspamd_classifier *cl)
  792. {
  793. unsigned int i;
  794. int id;
  795. struct rspamd_statfile *st;
  796. struct rspamd_stat_ctx *st_ctx;
  797. gboolean is_spam;
  798. if (mres == NULL) {
  799. return FALSE;
  800. }
  801. st_ctx = rspamd_stat_get_ctx();
  802. is_spam = !!(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM);
  803. for (i = 0; i < cl->statfiles_ids->len; i++) {
  804. id = g_array_index(cl->statfiles_ids, int, i);
  805. st = g_ptr_array_index(st_ctx->statfiles, id);
  806. if (rspamd_task_find_symbol_result(task, st->stcf->symbol, NULL)) {
  807. if (is_spam == !!st->stcf->is_spam) {
  808. msg_debug_bayes("do not autolearn %s as symbol %s is already "
  809. "added",
  810. is_spam ? "spam" : "ham", st->stcf->symbol);
  811. return TRUE;
  812. }
  813. }
  814. }
  815. return FALSE;
  816. }
  817. gboolean
  818. rspamd_stat_check_autolearn(struct rspamd_task *task)
  819. {
  820. struct rspamd_stat_ctx *st_ctx;
  821. struct rspamd_classifier *cl;
  822. const ucl_object_t *obj, *elt1, *elt2;
  823. struct rspamd_scan_result *mres = NULL;
  824. struct rspamd_task **ptask;
  825. lua_State *L;
  826. unsigned int i;
  827. int err_idx;
  828. gboolean ret = FALSE;
  829. double ham_score, spam_score;
  830. const char *lua_script, *lua_ret;
  831. g_assert(RSPAMD_TASK_IS_CLASSIFIED(task));
  832. st_ctx = rspamd_stat_get_ctx();
  833. g_assert(st_ctx != NULL);
  834. L = task->cfg->lua_state;
  835. for (i = 0; i < st_ctx->classifiers->len; i++) {
  836. cl = g_ptr_array_index(st_ctx->classifiers, i);
  837. ret = FALSE;
  838. if (cl->cfg->opts) {
  839. obj = ucl_object_lookup(cl->cfg->opts, "autolearn");
  840. if (ucl_object_type(obj) == UCL_BOOLEAN) {
  841. /* Legacy true/false */
  842. if (ucl_object_toboolean(obj)) {
  843. /*
  844. * Default learning algorithm:
  845. *
  846. * - We learn spam if action is ACTION_REJECT
  847. * - We learn ham if score is less than zero
  848. */
  849. mres = task->result;
  850. if (mres) {
  851. if (mres->score > rspamd_task_get_required_score(task, mres)) {
  852. task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
  853. ret = TRUE;
  854. }
  855. else if (mres->score < 0) {
  856. task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
  857. ret = TRUE;
  858. }
  859. }
  860. }
  861. }
  862. else if (ucl_object_type(obj) == UCL_ARRAY && obj->len == 2) {
  863. /* Legacy thresholds */
  864. /*
  865. * We have an array of 2 elements, treat it as a
  866. * ham_score, spam_score
  867. */
  868. elt1 = ucl_array_find_index(obj, 0);
  869. elt2 = ucl_array_find_index(obj, 1);
  870. if ((ucl_object_type(elt1) == UCL_FLOAT ||
  871. ucl_object_type(elt1) == UCL_INT) &&
  872. (ucl_object_type(elt2) == UCL_FLOAT ||
  873. ucl_object_type(elt2) == UCL_INT)) {
  874. ham_score = ucl_object_todouble(elt1);
  875. spam_score = ucl_object_todouble(elt2);
  876. if (ham_score > spam_score) {
  877. double t;
  878. t = ham_score;
  879. ham_score = spam_score;
  880. spam_score = t;
  881. }
  882. mres = task->result;
  883. if (mres) {
  884. if (mres->score >= spam_score) {
  885. task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
  886. ret = TRUE;
  887. }
  888. else if (mres->score <= ham_score) {
  889. task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
  890. ret = TRUE;
  891. }
  892. }
  893. }
  894. }
  895. else if (ucl_object_type(obj) == UCL_STRING) {
  896. /* Legacy script */
  897. lua_script = ucl_object_tostring(obj);
  898. if (luaL_dostring(L, lua_script) != 0) {
  899. msg_err_task("cannot execute lua script for autolearn "
  900. "extraction: %s",
  901. lua_tostring(L, -1));
  902. }
  903. else {
  904. if (lua_type(L, -1) == LUA_TFUNCTION) {
  905. lua_pushcfunction(L, &rspamd_lua_traceback);
  906. err_idx = lua_gettop(L);
  907. lua_pushvalue(L, -2); /* Function itself */
  908. ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
  909. *ptask = task;
  910. rspamd_lua_setclass(L, rspamd_task_classname, -1);
  911. if (lua_pcall(L, 1, 1, err_idx) != 0) {
  912. msg_err_task("call to autolearn script failed: "
  913. "%s",
  914. lua_tostring(L, -1));
  915. }
  916. else {
  917. lua_ret = lua_tostring(L, -1);
  918. /* We can have immediate results */
  919. if (lua_ret) {
  920. if (strcmp(lua_ret, "ham") == 0) {
  921. task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
  922. ret = TRUE;
  923. }
  924. else if (strcmp(lua_ret, "spam") == 0) {
  925. task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
  926. ret = TRUE;
  927. }
  928. }
  929. }
  930. /* Result + error function + original function */
  931. lua_pop(L, 3);
  932. }
  933. else {
  934. msg_err_task("lua script must return "
  935. "function(task) and not %s",
  936. lua_typename(L, lua_type(
  937. L, -1)));
  938. }
  939. }
  940. }
  941. else if (ucl_object_type(obj) == UCL_OBJECT) {
  942. /* Try to find autolearn callback */
  943. if (cl->autolearn_cbref == 0) {
  944. /* We don't have preprocessed cb id, so try to get it */
  945. if (!rspamd_lua_require_function(L, "lua_bayes_learn",
  946. "autolearn")) {
  947. msg_err_task("cannot get autolearn library from "
  948. "`lua_bayes_learn`");
  949. }
  950. else {
  951. cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX);
  952. }
  953. }
  954. if (cl->autolearn_cbref != -1) {
  955. lua_pushcfunction(L, &rspamd_lua_traceback);
  956. err_idx = lua_gettop(L);
  957. lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
  958. ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
  959. *ptask = task;
  960. rspamd_lua_setclass(L, rspamd_task_classname, -1);
  961. /* Push the whole object as well */
  962. ucl_object_push_lua(L, obj, true);
  963. if (lua_pcall(L, 2, 1, err_idx) != 0) {
  964. msg_err_task("call to autolearn script failed: "
  965. "%s",
  966. lua_tostring(L, -1));
  967. }
  968. else {
  969. lua_ret = lua_tostring(L, -1);
  970. if (lua_ret) {
  971. if (strcmp(lua_ret, "ham") == 0) {
  972. task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
  973. ret = TRUE;
  974. }
  975. else if (strcmp(lua_ret, "spam") == 0) {
  976. task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
  977. ret = TRUE;
  978. }
  979. }
  980. }
  981. lua_settop(L, err_idx - 1);
  982. }
  983. }
  984. if (ret) {
  985. /* Do not autolearn if we have this symbol already */
  986. if (rspamd_stat_has_classifier_symbols(task, mres, cl)) {
  987. ret = FALSE;
  988. task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM |
  989. RSPAMD_TASK_FLAG_LEARN_SPAM);
  990. }
  991. else if (mres != NULL) {
  992. if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
  993. msg_info_task("<%s>: autolearn ham for classifier "
  994. "'%s' as message's "
  995. "score is negative: %.2f",
  996. MESSAGE_FIELD(task, message_id), cl->cfg->name,
  997. mres->score);
  998. }
  999. else {
  1000. msg_info_task("<%s>: autolearn spam for classifier "
  1001. "'%s' as message's "
  1002. "action is reject, score: %.2f",
  1003. MESSAGE_FIELD(task, message_id), cl->cfg->name,
  1004. mres->score);
  1005. }
  1006. task->classifier = cl->cfg->name;
  1007. break;
  1008. }
  1009. }
  1010. }
  1011. }
  1012. return ret;
  1013. }
  1014. /**
  1015. * Get the overall statistics for all statfile backends
  1016. * @param cfg configuration
  1017. * @param total_learns the total number of learns is stored here
  1018. * @return array of statistical information
  1019. */
  1020. rspamd_stat_result_t
  1021. rspamd_stat_statistics(struct rspamd_task *task,
  1022. struct rspamd_config *cfg,
  1023. uint64_t *total_learns,
  1024. ucl_object_t **target)
  1025. {
  1026. struct rspamd_stat_ctx *st_ctx;
  1027. struct rspamd_classifier *cl;
  1028. struct rspamd_statfile *st;
  1029. gpointer backend_runtime;
  1030. ucl_object_t *res = NULL, *elt;
  1031. uint64_t learns = 0;
  1032. unsigned int i, j;
  1033. int id;
  1034. st_ctx = rspamd_stat_get_ctx();
  1035. g_assert(st_ctx != NULL);
  1036. res = ucl_object_typed_new(UCL_ARRAY);
  1037. for (i = 0; i < st_ctx->classifiers->len; i++) {
  1038. cl = g_ptr_array_index(st_ctx->classifiers, i);
  1039. if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
  1040. continue;
  1041. }
  1042. for (j = 0; j < cl->statfiles_ids->len; j++) {
  1043. id = g_array_index(cl->statfiles_ids, int, j);
  1044. st = g_ptr_array_index(st_ctx->statfiles, id);
  1045. backend_runtime = st->backend->runtime(task, st->stcf, FALSE,
  1046. st->bkcf, id);
  1047. elt = st->backend->get_stat(backend_runtime, st->bkcf);
  1048. if (elt && ucl_object_type(elt) == UCL_OBJECT) {
  1049. const ucl_object_t *rev = ucl_object_lookup(elt, "revision");
  1050. learns += ucl_object_toint(rev);
  1051. }
  1052. else {
  1053. learns += st->backend->total_learns(task, backend_runtime,
  1054. st->bkcf);
  1055. }
  1056. if (elt != NULL) {
  1057. ucl_array_append(res, elt);
  1058. }
  1059. }
  1060. }
  1061. if (total_learns != NULL) {
  1062. *total_learns = learns;
  1063. }
  1064. if (target) {
  1065. *target = res;
  1066. }
  1067. else {
  1068. ucl_object_unref(res);
  1069. }
  1070. return RSPAMD_STAT_PROCESS_OK;
  1071. }