Warning: This file is not a C or C++ file. It does not have highlighting.
1 | // -*- C++ -*- |
---|---|
2 | //===----------------------------------------------------------------------===// |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | // Kokkos v. 4.0 |
9 | // Copyright (2022) National Technology & Engineering |
10 | // Solutions of Sandia, LLC (NTESS). |
11 | // |
12 | // Under the terms of Contract DE-NA0003525 with NTESS, |
13 | // the U.S. Government retains certain rights in this software. |
14 | // |
15 | //===---------------------------------------------------------------------===// |
16 | |
17 | #ifndef _LIBCPP___MDSPAN_LAYOUT_STRIDE_H |
18 | #define _LIBCPP___MDSPAN_LAYOUT_STRIDE_H |
19 | |
20 | #include <__assert> |
21 | #include <__concepts/same_as.h> |
22 | #include <__config> |
23 | #include <__fwd/mdspan.h> |
24 | #include <__mdspan/extents.h> |
25 | #include <__memory/addressof.h> |
26 | #include <__type_traits/common_type.h> |
27 | #include <__type_traits/is_constructible.h> |
28 | #include <__type_traits/is_convertible.h> |
29 | #include <__type_traits/is_integral.h> |
30 | #include <__type_traits/is_nothrow_constructible.h> |
31 | #include <__type_traits/is_same.h> |
32 | #include <__utility/as_const.h> |
33 | #include <__utility/integer_sequence.h> |
34 | #include <__utility/swap.h> |
35 | #include <array> |
36 | #include <limits> |
37 | #include <span> |
38 | |
39 | #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) |
40 | # pragma GCC system_header |
41 | #endif |
42 | |
43 | _LIBCPP_PUSH_MACROS |
44 | #include <__undef_macros> |
45 | |
46 | _LIBCPP_BEGIN_NAMESPACE_STD |
47 | |
48 | #if _LIBCPP_STD_VER >= 23 |
49 | |
50 | namespace __mdspan_detail { |
51 | template <class _Layout, class _Mapping> |
52 | constexpr bool __is_mapping_of = |
53 | is_same_v<typename _Layout::template mapping<typename _Mapping::extents_type>, _Mapping>; |
54 | |
55 | template <class _Mapping> |
56 | concept __layout_mapping_alike = requires { |
57 | requires __is_mapping_of<typename _Mapping::layout_type, _Mapping>; |
58 | requires __is_extents_v<typename _Mapping::extents_type>; |
59 | { _Mapping::is_always_strided() } -> same_as<bool>; |
60 | { _Mapping::is_always_exhaustive() } -> same_as<bool>; |
61 | { _Mapping::is_always_unique() } -> same_as<bool>; |
62 | bool_constant<_Mapping::is_always_strided()>::value; |
63 | bool_constant<_Mapping::is_always_exhaustive()>::value; |
64 | bool_constant<_Mapping::is_always_unique()>::value; |
65 | }; |
66 | } // namespace __mdspan_detail |
67 | |
68 | template <class _Extents> |
69 | class layout_stride::mapping { |
70 | public: |
71 | static_assert(__mdspan_detail::__is_extents<_Extents>::value, |
72 | "layout_stride::mapping template argument must be a specialization of extents."); |
73 | |
74 | using extents_type = _Extents; |
75 | using index_type = typename extents_type::index_type; |
76 | using size_type = typename extents_type::size_type; |
77 | using rank_type = typename extents_type::rank_type; |
78 | using layout_type = layout_stride; |
79 | |
80 | private: |
81 | static constexpr rank_type __rank_ = extents_type::rank(); |
82 | |
83 | // Used for default construction check and mandates |
84 | _LIBCPP_HIDE_FROM_ABI static constexpr bool __required_span_size_is_representable(const extents_type& __ext) { |
85 | if constexpr (__rank_ == 0) |
86 | return true; |
87 | |
88 | index_type __prod = __ext.extent(0); |
89 | for (rank_type __r = 1; __r < __rank_; __r++) { |
90 | bool __overflowed = __builtin_mul_overflow(__prod, __ext.extent(__r), std::addressof(__prod)); |
91 | if (__overflowed) |
92 | return false; |
93 | } |
94 | return true; |
95 | } |
96 | |
97 | template <class _OtherIndexType> |
98 | _LIBCPP_HIDE_FROM_ABI static constexpr bool |
99 | __required_span_size_is_representable(const extents_type& __ext, span<_OtherIndexType, __rank_> __strides) { |
100 | if constexpr (__rank_ == 0) |
101 | return true; |
102 | |
103 | index_type __size = 1; |
104 | for (rank_type __r = 0; __r < __rank_; __r++) { |
105 | // We can only check correct conversion of _OtherIndexType if it is an integral |
106 | if constexpr (is_integral_v<_OtherIndexType>) { |
107 | using _CommonType = common_type_t<index_type, _OtherIndexType>; |
108 | if (static_cast<_CommonType>(__strides[__r]) > static_cast<_CommonType>(numeric_limits<index_type>::max())) |
109 | return false; |
110 | } |
111 | if (__ext.extent(__r) == static_cast<index_type>(0)) |
112 | return true; |
113 | index_type __prod = (__ext.extent(__r) - 1); |
114 | bool __overflowed_mul = |
115 | __builtin_mul_overflow(__prod, static_cast<index_type>(__strides[__r]), std::addressof(__prod)); |
116 | if (__overflowed_mul) |
117 | return false; |
118 | bool __overflowed_add = __builtin_add_overflow(__size, __prod, std::addressof(__size)); |
119 | if (__overflowed_add) |
120 | return false; |
121 | } |
122 | return true; |
123 | } |
124 | |
125 | // compute offset of a strided layout mapping |
126 | template <class _StridedMapping> |
127 | _LIBCPP_HIDE_FROM_ABI static constexpr index_type __offset(const _StridedMapping& __mapping) { |
128 | if constexpr (_StridedMapping::extents_type::rank() == 0) { |
129 | return static_cast<index_type>(__mapping()); |
130 | } else if (__mapping.required_span_size() == static_cast<typename _StridedMapping::index_type>(0)) { |
131 | return static_cast<index_type>(0); |
132 | } else { |
133 | return [&]<size_t... _Pos>(index_sequence<_Pos...>) { |
134 | return static_cast<index_type>(__mapping((_Pos ? 0 : 0)...)); |
135 | }(make_index_sequence<__rank_>()); |
136 | } |
137 | } |
138 | |
139 | // compute the permutation for sorting the stride array |
140 | // we never actually sort the stride array |
141 | _LIBCPP_HIDE_FROM_ABI constexpr void __bubble_sort_by_strides(array<rank_type, __rank_>& __permute) const { |
142 | for (rank_type __i = __rank_ - 1; __i > 0; __i--) { |
143 | for (rank_type __r = 0; __r < __i; __r++) { |
144 | if (__strides_[__permute[__r]] > __strides_[__permute[__r + 1]]) { |
145 | swap(__permute[__r], __permute[__r + 1]); |
146 | } else { |
147 | // if two strides are the same then one of the associated extents must be 1 or 0 |
148 | // both could be, but you can't have one larger than 1 come first |
149 | if ((__strides_[__permute[__r]] == __strides_[__permute[__r + 1]]) && |
150 | (__extents_.extent(__permute[__r]) > static_cast<index_type>(1))) |
151 | swap(__permute[__r], __permute[__r + 1]); |
152 | } |
153 | } |
154 | } |
155 | } |
156 | |
157 | static_assert(extents_type::rank_dynamic() > 0 || __required_span_size_is_representable(extents_type()), |
158 | "layout_stride::mapping product of static extents must be representable as index_type."); |
159 | |
160 | public: |
161 | // [mdspan.layout.stride.cons], constructors |
162 | _LIBCPP_HIDE_FROM_ABI constexpr mapping() noexcept : __extents_(extents_type()) { |
163 | // Note the nominal precondition is covered by above static assert since |
164 | // if rank_dynamic is != 0 required_span_size is zero for default construction |
165 | if constexpr (__rank_ > 0) { |
166 | index_type __stride = 1; |
167 | for (rank_type __r = __rank_ - 1; __r > static_cast<rank_type>(0); __r--) { |
168 | __strides_[__r] = __stride; |
169 | __stride *= __extents_.extent(__r); |
170 | } |
171 | __strides_[0] = __stride; |
172 | } |
173 | } |
174 | |
175 | _LIBCPP_HIDE_FROM_ABI constexpr mapping(const mapping&) noexcept = default; |
176 | |
177 | template <class _OtherIndexType> |
178 | requires(is_convertible_v<const _OtherIndexType&, index_type> && |
179 | is_nothrow_constructible_v<index_type, const _OtherIndexType&>) |
180 | _LIBCPP_HIDE_FROM_ABI constexpr mapping(const extents_type& __ext, span<_OtherIndexType, __rank_> __strides) noexcept |
181 | : __extents_(__ext), __strides_([&]<size_t... _Pos>(index_sequence<_Pos...>) { |
182 | return __mdspan_detail::__possibly_empty_array<index_type, __rank_>{ |
183 | static_cast<index_type>(std::as_const(__strides[_Pos]))...}; |
184 | }(make_index_sequence<__rank_>())) { |
185 | _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( |
186 | ([&]<size_t... _Pos>(index_sequence<_Pos...>) { |
187 | // For integrals we can do a pre-conversion check, for other types not |
188 | if constexpr (is_integral_v<_OtherIndexType>) { |
189 | return ((__strides[_Pos] > static_cast<_OtherIndexType>(0)) && ... && true); |
190 | } else { |
191 | return ((static_cast<index_type>(__strides[_Pos]) > static_cast<index_type>(0)) && ... && true); |
192 | } |
193 | }(make_index_sequence<__rank_>())), |
194 | "layout_stride::mapping ctor: all strides must be greater than 0"); |
195 | _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( |
196 | __required_span_size_is_representable(__ext, __strides), |
197 | "layout_stride::mapping ctor: required span size is not representable as index_type."); |
198 | if constexpr (__rank_ > 1) { |
199 | _LIBCPP_ASSERT_UNCATEGORIZED( |
200 | ([&]<size_t... _Pos>(index_sequence<_Pos...>) { |
201 | // basically sort the dimensions based on strides and extents, sorting is represented in permute array |
202 | array<rank_type, __rank_> __permute{_Pos...}; |
203 | __bubble_sort_by_strides(__permute); |
204 | |
205 | // check that this permutations represents a growing set |
206 | for (rank_type __i = 1; __i < __rank_; __i++) |
207 | if (static_cast<index_type>(__strides[__permute[__i]]) < |
208 | static_cast<index_type>(__strides[__permute[__i - 1]]) * __extents_.extent(__permute[__i - 1])) |
209 | return false; |
210 | return true; |
211 | }(make_index_sequence<__rank_>())), |
212 | "layout_stride::mapping ctor: the provided extents and strides lead to a non-unique mapping"); |
213 | } |
214 | } |
215 | |
216 | template <class _OtherIndexType> |
217 | requires(is_convertible_v<const _OtherIndexType&, index_type> && |
218 | is_nothrow_constructible_v<index_type, const _OtherIndexType&>) |
219 | _LIBCPP_HIDE_FROM_ABI constexpr mapping(const extents_type& __ext, |
220 | const array<_OtherIndexType, __rank_>& __strides) noexcept |
221 | : mapping(__ext, span(__strides)) {} |
222 | |
223 | template <class _StridedLayoutMapping> |
224 | requires(__mdspan_detail::__layout_mapping_alike<_StridedLayoutMapping> && |
225 | is_constructible_v<extents_type, typename _StridedLayoutMapping::extents_type> && |
226 | _StridedLayoutMapping::is_always_unique() && _StridedLayoutMapping::is_always_strided()) |
227 | _LIBCPP_HIDE_FROM_ABI constexpr explicit( |
228 | !(is_convertible_v<typename _StridedLayoutMapping::extents_type, extents_type> && |
229 | (__mdspan_detail::__is_mapping_of<layout_left, _StridedLayoutMapping> || |
230 | __mdspan_detail::__is_mapping_of<layout_right, _StridedLayoutMapping> || |
231 | __mdspan_detail::__is_mapping_of<layout_stride, _StridedLayoutMapping>))) |
232 | mapping(const _StridedLayoutMapping& __other) noexcept |
233 | : __extents_(__other.extents()), __strides_([&]<size_t... _Pos>(index_sequence<_Pos...>) { |
234 | // stride() only compiles for rank > 0 |
235 | if constexpr (__rank_ > 0) { |
236 | return __mdspan_detail::__possibly_empty_array<index_type, __rank_>{ |
237 | static_cast<index_type>(__other.stride(_Pos))...}; |
238 | } else { |
239 | return __mdspan_detail::__possibly_empty_array<index_type, 0>{}; |
240 | } |
241 | }(make_index_sequence<__rank_>())) { |
242 | // stride() only compiles for rank > 0 |
243 | if constexpr (__rank_ > 0) { |
244 | _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( |
245 | ([&]<size_t... _Pos>(index_sequence<_Pos...>) { |
246 | return ((static_cast<index_type>(__other.stride(_Pos)) > static_cast<index_type>(0)) && ... && true); |
247 | }(make_index_sequence<__rank_>())), |
248 | "layout_stride::mapping converting ctor: all strides must be greater than 0"); |
249 | } |
250 | _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS( |
251 | __mdspan_detail::__is_representable_as<index_type>(__other.required_span_size()), |
252 | "layout_stride::mapping converting ctor: other.required_span_size() must be representable as index_type."); |
253 | _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(static_cast<index_type>(0) == __offset(__other), |
254 | "layout_stride::mapping converting ctor: base offset of mapping must be zero."); |
255 | } |
256 | |
257 | _LIBCPP_HIDE_FROM_ABI constexpr mapping& operator=(const mapping&) noexcept = default; |
258 | |
259 | // [mdspan.layout.stride.obs], observers |
260 | _LIBCPP_HIDE_FROM_ABI constexpr const extents_type& extents() const noexcept { return __extents_; } |
261 | |
262 | _LIBCPP_HIDE_FROM_ABI constexpr array<index_type, __rank_> strides() const noexcept { |
263 | return [&]<size_t... _Pos>(index_sequence<_Pos...>) { |
264 | return array<index_type, __rank_>{__strides_[_Pos]...}; |
265 | }(make_index_sequence<__rank_>()); |
266 | } |
267 | |
268 | _LIBCPP_HIDE_FROM_ABI constexpr index_type required_span_size() const noexcept { |
269 | if constexpr (__rank_ == 0) { |
270 | return static_cast<index_type>(1); |
271 | } else { |
272 | return [&]<size_t... _Pos>(index_sequence<_Pos...>) { |
273 | if ((__extents_.extent(_Pos) * ... * 1) == 0) |
274 | return static_cast<index_type>(0); |
275 | else |
276 | return static_cast<index_type>( |
277 | static_cast<index_type>(1) + |
278 | (((__extents_.extent(_Pos) - static_cast<index_type>(1)) * __strides_[_Pos]) + ... + |
279 | static_cast<index_type>(0))); |
280 | }(make_index_sequence<__rank_>()); |
281 | } |
282 | } |
283 | |
284 | template <class... _Indices> |
285 | requires((sizeof...(_Indices) == __rank_) && (is_convertible_v<_Indices, index_type> && ...) && |
286 | (is_nothrow_constructible_v<index_type, _Indices> && ...)) |
287 | _LIBCPP_HIDE_FROM_ABI constexpr index_type operator()(_Indices... __idx) const noexcept { |
288 | // Mappings are generally meant to be used for accessing allocations and are meant to guarantee to never |
289 | // return a value exceeding required_span_size(), which is used to know how large an allocation one needs |
290 | // Thus, this is a canonical point in multi-dimensional data structures to make invalid element access checks |
291 | // However, mdspan does check this on its own, so for now we avoid double checking in hardened mode |
292 | _LIBCPP_ASSERT_UNCATEGORIZED(__mdspan_detail::__is_multidimensional_index_in(__extents_, __idx...), |
293 | "layout_stride::mapping: out of bounds indexing"); |
294 | return [&]<size_t... _Pos>(index_sequence<_Pos...>) { |
295 | return ((static_cast<index_type>(__idx) * __strides_[_Pos]) + ... + index_type(0)); |
296 | }(make_index_sequence<sizeof...(_Indices)>()); |
297 | } |
298 | |
299 | _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_unique() noexcept { return true; } |
300 | _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_exhaustive() noexcept { return false; } |
301 | _LIBCPP_HIDE_FROM_ABI static constexpr bool is_always_strided() noexcept { return true; } |
302 | |
303 | _LIBCPP_HIDE_FROM_ABI static constexpr bool is_unique() noexcept { return true; } |
304 | // The answer of this function is fairly complex in the case where one or more |
305 | // extents are zero. |
306 | // Technically it is meaningless to query is_exhaustive() in that case, but unfortunately |
307 | // the way the standard defines this function, we can't give a simple true or false then. |
308 | _LIBCPP_HIDE_FROM_ABI constexpr bool is_exhaustive() const noexcept { |
309 | if constexpr (__rank_ == 0) |
310 | return true; |
311 | else { |
312 | index_type __span_size = required_span_size(); |
313 | if (__span_size == static_cast<index_type>(0)) { |
314 | if constexpr (__rank_ == 1) |
315 | return __strides_[0] == 1; |
316 | else { |
317 | rank_type __r_largest = 0; |
318 | for (rank_type __r = 1; __r < __rank_; __r++) |
319 | if (__strides_[__r] > __strides_[__r_largest]) |
320 | __r_largest = __r; |
321 | for (rank_type __r = 0; __r < __rank_; __r++) |
322 | if (__extents_.extent(__r) == 0 && __r != __r_largest) |
323 | return false; |
324 | return true; |
325 | } |
326 | } else { |
327 | return required_span_size() == [&]<size_t... _Pos>(index_sequence<_Pos...>) { |
328 | return (__extents_.extent(_Pos) * ... * static_cast<index_type>(1)); |
329 | }(make_index_sequence<__rank_>()); |
330 | } |
331 | } |
332 | } |
333 | _LIBCPP_HIDE_FROM_ABI static constexpr bool is_strided() noexcept { return true; } |
334 | |
335 | // according to the standard layout_stride does not have a constraint on stride(r) for rank>0 |
336 | // it still has the precondition though |
337 | _LIBCPP_HIDE_FROM_ABI constexpr index_type stride(rank_type __r) const noexcept { |
338 | _LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(__r < __rank_, "layout_stride::mapping::stride(): invalid rank index"); |
339 | return __strides_[__r]; |
340 | } |
341 | |
342 | template <class _OtherMapping> |
343 | requires(__mdspan_detail::__layout_mapping_alike<_OtherMapping> && |
344 | (_OtherMapping::extents_type::rank() == __rank_) && _OtherMapping::is_always_strided()) |
345 | _LIBCPP_HIDE_FROM_ABI friend constexpr bool operator==(const mapping& __lhs, const _OtherMapping& __rhs) noexcept { |
346 | if (__offset(__rhs)) |
347 | return false; |
348 | if constexpr (__rank_ == 0) |
349 | return true; |
350 | else { |
351 | return __lhs.extents() == __rhs.extents() && [&]<size_t... _Pos>(index_sequence<_Pos...>) { |
352 | // avoid warning when comparing signed and unsigner integers and pick the wider of two types |
353 | using _CommonType = common_type_t<index_type, typename _OtherMapping::index_type>; |
354 | return ((static_cast<_CommonType>(__lhs.stride(_Pos)) == static_cast<_CommonType>(__rhs.stride(_Pos))) && ... && |
355 | true); |
356 | }(make_index_sequence<__rank_>()); |
357 | } |
358 | } |
359 | |
360 | private: |
361 | _LIBCPP_NO_UNIQUE_ADDRESS extents_type __extents_{}; |
362 | _LIBCPP_NO_UNIQUE_ADDRESS __mdspan_detail::__possibly_empty_array<index_type, __rank_> __strides_{}; |
363 | }; |
364 | |
365 | #endif // _LIBCPP_STD_VER >= 23 |
366 | |
367 | _LIBCPP_END_NAMESPACE_STD |
368 | |
369 | _LIBCPP_POP_MACROS |
370 | |
371 | #endif // _LIBCPP___MDSPAN_LAYOUT_STRIDE_H |
372 |
Warning: This file is not a C or C++ file. It does not have highlighting.