]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Enable make_shared like behaviour
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 22 Jul 2021 16:08:36 +0000 (17:08 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 22 Jul 2021 16:08:36 +0000 (17:08 +0100)
src/libutil/cxx/local_shared_ptr.hxx
test/rspamd_cxx_unit_utils.hxx

index 8c8751447d3f11d6395722e68c38aabbcc4dbdbd..5a9cccb38fcb28082d2b05d5a887d24561ac14d5 100644 (file)
  */
 namespace rspamd {
 
+namespace detail {
+
+class ref_cnt {
+public:
+       using refcount_t = int;
+
+       constexpr auto add_shared() -> refcount_t {
+               return ++ref_shared;
+       }
+       constexpr auto release_shared() -> refcount_t {
+               return --ref_shared;
+       }
+       constexpr auto release_weak() -> refcount_t {
+               return --ref_weak;
+       }
+       constexpr auto shared_count() const -> refcount_t {
+               return ref_shared;
+       }
+       virtual ~ref_cnt() {}
+       virtual void dispose() = 0;
+private:
+       refcount_t ref_weak = 0;
+       refcount_t ref_shared = 1;
+};
+
 template <class T>
-class local_weak_ptr {
-       typedef T element_type;
+class obj_and_refcnt : public ref_cnt {
+private:
+       typedef typename std::aligned_storage<sizeof(T), std::alignment_of<T>::value>::type storage_type;
+       storage_type storage;
+       bool initialized;
+       virtual void dispose() override {
+               if (initialized) {
+                       T *p = reinterpret_cast<T *>(&storage);
+                       p->~T();
+                       initialized = false;
+               }
+       }
+public:
+       template <typename... Args>
+       explicit obj_and_refcnt(Args&&... args) : initialized(true)
+       {
+               new(&storage) T(std::forward<Args>(args)...);
+       }
+       auto get(void) -> T* {
+               if (initialized) {
+                       return reinterpret_cast<T *>(&storage);
+               }
+
+               return nullptr;
+       }
+       virtual ~obj_and_refcnt() = default;
+};
 
+template <class T, class D = typename std::default_delete<T>>
+class ptr_and_refcnt : public ref_cnt {
+private:
+       T* ptr;
+       D deleter;
+       virtual void dispose() override {
+               deleter(ptr);
+               ptr = nullptr;
+       }
+public:
+       explicit ptr_and_refcnt(T *_ptr, D d = std::default_delete<T>()) : ptr(_ptr),
+                       deleter(std::move(d)) {}
+       virtual ~ptr_and_refcnt() = default;
 };
 
+}
+
+template <class T> class local_weak_ptr;
+
 template <class T>
 class local_shared_ptr {
 public:
@@ -48,8 +115,15 @@ public:
 
        template<class Y, typename std::enable_if<
                std::is_convertible<Y*, element_type*>::value, bool>::type = true>
-       explicit local_shared_ptr(Y* p) : px(p), cnt(new local_shared_ptr::control) {
-               cnt->add_shared();
+       explicit local_shared_ptr(Y* p) : px(p), cnt(new detail::ptr_and_refcnt(p))
+       {
+       }
+
+       // custom deleter
+       template<class Y, class D, typename std::enable_if<
+                       std::is_convertible<Y*, element_type*>::value, bool>::type = true>
+       explicit local_shared_ptr(Y* p, D d) : px(p), cnt(new detail::ptr_and_refcnt(p, std::move(d)))
+       {
        }
 
        local_shared_ptr(const local_shared_ptr& r) noexcept : px(r.px), cnt(r.cnt) {
@@ -67,8 +141,7 @@ public:
        ~local_shared_ptr() {
                if (cnt) {
                        if (cnt->release_shared() <= 0) {
-                               delete px;
-                               px = nullptr;
+                               cnt->dispose();
 
                                if (cnt->release_weak() <= 0) {
                                        delete cnt;
@@ -135,31 +208,24 @@ public:
        }
 
 private:
-       class control {
-       public:
-               using refcount_t = int;
-
-               constexpr auto add_shared() -> refcount_t {
-                       return ++ref_shared;
-               }
-               constexpr auto release_shared() -> refcount_t {
-                       return --ref_shared;
-               }
-               constexpr auto release_weak() -> refcount_t {
-                       return --ref_weak;
-               }
-               constexpr auto shared_count() const -> refcount_t {
-                       return ref_shared;
-               }
-       private:
-               refcount_t ref_weak = 0;
-               refcount_t ref_shared = 0;
-       };
-
        T *px; // contained pointer
-       control *cnt;
+       detail::ref_cnt *cnt;
+
+       template<class _T, class ... Args>
+       friend local_shared_ptr<_T> local_make_shared(Args && ... args);
 };
 
+template<class T, class ... Args>
+local_shared_ptr<T> local_make_shared(Args && ... args)
+{
+       local_shared_ptr<T> ptr;
+       auto tmp_object = new detail::obj_and_refcnt<T>(std::forward<Args>(args)...);
+       ptr.px = tmp_object->get();
+       ptr.cnt = tmp_object;
+
+       return ptr;
+}
+
 }
 
 #endif //RSPAMD_LOCAL_SHARED_PTR_HXX
index a6cbc3a32ce7f97ada668066873cf5764e1c7b3c..be5d193f48f86820231aceebe267148c723a11b9 100644 (file)
@@ -206,6 +206,36 @@ TEST_CASE("shared_ptr dtor") {
        CHECK(t == true);
 }
 
+TEST_CASE("make_shared dtor") {
+       bool t;
+
+       {
+               auto pi = rspamd::local_make_shared<deleter_test>(t);
+
+               CHECK((!pi ? false : true));
+               CHECK(!!pi);
+               CHECK(pi.get() != nullptr);
+               CHECK(pi.use_count() == 1);
+               CHECK(pi.unique());
+               CHECK(t == false);
+
+               rspamd::local_shared_ptr<deleter_test> pi2(pi);
+               CHECK(pi2 == pi);
+               CHECK(pi.use_count() == 2);
+               pi.reset();
+               CHECK(!(pi2 == pi));
+               CHECK(pi2.use_count() == 1);
+               CHECK(t == false);
+
+               pi = pi2;
+               CHECK(pi2 == pi);
+               CHECK(pi.use_count() == 2);
+               CHECK(t == false);
+       }
+
+       CHECK(t == true);
+}
+
 }
 
 #endif