1 | //===-- Exhaustive test template for math functions -------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "src/__support/CPP/type_traits.h" |
10 | #include "src/__support/FPUtil/FPBits.h" |
11 | #include "test/UnitTest/FPMatcher.h" |
12 | #include "test/UnitTest/Test.h" |
13 | #include "utils/MPFRWrapper/MPFRUtils.h" |
14 | |
15 | #include <atomic> |
16 | #include <functional> |
17 | #include <iostream> |
18 | #include <mutex> |
19 | #include <sstream> |
20 | #include <thread> |
21 | #include <vector> |
22 | |
23 | // To test exhaustively for inputs in the range [start, stop) in parallel: |
24 | // 1. Define a Checker class with: |
25 | // - FloatType: define floating point type to be used. |
26 | // - FPBits: fputil::FPBits<FloatType>. |
27 | // - StorageType: define bit type for the corresponding floating point type. |
28 | // - uint64_t check(start, stop, rounding_mode): a method to test in given |
29 | // range for a given rounding mode, which returns the number of |
30 | // failures. |
31 | // 2. Use LlvmLibcExhaustiveMathTest<Checker> class |
32 | // 3. Call: test_full_range(start, stop, nthreads, rounding) |
33 | // or test_full_range_all_roundings(start, stop). |
34 | // * For single input single output math function, use the convenient template: |
35 | // LlvmLibcUnaryOpExhaustiveMathTest<FloatType, Op, Func>. |
36 | namespace mpfr = LIBC_NAMESPACE::testing::mpfr; |
37 | |
38 | template <typename T> using UnaryOp = T(T); |
39 | |
40 | template <typename T, mpfr::Operation Op, UnaryOp<T> Func> |
41 | struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test { |
42 | using FloatType = T; |
43 | using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>; |
44 | using StorageType = typename FPBits::StorageType; |
45 | |
46 | static constexpr UnaryOp<FloatType> *FUNC = Func; |
47 | |
48 | // Check in a range, return the number of failures. |
49 | uint64_t check(StorageType start, StorageType stop, |
50 | mpfr::RoundingMode rounding) { |
51 | mpfr::ForceRoundingMode r(rounding); |
52 | if (!r.success) |
53 | return (stop > start); |
54 | StorageType bits = start; |
55 | uint64_t failed = 0; |
56 | do { |
57 | FPBits xbits(bits); |
58 | FloatType x = xbits.get_val(); |
59 | bool correct = |
60 | TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, FUNC(x), 0.5, rounding); |
61 | failed += (!correct); |
62 | // Uncomment to print out failed values. |
63 | // if (!correct) { |
64 | // TEST_MPFR_MATCH(Op::Operation, x, Op::func(x), 0.5, rounding); |
65 | // } |
66 | } while (bits++ < stop); |
67 | return failed; |
68 | } |
69 | }; |
70 | |
71 | // Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide |
72 | // StorageType and check method. |
73 | template <typename Checker> |
74 | struct LlvmLibcExhaustiveMathTest |
75 | : public virtual LIBC_NAMESPACE::testing::Test, |
76 | public Checker { |
77 | using FloatType = typename Checker::FloatType; |
78 | using FPBits = typename Checker::FPBits; |
79 | using StorageType = typename Checker::StorageType; |
80 | |
81 | static constexpr StorageType INCREMENT = (1 << 20); |
82 | |
83 | // Break [start, stop) into `nthreads` subintervals and apply *check to each |
84 | // subinterval in parallel. |
85 | void test_full_range(StorageType start, StorageType stop, |
86 | mpfr::RoundingMode rounding) { |
87 | int n_threads = std::thread::hardware_concurrency(); |
88 | std::vector<std::thread> thread_list; |
89 | std::mutex mx_cur_val; |
90 | int current_percent = -1; |
91 | StorageType current_value = start; |
92 | std::atomic<uint64_t> failed(0); |
93 | |
94 | for (int i = 0; i < n_threads; ++i) { |
95 | thread_list.emplace_back([&, this]() { |
96 | while (true) { |
97 | StorageType range_begin, range_end; |
98 | int new_percent = -1; |
99 | { |
100 | std::lock_guard<std::mutex> lock(mx_cur_val); |
101 | if (current_value == stop) |
102 | return; |
103 | |
104 | range_begin = current_value; |
105 | if (stop >= INCREMENT && stop - INCREMENT >= current_value) { |
106 | range_end = current_value + INCREMENT; |
107 | } else { |
108 | range_end = stop; |
109 | } |
110 | current_value = range_end; |
111 | int pc = 100.0 * (range_end - start) / (stop - start); |
112 | if (current_percent != pc) { |
113 | new_percent = pc; |
114 | current_percent = pc; |
115 | } |
116 | } |
117 | if (new_percent >= 0) { |
118 | std::stringstream msg; |
119 | msg << new_percent << "% is in process \r" ; |
120 | std::cout << msg.str() << std::flush; |
121 | } |
122 | |
123 | uint64_t failed_in_range = |
124 | Checker::check(range_begin, range_end, rounding); |
125 | if (failed_in_range > 0) { |
126 | std::stringstream msg; |
127 | msg << "Test failed for " << std::dec << failed_in_range |
128 | << " inputs in range: " << range_begin << " to " << range_end |
129 | << " [0x" << std::hex << range_begin << ", 0x" << range_end |
130 | << "), [" << std::hexfloat << FPBits(range_begin).get_val() |
131 | << ", " << FPBits(range_end).get_val() << ")\n" ; |
132 | std::cerr << msg.str() << std::flush; |
133 | |
134 | failed.fetch_add(i: failed_in_range); |
135 | } |
136 | } |
137 | }); |
138 | } |
139 | |
140 | for (auto &thread : thread_list) { |
141 | if (thread.joinable()) { |
142 | thread.join(); |
143 | } |
144 | } |
145 | |
146 | std::cout << std::endl; |
147 | std::cout << "Test " << ((failed > 0) ? "FAILED" : "PASSED" ) << std::endl; |
148 | ASSERT_EQ(failed.load(), uint64_t(0)); |
149 | } |
150 | |
151 | void test_full_range_all_roundings(StorageType start, StorageType stop) { |
152 | std::cout << "-- Testing for FE_TONEAREST in range [0x" << std::hex << start |
153 | << ", 0x" << stop << ") --" << std::dec << std::endl; |
154 | test_full_range(start, stop, rounding: mpfr::RoundingMode::Nearest); |
155 | |
156 | std::cout << "-- Testing for FE_UPWARD in range [0x" << std::hex << start |
157 | << ", 0x" << stop << ") --" << std::dec << std::endl; |
158 | test_full_range(start, stop, rounding: mpfr::RoundingMode::Upward); |
159 | |
160 | std::cout << "-- Testing for FE_DOWNWARD in range [0x" << std::hex << start |
161 | << ", 0x" << stop << ") --" << std::dec << std::endl; |
162 | test_full_range(start, stop, rounding: mpfr::RoundingMode::Downward); |
163 | |
164 | std::cout << "-- Testing for FE_TOWARDZERO in range [0x" << std::hex |
165 | << start << ", 0x" << stop << ") --" << std::dec << std::endl; |
166 | test_full_range(start, stop, rounding: mpfr::RoundingMode::TowardZero); |
167 | }; |
168 | }; |
169 | |
170 | template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func> |
171 | using LlvmLibcUnaryOpExhaustiveMathTest = |
172 | LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, Op, Func>>; |
173 | |