/*- * Copyright 2021 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef RSPAMD_MIME_STRING_HXX #define RSPAMD_MIME_STRING_HXX #pragma once #include #include #include #include #include #include #include #include "function2/function2.hpp" #include "unicode/utf8.h" #include "contrib/fastutf8/fastutf8.h" namespace rspamd::mime { /* * The motivation for another string is to have utf8 valid string replacing * all bad things with FFFFD replacement character and filtering \0 and other * strange stuff defined by policies * This string always exclude \0 characters and ignore them! This is how MUA acts, * and we also store a flag about bad characters */ template> class basic_mime_string; using mime_string = basic_mime_string; /* Helpers for type safe flags */ enum class mime_string_flags : std::uint8_t { MIME_STRING_DEFAULT = 0, MIME_STRING_SEEN_ZEROES = 0x1 << 0, MIME_STRING_SEEN_INVALID = 0x1 << 1, }; mime_string_flags operator |(mime_string_flags lhs, mime_string_flags rhs) { using ut = std::underlying_type::type; return static_cast(static_cast(lhs) | static_cast(rhs)); } mime_string_flags operator &(mime_string_flags lhs, mime_string_flags rhs) { using ut = std::underlying_type::type; return static_cast(static_cast(lhs) & static_cast(rhs)); } bool operator !(mime_string_flags fl) { return fl == mime_string_flags::MIME_STRING_DEFAULT; } // Codepoint iterator base class template struct iterator_base { template friend class basic_mime_string; public: using value_type = typename Container::value_type; using difference_type = typename Container::difference_type; using codepoint_type = typename Container::codepoint_type; using reference_type = codepoint_type; using iterator_category = std::bidirectional_iterator_tag; bool operator==(const iterator_base &it) const noexcept { return idx == it.idx; } bool operator!=(const iterator_base &it) const noexcept { return idx != it.idx; } iterator_base(difference_type index, Container *instance) noexcept: idx(index), cont_instance(instance) {} iterator_base() noexcept = default; iterator_base(const iterator_base &) noexcept = default; iterator_base &operator=(const iterator_base &) noexcept = default; Container *get_instance() const noexcept { return cont_instance; } codepoint_type get_value() const noexcept { auto i = idx; codepoint_type uc; U8_NEXT_UNSAFE(cont_instance->data(), i, uc); return uc; } protected: difference_type idx; Container* cont_instance = nullptr; protected: void advance(difference_type n) noexcept { if (n > 0) { U8_FWD_N_UNSAFE(cont_instance->data(), idx, n); } else if (n < 0) { U8_BACK_N_UNSAFE(cont_instance->data(), idx, (-n)); } } void increment() noexcept { codepoint_type uc; U8_NEXT_UNSAFE(cont_instance->data(), idx, uc); } void decrement() noexcept { codepoint_type uc; U8_PREV_UNSAFE(cont_instance->data(), idx, uc); } }; // Partial spec for raw Byte-based iterator base template struct iterator_base { template friend class basic_string; public: using value_type = typename Container::value_type; using difference_type = typename Container::difference_type; using reference_type = value_type; using iterator_category = std::bidirectional_iterator_tag; bool operator==( const iterator_base& it ) const noexcept { return idx == it.idx; } bool operator!=( const iterator_base& it ) const noexcept { return idx != it.idx; } iterator_base(difference_type index, Container *instance) noexcept: idx(index), cont_instance(instance) {} iterator_base() noexcept = default; iterator_base( const iterator_base& ) noexcept = default; iterator_base& operator=( const iterator_base& ) noexcept = default; Container* get_instance() const noexcept { return cont_instance; } value_type get_value() const noexcept { return cont_instance->storage.at(idx, std::nothrow); } protected: difference_type idx; Container* cont_instance = nullptr; protected: //! Advance the iterator n times (negative values allowed!) void advance( difference_type n ) noexcept { idx += n; } void increment() noexcept { idx ++; } void decrement() noexcept { idx --; } }; template struct iterator; template struct const_iterator; template struct iterator : iterator_base { iterator(typename iterator_base::difference_type index, Container *instance) noexcept: iterator_base(index, instance) { } iterator() noexcept = default; iterator(const iterator &) noexcept = default; iterator &operator=(const iterator &) noexcept = default; /* Disallow creating from const_iterator */ iterator(const const_iterator &) = delete; /* Prefix */ iterator &operator++() noexcept { this->increment(); return *this; } /* Postfix */ iterator operator++(int) noexcept { iterator tmp{this->idx, this->cont_instance}; this->increment(); return tmp; } /* Prefix */ iterator &operator--() noexcept { this->decrement(); return *this; } /* Postfix */ iterator operator--(int) noexcept { iterator tmp{this->idx, this->cont_instance}; this->decrement(); return tmp; } iterator operator+(typename iterator_base::difference_type n) const noexcept { iterator it{*this}; it.advance(n); return it; } iterator &operator+=(typename iterator_base::difference_type n) noexcept { this->advance(n); return *this; } iterator operator-(typename iterator_base::difference_type n) const noexcept { iterator it{*this}; it.advance(-n); return it; } iterator &operator-=(typename iterator_base::difference_type n) noexcept { this->advance(-n); return *this; } typename iterator::reference_type operator*() const noexcept { return this->get_value(); } }; template struct const_iterator : iterator { const_iterator(typename iterator_base::difference_type index, const Container *instance) noexcept: iterator(index, const_cast(instance)) { } const_iterator(const iterator &other) noexcept: iterator(other) { } const_iterator() noexcept = default; const_iterator(const const_iterator &) noexcept = default; const_iterator &operator=(const const_iterator &) noexcept = default; const typename iterator::reference_type operator*() const noexcept { return this->get_value(); } }; template class basic_mime_string : private Allocator { public: using storage_type = std::basic_string, Allocator>; using view_type = std::basic_string_view>; using filter_type = fu2::function_view; using codepoint_type = UChar32; using value_type = T; using difference_type = std::ptrdiff_t; using iterator = rspamd::mime::iterator; using const_iterator = rspamd::mime::const_iterator; using raw_iterator = rspamd::mime::iterator; using raw_const_iterator = rspamd::mime::const_iterator; /* Ctors */ basic_mime_string() noexcept : Allocator() {} explicit basic_mime_string(const Allocator& alloc) noexcept : Allocator(alloc) {} basic_mime_string(const T* str, std::size_t sz, const Allocator& alloc = Allocator()) noexcept : Allocator(alloc) { append_c_string_unfiltered(str, sz); } basic_mime_string(const storage_type &st, const Allocator& alloc = Allocator()) noexcept : basic_mime_string(st.data(), st.size(), alloc) {} basic_mime_string(const view_type &st, const Allocator& alloc = Allocator()) noexcept : basic_mime_string(st.data(), st.size(), alloc) {} /** * Creates a string with a filter function. It is calee responsibility to * ensure that the filter functor survives long enough to work with a string * @param str * @param sz * @param filt * @param alloc */ basic_mime_string(const T* str, std::size_t sz, filter_type &&filt, const Allocator& alloc = Allocator()) noexcept : Allocator(alloc), filter_func(std::forward(filt)) { append_c_string_filtered(str, sz); } basic_mime_string(const storage_type &st, filter_type &&filt, const Allocator& alloc = Allocator()) noexcept : basic_mime_string(st.data(), st.size(), std::forward(filt), alloc) {} basic_mime_string(const view_type &st, filter_type &&filt, const Allocator& alloc = Allocator()) noexcept : basic_mime_string(st.data(), st.size(), std::forward(filt), alloc) {} auto size() const -> std::size_t { return storage.size(); } auto data() const -> const T* { return storage.data(); } constexpr auto has_zeroes() const -> bool { return !!(flags & mime_string_flags::MIME_STRING_SEEN_ZEROES); } constexpr auto has_invalid() const -> bool { return !!(flags & mime_string_flags::MIME_STRING_SEEN_INVALID); } /** * Assign mime string from another string using move operation if a source string * is utf8 valid. * If this function returns false, then ownership has not been transferred * and the `other` string is unmodified as well as the storage * @param other * @return */ [[nodiscard]] auto assign_if_valid(storage_type &&other) -> bool { if (filter_func) { /* No way */ return false; } if (rspamd_fast_utf8_validate((const unsigned char *)other.data(), other.size()) == 0) { std::swap(storage, other); return true; } return false; } /** * Copy to the internal storage discarding the contained value * @param other * @return */ auto assign_copy(const storage_type &other) { storage.clear(); if (filter_func) { append_c_string_filtered(other.data(), other.size()); } else { append_c_string_unfiltered(other.data(), other.size()); } } auto append(const T* str, std::size_t size) -> std::size_t { if (filter_func) { return append_c_string_filtered(str, size); } else { return append_c_string_unfiltered(str, size); } } auto append(const storage_type &other) -> std::size_t { return append(other.data(), other.size()); } auto append(const view_type &other) -> std::size_t { return append(other.data(), other.size()); } auto operator ==(const basic_mime_string &other) { return other.storage == storage; } auto operator ==(const storage_type &other) { return other == storage; } auto operator ==(const view_type &other) { return other == storage; } auto operator ==(const T* other) { if (other == NULL) { return false; } auto olen = strlen(other); if (storage.size() == olen) { return memcmp(storage.data(), other, olen) == 0; } return false; } inline iterator begin() noexcept { return {0, this}; } inline const_iterator begin() const noexcept { return {0, this}; } inline raw_iterator raw_begin() noexcept { return {0, this}; } inline raw_const_iterator raw_begin() const noexcept { return {0, this}; } inline iterator end() noexcept { return {(difference_type) size(), this}; } inline const_iterator end() const noexcept { return {(difference_type) size(), this}; } inline raw_iterator raw_end() noexcept { return {(difference_type) size(), this}; } inline raw_const_iterator raw_end() const noexcept { return {(difference_type) size(), this}; } /* For doctest stringify */ friend std::ostream& operator<< (std::ostream& os, const T& value) { os << value.storage; return os; } private: mime_string_flags flags = mime_string_flags::MIME_STRING_DEFAULT; storage_type storage; filter_type filter_func; auto append_c_string_unfiltered(const T* str, std::size_t len) -> std::size_t { /* This is fast path */ const auto *p = str; const auto *end = str + len; std::ptrdiff_t err_offset; auto orig_size = storage.size(); storage.reserve(len + storage.size()); if (memchr(str, 0, len) != NULL) { /* Fallback to slow path */ flags = flags | mime_string_flags::MIME_STRING_SEEN_ZEROES; return append_c_string_filtered(str, len); } while (p < end && len > 0 && (err_offset = rspamd_fast_utf8_validate((const unsigned char *)p, len)) > 0) { auto cur_offset = err_offset - 1; storage.append(p, cur_offset); while (cur_offset < len) { auto tmp = cur_offset; UChar32 uc; U8_NEXT(p, cur_offset, len, uc); if (uc < 0) { storage.append("\uFFFD"); flags = flags | mime_string_flags::MIME_STRING_SEEN_INVALID; } else { cur_offset = tmp; break; } } p += cur_offset; len = end - p; } storage.append(p, len); return storage.size() - orig_size; } auto append_c_string_filtered(const T* str, std::size_t len) -> std::size_t { std::ptrdiff_t i = 0, o = 0; UChar32 uc; char tmp[4]; auto orig_size = storage.size(); /* Slow path */ storage.reserve(len + storage.size()); while (i < len) { U8_NEXT(str, i, len, uc); if (uc < 0) { /* Replace with 0xFFFD */ storage.append("\uFFFD"); flags = flags | mime_string_flags::MIME_STRING_SEEN_INVALID; } else { if (filter_func) { uc = filter_func(uc); } if (uc == 0) { /* Special case, ignore it */ flags = flags | mime_string_flags::MIME_STRING_SEEN_ZEROES; } else { o = 0; U8_APPEND_UNSAFE(tmp, o, uc); storage.append(tmp, o); } } } return storage.size() - orig_size; } }; } #endif //RSPAMD_MIME_STRING_HXX