]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Further optimization of the lang_detection
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 17 Apr 2018 16:53:55 +0000 (17:53 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 17 Apr 2018 16:53:55 +0000 (17:53 +0100)
src/libmime/lang_detection.c

index 363486209adb4cd9cc974bf6005b97c0a27944d4..84d23ac6302a096406902ea34fa34db0ba01512e 100644 (file)
@@ -172,10 +172,13 @@ rspamd_trigram_equal_func (gconstpointer v, gconstpointer v2)
        return memcmp (v, v2, 3 * sizeof (UChar)) == 0;
 }
 
-KHASH_INIT (rspamd_unigram_hash, UChar *, struct rspamd_ngramm_chain, true,
+KHASH_INIT (rspamd_unigram_hash, const UChar *, struct rspamd_ngramm_chain, true,
                rspamd_unigram_hash_func, rspamd_unigram_equal_func);
-KHASH_INIT (rspamd_trigram_hash, UChar *, struct rspamd_ngramm_chain, true,
+KHASH_INIT (rspamd_trigram_hash, const UChar *, struct rspamd_ngramm_chain, true,
                rspamd_trigram_hash_func, rspamd_trigram_equal_func);
+KHASH_INIT (rspamd_candidates_hash, const gchar *,
+               struct rspamd_lang_detector_res *, true,
+               rspamd_str_hash, rspamd_str_equal);
 
 struct rspamd_lang_detector {
        GPtrArray *languages;
@@ -880,9 +883,10 @@ static void
 rspamd_language_detector_process_ngramm_full (struct rspamd_task *task,
                struct rspamd_lang_detector *d,
                UChar *window, enum rspamd_language_gramm_type type,
-               GHashTable *candidates)
+               khash_t(rspamd_candidates_hash) *candidates)
 {
        guint i;
+       gint ret;
        struct rspamd_ngramm_chain *chain = NULL;
        struct rspamd_ngramm_elt *elt;
        struct rspamd_lang_detector_res *cand;
@@ -906,23 +910,33 @@ rspamd_language_detector_process_ngramm_full (struct rspamd_task *task,
 
        if (chain) {
                PTR_ARRAY_FOREACH (chain->languages, i, elt) {
-                       cand = g_hash_table_lookup (candidates, elt->elt->name);
                        prob = elt->prob;
 
                        if (prob < chain->mean) {
                                continue;
                        }
+
+                       k = kh_get (rspamd_candidates_hash, candidates, elt->elt->name);
+                       if (k != kh_end (candidates)) {
+                               cand = kh_value (candidates, k);
+                       }
+                       else {
+                               cand = NULL;
+                       }
+
 #ifdef NGRAMMS_DEBUG
                        msg_err ("gramm: %s, lang: %s, prob: %.3f", chain->utf,
                                        elt->elt->name, log2 (elt->prob));
 #endif
                        if (cand == NULL) {
-                               cand = g_malloc (sizeof (*cand));
+                               cand = rspamd_mempool_alloc (task->task_pool, sizeof (*cand));
                                cand->elt = elt->elt;
                                cand->lang = elt->elt->name;
                                cand->prob = prob;
 
-                               g_hash_table_insert (candidates, (gpointer)cand->lang, cand);
+                               k = kh_put (rspamd_candidates_hash, candidates, elt->elt->name,
+                                               &ret);
+                               kh_value (candidates, k) = cand;
                        } else {
                                /* Update guess */
                                cand->prob += prob;
@@ -934,7 +948,8 @@ rspamd_language_detector_process_ngramm_full (struct rspamd_task *task,
 static void
 rspamd_language_detector_detect_word (struct rspamd_task *task,
                struct rspamd_lang_detector *d,
-               rspamd_stat_token_t *tok, GHashTable *candidates,
+               rspamd_stat_token_t *tok,
+               khash_t(rspamd_candidates_hash) *candidates,
                enum rspamd_language_gramm_type type)
 {
        guint wlen;
@@ -963,58 +978,65 @@ static const gdouble cutoff_limit = -8.0;
  * Converts frequencies to log probabilities, filter those candidates who
  * has the lowest probabilities
  */
-static void
-rspamd_language_detector_filter_negligible (struct rspamd_task *task,
-               GHashTable *candidates)
-{
-       GHashTableIter it;
-       gpointer k, v;
-       struct rspamd_lang_detector_res *cand;
-       guint filtered = 0;
-       gdouble max_prob = -(G_MAXDOUBLE);
-
-       /* Normalize step */
-       g_hash_table_iter_init (&it, candidates);
-
-       while (g_hash_table_iter_next (&it, &k, &v)) {
-               cand = (struct rspamd_lang_detector_res *)v;
 
+static inline void
+rspamd_language_detector_filter_step1 (struct rspamd_task *task,
+               struct rspamd_lang_detector_res *cand,
+               gdouble *max_prob, guint *filtered)
+{
+       if (!isnan (cand->prob)) {
                if (cand->prob == 0) {
-                       msg_debug_lang_det ("exclude language %s: %.3f",
-                                       cand->lang, cand->prob, max_prob);
-                       g_hash_table_iter_remove (&it);
-                       filtered ++;
+                       cand->prob = NAN;
+                       msg_debug_lang_det (
+                                       "exclude language %s",
+                                       cand->lang);
+                       (*filtered)++;
                }
                else {
                        cand->prob = log2 (cand->prob);
                        if (cand->prob < cutoff_limit) {
-                               msg_debug_lang_det ("exclude language %s: %.3f, cutoff limit: %.3f",
+                               msg_debug_lang_det (
+                                               "exclude language %s: %.3f, cutoff limit: %.3f",
                                                cand->lang, cand->prob, cutoff_limit);
-                               g_hash_table_iter_remove (&it);
-                               filtered ++;
+                               cand->prob = NAN;
+                               (*filtered)++;
                        }
-                       else if (cand->prob > max_prob) {
-                               max_prob = cand->prob;
+                       else if (cand->prob > *max_prob) {
+                               *max_prob = cand->prob;
                        }
                }
        }
+}
 
-       g_hash_table_iter_init (&it, candidates);
-       /* Filter step */
-       while (g_hash_table_iter_next (&it, &k, &v)) {
-               cand = (struct rspamd_lang_detector_res *) v;
-
-               /*
+static inline void
+rspamd_language_detector_filter_step2 (struct rspamd_task *task,
+               struct rspamd_lang_detector_res *cand,
+               gdouble max_prob, guint *filtered)
+{
+       /*
                 * Probabilities are logarithmic, so if prob1 - prob2 > 4, it means that
                 * prob2 is 2^4 less than prob1
                 */
-               if (max_prob - cand->prob > 1) {
-                       msg_debug_lang_det ("exclude language %s: %.3f (%.3f max)",
-                                       cand->lang, cand->prob, max_prob);
-                       g_hash_table_iter_remove (&it);
-                       filtered ++;
-               }
+       if (!isnan (cand->prob) && max_prob - cand->prob > 1) {
+               msg_debug_lang_det ("exclude language %s: %.3f (%.3f max)",
+                               cand->lang, cand->prob, max_prob);
+               cand->prob = NAN;
+               (*filtered) ++;
        }
+}
+
+static void
+rspamd_language_detector_filter_negligible (struct rspamd_task *task,
+               khash_t(rspamd_candidates_hash) *candidates)
+{
+       struct rspamd_lang_detector_res *cand;
+       guint filtered = 0;
+       gdouble max_prob = -(G_MAXDOUBLE);
+
+       kh_foreach_value (candidates, cand,
+                       rspamd_language_detector_filter_step1 (task, cand, &max_prob, &filtered));
+       kh_foreach_value (candidates, cand,
+                       rspamd_language_detector_filter_step2 (task, cand, max_prob, &filtered));
 
        msg_debug_lang_det ("removed %d languages", filtered);
 }
@@ -1025,16 +1047,15 @@ rspamd_language_detector_is_unicode (struct rspamd_task *task,
                GArray *ucs_tokens,
                goffset *selected_words,
                gsize nparts,
-               GHashTable *candidates)
+               khash_t(rspamd_candidates_hash) *candidates)
 {
        guint i, j, total_found = 0, total_checked = 0;
        rspamd_stat_token_t *tok;
        UChar t;
-       gint uc_script;
+       gint uc_script, ret;
+       khint_t k;
        struct rspamd_language_elt *elt;
        struct rspamd_lang_detector_res *cand;
-       GHashTableIter it;
-       gpointer k, v;
 
        for (i = 0; i < nparts; i++) {
                tok = &g_array_index (ucs_tokens, rspamd_stat_token_t,
@@ -1047,15 +1068,23 @@ rspamd_language_detector_is_unicode (struct rspamd_task *task,
                        elt = g_hash_table_lookup (d->unicode_scripts, &uc_script);
 
                        if (elt) {
-                               cand = g_hash_table_lookup (candidates, elt->name);
+                               k = kh_get (rspamd_candidates_hash, candidates, elt->name);
+                               if (k != kh_end (candidates)) {
+                                       cand = kh_value (candidates, k);
+                               }
+                               else {
+                                       cand = NULL;
+                               }
 
                                if (cand == NULL) {
-                                       cand = g_malloc (sizeof (*cand));
+                                       cand = rspamd_mempool_alloc (task->task_pool,
+                                                       sizeof (*cand));
                                        cand->elt = elt;
                                        cand->lang = elt->name;
                                        cand->prob = 1;
 
-                                       g_hash_table_insert (candidates, (gpointer)cand->lang, cand);
+                                       k = kh_put (rspamd_candidates_hash, candidates, elt->name, &ret);
+                                       kh_value (candidates, k) = cand;
                                } else {
                                        /* Update guess */
                                        cand->prob ++;
@@ -1079,13 +1108,9 @@ rspamd_language_detector_is_unicode (struct rspamd_task *task,
        }
        else {
                /* Filter candidates */
-               g_hash_table_iter_init (&it, candidates);
-
-               while (g_hash_table_iter_next (&it, &k, &v)) {
-                       cand = (struct rspamd_lang_detector_res *)v;
-
+               kh_foreach_value (candidates, cand, {
                        cand->prob = cand->prob / total_checked;
-               }
+               });
        }
 
        return TRUE;
@@ -1096,7 +1121,7 @@ rspamd_language_detector_detect_type (struct rspamd_task *task,
                guint nwords,
                struct rspamd_lang_detector *d,
                GArray *ucs_tokens,
-               GHashTable *candidates,
+               khash_t(rspamd_candidates_hash) *candidates,
                enum rspamd_language_gramm_type type) {
        guint nparts = MIN (ucs_tokens->len, nwords);
        goffset *selected_words;
@@ -1108,7 +1133,7 @@ rspamd_language_detector_detect_type (struct rspamd_task *task,
        msg_debug_lang_det ("randomly selected %d words", nparts);
 
        /* Check unicode scripts */
-       if (g_hash_table_size (candidates) != 0 ||
+       if (kh_size (candidates) != 0 ||
                        !rspamd_language_detector_is_unicode (task, d, ucs_tokens,
                                        selected_words, nparts, candidates)) {
 
@@ -1155,9 +1180,10 @@ rspamd_language_detector_try_ngramm (struct rspamd_task *task,
                struct rspamd_lang_detector *d,
                GArray *ucs_tokens,
                enum rspamd_language_gramm_type type,
-               GHashTable *candidates)
+               khash_t(rspamd_candidates_hash) *candidates)
 {
-       guint cand_len;
+       guint cand_len = 0;
+       struct rspamd_lang_detector_res *cand;
 
        rspamd_language_detector_detect_type (task,
                        nwords,
@@ -1166,7 +1192,11 @@ rspamd_language_detector_try_ngramm (struct rspamd_task *task,
                        candidates,
                        type);
 
-       cand_len = g_hash_table_size (candidates);
+       kh_foreach_value (candidates, cand, {
+               if (!isnan (cand->prob)) {
+                       cand_len ++;
+               }
+       });
 
        if (cand_len == 0) {
                return rs_detect_none;
@@ -1262,11 +1292,10 @@ rspamd_language_detector_detect (struct rspamd_task *task,
                struct rspamd_lang_detector *d,
                GArray *ucs_tokens, gsize words_len)
 {
-       GHashTable *candidates;
+       khash_t(rspamd_candidates_hash) *candidates;
        GPtrArray *result;
-       GHashTableIter it;
-       gpointer k, v;
        gdouble mean, std, start_ticks, end_ticks;
+       guint cand_len;
        struct rspamd_lang_detector_res *cand;
        enum rspamd_language_detected_type r;
        struct rspamd_frequency_sort_cbdata cbd;
@@ -1278,8 +1307,8 @@ rspamd_language_detector_detect (struct rspamd_task *task,
        }
 
        start_ticks = rspamd_get_ticks (TRUE);
-       candidates = g_hash_table_new_full (rspamd_str_hash, rspamd_str_equal,
-                       NULL, g_free);
+       candidates = kh_init (rspamd_candidates_hash);
+       kh_resize (rspamd_candidates_hash, candidates, 32);
 
        r = rspamd_language_detector_try_ngramm (task, default_words, d,
                        ucs_tokens, rs_trigramm,
@@ -1293,35 +1322,37 @@ rspamd_language_detector_detect (struct rspamd_task *task,
        }
        else if (r == rs_detect_multiple) {
                /* Check our guess */
-               msg_debug_lang_det ("trigramms pass finished, found %d candidates",
-                               (gint)g_hash_table_size (candidates));
 
                mean = 0.0;
                std = 0.0;
-               g_hash_table_iter_init (&it, candidates);
+               cand_len = 0;
 
                /* Check distirbution */
-               while (g_hash_table_iter_next (&it, &k, &v)) {
-                       cand = (struct rspamd_lang_detector_res *) v;
-                       mean += cand->prob;
-               }
+               kh_foreach_value (candidates, cand, {
+                       if (!isnan (cand->prob)) {
+                               mean += cand->prob;
+                               cand_len ++;
+                       }
+               });
 
-               mean /= g_hash_table_size (candidates);
+               if (cand_len > 0) {
+                       mean /= cand_len;
 
-               g_hash_table_iter_init (&it, candidates);
-               while (g_hash_table_iter_next (&it, &k, &v)) {
-                       gdouble err;
-                       cand = (struct rspamd_lang_detector_res *) v;
-                       err = cand->prob - mean;
-                       std += fabs (err);
-               }
+                       kh_foreach_value (candidates, cand, {
+                               gdouble err;
+                               if (!isnan (cand->prob)) {
+                                       err = cand->prob - mean;
+                                       std += fabs (err);
+                               }
+                       });
 
-               std /= g_hash_table_size (candidates);
+                       std /= cand_len;
+               }
 
-               msg_debug_lang_det ("trigramms checked, %.3f mean, %.4f stddev",
-                               mean, std);
+               msg_debug_lang_det ("trigramms checked, %d candidates, %.3f mean, %.4f stddev",
+                               cand_len, mean, std);
 
-               if (std / fabs (mean) < 0.25) {
+               if (cand_len > 0 && std / fabs (mean) < 0.25) {
                        msg_debug_lang_det ("apply frequency heuristic sorting");
                        frequency_heuristic_applied = TRUE;
                        cbd.d = d;
@@ -1336,15 +1367,15 @@ rspamd_language_detector_detect (struct rspamd_task *task,
        }
 
        /* Now, convert hash to array and sort it */
-       result = g_ptr_array_new_full (g_hash_table_size (candidates), g_free);
-       g_hash_table_iter_init (&it, candidates);
+       result = g_ptr_array_sized_new (kh_size (candidates));
 
-       while (g_hash_table_iter_next (&it, &k, &v)) {
-               cand = (struct rspamd_lang_detector_res *) v;
-               msg_debug_lang_det ("final probability %s -> %.2f", cand->lang, cand->prob);
-               g_ptr_array_add (result, cand);
-               g_hash_table_iter_steal (&it);
-       }
+       kh_foreach_value (candidates, cand, {
+               if (!isnan (cand->prob)) {
+                       msg_debug_lang_det ("final probability %s -> %.2f", cand->lang,
+                                       cand->prob);
+                       g_ptr_array_add (result, cand);
+               }
+       });
 
        if (frequency_heuristic_applied) {
                g_ptr_array_sort_with_data (result,
@@ -1353,8 +1384,8 @@ rspamd_language_detector_detect (struct rspamd_task *task,
        else {
                g_ptr_array_sort (result, rspamd_language_detector_cmp);
        }
-       g_hash_table_unref (candidates);
 
+       kh_destroy (rspamd_candidates_hash, candidates);
 
        if (result->len > 0 && !frequency_heuristic_applied) {
                cand = g_ptr_array_index (result, 0);