1//===-- Exhaustive test template for math functions -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "src/__support/CPP/type_traits.h"
10#include "src/__support/FPUtil/FPBits.h"
11#include "src/__support/macros/properties/types.h"
12#include "test/UnitTest/FPMatcher.h"
13#include "test/UnitTest/Test.h"
14#include "utils/MPFRWrapper/MPFRUtils.h"
15
16#include <atomic>
17#include <functional>
18#include <iostream>
19#include <mutex>
20#include <sstream>
21#include <thread>
22#include <vector>
23
24// To test exhaustively for inputs in the range [start, stop) in parallel:
25// 1. Define a Checker class with:
26// - FloatType: define floating point type to be used.
27// - FPBits: fputil::FPBits<FloatType>.
28// - StorageType: define bit type for the corresponding floating point type.
29// - uint64_t check(start, stop, rounding_mode): a method to test in given
30// range for a given rounding mode, which returns the number of
31// failures.
32// 2. Use LlvmLibcExhaustiveMathTest<Checker> class
33// 3. Call: test_full_range(start, stop, nthreads, rounding)
34// or test_full_range_all_roundings(start, stop).
35// * For single input single output math function, use the convenient template:
36// LlvmLibcUnaryOpExhaustiveMathTest<FloatType, Op, Func>.
37namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
38
39template <typename OutType, typename InType = OutType>
40using UnaryOp = OutType(InType);
41
42template <typename OutType, typename InType, mpfr::Operation Op,
43 UnaryOp<OutType, InType> Func>
44struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
45 using FloatType = InType;
46 using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
47 using StorageType = typename FPBits::StorageType;
48
49 // Check in a range, return the number of failures.
50 uint64_t check(StorageType start, StorageType stop,
51 mpfr::RoundingMode rounding) {
52 mpfr::ForceRoundingMode r(rounding);
53 if (!r.success)
54 return (stop > start);
55 StorageType bits = start;
56 uint64_t failed = 0;
57 do {
58 FPBits xbits(bits);
59 FloatType x = xbits.get_val();
60 bool correct =
61 TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, Func(x), 0.5, rounding);
62 failed += (!correct);
63 // Uncomment to print out failed values.
64 if (!correct) {
65 EXPECT_MPFR_MATCH_ROUNDING(Op, x, Func(x), 0.5, rounding);
66 }
67 } while (bits++ < stop);
68 return failed;
69 }
70};
71
72template <typename OutType, typename InType = OutType>
73using BinaryOp = OutType(InType, InType);
74
75template <typename OutType, typename InType, mpfr::Operation Op,
76 BinaryOp<OutType, InType> Func>
77struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
78 using FloatType = InType;
79 using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
80 using StorageType = typename FPBits::StorageType;
81
82 // Check in a range, return the number of failures.
83 uint64_t check(StorageType x_start, StorageType x_stop, StorageType y_start,
84 StorageType y_stop, mpfr::RoundingMode rounding) {
85 mpfr::ForceRoundingMode r(rounding);
86 if (!r.success)
87 return x_stop > x_start || y_stop > y_start;
88 StorageType xbits = x_start;
89 uint64_t failed = 0;
90 do {
91 FloatType x = FPBits(xbits).get_val();
92 StorageType ybits = y_start;
93 do {
94 FloatType y = FPBits(ybits).get_val();
95 mpfr::BinaryInput<FloatType> input{x, y};
96 bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, input, Func(x, y),
97 0.5, rounding);
98 failed += (!correct);
99 // Uncomment to print out failed values.
100 if (!correct) {
101 EXPECT_MPFR_MATCH_ROUNDING(Op, input, Func(x, y), 0.5, rounding);
102 }
103 } while (ybits++ < y_stop);
104 } while (xbits++ < x_stop);
105 return failed;
106 }
107};
108
109// Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
110// StorageType and check method.
111template <typename Checker, size_t Increment = 1 << 20>
112struct LlvmLibcExhaustiveMathTest
113 : public virtual LIBC_NAMESPACE::testing::Test,
114 public Checker {
115 using FloatType = typename Checker::FloatType;
116 using FPBits = typename Checker::FPBits;
117 using StorageType = typename Checker::StorageType;
118
119 void explain_failed_range(std::stringstream &msg, StorageType x_begin,
120 StorageType x_end) {
121#ifdef LIBC_TYPES_HAS_FLOAT16
122 using T = LIBC_NAMESPACE::cpp::conditional_t<
123 LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float, FloatType>;
124#else
125 using T = FloatType;
126#endif
127
128 msg << x_begin << " to " << x_end << " [0x" << std::hex << x_begin << ", 0x"
129 << x_end << "), [" << std::hexfloat
130 << static_cast<T>(FPBits(x_begin).get_val()) << ", "
131 << static_cast<T>(FPBits(x_end).get_val()) << ")";
132 }
133
134 void explain_failed_range(std::stringstream &msg, StorageType x_begin,
135 StorageType x_end, StorageType y_begin,
136 StorageType y_end) {
137 msg << "x ";
138 explain_failed_range(msg, x_begin, x_end);
139 msg << ", y ";
140 explain_failed_range(msg, y_begin, y_end);
141 }
142
143 // Break [start, stop) into `nthreads` subintervals and apply *check to each
144 // subinterval in parallel.
145 template <typename... T>
146 void test_full_range(mpfr::RoundingMode rounding, StorageType start,
147 StorageType stop, T... extra_range_bounds) {
148 int n_threads = std::thread::hardware_concurrency();
149 std::vector<std::thread> thread_list;
150 std::mutex mx_cur_val;
151 int current_percent = -1;
152 StorageType current_value = start;
153 std::atomic<uint64_t> failed(0);
154
155 for (int i = 0; i < n_threads; ++i) {
156 thread_list.emplace_back([&, this]() {
157 while (true) {
158 StorageType range_begin, range_end;
159 int new_percent = -1;
160 {
161 std::lock_guard<std::mutex> lock(mx_cur_val);
162 if (current_value == stop)
163 return;
164
165 range_begin = current_value;
166 if (stop >= Increment && stop - Increment >= current_value) {
167 range_end = current_value + Increment;
168 } else {
169 range_end = stop;
170 }
171 current_value = range_end;
172 int pc =
173 static_cast<int>(100.0 * (range_end - start) / (stop - start));
174 if (current_percent != pc) {
175 new_percent = pc;
176 current_percent = pc;
177 }
178 }
179 if (new_percent >= 0) {
180 std::stringstream msg;
181 msg << new_percent << "% is in process \r";
182 std::cout << msg.str() << std::flush;
183 }
184
185 uint64_t failed_in_range = Checker::check(
186 range_begin, range_end, extra_range_bounds..., rounding);
187 if (failed_in_range > 0) {
188 std::stringstream msg;
189 msg << "Test failed for " << std::dec << failed_in_range
190 << " inputs in range: ";
191 explain_failed_range(msg, range_begin, range_end,
192 extra_range_bounds...);
193 msg << "\n";
194 std::cerr << msg.str() << std::flush;
195
196 failed.fetch_add(i: failed_in_range);
197 }
198 }
199 });
200 }
201
202 for (auto &thread : thread_list) {
203 if (thread.joinable()) {
204 thread.join();
205 }
206 }
207
208 std::cout << std::endl;
209 std::cout << "Test " << ((failed > 0) ? "FAILED" : "PASSED") << std::endl;
210 ASSERT_EQ(failed.load(), uint64_t(0));
211 }
212
213 void test_full_range_all_roundings(StorageType start, StorageType stop) {
214 std::cout << "-- Testing for FE_TONEAREST in range [0x" << std::hex << start
215 << ", 0x" << stop << ") --" << std::dec << std::endl;
216 test_full_range(mpfr::RoundingMode::Nearest, start, stop);
217
218 std::cout << "-- Testing for FE_UPWARD in range [0x" << std::hex << start
219 << ", 0x" << stop << ") --" << std::dec << std::endl;
220 test_full_range(mpfr::RoundingMode::Upward, start, stop);
221
222 std::cout << "-- Testing for FE_DOWNWARD in range [0x" << std::hex << start
223 << ", 0x" << stop << ") --" << std::dec << std::endl;
224 test_full_range(mpfr::RoundingMode::Downward, start, stop);
225
226 std::cout << "-- Testing for FE_TOWARDZERO in range [0x" << std::hex
227 << start << ", 0x" << stop << ") --" << std::dec << std::endl;
228 test_full_range(mpfr::RoundingMode::TowardZero, start, stop);
229 }
230
231 void test_full_range_all_roundings(StorageType x_start, StorageType x_stop,
232 StorageType y_start, StorageType y_stop) {
233 std::cout << "-- Testing for FE_TONEAREST in x range [0x" << std::hex
234 << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
235 << ", 0x" << y_stop << ") --" << std::dec << std::endl;
236 test_full_range(mpfr::RoundingMode::Nearest, x_start, x_stop, y_start,
237 y_stop);
238
239 std::cout << "-- Testing for FE_UPWARD in x range [0x" << std::hex
240 << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
241 << ", 0x" << y_stop << ") --" << std::dec << std::endl;
242 test_full_range(mpfr::RoundingMode::Upward, x_start, x_stop, y_start,
243 y_stop);
244
245 std::cout << "-- Testing for FE_DOWNWARD in x range [0x" << std::hex
246 << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
247 << ", 0x" << y_stop << ") --" << std::dec << std::endl;
248 test_full_range(mpfr::RoundingMode::Downward, x_start, x_stop, y_start,
249 y_stop);
250
251 std::cout << "-- Testing for FE_TOWARDZERO in x range [0x" << std::hex
252 << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
253 << ", 0x" << y_stop << ") --" << std::dec << std::endl;
254 test_full_range(mpfr::RoundingMode::TowardZero, x_start, x_stop, y_start,
255 y_stop);
256 }
257};
258
259template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func>
260using LlvmLibcUnaryOpExhaustiveMathTest =
261 LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, FloatType, Op, Func>>;
262
263template <typename OutType, typename InType, mpfr::Operation Op,
264 UnaryOp<OutType, InType> Func>
265using LlvmLibcUnaryNarrowingOpExhaustiveMathTest =
266 LlvmLibcExhaustiveMathTest<UnaryOpChecker<OutType, InType, Op, Func>>;
267
268template <typename FloatType, mpfr::Operation Op, BinaryOp<FloatType> Func>
269using LlvmLibcBinaryOpExhaustiveMathTest =
270 LlvmLibcExhaustiveMathTest<BinaryOpChecker<FloatType, FloatType, Op, Func>,
271 1 << 2>;
272

source code of libc/test/src/math/exhaustive/exhaustive_test.h