diff options
-rw-r--r-- | src/libmime/lang_detection.c | 3 | ||||
-rw-r--r-- | src/libmime/lang_detection_fasttext.cxx | 74 | ||||
-rw-r--r-- | src/libmime/lang_detection_fasttext.h | 2 |
3 files changed, 53 insertions, 26 deletions
diff --git a/src/libmime/lang_detection.c b/src/libmime/lang_detection.c index 62d04975c..d4d10b216 100644 --- a/src/libmime/lang_detection.c +++ b/src/libmime/lang_detection.c @@ -1836,8 +1836,7 @@ rspamd_language_detector_detect (struct rspamd_task *task, if (rspamd_lang_detection_fasttext_is_enabled(d->fasttext_detector)) { rspamd_fasttext_predict_result_t fasttext_predict_result = rspamd_lang_detection_fasttext_detect(d->fasttext_detector, - part->utf_stripped_content->data, - part->utf_stripped_content->len, 4); + part->utf_words, 4); ndetected = rspamd_lang_detection_fasttext_get_nlangs(fasttext_predict_result); diff --git a/src/libmime/lang_detection_fasttext.cxx b/src/libmime/lang_detection_fasttext.cxx index 7e16414bc..b75668670 100644 --- a/src/libmime/lang_detection_fasttext.cxx +++ b/src/libmime/lang_detection_fasttext.cxx @@ -21,12 +21,11 @@ #include "libserver/cfg_file.h" #include "libserver/logger.h" #include "fmt/core.h" +#include "stat_api.h" #include <exception> #include <string> #include <string_view> #include <vector> -#include <sstream> -#include <streambuf> #endif #ifdef WITH_FASTTEXT @@ -37,12 +36,6 @@ private: std::string model_fname; bool loaded; - struct one_shot_buf : public std::streambuf { - explicit one_shot_buf(const char *in, std::size_t sz) { - auto deconst_in = const_cast<char *>(in); - setg(deconst_in, deconst_in, deconst_in + sz); - } - }; public: explicit fasttext_langdet(struct rspamd_config *cfg) { const auto *ucl_obj = cfg->rcl_obj; @@ -74,27 +67,51 @@ public: ~fasttext_langdet() = default; auto is_enabled() const -> bool { return loaded; } - auto detect_language(const char *in, size_t len, int k) const -> std::vector<std::pair<fasttext::real, std::string>> * + auto word2vec(const char *in, std::size_t len, std::vector<std::int32_t> &word_ngramms) const { + if (!loaded) { + return; + } + + std::string tok{in, len}; + const auto &dic = ft.getDictionary(); + auto h = dic->hash(tok); + auto wid = dic->getId(tok, h); + auto type = wid < 0 ? dic->getType(tok) : dic->getType(wid); + + if (type == fasttext::entry_type::word) { + if (wid < 0) { + auto pipelined_word = fmt::format("{}{}{}", fasttext::Dictionary::BOW, tok, fasttext::Dictionary::EOW); + dic->computeSubwords(pipelined_word, word_ngramms); + } + else { + if (ft.getArgs().maxn <= 0) { + word_ngramms.push_back(wid); + } + else { + const auto ngrams = dic->getSubwords(wid); + word_ngramms.insert(word_ngramms.end(), ngrams.cbegin(), ngrams.cend()); + } + } + } + } + auto detect_language(std::vector<std::int32_t> &words, int k) + -> std::vector<std::pair<fasttext::real, std::string>> * { if (!loaded) { return nullptr; } - /* Hack to deal with streams without copies */ - one_shot_buf buf{in, len}; - auto stream = std::istream{&buf}; auto predictions = new std::vector<std::pair<fasttext::real, std::string>>; predictions->reserve(k); - auto res = ft.predictLine(stream, *predictions, k, 0.0f); + fasttext::Predictions line_predictions; + line_predictions.reserve(k); + ft.predict(k, words, line_predictions, 0.0f); + const auto *dict = ft.getDictionary().get(); - if (res) { - return predictions; + for (const auto &pred : line_predictions) { + predictions->push_back(std::make_pair(std::exp(pred.first), dict->getLabel(pred.second))); } - else { - delete predictions; - } - - return nullptr; + return predictions; } auto model_info(void) const -> std::string { @@ -150,15 +167,26 @@ bool rspamd_lang_detection_fasttext_is_enabled(void *ud) } rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud, - const char *in, size_t len, int k) + GArray *utf_words, + int k) { #ifndef WITH_FASTTEXT return nullptr; #else /* Avoid too long inputs */ - static const size_t max_fasttext_input_len = 1024 * 1024 * 1; + static const guint max_fasttext_input_len = 1024 * 1024; auto *real_model = FASTTEXT_MODEL_TO_C_API(ud); - auto *res = real_model->detect_language(in, std::min(max_fasttext_input_len, len), k); + std::vector<std::int32_t> words_vec; + words_vec.reserve(utf_words->len); + + for (auto i = 0; i < std::min(utf_words->len, max_fasttext_input_len); i++) { + const auto *w = &g_array_index (utf_words, rspamd_stat_token_t, i); + if (w->original.len > 0) { + real_model->word2vec(w->original.begin, w->original.len, words_vec); + } + } + + auto *res = real_model->detect_language(words_vec, k); return (rspamd_fasttext_predict_result_t)res; #endif diff --git a/src/libmime/lang_detection_fasttext.h b/src/libmime/lang_detection_fasttext.h index 2e8a9fe78..9fb1db222 100644 --- a/src/libmime/lang_detection_fasttext.h +++ b/src/libmime/lang_detection_fasttext.h @@ -52,7 +52,7 @@ typedef void * rspamd_fasttext_predict_result_t; * @return TRUE if language is detected */ rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud, - const char *in, size_t len, int k); + GArray *utf_words, int k); /** * Get number of languages detected |