| 1 | // -*- C++ -*- |
| 2 | //===----------------------------------------------------------------------===// |
| 3 | // |
| 4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 5 | // See https://llvm.org/LICENSE.txt for license information. |
| 6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 7 | // |
| 8 | // Kokkos v. 4.0 |
| 9 | // Copyright (2022) National Technology & Engineering |
| 10 | // Solutions of Sandia, LLC (NTESS). |
| 11 | // |
| 12 | // Under the terms of Contract DE-NA0003525 with NTESS, |
| 13 | // the U.S. Government retains certain rights in this software. |
| 14 | // |
| 15 | //===---------------------------------------------------------------------===// |
| 16 | |
| 17 | #ifndef TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H |
| 18 | #define TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H |
| 19 | |
| 20 | #include <algorithm> |
| 21 | #include <array> |
| 22 | #include <cassert> |
| 23 | #include <cinttypes> |
| 24 | #include <concepts> |
| 25 | #include <cstddef> |
| 26 | #include <limits> |
| 27 | #include <mdspan> |
| 28 | #include <type_traits> |
| 29 | #include <utility> |
| 30 | |
| 31 | // Layout that wraps indices to test some idiosyncratic behavior |
| 32 | // - basically it is a layout_left where indices are first wrapped i.e. i%Wrap |
| 33 | // - only accepts integers as indices |
| 34 | // - is_always_strided and is_always_unique are false |
| 35 | // - is_strided and is_unique are true if all extents are smaller than Wrap |
| 36 | // - not default constructible |
| 37 | // - not extents constructible |
| 38 | // - not trivially copyable |
| 39 | // - does not check dynamic to static extent conversion in converting ctor |
| 40 | // - check via side-effects that mdspan::swap calls mappings swap via ADL |
| 41 | |
| 42 | struct not_extents_constructible_tag {}; |
| 43 | |
| 44 | template <size_t Wrap> |
| 45 | class layout_wrapping_integral { |
| 46 | public: |
| 47 | template <class Extents> |
| 48 | class mapping; |
| 49 | }; |
| 50 | |
| 51 | template <size_t WrapArg> |
| 52 | template <class Extents> |
| 53 | class layout_wrapping_integral<WrapArg>::mapping { |
| 54 | static constexpr typename Extents::index_type Wrap = static_cast<typename Extents::index_type>(WrapArg); |
| 55 | |
| 56 | public: |
| 57 | using extents_type = Extents; |
| 58 | using index_type = typename extents_type::index_type; |
| 59 | using size_type = typename extents_type::size_type; |
| 60 | using rank_type = typename extents_type::rank_type; |
| 61 | using layout_type = layout_wrapping_integral<Wrap>; |
| 62 | |
| 63 | private: |
| 64 | static constexpr bool required_span_size_is_representable(const extents_type& ext) { |
| 65 | if constexpr (extents_type::rank() == 0) |
| 66 | return true; |
| 67 | |
| 68 | index_type prod = ext.extent(0); |
| 69 | for (rank_type r = 1; r < extents_type::rank(); r++) { |
| 70 | bool overflowed = __builtin_mul_overflow(prod, std::min(ext.extent(r), Wrap), &prod); |
| 71 | if (overflowed) |
| 72 | return false; |
| 73 | } |
| 74 | return true; |
| 75 | } |
| 76 | |
| 77 | public: |
| 78 | constexpr mapping() noexcept = delete; |
| 79 | constexpr mapping(const mapping& other) noexcept : extents_(other.extents()) {} |
| 80 | constexpr mapping(extents_type&& ext) noexcept |
| 81 | requires(Wrap == 8) |
| 82 | : extents_(ext) {} |
| 83 | constexpr mapping(const extents_type& ext, not_extents_constructible_tag) noexcept : extents_(ext) {} |
| 84 | |
| 85 | template <class OtherExtents> |
| 86 | requires(std::is_constructible_v<extents_type, OtherExtents> && (Wrap != 8)) |
| 87 | constexpr explicit(!std::is_convertible_v<OtherExtents, extents_type>) |
| 88 | mapping(const mapping<OtherExtents>& other) noexcept { |
| 89 | std::array<index_type, extents_type::rank_dynamic()> dyn_extents; |
| 90 | rank_type count = 0; |
| 91 | for (rank_type r = 0; r < extents_type::rank(); r++) { |
| 92 | if (extents_type::static_extent(r) == std::dynamic_extent) { |
| 93 | dyn_extents[count++] = other.extents().extent(r); |
| 94 | } |
| 95 | } |
| 96 | extents_ = extents_type(dyn_extents); |
| 97 | } |
| 98 | template <class OtherExtents> |
| 99 | requires(std::is_constructible_v<extents_type, OtherExtents> && (Wrap == 8)) |
| 100 | constexpr explicit(!std::is_convertible_v<OtherExtents, extents_type>) |
| 101 | mapping(mapping<OtherExtents>&& other) noexcept { |
| 102 | std::array<index_type, extents_type::rank_dynamic()> dyn_extents; |
| 103 | rank_type count = 0; |
| 104 | for (rank_type r = 0; r < extents_type::rank(); r++) { |
| 105 | if (extents_type::static_extent(r) == std::dynamic_extent) { |
| 106 | dyn_extents[count++] = other.extents().extent(r); |
| 107 | } |
| 108 | } |
| 109 | extents_ = extents_type(dyn_extents); |
| 110 | } |
| 111 | |
| 112 | constexpr mapping& operator=(const mapping& other) noexcept { |
| 113 | extents_ = other.extents_; |
| 114 | return *this; |
| 115 | }; |
| 116 | |
| 117 | constexpr const extents_type& extents() const noexcept { return extents_; } |
| 118 | |
| 119 | constexpr index_type required_span_size() const noexcept { |
| 120 | index_type size = 1; |
| 121 | for (size_t r = 0; r < extents_type::rank(); r++) |
| 122 | size *= extents_.extent(r) < Wrap ? extents_.extent(r) : Wrap; |
| 123 | return size; |
| 124 | } |
| 125 | |
| 126 | template <std::integral... Indices> |
| 127 | requires((sizeof...(Indices) == extents_type::rank()) && (std::is_convertible_v<Indices, index_type> && ...) && |
| 128 | (std::is_nothrow_constructible_v<index_type, Indices> && ...)) |
| 129 | constexpr index_type operator()(Indices... idx) const noexcept { |
| 130 | std::array<index_type, extents_type::rank()> idx_a{static_cast<index_type>(static_cast<index_type>(idx) % Wrap)...}; |
| 131 | return [&]<size_t... Pos>(std::index_sequence<Pos...>) { |
| 132 | index_type res = 0; |
| 133 | ((res = idx_a[extents_type::rank() - 1 - Pos] + |
| 134 | (extents_.extent(extents_type::rank() - 1 - Pos) < Wrap ? extents_.extent(extents_type::rank() - 1 - Pos) |
| 135 | : Wrap) * |
| 136 | res), |
| 137 | ...); |
| 138 | return res; |
| 139 | }(std::make_index_sequence<sizeof...(Indices)>()); |
| 140 | } |
| 141 | |
| 142 | static constexpr bool is_always_unique() noexcept { return false; } |
| 143 | static constexpr bool is_always_exhaustive() noexcept { return true; } |
| 144 | static constexpr bool is_always_strided() noexcept { return false; } |
| 145 | |
| 146 | constexpr bool is_unique() const noexcept { |
| 147 | for (rank_type r = 0; r < extents_type::rank(); r++) { |
| 148 | if (extents_.extent(r) > Wrap) |
| 149 | return false; |
| 150 | } |
| 151 | return true; |
| 152 | } |
| 153 | static constexpr bool is_exhaustive() noexcept { return true; } |
| 154 | constexpr bool is_strided() const noexcept { |
| 155 | for (rank_type r = 0; r < extents_type::rank(); r++) { |
| 156 | if (extents_.extent(r) > Wrap) |
| 157 | return false; |
| 158 | } |
| 159 | return true; |
| 160 | } |
| 161 | |
| 162 | constexpr index_type stride(rank_type r) const noexcept |
| 163 | requires(extents_type::rank() > 0) |
| 164 | { |
| 165 | index_type s = 1; |
| 166 | for (rank_type i = extents_type::rank() - 1; i > r; i--) |
| 167 | s *= extents_.extent(i); |
| 168 | return s; |
| 169 | } |
| 170 | |
| 171 | template <class OtherExtents> |
| 172 | requires(OtherExtents::rank() == extents_type::rank()) |
| 173 | friend constexpr bool operator==(const mapping& lhs, const mapping<OtherExtents>& rhs) noexcept { |
| 174 | return lhs.extents() == rhs.extents(); |
| 175 | } |
| 176 | |
| 177 | friend constexpr void swap(mapping& x, mapping& y) noexcept { |
| 178 | swap(x.extents_, y.extents_); |
| 179 | if (!std::is_constant_evaluated()) { |
| 180 | swap_counter()++; |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | static int& swap_counter() { |
| 185 | static int value = 0; |
| 186 | return value; |
| 187 | } |
| 188 | |
| 189 | private: |
| 190 | extents_type extents_{}; |
| 191 | }; |
| 192 | |
| 193 | template <class Extents> |
| 194 | constexpr auto construct_mapping(std::layout_left, Extents exts) { |
| 195 | return std::layout_left::mapping<Extents>(exts); |
| 196 | } |
| 197 | |
| 198 | template <class Extents> |
| 199 | constexpr auto construct_mapping(std::layout_right, Extents exts) { |
| 200 | return std::layout_right::mapping<Extents>(exts); |
| 201 | } |
| 202 | |
| 203 | template <size_t Wraps, class Extents> |
| 204 | constexpr auto construct_mapping(layout_wrapping_integral<Wraps>, Extents exts) { |
| 205 | return typename layout_wrapping_integral<Wraps>::template mapping<Extents>(exts, not_extents_constructible_tag{}); |
| 206 | } |
| 207 | |
| 208 | // This layout does not check convertibility of extents for its conversion ctor |
| 209 | // Allows triggering mdspan's ctor static assertion on convertibility of extents |
| 210 | // It also allows for negative strides and offsets via runtime arguments |
| 211 | class always_convertible_layout { |
| 212 | public: |
| 213 | template <class Extents> |
| 214 | class mapping; |
| 215 | }; |
| 216 | |
| 217 | template <class Extents> |
| 218 | class always_convertible_layout::mapping { |
| 219 | public: |
| 220 | using extents_type = Extents; |
| 221 | using index_type = typename extents_type::index_type; |
| 222 | using size_type = typename extents_type::size_type; |
| 223 | using rank_type = typename extents_type::rank_type; |
| 224 | using layout_type = always_convertible_layout; |
| 225 | |
| 226 | private: |
| 227 | static constexpr bool required_span_size_is_representable(const extents_type& ext) { |
| 228 | if constexpr (extents_type::rank() == 0) |
| 229 | return true; |
| 230 | |
| 231 | index_type prod = ext.extent(0); |
| 232 | for (rank_type r = 1; r < extents_type::rank(); r++) { |
| 233 | bool overflowed = __builtin_mul_overflow(prod, ext.extent(r), &prod); |
| 234 | if (overflowed) |
| 235 | return false; |
| 236 | } |
| 237 | return true; |
| 238 | } |
| 239 | |
| 240 | public: |
| 241 | constexpr mapping() noexcept = delete; |
| 242 | constexpr mapping(const mapping& other) noexcept |
| 243 | : extents_(other.extents_), offset_(other.offset_), scaling_(other.scaling_) {} |
| 244 | constexpr mapping(const extents_type& ext, index_type offset = 0, index_type scaling = 1) noexcept |
| 245 | : extents_(ext), offset_(offset), scaling_(scaling) {} |
| 246 | |
| 247 | template <class OtherExtents> |
| 248 | constexpr mapping(const mapping<OtherExtents>& other) noexcept { |
| 249 | if constexpr (extents_type::rank() == OtherExtents::rank()) { |
| 250 | std::array<index_type, extents_type::rank_dynamic()> dyn_extents; |
| 251 | rank_type count = 0; |
| 252 | for (rank_type r = 0; r < extents_type::rank(); r++) { |
| 253 | if (extents_type::static_extent(r) == std::dynamic_extent) { |
| 254 | dyn_extents[count++] = other.extents().extent(r); |
| 255 | } |
| 256 | } |
| 257 | extents_ = extents_type(dyn_extents); |
| 258 | } else { |
| 259 | extents_ = extents_type(); |
| 260 | } |
| 261 | offset_ = other.offset_; |
| 262 | scaling_ = other.scaling_; |
| 263 | } |
| 264 | |
| 265 | constexpr mapping& operator=(const mapping& other) noexcept { |
| 266 | extents_ = other.extents_; |
| 267 | offset_ = other.offset_; |
| 268 | scaling_ = other.scaling_; |
| 269 | return *this; |
| 270 | }; |
| 271 | |
| 272 | constexpr const extents_type& extents() const noexcept { return extents_; } |
| 273 | |
| 274 | constexpr index_type required_span_size() const noexcept { |
| 275 | index_type size = 1; |
| 276 | for (size_t r = 0; r < extents_type::rank(); r++) |
| 277 | size *= extents_.extent(r); |
| 278 | return std::max(size * scaling_ + offset_, offset_); |
| 279 | } |
| 280 | |
| 281 | template <std::integral... Indices> |
| 282 | requires((sizeof...(Indices) == extents_type::rank()) && (std::is_convertible_v<Indices, index_type> && ...) && |
| 283 | (std::is_nothrow_constructible_v<index_type, Indices> && ...)) |
| 284 | constexpr index_type operator()(Indices... idx) const noexcept { |
| 285 | std::array<index_type, extents_type::rank()> idx_a{static_cast<index_type>(static_cast<index_type>(idx))...}; |
| 286 | return offset_ + |
| 287 | scaling_ * ([&]<size_t... Pos>(std::index_sequence<Pos...>) { |
| 288 | index_type res = 0; |
| 289 | ((res = idx_a[extents_type::rank() - 1 - Pos] + extents_.extent(extents_type::rank() - 1 - Pos) * res), |
| 290 | ...); |
| 291 | return res; |
| 292 | }(std::make_index_sequence<sizeof...(Indices)>())); |
| 293 | } |
| 294 | |
| 295 | static constexpr bool is_always_unique() noexcept { return true; } |
| 296 | static constexpr bool is_always_exhaustive() noexcept { return true; } |
| 297 | static constexpr bool is_always_strided() noexcept { return true; } |
| 298 | |
| 299 | static constexpr bool is_unique() noexcept { return true; } |
| 300 | static constexpr bool is_exhaustive() noexcept { return true; } |
| 301 | static constexpr bool is_strided() noexcept { return true; } |
| 302 | |
| 303 | constexpr index_type stride(rank_type r) const noexcept |
| 304 | requires(extents_type::rank() > 0) |
| 305 | { |
| 306 | index_type s = 1; |
| 307 | for (rank_type i = 0; i < r; i++) |
| 308 | s *= extents_.extent(i); |
| 309 | return s * scaling_; |
| 310 | } |
| 311 | |
| 312 | template <class OtherExtents> |
| 313 | requires(OtherExtents::rank() == extents_type::rank()) |
| 314 | friend constexpr bool operator==(const mapping& lhs, const mapping<OtherExtents>& rhs) noexcept { |
| 315 | return lhs.extents() == rhs.extents() && lhs.offset_ == rhs.offset && lhs.scaling_ == rhs.scaling_; |
| 316 | } |
| 317 | |
| 318 | friend constexpr void swap(mapping& x, mapping& y) noexcept { |
| 319 | swap(x.extents_, y.extents_); |
| 320 | if (!std::is_constant_evaluated()) { |
| 321 | swap_counter()++; |
| 322 | } |
| 323 | } |
| 324 | |
| 325 | static int& swap_counter() { |
| 326 | static int value = 0; |
| 327 | return value; |
| 328 | } |
| 329 | |
| 330 | private: |
| 331 | template <class> |
| 332 | friend class mapping; |
| 333 | |
| 334 | extents_type extents_{}; |
| 335 | index_type offset_{}; |
| 336 | index_type scaling_{}; |
| 337 | }; |
| 338 | #endif // TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H |
| 339 | |