diff options
author | Peng Liu <winner245@hotmail.com> | 2025-02-26 12:18:25 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-26 12:18:25 -0500 |
commit | 7717a549e91c4fb554b78fce38e75b0147fb6cac (patch) | |
tree | 961ce4ed8e1658652b4e67de94333cc7ab16d17f /libcxx/include/__algorithm/equal.h | |
parent | 7ffeab3121c984cc00f79b0a78f372a4f7526e3b (diff) | |
download | llvm-7717a549e91c4fb554b78fce38e75b0147fb6cac.zip llvm-7717a549e91c4fb554b78fce38e75b0147fb6cac.tar.gz llvm-7717a549e91c4fb554b78fce38e75b0147fb6cac.tar.bz2 |
[libc++] Optimize ranges::equal for vector<bool>::iterator (#121084)
This PR optimizes the performance of `std::ranges::equal` for
`vector<bool>::iterator`, addressing a subtask outlined in issue #64038.
The optimizations yield performance improvements of up to 188x for
aligned equality comparison and 82x for unaligned equality
comparison. Moreover, comprehensive tests covering up to 4 storage words
(256 bytes) with odd and even bit sizes are provided, which validate the
proposed optimizations in this patch.
Diffstat (limited to 'libcxx/include/__algorithm/equal.h')
-rw-r--r-- | libcxx/include/__algorithm/equal.h | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/libcxx/include/__algorithm/equal.h b/libcxx/include/__algorithm/equal.h index a276bb9..4cac965 100644 --- a/libcxx/include/__algorithm/equal.h +++ b/libcxx/include/__algorithm/equal.h @@ -11,16 +11,20 @@ #define _LIBCPP___ALGORITHM_EQUAL_H #include <__algorithm/comp.h> +#include <__algorithm/min.h> #include <__algorithm/unwrap_iter.h> #include <__config> #include <__functional/identity.h> +#include <__fwd/bit_reference.h> #include <__iterator/distance.h> #include <__iterator/iterator_traits.h> +#include <__memory/pointer_traits.h> #include <__string/constexpr_c_functions.h> #include <__type_traits/desugars_to.h> #include <__type_traits/enable_if.h> #include <__type_traits/invoke.h> #include <__type_traits/is_equality_comparable.h> +#include <__type_traits/is_same.h> #include <__type_traits/is_volatile.h> #include <__utility/move.h> @@ -33,6 +37,136 @@ _LIBCPP_PUSH_MACROS _LIBCPP_BEGIN_NAMESPACE_STD +template <class _Cp, bool _IsConst1, bool _IsConst2> +[[__nodiscard__]] _LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool +__equal_unaligned(__bit_iterator<_Cp, _IsConst1> __first1, + __bit_iterator<_Cp, _IsConst1> __last1, + __bit_iterator<_Cp, _IsConst2> __first2) { + using _It = __bit_iterator<_Cp, _IsConst1>; + using difference_type = typename _It::difference_type; + using __storage_type = typename _It::__storage_type; + + const int __bits_per_word = _It::__bits_per_word; + difference_type __n = __last1 - __first1; + if (__n > 0) { + // do first word + if (__first1.__ctz_ != 0) { + unsigned __clz_f = __bits_per_word - __first1.__ctz_; + difference_type __dn = std::min(static_cast<difference_type>(__clz_f), __n); + __n -= __dn; + __storage_type __m = (~__storage_type(0) << __first1.__ctz_) & (~__storage_type(0) >> (__clz_f - __dn)); + __storage_type __b = *__first1.__seg_ & __m; + unsigned __clz_r = __bits_per_word - __first2.__ctz_; + __storage_type __ddn = std::min<__storage_type>(__dn, __clz_r); + __m = (~__storage_type(0) << __first2.__ctz_) & (~__storage_type(0) >> (__clz_r - __ddn)); + if (__first2.__ctz_ > __first1.__ctz_) { + if ((*__first2.__seg_ & __m) != (__b << (__first2.__ctz_ - __first1.__ctz_))) + return false; + } else { + if ((*__first2.__seg_ & __m) != (__b >> (__first1.__ctz_ - __first2.__ctz_))) + return false; + } + __first2.__seg_ += (__ddn + __first2.__ctz_) / __bits_per_word; + __first2.__ctz_ = static_cast<unsigned>((__ddn + __first2.__ctz_) % __bits_per_word); + __dn -= __ddn; + if (__dn > 0) { + __m = ~__storage_type(0) >> (__bits_per_word - __dn); + if ((*__first2.__seg_ & __m) != (__b >> (__first1.__ctz_ + __ddn))) + return false; + __first2.__ctz_ = static_cast<unsigned>(__dn); + } + ++__first1.__seg_; + // __first1.__ctz_ = 0; + } + // __first1.__ctz_ == 0; + // do middle words + unsigned __clz_r = __bits_per_word - __first2.__ctz_; + __storage_type __m = ~__storage_type(0) << __first2.__ctz_; + for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_) { + __storage_type __b = *__first1.__seg_; + if ((*__first2.__seg_ & __m) != (__b << __first2.__ctz_)) + return false; + ++__first2.__seg_; + if ((*__first2.__seg_ & ~__m) != (__b >> __clz_r)) + return false; + } + // do last word + if (__n > 0) { + __m = ~__storage_type(0) >> (__bits_per_word - __n); + __storage_type __b = *__first1.__seg_ & __m; + __storage_type __dn = std::min(__n, static_cast<difference_type>(__clz_r)); + __m = (~__storage_type(0) << __first2.__ctz_) & (~__storage_type(0) >> (__clz_r - __dn)); + if ((*__first2.__seg_ & __m) != (__b << __first2.__ctz_)) + return false; + __first2.__seg_ += (__dn + __first2.__ctz_) / __bits_per_word; + __first2.__ctz_ = static_cast<unsigned>((__dn + __first2.__ctz_) % __bits_per_word); + __n -= __dn; + if (__n > 0) { + __m = ~__storage_type(0) >> (__bits_per_word - __n); + if ((*__first2.__seg_ & __m) != (__b >> __dn)) + return false; + } + } + } + return true; +} + +template <class _Cp, bool _IsConst1, bool _IsConst2> +[[__nodiscard__]] _LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool +__equal_aligned(__bit_iterator<_Cp, _IsConst1> __first1, + __bit_iterator<_Cp, _IsConst1> __last1, + __bit_iterator<_Cp, _IsConst2> __first2) { + using _It = __bit_iterator<_Cp, _IsConst1>; + using difference_type = typename _It::difference_type; + using __storage_type = typename _It::__storage_type; + + const int __bits_per_word = _It::__bits_per_word; + difference_type __n = __last1 - __first1; + if (__n > 0) { + // do first word + if (__first1.__ctz_ != 0) { + unsigned __clz = __bits_per_word - __first1.__ctz_; + difference_type __dn = std::min(static_cast<difference_type>(__clz), __n); + __n -= __dn; + __storage_type __m = (~__storage_type(0) << __first1.__ctz_) & (~__storage_type(0) >> (__clz - __dn)); + if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m)) + return false; + ++__first2.__seg_; + ++__first1.__seg_; + // __first1.__ctz_ = 0; + // __first2.__ctz_ = 0; + } + // __first1.__ctz_ == 0; + // __first2.__ctz_ == 0; + // do middle words + for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_, ++__first2.__seg_) + if (*__first2.__seg_ != *__first1.__seg_) + return false; + // do last word + if (__n > 0) { + __storage_type __m = ~__storage_type(0) >> (__bits_per_word - __n); + if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m)) + return false; + } + } + return true; +} + +template <class _Cp, + bool _IsConst1, + bool _IsConst2, + class _BinaryPredicate, + __enable_if_t<__desugars_to_v<__equal_tag, _BinaryPredicate, bool, bool>, int> = 0> +[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_iter_impl( + __bit_iterator<_Cp, _IsConst1> __first1, + __bit_iterator<_Cp, _IsConst1> __last1, + __bit_iterator<_Cp, _IsConst2> __first2, + _BinaryPredicate) { + if (__first1.__ctz_ == __first2.__ctz_) + return std::__equal_aligned(__first1, __last1, __first2); + return std::__equal_unaligned(__first1, __last1, __first2); +} + template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate> [[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_iter_impl( _InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _BinaryPredicate& __pred) { @@ -94,6 +228,28 @@ __equal_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _Up*, _Pred&, _Proj1&, return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1)); } +template <class _Cp, + bool _IsConst1, + bool _IsConst2, + class _Pred, + class _Proj1, + class _Proj2, + __enable_if_t<__desugars_to_v<__equal_tag, _Pred, bool, bool> && __is_identity<_Proj1>::value && + __is_identity<_Proj2>::value, + int> = 0> +[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_impl( + __bit_iterator<_Cp, _IsConst1> __first1, + __bit_iterator<_Cp, _IsConst1> __last1, + __bit_iterator<_Cp, _IsConst2> __first2, + __bit_iterator<_Cp, _IsConst2>, + _Pred&, + _Proj1&, + _Proj2&) { + if (__first1.__ctz_ == __first2.__ctz_) + return std::__equal_aligned(__first1, __last1, __first2); + return std::__equal_unaligned(__first1, __last1, __first2); +} + template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate> [[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool equal(_InputIterator1 __first1, |