1//===-- MPCUtils.h ----------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H
10#define LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H
11
12#include "src/__support/CPP/type_traits.h"
13#include "src/__support/complex_type.h"
14#include "src/__support/macros/config.h"
15#include "src/__support/macros/properties/complex_types.h"
16#include "src/__support/macros/properties/types.h"
17#include "test/UnitTest/RoundingModeUtils.h"
18#include "test/UnitTest/Test.h"
19
20#include <stdint.h>
21
22namespace LIBC_NAMESPACE_DECL {
23namespace testing {
24namespace mpc {
25
26enum class Operation {
27 // Operations which take a single complex floating point number as input
28 // and produce a single floating point number as output which has the same
29 // floating point type as the real/imaginary part of the input.
30 BeginUnaryOperationsSingleOutputDifferentOutputType,
31 Carg,
32 Cabs,
33 EndUnaryOperationsSingleOutputDifferentOutputType,
34
35 // Operations which take a single complex floating point number as input
36 // and produce a single complex floating point number of the same kind
37 // as output.
38 BeginUnaryOperationsSingleOutputSameOutputType,
39 Cproj,
40 Csqrt,
41 Clog,
42 Cexp,
43 Csinh,
44 Ccosh,
45 Ctanh,
46 Casinh,
47 Cacosh,
48 Catanh,
49 Csin,
50 Ccos,
51 Ctan,
52 Casin,
53 Cacos,
54 Catan,
55 EndUnaryOperationsSingleOutputSameOutputType,
56
57 // Operations which take two complex floating point numbers as input
58 // and produce a single complex floating point number of the same kind
59 // as output.
60 BeginBinaryOperationsSingleOutput,
61 Cpow,
62 EndBinaryOperationsSingleOutput,
63};
64
65using LIBC_NAMESPACE::fputil::testing::RoundingMode;
66
67template <typename T> struct BinaryInput {
68 static_assert(LIBC_NAMESPACE::cpp::is_complex_v<T>,
69 "Template parameter of BinaryInput must be a complex floating "
70 "point type.");
71
72 using Type = T;
73 T x, y;
74};
75
76namespace internal {
77
78template <typename InputType, typename OutputType>
79bool compare_unary_operation_single_output_same_type(Operation op,
80 InputType input,
81 OutputType libc_output,
82 double ulp_tolerance,
83 RoundingMode rounding);
84
85template <typename InputType, typename OutputType>
86bool compare_unary_operation_single_output_different_type(
87 Operation op, InputType input, OutputType libc_output, double ulp_tolerance,
88 RoundingMode rounding);
89
90template <typename InputType, typename OutputType>
91bool compare_binary_operation_one_output(Operation op,
92 const BinaryInput<InputType> &input,
93 OutputType libc_output,
94 double ulp_tolerance,
95 RoundingMode rounding);
96
97template <typename InputType, typename OutputType>
98void explain_unary_operation_single_output_same_type_error(
99 Operation op, InputType input, OutputType match_value, double ulp_tolerance,
100 RoundingMode rounding);
101
102template <typename InputType, typename OutputType>
103void explain_unary_operation_single_output_different_type_error(
104 Operation op, InputType input, OutputType match_value, double ulp_tolerance,
105 RoundingMode rounding);
106
107template <typename InputType, typename OutputType>
108void explain_binary_operation_one_output_error(
109 Operation op, const BinaryInput<InputType> &input, OutputType match_value,
110 double ulp_tolerance, RoundingMode rounding);
111
112template <Operation op, typename InputType, typename OutputType>
113class MPCMatcher : public testing::Matcher<OutputType> {
114private:
115 InputType input;
116 OutputType match_value;
117 double ulp_tolerance;
118 RoundingMode rounding;
119
120public:
121 MPCMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
122 : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}
123
124 bool match(OutputType libcResult) {
125 match_value = libcResult;
126 return match(input, match_value);
127 }
128
129 void explainError() override { // NOLINT
130 explain_error(input, match_value);
131 }
132
133private:
134 template <typename InType, typename OutType>
135 bool match(InType in, OutType out) {
136 if (cpp::is_same_v<InType, OutType>) {
137 return compare_unary_operation_single_output_same_type(
138 op, in, out, ulp_tolerance, rounding);
139 } else {
140 return compare_unary_operation_single_output_different_type(
141 op, in, out, ulp_tolerance, rounding);
142 }
143 }
144
145 template <typename T, typename U>
146 bool match(const BinaryInput<T> &in, U out) {
147 return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
148 rounding);
149 }
150
151 template <typename InType, typename OutType>
152 void explain_error(InType in, OutType out) {
153 if (cpp::is_same_v<InType, OutType>) {
154 explain_unary_operation_single_output_same_type_error(
155 op, in, out, ulp_tolerance, rounding);
156 } else {
157 explain_unary_operation_single_output_different_type_error(
158 op, in, out, ulp_tolerance, rounding);
159 }
160 }
161
162 template <typename T, typename U>
163 void explain_error(const BinaryInput<T> &in, U out) {
164 explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
165 rounding);
166 }
167};
168
169} // namespace internal
170
171// Return true if the input and ouput types for the operation op are valid
172// types.
173template <Operation op, typename InputType, typename OutputType>
174constexpr bool is_valid_operation() {
175 return (Operation::BeginBinaryOperationsSingleOutput < op &&
176 op < Operation::EndBinaryOperationsSingleOutput &&
177 cpp::is_complex_type_same<InputType, OutputType>() &&
178 cpp::is_complex_v<InputType>) ||
179 (Operation::BeginUnaryOperationsSingleOutputSameOutputType < op &&
180 op < Operation::EndUnaryOperationsSingleOutputSameOutputType &&
181 cpp::is_complex_type_same<InputType, OutputType>() &&
182 cpp::is_complex_v<InputType>) ||
183 (Operation::BeginUnaryOperationsSingleOutputDifferentOutputType < op &&
184 op < Operation::EndUnaryOperationsSingleOutputDifferentOutputType &&
185 cpp::is_same_v<make_real_t<InputType>, OutputType> &&
186 cpp::is_complex_v<InputType>);
187}
188
189template <Operation op, typename InputType, typename OutputType>
190cpp::enable_if_t<is_valid_operation<op, InputType, OutputType>(),
191 internal::MPCMatcher<op, InputType, OutputType>>
192get_mpc_matcher(InputType input, [[maybe_unused]] OutputType output,
193 double ulp_tolerance, RoundingMode rounding) {
194 return internal::MPCMatcher<op, InputType, OutputType>(input, ulp_tolerance,
195 rounding);
196}
197
198} // namespace mpc
199} // namespace testing
200} // namespace LIBC_NAMESPACE_DECL
201
202#define EXPECT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
203 EXPECT_THAT(match_value, \
204 LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \
205 input, match_value, ulp_tolerance, \
206 LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest))
207
208#define EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
209 rounding) \
210 EXPECT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \
211 input, match_value, ulp_tolerance, rounding))
212
213#define EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \
214 ulp_tolerance, rounding) \
215 { \
216 MPCRND::ForceRoundingMode __r(rounding); \
217 if (__r.success) { \
218 EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
219 rounding); \
220 } \
221 }
222
223#define EXPECT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \
224 { \
225 namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \
226 for (int i = 0; i < 4; i++) { \
227 MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i); \
228 EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \
229 ulp_tolerance, r_mode); \
230 } \
231 }
232
233#define TEST_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
234 rounding) \
235 LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(input, match_value, \
236 ulp_tolerance, rounding) \
237 .match(match_value)
238
239#define ASSERT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
240 ASSERT_THAT(match_value, \
241 LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \
242 input, match_value, ulp_tolerance, \
243 LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest))
244
245#define ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
246 rounding) \
247 ASSERT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \
248 input, match_value, ulp_tolerance, rounding))
249
250#define ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \
251 ulp_tolerance, rounding) \
252 { \
253 MPCRND::ForceRoundingMode __r(rounding); \
254 if (__r.success) { \
255 ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
256 rounding); \
257 } \
258 }
259
260#define ASSERT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \
261 { \
262 namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \
263 for (int i = 0; i < 4; i++) { \
264 MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i); \
265 ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \
266 ulp_tolerance, r_mode); \
267 } \
268 }
269
270#endif // LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H
271

source code of libc/utils/MPCWrapper/MPCUtils.h