]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Implement symbols augmentations
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 14 May 2022 12:05:14 +0000 (13:05 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 14 May 2022 12:05:14 +0000 (13:05 +0100)
src/libserver/symcache/symcache_impl.cxx
src/libserver/symcache/symcache_internal.hxx
src/libserver/symcache/symcache_item.cxx
src/libserver/symcache/symcache_item.hxx

index f76188c9faad6d642c4f94fd30bf7fe7945a99b9..e557f62124b7a48264ed287a6085503cfd5aafd4 100644 (file)
@@ -425,6 +425,7 @@ auto symcache::add_dependency(int id_from, std::string_view to, int virtual_id_f
 
 auto symcache::resort() -> void
 {
+       auto log_func = RSPAMD_LOG_FUNC;
        auto ord = std::make_shared<order_generation>(filters.size() +
                        prefilters.size() +
                        composites.size() +
@@ -436,6 +437,7 @@ auto symcache::resort() -> void
        for (auto &it: filters) {
                if (it) {
                        total_hits += it->st->total_hits;
+                       /* Unmask topological order */
                        it->order = 0;
                        ord->d.emplace_back(it);
                }
@@ -484,16 +486,16 @@ auto symcache::resort() -> void
                        }
                }
                else if (tsort_is_marked(it, tsort_mask::TEMP)) {
-                       msg_err_cache("cyclic dependencies found when checking '%s'!",
+                       msg_err_cache_lambda("cyclic dependencies found when checking '%s'!",
                                        it->symbol.c_str());
                        return;
                }
 
                tsort_mark(it, tsort_mask::TEMP);
-               msg_debug_cache("visiting node: %s (%d)", it->symbol.c_str(), cur_order);
+               msg_debug_cache_lambda("visiting node: %s (%d)", it->symbol.c_str(), cur_order);
 
                for (const auto &dep: it->deps) {
-                       msg_debug_cache ("visiting dep: %s (%d)", dep.item->symbol.c_str(), cur_order + 1);
+                       msg_debug_cache_lambda("visiting dep: %s (%d)", dep.item->symbol.c_str(), cur_order + 1);
                        rec(dep.item.get(), cur_order + 1, rec);
                }
 
@@ -528,16 +530,26 @@ auto symcache::resort() -> void
                if (o1 == o2) {
                        /* No topological order */
                        if (it1->priority == it2->priority) {
-                               auto avg_freq = ((double) total_hits / used_items);
-                               auto avg_weight = (total_weight / used_items);
-                               auto f1 = (double) it1->st->total_hits / avg_freq;
-                               auto f2 = (double) it2->st->total_hits / avg_freq;
-                               auto weight1 = std::fabs(it1->st->weight) / avg_weight;
-                               auto weight2 = std::fabs(it2->st->weight) / avg_weight;
-                               auto t1 = it1->st->avg_time;
-                               auto t2 = it2->st->avg_time;
-                               w1 = score_functor(weight1, f1, t1);
-                               w2 = score_functor(weight2, f2, t2);
+
+                               auto augmentations1 = it1->get_augmentation_weight();
+                               auto augmentations2 = it2->get_augmentation_weight();
+
+                               if (augmentations1 == augmentations2) {
+                                       auto avg_freq = ((double) total_hits / used_items);
+                                       auto avg_weight = (total_weight / used_items);
+                                       auto f1 = (double) it1->st->total_hits / avg_freq;
+                                       auto f2 = (double) it2->st->total_hits / avg_freq;
+                                       auto weight1 = std::fabs(it1->st->weight) / avg_weight;
+                                       auto weight2 = std::fabs(it2->st->weight) / avg_weight;
+                                       auto t1 = it1->st->avg_time;
+                                       auto t2 = it2->st->avg_time;
+                                       w1 = score_functor(weight1, f1, t1);
+                                       w2 = score_functor(weight2, f2, t2);
+                               }
+                               else {
+                                       w1 = augmentations1;
+                                       w2 = augmentations2;
+                               }
                        }
                        else {
                                /* Strict sorting */
index 84ae8de7f05271ebccbc5af87dc8fa5a77f06584..6a96eb5474327f61fcc425bf5033777cbf55829c 100644 (file)
         "symcache", log_tag(), \
         RSPAMD_LOG_FUNC, \
         __VA_ARGS__)
+#define msg_err_cache_lambda(...) rspamd_default_log_function (G_LOG_LEVEL_CRITICAL, \
+        "symcache", log_tag(), \
+        log_func, \
+        __VA_ARGS__)
 #define msg_err_cache_task(...) rspamd_default_log_function (G_LOG_LEVEL_CRITICAL, \
         "symcache", task->task_pool->tag.uid, \
         RSPAMD_LOG_FUNC, \
         ::rspamd::symcache::rspamd_symcache_log_id, "symcache", log_tag(), \
         RSPAMD_LOG_FUNC, \
         __VA_ARGS__)
+#define msg_debug_cache_lambda(...)  rspamd_conditional_debug_fast (NULL, NULL, \
+        ::rspamd::symcache::rspamd_symcache_log_id, "symcache", log_tag(), \
+        log_func, \
+        __VA_ARGS__)
 #define msg_debug_cache_task(...)  rspamd_conditional_debug_fast (NULL, NULL, \
         ::rspamd::symcache::rspamd_symcache_log_id, "symcache", task->task_pool->tag.uid, \
         RSPAMD_LOG_FUNC, \
index 70c1921bbc3dc581303c28402480fde31608c8b0..091e6cbf9967bf267aa1c0534aa7644fcdd89a1d 100644 (file)
 #include "symcache_item.hxx"
 #include "fmt/core.h"
 #include "libserver/task.h"
+#include "libutil/cxx/util.hxx"
+#include <numeric>
+#include <functional>
 
 namespace rspamd::symcache {
 
+/* A list of internal augmentations that are known to Rspamd with their weight */
+static const auto known_augmentations =
+               robin_hood::unordered_flat_map<std::string, int, rspamd::smart_str_hash, rspamd::smart_str_equal>{
+                               {"passthrough", 10},
+                               {"single_network", 1},
+                               {"no_network", 0},
+                               {"many_network", 1},
+                               {"important", 5},
+               };
+
 auto cache_item::get_parent(const symcache &cache) const -> const cache_item *
 {
        if (is_virtual()) {
@@ -347,6 +360,30 @@ auto cache_item::is_allowed(struct rspamd_task *task, bool exec_only) const -> b
        return true;
 }
 
+auto
+cache_item::add_augmentation(const symcache &cache, std::string_view augmentation) -> bool {
+       auto log_tag = [&]() { return cache.log_tag(); };
+
+       if (augmentations.contains(augmentation)) {
+               msg_warn_cache("duplicate augmentation: %s", augmentation.data());
+       }
+
+       augmentations.insert(std::string(augmentation));
+
+       return known_augmentations.contains(augmentation);
+}
+
+auto
+cache_item::get_augmentation_weight() const -> int
+{
+       return std::accumulate(std::begin(augmentations), std::end(augmentations),
+                                                 0, [](int acc, const std::string &augmentation) {
+               int zero = 0; /* C++ limitation of the cref */
+               return acc + rspamd::find_map(known_augmentations, augmentation).value_or(std::cref<int>(zero));
+       });
+}
+
+
 auto virtual_item::get_parent(const symcache &cache) const -> const cache_item *
 {
        if (parent) {
index 40e2d67c1601853c3538903a5b3e05da4d10b7ff..70203770a41b34442475ae1ca4eaf039158ffcac 100644 (file)
@@ -32,6 +32,7 @@
 #include "contrib/expected/expected.hpp"
 #include "contrib/libev/ev.h"
 #include "symcache_runtime.hxx"
+#include "libutil/cxx/hash_util.hxx"
 
 namespace rspamd::symcache {
 
@@ -193,6 +194,9 @@ struct cache_item : std::enable_shared_from_this<cache_item> {
        id_list exec_only_ids{};
        id_list forbidden_ids{};
 
+       /* Set of augmentations */
+       robin_hood::unordered_flat_set<std::string, rspamd::smart_str_hash, rspamd::smart_str_equal> augmentations;
+
        /* Dependencies */
        std::vector<cache_dependency> deps;
        /* Reverse dependencies */
@@ -378,6 +382,19 @@ public:
                }
        }
 
+       /**
+        * Add an augmentation to the item, returns `true` if augmentation is known and unique, false otherwise
+        * @param augmentation
+        * @return
+        */
+       auto add_augmentation(const symcache &cache, std::string_view augmentation) -> bool;
+
+       /**
+        * Return sum weight of all known augmentations
+        * @return
+        */
+       auto get_augmentation_weight() const -> int;
+
 private:
        /**
         * Constructor for a normal symbols with callback