1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// UNSUPPORTED: c++03, c++11, c++14, c++17
10
11// <algorithm>
12
13// template<random_access_iterator I, sentinel_for<I> S, class Comp = ranges::less,
14// class Proj = identity>
15// requires sortable<I, Comp, Proj>
16// constexpr I
17// ranges::nth_element(I first, I nth, S last, Comp comp = {}, Proj proj = {}); // since C++20
18//
19// template<random_access_range R, class Comp = ranges::less, class Proj = identity>
20// requires sortable<iterator_t<R>, Comp, Proj>
21// constexpr borrowed_iterator_t<R>
22// ranges::nth_element(R&& r, iterator_t<R> nth, Comp comp = {}, Proj proj = {}); // since C++20
23
24#include <algorithm>
25#include <array>
26#include <concepts>
27#include <functional>
28#include <iterator>
29#include <optional>
30#include <ranges>
31
32#include "almost_satisfies_types.h"
33#include "test_iterators.h"
34
35// SFINAE tests.
36
37using BadComparator = ComparatorNotCopyable<int*>;
38static_assert(!std::sortable<int*, BadComparator>);
39
40template <class Iter, class Sent = sentinel_wrapper<Iter>, class Comp = std::ranges::less>
41concept HasNthElementIt = requires(Iter first, Iter nth, Sent last, Comp comp) {
42 std::ranges::nth_element(first, nth, last, comp);
43};
44
45static_assert(HasNthElementIt<int*>);
46static_assert(!HasNthElementIt<RandomAccessIteratorNotDerivedFrom>);
47static_assert(!HasNthElementIt<RandomAccessIteratorBadIndex>);
48static_assert(!HasNthElementIt<int*, SentinelForNotSemiregular>);
49static_assert(!HasNthElementIt<int*, SentinelForNotWeaklyEqualityComparableWith>);
50static_assert(!HasNthElementIt<int*, int*, BadComparator>);
51static_assert(!HasNthElementIt<const int*>); // Doesn't satisfy `sortable`.
52
53template <class Range, class Comp = std::ranges::less>
54concept HasNthElementR = requires(Range range, std::ranges::iterator_t<Range> nth, Comp comp) {
55 std::ranges::nth_element(range, nth, comp);
56};
57
58static_assert(HasNthElementR<UncheckedRange<int*>>);
59static_assert(!HasNthElementR<RandomAccessRangeNotDerivedFrom>);
60static_assert(!HasNthElementR<RandomAccessRangeBadIndex>);
61static_assert(!HasNthElementR<UncheckedRange<int*, SentinelForNotSemiregular>>);
62static_assert(!HasNthElementR<UncheckedRange<int*, SentinelForNotWeaklyEqualityComparableWith>>);
63static_assert(!HasNthElementR<UncheckedRange<int*>, BadComparator>);
64static_assert(!HasNthElementR<UncheckedRange<const int*>>); // Doesn't satisfy `sortable`.
65
66template <std::size_t N, class T, class Iter>
67constexpr void verify_nth(const std::array<T, N>& partially_sorted, std::size_t nth_index, Iter last, T expected_nth) {
68 // Note that the exact output of `nth_element` is unspecified and may vary between implementations.
69
70 assert(base(last) == partially_sorted.data() + partially_sorted.size());
71
72 auto b = partially_sorted.begin();
73 auto nth = b + nth_index;
74 auto e = partially_sorted.end();
75 if (nth == e)
76 return;
77
78 assert(*nth == expected_nth);
79
80 // All elements on the left are <= nth.
81 assert(std::all_of(b, nth, [&](const auto& v) { return v <= *nth; }));
82 // All elements on the right are >= nth.
83 assert(std::all_of(nth, e, [&](const auto& v) { return v >= *nth; }));
84
85 {
86 auto sorted = partially_sorted;
87 std::ranges::sort(sorted);
88
89 // The element at index `n` is the same as if the range were fully sorted.
90 assert(sorted[nth_index] == *nth);
91 }
92}
93
94template <class Iter, class Sent, std::size_t N>
95constexpr void test_one(std::array<int, N> input, std::size_t nth_index, std::optional<int> expected_nth = {}) {
96 assert(expected_nth || nth_index == N);
97
98 { // (iterator, sentinel) overload.
99 auto partially_sorted = input;
100 auto b = Iter(partially_sorted.data());
101 auto nth = b + nth_index;
102 auto e = Sent(Iter(partially_sorted.data() + partially_sorted.size()));
103
104 std::same_as<Iter> decltype(auto) last = std::ranges::nth_element(b, nth, e);
105 if (nth_index != N) {
106 verify_nth(partially_sorted, nth_index, last, *expected_nth);
107 } else {
108 assert(partially_sorted == input);
109 }
110 }
111
112 { // (range) overload.
113 auto partially_sorted = input;
114 auto b = Iter(partially_sorted.data());
115 auto nth = b + nth_index;
116 auto e = Sent(Iter(partially_sorted.data() + partially_sorted.size()));
117 auto range = std::ranges::subrange(b, e);
118
119 std::same_as<Iter> decltype(auto) last = std::ranges::nth_element(range, nth);
120 if (nth_index != N) {
121 verify_nth(partially_sorted, nth_index, last, *expected_nth);
122 } else {
123 assert(partially_sorted == input);
124 }
125 }
126}
127
128template <class Iter, class Sent, std::size_t N>
129constexpr void test_all_cases(std::array<int, N> input) {
130 auto sorted = input;
131 std::sort(sorted.begin(), sorted.end());
132
133 for (int n = 0; n != N; ++n) {
134 test_one<Iter, Sent, N>(input, n, sorted[n]);
135 }
136 test_one<Iter, Sent, N>(input, N);
137}
138
139constexpr void test_iterators() {
140 auto check = []<class Iter, class Sent> {
141 // Empty sequence.
142 test_one<Iter, Sent, 0>({}, 0);
143
144 // 1-element sequence.
145 test_all_cases<Iter, Sent>(std::array{1});
146
147 // 2-element sequence.
148 test_all_cases<Iter, Sent>(std::array{2, 1});
149
150 // 3-element sequence.
151 test_all_cases<Iter, Sent>(std::array{2, 1, 3});
152
153 // Longer sequence.
154 test_all_cases<Iter, Sent>(std::array{2, 1, 3, 6, 8, 4, 11, 5});
155
156 // Longer sequence with duplicates.
157 test_all_cases<Iter, Sent>(std::array{2, 1, 3, 6, 2, 8, 6});
158
159 // All elements are the same.
160 test_all_cases<Iter, Sent>(std::array{1, 1, 1, 1});
161
162 { // nth element is in the right place.
163 std::array input = {6, 5, 3, 1, 4, 2};
164 constexpr std::size_t N = input.size();
165 test_one<Iter, Sent, N>(input, 2, /*expected_nth=*/3);
166 }
167
168 // Already sorted.
169 test_all_cases<Iter, Sent>(std::array{1, 2, 3, 4, 5, 6});
170
171 // Descending.
172 test_all_cases<Iter, Sent>(std::array{6, 5, 4, 3, 2, 1});
173
174 // Repeating pattern.
175 test_all_cases<Iter, Sent>(std::array{2, 1, 2, 1, 2, 1});
176 };
177
178 check.operator()<random_access_iterator<int*>, random_access_iterator<int*>>();
179 check.operator()<random_access_iterator<int*>, sentinel_wrapper<random_access_iterator<int*>>>();
180 check.operator()<contiguous_iterator<int*>, contiguous_iterator<int*>>();
181 check.operator()<contiguous_iterator<int*>, sentinel_wrapper<contiguous_iterator<int*>>>();
182 check.operator()<int*, int*>();
183 check.operator()<int*, sentinel_wrapper<int*>>();
184}
185
186constexpr bool test() {
187 test_iterators();
188
189 { // A custom comparator works.
190 const std::array input = {1, 2, 3, 4, 5};
191 std::ranges::greater comp;
192
193 {
194 auto in = input;
195 auto last = std::ranges::nth_element(in.begin(), in.begin() + 1, in.end(), comp);
196 assert(in[1] == 4);
197 assert(last == in.end());
198 }
199
200 {
201 auto in = input;
202 auto last = std::ranges::nth_element(in, in.begin() + 1, comp);
203 assert(in[1] == 4);
204 assert(last == in.end());
205 }
206 }
207
208 { // A custom projection works.
209 struct A {
210 int a;
211 constexpr bool operator==(const A&) const = default;
212 };
213
214 const std::array input = {A{.a: 2}, A{.a: 1}, A{.a: 3}};
215
216 {
217 auto in = input;
218 auto last = std::ranges::nth_element(in.begin(), in.begin() + 1, in.end(), {}, &A::a);
219 assert(in[1] == A{.a: 2});
220 assert(last == in.end());
221 }
222
223 {
224 auto in = input;
225 auto last = std::ranges::nth_element(in, in.begin() + 1, {}, &A::a);
226 assert(in[1] == A{.a: 2});
227 assert(last == in.end());
228 }
229 }
230
231 { // `std::invoke` is used in the implementation.
232 struct S {
233 int i;
234 constexpr S(int i_) : i(i_) {}
235
236 constexpr bool comparator(const S& rhs) const { return i < rhs.i; }
237 constexpr const S& projection() const { return *this; }
238
239 constexpr bool operator==(const S&) const = default;
240 };
241
242 const std::array input = {S{2}, S{1}, S{3}};
243
244 {
245 auto in = input;
246 auto last = std::ranges::nth_element(in.begin(), in.begin() + 1, in.end(), &S::comparator, &S::projection);
247 assert(in[1] == S{2});
248 assert(last == in.end());
249 }
250
251 {
252 auto in = input;
253 auto last = std::ranges::nth_element(in, in.begin() + 1, &S::comparator, &S::projection);
254 assert(in[1] == S{2});
255 assert(last == in.end());
256 }
257 }
258
259 { // `std::ranges::dangling` is returned.
260 std::array in{1, 2, 3};
261 [[maybe_unused]] std::same_as<std::ranges::dangling> decltype(auto) result =
262 std::ranges::nth_element(std::move(in), in.begin());
263 }
264
265 return true;
266}
267
268int main(int, char**) {
269 test();
270 static_assert(test());
271
272 return 0;
273}
274

source code of libcxx/test/std/algorithms/alg.sorting/alg.nth.element/ranges_nth_element.pass.cpp