You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_classifier.c 6.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. /*
  2. * Copyright 2024 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "classifiers.h"
  17. #include "cfg_file.h"
  18. #include "stat_internal.h"
  19. #include "lua/lua_common.h"
  20. struct rspamd_lua_classifier_ctx {
  21. gchar *name;
  22. gint classify_ref;
  23. gint learn_ref;
  24. };
  25. static GHashTable *lua_classifiers = NULL;
  26. #define msg_err_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \
  27. "luacl", task->task_pool->tag.uid, \
  28. RSPAMD_LOG_FUNC, \
  29. __VA_ARGS__)
  30. #define msg_warn_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_WARNING, \
  31. "luacl", task->task_pool->tag.uid, \
  32. RSPAMD_LOG_FUNC, \
  33. __VA_ARGS__)
  34. #define msg_info_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_INFO, \
  35. "luacl", task->task_pool->tag.uid, \
  36. RSPAMD_LOG_FUNC, \
  37. __VA_ARGS__)
  38. #define msg_debug_luacl(...) rspamd_conditional_debug_fast(NULL, task->from_addr, \
  39. rspamd_luacl_log_id, "luacl", task->task_pool->tag.uid, \
  40. RSPAMD_LOG_FUNC, \
  41. __VA_ARGS__)
  42. INIT_LOG_MODULE(luacl)
  43. gboolean
  44. lua_classifier_init(struct rspamd_config *cfg,
  45. struct ev_loop *ev_base,
  46. struct rspamd_classifier *cl)
  47. {
  48. struct rspamd_lua_classifier_ctx *ctx;
  49. lua_State *L = cl->ctx->cfg->lua_state;
  50. gint cb_classify = -1, cb_learn = -1;
  51. if (lua_classifiers == NULL) {
  52. lua_classifiers = g_hash_table_new_full(rspamd_strcase_hash,
  53. rspamd_strcase_equal, g_free, g_free);
  54. }
  55. ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name);
  56. if (ctx != NULL) {
  57. msg_err_config("duplicate lua classifier definition: %s",
  58. cl->subrs->name);
  59. return FALSE;
  60. }
  61. lua_getglobal(L, "rspamd_classifiers");
  62. if (lua_type(L, -1) != LUA_TTABLE) {
  63. msg_err_config("cannot register classifier %s: no rspamd_classifier global",
  64. cl->subrs->name);
  65. lua_pop(L, 1);
  66. return FALSE;
  67. }
  68. lua_pushstring(L, cl->subrs->name);
  69. lua_gettable(L, -2);
  70. if (lua_type(L, -1) != LUA_TTABLE) {
  71. msg_err_config("cannot register classifier %s: bad lua type: %s",
  72. cl->subrs->name, lua_typename(L, lua_type(L, -1)));
  73. lua_pop(L, 2);
  74. return FALSE;
  75. }
  76. lua_pushstring(L, "classify");
  77. lua_gettable(L, -2);
  78. if (lua_type(L, -1) != LUA_TFUNCTION) {
  79. msg_err_config("cannot register classifier %s: bad lua type for classify: %s",
  80. cl->subrs->name, lua_typename(L, lua_type(L, -1)));
  81. lua_pop(L, 3);
  82. return FALSE;
  83. }
  84. cb_classify = luaL_ref(L, LUA_REGISTRYINDEX);
  85. lua_pushstring(L, "learn");
  86. lua_gettable(L, -2);
  87. if (lua_type(L, -1) != LUA_TFUNCTION) {
  88. msg_err_config("cannot register classifier %s: bad lua type for learn: %s",
  89. cl->subrs->name, lua_typename(L, lua_type(L, -1)));
  90. lua_pop(L, 3);
  91. return FALSE;
  92. }
  93. cb_learn = luaL_ref(L, LUA_REGISTRYINDEX);
  94. lua_pop(L, 2); /* Table + global */
  95. ctx = g_malloc0(sizeof(*ctx));
  96. ctx->name = g_strdup(cl->subrs->name);
  97. ctx->classify_ref = cb_classify;
  98. ctx->learn_ref = cb_learn;
  99. cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_NO_BACKEND;
  100. g_hash_table_insert(lua_classifiers, ctx->name, ctx);
  101. return TRUE;
  102. }
  103. gboolean
  104. lua_classifier_classify(struct rspamd_classifier *cl,
  105. GPtrArray *tokens,
  106. struct rspamd_task *task)
  107. {
  108. struct rspamd_lua_classifier_ctx *ctx;
  109. struct rspamd_task **ptask;
  110. struct rspamd_classifier_config **pcfg;
  111. lua_State *L;
  112. rspamd_token_t *tok;
  113. guint i;
  114. guint64 v;
  115. ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name);
  116. g_assert(ctx != NULL);
  117. L = task->cfg->lua_state;
  118. lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->classify_ref);
  119. ptask = lua_newuserdata(L, sizeof(*ptask));
  120. *ptask = task;
  121. rspamd_lua_setclass(L, rspamd_task_classname, -1);
  122. pcfg = lua_newuserdata(L, sizeof(*pcfg));
  123. *pcfg = cl->cfg;
  124. rspamd_lua_setclass(L, "rspamd{classifier}", -1);
  125. lua_createtable(L, tokens->len, 0);
  126. for (i = 0; i < tokens->len; i++) {
  127. tok = g_ptr_array_index(tokens, i);
  128. v = tok->data;
  129. lua_createtable(L, 3, 0);
  130. /* High word, low word, order */
  131. lua_pushinteger(L, (guint32) (v >> 32));
  132. lua_rawseti(L, -2, 1);
  133. lua_pushinteger(L, (guint32) (v));
  134. lua_rawseti(L, -2, 2);
  135. lua_pushinteger(L, tok->window_idx);
  136. lua_rawseti(L, -2, 3);
  137. lua_rawseti(L, -2, i + 1);
  138. }
  139. if (lua_pcall(L, 3, 0, 0) != 0) {
  140. msg_err_luacl("error running classify function for %s: %s", ctx->name,
  141. lua_tostring(L, -1));
  142. lua_pop(L, 1);
  143. return FALSE;
  144. }
  145. return TRUE;
  146. }
  147. gboolean
  148. lua_classifier_learn_spam(struct rspamd_classifier *cl,
  149. GPtrArray *tokens,
  150. struct rspamd_task *task,
  151. gboolean is_spam,
  152. gboolean unlearn,
  153. GError **err)
  154. {
  155. struct rspamd_lua_classifier_ctx *ctx;
  156. struct rspamd_task **ptask;
  157. struct rspamd_classifier_config **pcfg;
  158. lua_State *L;
  159. rspamd_token_t *tok;
  160. guint i;
  161. guint64 v;
  162. ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name);
  163. g_assert(ctx != NULL);
  164. L = task->cfg->lua_state;
  165. lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
  166. ptask = lua_newuserdata(L, sizeof(*ptask));
  167. *ptask = task;
  168. rspamd_lua_setclass(L, rspamd_task_classname, -1);
  169. pcfg = lua_newuserdata(L, sizeof(*pcfg));
  170. *pcfg = cl->cfg;
  171. rspamd_lua_setclass(L, "rspamd{classifier}", -1);
  172. lua_createtable(L, tokens->len, 0);
  173. for (i = 0; i < tokens->len; i++) {
  174. tok = g_ptr_array_index(tokens, i);
  175. v = 0;
  176. v = tok->data;
  177. lua_createtable(L, 3, 0);
  178. /* High word, low word, order */
  179. lua_pushinteger(L, (guint32) (v >> 32));
  180. lua_rawseti(L, -2, 1);
  181. lua_pushinteger(L, (guint32) (v));
  182. lua_rawseti(L, -2, 2);
  183. lua_pushinteger(L, tok->window_idx);
  184. lua_rawseti(L, -2, 3);
  185. lua_rawseti(L, -2, i + 1);
  186. }
  187. lua_pushboolean(L, is_spam);
  188. lua_pushboolean(L, unlearn);
  189. if (lua_pcall(L, 5, 0, 0) != 0) {
  190. msg_err_luacl("error running learn function for %s: %s", ctx->name,
  191. lua_tostring(L, -1));
  192. lua_pop(L, 1);
  193. return FALSE;
  194. }
  195. return TRUE;
  196. }