| 1 | //===-- Exhaustive test template for math functions -------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "src/__support/CPP/type_traits.h" |
| 10 | #include "src/__support/FPUtil/FPBits.h" |
| 11 | #include "src/__support/macros/properties/types.h" |
| 12 | #include "test/UnitTest/FPMatcher.h" |
| 13 | #include "test/UnitTest/Test.h" |
| 14 | #include "utils/MPFRWrapper/MPFRUtils.h" |
| 15 | |
| 16 | #include <atomic> |
| 17 | #include <functional> |
| 18 | #include <iostream> |
| 19 | #include <mutex> |
| 20 | #include <sstream> |
| 21 | #include <thread> |
| 22 | #include <vector> |
| 23 | |
| 24 | // To test exhaustively for inputs in the range [start, stop) in parallel: |
| 25 | // 1. Define a Checker class with: |
| 26 | // - FloatType: define floating point type to be used. |
| 27 | // - FPBits: fputil::FPBits<FloatType>. |
| 28 | // - StorageType: define bit type for the corresponding floating point type. |
| 29 | // - uint64_t check(start, stop, rounding_mode): a method to test in given |
| 30 | // range for a given rounding mode, which returns the number of |
| 31 | // failures. |
| 32 | // 2. Use LlvmLibcExhaustiveMathTest<Checker> class |
| 33 | // 3. Call: test_full_range(start, stop, nthreads, rounding) |
| 34 | // or test_full_range_all_roundings(start, stop). |
| 35 | // * For single input single output math function, use the convenient template: |
| 36 | // LlvmLibcUnaryOpExhaustiveMathTest<FloatType, Op, Func>. |
| 37 | namespace mpfr = LIBC_NAMESPACE::testing::mpfr; |
| 38 | |
| 39 | template <typename OutType, typename InType = OutType> |
| 40 | using UnaryOp = OutType(InType); |
| 41 | |
| 42 | template <typename OutType, typename InType, mpfr::Operation Op, |
| 43 | UnaryOp<OutType, InType> Func> |
| 44 | struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test { |
| 45 | using FloatType = InType; |
| 46 | using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>; |
| 47 | using StorageType = typename FPBits::StorageType; |
| 48 | |
| 49 | // Check in a range, return the number of failures. |
| 50 | uint64_t check(StorageType start, StorageType stop, |
| 51 | mpfr::RoundingMode rounding) { |
| 52 | mpfr::ForceRoundingMode r(rounding); |
| 53 | if (!r.success) |
| 54 | return (stop > start); |
| 55 | StorageType bits = start; |
| 56 | uint64_t failed = 0; |
| 57 | do { |
| 58 | FPBits xbits(bits); |
| 59 | FloatType x = xbits.get_val(); |
| 60 | bool correct = |
| 61 | TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, Func(x), 0.5, rounding); |
| 62 | failed += (!correct); |
| 63 | // Uncomment to print out failed values. |
| 64 | if (!correct) { |
| 65 | EXPECT_MPFR_MATCH_ROUNDING(Op, x, Func(x), 0.5, rounding); |
| 66 | } |
| 67 | } while (bits++ < stop); |
| 68 | return failed; |
| 69 | } |
| 70 | }; |
| 71 | |
| 72 | template <typename OutType, typename InType = OutType> |
| 73 | using BinaryOp = OutType(InType, InType); |
| 74 | |
| 75 | template <typename OutType, typename InType, mpfr::Operation Op, |
| 76 | BinaryOp<OutType, InType> Func> |
| 77 | struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test { |
| 78 | using FloatType = InType; |
| 79 | using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>; |
| 80 | using StorageType = typename FPBits::StorageType; |
| 81 | |
| 82 | // Check in a range, return the number of failures. |
| 83 | uint64_t check(StorageType x_start, StorageType x_stop, StorageType y_start, |
| 84 | StorageType y_stop, mpfr::RoundingMode rounding) { |
| 85 | mpfr::ForceRoundingMode r(rounding); |
| 86 | if (!r.success) |
| 87 | return x_stop > x_start || y_stop > y_start; |
| 88 | StorageType xbits = x_start; |
| 89 | uint64_t failed = 0; |
| 90 | do { |
| 91 | FloatType x = FPBits(xbits).get_val(); |
| 92 | StorageType ybits = y_start; |
| 93 | do { |
| 94 | FloatType y = FPBits(ybits).get_val(); |
| 95 | mpfr::BinaryInput<FloatType> input{x, y}; |
| 96 | bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, input, Func(x, y), |
| 97 | 0.5, rounding); |
| 98 | failed += (!correct); |
| 99 | // Uncomment to print out failed values. |
| 100 | if (!correct) { |
| 101 | EXPECT_MPFR_MATCH_ROUNDING(Op, input, Func(x, y), 0.5, rounding); |
| 102 | } |
| 103 | } while (ybits++ < y_stop); |
| 104 | } while (xbits++ < x_stop); |
| 105 | return failed; |
| 106 | } |
| 107 | }; |
| 108 | |
| 109 | // Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide |
| 110 | // StorageType and check method. |
| 111 | template <typename Checker, size_t Increment = 1 << 20> |
| 112 | struct LlvmLibcExhaustiveMathTest |
| 113 | : public virtual LIBC_NAMESPACE::testing::Test, |
| 114 | public Checker { |
| 115 | using FloatType = typename Checker::FloatType; |
| 116 | using FPBits = typename Checker::FPBits; |
| 117 | using StorageType = typename Checker::StorageType; |
| 118 | |
| 119 | void explain_failed_range(std::stringstream &msg, StorageType x_begin, |
| 120 | StorageType x_end) { |
| 121 | #ifdef LIBC_TYPES_HAS_FLOAT16 |
| 122 | using T = LIBC_NAMESPACE::cpp::conditional_t< |
| 123 | LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float, FloatType>; |
| 124 | #else |
| 125 | using T = FloatType; |
| 126 | #endif |
| 127 | |
| 128 | msg << x_begin << " to " << x_end << " [0x" << std::hex << x_begin << ", 0x" |
| 129 | << x_end << "), [" << std::hexfloat |
| 130 | << static_cast<T>(FPBits(x_begin).get_val()) << ", " |
| 131 | << static_cast<T>(FPBits(x_end).get_val()) << ")" ; |
| 132 | } |
| 133 | |
| 134 | void explain_failed_range(std::stringstream &msg, StorageType x_begin, |
| 135 | StorageType x_end, StorageType y_begin, |
| 136 | StorageType y_end) { |
| 137 | msg << "x " ; |
| 138 | explain_failed_range(msg, x_begin, x_end); |
| 139 | msg << ", y " ; |
| 140 | explain_failed_range(msg, y_begin, y_end); |
| 141 | } |
| 142 | |
| 143 | // Break [start, stop) into `nthreads` subintervals and apply *check to each |
| 144 | // subinterval in parallel. |
| 145 | template <typename... T> |
| 146 | void test_full_range(mpfr::RoundingMode rounding, StorageType start, |
| 147 | StorageType stop, T... ) { |
| 148 | int n_threads = std::thread::hardware_concurrency(); |
| 149 | std::vector<std::thread> thread_list; |
| 150 | std::mutex mx_cur_val; |
| 151 | int current_percent = -1; |
| 152 | StorageType current_value = start; |
| 153 | std::atomic<uint64_t> failed(0); |
| 154 | |
| 155 | for (int i = 0; i < n_threads; ++i) { |
| 156 | thread_list.emplace_back([&, this]() { |
| 157 | while (true) { |
| 158 | StorageType range_begin, range_end; |
| 159 | int new_percent = -1; |
| 160 | { |
| 161 | std::lock_guard<std::mutex> lock(mx_cur_val); |
| 162 | if (current_value == stop) |
| 163 | return; |
| 164 | |
| 165 | range_begin = current_value; |
| 166 | if (stop >= Increment && stop - Increment >= current_value) { |
| 167 | range_end = current_value + Increment; |
| 168 | } else { |
| 169 | range_end = stop; |
| 170 | } |
| 171 | current_value = range_end; |
| 172 | int pc = |
| 173 | static_cast<int>(100.0 * (range_end - start) / (stop - start)); |
| 174 | if (current_percent != pc) { |
| 175 | new_percent = pc; |
| 176 | current_percent = pc; |
| 177 | } |
| 178 | } |
| 179 | if (new_percent >= 0) { |
| 180 | std::stringstream msg; |
| 181 | msg << new_percent << "% is in process \r" ; |
| 182 | std::cout << msg.str() << std::flush; |
| 183 | } |
| 184 | |
| 185 | uint64_t failed_in_range = Checker::check( |
| 186 | range_begin, range_end, extra_range_bounds..., rounding); |
| 187 | if (failed_in_range > 0) { |
| 188 | std::stringstream msg; |
| 189 | msg << "Test failed for " << std::dec << failed_in_range |
| 190 | << " inputs in range: " ; |
| 191 | explain_failed_range(msg, range_begin, range_end, |
| 192 | extra_range_bounds...); |
| 193 | msg << "\n" ; |
| 194 | std::cerr << msg.str() << std::flush; |
| 195 | |
| 196 | failed.fetch_add(i: failed_in_range); |
| 197 | } |
| 198 | } |
| 199 | }); |
| 200 | } |
| 201 | |
| 202 | for (auto &thread : thread_list) { |
| 203 | if (thread.joinable()) { |
| 204 | thread.join(); |
| 205 | } |
| 206 | } |
| 207 | |
| 208 | std::cout << std::endl; |
| 209 | std::cout << "Test " << ((failed > 0) ? "FAILED" : "PASSED" ) << std::endl; |
| 210 | ASSERT_EQ(failed.load(), uint64_t(0)); |
| 211 | } |
| 212 | |
| 213 | void test_full_range_all_roundings(StorageType start, StorageType stop) { |
| 214 | std::cout << "-- Testing for FE_TONEAREST in range [0x" << std::hex << start |
| 215 | << ", 0x" << stop << ") --" << std::dec << std::endl; |
| 216 | test_full_range(mpfr::RoundingMode::Nearest, start, stop); |
| 217 | |
| 218 | std::cout << "-- Testing for FE_UPWARD in range [0x" << std::hex << start |
| 219 | << ", 0x" << stop << ") --" << std::dec << std::endl; |
| 220 | test_full_range(mpfr::RoundingMode::Upward, start, stop); |
| 221 | |
| 222 | std::cout << "-- Testing for FE_DOWNWARD in range [0x" << std::hex << start |
| 223 | << ", 0x" << stop << ") --" << std::dec << std::endl; |
| 224 | test_full_range(mpfr::RoundingMode::Downward, start, stop); |
| 225 | |
| 226 | std::cout << "-- Testing for FE_TOWARDZERO in range [0x" << std::hex |
| 227 | << start << ", 0x" << stop << ") --" << std::dec << std::endl; |
| 228 | test_full_range(mpfr::RoundingMode::TowardZero, start, stop); |
| 229 | } |
| 230 | |
| 231 | void test_full_range_all_roundings(StorageType x_start, StorageType x_stop, |
| 232 | StorageType y_start, StorageType y_stop) { |
| 233 | std::cout << "-- Testing for FE_TONEAREST in x range [0x" << std::hex |
| 234 | << x_start << ", 0x" << x_stop << "), y range [0x" << y_start |
| 235 | << ", 0x" << y_stop << ") --" << std::dec << std::endl; |
| 236 | test_full_range(mpfr::RoundingMode::Nearest, x_start, x_stop, y_start, |
| 237 | y_stop); |
| 238 | |
| 239 | std::cout << "-- Testing for FE_UPWARD in x range [0x" << std::hex |
| 240 | << x_start << ", 0x" << x_stop << "), y range [0x" << y_start |
| 241 | << ", 0x" << y_stop << ") --" << std::dec << std::endl; |
| 242 | test_full_range(mpfr::RoundingMode::Upward, x_start, x_stop, y_start, |
| 243 | y_stop); |
| 244 | |
| 245 | std::cout << "-- Testing for FE_DOWNWARD in x range [0x" << std::hex |
| 246 | << x_start << ", 0x" << x_stop << "), y range [0x" << y_start |
| 247 | << ", 0x" << y_stop << ") --" << std::dec << std::endl; |
| 248 | test_full_range(mpfr::RoundingMode::Downward, x_start, x_stop, y_start, |
| 249 | y_stop); |
| 250 | |
| 251 | std::cout << "-- Testing for FE_TOWARDZERO in x range [0x" << std::hex |
| 252 | << x_start << ", 0x" << x_stop << "), y range [0x" << y_start |
| 253 | << ", 0x" << y_stop << ") --" << std::dec << std::endl; |
| 254 | test_full_range(mpfr::RoundingMode::TowardZero, x_start, x_stop, y_start, |
| 255 | y_stop); |
| 256 | } |
| 257 | }; |
| 258 | |
| 259 | template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func> |
| 260 | using LlvmLibcUnaryOpExhaustiveMathTest = |
| 261 | LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, FloatType, Op, Func>>; |
| 262 | |
| 263 | template <typename OutType, typename InType, mpfr::Operation Op, |
| 264 | UnaryOp<OutType, InType> Func> |
| 265 | using LlvmLibcUnaryNarrowingOpExhaustiveMathTest = |
| 266 | LlvmLibcExhaustiveMathTest<UnaryOpChecker<OutType, InType, Op, Func>>; |
| 267 | |
| 268 | template <typename FloatType, mpfr::Operation Op, BinaryOp<FloatType> Func> |
| 269 | using LlvmLibcBinaryOpExhaustiveMathTest = |
| 270 | LlvmLibcExhaustiveMathTest<BinaryOpChecker<FloatType, FloatType, Op, Func>, |
| 271 | 1 << 2>; |
| 272 | |