Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

redis_cache.cxx 6.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /*
  2. * Copyright 2024 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "config.h"
  17. // Include early to avoid `extern "C"` issues
  18. #include "lua/lua_common.h"
  19. #include "learn_cache.h"
  20. #include "rspamd.h"
  21. #include "stat_api.h"
  22. #include "stat_internal.h"
  23. #include "cryptobox.h"
  24. #include "ucl.h"
  25. #include "libmime/message.h"
  26. #include <memory>
  27. struct rspamd_redis_cache_ctx {
  28. lua_State *L;
  29. struct rspamd_statfile_config *stcf;
  30. int check_ref = -1;
  31. int learn_ref = -1;
  32. rspamd_redis_cache_ctx() = delete;
  33. explicit rspamd_redis_cache_ctx(lua_State *L)
  34. : L(L)
  35. {
  36. }
  37. ~rspamd_redis_cache_ctx()
  38. {
  39. if (check_ref != -1) {
  40. luaL_unref(L, LUA_REGISTRYINDEX, check_ref);
  41. }
  42. if (learn_ref != -1) {
  43. luaL_unref(L, LUA_REGISTRYINDEX, learn_ref);
  44. }
  45. }
  46. };
  47. static void
  48. rspamd_stat_cache_redis_generate_id(struct rspamd_task *task)
  49. {
  50. rspamd_cryptobox_hash_state_t st;
  51. rspamd_cryptobox_hash_init(&st, nullptr, 0);
  52. const auto *user = (const char *) rspamd_mempool_get_variable(task->task_pool, "stat_user");
  53. /* Use dedicated hash space for per users cache */
  54. if (user != nullptr) {
  55. rspamd_cryptobox_hash_update(&st, (const unsigned char *) user, strlen(user));
  56. }
  57. for (auto i = 0; i < task->tokens->len; i++) {
  58. const auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, i);
  59. rspamd_cryptobox_hash_update(&st, (const unsigned char *) &tok->data,
  60. sizeof(tok->data));
  61. }
  62. unsigned char out[rspamd_cryptobox_HASHBYTES];
  63. rspamd_cryptobox_hash_final(&st, out);
  64. auto *b32out = rspamd_mempool_alloc_array_type(task->task_pool,
  65. sizeof(out) * 8 / 5 + 3, char);
  66. auto out_sz = rspamd_encode_base32_buf(out, sizeof(out), b32out,
  67. sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT);
  68. if (out_sz > 0) {
  69. /* Zero terminate */
  70. b32out[out_sz] = '\0';
  71. rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, nullptr);
  72. }
  73. }
  74. gpointer
  75. rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx,
  76. struct rspamd_config *cfg,
  77. struct rspamd_statfile *st,
  78. const ucl_object_t *cf)
  79. {
  80. std::unique_ptr<rspamd_redis_cache_ctx> cache_ctx = std::make_unique<rspamd_redis_cache_ctx>(RSPAMD_LUA_CFG_STATE(cfg));
  81. auto *L = RSPAMD_LUA_CFG_STATE(cfg);
  82. lua_settop(L, 0);
  83. lua_pushcfunction(L, &rspamd_lua_traceback);
  84. auto err_idx = lua_gettop(L);
  85. /* Obtain function */
  86. if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_cache")) {
  87. msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_cache");
  88. lua_settop(L, err_idx - 1);
  89. return nullptr;
  90. }
  91. /* Push arguments */
  92. ucl_object_push_lua(L, st->classifier->cfg->opts, false);
  93. ucl_object_push_lua(L, st->stcf->opts, false);
  94. if (lua_pcall(L, 2, 2, err_idx) != 0) {
  95. msg_err("call to lua_bayes_init_cache "
  96. "script failed: %s",
  97. lua_tostring(L, -1));
  98. lua_settop(L, err_idx - 1);
  99. return nullptr;
  100. }
  101. /*
  102. * Results are in the stack:
  103. * top - 1 - check function (idx = -2)
  104. * top - learn function (idx = -1)
  105. */
  106. lua_pushvalue(L, -2);
  107. cache_ctx->check_ref = luaL_ref(L, LUA_REGISTRYINDEX);
  108. lua_pushvalue(L, -1);
  109. cache_ctx->learn_ref = luaL_ref(L, LUA_REGISTRYINDEX);
  110. lua_settop(L, err_idx - 1);
  111. return (gpointer) cache_ctx.release();
  112. }
  113. gpointer
  114. rspamd_stat_cache_redis_runtime(struct rspamd_task *task,
  115. gpointer c, gboolean learn)
  116. {
  117. auto *ctx = (struct rspamd_redis_cache_ctx *) c;
  118. if (task->tokens == nullptr || task->tokens->len == 0) {
  119. return nullptr;
  120. }
  121. if (!learn) {
  122. /* On check, we produce words_hash variable, on learn it is guaranteed to be set */
  123. rspamd_stat_cache_redis_generate_id(task);
  124. }
  125. return (void *) ctx;
  126. }
  127. static int
  128. rspamd_stat_cache_checked(lua_State *L)
  129. {
  130. auto *task = lua_check_task(L, 1);
  131. auto res = lua_toboolean(L, 2);
  132. if (res) {
  133. auto val = lua_tointeger(L, 3);
  134. if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
  135. (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
  136. /* Already learned */
  137. msg_info_task("<%s> has been already "
  138. "learned as %s, ignore it",
  139. MESSAGE_FIELD(task, message_id),
  140. (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
  141. task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
  142. }
  143. else {
  144. /* Unlearn flag */
  145. task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
  146. }
  147. }
  148. /* Ignore errors for now, as we can do nothing about them at the moment */
  149. return 0;
  150. }
  151. int rspamd_stat_cache_redis_check(struct rspamd_task *task,
  152. gboolean is_spam,
  153. gpointer runtime)
  154. {
  155. auto *ctx = (struct rspamd_redis_cache_ctx *) runtime;
  156. auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash");
  157. if (h == nullptr) {
  158. return RSPAMD_LEARN_IGNORE;
  159. }
  160. auto *L = ctx->L;
  161. lua_pushcfunction(L, &rspamd_lua_traceback);
  162. int err_idx = lua_gettop(L);
  163. /* Function arguments */
  164. lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->check_ref);
  165. rspamd_lua_task_push(L, task);
  166. lua_pushstring(L, h);
  167. lua_pushcclosure(L, &rspamd_stat_cache_checked, 0);
  168. if (lua_pcall(L, 3, 0, err_idx) != 0) {
  169. msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
  170. lua_settop(L, err_idx - 1);
  171. return RSPAMD_LEARN_IGNORE;
  172. }
  173. /* We need to return OK every time */
  174. return RSPAMD_LEARN_OK;
  175. }
  176. int rspamd_stat_cache_redis_learn(struct rspamd_task *task,
  177. gboolean is_spam,
  178. gpointer runtime)
  179. {
  180. auto *ctx = (struct rspamd_redis_cache_ctx *) runtime;
  181. if (rspamd_session_blocked(task->s)) {
  182. return RSPAMD_LEARN_IGNORE;
  183. }
  184. auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash");
  185. g_assert(h != nullptr);
  186. auto *L = ctx->L;
  187. lua_pushcfunction(L, &rspamd_lua_traceback);
  188. int err_idx = lua_gettop(L);
  189. /* Function arguments */
  190. lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
  191. rspamd_lua_task_push(L, task);
  192. lua_pushstring(L, h);
  193. lua_pushboolean(L, is_spam);
  194. if (lua_pcall(L, 3, 0, err_idx) != 0) {
  195. msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
  196. lua_settop(L, err_idx - 1);
  197. return RSPAMD_LEARN_IGNORE;
  198. }
  199. /* We need to return OK every time */
  200. return RSPAMD_LEARN_OK;
  201. }
  202. void rspamd_stat_cache_redis_close(gpointer c)
  203. {
  204. auto *ctx = (struct rspamd_redis_cache_ctx *) c;
  205. delete ctx;
  206. }