1 | //===----------------------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | // <algorithm> |
10 | |
11 | // Test std::nth_element stability randomization |
12 | |
13 | // UNSUPPORTED: c++03 |
14 | // ADDITIONAL_COMPILE_FLAGS: -D_LIBCPP_DEBUG_RANDOMIZE_UNSPECIFIED_STABILITY |
15 | |
16 | #include <algorithm> |
17 | #include <array> |
18 | #include <cassert> |
19 | #include <functional> |
20 | #include <iterator> |
21 | #include <vector> |
22 | |
23 | #include "test_macros.h" |
24 | |
25 | struct MyType { |
26 | int value = 0; |
27 | constexpr bool operator<(const MyType& other) const { return value < other.value; } |
28 | }; |
29 | |
30 | std::vector<MyType> deterministic() { |
31 | static constexpr int kSize = 100; |
32 | std::vector<MyType> v; |
33 | v.resize(new_size: kSize); |
34 | for (int i = 0; i < kSize; ++i) { |
35 | v[i].value = (i % 2 ? i : kSize / 2 + i); |
36 | } |
37 | std::__nth_element<std::_ClassicAlgPolicy>(v.begin(), v.begin() + kSize / 2, v.end(), std::less<MyType>()); |
38 | return v; |
39 | } |
40 | |
41 | void test_randomization() { |
42 | static constexpr int kSize = 100; |
43 | std::vector<MyType> v; |
44 | v.resize(new_size: kSize); |
45 | for (int i = 0; i < kSize; ++i) { |
46 | v[i].value = (i % 2 ? i : kSize / 2 + i); |
47 | } |
48 | auto deterministic_v = deterministic(); |
49 | std::nth_element(first: v.begin(), nth: v.begin() + kSize / 2, last: v.end()); |
50 | bool all_equal = true; |
51 | for (int i = 0; i < kSize; ++i) { |
52 | if (v[i].value != deterministic_v[i].value) { |
53 | all_equal = false; |
54 | } |
55 | } |
56 | assert(!all_equal); |
57 | } |
58 | |
59 | void test_same() { |
60 | static constexpr int kSize = 100; |
61 | std::vector<MyType> v; |
62 | v.resize(new_size: kSize); |
63 | for (int i = 0; i < kSize; ++i) { |
64 | v[i].value = (i % 2 ? i : kSize / 2 + i); |
65 | } |
66 | auto snapshot_v = v; |
67 | auto snapshot_custom_v = v; |
68 | std::nth_element(first: v.begin(), nth: v.begin() + kSize / 2, last: v.end()); |
69 | std::nth_element(first: snapshot_v.begin(), nth: snapshot_v.begin() + kSize / 2, last: snapshot_v.end()); |
70 | std::nth_element(first: snapshot_custom_v.begin(), nth: snapshot_custom_v.begin() + kSize / 2, last: snapshot_custom_v.end(), comp: std::less<MyType>()); |
71 | bool all_equal = true; |
72 | for (int i = 0; i < kSize; ++i) { |
73 | if (v[i].value != snapshot_v[i].value || v[i].value != snapshot_custom_v[i].value) { |
74 | all_equal = false; |
75 | } |
76 | if (i < kSize / 2) { |
77 | assert(v[i].value <= v[kSize / 2].value); |
78 | } |
79 | } |
80 | assert(all_equal); |
81 | } |
82 | |
83 | #if TEST_STD_VER > 17 |
84 | constexpr bool test_constexpr() { |
85 | std::array<MyType, 10> v; |
86 | for (int i = 9; i >= 0; --i) { |
87 | v[9 - i].value = i; |
88 | } |
89 | std::nth_element(v.begin(), v.begin() + 5, v.end()); |
90 | return std::is_partitioned(v.begin(), v.end(), [&](const MyType& m) { return m.value <= v[5].value; }); |
91 | } |
92 | #endif |
93 | |
94 | int main(int, char**) { |
95 | test_randomization(); |
96 | test_same(); |
97 | #if TEST_STD_VER > 17 |
98 | static_assert(test_constexpr(), "" ); |
99 | #endif |
100 | return 0; |
101 | } |
102 | |