]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Follow-up for static disabling of the symbols
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 17 Jul 2022 16:57:55 +0000 (17:57 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 17 Jul 2022 16:57:55 +0000 (17:57 +0100)
src/libserver/symcache/symcache_impl.cxx
src/libserver/symcache/symcache_item.hxx

index 77f7ea19ff7e029e850642ebb52f75eb19a7cd4e..0ec6d11ec379c983adfd828728b380f399522446 100644 (file)
@@ -40,6 +40,63 @@ auto symcache::init() -> bool
                res = load_items();
        }
 
+       ankerl::unordered_dense::set<int> disabled_ids;
+       /* Process enabled/disabled symbols */
+       for (const auto &it: items_by_id) {
+               if (disabled_symbols) {
+                       /*
+                        * Due to the ability to add patterns, this is now O(N^2), but it is done
+                        * once on configuration and the amount of static patterns is usually low
+                        * The possible optimization is to store non patterns in a different set to check it
+                        * quickly. However, it is unlikely that this would be used to something really heavy.
+                        */
+                       for (const auto &disable_pat: *disabled_symbols) {
+                               if (disable_pat.matches(it->get_name())) {
+                                       msg_debug_cache("symbol %s matches %*s disable pattern", it->get_name().c_str(),
+                                               (int)disable_pat.to_string_view().size(), disable_pat.to_string_view().data());
+                                       auto need_disable = true;
+
+                                       if (enabled_symbols) {
+                                               for (const auto &enable_pat: *enabled_symbols) {
+                                                       if (enable_pat.matches(it->get_name())) {
+                                                               msg_debug_cache("symbol %s matches %*s enable pattern; skip disabling", it->get_name().c_str(),
+                                                                               (int)enable_pat.to_string_view().size(), enable_pat.to_string_view().data());
+                                                               need_disable = false;
+                                                               break;
+                                                       }
+                                               }
+                                       }
+
+                                       if (need_disable) {
+                                               disabled_ids.insert(it->id);
+
+                                               if (it->is_virtual()) {
+                                                       auto real_elt = it->get_parent(*this);
+
+                                                       if (real_elt) {
+                                                               disabled_ids.insert(real_elt->id);
+
+                                                               for (const auto &cld : real_elt->get_children().value().get()) {
+                                                                       msg_debug_cache("symbol %s is a virtual sibling of the disabled symbol %s",
+                                                                                       cld->get_name().c_str(), it->get_name().c_str());
+                                                                       disabled_ids.insert(cld->id);
+                                                               }
+                                                       }
+                                               }
+                                               else {
+                                                       /* Also disable all virtual children of this element */
+                                                       for (const auto &cld : it->get_children().value().get()) {
+                                                               msg_debug_cache("symbol %s is a virtual child of the disabled symbol %s",
+                                                                               cld->get_name().c_str(), it->get_name().c_str());
+                                                               disabled_ids.insert(cld->id);
+                                                       }
+                                               }
+                                       }
+                               }
+                       }
+               }
+       }
+
        /* Deal with the delayed dependencies */
        msg_debug_cache("resolving delayed dependencies: %d in list", (int)delayed_deps->size());
        for (const auto &delayed_dep: *delayed_deps) {
@@ -53,18 +110,52 @@ auto symcache::init() -> bool
                                        delayed_dep.to.data(), delayed_dep.from.data());
                }
                else {
-                       msg_debug_cache("delayed between %s(%d:%d) -> %s",
-                                       delayed_dep.from.data(),
-                                       real_item->id, virt_item->id,
-                                       delayed_dep.to.data());
-                       add_dependency(real_item->id, delayed_dep.to, virt_item != real_item ?
-                                                                                                                 virt_item->id : -1);
+
+                       if (!disabled_ids.contains(real_item->id)) {
+                               msg_debug_cache("delayed between %s(%d:%d) -> %s",
+                                               delayed_dep.from.data(),
+                                               real_item->id, virt_item->id,
+                                               delayed_dep.to.data());
+                               add_dependency(real_item->id, delayed_dep.to,
+                                               virt_item != real_item ? virt_item->id : -1);
+                       }
+                       else {
+                               msg_debug_cache("no delayed between %s(%d:%d) -> %s; %s is disabled",
+                                               delayed_dep.from.data(),
+                                               real_item->id, virt_item->id,
+                                               delayed_dep.to.data(),
+                                               delayed_dep.from.data());
+                       }
                }
        }
 
        /* Remove delayed dependencies, as they are no longer needed at this point */
        delayed_deps.reset();
 
+       /* Physically remove ids that are disabled statically */
+       for (auto id_to_disable : disabled_ids) {
+               /*
+                * This erasure is inefficient, we can swap the last element with the removed id
+                * But in this way, our ids are still sorted by addition
+                */
+
+               /* Preserve refcount here */
+               auto deleted_element_refcount = items_by_id[id_to_disable];
+               items_by_id.erase(std::begin(items_by_id) + id_to_disable);
+               items_by_symbol.erase(deleted_element_refcount->get_name());
+
+               auto &additional_vec = get_item_specific_vector(*deleted_element_refcount);
+               std::erase_if(additional_vec, [id_to_disable](const cache_item_ptr &elt) {
+                       return elt->id == id_to_disable;
+               });
+
+               /* Refcount is dropped, so the symbol should be freed, ensure that nothing else owns this symbol */
+               g_assert(deleted_element_refcount.use_count() == 1);
+       }
+
+       /* Remove no longer used stuff */
+       enabled_symbols.reset();
+       disabled_symbols.reset();
 
        /* Deal with the delayed conditions */
        msg_debug_cache("resolving delayed conditions: %d in list", (int)delayed_conditions->size());
index 50e3212654030fe5d10e39bce6df376bc6395fea..f0bedea5c1f7bd1143a88fe5f821e35898c69b69 100644 (file)
@@ -26,6 +26,7 @@
 #include <memory>
 #include <variant>
 #include <algorithm>
+#include <optional>
 
 #include "rspamd_symcache.h"
 #include "symcache_id_list.hxx"
@@ -139,6 +140,10 @@ public:
        auto add_child(const cache_item_ptr &ptr) -> void {
                virtual_children.push_back(ptr);
        }
+
+       auto get_childen() const -> const std::vector<cache_item_ptr>& {
+               return virtual_children;
+       }
 };
 
 class virtual_item {
@@ -415,6 +420,21 @@ public:
                }
        }
 
+       /**
+        * Returns virtual children for a normal item
+        * @param ptr
+        * @return
+        */
+       auto get_children() const -> std::optional<std::reference_wrapper<const std::vector<cache_item_ptr>>> {
+               if (std::holds_alternative<normal_item>(specific)) {
+                       const auto &filter_data = std::get<normal_item>(specific);
+
+                       return std::cref(filter_data.get_childen());
+               }
+
+               return std::nullopt;
+       }
+
 private:
        /**
         * Constructor for a normal symbols with callback