]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Use frequencies distribution in language detector
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 24 Jan 2018 20:56:00 +0000 (20:56 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 24 Jan 2018 20:57:09 +0000 (20:57 +0000)
src/libmime/lang_detection.c

index 2bdda30049a847cd64ed7eaceb999bd9fef71842..ed2aa1681ff5f802d072c5ade2a8c12c3e54bfd3 100644 (file)
@@ -769,18 +769,70 @@ rspamd_language_detector_try_ngramm (struct rspamd_task *task,
        return rs_detect_multiple;
 }
 
+struct rspamd_frequency_sort_cbdata {
+       struct rspamd_lang_detector *d;
+       gdouble std;
+       gdouble mean;
+};
+
+static gint
+rspamd_language_detector_cmp_heuristic (gconstpointer a, gconstpointer b,
+               gpointer ud)
+{
+       struct rspamd_frequency_sort_cbdata *cbd = ud;
+       const struct rspamd_lang_detector_res
+                       *canda = *(const struct rspamd_lang_detector_res **)a,
+                       *candb = *(const struct rspamd_lang_detector_res **)b;
+       gdouble diff;
+
+       diff = fabs (canda->prob - candb->prob);
+
+       if (diff > cbd->std) {
+               /* Generic case */
+               if (canda->prob > candb->prob) {
+                       return -1;
+               } else if (candb->prob > canda->prob) {
+                       return 1;
+               }
+
+               return 0;
+       }
+       else {
+               gdouble proba_adjusted, probb_adjusted, freqa, freqb;
+
+               freqa = ((gdouble)canda->elt->occurencies) /
+                               (gdouble)cbd->d->total_occurencies;
+               freqb = ((gdouble)candb->elt->occurencies) /
+                               (gdouble)cbd->d->total_occurencies;
+
+               proba_adjusted = canda->prob * freqa;
+               probb_adjusted = candb->prob * freqb;
+
+               if (proba_adjusted > probb_adjusted) {
+                       return -1;
+               } else if (probb_adjusted > proba_adjusted) {
+                       return 1;
+               }
+
+               return 0;
+       }
+}
+
 GPtrArray *
 rspamd_language_detector_detect (struct rspamd_task *task,
                struct rspamd_lang_detector *d,
                GArray *ucs_tokens, gsize words_len)
 {
-       GHashTable *candidates, *tcandidates;
+       GHashTable *candidates;
        GPtrArray *result;
        GHashTableIter it;
        gpointer k, v;
        gdouble mean, std;
        struct rspamd_lang_detector_res *cand;
        enum rspamd_language_detected_type r;
+       struct rspamd_frequency_sort_cbdata cbd;
+       /* Check if we have sorted candidates based on frequency */
+       gboolean frequency_heuristic_applied = FALSE;
 
        if (ucs_tokens->len == 0) {
                return g_ptr_array_new ();
@@ -789,110 +841,61 @@ rspamd_language_detector_detect (struct rspamd_task *task,
        candidates = g_hash_table_new_full (rspamd_str_hash, rspamd_str_equal,
                        NULL, g_free);
 
-       if (words_len < d->short_text_limit) {
-               /* For short text, start directly from trigramms */
-               msg_debug_lang_det ("text is less than %z words: %z, start with trigramms",
-                               d->short_text_limit, words_len);
+       msg_debug_lang_det ("text is less than %z words: %z, start with trigramms",
+                       d->short_text_limit, words_len);
+       r = rspamd_language_detector_try_ngramm (task, default_words, d,
+                       ucs_tokens, rs_trigramm,
+                       candidates);
+
+       if (r == rs_detect_none) {
+               msg_debug_lang_det ("short mode; no trigramms found, switch to bigramms");
                r = rspamd_language_detector_try_ngramm (task, default_words, d,
-                               ucs_tokens, rs_trigramm,
+                               ucs_tokens, rs_bigramm,
                                candidates);
 
                if (r == rs_detect_none) {
-                       msg_debug_lang_det ("short mode; no trigramms found, switch to bigramms");
-                       r = rspamd_language_detector_try_ngramm (task, default_words, d,
-                                       ucs_tokens, rs_bigramm,
+                       msg_debug_lang_det ("short mode; no trigramms found, "
+                                       "switch to unigramms");
+                       r = rspamd_language_detector_try_ngramm (task, default_words,
+                                       d, ucs_tokens, rs_unigramm,
                                        candidates);
-
-                       if (r == rs_detect_none) {
-                               msg_debug_lang_det ("short mode; no trigramms found, "
-                                               "switch to unigramms");
-                               r = rspamd_language_detector_try_ngramm (task, default_words,
-                                               d, ucs_tokens, rs_unigramm,
-                                               candidates);
-                       }
                }
        }
-       else {
-               /* Start with unigramms */
-               r = rspamd_language_detector_try_ngramm (task, default_words,
-                               d, ucs_tokens, rs_unigramm,
-                               candidates);
+       else if (r == rs_detect_multiple) {
+               /* Check our guess */
+               msg_debug_lang_det ("unigramms pass finished, found %d candidates",
+                               (gint)g_hash_table_size (candidates));
+               mean = 0.0;
+               std = 0.0;
+               g_hash_table_iter_init (&it, candidates);
+
+               /* Check distirbution */
+               while (g_hash_table_iter_next (&it, &k, &v)) {
+                       cand = (struct rspamd_lang_detector_res *) v;
+                       mean += cand->prob;
+               }
 
-               switch (r) {
-               case rs_detect_none:
-               case rs_detect_single:
-                       msg_debug_lang_det ("no unigramms found, try bigramms");
-                       break;
-               case rs_detect_multiple:
-                       /* Try to improve guess */
-                       msg_debug_lang_det ("unigramms pass finished, found %d candidates",
-                                       (gint)g_hash_table_size (candidates));
-                       tcandidates = g_hash_table_new_full (rspamd_str_hash, rspamd_str_equal,
-                                       NULL, g_free);
-                       r = rspamd_language_detector_try_ngramm (task, default_words,
-                                       d, ucs_tokens, rs_trigramm,
-                                       tcandidates);
+               mean /= g_hash_table_size (candidates);
 
-                       switch (r) {
-                       case rs_detect_none:
-                               /* Revert to unigramms result */
-                               g_hash_table_unref (tcandidates);
-                               break;
-                       case rs_detect_single:
-                               /* We have good enough result, return it */
-                               g_hash_table_unref (candidates);
-                               candidates = tcandidates;
-                               break;
-                       case rs_detect_multiple:
-                               mean = 0.0;
-                               std = 0.0;
-                               g_hash_table_iter_init (&it, tcandidates);
-
-                               /* Check distirbution */
-                               while (g_hash_table_iter_next (&it, &k, &v)) {
-                                       cand = (struct rspamd_lang_detector_res *) v;
-                                       mean += cand->prob;
-                               }
-
-                               mean /= g_hash_table_size (tcandidates);
-
-                               g_hash_table_iter_init (&it, tcandidates);
-                               while (g_hash_table_iter_next (&it, &k, &v)) {
-                                       gdouble err;
-                                       cand = (struct rspamd_lang_detector_res *) v;
-                                       err = cand->prob - mean;
-                                       std += fabs (err);
-                               }
-
-                               std /= g_hash_table_size (tcandidates);
-                               g_hash_table_unref (candidates);
-                               candidates = tcandidates;
-
-                               msg_debug_lang_det ("trigramms checked, %.3f mean, %.4f stddev",
-                                               mean, std);
-
-                               if (std / fabs (mean) < 0.01) {
-                                       /* Try trigramms */
-                                       tcandidates = g_hash_table_new_full (rspamd_str_hash,
-                                                       rspamd_str_equal,
-                                                       NULL, g_free);
-
-                                       r = rspamd_language_detector_try_ngramm (task,
-                                                       default_words * 2,
-                                                       d,
-                                                       ucs_tokens,
-                                                       rs_trigramm,
-                                                       tcandidates);
-
-                                       if (r != rs_detect_none) {
-                                               /* TODO: check if we have better distribution here */
-                                               g_hash_table_unref (candidates);
-                                               candidates = tcandidates;
-                                       }
-                               }
-                               break;
-                       }
-                       break;
+               g_hash_table_iter_init (&it, candidates);
+               while (g_hash_table_iter_next (&it, &k, &v)) {
+                       gdouble err;
+                       cand = (struct rspamd_lang_detector_res *) v;
+                       err = cand->prob - mean;
+                       std += fabs (err);
+               }
+
+               std /= g_hash_table_size (candidates);
+
+               msg_debug_lang_det ("trigramms checked, %.3f mean, %.4f stddev",
+                               mean, std);
+
+               if (std / fabs (mean) < 0.1) {
+                       msg_debug_lang_det ("apply frequency heuristic sorting");
+                       frequency_heuristic_applied = TRUE;
+                       cbd.d = d;
+                       cbd.mean = mean;
+                       cbd.std = std;
                }
        }
 
@@ -907,10 +910,17 @@ rspamd_language_detector_detect (struct rspamd_task *task,
                g_hash_table_iter_steal (&it);
        }
 
-       g_ptr_array_sort (result, rspamd_language_detector_cmp);
+       if (frequency_heuristic_applied) {
+               g_ptr_array_sort_with_data (result,
+                               rspamd_language_detector_cmp_heuristic, (gpointer)&cbd);
+       }
+       else {
+               g_ptr_array_sort (result, rspamd_language_detector_cmp);
+       }
        g_hash_table_unref (candidates);
 
-       if (result->len > 0) {
+
+       if (result->len > 0 && !frequency_heuristic_applied) {
                cand = g_ptr_array_index (result, 0);
                cand->elt->occurencies ++;
                d->total_occurencies ++;