]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Add ability to statically maintain disabled/enabled patterns
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 17 Jul 2022 15:58:03 +0000 (16:58 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 17 Jul 2022 15:58:03 +0000 (16:58 +0100)
src/libserver/symcache/symcache_internal.hxx

index 063777a715b654d27fdd1abf4ca60e483db4f113..ea583a19978ab57a23063c4af96e4615941e4981 100644 (file)
@@ -143,6 +143,97 @@ public:
                sym(_sym), cbref(_cbref), L(_L) {}
 };
 
+class delayed_symbol_elt {
+private:
+       std::variant<std::string, rspamd_regexp_t *> content;
+public:
+       /* Disable copy */
+       delayed_symbol_elt() = delete;
+       delayed_symbol_elt(const delayed_symbol_elt &) = delete;
+       delayed_symbol_elt &operator=(const delayed_symbol_elt &) = delete;
+       /* Enable move */
+       delayed_symbol_elt(delayed_symbol_elt &&other) noexcept = default;
+       delayed_symbol_elt &operator=(delayed_symbol_elt &&other) noexcept = default;
+
+       explicit delayed_symbol_elt(std::string_view elt) noexcept {
+               if (!elt.empty() && elt[0] == '/') {
+                       /* Possibly regexp */
+                       auto *re = rspamd_regexp_new_len(elt.data(), elt.size(), nullptr, nullptr);
+
+                       if (re != nullptr) {
+                               std::get<rspamd_regexp_t *>(content) = re;
+                       }
+                       else {
+                               std::get<std::string>(content) = elt;
+                       }
+               }
+               else {
+                       std::get<std::string>(content) = elt;
+               }
+       }
+
+       ~delayed_symbol_elt() {
+               if (std::holds_alternative<rspamd_regexp_t *>(content)) {
+                       rspamd_regexp_unref(std::get<rspamd_regexp_t *>(content));
+               }
+       }
+
+       auto matches(std::string_view what) const -> bool {
+               return std::visit([&](auto &elt) {
+                       using T = typeof(elt);
+                       if constexpr (std::is_same_v<T, rspamd_regexp_t *>) {
+                               if (rspamd_regexp_match(elt, what.data(), what.size(), false)) {
+                                       return true;
+                               }
+                       }
+                       else if constexpr (std::is_same_v<T, std::string>) {
+                               return elt == what;
+                       }
+
+                       return false;
+               },
+               content);
+       }
+
+       auto to_string_view() const -> std::string_view {
+               return std::visit([&](auto &elt) {
+                       using T = typeof(elt);
+                       if constexpr (std::is_same_v<T, rspamd_regexp_t *>) {
+                               return std::string_view{rspamd_regexp_get_pattern(elt)};
+                       }
+                       else if constexpr (std::is_same_v<T, std::string>) {
+                               return std::string_view{elt};
+                       }
+
+                       return std::string_view{};
+               },
+               content);
+       }
+};
+
+struct delayed_symbol_elt_equal {
+       using is_transparent = void;
+       auto operator()(const delayed_symbol_elt &a, const delayed_symbol_elt &b) const {
+               return a.to_string_view() == b.to_string_view();
+       }
+       auto operator()(const delayed_symbol_elt &a, const std::string_view &b) const {
+               return a.to_string_view() == b;
+       }
+       auto operator()(const std::string_view &a, const delayed_symbol_elt &b) const {
+               return a == b.to_string_view();
+       }
+};
+
+struct delayed_symbol_elt_hash {
+       using is_transparent = void;
+       auto operator()(const delayed_symbol_elt &a) const {
+               return ankerl::unordered_dense::hash<std::string_view>()(a.to_string_view());
+       }
+       auto operator()(const std::string_view &a) const {
+               return ankerl::unordered_dense::hash<std::string_view>()(a);
+       }
+};
+
 class symcache {
 private:
        using items_ptr_vec = std::vector<cache_item_ptr>;
@@ -167,6 +258,11 @@ private:
        /* These are stored within pointer to clean up after init */
        std::unique_ptr<std::vector<delayed_cache_dependency>> delayed_deps;
        std::unique_ptr<std::vector<delayed_cache_condition>> delayed_conditions;
+       /* Delayed statically enabled or disabled symbols */
+       using delayed_symbol_names = ankerl::unordered_dense::set<delayed_symbol_elt,
+               delayed_symbol_elt_hash, delayed_symbol_elt_equal>;
+       std::unique_ptr<delayed_symbol_names> disabled_symbols;
+       std::unique_ptr<delayed_symbol_names> enabled_symbols;
 
        rspamd_mempool_t *static_pool;
        std::uint64_t cksum;
@@ -261,6 +357,44 @@ public:
                delayed_deps->emplace_back(from, to);
        }
 
+       /**
+        * Adds a symbol to the list of the disabled symbols
+        * @param sym
+        * @return
+        */
+       auto disable_symbol_delayed(std::string_view sym) -> bool {
+               if (!disabled_symbols) {
+                       disabled_symbols = std::make_unique<delayed_symbol_names>();
+               }
+
+               if (!disabled_symbols->contains(sym)) {
+                       disabled_symbols->emplace(sym);
+
+                       return true;
+               }
+
+               return false;
+       }
+
+       /**
+        * Adds a symbol to the list of the enabled symbols
+        * @param sym
+        * @return
+        */
+       auto enable_symbol_delayed(std::string_view sym) -> bool {
+               if (!enabled_symbols) {
+                       enabled_symbols = std::make_unique<delayed_symbol_names>();
+               }
+
+               if (!enabled_symbols->contains(sym)) {
+                       enabled_symbols->emplace(sym);
+
+                       return true;
+               }
+
+               return false;
+       }
+
        /**
         * Initialises the symbols cache, must be called after all symbols are added
         * and the config file is loaded