diff options
Diffstat (limited to 'src/libmime/lang_detection_fasttext.cxx')
-rw-r--r-- | src/libmime/lang_detection_fasttext.cxx | 45 |
1 files changed, 24 insertions, 21 deletions
diff --git a/src/libmime/lang_detection_fasttext.cxx b/src/libmime/lang_detection_fasttext.cxx index d9e4e7192..f06e8ccb6 100644 --- a/src/libmime/lang_detection_fasttext.cxx +++ b/src/libmime/lang_detection_fasttext.cxx @@ -30,10 +30,10 @@ #ifdef WITH_FASTTEXT EXTERN_LOG_MODULE_DEF(langdet); -#define msg_debug_lang_det(...) rspamd_conditional_debug_fast (nullptr, nullptr, \ - rspamd_langdet_log_id, "langdet", task->task_pool->tag.uid, \ - __FUNCTION__, \ - __VA_ARGS__) +#define msg_debug_lang_det(...) rspamd_conditional_debug_fast(nullptr, nullptr, \ + rspamd_langdet_log_id, "langdet", task->task_pool->tag.uid, \ + __FUNCTION__, \ + __VA_ARGS__) namespace rspamd::langdet { class fasttext_langdet { @@ -43,7 +43,8 @@ private: bool loaded; public: - explicit fasttext_langdet(struct rspamd_config *cfg) { + explicit fasttext_langdet(struct rspamd_config *cfg) + { const auto *ucl_obj = cfg->rcl_obj; const auto *opts_section = ucl_object_find_key(ucl_obj, "lang_detection"); @@ -55,8 +56,7 @@ public: ft.loadModel(ucl_object_tostring(model)); loaded = true; model_fname = std::string{ucl_object_tostring(model)}; - } - catch (std::exception &e) { + } catch (std::exception &e) { auto err_message = fmt::format("cannot load fasttext model: {}", e.what()); msg_err_config("%s", err_message.c_str()); loaded = false; @@ -72,8 +72,12 @@ public: ~fasttext_langdet() = default; - auto is_enabled() const -> bool { return loaded; } - auto word2vec(const char *in, std::size_t len, std::vector<std::int32_t> &word_ngramms) const { + auto is_enabled() const -> bool + { + return loaded; + } + auto word2vec(const char *in, std::size_t len, std::vector<std::int32_t> &word_ngramms) const + { if (!loaded) { return; } @@ -114,23 +118,24 @@ public: ft.predict(k, words, line_predictions, 0.0f); const auto *dict = ft.getDictionary().get(); - for (const auto &pred : line_predictions) { + for (const auto &pred: line_predictions) { predictions->push_back(std::make_pair(std::exp(pred.first), dict->getLabel(pred.second))); } return predictions; } - auto model_info(void) const -> std::string { + auto model_info(void) const -> std::string + { if (!loaded) { return "fasttext model is not loaded"; } else { return fmt::format("fasttext model {}: {} languages, {} tokens", model_fname, - ft.getDictionary()->nlabels(), ft.getDictionary()->ntokens()); + ft.getDictionary()->nlabels(), ft.getDictionary()->ntokens()); } } }; -} +}// namespace rspamd::langdet #endif /* C API part */ @@ -139,12 +144,12 @@ G_BEGIN_DECLS #define FASTTEXT_MODEL_TO_C_API(p) reinterpret_cast<rspamd::langdet::fasttext_langdet *>(p) #define FASTTEXT_RESULT_TO_C_API(res) reinterpret_cast<std::vector<std::pair<fasttext::real, std::string>> *>(res) -void* rspamd_lang_detection_fasttext_init(struct rspamd_config *cfg) +void *rspamd_lang_detection_fasttext_init(struct rspamd_config *cfg) { #ifndef WITH_FASTTEXT return nullptr; #else - return (void *)new rspamd::langdet::fasttext_langdet(cfg); + return (void *) new rspamd::langdet::fasttext_langdet(cfg); #endif } @@ -187,7 +192,7 @@ rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud, words_vec.reserve(utf_words->len); for (auto i = 0; i < std::min(utf_words->len, max_fasttext_input_len); i++) { - const auto *w = &g_array_index (utf_words, rspamd_stat_token_t, i); + const auto *w = &g_array_index(utf_words, rspamd_stat_token_t, i); if (w->original.len > 0) { real_model->word2vec(w->original.begin, w->original.len, words_vec); } @@ -197,7 +202,7 @@ rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud, auto *res = real_model->detect_language(words_vec, k); - return (rspamd_fasttext_predict_result_t)res; + return (rspamd_fasttext_predict_result_t) res; #endif } @@ -209,8 +214,7 @@ void rspamd_lang_detection_fasttext_destroy(void *ud) } -guint -rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t res) +guint rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t res) { #ifdef WITH_FASTTEXT auto *real_res = FASTTEXT_RESULT_TO_C_API(res); @@ -240,8 +244,7 @@ rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, un return nullptr; } -float -rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx) +float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx) { #ifdef WITH_FASTTEXT auto *real_res = FASTTEXT_RESULT_TO_C_API(res); |