1 | //===----------------------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | // <algorithm> |
10 | |
11 | // UNSUPPORTED: c++03, c++11, c++14, c++17 |
12 | |
13 | // template<bidirectional_iterator I, sentinel_for<I> S, class Comp = ranges::less, |
14 | // class Proj = identity> |
15 | // requires sortable<I, Comp, Proj> |
16 | // constexpr ranges::next_permutation_result<I> |
17 | // ranges::next_permutation(I first, S last, Comp comp = {}, Proj proj = {}); |
18 | // template<bidirectional_range R, class Comp = ranges::less, |
19 | // class Proj = identity> |
20 | // requires sortable<iterator_t<R>, Comp, Proj> |
21 | // constexpr ranges::next_permutation_result<borrowed_iterator_t<R>> |
22 | // ranges::next_permutation(R&& r, Comp comp = {}, Proj proj = {}); |
23 | |
24 | #include <algorithm> |
25 | #include <array> |
26 | #include <cassert> |
27 | #include <cstddef> |
28 | #include <ranges> |
29 | |
30 | #include "almost_satisfies_types.h" |
31 | #include "test_iterators.h" |
32 | |
33 | template <class Iter, class Sent = sentinel_wrapper<Iter>> |
34 | concept HasNextPermutationIt = requires(Iter first, Sent last) { std::ranges::next_permutation(first, last); }; |
35 | |
36 | static_assert(HasNextPermutationIt<int*>); |
37 | static_assert(!HasNextPermutationIt<BidirectionalIteratorNotDerivedFrom>); |
38 | static_assert(!HasNextPermutationIt<BidirectionalIteratorNotDecrementable>); |
39 | static_assert(!HasNextPermutationIt<int*, SentinelForNotSemiregular>); |
40 | static_assert(!HasNextPermutationIt<int*, SentinelForNotWeaklyEqualityComparableWith>); |
41 | static_assert(!HasNextPermutationIt<const int*>); // not sortable |
42 | |
43 | template <class Range> |
44 | concept HasNextPermutationR = requires(Range range) { std::ranges::next_permutation(range); }; |
45 | |
46 | static_assert(HasNextPermutationR<UncheckedRange<int*>>); |
47 | static_assert(!HasNextPermutationR<BidirectionalRangeNotDerivedFrom>); |
48 | static_assert(!HasNextPermutationR<BidirectionalRangeNotDecrementable>); |
49 | static_assert(!HasNextPermutationR<BidirectionalRangeNotSentinelSemiregular>); |
50 | static_assert(!HasNextPermutationR<BidirectionalRangeNotSentinelWeaklyEqualityComparableWith>); |
51 | static_assert(!HasNextPermutationR<UncheckedRange<const int*>>); // not sortable |
52 | |
53 | constexpr std::size_t factorial(size_t i) { |
54 | std::array memoized = {1, 1, 2, 6, 24, 120, 720, 5040, 40320}; |
55 | return memoized[i]; |
56 | } |
57 | |
58 | template <class Iter, class Range, class Func> |
59 | constexpr bool run_next_permutation(Func call_next_permutation, Range permuted, Range previous) { |
60 | using Result = std::ranges::next_permutation_result<Iter>; |
61 | |
62 | std::same_as<Result> decltype(auto) ret = call_next_permutation(permuted); |
63 | assert(ret.in == permuted.end()); |
64 | bool next_found = ret.found; |
65 | |
66 | if (std::ranges::distance(permuted) > 1) { |
67 | if (next_found) { |
68 | assert(std::ranges::lexicographical_compare(previous, permuted)); |
69 | } else { |
70 | assert(std::ranges::lexicographical_compare(permuted, previous)); |
71 | assert(std::ranges::is_sorted(permuted)); |
72 | } |
73 | } |
74 | |
75 | return next_found; |
76 | } |
77 | |
78 | template <class Iter, class Sent, class Func> |
79 | constexpr void test_next_permutations(Func call_next_permutation) { |
80 | std::array input = {1, 2, 3, 4}; |
81 | auto current_permutation = input; |
82 | auto previous_permutation = current_permutation; |
83 | |
84 | // For all subarrays of `input` from `[0, 0]` to `[0, N - 1]`, call `next_permutation` until no next permutation |
85 | // exists. |
86 | // The number of permutations must equal `N!`. `run_next_permutation` checks that each next permutation is |
87 | // lexicographically greater than the previous. If these two conditions hold (the number of permutations is `N!`, and |
88 | // each permutation is lexicographically greater than the previous one), it follows that the |
89 | // `ranges::next_permutation` algorithm works correctly. |
90 | for (std::size_t i = 0; i <= current_permutation.size(); ++i) { |
91 | std::size_t count = 0; |
92 | bool next_found = true; |
93 | |
94 | while (next_found) { |
95 | ++count; |
96 | previous_permutation = current_permutation; |
97 | |
98 | auto current_subrange = std::ranges::subrange( |
99 | Iter(current_permutation.data()), Sent(Iter(current_permutation.data() + i))); |
100 | auto previous_subrange = std::ranges::subrange( |
101 | Iter(previous_permutation.data()), Sent(Iter(previous_permutation.data() + i))); |
102 | |
103 | next_found = run_next_permutation<Iter>(call_next_permutation, current_subrange, previous_subrange); |
104 | } |
105 | |
106 | assert(count == factorial(i)); |
107 | } |
108 | } |
109 | |
110 | template <class Iter, class Sent> |
111 | constexpr void test_all_permutations() { |
112 | test_next_permutations<Iter, Sent>([](auto&& range) { |
113 | return std::ranges::next_permutation(range.begin(), range.end()); |
114 | }); |
115 | |
116 | test_next_permutations<Iter, Sent>([](auto&& range) { |
117 | return std::ranges::next_permutation(range); |
118 | }); |
119 | } |
120 | |
121 | template <class Iter, class Sent, int N> |
122 | constexpr void test_one(const std::array<int, N> input, bool expected_found, std::array<int, N> expected) { |
123 | using Result = std::ranges::next_permutation_result<Iter>; |
124 | |
125 | { // (iterator, sentinel) overload. |
126 | auto in = input; |
127 | auto begin = Iter(in.data()); |
128 | auto end = Sent(Iter(in.data() + in.size())); |
129 | |
130 | std::same_as<Result> decltype(auto) result = std::ranges::next_permutation(begin, end); |
131 | assert(result.found == expected_found); |
132 | assert(result.in == end); |
133 | assert(in == expected); |
134 | } |
135 | |
136 | { // (range) overload. |
137 | auto in = input; |
138 | auto begin = Iter(in.data()); |
139 | auto end = Sent(Iter(in.data() + in.size())); |
140 | auto range = std::ranges::subrange(begin, end); |
141 | |
142 | std::same_as<Result> decltype(auto) result = std::ranges::next_permutation(range); |
143 | assert(result.found == expected_found); |
144 | assert(result.in == end); |
145 | assert(in == expected); |
146 | } |
147 | } |
148 | |
149 | template <class Iter, class Sent> |
150 | constexpr void test_iter_sent() { |
151 | test_all_permutations<Iter, Sent>(); |
152 | |
153 | // Empty range. |
154 | test_one<Iter, Sent, 0>({}, false, {}); |
155 | // 1-element range. |
156 | test_one<Iter, Sent, 1>({1}, false, {1}); |
157 | // 2-element range. |
158 | test_one<Iter, Sent, 2>({1, 2}, true, {2, 1}); |
159 | test_one<Iter, Sent, 2>({2, 1}, false, {1, 2}); |
160 | // Longer sequence. |
161 | test_one<Iter, Sent, 8>({1, 2, 3, 4, 5, 6, 7, 8}, true, {1, 2, 3, 4, 5, 6, 8, 7}); |
162 | // Longer sequence, permutations exhausted. |
163 | test_one<Iter, Sent, 8>({8, 7, 6, 5, 4, 3, 2, 1}, false, {1, 2, 3, 4, 5, 6, 7, 8}); |
164 | } |
165 | |
166 | template <class Iter> |
167 | constexpr void test_iter() { |
168 | test_iter_sent<Iter, Iter>(); |
169 | test_iter_sent<Iter, sentinel_wrapper<Iter>>(); |
170 | test_iter_sent<Iter, sized_sentinel<Iter>>(); |
171 | } |
172 | |
173 | constexpr void test_iterators() { |
174 | test_iter<bidirectional_iterator<int*>>(); |
175 | test_iter<random_access_iterator<int*>>(); |
176 | test_iter<contiguous_iterator<int*>>(); |
177 | test_iter<int*>(); |
178 | } |
179 | |
180 | constexpr bool test() { |
181 | test_iterators(); |
182 | |
183 | { // A custom predicate works. |
184 | struct A { |
185 | int i; |
186 | constexpr bool comp(const A& rhs) const { return i > rhs.i; } |
187 | constexpr bool operator==(const A&) const = default; |
188 | }; |
189 | const std::array input = {A{.i: 1}, A{.i: 2}, A{.i: 3}, A{.i: 4}, A{.i: 5}}; |
190 | std::array expected = {A{.i: 5}, A{.i: 4}, A{.i: 3}, A{.i: 2}, A{.i: 1}}; |
191 | |
192 | { // (iterator, sentinel) overload. |
193 | auto in = input; |
194 | auto result = std::ranges::next_permutation(in.begin(), in.end(), &A::comp); |
195 | |
196 | assert(result.found == false); |
197 | assert(result.in == in.end()); |
198 | assert(in == expected); |
199 | } |
200 | |
201 | { // (range) overload. |
202 | auto in = input; |
203 | auto result = std::ranges::next_permutation(in, &A::comp); |
204 | |
205 | assert(result.found == false); |
206 | assert(result.in == in.end()); |
207 | assert(in == expected); |
208 | } |
209 | } |
210 | |
211 | { // A custom projection works. |
212 | struct A { |
213 | int i; |
214 | constexpr A negate() const { return A{.i: i * -1}; } |
215 | constexpr auto operator<=>(const A&) const = default; |
216 | }; |
217 | const std::array input = {A{.i: 1}, A{.i: 2}, A{.i: 3}, A{.i: 4}, A{.i: 5}}; |
218 | std::array expected = {A{.i: 5}, A{.i: 4}, A{.i: 3}, A{.i: 2}, A{.i: 1}}; |
219 | |
220 | { // (iterator, sentinel) overload. |
221 | auto in = input; |
222 | auto result = std::ranges::next_permutation(in.begin(), in.end(), {}, &A::negate); |
223 | |
224 | assert(result.found == false); |
225 | assert(result.in == in.end()); |
226 | assert(in == expected); |
227 | } |
228 | |
229 | { // (range) overload. |
230 | auto in = input; |
231 | auto result = std::ranges::next_permutation(in, {}, &A::negate); |
232 | |
233 | assert(result.found == false); |
234 | assert(result.in == in.end()); |
235 | assert(in == expected); |
236 | } |
237 | } |
238 | |
239 | { // Complexity: At most `(last - first) / 2` swaps. |
240 | const std::array input = {1, 2, 3, 4, 5, 6}; |
241 | |
242 | { // (iterator, sentinel) overload. |
243 | auto in = input; |
244 | int swaps_count = 0; |
245 | auto begin = adl::Iterator::TrackSwaps(in.data(), swaps_count); |
246 | auto end = adl::Iterator::TrackSwaps(in.data() + in.size(), swaps_count); |
247 | |
248 | std::ranges::next_permutation(begin, end); |
249 | assert(swaps_count <= (base(end) - base(begin) + 1) / 2); |
250 | } |
251 | |
252 | { // (range) overload. |
253 | auto in = input; |
254 | int swaps_count = 0; |
255 | auto begin = adl::Iterator::TrackSwaps(in.data(), swaps_count); |
256 | auto end = adl::Iterator::TrackSwaps(in.data() + in.size(), swaps_count); |
257 | |
258 | std::ranges::next_permutation(std::ranges::subrange(begin, end)); |
259 | assert(swaps_count <= (base(end) - base(begin) + 1) / 2); |
260 | } |
261 | } |
262 | |
263 | return true; |
264 | } |
265 | |
266 | int main(int, char**) { |
267 | test(); |
268 | static_assert(test()); |
269 | |
270 | return 0; |
271 | } |
272 | |