123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- /*
- * Copyright 2023 Vsevolod Stakhov
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "lang_detection_fasttext.h"
-
- #ifdef WITH_FASTTEXT
- #include "fasttext/fasttext.h"
- #include "libserver/cfg_file.h"
- #include "libserver/logger.h"
- #include "fmt/core.h"
- #include "stat_api.h"
- #include <exception>
- #include <string_view>
- #include <vector>
- #endif
-
- #ifdef WITH_FASTTEXT
-
- EXTERN_LOG_MODULE_DEF(langdet);
- #define msg_debug_lang_det(...) rspamd_conditional_debug_fast(nullptr, nullptr, \
- rspamd_langdet_log_id, "langdet", task->task_pool->tag.uid, \
- __FUNCTION__, \
- __VA_ARGS__)
-
- namespace rspamd::langdet {
- class fasttext_langdet {
- private:
- fasttext::FastText ft;
- std::string model_fname;
- bool loaded = false;
-
- public:
- explicit fasttext_langdet(struct rspamd_config *cfg)
- {
- const auto *ucl_obj = cfg->cfg_ucl_obj;
- const auto *opts_section = ucl_object_find_key(ucl_obj, "lang_detection");
-
- if (opts_section) {
- const auto *model = ucl_object_find_key(opts_section, "fasttext_model");
-
- if (model) {
- try {
- ft.loadModel(ucl_object_tostring(model));
- loaded = true;
- model_fname = std::string{ucl_object_tostring(model)};
- } catch (std::exception &e) {
- auto err_message = fmt::format("cannot load fasttext model: {}", e.what());
- msg_err_config("%s", err_message.c_str());
- loaded = false;
- }
- }
- }
- }
-
- /* Disallow multiple initialisation */
- fasttext_langdet() = delete;
- fasttext_langdet(const fasttext_langdet &) = delete;
- fasttext_langdet(fasttext_langdet &&) = delete;
-
- ~fasttext_langdet() = default;
-
- auto is_enabled() const -> bool
- {
- return loaded;
- }
- auto word2vec(const char *in, std::size_t len, std::vector<std::int32_t> &word_ngramms) const
- {
- if (!loaded) {
- return;
- }
-
- std::string tok{in, len};
- const auto &dic = ft.getDictionary();
- auto h = dic->hash(tok);
- auto wid = dic->getId(tok, h);
- auto type = wid < 0 ? dic->getType(tok) : dic->getType(wid);
-
- if (type == fasttext::entry_type::word) {
- if (wid < 0) {
- auto pipelined_word = fmt::format("{}{}{}", fasttext::Dictionary::BOW, tok, fasttext::Dictionary::EOW);
- dic->computeSubwords(pipelined_word, word_ngramms);
- }
- else {
- if (ft.getArgs().maxn <= 0) {
- word_ngramms.push_back(wid);
- }
- else {
- const auto ngrams = dic->getSubwords(wid);
- word_ngramms.insert(word_ngramms.end(), ngrams.cbegin(), ngrams.cend());
- }
- }
- }
- }
- auto detect_language(std::vector<std::int32_t> &words, int k)
- -> std::vector<std::pair<fasttext::real, std::string>> *
- {
- if (!loaded) {
- return nullptr;
- }
-
- auto predictions = new std::vector<std::pair<fasttext::real, std::string>>;
- predictions->reserve(k);
- fasttext::Predictions line_predictions;
- line_predictions.reserve(k);
- ft.predict(k, words, line_predictions, 0.0f);
- const auto *dict = ft.getDictionary().get();
-
- for (const auto &pred: line_predictions) {
- predictions->push_back(std::make_pair(std::exp(pred.first), dict->getLabel(pred.second)));
- }
- return predictions;
- }
-
- auto model_info(void) const -> const std::string
- {
- if (!loaded) {
- static const auto not_loaded = std::string{"fasttext model is not loaded"};
- return not_loaded;
- }
- else {
- return fmt::format("fasttext model {}: {} languages, {} tokens", model_fname,
- ft.getDictionary()->nlabels(), ft.getDictionary()->ntokens());
- }
- }
- };
- }// namespace rspamd::langdet
- #endif
-
- /* C API part */
- G_BEGIN_DECLS
-
- #define FASTTEXT_MODEL_TO_C_API(p) reinterpret_cast<rspamd::langdet::fasttext_langdet *>(p)
- #define FASTTEXT_RESULT_TO_C_API(res) reinterpret_cast<std::vector<std::pair<fasttext::real, std::string>> *>(res)
-
- void *rspamd_lang_detection_fasttext_init(struct rspamd_config *cfg)
- {
- #ifndef WITH_FASTTEXT
- return nullptr;
- #else
- return (void *) new rspamd::langdet::fasttext_langdet(cfg);
- #endif
- }
-
- char *rspamd_lang_detection_fasttext_show_info(void *ud)
- {
- #ifndef WITH_FASTTEXT
- return g_strdup("fasttext is not compiled in");
- #else
- auto model_info = FASTTEXT_MODEL_TO_C_API(ud)->model_info();
-
- return g_strdup(model_info.c_str());
- #endif
- }
-
- bool rspamd_lang_detection_fasttext_is_enabled(void *ud)
- {
- #ifdef WITH_FASTTEXT
- auto *real_model = FASTTEXT_MODEL_TO_C_API(ud);
-
- if (real_model) {
- return real_model->is_enabled();
- }
- #endif
-
- return false;
- }
-
- rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
- struct rspamd_task *task,
- GArray *utf_words,
- int k)
- {
- #ifndef WITH_FASTTEXT
- return nullptr;
- #else
- /* Avoid too long inputs */
- static const unsigned int max_fasttext_input_len = 1024 * 1024;
- auto *real_model = FASTTEXT_MODEL_TO_C_API(ud);
- std::vector<std::int32_t> words_vec;
- words_vec.reserve(utf_words->len);
-
- for (auto i = 0; i < std::min(utf_words->len, max_fasttext_input_len); i++) {
- const auto *w = &g_array_index(utf_words, rspamd_stat_token_t, i);
- if (w->original.len > 0) {
- real_model->word2vec(w->original.begin, w->original.len, words_vec);
- }
- }
-
- msg_debug_lang_det("fasttext: got %z word tokens from %ud words", words_vec.size(), utf_words->len);
-
- auto *res = real_model->detect_language(words_vec, k);
-
- return (rspamd_fasttext_predict_result_t) res;
- #endif
- }
-
- void rspamd_lang_detection_fasttext_destroy(void *ud)
- {
- #ifdef WITH_FASTTEXT
- delete FASTTEXT_MODEL_TO_C_API(ud);
- #endif
- }
-
-
- unsigned int rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t res)
- {
- #ifdef WITH_FASTTEXT
- auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
-
- if (real_res) {
- return real_res->size();
- }
- #endif
- return 0;
- }
-
- const char *
- rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, unsigned int idx)
- {
- #ifdef WITH_FASTTEXT
- auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
-
- if (real_res && real_res->size() > idx) {
- /* Fasttext returns result in form __label__<lang>, so we need to remove __label__ prefix */
- auto lang = std::string_view{real_res->at(idx).second};
- if (lang.size() > sizeof("__label__") && lang.substr(0, sizeof("__label__") - 1) == "__label__") {
- lang.remove_prefix(sizeof("__label__") - 1);
- }
- return lang.data();
- }
- #endif
- return nullptr;
- }
-
- float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx)
- {
- #ifdef WITH_FASTTEXT
- auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
-
- if (real_res && real_res->size() > idx) {
- return real_res->at(idx).first;
- }
- #endif
- return 0.0f;
- }
-
- void rspamd_fasttext_predict_result_destroy(rspamd_fasttext_predict_result_t res)
- {
- #ifdef WITH_FASTTEXT
- auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
-
- delete real_res;
- #endif
- }
-
- G_END_DECLS
|