1//===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H
10#define LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H
11
12#include "src/__support/CPP/type_traits.h"
13#include "src/__support/macros/config.h"
14#include "test/UnitTest/RoundingModeUtils.h"
15#include "test/UnitTest/Test.h"
16
17#include <stdint.h>
18
19namespace LIBC_NAMESPACE_DECL {
20namespace testing {
21namespace mpfr {
22
23enum class Operation : int {
24 // Operations which take a single floating point number as input
25 // and produce a single floating point number as output. The input
26 // and output floating point numbers are of the same kind.
27 BeginUnaryOperationsSingleOutput,
28 Abs,
29 Acos,
30 Acosh,
31 Acospi,
32 Asin,
33 Asinh,
34 Atan,
35 Atanh,
36 Cbrt,
37 Ceil,
38 Cos,
39 Cosh,
40 Cospi,
41 Erf,
42 Exp,
43 Exp2,
44 Exp2m1,
45 Exp10,
46 Exp10m1,
47 Expm1,
48 Floor,
49 Log,
50 Log2,
51 Log10,
52 Log1p,
53 Mod2PI,
54 ModPIOver2,
55 ModPIOver4,
56 Round,
57 RoundEven,
58 Sin,
59 Sinpi,
60 Sinh,
61 Sqrt,
62 Tan,
63 Tanh,
64 Tanpi,
65 Trunc,
66 EndUnaryOperationsSingleOutput,
67
68 // Operations which take a single floating point nubmer as input
69 // but produce two outputs. The first ouput is a floating point
70 // number of the same type as the input. The second output is of type
71 // 'int'.
72 BeginUnaryOperationsTwoOutputs,
73 Frexp, // Floating point output, the first output, is the fractional part.
74 EndUnaryOperationsTwoOutputs,
75
76 // Operations wich take two floating point nubmers of the same type as
77 // input and produce a single floating point number of the same type as
78 // output.
79 BeginBinaryOperationsSingleOutput,
80 Add,
81 Atan2,
82 Div,
83 Fmod,
84 Hypot,
85 Mul,
86 Pow,
87 Sub,
88 EndBinaryOperationsSingleOutput,
89
90 // Operations which take two floating point numbers of the same type as
91 // input and produce two outputs. The first output is a floating point number
92 // of the same type as the inputs. The second output is of type 'int'.
93 BeginBinaryOperationsTwoOutputs,
94 RemQuo, // The first output(floating point) is the remainder.
95 EndBinaryOperationsTwoOutputs,
96
97 // Operations which take three floating point nubmers of the same type as
98 // input and produce a single floating point number of the same type as
99 // output.
100 BeginTernaryOperationsSingleOuput,
101 Fma,
102 EndTernaryOperationsSingleOutput,
103};
104
105using LIBC_NAMESPACE::fputil::testing::ForceRoundingMode;
106using LIBC_NAMESPACE::fputil::testing::RoundingMode;
107
108template <typename T> struct BinaryInput {
109 static_assert(
110 LIBC_NAMESPACE::cpp::is_floating_point_v<T>,
111 "Template parameter of BinaryInput must be a floating point type.");
112
113 using Type = T;
114 T x, y;
115};
116
117template <typename T> struct TernaryInput {
118 static_assert(
119 LIBC_NAMESPACE::cpp::is_floating_point_v<T>,
120 "Template parameter of TernaryInput must be a floating point type.");
121
122 using Type = T;
123 T x, y, z;
124};
125
126template <typename T> struct BinaryOutput {
127 T f;
128 int i;
129};
130
131namespace internal {
132
133template <typename T1, typename T2>
134struct AreMatchingBinaryInputAndBinaryOutput {
135 static constexpr bool VALUE = false;
136};
137
138template <typename T>
139struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
140 static constexpr bool VALUE = cpp::is_floating_point_v<T>;
141};
142
143template <typename T> struct IsBinaryInput {
144 static constexpr bool VALUE = false;
145};
146
147template <typename T> struct IsBinaryInput<BinaryInput<T>> {
148 static constexpr bool VALUE = true;
149};
150
151template <typename T> struct IsTernaryInput {
152 static constexpr bool VALUE = false;
153};
154
155template <typename T> struct IsTernaryInput<TernaryInput<T>> {
156 static constexpr bool VALUE = true;
157};
158
159template <typename T> struct MakeScalarInput : cpp::type_identity<T> {};
160
161template <typename T>
162struct MakeScalarInput<BinaryInput<T>> : cpp::type_identity<T> {};
163
164template <typename T>
165struct MakeScalarInput<TernaryInput<T>> : cpp::type_identity<T> {};
166
167template <typename InputType, typename OutputType>
168bool compare_unary_operation_single_output(Operation op, InputType input,
169 OutputType libc_output,
170 double ulp_tolerance,
171 RoundingMode rounding);
172template <typename T>
173bool compare_unary_operation_two_outputs(Operation op, T input,
174 const BinaryOutput<T> &libc_output,
175 double ulp_tolerance,
176 RoundingMode rounding);
177template <typename T>
178bool compare_binary_operation_two_outputs(Operation op,
179 const BinaryInput<T> &input,
180 const BinaryOutput<T> &libc_output,
181 double ulp_tolerance,
182 RoundingMode rounding);
183
184template <typename InputType, typename OutputType>
185bool compare_binary_operation_one_output(Operation op,
186 const BinaryInput<InputType> &input,
187 OutputType libc_output,
188 double ulp_tolerance,
189 RoundingMode rounding);
190
191template <typename InputType, typename OutputType>
192bool compare_ternary_operation_one_output(Operation op,
193 const TernaryInput<InputType> &input,
194 OutputType libc_output,
195 double ulp_tolerance,
196 RoundingMode rounding);
197
198template <typename InputType, typename OutputType>
199void explain_unary_operation_single_output_error(Operation op, InputType input,
200 OutputType match_value,
201 double ulp_tolerance,
202 RoundingMode rounding);
203template <typename T>
204void explain_unary_operation_two_outputs_error(
205 Operation op, T input, const BinaryOutput<T> &match_value,
206 double ulp_tolerance, RoundingMode rounding);
207template <typename T>
208void explain_binary_operation_two_outputs_error(
209 Operation op, const BinaryInput<T> &input,
210 const BinaryOutput<T> &match_value, double ulp_tolerance,
211 RoundingMode rounding);
212
213template <typename InputType, typename OutputType>
214void explain_binary_operation_one_output_error(
215 Operation op, const BinaryInput<InputType> &input, OutputType match_value,
216 double ulp_tolerance, RoundingMode rounding);
217
218template <typename InputType, typename OutputType>
219void explain_ternary_operation_one_output_error(
220 Operation op, const TernaryInput<InputType> &input, OutputType match_value,
221 double ulp_tolerance, RoundingMode rounding);
222
223template <Operation op, bool silent, typename InputType, typename OutputType>
224class MPFRMatcher : public testing::Matcher<OutputType> {
225 InputType input;
226 OutputType match_value;
227 double ulp_tolerance;
228 RoundingMode rounding;
229
230public:
231 MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
232 : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}
233
234 bool match(OutputType libcResult) {
235 match_value = libcResult;
236 return match(input, match_value);
237 }
238
239 // This method is marked with NOLINT because the name `explainError` does not
240 // conform to the coding style.
241 void explainError() override { // NOLINT
242 explain_error(input, match_value);
243 }
244
245 // Whether the `explainError` step is skipped or not.
246 bool is_silent() const override { return silent; }
247
248private:
249 template <typename InType, typename OutType>
250 bool match(InType in, OutType out) {
251 return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
252 rounding);
253 }
254
255 template <typename T> bool match(T in, const BinaryOutput<T> &out) {
256 return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
257 rounding);
258 }
259
260 template <typename T, typename U>
261 bool match(const BinaryInput<T> &in, U out) {
262 return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
263 rounding);
264 }
265
266 template <typename T>
267 bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
268 return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
269 rounding);
270 }
271
272 template <typename InType, typename OutType>
273 bool match(const TernaryInput<InType> &in, OutType out) {
274 return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
275 rounding);
276 }
277
278 template <typename InType, typename OutType>
279 void explain_error(InType in, OutType out) {
280 explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
281 rounding);
282 }
283
284 template <typename T> void explain_error(T in, const BinaryOutput<T> &out) {
285 explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
286 rounding);
287 }
288
289 template <typename T>
290 void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out) {
291 explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
292 rounding);
293 }
294
295 template <typename T, typename U>
296 void explain_error(const BinaryInput<T> &in, U out) {
297 explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
298 rounding);
299 }
300
301 template <typename InType, typename OutType>
302 void explain_error(const TernaryInput<InType> &in, OutType out) {
303 explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
304 rounding);
305 }
306};
307
308} // namespace internal
309
310// Return true if the input and ouput types for the operation op are valid
311// types.
312template <Operation op, typename InputType, typename OutputType>
313constexpr bool is_valid_operation() {
314 constexpr bool IS_NARROWING_OP =
315 (op == Operation::Sqrt && cpp::is_floating_point_v<InputType> &&
316 cpp::is_floating_point_v<OutputType> &&
317 sizeof(OutputType) <= sizeof(InputType)) ||
318 (Operation::BeginBinaryOperationsSingleOutput < op &&
319 op < Operation::EndBinaryOperationsSingleOutput &&
320 internal::IsBinaryInput<InputType>::VALUE &&
321 cpp::is_floating_point_v<
322 typename internal::MakeScalarInput<InputType>::type> &&
323 cpp::is_floating_point_v<OutputType>) ||
324 (op == Operation::Fma && internal::IsTernaryInput<InputType>::VALUE &&
325 cpp::is_floating_point_v<
326 typename internal::MakeScalarInput<InputType>::type> &&
327 cpp::is_floating_point_v<OutputType>);
328 if (IS_NARROWING_OP)
329 return true;
330 return (Operation::BeginUnaryOperationsSingleOutput < op &&
331 op < Operation::EndUnaryOperationsSingleOutput &&
332 cpp::is_same_v<InputType, OutputType> &&
333 cpp::is_floating_point_v<InputType>) ||
334 (Operation::BeginUnaryOperationsTwoOutputs < op &&
335 op < Operation::EndUnaryOperationsTwoOutputs &&
336 cpp::is_floating_point_v<InputType> &&
337 cpp::is_same_v<OutputType, BinaryOutput<InputType>>) ||
338 (Operation::BeginBinaryOperationsSingleOutput < op &&
339 op < Operation::EndBinaryOperationsSingleOutput &&
340 cpp::is_floating_point_v<OutputType> &&
341 cpp::is_same_v<InputType, BinaryInput<OutputType>>) ||
342 (Operation::BeginBinaryOperationsTwoOutputs < op &&
343 op < Operation::EndBinaryOperationsTwoOutputs &&
344 internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
345 OutputType>::VALUE) ||
346 (Operation::BeginTernaryOperationsSingleOuput < op &&
347 op < Operation::EndTernaryOperationsSingleOutput &&
348 cpp::is_floating_point_v<OutputType> &&
349 cpp::is_same_v<InputType, TernaryInput<OutputType>>);
350}
351
352template <Operation op, typename InputType, typename OutputType>
353__attribute__((no_sanitize("address"))) cpp::enable_if_t<
354 is_valid_operation<op, InputType, OutputType>(),
355 internal::MPFRMatcher<op, /*is_silent*/ false, InputType, OutputType>>
356get_mpfr_matcher(InputType input, OutputType output_unused,
357 double ulp_tolerance, RoundingMode rounding) {
358 return internal::MPFRMatcher<op, /*is_silent*/ false, InputType, OutputType>(
359 input, ulp_tolerance, rounding);
360}
361
362template <Operation op, typename InputType, typename OutputType>
363__attribute__((no_sanitize("address"))) cpp::enable_if_t<
364 is_valid_operation<op, InputType, OutputType>(),
365 internal::MPFRMatcher<op, /*is_silent*/ true, InputType, OutputType>>
366get_silent_mpfr_matcher(InputType input, OutputType output_unused,
367 double ulp_tolerance, RoundingMode rounding) {
368 return internal::MPFRMatcher<op, /*is_silent*/ true, InputType, OutputType>(
369 input, ulp_tolerance, rounding);
370}
371
372template <typename T> T round(T x, RoundingMode mode);
373
374template <typename T> bool round_to_long(T x, long &result);
375template <typename T> bool round_to_long(T x, RoundingMode mode, long &result);
376
377} // namespace mpfr
378} // namespace testing
379} // namespace LIBC_NAMESPACE_DECL
380
381// GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a
382// simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`.
383#define GET_MPFR_DUMMY_ARG(...) 0
384
385#define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME
386
387#define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
388 EXPECT_THAT(match_value, \
389 LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>( \
390 input, match_value, ulp_tolerance, \
391 LIBC_NAMESPACE::testing::mpfr::RoundingMode::Nearest))
392
393#define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
394 rounding) \
395 EXPECT_THAT(match_value, \
396 LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>( \
397 input, match_value, ulp_tolerance, rounding))
398
399#define EXPECT_MPFR_MATCH(...) \
400 GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING, \
401 EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \
402 (__VA_ARGS__)
403
404#define TEST_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
405 rounding) \
406 LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(input, match_value, \
407 ulp_tolerance, rounding) \
408 .match(match_value)
409
410#define TEST_MPFR_MATCH(...) \
411 GET_MPFR_MACRO(__VA_ARGS__, TEST_MPFR_MATCH_ROUNDING, \
412 EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \
413 (__VA_ARGS__)
414
415#define EXPECT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \
416 { \
417 namespace mpfr = LIBC_NAMESPACE::testing::mpfr; \
418 mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest); \
419 if (__r1.success) { \
420 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
421 mpfr::RoundingMode::Nearest); \
422 } \
423 mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward); \
424 if (__r2.success) { \
425 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
426 mpfr::RoundingMode::Upward); \
427 } \
428 mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward); \
429 if (__r3.success) { \
430 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
431 mpfr::RoundingMode::Downward); \
432 } \
433 mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero); \
434 if (__r4.success) { \
435 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
436 mpfr::RoundingMode::TowardZero); \
437 } \
438 }
439
440#define TEST_MPFR_MATCH_ROUNDING_SILENTLY(op, input, match_value, \
441 ulp_tolerance, rounding) \
442 LIBC_NAMESPACE::testing::mpfr::get_silent_mpfr_matcher<op>( \
443 input, match_value, ulp_tolerance, rounding) \
444 .match(match_value)
445
446#define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
447 ASSERT_THAT(match_value, \
448 LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>( \
449 input, match_value, ulp_tolerance, \
450 LIBC_NAMESPACE::testing::mpfr::RoundingMode::Nearest))
451
452#define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
453 rounding) \
454 ASSERT_THAT(match_value, \
455 LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>( \
456 input, match_value, ulp_tolerance, rounding))
457
458#define ASSERT_MPFR_MATCH(...) \
459 GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING, \
460 ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \
461 (__VA_ARGS__)
462
463#define ASSERT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \
464 { \
465 namespace mpfr = LIBC_NAMESPACE::testing::mpfr; \
466 mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest); \
467 if (__r1.success) { \
468 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
469 mpfr::RoundingMode::Nearest); \
470 } \
471 mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward); \
472 if (__r2.success) { \
473 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
474 mpfr::RoundingMode::Upward); \
475 } \
476 mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward); \
477 if (__r3.success) { \
478 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
479 mpfr::RoundingMode::Downward); \
480 } \
481 mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero); \
482 if (__r4.success) { \
483 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
484 mpfr::RoundingMode::TowardZero); \
485 } \
486 }
487
488#endif // LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H
489

source code of libc/utils/MPFRWrapper/MPFRUtils.h