]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Further improvements of language detector by using khash
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 17 Apr 2018 15:54:53 +0000 (16:54 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 17 Apr 2018 15:54:53 +0000 (16:54 +0100)
src/libmime/lang_detection.c

index 0d9b40ec1eb836a6e3367b8c9f3394d588c07e02..363486209adb4cd9cc974bf6005b97c0a27944d4 100644 (file)
@@ -18,6 +18,7 @@
 #include "libutil/logger.h"
 #include "libcryptobox/cryptobox.h"
 #include "ucl.h"
+#include "khash.h"
 #include <glob.h>
 #include <unicode/utf8.h>
 #include <unicode/ucnv.h>
@@ -108,16 +109,6 @@ struct rspamd_ngramm_chain {
        gchar *utf;
 };
 
-struct rspamd_lang_detector {
-       GPtrArray *languages;
-       GHashTable *unigramms; /* unigramms frequencies */
-       GHashTable *trigramms; /* trigramms frequencies */
-       GHashTable *unicode_scripts; /* indexed by unicode script */
-       UConverter *uchar_converter;
-       gsize short_text_limit;
-       gsize total_occurencies; /* number of all languages found */
-};
-
 #define msg_debug_lang_det(...)  rspamd_conditional_debug_fast (NULL, NULL, \
         rspamd_langdet_log_id, "langdet", task->task_pool->tag.uid, \
         G_STRFUNC, \
@@ -158,29 +149,44 @@ rspamd_language_search_str (const gchar *key, const gchar *elts[], size_t nelts)
 }
 
 static guint
-rspamd_unigram_hash (gconstpointer key)
+rspamd_unigram_hash_func (gconstpointer key)
 {
        return rspamd_cryptobox_fast_hash (key, sizeof (UChar), rspamd_hash_seed ());
 }
 
 static gboolean
-rspamd_unigram_equal (gconstpointer v, gconstpointer v2)
+rspamd_unigram_equal_func (gconstpointer v, gconstpointer v2)
 {
        return memcmp (v, v2, sizeof (UChar)) == 0;
 }
 
 static guint
-rspamd_trigram_hash (gconstpointer key)
+rspamd_trigram_hash_func (gconstpointer key)
 {
        return rspamd_cryptobox_fast_hash (key, 3 * sizeof (UChar), rspamd_hash_seed ());
 }
 
 static gboolean
-rspamd_trigram_equal (gconstpointer v, gconstpointer v2)
+rspamd_trigram_equal_func (gconstpointer v, gconstpointer v2)
 {
        return memcmp (v, v2, 3 * sizeof (UChar)) == 0;
 }
 
+KHASH_INIT (rspamd_unigram_hash, UChar *, struct rspamd_ngramm_chain, true,
+               rspamd_unigram_hash_func, rspamd_unigram_equal_func);
+KHASH_INIT (rspamd_trigram_hash, UChar *, struct rspamd_ngramm_chain, true,
+               rspamd_trigram_hash_func, rspamd_trigram_equal_func);
+
+struct rspamd_lang_detector {
+       GPtrArray *languages;
+       khash_t(rspamd_unigram_hash) *unigramms; /* unigramms frequencies */
+       khash_t(rspamd_trigram_hash) *trigramms; /* trigramms frequencies */
+       GHashTable *unicode_scripts; /* indexed by unicode script */
+       UConverter *uchar_converter;
+       gsize short_text_limit;
+       gsize total_occurencies; /* number of all languages found */
+};
+
 static void
 rspamd_language_detector_ucs_lowercase (UChar *s, gsize len)
 {
@@ -220,40 +226,54 @@ rspamd_language_detector_init_ngramm (struct rspamd_config *cfg,
                struct rspamd_language_elt *lelt,
                struct rspamd_language_ucs_elt *ucs, guint len, guint freq, guint total)
 {
-       GHashTable *target;
-       struct rspamd_ngramm_chain *chain;
+       struct rspamd_ngramm_chain *chain = NULL, st_chain;
        struct rspamd_ngramm_elt *elt;
+       khiter_t k;
        guint i;
        gboolean found;
 
        switch (len) {
        case 1:
-               target = d->unigramms;
+               k = kh_get (rspamd_unigram_hash, d->unigramms, ucs->s);
+               if (k != kh_end (d->unigramms)) {
+                       chain = &kh_value (d->unigramms, k);
+               }
                break;
        case 2:
                g_assert_not_reached ();
                break;
        case 3:
-               target = d->trigramms;
+               k = kh_get (rspamd_trigram_hash, d->trigramms, ucs->s);
+               if (k != kh_end (d->trigramms)) {
+                       chain = &kh_value (d->trigramms, k);
+               }
                break;
        default:
                g_assert_not_reached ();
                break;
        }
 
-       chain = g_hash_table_lookup (target, ucs->s);
-
        if (chain == NULL) {
                /* New element */
-               chain = rspamd_mempool_alloc0 (cfg->cfg_pool, sizeof (*chain));
+               chain = &st_chain;
+               memset (chain, 0, sizeof (st_chain));
                chain->languages = g_ptr_array_sized_new (32);
+               rspamd_mempool_add_destructor (cfg->cfg_pool, rspamd_ptr_array_free_hard,
+                               chain->languages);
                chain->utf = rspamd_mempool_strdup (cfg->cfg_pool, ucs->utf);
                elt = rspamd_mempool_alloc (cfg->cfg_pool, sizeof (*elt));
                elt->elt = lelt;
                elt->prob = ((gdouble)freq) / ((gdouble)total);
                g_ptr_array_add (chain->languages, elt);
 
-               g_hash_table_insert (target, ucs->s, chain);
+               if (len == 1) {
+                       k = kh_put (rspamd_unigram_hash, d->unigramms, ucs->s, &i);
+                       kh_value (d->unigramms, k) = *chain;
+               }
+               else {
+                       k = kh_put (rspamd_trigram_hash, d->trigramms, ucs->s, &i);
+                       kh_value (d->trigramms, k) = *chain;
+               }
        }
        else {
                /* Check sanity */
@@ -609,9 +629,7 @@ rspamd_language_detector_init (struct rspamd_config *cfg)
        size_t i, short_text_limit = default_short_text_limit;
        UErrorCode uc_err = U_ZERO_ERROR;
        GString *languages_pattern;
-       GHashTableIter it;
-       gpointer k, v;
-       struct rspamd_ngramm_chain *chain;
+       struct rspamd_ngramm_chain *chain, schain;
        gchar *fname;
        struct rspamd_lang_detector *ret = NULL;
 
@@ -648,10 +666,8 @@ rspamd_language_detector_init (struct rspamd_config *cfg)
        ret->uchar_converter = ucnv_open ("UTF-8", &uc_err);
        ret->short_text_limit = short_text_limit;
        /* Map from ngramm in ucs32 to GPtrArray of rspamd_language_elt */
-       ret->unigramms = g_hash_table_new_full (rspamd_unigram_hash,
-                       rspamd_unigram_equal, NULL, rspamd_ptr_array_free_hard);
-       ret->trigramms = g_hash_table_new_full (rspamd_trigram_hash,
-                       rspamd_trigram_equal, NULL, rspamd_ptr_array_free_hard);
+       ret->unigramms = kh_init (rspamd_unigram_hash);
+       ret->trigramms = kh_init (rspamd_trigram_hash);
        ret->unicode_scripts = g_hash_table_new (g_int_hash, g_int_equal);
 
        g_assert (uc_err == U_ZERO_ERROR);
@@ -671,20 +687,18 @@ rspamd_language_detector_init (struct rspamd_config *cfg)
                g_free (fname);
        }
 
-       g_hash_table_iter_init (&it, ret->trigramms);
-
-       while (g_hash_table_iter_next (&it, &k, &v)) {
-               chain = (struct rspamd_ngramm_chain *)v;
+       kh_foreach_value (ret->trigramms, schain, {
+               chain = &schain;
                rspamd_language_detector_process_chain (cfg, chain);
-       }
+       });
 
        msg_info_config ("loaded %d languages, %d unicode only languages, "
                        "%d unigramms, "
                        "%d trigramms",
                        (gint)ret->languages->len,
                        (gint)g_hash_table_size (ret->unicode_scripts),
-                       (gint)g_hash_table_size (ret->unigramms),
-                       (gint)g_hash_table_size (ret->trigramms));
+                       (gint)kh_size (ret->unigramms),
+                       (gint)kh_size (ret->trigramms));
 end:
        if (gl.gl_pathc > 0) {
                globfree (&gl);
@@ -869,24 +883,27 @@ rspamd_language_detector_process_ngramm_full (struct rspamd_task *task,
                GHashTable *candidates)
 {
        guint i;
-       struct rspamd_ngramm_chain *chain;
+       struct rspamd_ngramm_chain *chain = NULL;
        struct rspamd_ngramm_elt *elt;
        struct rspamd_lang_detector_res *cand;
-       GHashTable *ngramms;
+       khiter_t k;
        gdouble prob;
 
        switch (type) {
        case rs_unigramm:
-               ngramms = d->unigramms;
+               k = kh_get (rspamd_unigram_hash, d->unigramms, window);
+               if (k != kh_end (d->unigramms)) {
+                       chain = &kh_value (d->unigramms, k);
+               }
                break;
        case rs_trigramm:
-               ngramms = d->trigramms;
+               k = kh_get (rspamd_trigram_hash, d->trigramms, window);
+               if (k != kh_end (d->trigramms)) {
+                       chain = &kh_value (d->trigramms, k);
+               }
                break;
        }
 
-
-       chain = g_hash_table_lookup (ngramms, window);
-
        if (chain) {
                PTR_ARRAY_FOREACH (chain->languages, i, elt) {
                        cand = g_hash_table_lookup (candidates, elt->elt->name);