diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2023-04-29 15:47:15 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2023-04-29 15:47:15 +0100 |
commit | 264b9f2c480a1b0240acb8183a8d7470691aff11 (patch) | |
tree | aeecf4738499e48c0405903abebc5434975b39ba | |
parent | fea5bdc79758530a3c28970c9c19d05e9932de74 (diff) | |
download | rspamd-264b9f2c480a1b0240acb8183a8d7470691aff11.tar.gz rspamd-264b9f2c480a1b0240acb8183a8d7470691aff11.zip |
[Project] Implement fasttext language detection
-rw-r--r-- | src/libmime/lang_detection.c | 169 | ||||
-rw-r--r-- | src/libmime/lang_detection_fasttext.cxx | 43 | ||||
-rw-r--r-- | src/libmime/lang_detection_fasttext.h | 17 |
3 files changed, 158 insertions, 71 deletions
diff --git a/src/libmime/lang_detection.c b/src/libmime/lang_detection.c index 09591438e..211dfe48b 100644 --- a/src/libmime/lang_detection.c +++ b/src/libmime/lang_detection.c @@ -1801,88 +1801,132 @@ rspamd_language_detector_detect (struct rspamd_task *task, } if (!ret) { - if (part->utf_words->len < default_short_text_limit) { - r = rs_detect_none; - msg_debug_lang_det ("text is too short for trigrams detection: " - "%d words; at least %d words required", + unsigned ndetected = 0; + if (rspamd_lang_detection_fasttext_is_enabled(d->fasttext_detector)) { + rspamd_fasttext_predict_result_t fasttext_predict_result; + fasttext_predict_result = rspamd_lang_detection_fasttext_detect(d->fasttext_detector, + part->utf_stripped_content->data, + part->utf_stripped_content->len, 4); + + ndetected = rspamd_lang_detection_fasttext_get_nlangs(fasttext_predict_result); + + if (ndetected > 0) { + candidates = kh_init (rspamd_candidates_hash); + kh_resize (rspamd_candidates_hash, candidates, ndetected); + + /* Now fill all results where probability is above threshold */ + float max_prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, 0); + + for (unsigned int i = 0; i < ndetected; i ++) { + float prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, i); + if (prob > max_prob * 0.75) { + char *lang = rspamd_mempool_strdup(task->task_pool, + rspamd_lang_detection_fasttext_get_lang(fasttext_predict_result, i)); + int tmp; + khiter_t k = kh_put (rspamd_candidates_hash, candidates, lang, &tmp); + + kh_value(candidates, k) = rspamd_mempool_alloc0(task->task_pool, sizeof(*cand)); + cand = kh_value(candidates, k); + cand->lang = lang; + cand->prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, i); + } + } + + if (kh_size(candidates) == 1) { + r = rs_detect_single; + } + else if (kh_size(candidates) > 1) { + r = rs_detect_multiple; + } + else { + r = rs_detect_none; + } + } + } + if (ndetected == 0) { + if (part->utf_words->len < default_short_text_limit) { + r = rs_detect_none; + msg_debug_lang_det ("text is too short for trigrams detection: " + "%d words; at least %d words required", (int)part->utf_words->len, (int)default_short_text_limit); - switch (cat) { - case RSPAMD_LANGUAGE_CYRILLIC: - rspamd_language_detector_set_language (task, part, "ru", NULL); - break; - case RSPAMD_LANGUAGE_DEVANAGARI: - rspamd_language_detector_set_language (task, part, "hi", NULL); - break; - case RSPAMD_LANGUAGE_ARAB: - rspamd_language_detector_set_language (task, part, "ar", NULL); - break; - default: - case RSPAMD_LANGUAGE_LATIN: - rspamd_language_detector_set_language (task, part, "en", NULL); - break; - } - msg_debug_lang_det ("set %s language based on symbols category", + switch (cat) { + case RSPAMD_LANGUAGE_CYRILLIC: + rspamd_language_detector_set_language (task, part, "ru", NULL); + break; + case RSPAMD_LANGUAGE_DEVANAGARI: + rspamd_language_detector_set_language (task, part, "hi", NULL); + break; + case RSPAMD_LANGUAGE_ARAB: + rspamd_language_detector_set_language (task, part, "ar", NULL); + break; + default: + case RSPAMD_LANGUAGE_LATIN: + rspamd_language_detector_set_language (task, part, "en", NULL); + break; + } + msg_debug_lang_det ("set %s language based on symbols category", part->language); - candidates = kh_init (rspamd_candidates_hash); - } - else { - candidates = kh_init (rspamd_candidates_hash); - kh_resize (rspamd_candidates_hash, candidates, 32); + candidates = kh_init (rspamd_candidates_hash); + } + else { + candidates = kh_init (rspamd_candidates_hash); + kh_resize (rspamd_candidates_hash, candidates, 32); - r = rspamd_language_detector_try_ngramm (task, + r = rspamd_language_detector_try_ngramm (task, default_words, d, part->utf_words, cat, candidates); - if (r == rs_detect_none) { - msg_debug_lang_det ("no trigrams found, fallback to english"); - rspamd_language_detector_set_language (task, part, "en", NULL); - } else if (r == rs_detect_multiple) { - /* Check our guess */ - - mean = 0.0; - std = 0.0; - cand_len = 0; - - /* Check distribution */ - kh_foreach_value (candidates, cand, { - if (!isnan (cand->prob)) { - mean += cand->prob; - cand_len++; - } - }); + if (r == rs_detect_none) { + msg_debug_lang_det ("no trigrams found, fallback to english"); + rspamd_language_detector_set_language (task, part, "en", NULL); + } else if (r == rs_detect_multiple) { + /* Check our guess */ - if (cand_len > 0) { - mean /= cand_len; + mean = 0.0; + std = 0.0; + cand_len = 0; + /* Check distribution */ kh_foreach_value (candidates, cand, { - gdouble err; if (!isnan (cand->prob)) { - err = cand->prob - mean; - std += fabs (err); + mean += cand->prob; + cand_len++; } }); - std /= cand_len; - } + if (cand_len > 0) { + mean /= cand_len; - msg_debug_lang_det ("trigrams checked, %d candidates, %.3f mean, %.4f stddev", + kh_foreach_value (candidates, cand, { + gdouble err; + if (!isnan (cand->prob)) { + err = cand->prob - mean; + std += fabs (err); + } + }); + + std /= cand_len; + } + + msg_debug_lang_det ("trigrams checked, %d candidates, %.3f mean, %.4f stddev", cand_len, mean, std); - if (cand_len > 0 && std / fabs (mean) < 0.25) { - msg_debug_lang_det ("apply frequency heuristic sorting"); - frequency_heuristic_applied = TRUE; - cbd.d = d; - cbd.mean = mean; - cbd.std = std; - cbd.flags = RSPAMD_LANG_FLAG_DEFAULT; + if (cand_len > 0 && std / fabs (mean) < 0.25) { + msg_debug_lang_det ("apply frequency heuristic sorting"); + frequency_heuristic_applied = TRUE; + cbd.d = d; + cbd.mean = mean; + cbd.std = std; + cbd.flags = RSPAMD_LANG_FLAG_DEFAULT; - if (part->nwords < default_words / 2) { - cbd.flags |= RSPAMD_LANG_FLAG_SHORT; + if (part->nwords < default_words / 2) { + cbd.flags |= RSPAMD_LANG_FLAG_SHORT; + } } } } @@ -1909,7 +1953,9 @@ rspamd_language_detector_detect (struct rspamd_task *task, if (result->len > 0 && !frequency_heuristic_applied) { cand = g_ptr_array_index (result, 0); - cand->elt->occurrences++; + if (cand->elt) { + cand->elt->occurrences++; + } d->total_occurrences++; } @@ -1918,6 +1964,7 @@ rspamd_language_detector_detect (struct rspamd_task *task, } part->languages = result; + part->language = ((struct rspamd_lang_detector_res *)g_ptr_array_index (result, 0))->lang; ret = TRUE; } else if (part->languages == NULL) { diff --git a/src/libmime/lang_detection_fasttext.cxx b/src/libmime/lang_detection_fasttext.cxx index 9ede47a6e..eda4c2850 100644 --- a/src/libmime/lang_detection_fasttext.cxx +++ b/src/libmime/lang_detection_fasttext.cxx @@ -72,8 +72,8 @@ public: ~fasttext_langdet() = default; - - auto detect_language(const char *in, size_t len, int k) -> std::vector<std::pair<fasttext::real, std::string>> * + auto is_enabled() const -> bool { return loaded; } + auto detect_language(const char *in, size_t len, int k) const -> std::vector<std::pair<fasttext::real, std::string>> * { if (!loaded) { return nullptr; @@ -135,6 +135,19 @@ char *rspamd_lang_detection_fasttext_show_info(void *ud) #endif } +bool rspamd_lang_detection_fasttext_is_enabled(void *ud) +{ +#ifdef WITH_FASTTEXT + auto *real_model = FASTTEXT_MODEL_TO_C_API(ud); + + if (real_model) { + return real_model->is_enabled(); + } +#endif + + return false; +} + rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud, const char *in, size_t len, int k) { @@ -155,27 +168,41 @@ void rspamd_lang_detection_fasttext_destroy(void *ud) #endif } + +guint +rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t res) +{ +#ifdef WITH_FASTTEXT + auto *real_res = FASTTEXT_RESULT_TO_C_API(res); + + if (real_res) { + return real_res->size(); + } +#endif + return 0; +} + const char * -rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res) +rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, unsigned int idx) { #ifdef WITH_FASTTEXT auto *real_res = FASTTEXT_RESULT_TO_C_API(res); - if (real_res && !real_res->empty()) { - return real_res->front().second.c_str(); + if (real_res && real_res->size() < idx) { + return real_res->at(idx).second.c_str(); } #endif return nullptr; } float -rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res) +rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx) { #ifdef WITH_FASTTEXT auto *real_res = FASTTEXT_RESULT_TO_C_API(res); - if (real_res && !real_res->empty()) { - return real_res->front().first; + if (real_res && real_res->size() < idx) { + return real_res->at(idx).first; } #endif return 0.0f; diff --git a/src/libmime/lang_detection_fasttext.h b/src/libmime/lang_detection_fasttext.h index 71e253940..2e8a9fe78 100644 --- a/src/libmime/lang_detection_fasttext.h +++ b/src/libmime/lang_detection_fasttext.h @@ -28,6 +28,13 @@ struct rspamd_config; void* rspamd_lang_detection_fasttext_init(struct rspamd_config *cfg); /** + * Check if fasttext language detector is enabled + * @param ud + * @return + */ +bool rspamd_lang_detection_fasttext_is_enabled(void *ud); + +/** * Show info about fasttext language detector * @param ud * @return @@ -48,18 +55,24 @@ rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud, const char *in, size_t len, int k); /** + * Get number of languages detected + * @param ud + * @return + */ +guint rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t ud); +/** * Get language from fasttext result * @param res * @return */ -const char *rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res); +const char *rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, unsigned int idx); /** * Get probability from fasttext result * @param res * @return */ -float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res); +float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx); /** * Destroy fasttext result |