You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lang_detection_fasttext.cxx 7.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. /*
  2. * Copyright 2023 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "lang_detection_fasttext.h"
  17. #ifdef WITH_FASTTEXT
  18. #include "fasttext/fasttext.h"
  19. #include "libserver/cfg_file.h"
  20. #include "libserver/logger.h"
  21. #include "fmt/core.h"
  22. #include "stat_api.h"
  23. #include <exception>
  24. #include <string_view>
  25. #include <vector>
  26. #endif
  27. #ifdef WITH_FASTTEXT
  28. EXTERN_LOG_MODULE_DEF(langdet);
  29. #define msg_debug_lang_det(...) rspamd_conditional_debug_fast(nullptr, nullptr, \
  30. rspamd_langdet_log_id, "langdet", task->task_pool->tag.uid, \
  31. __FUNCTION__, \
  32. __VA_ARGS__)
  33. namespace rspamd::langdet {
  34. class fasttext_langdet {
  35. private:
  36. fasttext::FastText ft;
  37. std::string model_fname;
  38. bool loaded = false;
  39. public:
  40. explicit fasttext_langdet(struct rspamd_config *cfg)
  41. {
  42. const auto *ucl_obj = cfg->cfg_ucl_obj;
  43. const auto *opts_section = ucl_object_find_key(ucl_obj, "lang_detection");
  44. if (opts_section) {
  45. const auto *model = ucl_object_find_key(opts_section, "fasttext_model");
  46. if (model) {
  47. try {
  48. ft.loadModel(ucl_object_tostring(model));
  49. loaded = true;
  50. model_fname = std::string{ucl_object_tostring(model)};
  51. } catch (std::exception &e) {
  52. auto err_message = fmt::format("cannot load fasttext model: {}", e.what());
  53. msg_err_config("%s", err_message.c_str());
  54. loaded = false;
  55. }
  56. }
  57. }
  58. }
  59. /* Disallow multiple initialisation */
  60. fasttext_langdet() = delete;
  61. fasttext_langdet(const fasttext_langdet &) = delete;
  62. fasttext_langdet(fasttext_langdet &&) = delete;
  63. ~fasttext_langdet() = default;
  64. auto is_enabled() const -> bool
  65. {
  66. return loaded;
  67. }
  68. auto word2vec(const char *in, std::size_t len, std::vector<std::int32_t> &word_ngramms) const
  69. {
  70. if (!loaded) {
  71. return;
  72. }
  73. std::string tok{in, len};
  74. const auto &dic = ft.getDictionary();
  75. auto h = dic->hash(tok);
  76. auto wid = dic->getId(tok, h);
  77. auto type = wid < 0 ? dic->getType(tok) : dic->getType(wid);
  78. if (type == fasttext::entry_type::word) {
  79. if (wid < 0) {
  80. auto pipelined_word = fmt::format("{}{}{}", fasttext::Dictionary::BOW, tok, fasttext::Dictionary::EOW);
  81. dic->computeSubwords(pipelined_word, word_ngramms);
  82. }
  83. else {
  84. if (ft.getArgs().maxn <= 0) {
  85. word_ngramms.push_back(wid);
  86. }
  87. else {
  88. const auto ngrams = dic->getSubwords(wid);
  89. word_ngramms.insert(word_ngramms.end(), ngrams.cbegin(), ngrams.cend());
  90. }
  91. }
  92. }
  93. }
  94. auto detect_language(std::vector<std::int32_t> &words, int k)
  95. -> std::vector<std::pair<fasttext::real, std::string>> *
  96. {
  97. if (!loaded) {
  98. return nullptr;
  99. }
  100. auto predictions = new std::vector<std::pair<fasttext::real, std::string>>;
  101. predictions->reserve(k);
  102. fasttext::Predictions line_predictions;
  103. line_predictions.reserve(k);
  104. ft.predict(k, words, line_predictions, 0.0f);
  105. const auto *dict = ft.getDictionary().get();
  106. for (const auto &pred: line_predictions) {
  107. predictions->push_back(std::make_pair(std::exp(pred.first), dict->getLabel(pred.second)));
  108. }
  109. return predictions;
  110. }
  111. auto model_info(void) const -> const std::string
  112. {
  113. if (!loaded) {
  114. static const auto not_loaded = std::string{"fasttext model is not loaded"};
  115. return not_loaded;
  116. }
  117. else {
  118. return fmt::format("fasttext model {}: {} languages, {} tokens", model_fname,
  119. ft.getDictionary()->nlabels(), ft.getDictionary()->ntokens());
  120. }
  121. }
  122. };
  123. }// namespace rspamd::langdet
  124. #endif
  125. /* C API part */
  126. G_BEGIN_DECLS
  127. #define FASTTEXT_MODEL_TO_C_API(p) reinterpret_cast<rspamd::langdet::fasttext_langdet *>(p)
  128. #define FASTTEXT_RESULT_TO_C_API(res) reinterpret_cast<std::vector<std::pair<fasttext::real, std::string>> *>(res)
  129. void *rspamd_lang_detection_fasttext_init(struct rspamd_config *cfg)
  130. {
  131. #ifndef WITH_FASTTEXT
  132. return nullptr;
  133. #else
  134. return (void *) new rspamd::langdet::fasttext_langdet(cfg);
  135. #endif
  136. }
  137. char *rspamd_lang_detection_fasttext_show_info(void *ud)
  138. {
  139. #ifndef WITH_FASTTEXT
  140. return g_strdup("fasttext is not compiled in");
  141. #else
  142. auto model_info = FASTTEXT_MODEL_TO_C_API(ud)->model_info();
  143. return g_strdup(model_info.c_str());
  144. #endif
  145. }
  146. bool rspamd_lang_detection_fasttext_is_enabled(void *ud)
  147. {
  148. #ifdef WITH_FASTTEXT
  149. auto *real_model = FASTTEXT_MODEL_TO_C_API(ud);
  150. if (real_model) {
  151. return real_model->is_enabled();
  152. }
  153. #endif
  154. return false;
  155. }
  156. rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
  157. struct rspamd_task *task,
  158. GArray *utf_words,
  159. int k)
  160. {
  161. #ifndef WITH_FASTTEXT
  162. return nullptr;
  163. #else
  164. /* Avoid too long inputs */
  165. static const unsigned int max_fasttext_input_len = 1024 * 1024;
  166. auto *real_model = FASTTEXT_MODEL_TO_C_API(ud);
  167. std::vector<std::int32_t> words_vec;
  168. words_vec.reserve(utf_words->len);
  169. for (auto i = 0; i < std::min(utf_words->len, max_fasttext_input_len); i++) {
  170. const auto *w = &g_array_index(utf_words, rspamd_stat_token_t, i);
  171. if (w->original.len > 0) {
  172. real_model->word2vec(w->original.begin, w->original.len, words_vec);
  173. }
  174. }
  175. msg_debug_lang_det("fasttext: got %z word tokens from %ud words", words_vec.size(), utf_words->len);
  176. auto *res = real_model->detect_language(words_vec, k);
  177. return (rspamd_fasttext_predict_result_t) res;
  178. #endif
  179. }
  180. void rspamd_lang_detection_fasttext_destroy(void *ud)
  181. {
  182. #ifdef WITH_FASTTEXT
  183. delete FASTTEXT_MODEL_TO_C_API(ud);
  184. #endif
  185. }
  186. unsigned int rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t res)
  187. {
  188. #ifdef WITH_FASTTEXT
  189. auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
  190. if (real_res) {
  191. return real_res->size();
  192. }
  193. #endif
  194. return 0;
  195. }
  196. const char *
  197. rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, unsigned int idx)
  198. {
  199. #ifdef WITH_FASTTEXT
  200. auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
  201. if (real_res && real_res->size() > idx) {
  202. /* Fasttext returns result in form __label__<lang>, so we need to remove __label__ prefix */
  203. auto lang = std::string_view{real_res->at(idx).second};
  204. if (lang.size() > sizeof("__label__") && lang.substr(0, sizeof("__label__") - 1) == "__label__") {
  205. lang.remove_prefix(sizeof("__label__") - 1);
  206. }
  207. return lang.data();
  208. }
  209. #endif
  210. return nullptr;
  211. }
  212. float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx)
  213. {
  214. #ifdef WITH_FASTTEXT
  215. auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
  216. if (real_res && real_res->size() > idx) {
  217. return real_res->at(idx).first;
  218. }
  219. #endif
  220. return 0.0f;
  221. }
  222. void rspamd_fasttext_predict_result_destroy(rspamd_fasttext_predict_result_t res)
  223. {
  224. #ifdef WITH_FASTTEXT
  225. auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
  226. delete real_res;
  227. #endif
  228. }
  229. G_END_DECLS