Przeglądaj źródła

[Project] Implement fasttext language detection

tags/3.6
Vsevolod Stakhov 1 rok temu
rodzic
commit
264b9f2c48
No account linked to committer's email address

+ 108
- 61
src/libmime/lang_detection.c Wyświetl plik

@@ -1801,88 +1801,132 @@ rspamd_language_detector_detect (struct rspamd_task *task,
}

if (!ret) {
if (part->utf_words->len < default_short_text_limit) {
r = rs_detect_none;
msg_debug_lang_det ("text is too short for trigrams detection: "
"%d words; at least %d words required",
unsigned ndetected = 0;
if (rspamd_lang_detection_fasttext_is_enabled(d->fasttext_detector)) {
rspamd_fasttext_predict_result_t fasttext_predict_result;
fasttext_predict_result = rspamd_lang_detection_fasttext_detect(d->fasttext_detector,
part->utf_stripped_content->data,
part->utf_stripped_content->len, 4);

ndetected = rspamd_lang_detection_fasttext_get_nlangs(fasttext_predict_result);

if (ndetected > 0) {
candidates = kh_init (rspamd_candidates_hash);
kh_resize (rspamd_candidates_hash, candidates, ndetected);

/* Now fill all results where probability is above threshold */
float max_prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, 0);

for (unsigned int i = 0; i < ndetected; i ++) {
float prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, i);
if (prob > max_prob * 0.75) {
char *lang = rspamd_mempool_strdup(task->task_pool,
rspamd_lang_detection_fasttext_get_lang(fasttext_predict_result, i));
int tmp;
khiter_t k = kh_put (rspamd_candidates_hash, candidates, lang, &tmp);

kh_value(candidates, k) = rspamd_mempool_alloc0(task->task_pool, sizeof(*cand));
cand = kh_value(candidates, k);
cand->lang = lang;
cand->prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, i);
}
}

if (kh_size(candidates) == 1) {
r = rs_detect_single;
}
else if (kh_size(candidates) > 1) {
r = rs_detect_multiple;
}
else {
r = rs_detect_none;
}
}
}
if (ndetected == 0) {
if (part->utf_words->len < default_short_text_limit) {
r = rs_detect_none;
msg_debug_lang_det ("text is too short for trigrams detection: "
"%d words; at least %d words required",
(int)part->utf_words->len,
(int)default_short_text_limit);
switch (cat) {
case RSPAMD_LANGUAGE_CYRILLIC:
rspamd_language_detector_set_language (task, part, "ru", NULL);
break;
case RSPAMD_LANGUAGE_DEVANAGARI:
rspamd_language_detector_set_language (task, part, "hi", NULL);
break;
case RSPAMD_LANGUAGE_ARAB:
rspamd_language_detector_set_language (task, part, "ar", NULL);
break;
default:
case RSPAMD_LANGUAGE_LATIN:
rspamd_language_detector_set_language (task, part, "en", NULL);
break;
}
msg_debug_lang_det ("set %s language based on symbols category",
switch (cat) {
case RSPAMD_LANGUAGE_CYRILLIC:
rspamd_language_detector_set_language (task, part, "ru", NULL);
break;
case RSPAMD_LANGUAGE_DEVANAGARI:
rspamd_language_detector_set_language (task, part, "hi", NULL);
break;
case RSPAMD_LANGUAGE_ARAB:
rspamd_language_detector_set_language (task, part, "ar", NULL);
break;
default:
case RSPAMD_LANGUAGE_LATIN:
rspamd_language_detector_set_language (task, part, "en", NULL);
break;
}
msg_debug_lang_det ("set %s language based on symbols category",
part->language);

candidates = kh_init (rspamd_candidates_hash);
}
else {
candidates = kh_init (rspamd_candidates_hash);
kh_resize (rspamd_candidates_hash, candidates, 32);
candidates = kh_init (rspamd_candidates_hash);
}
else {
candidates = kh_init (rspamd_candidates_hash);
kh_resize (rspamd_candidates_hash, candidates, 32);

r = rspamd_language_detector_try_ngramm (task,
r = rspamd_language_detector_try_ngramm (task,
default_words,
d,
part->utf_words,
cat,
candidates);

if (r == rs_detect_none) {
msg_debug_lang_det ("no trigrams found, fallback to english");
rspamd_language_detector_set_language (task, part, "en", NULL);
} else if (r == rs_detect_multiple) {
/* Check our guess */

mean = 0.0;
std = 0.0;
cand_len = 0;

/* Check distribution */
kh_foreach_value (candidates, cand, {
if (!isnan (cand->prob)) {
mean += cand->prob;
cand_len++;
}
});
if (r == rs_detect_none) {
msg_debug_lang_det ("no trigrams found, fallback to english");
rspamd_language_detector_set_language (task, part, "en", NULL);
} else if (r == rs_detect_multiple) {
/* Check our guess */

if (cand_len > 0) {
mean /= cand_len;
mean = 0.0;
std = 0.0;
cand_len = 0;

/* Check distribution */
kh_foreach_value (candidates, cand, {
gdouble err;
if (!isnan (cand->prob)) {
err = cand->prob - mean;
std += fabs (err);
mean += cand->prob;
cand_len++;
}
});

std /= cand_len;
}
if (cand_len > 0) {
mean /= cand_len;

msg_debug_lang_det ("trigrams checked, %d candidates, %.3f mean, %.4f stddev",
kh_foreach_value (candidates, cand, {
gdouble err;
if (!isnan (cand->prob)) {
err = cand->prob - mean;
std += fabs (err);
}
});

std /= cand_len;
}

msg_debug_lang_det ("trigrams checked, %d candidates, %.3f mean, %.4f stddev",
cand_len, mean, std);

if (cand_len > 0 && std / fabs (mean) < 0.25) {
msg_debug_lang_det ("apply frequency heuristic sorting");
frequency_heuristic_applied = TRUE;
cbd.d = d;
cbd.mean = mean;
cbd.std = std;
cbd.flags = RSPAMD_LANG_FLAG_DEFAULT;
if (cand_len > 0 && std / fabs (mean) < 0.25) {
msg_debug_lang_det ("apply frequency heuristic sorting");
frequency_heuristic_applied = TRUE;
cbd.d = d;
cbd.mean = mean;
cbd.std = std;
cbd.flags = RSPAMD_LANG_FLAG_DEFAULT;

if (part->nwords < default_words / 2) {
cbd.flags |= RSPAMD_LANG_FLAG_SHORT;
if (part->nwords < default_words / 2) {
cbd.flags |= RSPAMD_LANG_FLAG_SHORT;
}
}
}
}
@@ -1909,7 +1953,9 @@ rspamd_language_detector_detect (struct rspamd_task *task,

if (result->len > 0 && !frequency_heuristic_applied) {
cand = g_ptr_array_index (result, 0);
cand->elt->occurrences++;
if (cand->elt) {
cand->elt->occurrences++;
}
d->total_occurrences++;
}

@@ -1918,6 +1964,7 @@ rspamd_language_detector_detect (struct rspamd_task *task,
}

part->languages = result;
part->language = ((struct rspamd_lang_detector_res *)g_ptr_array_index (result, 0))->lang;
ret = TRUE;
}
else if (part->languages == NULL) {

+ 35
- 8
src/libmime/lang_detection_fasttext.cxx Wyświetl plik

@@ -72,8 +72,8 @@ public:

~fasttext_langdet() = default;

auto detect_language(const char *in, size_t len, int k) -> std::vector<std::pair<fasttext::real, std::string>> *
auto is_enabled() const -> bool { return loaded; }
auto detect_language(const char *in, size_t len, int k) const -> std::vector<std::pair<fasttext::real, std::string>> *
{
if (!loaded) {
return nullptr;
@@ -135,6 +135,19 @@ char *rspamd_lang_detection_fasttext_show_info(void *ud)
#endif
}

bool rspamd_lang_detection_fasttext_is_enabled(void *ud)
{
#ifdef WITH_FASTTEXT
auto *real_model = FASTTEXT_MODEL_TO_C_API(ud);

if (real_model) {
return real_model->is_enabled();
}
#endif

return false;
}

rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
const char *in, size_t len, int k)
{
@@ -155,27 +168,41 @@ void rspamd_lang_detection_fasttext_destroy(void *ud)
#endif
}


guint
rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t res)
{
#ifdef WITH_FASTTEXT
auto *real_res = FASTTEXT_RESULT_TO_C_API(res);

if (real_res) {
return real_res->size();
}
#endif
return 0;
}

const char *
rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res)
rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, unsigned int idx)
{
#ifdef WITH_FASTTEXT
auto *real_res = FASTTEXT_RESULT_TO_C_API(res);

if (real_res && !real_res->empty()) {
return real_res->front().second.c_str();
if (real_res && real_res->size() < idx) {
return real_res->at(idx).second.c_str();
}
#endif
return nullptr;
}

float
rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res)
rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx)
{
#ifdef WITH_FASTTEXT
auto *real_res = FASTTEXT_RESULT_TO_C_API(res);

if (real_res && !real_res->empty()) {
return real_res->front().first;
if (real_res && real_res->size() < idx) {
return real_res->at(idx).first;
}
#endif
return 0.0f;

+ 15
- 2
src/libmime/lang_detection_fasttext.h Wyświetl plik

@@ -27,6 +27,13 @@ struct rspamd_config;
*/
void* rspamd_lang_detection_fasttext_init(struct rspamd_config *cfg);

/**
* Check if fasttext language detector is enabled
* @param ud
* @return
*/
bool rspamd_lang_detection_fasttext_is_enabled(void *ud);

/**
* Show info about fasttext language detector
* @param ud
@@ -47,19 +54,25 @@ typedef void * rspamd_fasttext_predict_result_t;
rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
const char *in, size_t len, int k);

/**
* Get number of languages detected
* @param ud
* @return
*/
guint rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t ud);
/**
* Get language from fasttext result
* @param res
* @return
*/
const char *rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res);
const char *rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, unsigned int idx);

/**
* Get probability from fasttext result
* @param res
* @return
*/
float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res);
float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx);

/**
* Destroy fasttext result

Ładowanie…
Anuluj
Zapisz