1 | //===-- Exhaustive test template for math functions -------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "src/__support/CPP/type_traits.h" |
10 | #include "src/__support/FPUtil/FPBits.h" |
11 | #include "src/__support/macros/properties/types.h" |
12 | #include "test/UnitTest/FPMatcher.h" |
13 | #include "test/UnitTest/Test.h" |
14 | #include "utils/MPFRWrapper/MPFRUtils.h" |
15 | |
16 | #include <atomic> |
17 | #include <functional> |
18 | #include <iostream> |
19 | #include <mutex> |
20 | #include <sstream> |
21 | #include <thread> |
22 | #include <vector> |
23 | |
24 | // To test exhaustively for inputs in the range [start, stop) in parallel: |
25 | // 1. Define a Checker class with: |
26 | // - FloatType: define floating point type to be used. |
27 | // - FPBits: fputil::FPBits<FloatType>. |
28 | // - StorageType: define bit type for the corresponding floating point type. |
29 | // - uint64_t check(start, stop, rounding_mode): a method to test in given |
30 | // range for a given rounding mode, which returns the number of |
31 | // failures. |
32 | // 2. Use LlvmLibcExhaustiveMathTest<Checker> class |
33 | // 3. Call: test_full_range(start, stop, nthreads, rounding) |
34 | // or test_full_range_all_roundings(start, stop). |
35 | // * For single input single output math function, use the convenient template: |
36 | // LlvmLibcUnaryOpExhaustiveMathTest<FloatType, Op, Func>. |
37 | namespace mpfr = LIBC_NAMESPACE::testing::mpfr; |
38 | |
39 | template <typename OutType, typename InType = OutType> |
40 | using UnaryOp = OutType(InType); |
41 | |
42 | template <typename OutType, typename InType, mpfr::Operation Op, |
43 | UnaryOp<OutType, InType> Func> |
44 | struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test { |
45 | using FloatType = InType; |
46 | using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>; |
47 | using StorageType = typename FPBits::StorageType; |
48 | |
49 | // Check in a range, return the number of failures. |
50 | uint64_t check(StorageType start, StorageType stop, |
51 | mpfr::RoundingMode rounding) { |
52 | mpfr::ForceRoundingMode r(rounding); |
53 | if (!r.success) |
54 | return (stop > start); |
55 | StorageType bits = start; |
56 | uint64_t failed = 0; |
57 | do { |
58 | FPBits xbits(bits); |
59 | FloatType x = xbits.get_val(); |
60 | bool correct = |
61 | TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, Func(x), 0.5, rounding); |
62 | failed += (!correct); |
63 | // Uncomment to print out failed values. |
64 | if (!correct) { |
65 | EXPECT_MPFR_MATCH_ROUNDING(Op, x, Func(x), 0.5, rounding); |
66 | } |
67 | } while (bits++ < stop); |
68 | return failed; |
69 | } |
70 | }; |
71 | |
72 | template <typename OutType, typename InType = OutType> |
73 | using BinaryOp = OutType(InType, InType); |
74 | |
75 | template <typename OutType, typename InType, mpfr::Operation Op, |
76 | BinaryOp<OutType, InType> Func> |
77 | struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test { |
78 | using FloatType = InType; |
79 | using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>; |
80 | using StorageType = typename FPBits::StorageType; |
81 | |
82 | // Check in a range, return the number of failures. |
83 | uint64_t check(StorageType x_start, StorageType x_stop, StorageType y_start, |
84 | StorageType y_stop, mpfr::RoundingMode rounding) { |
85 | mpfr::ForceRoundingMode r(rounding); |
86 | if (!r.success) |
87 | return x_stop > x_start || y_stop > y_start; |
88 | StorageType xbits = x_start; |
89 | uint64_t failed = 0; |
90 | do { |
91 | FloatType x = FPBits(xbits).get_val(); |
92 | StorageType ybits = y_start; |
93 | do { |
94 | FloatType y = FPBits(ybits).get_val(); |
95 | mpfr::BinaryInput<FloatType> input{x, y}; |
96 | bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, input, Func(x, y), |
97 | 0.5, rounding); |
98 | failed += (!correct); |
99 | // Uncomment to print out failed values. |
100 | if (!correct) { |
101 | EXPECT_MPFR_MATCH_ROUNDING(Op, input, Func(x, y), 0.5, rounding); |
102 | } |
103 | } while (ybits++ < y_stop); |
104 | } while (xbits++ < x_stop); |
105 | return failed; |
106 | } |
107 | }; |
108 | |
109 | // Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide |
110 | // StorageType and check method. |
111 | template <typename Checker, size_t Increment = 1 << 20> |
112 | struct LlvmLibcExhaustiveMathTest |
113 | : public virtual LIBC_NAMESPACE::testing::Test, |
114 | public Checker { |
115 | using FloatType = typename Checker::FloatType; |
116 | using FPBits = typename Checker::FPBits; |
117 | using StorageType = typename Checker::StorageType; |
118 | |
119 | void explain_failed_range(std::stringstream &msg, StorageType x_begin, |
120 | StorageType x_end) { |
121 | #ifdef LIBC_TYPES_HAS_FLOAT16 |
122 | using T = LIBC_NAMESPACE::cpp::conditional_t< |
123 | LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float, FloatType>; |
124 | #else |
125 | using T = FloatType; |
126 | #endif |
127 | |
128 | msg << x_begin << " to " << x_end << " [0x" << std::hex << x_begin << ", 0x" |
129 | << x_end << "), [" << std::hexfloat |
130 | << static_cast<T>(FPBits(x_begin).get_val()) << ", " |
131 | << static_cast<T>(FPBits(x_end).get_val()) << ")" ; |
132 | } |
133 | |
134 | void explain_failed_range(std::stringstream &msg, StorageType x_begin, |
135 | StorageType x_end, StorageType y_begin, |
136 | StorageType y_end) { |
137 | msg << "x " ; |
138 | explain_failed_range(msg, x_begin, x_end); |
139 | msg << ", y " ; |
140 | explain_failed_range(msg, y_begin, y_end); |
141 | } |
142 | |
143 | // Break [start, stop) into `nthreads` subintervals and apply *check to each |
144 | // subinterval in parallel. |
145 | template <typename... T> |
146 | void test_full_range(mpfr::RoundingMode rounding, StorageType start, |
147 | StorageType stop, T... ) { |
148 | int n_threads = std::thread::hardware_concurrency(); |
149 | std::vector<std::thread> thread_list; |
150 | std::mutex mx_cur_val; |
151 | int current_percent = -1; |
152 | StorageType current_value = start; |
153 | std::atomic<uint64_t> failed(0); |
154 | |
155 | for (int i = 0; i < n_threads; ++i) { |
156 | thread_list.emplace_back([&, this]() { |
157 | while (true) { |
158 | StorageType range_begin, range_end; |
159 | int new_percent = -1; |
160 | { |
161 | std::lock_guard<std::mutex> lock(mx_cur_val); |
162 | if (current_value == stop) |
163 | return; |
164 | |
165 | range_begin = current_value; |
166 | if (stop >= Increment && stop - Increment >= current_value) { |
167 | range_end = current_value + Increment; |
168 | } else { |
169 | range_end = stop; |
170 | } |
171 | current_value = range_end; |
172 | int pc = |
173 | static_cast<int>(100.0 * (range_end - start) / (stop - start)); |
174 | if (current_percent != pc) { |
175 | new_percent = pc; |
176 | current_percent = pc; |
177 | } |
178 | } |
179 | if (new_percent >= 0) { |
180 | std::stringstream msg; |
181 | msg << new_percent << "% is in process \r" ; |
182 | std::cout << msg.str() << std::flush; |
183 | } |
184 | |
185 | uint64_t failed_in_range = Checker::check( |
186 | range_begin, range_end, extra_range_bounds..., rounding); |
187 | if (failed_in_range > 0) { |
188 | std::stringstream msg; |
189 | msg << "Test failed for " << std::dec << failed_in_range |
190 | << " inputs in range: " ; |
191 | explain_failed_range(msg, range_begin, range_end, |
192 | extra_range_bounds...); |
193 | msg << "\n" ; |
194 | std::cerr << msg.str() << std::flush; |
195 | |
196 | failed.fetch_add(i: failed_in_range); |
197 | } |
198 | } |
199 | }); |
200 | } |
201 | |
202 | for (auto &thread : thread_list) { |
203 | if (thread.joinable()) { |
204 | thread.join(); |
205 | } |
206 | } |
207 | |
208 | std::cout << std::endl; |
209 | std::cout << "Test " << ((failed > 0) ? "FAILED" : "PASSED" ) << std::endl; |
210 | ASSERT_EQ(failed.load(), uint64_t(0)); |
211 | } |
212 | |
213 | void test_full_range_all_roundings(StorageType start, StorageType stop) { |
214 | std::cout << "-- Testing for FE_TONEAREST in range [0x" << std::hex << start |
215 | << ", 0x" << stop << ") --" << std::dec << std::endl; |
216 | test_full_range(mpfr::RoundingMode::Nearest, start, stop); |
217 | |
218 | std::cout << "-- Testing for FE_UPWARD in range [0x" << std::hex << start |
219 | << ", 0x" << stop << ") --" << std::dec << std::endl; |
220 | test_full_range(mpfr::RoundingMode::Upward, start, stop); |
221 | |
222 | std::cout << "-- Testing for FE_DOWNWARD in range [0x" << std::hex << start |
223 | << ", 0x" << stop << ") --" << std::dec << std::endl; |
224 | test_full_range(mpfr::RoundingMode::Downward, start, stop); |
225 | |
226 | std::cout << "-- Testing for FE_TOWARDZERO in range [0x" << std::hex |
227 | << start << ", 0x" << stop << ") --" << std::dec << std::endl; |
228 | test_full_range(mpfr::RoundingMode::TowardZero, start, stop); |
229 | } |
230 | |
231 | void test_full_range_all_roundings(StorageType x_start, StorageType x_stop, |
232 | StorageType y_start, StorageType y_stop) { |
233 | std::cout << "-- Testing for FE_TONEAREST in x range [0x" << std::hex |
234 | << x_start << ", 0x" << x_stop << "), y range [0x" << y_start |
235 | << ", 0x" << y_stop << ") --" << std::dec << std::endl; |
236 | test_full_range(mpfr::RoundingMode::Nearest, x_start, x_stop, y_start, |
237 | y_stop); |
238 | |
239 | std::cout << "-- Testing for FE_UPWARD in x range [0x" << std::hex |
240 | << x_start << ", 0x" << x_stop << "), y range [0x" << y_start |
241 | << ", 0x" << y_stop << ") --" << std::dec << std::endl; |
242 | test_full_range(mpfr::RoundingMode::Upward, x_start, x_stop, y_start, |
243 | y_stop); |
244 | |
245 | std::cout << "-- Testing for FE_DOWNWARD in x range [0x" << std::hex |
246 | << x_start << ", 0x" << x_stop << "), y range [0x" << y_start |
247 | << ", 0x" << y_stop << ") --" << std::dec << std::endl; |
248 | test_full_range(mpfr::RoundingMode::Downward, x_start, x_stop, y_start, |
249 | y_stop); |
250 | |
251 | std::cout << "-- Testing for FE_TOWARDZERO in x range [0x" << std::hex |
252 | << x_start << ", 0x" << x_stop << "), y range [0x" << y_start |
253 | << ", 0x" << y_stop << ") --" << std::dec << std::endl; |
254 | test_full_range(mpfr::RoundingMode::TowardZero, x_start, x_stop, y_start, |
255 | y_stop); |
256 | } |
257 | }; |
258 | |
259 | template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func> |
260 | using LlvmLibcUnaryOpExhaustiveMathTest = |
261 | LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, FloatType, Op, Func>>; |
262 | |
263 | template <typename OutType, typename InType, mpfr::Operation Op, |
264 | UnaryOp<OutType, InType> Func> |
265 | using LlvmLibcUnaryNarrowingOpExhaustiveMathTest = |
266 | LlvmLibcExhaustiveMathTest<UnaryOpChecker<OutType, InType, Op, Func>>; |
267 | |
268 | template <typename FloatType, mpfr::Operation Op, BinaryOp<FloatType> Func> |
269 | using LlvmLibcBinaryOpExhaustiveMathTest = |
270 | LlvmLibcExhaustiveMathTest<BinaryOpChecker<FloatType, FloatType, Op, Func>, |
271 | 1 << 2>; |
272 | |