]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Feed fasttext language model with the pre-tokenized words
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 2 May 2023 16:56:14 +0000 (17:56 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 2 May 2023 16:56:14 +0000 (17:56 +0100)
src/libmime/lang_detection.c
src/libmime/lang_detection_fasttext.cxx
src/libmime/lang_detection_fasttext.h

index 62d04975c6396420ab430a127d8254f8a5cd42d8..d4d10b216d14331d5b53bd8461b4d7b880c3b02b 100644 (file)
@@ -1836,8 +1836,7 @@ rspamd_language_detector_detect (struct rspamd_task *task,
                if (rspamd_lang_detection_fasttext_is_enabled(d->fasttext_detector)) {
                        rspamd_fasttext_predict_result_t fasttext_predict_result =
                                rspamd_lang_detection_fasttext_detect(d->fasttext_detector,
-                                       part->utf_stripped_content->data,
-                                       part->utf_stripped_content->len, 4);
+                                       part->utf_words, 4);
 
                        ndetected = rspamd_lang_detection_fasttext_get_nlangs(fasttext_predict_result);
 
index 7e16414bc941117f3ad947aa4e9b528bcc90c179..b75668670baf34c875ba34c7393975128f066a98 100644 (file)
 #include "libserver/cfg_file.h"
 #include "libserver/logger.h"
 #include "fmt/core.h"
+#include "stat_api.h"
 #include <exception>
 #include <string>
 #include <string_view>
 #include <vector>
-#include <sstream>
-#include <streambuf>
 #endif
 
 #ifdef WITH_FASTTEXT
@@ -37,12 +36,6 @@ private:
        std::string model_fname;
        bool loaded;
 
-       struct one_shot_buf : public std::streambuf {
-               explicit one_shot_buf(const char *in, std::size_t sz) {
-                       auto deconst_in = const_cast<char *>(in);
-                       setg(deconst_in, deconst_in, deconst_in + sz);
-               }
-       };
 public:
        explicit fasttext_langdet(struct rspamd_config *cfg) {
                const auto *ucl_obj = cfg->rcl_obj;
@@ -74,27 +67,51 @@ public:
        ~fasttext_langdet() = default;
 
        auto is_enabled() const -> bool { return loaded; }
-       auto detect_language(const char *in, size_t len, int k) const -> std::vector<std::pair<fasttext::real, std::string>> *
+       auto word2vec(const char *in, std::size_t len, std::vector<std::int32_t> &word_ngramms) const {
+               if (!loaded) {
+                       return;
+               }
+
+               std::string tok{in, len};
+               const auto &dic = ft.getDictionary();
+               auto h = dic->hash(tok);
+               auto wid = dic->getId(tok, h);
+               auto type = wid < 0 ? dic->getType(tok) : dic->getType(wid);
+
+               if (type == fasttext::entry_type::word) {
+                       if (wid < 0) {
+                               auto pipelined_word = fmt::format("{}{}{}", fasttext::Dictionary::BOW, tok, fasttext::Dictionary::EOW);
+                               dic->computeSubwords(pipelined_word, word_ngramms);
+                       }
+                       else {
+                               if (ft.getArgs().maxn <= 0) {
+                                       word_ngramms.push_back(wid);
+                               }
+                               else {
+                                       const auto ngrams = dic->getSubwords(wid);
+                                       word_ngramms.insert(word_ngramms.end(), ngrams.cbegin(), ngrams.cend());
+                               }
+                       }
+               }
+       }
+       auto detect_language(std::vector<std::int32_t> &words, int k)
+               -> std::vector<std::pair<fasttext::real, std::string>> *
        {
                if (!loaded) {
                        return nullptr;
                }
 
-               /* Hack to deal with streams without copies */
-               one_shot_buf buf{in, len};
-               auto stream = std::istream{&buf};
                auto predictions = new std::vector<std::pair<fasttext::real, std::string>>;
                predictions->reserve(k);
-               auto res = ft.predictLine(stream, *predictions, k, 0.0f);
+               fasttext::Predictions line_predictions;
+               line_predictions.reserve(k);
+               ft.predict(k, words, line_predictions, 0.0f);
+               const auto *dict = ft.getDictionary().get();
 
-               if (res) {
-                       return predictions;
+               for (const auto &pred : line_predictions) {
+                       predictions->push_back(std::make_pair(std::exp(pred.first), dict->getLabel(pred.second)));
                }
-               else {
-                       delete predictions;
-               }
-
-               return nullptr;
+               return predictions;
        }
 
        auto model_info(void) const -> std::string {
@@ -150,15 +167,26 @@ bool rspamd_lang_detection_fasttext_is_enabled(void *ud)
 }
 
 rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
-                                                                                          const char *in, size_t len, int k)
+                                                                                                                                          GArray *utf_words,
+                                                                                                                                          int k)
 {
 #ifndef WITH_FASTTEXT
        return nullptr;
 #else
        /* Avoid too long inputs */
-       static const size_t max_fasttext_input_len = 1024 * 1024 * 1;
+       static const guint max_fasttext_input_len = 1024 * 1024;
        auto *real_model = FASTTEXT_MODEL_TO_C_API(ud);
-       auto *res = real_model->detect_language(in, std::min(max_fasttext_input_len, len), k);
+       std::vector<std::int32_t> words_vec;
+       words_vec.reserve(utf_words->len);
+
+       for (auto i = 0; i < std::min(utf_words->len, max_fasttext_input_len); i++) {
+               const auto *w = &g_array_index (utf_words, rspamd_stat_token_t, i);
+               if (w->original.len > 0) {
+                       real_model->word2vec(w->original.begin, w->original.len, words_vec);
+               }
+       }
+
+       auto *res = real_model->detect_language(words_vec, k);
 
        return (rspamd_fasttext_predict_result_t)res;
 #endif
index 2e8a9fe78c3eb68c8c8aae78560026d344e60f1c..9fb1db222d5cb79fd56ae64a6dd981be07df6ae3 100644 (file)
@@ -52,7 +52,7 @@ typedef  void * rspamd_fasttext_predict_result_t;
  * @return TRUE if language is detected
  */
 rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
-               const char *in, size_t len, int k);
+               GArray *utf_words, int k);
 
 /**
  * Get number of languages detected