You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

stat_process.c 32KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251
  1. /*
  2. * Copyright 2024 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "config.h"
  17. #include "stat_api.h"
  18. #include "rspamd.h"
  19. #include "stat_internal.h"
  20. #include "libmime/message.h"
  21. #include "libmime/images.h"
  22. #include "libserver/html/html.h"
  23. #include "lua/lua_common.h"
  24. #include "lua/lua_classnames.h"
  25. #include "libserver/mempool_vars_internal.h"
  26. #include "utlist.h"
  27. #include <math.h>
  28. #define RSPAMD_CLASSIFY_OP 0
  29. #define RSPAMD_LEARN_OP 1
  30. #define RSPAMD_UNLEARN_OP 2
  31. static const gdouble similarity_threshold = 80.0;
  32. static void
  33. rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx,
  34. struct rspamd_task *task)
  35. {
  36. GArray *ar;
  37. rspamd_stat_token_t elt;
  38. guint i;
  39. lua_State *L = task->cfg->lua_state;
  40. ar = g_array_sized_new(FALSE, FALSE, sizeof(elt), 16);
  41. memset(&elt, 0, sizeof(elt));
  42. elt.flags = RSPAMD_STAT_TOKEN_FLAG_META;
  43. if (st_ctx->lua_stat_tokens_ref != -1) {
  44. gint err_idx, ret;
  45. struct rspamd_task **ptask;
  46. lua_pushcfunction(L, &rspamd_lua_traceback);
  47. err_idx = lua_gettop(L);
  48. lua_rawgeti(L, LUA_REGISTRYINDEX, st_ctx->lua_stat_tokens_ref);
  49. ptask = lua_newuserdata(L, sizeof(*ptask));
  50. *ptask = task;
  51. rspamd_lua_setclass(L, rspamd_task_classname, -1);
  52. if ((ret = lua_pcall(L, 1, 1, err_idx)) != 0) {
  53. msg_err_task("call to stat_tokens lua "
  54. "script failed (%d): %s",
  55. ret, lua_tostring(L, -1));
  56. }
  57. else {
  58. if (lua_type(L, -1) != LUA_TTABLE) {
  59. msg_err_task("stat_tokens invocation must return "
  60. "table and not %s",
  61. lua_typename(L, lua_type(L, -1)));
  62. }
  63. else {
  64. guint vlen;
  65. rspamd_ftok_t tok;
  66. vlen = rspamd_lua_table_size(L, -1);
  67. for (i = 0; i < vlen; i++) {
  68. lua_rawgeti(L, -1, i + 1);
  69. tok.begin = lua_tolstring(L, -1, &tok.len);
  70. if (tok.begin && tok.len > 0) {
  71. elt.original.begin =
  72. rspamd_mempool_ftokdup(task->task_pool, &tok);
  73. elt.original.len = tok.len;
  74. elt.stemmed.begin = elt.original.begin;
  75. elt.stemmed.len = elt.original.len;
  76. elt.normalized.begin = elt.original.begin;
  77. elt.normalized.len = elt.original.len;
  78. g_array_append_val(ar, elt);
  79. }
  80. lua_pop(L, 1);
  81. }
  82. }
  83. }
  84. lua_settop(L, 0);
  85. }
  86. if (ar->len > 0) {
  87. st_ctx->tokenizer->tokenize_func(st_ctx,
  88. task,
  89. ar,
  90. TRUE,
  91. "M",
  92. task->tokens);
  93. }
  94. rspamd_mempool_add_destructor(task->task_pool,
  95. rspamd_array_free_hard, ar);
  96. }
  97. /*
  98. * Tokenize task using the tokenizer specified
  99. */
  100. void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx,
  101. struct rspamd_task *task)
  102. {
  103. struct rspamd_mime_text_part *part;
  104. rspamd_cryptobox_hash_state_t hst;
  105. rspamd_token_t *st_tok;
  106. guint i, reserved_len = 0;
  107. gdouble *pdiff;
  108. guchar hout[rspamd_cryptobox_HASHBYTES];
  109. gchar *b32_hout;
  110. if (st_ctx == NULL) {
  111. st_ctx = rspamd_stat_get_ctx();
  112. }
  113. g_assert(st_ctx != NULL);
  114. PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part)
  115. {
  116. if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) {
  117. reserved_len += part->utf_words->len;
  118. }
  119. /* XXX: normal window size */
  120. reserved_len += 5;
  121. }
  122. task->tokens = g_ptr_array_sized_new(reserved_len);
  123. rspamd_mempool_add_destructor(task->task_pool,
  124. rspamd_ptr_array_free_hard, task->tokens);
  125. rspamd_mempool_notify_alloc(task->task_pool, reserved_len * sizeof(gpointer));
  126. pdiff = rspamd_mempool_get_variable(task->task_pool, "parts_distance");
  127. PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part)
  128. {
  129. if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) {
  130. st_ctx->tokenizer->tokenize_func(st_ctx, task,
  131. part->utf_words, IS_TEXT_PART_UTF(part),
  132. NULL, task->tokens);
  133. }
  134. if (pdiff != NULL && (1.0 - *pdiff) * 100.0 > similarity_threshold) {
  135. msg_debug_bayes("message has two common parts (%.2f), so skip the last one",
  136. *pdiff);
  137. break;
  138. }
  139. }
  140. if (task->meta_words != NULL) {
  141. st_ctx->tokenizer->tokenize_func(st_ctx,
  142. task,
  143. task->meta_words,
  144. TRUE,
  145. "SUBJECT",
  146. task->tokens);
  147. }
  148. rspamd_stat_tokenize_parts_metadata(st_ctx, task);
  149. /* Produce signature */
  150. rspamd_cryptobox_hash_init(&hst, NULL, 0);
  151. PTR_ARRAY_FOREACH(task->tokens, i, st_tok)
  152. {
  153. rspamd_cryptobox_hash_update(&hst, (guchar *) &st_tok->data,
  154. sizeof(st_tok->data));
  155. }
  156. rspamd_cryptobox_hash_final(&hst, hout);
  157. b32_hout = rspamd_encode_base32(hout, sizeof(hout), RSPAMD_BASE32_DEFAULT);
  158. /*
  159. * We need to strip it to 32 characters providing ~160 bits of
  160. * hash distribution
  161. */
  162. b32_hout[32] = '\0';
  163. rspamd_mempool_set_variable(task->task_pool, RSPAMD_MEMPOOL_STAT_SIGNATURE,
  164. b32_hout, g_free);
  165. }
  166. static gboolean
  167. rspamd_stat_classifier_is_skipped(struct rspamd_task *task,
  168. struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam)
  169. {
  170. GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions;
  171. lua_State *L = task->cfg->lua_state;
  172. gboolean ret = FALSE;
  173. while (cur) {
  174. gint cb_ref = GPOINTER_TO_INT(cur->data);
  175. gint old_top = lua_gettop(L);
  176. gint nargs;
  177. lua_rawgeti(L, LUA_REGISTRYINDEX, cb_ref);
  178. /* Push task and two booleans: is_spam and is_unlearn */
  179. struct rspamd_task **ptask = lua_newuserdata(L, sizeof(*ptask));
  180. *ptask = task;
  181. rspamd_lua_setclass(L, rspamd_task_classname, -1);
  182. if (is_learn) {
  183. lua_pushboolean(L, is_spam);
  184. lua_pushboolean(L,
  185. task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
  186. nargs = 3;
  187. }
  188. else {
  189. nargs = 1;
  190. }
  191. if (lua_pcall(L, nargs, LUA_MULTRET, 0) != 0) {
  192. msg_err_task("call to %s failed: %s",
  193. "condition callback",
  194. lua_tostring(L, -1));
  195. }
  196. else {
  197. if (lua_isboolean(L, 1)) {
  198. if (!lua_toboolean(L, 1)) {
  199. ret = TRUE;
  200. }
  201. }
  202. if (lua_isstring(L, 2)) {
  203. if (ret) {
  204. msg_notice_task("%s condition for classifier %s returned: %s; skip classifier",
  205. is_learn ? "learn" : "classify", cl->cfg->name,
  206. lua_tostring(L, 2));
  207. }
  208. else {
  209. msg_info_task("%s condition for classifier %s returned: %s",
  210. is_learn ? "learn" : "classify", cl->cfg->name,
  211. lua_tostring(L, 2));
  212. }
  213. }
  214. else if (ret) {
  215. msg_notice_task("%s condition for classifier %s returned false; skip classifier",
  216. is_learn ? "learn" : "classify", cl->cfg->name);
  217. }
  218. if (ret) {
  219. lua_settop(L, old_top);
  220. break;
  221. }
  222. }
  223. lua_settop(L, old_top);
  224. cur = g_list_next(cur);
  225. }
  226. return ret;
  227. }
  228. static void
  229. rspamd_stat_preprocess(struct rspamd_stat_ctx *st_ctx,
  230. struct rspamd_task *task, gboolean is_learn, gboolean is_spam)
  231. {
  232. guint i;
  233. struct rspamd_statfile *st;
  234. gpointer bk_run;
  235. if (task->tokens == NULL) {
  236. rspamd_stat_process_tokenize(st_ctx, task);
  237. }
  238. task->stat_runtimes = g_ptr_array_sized_new(st_ctx->statfiles->len);
  239. g_ptr_array_set_size(task->stat_runtimes, st_ctx->statfiles->len);
  240. rspamd_mempool_add_destructor(task->task_pool,
  241. rspamd_ptr_array_free_hard, task->stat_runtimes);
  242. /* Temporary set all stat_runtimes to some max size to distinguish from NULL */
  243. for (i = 0; i < st_ctx->statfiles->len; i++) {
  244. g_ptr_array_index(task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE);
  245. }
  246. for (i = 0; i < st_ctx->classifiers->len; i++) {
  247. struct rspamd_classifier *cl = g_ptr_array_index(st_ctx->classifiers, i);
  248. gboolean skip_classifier = FALSE;
  249. if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
  250. skip_classifier = TRUE;
  251. }
  252. else {
  253. if (rspamd_stat_classifier_is_skipped(task, cl, is_learn, is_spam)) {
  254. skip_classifier = TRUE;
  255. }
  256. }
  257. if (skip_classifier) {
  258. /* Set NULL for all statfiles indexed by id */
  259. for (int j = 0; j < cl->statfiles_ids->len; j++) {
  260. int id = g_array_index(cl->statfiles_ids, gint, j);
  261. g_ptr_array_index(task->stat_runtimes, id) = NULL;
  262. }
  263. }
  264. }
  265. for (i = 0; i < st_ctx->statfiles->len; i++) {
  266. st = g_ptr_array_index(st_ctx->statfiles, i);
  267. g_assert(st != NULL);
  268. if (g_ptr_array_index(task->stat_runtimes, i) == NULL) {
  269. /* The whole classifier is skipped */
  270. continue;
  271. }
  272. if (is_learn && st->backend->read_only) {
  273. /* Read only backend, skip it */
  274. g_ptr_array_index(task->stat_runtimes, i) = NULL;
  275. continue;
  276. }
  277. if (!is_learn && !rspamd_symcache_is_symbol_enabled(task, task->cfg->cache,
  278. st->stcf->symbol)) {
  279. g_ptr_array_index(task->stat_runtimes, i) = NULL;
  280. msg_debug_bayes("symbol %s is disabled, skip classification",
  281. st->stcf->symbol);
  282. continue;
  283. }
  284. bk_run = st->backend->runtime(task, st->stcf, is_learn, st->bkcf, i);
  285. if (bk_run == NULL) {
  286. msg_err_task("cannot init backend %s for statfile %s",
  287. st->backend->name, st->stcf->symbol);
  288. }
  289. g_ptr_array_index(task->stat_runtimes, i) = bk_run;
  290. }
  291. }
  292. static void
  293. rspamd_stat_backends_process(struct rspamd_stat_ctx *st_ctx,
  294. struct rspamd_task *task)
  295. {
  296. guint i;
  297. struct rspamd_statfile *st;
  298. gpointer bk_run;
  299. g_assert(task->stat_runtimes != NULL);
  300. for (i = 0; i < st_ctx->statfiles->len; i++) {
  301. st = g_ptr_array_index(st_ctx->statfiles, i);
  302. bk_run = g_ptr_array_index(task->stat_runtimes, i);
  303. if (bk_run != NULL) {
  304. st->backend->process_tokens(task, task->tokens, i, bk_run);
  305. }
  306. }
  307. }
  308. static void
  309. rspamd_stat_classifiers_process(struct rspamd_stat_ctx *st_ctx,
  310. struct rspamd_task *task)
  311. {
  312. guint i, j, id;
  313. struct rspamd_classifier *cl;
  314. struct rspamd_statfile *st;
  315. gpointer bk_run;
  316. gboolean skip;
  317. if (st_ctx->classifiers->len == 0) {
  318. return;
  319. }
  320. /*
  321. * Do not classify a message if some class is missing
  322. */
  323. if (!(task->flags & RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS)) {
  324. msg_info_task("skip statistics as SPAM class is missing");
  325. return;
  326. }
  327. if (!(task->flags & RSPAMD_TASK_FLAG_HAS_HAM_TOKENS)) {
  328. msg_info_task("skip statistics as HAM class is missing");
  329. return;
  330. }
  331. for (i = 0; i < st_ctx->classifiers->len; i++) {
  332. cl = g_ptr_array_index(st_ctx->classifiers, i);
  333. cl->spam_learns = 0;
  334. cl->ham_learns = 0;
  335. }
  336. g_assert(task->stat_runtimes != NULL);
  337. for (i = 0; i < st_ctx->statfiles->len; i++) {
  338. st = g_ptr_array_index(st_ctx->statfiles, i);
  339. cl = st->classifier;
  340. bk_run = g_ptr_array_index(task->stat_runtimes, i);
  341. g_assert(st != NULL);
  342. if (bk_run != NULL) {
  343. if (st->stcf->is_spam) {
  344. cl->spam_learns += st->backend->total_learns(task,
  345. bk_run,
  346. st_ctx);
  347. }
  348. else {
  349. cl->ham_learns += st->backend->total_learns(task,
  350. bk_run,
  351. st_ctx);
  352. }
  353. }
  354. }
  355. for (i = 0; i < st_ctx->classifiers->len; i++) {
  356. cl = g_ptr_array_index(st_ctx->classifiers, i);
  357. g_assert(cl != NULL);
  358. skip = FALSE;
  359. /* Do not process classifiers on backend failures */
  360. for (j = 0; j < cl->statfiles_ids->len; j++) {
  361. id = g_array_index(cl->statfiles_ids, gint, j);
  362. bk_run = g_ptr_array_index(task->stat_runtimes, id);
  363. st = g_ptr_array_index(st_ctx->statfiles, id);
  364. if (bk_run != NULL) {
  365. if (!st->backend->finalize_process(task, bk_run, st_ctx)) {
  366. skip = TRUE;
  367. break;
  368. }
  369. }
  370. }
  371. /* Ensure that all symbols enabled */
  372. if (!skip && !(cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) {
  373. for (j = 0; j < cl->statfiles_ids->len; j++) {
  374. id = g_array_index(cl->statfiles_ids, gint, j);
  375. bk_run = g_ptr_array_index(task->stat_runtimes, id);
  376. st = g_ptr_array_index(st_ctx->statfiles, id);
  377. if (bk_run == NULL) {
  378. skip = TRUE;
  379. msg_debug_bayes("disable classifier %s as statfile symbol %s is disabled",
  380. cl->cfg->name, st->stcf->symbol);
  381. break;
  382. }
  383. }
  384. }
  385. if (!skip) {
  386. if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
  387. msg_debug_bayes(
  388. "contains less tokens than required for %s classifier: "
  389. "%ud < %ud",
  390. cl->cfg->name,
  391. task->tokens->len,
  392. cl->cfg->min_tokens);
  393. continue;
  394. }
  395. else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
  396. msg_debug_bayes(
  397. "contains more tokens than allowed for %s classifier: "
  398. "%ud > %ud",
  399. cl->cfg->name,
  400. task->tokens->len,
  401. cl->cfg->max_tokens);
  402. continue;
  403. }
  404. cl->subrs->classify_func(cl, task->tokens, task);
  405. }
  406. }
  407. }
  408. rspamd_stat_result_t
  409. rspamd_stat_classify(struct rspamd_task *task, lua_State *L, guint stage,
  410. GError **err)
  411. {
  412. struct rspamd_stat_ctx *st_ctx;
  413. rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
  414. st_ctx = rspamd_stat_get_ctx();
  415. g_assert(st_ctx != NULL);
  416. if (st_ctx->classifiers->len == 0) {
  417. task->processed_stages |= stage;
  418. return ret;
  419. }
  420. if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) {
  421. /* Preprocess tokens */
  422. rspamd_stat_preprocess(st_ctx, task, FALSE, FALSE);
  423. }
  424. else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) {
  425. /* Process backends */
  426. rspamd_stat_backends_process(st_ctx, task);
  427. }
  428. else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_POST) {
  429. /* Process classifiers */
  430. rspamd_stat_classifiers_process(st_ctx, task);
  431. }
  432. task->processed_stages |= stage;
  433. return ret;
  434. }
  435. static gboolean
  436. rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx,
  437. struct rspamd_task *task,
  438. const gchar *classifier,
  439. gboolean spam,
  440. GError **err)
  441. {
  442. rspamd_learn_t learn_res = RSPAMD_LEARN_OK;
  443. struct rspamd_classifier *cl, *sel = NULL;
  444. gpointer rt;
  445. guint i;
  446. /* Check whether we have learned that file */
  447. for (i = 0; i < st_ctx->classifiers->len; i++) {
  448. cl = g_ptr_array_index(st_ctx->classifiers, i);
  449. /* Skip other classifiers if they are not needed */
  450. if (classifier != NULL && (cl->cfg->name == NULL ||
  451. g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
  452. continue;
  453. }
  454. sel = cl;
  455. if (sel->cache && sel->cachecf) {
  456. rt = cl->cache->runtime(task, sel->cachecf, FALSE);
  457. learn_res = cl->cache->check(task, spam, rt);
  458. }
  459. if (learn_res == RSPAMD_LEARN_IGNORE) {
  460. /* Do not learn twice */
  461. g_set_error(err, rspamd_stat_quark(), 404, "<%s> has been already "
  462. "learned as %s, ignore it",
  463. MESSAGE_FIELD(task, message_id),
  464. spam ? "spam" : "ham");
  465. task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
  466. return FALSE;
  467. }
  468. else if (learn_res == RSPAMD_LEARN_UNLEARN) {
  469. task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
  470. break;
  471. }
  472. }
  473. if (sel == NULL) {
  474. if (classifier) {
  475. g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
  476. "with name %s",
  477. classifier);
  478. }
  479. else {
  480. g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
  481. }
  482. return FALSE;
  483. }
  484. return TRUE;
  485. }
  486. static gboolean
  487. rspamd_stat_classifiers_learn(struct rspamd_stat_ctx *st_ctx,
  488. struct rspamd_task *task,
  489. const gchar *classifier,
  490. gboolean spam,
  491. GError **err)
  492. {
  493. struct rspamd_classifier *cl, *sel = NULL;
  494. guint i;
  495. gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
  496. if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
  497. *err == NULL) {
  498. /* Do not learn twice */
  499. g_set_error(err, rspamd_stat_quark(), 208, "<%s> has been already "
  500. "learned as %s, ignore it",
  501. MESSAGE_FIELD(task, message_id),
  502. spam ? "spam" : "ham");
  503. return FALSE;
  504. }
  505. /* Check whether we have learned that file */
  506. for (i = 0; i < st_ctx->classifiers->len; i++) {
  507. cl = g_ptr_array_index(st_ctx->classifiers, i);
  508. /* Skip other classifiers if they are not needed */
  509. if (classifier != NULL && (cl->cfg->name == NULL ||
  510. g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
  511. continue;
  512. }
  513. sel = cl;
  514. /* Now check max and min tokens */
  515. if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
  516. msg_info_task(
  517. "<%s> contains less tokens than required for %s classifier: "
  518. "%ud < %ud",
  519. MESSAGE_FIELD(task, message_id),
  520. cl->cfg->name,
  521. task->tokens->len,
  522. cl->cfg->min_tokens);
  523. too_small = TRUE;
  524. continue;
  525. }
  526. else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
  527. msg_info_task(
  528. "<%s> contains more tokens than allowed for %s classifier: "
  529. "%ud > %ud",
  530. MESSAGE_FIELD(task, message_id),
  531. cl->cfg->name,
  532. task->tokens->len,
  533. cl->cfg->max_tokens);
  534. too_large = TRUE;
  535. continue;
  536. }
  537. if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam,
  538. task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
  539. learned = TRUE;
  540. }
  541. }
  542. if (sel == NULL) {
  543. if (classifier) {
  544. g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
  545. "with name %s",
  546. classifier);
  547. }
  548. else {
  549. g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
  550. }
  551. return FALSE;
  552. }
  553. if (!learned && err && *err == NULL) {
  554. if (too_large) {
  555. g_set_error(err, rspamd_stat_quark(), 204,
  556. "<%s> contains more tokens than allowed for %s classifier: "
  557. "%d > %d",
  558. MESSAGE_FIELD(task, message_id),
  559. sel->cfg->name,
  560. task->tokens->len,
  561. sel->cfg->max_tokens);
  562. }
  563. else if (too_small) {
  564. g_set_error(err, rspamd_stat_quark(), 204,
  565. "<%s> contains less tokens than required for %s classifier: "
  566. "%d < %d",
  567. MESSAGE_FIELD(task, message_id),
  568. sel->cfg->name,
  569. task->tokens->len,
  570. sel->cfg->min_tokens);
  571. }
  572. }
  573. return learned;
  574. }
  575. static gboolean
  576. rspamd_stat_backends_learn(struct rspamd_stat_ctx *st_ctx,
  577. struct rspamd_task *task,
  578. const gchar *classifier,
  579. gboolean spam,
  580. GError **err)
  581. {
  582. struct rspamd_classifier *cl, *sel = NULL;
  583. struct rspamd_statfile *st;
  584. gpointer bk_run;
  585. guint i, j;
  586. gint id;
  587. gboolean res = FALSE, backend_found = FALSE;
  588. for (i = 0; i < st_ctx->classifiers->len; i++) {
  589. cl = g_ptr_array_index(st_ctx->classifiers, i);
  590. /* Skip other classifiers if they are not needed */
  591. if (classifier != NULL && (cl->cfg->name == NULL ||
  592. g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
  593. continue;
  594. }
  595. if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
  596. res = TRUE;
  597. continue;
  598. }
  599. sel = cl;
  600. for (j = 0; j < cl->statfiles_ids->len; j++) {
  601. id = g_array_index(cl->statfiles_ids, gint, j);
  602. st = g_ptr_array_index(st_ctx->statfiles, id);
  603. bk_run = g_ptr_array_index(task->stat_runtimes, id);
  604. g_assert(st != NULL);
  605. if (bk_run == NULL) {
  606. /* XXX: must be error */
  607. if (task->result->passthrough_result) {
  608. /* Passthrough email, cannot learn */
  609. g_set_error(err, rspamd_stat_quark(), 204,
  610. "Cannot learn statistics when passthrough "
  611. "result has been set; not classified");
  612. res = FALSE;
  613. goto end;
  614. }
  615. msg_debug_task("no runtime for backend %s; classifier %s; symbol %s",
  616. st->backend->name, cl->cfg->name, st->stcf->symbol);
  617. continue;
  618. }
  619. /* We set sel merely when we have runtime */
  620. backend_found = TRUE;
  621. if (!(task->flags & RSPAMD_TASK_FLAG_UNLEARN)) {
  622. if (!!spam != !!st->stcf->is_spam) {
  623. /* If we are not unlearning, then do not touch another class */
  624. continue;
  625. }
  626. }
  627. if (!st->backend->learn_tokens(task, task->tokens, id, bk_run)) {
  628. g_set_error(err, rspamd_stat_quark(), 500,
  629. "Cannot push "
  630. "learned results to the backend");
  631. res = FALSE;
  632. goto end;
  633. }
  634. else {
  635. if (!!spam == !!st->stcf->is_spam) {
  636. st->backend->inc_learns(task, bk_run, st_ctx);
  637. }
  638. else if (task->flags & RSPAMD_TASK_FLAG_UNLEARN) {
  639. st->backend->dec_learns(task, bk_run, st_ctx);
  640. }
  641. res = TRUE;
  642. }
  643. }
  644. }
  645. end:
  646. if (!res) {
  647. if (err && *err) {
  648. /* Error has been set already */
  649. return res;
  650. }
  651. if (sel == NULL) {
  652. if (classifier) {
  653. g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
  654. "with name %s",
  655. classifier);
  656. }
  657. else {
  658. g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
  659. }
  660. return FALSE;
  661. }
  662. else if (!backend_found) {
  663. g_set_error(err, rspamd_stat_quark(), 204, "all learn conditions "
  664. "denied learning %s in %s",
  665. spam ? "spam" : "ham",
  666. classifier ? classifier : "default classifier");
  667. }
  668. else {
  669. g_set_error(err, rspamd_stat_quark(), 404, "cannot find statfile "
  670. "backend to learn %s in %s",
  671. spam ? "spam" : "ham",
  672. classifier ? classifier : "default classifier");
  673. }
  674. }
  675. return res;
  676. }
  677. static gboolean
  678. rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx,
  679. struct rspamd_task *task,
  680. const gchar *classifier,
  681. gboolean spam,
  682. GError **err)
  683. {
  684. struct rspamd_classifier *cl;
  685. struct rspamd_statfile *st;
  686. gpointer bk_run, cache_run;
  687. guint i, j;
  688. gint id;
  689. gboolean res = TRUE;
  690. for (i = 0; i < st_ctx->classifiers->len; i++) {
  691. cl = g_ptr_array_index(st_ctx->classifiers, i);
  692. /* Skip other classifiers if they are not needed */
  693. if (classifier != NULL && (cl->cfg->name == NULL ||
  694. g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
  695. continue;
  696. }
  697. if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
  698. res = TRUE;
  699. continue;
  700. }
  701. for (j = 0; j < cl->statfiles_ids->len; j++) {
  702. id = g_array_index(cl->statfiles_ids, gint, j);
  703. st = g_ptr_array_index(st_ctx->statfiles, id);
  704. bk_run = g_ptr_array_index(task->stat_runtimes, id);
  705. g_assert(st != NULL);
  706. if (bk_run == NULL) {
  707. /* XXX: must be error */
  708. continue;
  709. }
  710. if (!st->backend->finalize_learn(task, bk_run, st_ctx, err)) {
  711. return RSPAMD_STAT_PROCESS_ERROR;
  712. }
  713. }
  714. if (cl->cache) {
  715. cache_run = cl->cache->runtime(task, cl->cachecf, TRUE);
  716. cl->cache->learn(task, spam, cache_run);
  717. }
  718. }
  719. g_atomic_int_add(&task->worker->srv->stat->messages_learned, 1);
  720. return res;
  721. }
  722. rspamd_stat_result_t
  723. rspamd_stat_learn(struct rspamd_task *task,
  724. gboolean spam, lua_State *L, const gchar *classifier, guint stage,
  725. GError **err)
  726. {
  727. struct rspamd_stat_ctx *st_ctx;
  728. rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
  729. /*
  730. * We assume now that a task has been already classified before
  731. * coming to learn
  732. */
  733. g_assert(RSPAMD_TASK_IS_CLASSIFIED(task));
  734. st_ctx = rspamd_stat_get_ctx();
  735. g_assert(st_ctx != NULL);
  736. if (st_ctx->classifiers->len == 0) {
  737. task->processed_stages |= stage;
  738. return ret;
  739. }
  740. if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
  741. /* Process classifiers */
  742. rspamd_stat_preprocess(st_ctx, task, TRUE, spam);
  743. if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) {
  744. return RSPAMD_STAT_PROCESS_ERROR;
  745. }
  746. }
  747. else if (stage == RSPAMD_TASK_STAGE_LEARN) {
  748. /* Process classifiers */
  749. if (!rspamd_stat_classifiers_learn(st_ctx, task, classifier,
  750. spam, err)) {
  751. if (err && *err == NULL) {
  752. g_set_error(err, rspamd_stat_quark(), 500,
  753. "Unknown statistics error, found when learning classifiers;"
  754. " classifier: %s",
  755. task->classifier);
  756. }
  757. return RSPAMD_STAT_PROCESS_ERROR;
  758. }
  759. /* Process backends */
  760. if (!rspamd_stat_backends_learn(st_ctx, task, classifier, spam, err)) {
  761. if (err && *err == NULL) {
  762. g_set_error(err, rspamd_stat_quark(), 500,
  763. "Unknown statistics error, found when storing data on backend;"
  764. " classifier: %s",
  765. task->classifier);
  766. }
  767. return RSPAMD_STAT_PROCESS_ERROR;
  768. }
  769. }
  770. else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) {
  771. if (!rspamd_stat_backends_post_learn(st_ctx, task, classifier, spam, err)) {
  772. return RSPAMD_STAT_PROCESS_ERROR;
  773. }
  774. }
  775. task->processed_stages |= stage;
  776. return ret;
  777. }
  778. static gboolean
  779. rspamd_stat_has_classifier_symbols(struct rspamd_task *task,
  780. struct rspamd_scan_result *mres,
  781. struct rspamd_classifier *cl)
  782. {
  783. guint i;
  784. gint id;
  785. struct rspamd_statfile *st;
  786. struct rspamd_stat_ctx *st_ctx;
  787. gboolean is_spam;
  788. if (mres == NULL) {
  789. return FALSE;
  790. }
  791. st_ctx = rspamd_stat_get_ctx();
  792. is_spam = !!(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM);
  793. for (i = 0; i < cl->statfiles_ids->len; i++) {
  794. id = g_array_index(cl->statfiles_ids, gint, i);
  795. st = g_ptr_array_index(st_ctx->statfiles, id);
  796. if (rspamd_task_find_symbol_result(task, st->stcf->symbol, NULL)) {
  797. if (is_spam == !!st->stcf->is_spam) {
  798. msg_debug_bayes("do not autolearn %s as symbol %s is already "
  799. "added",
  800. is_spam ? "spam" : "ham", st->stcf->symbol);
  801. return TRUE;
  802. }
  803. }
  804. }
  805. return FALSE;
  806. }
  807. gboolean
  808. rspamd_stat_check_autolearn(struct rspamd_task *task)
  809. {
  810. struct rspamd_stat_ctx *st_ctx;
  811. struct rspamd_classifier *cl;
  812. const ucl_object_t *obj, *elt1, *elt2;
  813. struct rspamd_scan_result *mres = NULL;
  814. struct rspamd_task **ptask;
  815. lua_State *L;
  816. guint i;
  817. gint err_idx;
  818. gboolean ret = FALSE;
  819. gdouble ham_score, spam_score;
  820. const gchar *lua_script, *lua_ret;
  821. g_assert(RSPAMD_TASK_IS_CLASSIFIED(task));
  822. st_ctx = rspamd_stat_get_ctx();
  823. g_assert(st_ctx != NULL);
  824. L = task->cfg->lua_state;
  825. for (i = 0; i < st_ctx->classifiers->len; i++) {
  826. cl = g_ptr_array_index(st_ctx->classifiers, i);
  827. ret = FALSE;
  828. if (cl->cfg->opts) {
  829. obj = ucl_object_lookup(cl->cfg->opts, "autolearn");
  830. if (ucl_object_type(obj) == UCL_BOOLEAN) {
  831. /* Legacy true/false */
  832. if (ucl_object_toboolean(obj)) {
  833. /*
  834. * Default learning algorithm:
  835. *
  836. * - We learn spam if action is ACTION_REJECT
  837. * - We learn ham if score is less than zero
  838. */
  839. mres = task->result;
  840. if (mres) {
  841. if (mres->score > rspamd_task_get_required_score(task, mres)) {
  842. task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
  843. ret = TRUE;
  844. }
  845. else if (mres->score < 0) {
  846. task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
  847. ret = TRUE;
  848. }
  849. }
  850. }
  851. }
  852. else if (ucl_object_type(obj) == UCL_ARRAY && obj->len == 2) {
  853. /* Legacy thresholds */
  854. /*
  855. * We have an array of 2 elements, treat it as a
  856. * ham_score, spam_score
  857. */
  858. elt1 = ucl_array_find_index(obj, 0);
  859. elt2 = ucl_array_find_index(obj, 1);
  860. if ((ucl_object_type(elt1) == UCL_FLOAT ||
  861. ucl_object_type(elt1) == UCL_INT) &&
  862. (ucl_object_type(elt2) == UCL_FLOAT ||
  863. ucl_object_type(elt2) == UCL_INT)) {
  864. ham_score = ucl_object_todouble(elt1);
  865. spam_score = ucl_object_todouble(elt2);
  866. if (ham_score > spam_score) {
  867. gdouble t;
  868. t = ham_score;
  869. ham_score = spam_score;
  870. spam_score = t;
  871. }
  872. mres = task->result;
  873. if (mres) {
  874. if (mres->score >= spam_score) {
  875. task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
  876. ret = TRUE;
  877. }
  878. else if (mres->score <= ham_score) {
  879. task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
  880. ret = TRUE;
  881. }
  882. }
  883. }
  884. }
  885. else if (ucl_object_type(obj) == UCL_STRING) {
  886. /* Legacy script */
  887. lua_script = ucl_object_tostring(obj);
  888. if (luaL_dostring(L, lua_script) != 0) {
  889. msg_err_task("cannot execute lua script for autolearn "
  890. "extraction: %s",
  891. lua_tostring(L, -1));
  892. }
  893. else {
  894. if (lua_type(L, -1) == LUA_TFUNCTION) {
  895. lua_pushcfunction(L, &rspamd_lua_traceback);
  896. err_idx = lua_gettop(L);
  897. lua_pushvalue(L, -2); /* Function itself */
  898. ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
  899. *ptask = task;
  900. rspamd_lua_setclass(L, rspamd_task_classname, -1);
  901. if (lua_pcall(L, 1, 1, err_idx) != 0) {
  902. msg_err_task("call to autolearn script failed: "
  903. "%s",
  904. lua_tostring(L, -1));
  905. }
  906. else {
  907. lua_ret = lua_tostring(L, -1);
  908. /* We can have immediate results */
  909. if (lua_ret) {
  910. if (strcmp(lua_ret, "ham") == 0) {
  911. task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
  912. ret = TRUE;
  913. }
  914. else if (strcmp(lua_ret, "spam") == 0) {
  915. task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
  916. ret = TRUE;
  917. }
  918. }
  919. }
  920. /* Result + error function + original function */
  921. lua_pop(L, 3);
  922. }
  923. else {
  924. msg_err_task("lua script must return "
  925. "function(task) and not %s",
  926. lua_typename(L, lua_type(
  927. L, -1)));
  928. }
  929. }
  930. }
  931. else if (ucl_object_type(obj) == UCL_OBJECT) {
  932. /* Try to find autolearn callback */
  933. if (cl->autolearn_cbref == 0) {
  934. /* We don't have preprocessed cb id, so try to get it */
  935. if (!rspamd_lua_require_function(L, "lua_bayes_learn",
  936. "autolearn")) {
  937. msg_err_task("cannot get autolearn library from "
  938. "`lua_bayes_learn`");
  939. }
  940. else {
  941. cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX);
  942. }
  943. }
  944. if (cl->autolearn_cbref != -1) {
  945. lua_pushcfunction(L, &rspamd_lua_traceback);
  946. err_idx = lua_gettop(L);
  947. lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
  948. ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
  949. *ptask = task;
  950. rspamd_lua_setclass(L, rspamd_task_classname, -1);
  951. /* Push the whole object as well */
  952. ucl_object_push_lua(L, obj, true);
  953. if (lua_pcall(L, 2, 1, err_idx) != 0) {
  954. msg_err_task("call to autolearn script failed: "
  955. "%s",
  956. lua_tostring(L, -1));
  957. }
  958. else {
  959. lua_ret = lua_tostring(L, -1);
  960. if (lua_ret) {
  961. if (strcmp(lua_ret, "ham") == 0) {
  962. task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
  963. ret = TRUE;
  964. }
  965. else if (strcmp(lua_ret, "spam") == 0) {
  966. task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
  967. ret = TRUE;
  968. }
  969. }
  970. }
  971. lua_settop(L, err_idx - 1);
  972. }
  973. }
  974. if (ret) {
  975. /* Do not autolearn if we have this symbol already */
  976. if (rspamd_stat_has_classifier_symbols(task, mres, cl)) {
  977. ret = FALSE;
  978. task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM |
  979. RSPAMD_TASK_FLAG_LEARN_SPAM);
  980. }
  981. else if (mres != NULL) {
  982. if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
  983. msg_info_task("<%s>: autolearn ham for classifier "
  984. "'%s' as message's "
  985. "score is negative: %.2f",
  986. MESSAGE_FIELD(task, message_id), cl->cfg->name,
  987. mres->score);
  988. }
  989. else {
  990. msg_info_task("<%s>: autolearn spam for classifier "
  991. "'%s' as message's "
  992. "action is reject, score: %.2f",
  993. MESSAGE_FIELD(task, message_id), cl->cfg->name,
  994. mres->score);
  995. }
  996. task->classifier = cl->cfg->name;
  997. break;
  998. }
  999. }
  1000. }
  1001. }
  1002. return ret;
  1003. }
  1004. /**
  1005. * Get the overall statistics for all statfile backends
  1006. * @param cfg configuration
  1007. * @param total_learns the total number of learns is stored here
  1008. * @return array of statistical information
  1009. */
  1010. rspamd_stat_result_t
  1011. rspamd_stat_statistics(struct rspamd_task *task,
  1012. struct rspamd_config *cfg,
  1013. uint64_t *total_learns,
  1014. ucl_object_t **target)
  1015. {
  1016. struct rspamd_stat_ctx *st_ctx;
  1017. struct rspamd_classifier *cl;
  1018. struct rspamd_statfile *st;
  1019. gpointer backend_runtime;
  1020. ucl_object_t *res = NULL, *elt;
  1021. uint64_t learns = 0;
  1022. guint i, j;
  1023. gint id;
  1024. st_ctx = rspamd_stat_get_ctx();
  1025. g_assert(st_ctx != NULL);
  1026. res = ucl_object_typed_new(UCL_ARRAY);
  1027. for (i = 0; i < st_ctx->classifiers->len; i++) {
  1028. cl = g_ptr_array_index(st_ctx->classifiers, i);
  1029. if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
  1030. continue;
  1031. }
  1032. for (j = 0; j < cl->statfiles_ids->len; j++) {
  1033. id = g_array_index(cl->statfiles_ids, gint, j);
  1034. st = g_ptr_array_index(st_ctx->statfiles, id);
  1035. backend_runtime = st->backend->runtime(task, st->stcf, FALSE,
  1036. st->bkcf, id);
  1037. elt = st->backend->get_stat(backend_runtime, st->bkcf);
  1038. if (elt && ucl_object_type(elt) == UCL_OBJECT) {
  1039. const ucl_object_t *rev = ucl_object_lookup(elt, "revision");
  1040. learns += ucl_object_toint(rev);
  1041. }
  1042. else {
  1043. learns += st->backend->total_learns(task, backend_runtime,
  1044. st->bkcf);
  1045. }
  1046. if (elt != NULL) {
  1047. ucl_array_append(res, elt);
  1048. }
  1049. }
  1050. }
  1051. if (total_learns != NULL) {
  1052. *total_learns = learns;
  1053. }
  1054. if (target) {
  1055. *target = res;
  1056. }
  1057. else {
  1058. ucl_object_unref(res);
  1059. }
  1060. return RSPAMD_STAT_PROCESS_OK;
  1061. }