1 | // -*- C++ -*- |
2 | //===----------------------------------------------------------------------===// |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | // Kokkos v. 4.0 |
9 | // Copyright (2022) National Technology & Engineering |
10 | // Solutions of Sandia, LLC (NTESS). |
11 | // |
12 | // Under the terms of Contract DE-NA0003525 with NTESS, |
13 | // the U.S. Government retains certain rights in this software. |
14 | // |
15 | //===---------------------------------------------------------------------===// |
16 | |
17 | #ifndef TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H |
18 | #define TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H |
19 | |
20 | #include <algorithm> |
21 | #include <array> |
22 | #include <cassert> |
23 | #include <cinttypes> |
24 | #include <concepts> |
25 | #include <cstddef> |
26 | #include <limits> |
27 | #include <mdspan> |
28 | #include <type_traits> |
29 | #include <utility> |
30 | |
31 | // Layout that wraps indices to test some idiosyncratic behavior |
32 | // - basically it is a layout_left where indicies are first wrapped i.e. i%Wrap |
33 | // - only accepts integers as indices |
34 | // - is_always_strided and is_always_unique are false |
35 | // - is_strided and is_unique are true if all extents are smaller than Wrap |
36 | // - not default constructible |
37 | // - not extents constructible |
38 | // - not trivially copyable |
39 | // - does not check dynamic to static extent conversion in converting ctor |
40 | // - check via side-effects that mdspan::swap calls mappings swap via ADL |
41 | |
42 | struct not_extents_constructible_tag {}; |
43 | |
44 | template <size_t Wrap> |
45 | class layout_wrapping_integral { |
46 | public: |
47 | template <class Extents> |
48 | class mapping; |
49 | }; |
50 | |
51 | template <size_t WrapArg> |
52 | template <class Extents> |
53 | class layout_wrapping_integral<WrapArg>::mapping { |
54 | static constexpr typename Extents::index_type Wrap = static_cast<typename Extents::index_type>(WrapArg); |
55 | |
56 | public: |
57 | using extents_type = Extents; |
58 | using index_type = typename extents_type::index_type; |
59 | using size_type = typename extents_type::size_type; |
60 | using rank_type = typename extents_type::rank_type; |
61 | using layout_type = layout_wrapping_integral<Wrap>; |
62 | |
63 | private: |
64 | static constexpr bool required_span_size_is_representable(const extents_type& ext) { |
65 | if constexpr (extents_type::rank() == 0) |
66 | return true; |
67 | |
68 | index_type prod = ext.extent(0); |
69 | for (rank_type r = 1; r < extents_type::rank(); r++) { |
70 | bool overflowed = __builtin_mul_overflow(prod, std::min(ext.extent(r), Wrap), &prod); |
71 | if (overflowed) |
72 | return false; |
73 | } |
74 | return true; |
75 | } |
76 | |
77 | public: |
78 | constexpr mapping() noexcept = delete; |
79 | constexpr mapping(const mapping& other) noexcept : extents_(other.extents()) {} |
80 | constexpr mapping(extents_type&& ext) noexcept |
81 | requires(Wrap == 8) |
82 | : extents_(ext) {} |
83 | constexpr mapping(const extents_type& ext, not_extents_constructible_tag) noexcept : extents_(ext) {} |
84 | |
85 | template <class OtherExtents> |
86 | requires(std::is_constructible_v<extents_type, OtherExtents> && (Wrap != 8)) |
87 | constexpr explicit(!std::is_convertible_v<OtherExtents, extents_type>) |
88 | mapping(const mapping<OtherExtents>& other) noexcept { |
89 | std::array<index_type, extents_type::rank_dynamic()> dyn_extents; |
90 | rank_type count = 0; |
91 | for (rank_type r = 0; r < extents_type::rank(); r++) { |
92 | if (extents_type::static_extent(r) == std::dynamic_extent) { |
93 | dyn_extents[count++] = other.extents().extent(r); |
94 | } |
95 | } |
96 | extents_ = extents_type(dyn_extents); |
97 | } |
98 | template <class OtherExtents> |
99 | requires(std::is_constructible_v<extents_type, OtherExtents> && (Wrap == 8)) |
100 | constexpr explicit(!std::is_convertible_v<OtherExtents, extents_type>) |
101 | mapping(mapping<OtherExtents>&& other) noexcept { |
102 | std::array<index_type, extents_type::rank_dynamic()> dyn_extents; |
103 | rank_type count = 0; |
104 | for (rank_type r = 0; r < extents_type::rank(); r++) { |
105 | if (extents_type::static_extent(r) == std::dynamic_extent) { |
106 | dyn_extents[count++] = other.extents().extent(r); |
107 | } |
108 | } |
109 | extents_ = extents_type(dyn_extents); |
110 | } |
111 | |
112 | constexpr mapping& operator=(const mapping& other) noexcept { |
113 | extents_ = other.extents_; |
114 | return *this; |
115 | }; |
116 | |
117 | constexpr const extents_type& extents() const noexcept { return extents_; } |
118 | |
119 | constexpr index_type required_span_size() const noexcept { |
120 | index_type size = 1; |
121 | for (size_t r = 0; r < extents_type::rank(); r++) |
122 | size *= extents_.extent(r) < Wrap ? extents_.extent(r) : Wrap; |
123 | return size; |
124 | } |
125 | |
126 | template <std::integral... Indices> |
127 | requires((sizeof...(Indices) == extents_type::rank()) && (std::is_convertible_v<Indices, index_type> && ...) && |
128 | (std::is_nothrow_constructible_v<index_type, Indices> && ...)) |
129 | constexpr index_type operator()(Indices... idx) const noexcept { |
130 | std::array<index_type, extents_type::rank()> idx_a{static_cast<index_type>(static_cast<index_type>(idx) % Wrap)...}; |
131 | return [&]<size_t... Pos>(std::index_sequence<Pos...>) { |
132 | index_type res = 0; |
133 | ((res = idx_a[extents_type::rank() - 1 - Pos] + |
134 | (extents_.extent(extents_type::rank() - 1 - Pos) < Wrap ? extents_.extent(extents_type::rank() - 1 - Pos) |
135 | : Wrap) * |
136 | res), |
137 | ...); |
138 | return res; |
139 | }(std::make_index_sequence<sizeof...(Indices)>()); |
140 | } |
141 | |
142 | static constexpr bool is_always_unique() noexcept { return false; } |
143 | static constexpr bool is_always_exhaustive() noexcept { return true; } |
144 | static constexpr bool is_always_strided() noexcept { return false; } |
145 | |
146 | constexpr bool is_unique() const noexcept { |
147 | for (rank_type r = 0; r < extents_type::rank(); r++) { |
148 | if (extents_.extent(r) > Wrap) |
149 | return false; |
150 | } |
151 | return true; |
152 | } |
153 | static constexpr bool is_exhaustive() noexcept { return true; } |
154 | constexpr bool is_strided() const noexcept { |
155 | for (rank_type r = 0; r < extents_type::rank(); r++) { |
156 | if (extents_.extent(r) > Wrap) |
157 | return false; |
158 | } |
159 | return true; |
160 | } |
161 | |
162 | constexpr index_type stride(rank_type r) const noexcept |
163 | requires(extents_type::rank() > 0) |
164 | { |
165 | index_type s = 1; |
166 | for (rank_type i = extents_type::rank() - 1; i > r; i--) |
167 | s *= extents_.extent(i); |
168 | return s; |
169 | } |
170 | |
171 | template <class OtherExtents> |
172 | requires(OtherExtents::rank() == extents_type::rank()) |
173 | friend constexpr bool operator==(const mapping& lhs, const mapping<OtherExtents>& rhs) noexcept { |
174 | return lhs.extents() == rhs.extents(); |
175 | } |
176 | |
177 | friend constexpr void swap(mapping& x, mapping& y) noexcept { |
178 | swap(x.extents_, y.extents_); |
179 | if (!std::is_constant_evaluated()) { |
180 | swap_counter()++; |
181 | } |
182 | } |
183 | |
184 | static int& swap_counter() { |
185 | static int value = 0; |
186 | return value; |
187 | } |
188 | |
189 | private: |
190 | extents_type extents_{}; |
191 | }; |
192 | |
193 | template <class Extents> |
194 | constexpr auto construct_mapping(std::layout_left, Extents exts) { |
195 | return std::layout_left::mapping<Extents>(exts); |
196 | } |
197 | |
198 | template <class Extents> |
199 | constexpr auto construct_mapping(std::layout_right, Extents exts) { |
200 | return std::layout_right::mapping<Extents>(exts); |
201 | } |
202 | |
203 | template <size_t Wraps, class Extents> |
204 | constexpr auto construct_mapping(layout_wrapping_integral<Wraps>, Extents exts) { |
205 | return typename layout_wrapping_integral<Wraps>::template mapping<Extents>(exts, not_extents_constructible_tag{}); |
206 | } |
207 | |
208 | // This layout does not check convertibility of extents for its conversion ctor |
209 | // Allows triggering mdspan's ctor static assertion on convertibility of extents |
210 | // It also allows for negative strides and offsets via runtime arguments |
211 | class always_convertible_layout { |
212 | public: |
213 | template <class Extents> |
214 | class mapping; |
215 | }; |
216 | |
217 | template <class Extents> |
218 | class always_convertible_layout::mapping { |
219 | public: |
220 | using extents_type = Extents; |
221 | using index_type = typename extents_type::index_type; |
222 | using size_type = typename extents_type::size_type; |
223 | using rank_type = typename extents_type::rank_type; |
224 | using layout_type = always_convertible_layout; |
225 | |
226 | private: |
227 | static constexpr bool required_span_size_is_representable(const extents_type& ext) { |
228 | if constexpr (extents_type::rank() == 0) |
229 | return true; |
230 | |
231 | index_type prod = ext.extent(0); |
232 | for (rank_type r = 1; r < extents_type::rank(); r++) { |
233 | bool overflowed = __builtin_mul_overflow(prod, ext.extent(r), &prod); |
234 | if (overflowed) |
235 | return false; |
236 | } |
237 | return true; |
238 | } |
239 | |
240 | public: |
241 | constexpr mapping() noexcept = delete; |
242 | constexpr mapping(const mapping& other) noexcept |
243 | : extents_(other.extents_), offset_(other.offset_), scaling_(other.scaling_) {} |
244 | constexpr mapping(const extents_type& ext, index_type offset = 0, index_type scaling = 1) noexcept |
245 | : extents_(ext), offset_(offset), scaling_(scaling) {} |
246 | |
247 | template <class OtherExtents> |
248 | constexpr mapping(const mapping<OtherExtents>& other) noexcept { |
249 | if constexpr (extents_type::rank() == OtherExtents::rank()) { |
250 | std::array<index_type, extents_type::rank_dynamic()> dyn_extents; |
251 | rank_type count = 0; |
252 | for (rank_type r = 0; r < extents_type::rank(); r++) { |
253 | if (extents_type::static_extent(r) == std::dynamic_extent) { |
254 | dyn_extents[count++] = other.extents().extent(r); |
255 | } |
256 | } |
257 | extents_ = extents_type(dyn_extents); |
258 | } else { |
259 | extents_ = extents_type(); |
260 | } |
261 | offset_ = other.offset_; |
262 | scaling_ = other.scaling_; |
263 | } |
264 | |
265 | constexpr mapping& operator=(const mapping& other) noexcept { |
266 | extents_ = other.extents_; |
267 | offset_ = other.offset_; |
268 | scaling_ = other.scaling_; |
269 | return *this; |
270 | }; |
271 | |
272 | constexpr const extents_type& extents() const noexcept { return extents_; } |
273 | |
274 | constexpr index_type required_span_size() const noexcept { |
275 | index_type size = 1; |
276 | for (size_t r = 0; r < extents_type::rank(); r++) |
277 | size *= extents_.extent(r); |
278 | return std::max(size * scaling_ + offset_, offset_); |
279 | } |
280 | |
281 | template <std::integral... Indices> |
282 | requires((sizeof...(Indices) == extents_type::rank()) && (std::is_convertible_v<Indices, index_type> && ...) && |
283 | (std::is_nothrow_constructible_v<index_type, Indices> && ...)) |
284 | constexpr index_type operator()(Indices... idx) const noexcept { |
285 | std::array<index_type, extents_type::rank()> idx_a{static_cast<index_type>(static_cast<index_type>(idx))...}; |
286 | return offset_ + |
287 | scaling_ * ([&]<size_t... Pos>(std::index_sequence<Pos...>) { |
288 | index_type res = 0; |
289 | ((res = idx_a[extents_type::rank() - 1 - Pos] + extents_.extent(extents_type::rank() - 1 - Pos) * res), |
290 | ...); |
291 | return res; |
292 | }(std::make_index_sequence<sizeof...(Indices)>())); |
293 | } |
294 | |
295 | static constexpr bool is_always_unique() noexcept { return true; } |
296 | static constexpr bool is_always_exhaustive() noexcept { return true; } |
297 | static constexpr bool is_always_strided() noexcept { return true; } |
298 | |
299 | static constexpr bool is_unique() noexcept { return true; } |
300 | static constexpr bool is_exhaustive() noexcept { return true; } |
301 | static constexpr bool is_strided() noexcept { return true; } |
302 | |
303 | constexpr index_type stride(rank_type r) const noexcept |
304 | requires(extents_type::rank() > 0) |
305 | { |
306 | index_type s = 1; |
307 | for (rank_type i = 0; i < r; i++) |
308 | s *= extents_.extent(i); |
309 | return s * scaling_; |
310 | } |
311 | |
312 | template <class OtherExtents> |
313 | requires(OtherExtents::rank() == extents_type::rank()) |
314 | friend constexpr bool operator==(const mapping& lhs, const mapping<OtherExtents>& rhs) noexcept { |
315 | return lhs.extents() == rhs.extents() && lhs.offset_ == rhs.offset && lhs.scaling_ == rhs.scaling_; |
316 | } |
317 | |
318 | friend constexpr void swap(mapping& x, mapping& y) noexcept { |
319 | swap(x.extents_, y.extents_); |
320 | if (!std::is_constant_evaluated()) { |
321 | swap_counter()++; |
322 | } |
323 | } |
324 | |
325 | static int& swap_counter() { |
326 | static int value = 0; |
327 | return value; |
328 | } |
329 | |
330 | private: |
331 | template <class> |
332 | friend class mapping; |
333 | |
334 | extents_type extents_{}; |
335 | index_type offset_{}; |
336 | index_type scaling_{}; |
337 | }; |
338 | #endif // TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H |
339 | |