]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Really fix the language detector statistical heuristic
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 18 Jan 2024 14:13:41 +0000 (14:13 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 18 Jan 2024 14:13:41 +0000 (14:13 +0000)
src/libmime/lang_detection.c

index c44aa2b049363bbeacb7ab4bbbdeb391c3091b74..383005ad10513d497459bb4d232346b523811f57 100644 (file)
@@ -1335,14 +1335,15 @@ rspamd_language_detector_cmp_heuristic(gconstpointer a, gconstpointer b,
                                                                           gpointer ud)
 {
        struct rspamd_frequency_sort_cbdata *cbd = ud;
-       const struct rspamd_lang_detector_res
-               *canda = *(const struct rspamd_lang_detector_res **) a,
-               *candb = *(const struct rspamd_lang_detector_res **) b;
+       struct rspamd_lang_detector_res
+               *canda = *(struct rspamd_lang_detector_res **) a,
+               *candb = *(struct rspamd_lang_detector_res **) b;
        gdouble adj;
        gdouble proba_adjusted, probb_adjusted, freqa, freqb;
 
        if (cbd->d->total_occurrences == 0) {
-               return 0;
+               /* Not enough data, compare directly */
+               return rspamd_language_detector_cmp(a, b);
        }
 
        freqa = ((gdouble) canda->elt->occurrences) /
@@ -1387,6 +1388,10 @@ rspamd_language_detector_cmp_heuristic(gconstpointer a, gconstpointer b,
                probb_adjusted += cbd->std * adj;
        }
 
+       /* Hack: adjust probability directly */
+       canda->prob = proba_adjusted;
+       candb->prob = probb_adjusted;
+
        if (proba_adjusted > probb_adjusted) {
                return -1;
        }
@@ -1998,7 +2003,7 @@ rspamd_language_detector_detect(struct rspamd_task *task,
 
                        kh_foreach_value(candidates, cand, {
                                if (!isnan(cand->prob)) {
-                                       msg_debug_lang_det("final probability %s -> %.2f", cand->lang,
+                                       msg_debug_lang_det("pre-sorting probability %s -> %.2f", cand->lang,
                                                                           cand->prob);
                                        g_ptr_array_add(result, cand);
                                }
@@ -2006,18 +2011,18 @@ rspamd_language_detector_detect(struct rspamd_task *task,
 
                        if (frequency_heuristic_applied) {
                                g_ptr_array_sort_with_data(result,
-                                                                                  rspamd_language_detector_cmp_heuristic, (gpointer) &cbd);
+                                                                                  rspamd_language_detector_cmp_heuristic,
+                                                                                  (gpointer) &cbd);
                        }
                        else {
                                g_ptr_array_sort(result, rspamd_language_detector_cmp);
                        }
 
-                       if (result->len > 0 && !frequency_heuristic_applied) {
-                               cand = g_ptr_array_index(result, 0);
-                               if (cand->elt) {
-                                       cand->elt->occurrences++;
-                               }
-                               d->total_occurrences++;
+                       int i;
+                       PTR_ARRAY_FOREACH(result, i, cand)
+                       {
+                               msg_debug_lang_det("final probability %s -> %.2f", cand->lang,
+                                                                  cand->prob);
                        }
 
                        if (part->languages != NULL) {
@@ -2035,6 +2040,15 @@ rspamd_language_detector_detect(struct rspamd_task *task,
                kh_destroy(rspamd_candidates_hash, candidates);
        }
 
+       /* Update internal stat */
+       if (part->languages != NULL && part->languages->len > 0 && !frequency_heuristic_applied) {
+               cand = g_ptr_array_index(part->languages, 0);
+               if (cand->elt) {
+                       cand->elt->occurrences++;
+               }
+               d->total_occurrences++;
+       }
+
        end_ticks = rspamd_get_ticks(TRUE);
        msg_debug_lang_det("detected languages in %.0f ticks",
                                           (end_ticks - start_ticks));