]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Some further fixes vstakhov-fasttext-langdet 4473/head
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 29 Apr 2023 16:44:16 +0000 (17:44 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 29 Apr 2023 16:44:16 +0000 (17:44 +0100)
src/libmime/lang_detection.c
src/libmime/lang_detection_fasttext.cxx

index d8e81e075249eebbb5b993bf58a30de8a81d716b..62d04975c6396420ab430a127d8254f8a5cd42d8 100644 (file)
@@ -174,8 +174,10 @@ KHASH_INIT (rspamd_stopwords_hash, rspamd_ftok_t *,
                char, false,
                rspamd_ftok_hash, rspamd_ftok_equal);
 
+KHASH_INIT (rspamd_languages_hash, const gchar *, struct rspamd_language_elt *, true,
+               rspamd_str_hash, rspamd_str_equal);
 struct rspamd_lang_detector {
-       GPtrArray *languages;
+       khash_t(rspamd_languages_hash) *languages;
        khash_t(rspamd_trigram_hash) *trigrams[RSPAMD_LANGUAGE_MAX]; /* trigrams frequencies */
        struct rspamd_stop_word_elt stop_words[RSPAMD_LANGUAGE_MAX];
        khash_t(rspamd_stopwords_hash) *stop_words_norm;
@@ -686,7 +688,10 @@ rspamd_language_detector_read_file (struct rspamd_config *cfg,
                        skipped, loaded, nelt->stop_words,
                        rspamd_language_detector_print_flags (nelt));
 
-       g_ptr_array_add (d->languages, nelt);
+       int ret;
+       khiter_t k = kh_put(rspamd_languages_hash, d->languages, nelt->name, &ret);
+       g_assert (ret > 0); /* must be unique */
+       kh_value(d->languages, k) = nelt;
        ucl_object_unref (top);
 }
 
@@ -764,7 +769,7 @@ rspamd_language_detector_dtor (struct rspamd_lang_detector *d)
                }
 
                if (d->languages) {
-                       g_ptr_array_free (d->languages, TRUE);
+                       kh_destroy (rspamd_languages_hash, d->languages);
                }
 
                kh_destroy (rspamd_stopwords_hash, d->stop_words_norm);
@@ -833,7 +838,8 @@ rspamd_language_detector_init (struct rspamd_config *cfg)
        }
 
        ret = rspamd_mempool_alloc0 (cfg->cfg_pool, sizeof (*ret));
-       ret->languages = g_ptr_array_sized_new (gl.gl_pathc);
+       ret->languages = kh_init(rspamd_languages_hash);
+       kh_resize(rspamd_languages_hash, ret->languages, gl.gl_pathc);
        ret->uchar_converter = rspamd_get_utf8_converter ();
        ret->short_text_limit = short_text_limit;
        ret->stop_words_norm = kh_init (rspamd_stopwords_hash);
@@ -894,7 +900,7 @@ rspamd_language_detector_init (struct rspamd_config *cfg)
 
        msg_info_config ("loaded %d languages, "
                        "%d trigrams; %s",
-                       (gint)ret->languages->len,
+                       (gint)kh_size(ret->languages),
                        (gint)total, fasttext_status);
        g_free (fasttext_status);
 
@@ -1810,25 +1816,28 @@ rspamd_language_detector_detect (struct rspamd_task *task,
 
        guint nchinese = 0, nspecial = 0;
        rspamd_language_detector_unicode_scripts (task, part, &nchinese, &nspecial);
-       /* Apply unicode scripts heuristic */
 
-       if (rspamd_language_detector_try_uniscript (task, part, nchinese, nspecial)) {
-               ret = TRUE;
-       }
+       /* Disable internal language detection heuristics if we have fasttext */
+       if (!rspamd_lang_detection_fasttext_is_enabled(d->fasttext_detector)) {
+               /* Apply unicode scripts heuristic */
+               if (rspamd_language_detector_try_uniscript(task, part, nchinese, nspecial)) {
+                       ret = TRUE;
+               }
 
-       cat = rspamd_language_detector_get_category (part->unicode_scripts);
+               cat = rspamd_language_detector_get_category(part->unicode_scripts);
 
-       if (!ret && rspamd_language_detector_try_stop_words (task, d, part, cat)) {
-               ret = TRUE;
+               if (!ret && rspamd_language_detector_try_stop_words(task, d, part, cat)) {
+                       ret = TRUE;
+               }
        }
 
        if (!ret) {
                unsigned ndetected = 0;
                if (rspamd_lang_detection_fasttext_is_enabled(d->fasttext_detector)) {
-                       rspamd_fasttext_predict_result_t fasttext_predict_result;
-                       fasttext_predict_result = rspamd_lang_detection_fasttext_detect(d->fasttext_detector,
-                               part->utf_stripped_content->data,
-                               part->utf_stripped_content->len, 4);
+                       rspamd_fasttext_predict_result_t fasttext_predict_result =
+                               rspamd_lang_detection_fasttext_detect(d->fasttext_detector,
+                                       part->utf_stripped_content->data,
+                                       part->utf_stripped_content->len, 4);
 
                        ndetected = rspamd_lang_detection_fasttext_get_nlangs(fasttext_predict_result);
 
@@ -1851,6 +1860,12 @@ rspamd_language_detector_detect (struct rspamd_task *task,
                                                cand = kh_value(candidates, k);
                                                cand->lang = lang;
                                                cand->prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, i);
+
+                                               /* Find the corresponding language elt */
+                                               k = kh_get(rspamd_languages_hash, d->languages, lang);
+                                               if (k != kh_end(d->languages)) {
+                                                       cand->elt = kh_value(d->languages, k);
+                                               }
                                        }
                                }
 
@@ -1864,6 +1879,8 @@ rspamd_language_detector_detect (struct rspamd_task *task,
                                        r = rs_detect_none;
                                }
                        }
+
+                       rspamd_fasttext_predict_result_destroy(fasttext_predict_result);
                }
                if (ndetected == 0) {
                        if (part->utf_words->len < default_short_text_limit) {
index eda4c2850067e458f20f5a99e83fa6a84c4aaebc..7e16414bc941117f3ad947aa4e9b528bcc90c179 100644 (file)
@@ -23,6 +23,7 @@
 #include "fmt/core.h"
 #include <exception>
 #include <string>
+#include <string_view>
 #include <vector>
 #include <sstream>
 #include <streambuf>
@@ -154,8 +155,10 @@ rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
 #ifndef WITH_FASTTEXT
        return nullptr;
 #else
+       /* Avoid too long inputs */
+       static const size_t max_fasttext_input_len = 1024 * 1024 * 1;
        auto *real_model = FASTTEXT_MODEL_TO_C_API(ud);
-       auto *res = real_model->detect_language(in, len, k);
+       auto *res = real_model->detect_language(in, std::min(max_fasttext_input_len, len), k);
 
        return (rspamd_fasttext_predict_result_t)res;
 #endif
@@ -188,8 +191,13 @@ rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, un
 #ifdef WITH_FASTTEXT
        auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
 
-       if (real_res && real_res->size() < idx) {
-               return real_res->at(idx).second.c_str();
+       if (real_res && real_res->size() > idx) {
+               /* Fasttext returns result in form __label__<lang>, so we need to remove __label__ prefix */
+               auto lang = std::string_view{real_res->at(idx).second};
+               if (lang.size() > sizeof("__label__") && lang.substr(0, sizeof("__label__") - 1) == "__label__") {
+                       lang.remove_prefix(sizeof("__label__") - 1);
+               }
+               return lang.data();
        }
 #endif
        return nullptr;
@@ -201,7 +209,7 @@ rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, un
 #ifdef WITH_FASTTEXT
        auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
 
-       if (real_res && real_res->size() < idx) {
+       if (real_res && real_res->size() > idx) {
                return real_res->at(idx).first;
        }
 #endif