您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

redis_backend.cxx 27KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139
  1. /*
  2. * Copyright 2024 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "config.h"
  17. #include "lua/lua_common.h"
  18. #include "rspamd.h"
  19. #include "stat_internal.h"
  20. #include "upstream.h"
  21. #include "libserver/mempool_vars_internal.h"
  22. #include "fmt/core.h"
  23. #include "libutil/cxx/error.hxx"
  24. #include <string>
  25. #include <cstdint>
  26. #include <vector>
  27. #include <optional>
  28. #define msg_debug_stat_redis(...) rspamd_conditional_debug_fast(nullptr, nullptr, \
  29. rspamd_stat_redis_log_id, "stat_redis", task->task_pool->tag.uid, \
  30. RSPAMD_LOG_FUNC, \
  31. __VA_ARGS__)
  32. INIT_LOG_MODULE(stat_redis)
  33. #define REDIS_CTX(p) (reinterpret_cast<struct redis_stat_ctx *>(p))
  34. #define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime<float> *>(p))
  35. #define REDIS_DEFAULT_OBJECT "%s%l"
  36. #define REDIS_DEFAULT_USERS_OBJECT "%s%l%r"
  37. #define REDIS_DEFAULT_TIMEOUT 0.5
  38. #define REDIS_STAT_TIMEOUT 30
  39. #define REDIS_MAX_USERS 1000
  40. struct redis_stat_ctx {
  41. lua_State *L;
  42. struct rspamd_statfile_config *stcf;
  43. const char *redis_object = REDIS_DEFAULT_OBJECT;
  44. bool enable_users = false;
  45. bool store_tokens = false;
  46. bool enable_signatures = false;
  47. int cbref_user = -1;
  48. int cbref_classify = -1;
  49. int cbref_learn = -1;
  50. ucl_object_t *cur_stat = nullptr;
  51. explicit redis_stat_ctx(lua_State *_L)
  52. : L(_L)
  53. {
  54. }
  55. ~redis_stat_ctx()
  56. {
  57. if (cbref_user != -1) {
  58. luaL_unref(L, LUA_REGISTRYINDEX, cbref_user);
  59. }
  60. if (cbref_classify != -1) {
  61. luaL_unref(L, LUA_REGISTRYINDEX, cbref_classify);
  62. }
  63. if (cbref_learn != -1) {
  64. luaL_unref(L, LUA_REGISTRYINDEX, cbref_learn);
  65. }
  66. }
  67. };
  68. template<class T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
  69. struct redis_stat_runtime {
  70. struct redis_stat_ctx *ctx;
  71. struct rspamd_task *task;
  72. struct rspamd_statfile_config *stcf;
  73. GPtrArray *tokens = nullptr;
  74. const char *redis_object_expanded;
  75. std::uint64_t learned = 0;
  76. int id;
  77. std::vector<std::pair<int, T>> *results = nullptr;
  78. bool need_redis_call = true;
  79. std::optional<rspamd::util::error> err;
  80. using result_type = std::vector<std::pair<int, T>>;
  81. private:
  82. /* Called on connection termination */
  83. static void rt_dtor(gpointer data)
  84. {
  85. auto *rt = REDIS_RUNTIME(data);
  86. delete rt;
  87. }
  88. /* Avoid occasional deletion */
  89. ~redis_stat_runtime()
  90. {
  91. if (tokens) {
  92. g_ptr_array_unref(tokens);
  93. }
  94. delete results;
  95. }
  96. public:
  97. explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
  98. : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
  99. {
  100. rspamd_mempool_add_destructor(task->task_pool, redis_stat_runtime<T>::rt_dtor, this);
  101. }
  102. static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded,
  103. bool is_spam) -> std::optional<redis_stat_runtime<T> *>
  104. {
  105. auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
  106. auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str());
  107. if (res) {
  108. msg_debug_bayes("recovered runtime from mempool at %s", var_name.c_str());
  109. return reinterpret_cast<redis_stat_runtime<T> *>(res);
  110. }
  111. else {
  112. msg_debug_bayes("no runtime at %s", var_name.c_str());
  113. return std::nullopt;
  114. }
  115. }
  116. void set_results(std::vector<std::pair<int, T>> *results)
  117. {
  118. this->results = results;
  119. }
  120. /* Propagate results from internal representation to the tokens array */
  121. auto process_tokens(GPtrArray *tokens) const -> bool
  122. {
  123. rspamd_token_t *tok;
  124. if (!results) {
  125. return false;
  126. }
  127. for (auto [idx, val]: *results) {
  128. tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1);
  129. tok->values[id] = val;
  130. }
  131. return true;
  132. }
  133. auto save_in_mempool(bool is_spam) const
  134. {
  135. auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
  136. /* We do not set destructor for the variable, as it should be already added on creation */
  137. rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr);
  138. msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str());
  139. }
  140. };
  141. #define GET_TASK_ELT(task, elt) (task == nullptr ? nullptr : (task)->elt)
  142. static const char *M = "redis statistics";
  143. static GQuark
  144. rspamd_redis_stat_quark(void)
  145. {
  146. return g_quark_from_static_string(M);
  147. }
  148. /*
  149. * Non-static for lua unit testing
  150. */
  151. gsize rspamd_redis_expand_object(const char *pattern,
  152. struct redis_stat_ctx *ctx,
  153. struct rspamd_task *task,
  154. char **target)
  155. {
  156. gsize tlen = 0;
  157. const char *p = pattern, *elt;
  158. char *d, *end;
  159. enum {
  160. just_char,
  161. percent_char,
  162. mod_char
  163. } state = just_char;
  164. struct rspamd_statfile_config *stcf;
  165. lua_State *L = nullptr;
  166. struct rspamd_task **ptask;
  167. const char *rcpt = nullptr;
  168. int err_idx;
  169. g_assert(ctx != nullptr);
  170. g_assert(task != nullptr);
  171. stcf = ctx->stcf;
  172. L = RSPAMD_LUA_CFG_STATE(task->cfg);
  173. g_assert(L != nullptr);
  174. if (ctx->enable_users) {
  175. if (ctx->cbref_user == -1) {
  176. rcpt = rspamd_task_get_principal_recipient(task);
  177. }
  178. else {
  179. /* Execute lua function to get userdata */
  180. lua_pushcfunction(L, &rspamd_lua_traceback);
  181. err_idx = lua_gettop(L);
  182. lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->cbref_user);
  183. ptask = (struct rspamd_task **) lua_newuserdata(L, sizeof(struct rspamd_task *));
  184. *ptask = task;
  185. rspamd_lua_setclass(L, rspamd_task_classname, -1);
  186. if (lua_pcall(L, 1, 1, err_idx) != 0) {
  187. msg_err_task("call to user extraction script failed: %s",
  188. lua_tostring(L, -1));
  189. }
  190. else {
  191. rcpt = rspamd_mempool_strdup(task->task_pool, lua_tostring(L, -1));
  192. }
  193. /* Result + error function */
  194. lua_settop(L, err_idx - 1);
  195. }
  196. if (rcpt) {
  197. rspamd_mempool_set_variable(task->task_pool, "stat_user",
  198. (gpointer) rcpt, nullptr);
  199. }
  200. }
  201. /* Length calculation */
  202. while (*p) {
  203. switch (state) {
  204. case just_char:
  205. if (*p == '%') {
  206. state = percent_char;
  207. }
  208. else {
  209. tlen++;
  210. }
  211. p++;
  212. break;
  213. case percent_char:
  214. switch (*p) {
  215. case '%':
  216. tlen++;
  217. state = just_char;
  218. break;
  219. case 'u':
  220. elt = GET_TASK_ELT(task, auth_user);
  221. if (elt) {
  222. tlen += strlen(elt);
  223. }
  224. break;
  225. case 'r':
  226. if (rcpt == nullptr) {
  227. elt = rspamd_task_get_principal_recipient(task);
  228. }
  229. else {
  230. elt = rcpt;
  231. }
  232. if (elt) {
  233. tlen += strlen(elt);
  234. }
  235. break;
  236. case 'l':
  237. if (stcf->label) {
  238. tlen += strlen(stcf->label);
  239. }
  240. /* Label miss is OK */
  241. break;
  242. case 's':
  243. tlen += sizeof("RS") - 1;
  244. break;
  245. default:
  246. state = just_char;
  247. tlen++;
  248. break;
  249. }
  250. if (state == percent_char) {
  251. state = mod_char;
  252. }
  253. p++;
  254. break;
  255. case mod_char:
  256. switch (*p) {
  257. case 'd':
  258. p++;
  259. state = just_char;
  260. break;
  261. default:
  262. state = just_char;
  263. break;
  264. }
  265. break;
  266. }
  267. }
  268. if (target == nullptr) {
  269. return -1;
  270. }
  271. *target = (char *) rspamd_mempool_alloc(task->task_pool, tlen + 1);
  272. d = *target;
  273. end = d + tlen + 1;
  274. d[tlen] = '\0';
  275. p = pattern;
  276. state = just_char;
  277. /* Expand string */
  278. while (*p && d < end) {
  279. switch (state) {
  280. case just_char:
  281. if (*p == '%') {
  282. state = percent_char;
  283. }
  284. else {
  285. *d++ = *p;
  286. }
  287. p++;
  288. break;
  289. case percent_char:
  290. switch (*p) {
  291. case '%':
  292. *d++ = *p;
  293. state = just_char;
  294. break;
  295. case 'u':
  296. elt = GET_TASK_ELT(task, auth_user);
  297. if (elt) {
  298. d += rspamd_strlcpy(d, elt, end - d);
  299. }
  300. break;
  301. case 'r':
  302. if (rcpt == nullptr) {
  303. elt = rspamd_task_get_principal_recipient(task);
  304. }
  305. else {
  306. elt = rcpt;
  307. }
  308. if (elt) {
  309. d += rspamd_strlcpy(d, elt, end - d);
  310. }
  311. break;
  312. case 'l':
  313. if (stcf->label) {
  314. d += rspamd_strlcpy(d, stcf->label, end - d);
  315. }
  316. break;
  317. case 's':
  318. d += rspamd_strlcpy(d, "RS", end - d);
  319. break;
  320. default:
  321. state = just_char;
  322. *d++ = *p;
  323. break;
  324. }
  325. if (state == percent_char) {
  326. state = mod_char;
  327. }
  328. p++;
  329. break;
  330. case mod_char:
  331. switch (*p) {
  332. case 'd':
  333. /* TODO: not supported yet */
  334. p++;
  335. state = just_char;
  336. break;
  337. default:
  338. state = just_char;
  339. break;
  340. }
  341. break;
  342. }
  343. }
  344. return tlen;
  345. }
  346. static int
  347. rspamd_redis_stat_cb(lua_State *L)
  348. {
  349. const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
  350. auto *cfg = lua_check_config(L, 1);
  351. auto *backend = REDIS_CTX(rspamd_mempool_get_variable(cfg->cfg_pool, cookie));
  352. if (backend == nullptr) {
  353. msg_err("internal error: cookie %s is not found", cookie);
  354. return 0;
  355. }
  356. auto *cur_obj = ucl_object_lua_import(L, 2);
  357. msg_debug_bayes_cfg("got stat object for %s", backend->stcf->symbol);
  358. /* Enrich with some default values that are meaningless for redis */
  359. ucl_object_insert_key(cur_obj,
  360. ucl_object_typed_new(UCL_INT), "used", 0, false);
  361. ucl_object_insert_key(cur_obj,
  362. ucl_object_typed_new(UCL_INT), "total", 0, false);
  363. ucl_object_insert_key(cur_obj,
  364. ucl_object_typed_new(UCL_INT), "size", 0, false);
  365. ucl_object_insert_key(cur_obj,
  366. ucl_object_fromstring(backend->stcf->symbol),
  367. "symbol", 0, false);
  368. ucl_object_insert_key(cur_obj, ucl_object_fromstring("redis"),
  369. "type", 0, false);
  370. ucl_object_insert_key(cur_obj, ucl_object_fromint(0),
  371. "languages", 0, false);
  372. if (backend->cur_stat) {
  373. ucl_object_unref(backend->cur_stat);
  374. }
  375. backend->cur_stat = cur_obj;
  376. return 0;
  377. }
  378. static void
  379. rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
  380. const ucl_object_t *statfile_obj,
  381. const ucl_object_t *classifier_obj,
  382. struct rspamd_config *cfg)
  383. {
  384. const char *lua_script;
  385. const ucl_object_t *elt, *users_enabled;
  386. auto *L = RSPAMD_LUA_CFG_STATE(cfg);
  387. users_enabled = ucl_object_lookup_any(classifier_obj, "per_user",
  388. "users_enabled", nullptr);
  389. if (users_enabled != nullptr) {
  390. if (ucl_object_type(users_enabled) == UCL_BOOLEAN) {
  391. backend->enable_users = ucl_object_toboolean(users_enabled);
  392. backend->cbref_user = -1;
  393. }
  394. else if (ucl_object_type(users_enabled) == UCL_STRING) {
  395. lua_script = ucl_object_tostring(users_enabled);
  396. if (luaL_dostring(L, lua_script) != 0) {
  397. msg_err_config("cannot execute lua script for users "
  398. "extraction: %s",
  399. lua_tostring(L, -1));
  400. }
  401. else {
  402. if (lua_type(L, -1) == LUA_TFUNCTION) {
  403. backend->enable_users = TRUE;
  404. backend->cbref_user = luaL_ref(L,
  405. LUA_REGISTRYINDEX);
  406. }
  407. else {
  408. msg_err_config("lua script must return "
  409. "function(task) and not %s",
  410. lua_typename(L, lua_type(L, -1)));
  411. }
  412. }
  413. }
  414. }
  415. else {
  416. backend->enable_users = FALSE;
  417. backend->cbref_user = -1;
  418. }
  419. elt = ucl_object_lookup(classifier_obj, "prefix");
  420. if (elt == nullptr || ucl_object_type(elt) != UCL_STRING) {
  421. /* Default non-users statistics */
  422. if (backend->enable_users || backend->cbref_user != -1) {
  423. backend->redis_object = REDIS_DEFAULT_USERS_OBJECT;
  424. }
  425. else {
  426. backend->redis_object = REDIS_DEFAULT_OBJECT;
  427. }
  428. }
  429. else {
  430. /* XXX: sanity check */
  431. backend->redis_object = ucl_object_tostring(elt);
  432. }
  433. elt = ucl_object_lookup(classifier_obj, "store_tokens");
  434. if (elt) {
  435. backend->store_tokens = ucl_object_toboolean(elt);
  436. }
  437. else {
  438. backend->store_tokens = FALSE;
  439. }
  440. elt = ucl_object_lookup(classifier_obj, "signatures");
  441. if (elt) {
  442. backend->enable_signatures = ucl_object_toboolean(elt);
  443. }
  444. else {
  445. backend->enable_signatures = FALSE;
  446. }
  447. }
  448. gpointer
  449. rspamd_redis_init(struct rspamd_stat_ctx *ctx,
  450. struct rspamd_config *cfg, struct rspamd_statfile *st)
  451. {
  452. auto *L = RSPAMD_LUA_CFG_STATE(cfg);
  453. auto backend = std::make_unique<struct redis_stat_ctx>(L);
  454. lua_settop(L, 0);
  455. rspamd_redis_parse_classifier_opts(backend.get(), st->stcf->opts, st->classifier->cfg->opts, cfg);
  456. st->stcf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
  457. backend->stcf = st->stcf;
  458. lua_pushcfunction(L, &rspamd_lua_traceback);
  459. auto err_idx = lua_gettop(L);
  460. /* Obtain function */
  461. if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_statfile")) {
  462. msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_statfile");
  463. lua_settop(L, err_idx - 1);
  464. return nullptr;
  465. }
  466. /* Push arguments */
  467. ucl_object_push_lua(L, st->classifier->cfg->opts, false);
  468. ucl_object_push_lua(L, st->stcf->opts, false);
  469. lua_pushstring(L, backend->stcf->symbol);
  470. lua_pushboolean(L, backend->stcf->is_spam);
  471. /* Push event loop if there is one available (e.g. we are not in rspamadm mode) */
  472. if (ctx->event_loop) {
  473. auto **pev_base = (struct ev_loop **) lua_newuserdata(L, sizeof(struct ev_loop *));
  474. *pev_base = ctx->event_loop;
  475. rspamd_lua_setclass(L, rspamd_ev_base_classname, -1);
  476. }
  477. else {
  478. lua_pushnil(L);
  479. }
  480. /* Store backend in random cookie */
  481. char *cookie = (char *) rspamd_mempool_alloc(cfg->cfg_pool, 16);
  482. rspamd_random_hex(cookie, 16);
  483. cookie[15] = '\0';
  484. rspamd_mempool_set_variable(cfg->cfg_pool, cookie, backend.get(), nullptr);
  485. /* Callback + 1 upvalue */
  486. lua_pushstring(L, cookie);
  487. lua_pushcclosure(L, &rspamd_redis_stat_cb, 1);
  488. if (lua_pcall(L, 6, 2, err_idx) != 0) {
  489. msg_err("call to lua_bayes_init_classifier "
  490. "script failed: %s",
  491. lua_tostring(L, -1));
  492. lua_settop(L, err_idx - 1);
  493. return nullptr;
  494. }
  495. /* Results are in the stack:
  496. * top - 1 - classifier function (idx = -2)
  497. * top - learn function (idx = -1)
  498. */
  499. lua_pushvalue(L, -2);
  500. backend->cbref_classify = luaL_ref(L, LUA_REGISTRYINDEX);
  501. lua_pushvalue(L, -1);
  502. backend->cbref_learn = luaL_ref(L, LUA_REGISTRYINDEX);
  503. lua_settop(L, err_idx - 1);
  504. return backend.release();
  505. }
  506. gpointer
  507. rspamd_redis_runtime(struct rspamd_task *task,
  508. struct rspamd_statfile_config *stcf,
  509. gboolean learn, gpointer c, int _id)
  510. {
  511. struct redis_stat_ctx *ctx = REDIS_CTX(c);
  512. char *object_expanded = nullptr;
  513. g_assert(ctx != nullptr);
  514. g_assert(stcf != nullptr);
  515. if (rspamd_redis_expand_object(ctx->redis_object, ctx, task,
  516. &object_expanded) == 0) {
  517. msg_err_task("expansion for %s failed for symbol %s "
  518. "(maybe learning per user classifier with no user or recipient)",
  519. learn ? "learning" : "classifying",
  520. stcf->symbol);
  521. return nullptr;
  522. }
  523. /* Look for the cached results */
  524. if (!learn) {
  525. auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
  526. object_expanded, stcf->is_spam);
  527. if (maybe_existing) {
  528. auto *rt = maybe_existing.value();
  529. /* Update stcf and ctx to correspond to what we have been asked */
  530. rt->stcf = stcf;
  531. rt->ctx = ctx;
  532. return rt;
  533. }
  534. }
  535. /* No cached result (or learn), create new one */
  536. auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
  537. if (!learn) {
  538. /*
  539. * For check, we also need to create the opposite class runtime to avoid
  540. * double call for Redis scripts.
  541. * This runtime will be filled later.
  542. */
  543. auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
  544. object_expanded,
  545. !stcf->is_spam);
  546. if (!maybe_opposite_rt) {
  547. auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
  548. opposite_rt->save_in_mempool(!stcf->is_spam);
  549. opposite_rt->need_redis_call = false;
  550. }
  551. }
  552. rt->save_in_mempool(stcf->is_spam);
  553. return rt;
  554. }
  555. void rspamd_redis_close(gpointer p)
  556. {
  557. struct redis_stat_ctx *ctx = REDIS_CTX(p);
  558. delete ctx;
  559. }
  560. static constexpr auto
  561. msgpack_emit_str(const std::string_view st, char *out) -> std::size_t
  562. {
  563. auto len = st.size();
  564. constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb;
  565. auto blen = 0;
  566. if (len <= 0x1F) {
  567. blen = 1;
  568. out[0] = (len | fix_mask) & 0xff;
  569. }
  570. else if (len <= 0xff) {
  571. blen = 2;
  572. out[0] = l8_ch;
  573. out[1] = len & 0xff;
  574. }
  575. else if (len <= 0xffff) {
  576. uint16_t bl = GUINT16_TO_BE(len);
  577. blen = 3;
  578. out[0] = l16_ch;
  579. memcpy(&out[1], &bl, sizeof(bl));
  580. }
  581. else {
  582. uint32_t bl = GUINT32_TO_BE(len);
  583. blen = 5;
  584. out[0] = l32_ch;
  585. memcpy(&out[1], &bl, sizeof(bl));
  586. }
  587. memcpy(&out[blen], st.data(), st.size());
  588. return blen + len;
  589. }
  590. static constexpr auto
  591. msgpack_str_len(std::size_t len) -> std::size_t
  592. {
  593. if (len <= 0x1F) {
  594. return 1 + len;
  595. }
  596. else if (len <= 0xff) {
  597. return 2 + len;
  598. }
  599. else if (len <= 0xffff) {
  600. return 3 + len;
  601. }
  602. else {
  603. return 4 + len;
  604. }
  605. }
  606. /*
  607. * Serialise stat tokens to message pack
  608. */
  609. static char *
  610. rspamd_redis_serialize_tokens(struct rspamd_task *task, const char *prefix, GPtrArray *tokens, gsize *ser_len)
  611. {
  612. /* Each token is int64_t that requires 10 bytes (2 int32_t) + 4 bytes array len + 1 byte array magic */
  613. char max_int64_str[] = "18446744073709551615";
  614. auto prefix_len = strlen(prefix);
  615. std::size_t req_len = 5;
  616. rspamd_token_t *tok;
  617. /* Calculate required length */
  618. req_len += tokens->len * (msgpack_str_len(sizeof(max_int64_str) + prefix_len) + 1);
  619. auto *buf = (char *) rspamd_mempool_alloc(task->task_pool, req_len);
  620. auto *p = buf;
  621. /* Array */
  622. *p++ = (char) 0xdd;
  623. /* Length in big-endian (4 bytes) */
  624. *p++ = (char) ((tokens->len >> 24) & 0xff);
  625. *p++ = (char) ((tokens->len >> 16) & 0xff);
  626. *p++ = (char) ((tokens->len >> 8) & 0xff);
  627. *p++ = (char) (tokens->len & 0xff);
  628. int i;
  629. auto numbuf_len = sizeof(max_int64_str) + prefix_len + 1;
  630. auto *numbuf = (char *) g_alloca(numbuf_len);
  631. PTR_ARRAY_FOREACH(tokens, i, tok)
  632. {
  633. std::size_t r = rspamd_snprintf(numbuf, numbuf_len, "%s_%uL", prefix, tok->data);
  634. auto shift = msgpack_emit_str({numbuf, r}, p);
  635. p += shift;
  636. }
  637. *ser_len = p - buf;
  638. return buf;
  639. }
  640. static char *
  641. rspamd_redis_serialize_text_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize *ser_len)
  642. {
  643. rspamd_token_t *tok;
  644. auto req_len = 5; /* Messagepack array prefix */
  645. int i;
  646. /*
  647. * First we need to determine the requested length
  648. */
  649. PTR_ARRAY_FOREACH(tokens, i, tok)
  650. {
  651. if (tok->t1 && tok->t2) {
  652. /* Two tokens */
  653. req_len += msgpack_str_len(tok->t1->stemmed.len) + msgpack_str_len(tok->t2->stemmed.len);
  654. }
  655. else if (tok->t1) {
  656. req_len += msgpack_str_len(tok->t1->stemmed.len);
  657. req_len += 1; /* null */
  658. }
  659. else {
  660. req_len += 2; /* 2 nulls */
  661. }
  662. }
  663. auto *buf = (char *) rspamd_mempool_alloc(task->task_pool, req_len);
  664. auto *p = buf;
  665. /* Array */
  666. std::uint32_t nlen = tokens->len * 2;
  667. nlen = GUINT32_TO_BE(nlen);
  668. *p++ = (char) 0xdd;
  669. /* Length in big-endian (4 bytes) */
  670. memcpy(p, &nlen, sizeof(nlen));
  671. p += sizeof(nlen);
  672. PTR_ARRAY_FOREACH(tokens, i, tok)
  673. {
  674. if (tok->t1 && tok->t2) {
  675. auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
  676. p += step;
  677. step = msgpack_emit_str({tok->t2->stemmed.begin, tok->t2->stemmed.len}, p);
  678. p += step;
  679. }
  680. else if (tok->t1) {
  681. auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
  682. p += step;
  683. *p++ = 0xc0;
  684. }
  685. else {
  686. *p++ = 0xc0;
  687. *p++ = 0xc0;
  688. }
  689. }
  690. *ser_len = p - buf;
  691. return buf;
  692. }
  693. static int
  694. rspamd_redis_classified(lua_State *L)
  695. {
  696. const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
  697. auto *task = lua_check_task(L, 1);
  698. auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
  699. if (rt == nullptr) {
  700. msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
  701. return 0;
  702. }
  703. bool result = lua_toboolean(L, 2);
  704. if (result) {
  705. /* Indexes:
  706. * 3 - learned_ham (int)
  707. * 4 - learned_spam (int)
  708. * 5 - ham_tokens (pair<int, int>)
  709. * 6 - spam_tokens (pair<int, int>)
  710. */
  711. /*
  712. * We need to fill our runtime AND the opposite runtime
  713. */
  714. auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
  715. rt->learned = learned;
  716. redis_stat_runtime<float>::result_type *res;
  717. res = new redis_stat_runtime<float>::result_type();
  718. for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) {
  719. lua_rawgeti(L, -1, 1);
  720. auto idx = lua_tointeger(L, -1);
  721. lua_pop(L, 1);
  722. lua_rawgeti(L, -1, 2);
  723. auto value = lua_tonumber(L, -1);
  724. lua_pop(L, 1);
  725. res->emplace_back(idx, value);
  726. }
  727. rt->set_results(res);
  728. };
  729. auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
  730. rt->redis_object_expanded,
  731. !rt->stcf->is_spam);
  732. if (!opposite_rt_maybe) {
  733. msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
  734. return 0;
  735. }
  736. if (rt->stcf->is_spam) {
  737. filler_func(rt, L, lua_tointeger(L, 4), 6);
  738. filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
  739. }
  740. else {
  741. filler_func(rt, L, lua_tointeger(L, 3), 5);
  742. filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
  743. }
  744. /* Mark task as being processed */
  745. task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
  746. /* Process all tokens */
  747. g_assert(rt->tokens != nullptr);
  748. rt->process_tokens(rt->tokens);
  749. opposite_rt_maybe.value()->process_tokens(rt->tokens);
  750. }
  751. else {
  752. /* Error message is on index 3 */
  753. const auto *err_msg = lua_tostring(L, 3);
  754. rt->err = rspamd::util::error(err_msg, 500);
  755. msg_err_task("cannot classify task: %s",
  756. err_msg);
  757. }
  758. return 0;
  759. }
  760. gboolean
  761. rspamd_redis_process_tokens(struct rspamd_task *task,
  762. GPtrArray *tokens,
  763. int id, gpointer p)
  764. {
  765. auto *rt = REDIS_RUNTIME(p);
  766. auto *L = rt->ctx->L;
  767. if (rspamd_session_blocked(task->s)) {
  768. return FALSE;
  769. }
  770. if (tokens == nullptr || tokens->len == 0) {
  771. return FALSE;
  772. }
  773. if (!rt->need_redis_call) {
  774. /* No need to do anything, as it is already done in the opposite class processing */
  775. /* However, we need to store id as it is needed for further tokens processing */
  776. rt->id = id;
  777. rt->tokens = g_ptr_array_ref(tokens);
  778. return TRUE;
  779. }
  780. gsize tokens_len;
  781. char *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len);
  782. rt->id = id;
  783. lua_pushcfunction(L, &rspamd_lua_traceback);
  784. int err_idx = lua_gettop(L);
  785. /* Function arguments */
  786. lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_classify);
  787. rspamd_lua_task_push(L, task);
  788. lua_pushstring(L, rt->redis_object_expanded);
  789. lua_pushinteger(L, id);
  790. lua_pushboolean(L, rt->stcf->is_spam);
  791. lua_new_text(L, tokens_buf, tokens_len, false);
  792. /* Store rt in random cookie */
  793. char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16);
  794. rspamd_random_hex(cookie, 16);
  795. cookie[15] = '\0';
  796. rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr);
  797. /* Callback */
  798. lua_pushstring(L, cookie);
  799. lua_pushcclosure(L, &rspamd_redis_classified, 1);
  800. if (lua_pcall(L, 6, 0, err_idx) != 0) {
  801. msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
  802. lua_settop(L, err_idx - 1);
  803. return FALSE;
  804. }
  805. rt->tokens = g_ptr_array_ref(tokens);
  806. lua_settop(L, err_idx - 1);
  807. return TRUE;
  808. }
  809. gboolean
  810. rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime,
  811. gpointer ctx)
  812. {
  813. auto *rt = REDIS_RUNTIME(runtime);
  814. return !rt->err.has_value();
  815. }
  816. static int
  817. rspamd_redis_learned(lua_State *L)
  818. {
  819. const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
  820. auto *task = lua_check_task(L, 1);
  821. auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
  822. if (rt == nullptr) {
  823. msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
  824. return 0;
  825. }
  826. bool result = lua_toboolean(L, 2);
  827. if (result) {
  828. /* TODO: write it */
  829. }
  830. else {
  831. /* Error message is on index 3 */
  832. const auto *err_msg = lua_tostring(L, 3);
  833. rt->err = rspamd::util::error(err_msg, 500);
  834. msg_err_task("cannot learn task: %s", err_msg);
  835. }
  836. return 0;
  837. }
  838. gboolean
  839. rspamd_redis_learn_tokens(struct rspamd_task *task,
  840. GPtrArray *tokens,
  841. int id, gpointer p)
  842. {
  843. auto *rt = REDIS_RUNTIME(p);
  844. auto *L = rt->ctx->L;
  845. if (rspamd_session_blocked(task->s)) {
  846. return FALSE;
  847. }
  848. if (tokens == nullptr || tokens->len == 0) {
  849. return FALSE;
  850. }
  851. gsize tokens_len;
  852. char *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len);
  853. rt->id = id;
  854. gsize text_tokens_len = 0;
  855. char *text_tokens_buf = nullptr;
  856. if (rt->ctx->store_tokens) {
  857. text_tokens_buf = rspamd_redis_serialize_text_tokens(task, tokens, &text_tokens_len);
  858. }
  859. lua_pushcfunction(L, &rspamd_lua_traceback);
  860. int err_idx = lua_gettop(L);
  861. auto nargs = 8;
  862. /* Function arguments */
  863. lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_learn);
  864. rspamd_lua_task_push(L, task);
  865. lua_pushstring(L, rt->redis_object_expanded);
  866. lua_pushinteger(L, id);
  867. lua_pushboolean(L, rt->stcf->is_spam);
  868. lua_pushstring(L, rt->stcf->symbol);
  869. /* Detect unlearn */
  870. auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, 0);
  871. if (tok->values[id] > 0) {
  872. lua_pushboolean(L, FALSE);// Learn
  873. }
  874. else {
  875. lua_pushboolean(L, TRUE);// Unlearn
  876. }
  877. lua_new_text(L, tokens_buf, tokens_len, false);
  878. /* Store rt in random cookie */
  879. char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16);
  880. rspamd_random_hex(cookie, 16);
  881. cookie[15] = '\0';
  882. rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr);
  883. /* Callback */
  884. lua_pushstring(L, cookie);
  885. lua_pushcclosure(L, &rspamd_redis_learned, 1);
  886. if (text_tokens_len) {
  887. nargs = 9;
  888. lua_new_text(L, text_tokens_buf, text_tokens_len, false);
  889. }
  890. if (lua_pcall(L, nargs, 0, err_idx) != 0) {
  891. msg_err_task("call to script failed: %s", lua_tostring(L, -1));
  892. lua_settop(L, err_idx - 1);
  893. return FALSE;
  894. }
  895. rt->tokens = g_ptr_array_ref(tokens);
  896. lua_settop(L, err_idx - 1);
  897. return TRUE;
  898. }
  899. gboolean
  900. rspamd_redis_finalize_learn(struct rspamd_task *task, gpointer runtime,
  901. gpointer ctx, GError **err)
  902. {
  903. auto *rt = REDIS_RUNTIME(runtime);
  904. if (rt->err.has_value()) {
  905. rt->err->into_g_error_set(rspamd_redis_stat_quark(), err);
  906. return FALSE;
  907. }
  908. return TRUE;
  909. }
  910. gulong
  911. rspamd_redis_total_learns(struct rspamd_task *task, gpointer runtime,
  912. gpointer ctx)
  913. {
  914. auto *rt = REDIS_RUNTIME(runtime);
  915. return rt->learned;
  916. }
  917. gulong
  918. rspamd_redis_inc_learns(struct rspamd_task *task, gpointer runtime,
  919. gpointer ctx)
  920. {
  921. auto *rt = REDIS_RUNTIME(runtime);
  922. /* XXX: may cause races */
  923. return rt->learned + 1;
  924. }
  925. gulong
  926. rspamd_redis_dec_learns(struct rspamd_task *task, gpointer runtime,
  927. gpointer ctx)
  928. {
  929. auto *rt = REDIS_RUNTIME(runtime);
  930. /* XXX: may cause races */
  931. return rt->learned + 1;
  932. }
  933. gulong
  934. rspamd_redis_learns(struct rspamd_task *task, gpointer runtime,
  935. gpointer ctx)
  936. {
  937. auto *rt = REDIS_RUNTIME(runtime);
  938. return rt->learned;
  939. }
  940. ucl_object_t *
  941. rspamd_redis_get_stat(gpointer runtime,
  942. gpointer ctx)
  943. {
  944. auto *rt = REDIS_RUNTIME(runtime);
  945. return ucl_object_ref(rt->ctx->cur_stat);
  946. }
  947. gpointer
  948. rspamd_redis_load_tokenizer_config(gpointer runtime,
  949. gsize *len)
  950. {
  951. return nullptr;
  952. }