1 | //===-- MPCUtils.h ----------------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H |
10 | #define LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H |
11 | |
12 | #include "src/__support/CPP/type_traits.h" |
13 | #include "src/__support/complex_type.h" |
14 | #include "src/__support/macros/config.h" |
15 | #include "src/__support/macros/properties/complex_types.h" |
16 | #include "src/__support/macros/properties/types.h" |
17 | #include "test/UnitTest/RoundingModeUtils.h" |
18 | #include "test/UnitTest/Test.h" |
19 | |
20 | #include <stdint.h> |
21 | |
22 | namespace LIBC_NAMESPACE_DECL { |
23 | namespace testing { |
24 | namespace mpc { |
25 | |
26 | enum class Operation { |
27 | // Operations which take a single complex floating point number as input |
28 | // and produce a single floating point number as output which has the same |
29 | // floating point type as the real/imaginary part of the input. |
30 | BeginUnaryOperationsSingleOutputDifferentOutputType, |
31 | Carg, |
32 | Cabs, |
33 | EndUnaryOperationsSingleOutputDifferentOutputType, |
34 | |
35 | // Operations which take a single complex floating point number as input |
36 | // and produce a single complex floating point number of the same kind |
37 | // as output. |
38 | BeginUnaryOperationsSingleOutputSameOutputType, |
39 | Cproj, |
40 | Csqrt, |
41 | Clog, |
42 | Cexp, |
43 | Csinh, |
44 | Ccosh, |
45 | Ctanh, |
46 | Casinh, |
47 | Cacosh, |
48 | Catanh, |
49 | Csin, |
50 | Ccos, |
51 | Ctan, |
52 | Casin, |
53 | Cacos, |
54 | Catan, |
55 | EndUnaryOperationsSingleOutputSameOutputType, |
56 | |
57 | // Operations which take two complex floating point numbers as input |
58 | // and produce a single complex floating point number of the same kind |
59 | // as output. |
60 | BeginBinaryOperationsSingleOutput, |
61 | Cpow, |
62 | EndBinaryOperationsSingleOutput, |
63 | }; |
64 | |
65 | using LIBC_NAMESPACE::fputil::testing::RoundingMode; |
66 | |
67 | template <typename T> struct BinaryInput { |
68 | static_assert(LIBC_NAMESPACE::cpp::is_complex_v<T>, |
69 | "Template parameter of BinaryInput must be a complex floating " |
70 | "point type." ); |
71 | |
72 | using Type = T; |
73 | T x, y; |
74 | }; |
75 | |
76 | namespace internal { |
77 | |
78 | template <typename InputType, typename OutputType> |
79 | bool compare_unary_operation_single_output_same_type(Operation op, |
80 | InputType input, |
81 | OutputType libc_output, |
82 | double ulp_tolerance, |
83 | RoundingMode rounding); |
84 | |
85 | template <typename InputType, typename OutputType> |
86 | bool compare_unary_operation_single_output_different_type( |
87 | Operation op, InputType input, OutputType libc_output, double ulp_tolerance, |
88 | RoundingMode rounding); |
89 | |
90 | template <typename InputType, typename OutputType> |
91 | bool compare_binary_operation_one_output(Operation op, |
92 | const BinaryInput<InputType> &input, |
93 | OutputType libc_output, |
94 | double ulp_tolerance, |
95 | RoundingMode rounding); |
96 | |
97 | template <typename InputType, typename OutputType> |
98 | void explain_unary_operation_single_output_same_type_error( |
99 | Operation op, InputType input, OutputType match_value, double ulp_tolerance, |
100 | RoundingMode rounding); |
101 | |
102 | template <typename InputType, typename OutputType> |
103 | void explain_unary_operation_single_output_different_type_error( |
104 | Operation op, InputType input, OutputType match_value, double ulp_tolerance, |
105 | RoundingMode rounding); |
106 | |
107 | template <typename InputType, typename OutputType> |
108 | void explain_binary_operation_one_output_error( |
109 | Operation op, const BinaryInput<InputType> &input, OutputType match_value, |
110 | double ulp_tolerance, RoundingMode rounding); |
111 | |
112 | template <Operation op, typename InputType, typename OutputType> |
113 | class MPCMatcher : public testing::Matcher<OutputType> { |
114 | private: |
115 | InputType input; |
116 | OutputType match_value; |
117 | double ulp_tolerance; |
118 | RoundingMode rounding; |
119 | |
120 | public: |
121 | MPCMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding) |
122 | : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {} |
123 | |
124 | bool match(OutputType libcResult) { |
125 | match_value = libcResult; |
126 | return match(input, match_value); |
127 | } |
128 | |
129 | void explainError() override { // NOLINT |
130 | explain_error(input, match_value); |
131 | } |
132 | |
133 | private: |
134 | template <typename InType, typename OutType> |
135 | bool match(InType in, OutType out) { |
136 | if (cpp::is_same_v<InType, OutType>) { |
137 | return compare_unary_operation_single_output_same_type( |
138 | op, in, out, ulp_tolerance, rounding); |
139 | } else { |
140 | return compare_unary_operation_single_output_different_type( |
141 | op, in, out, ulp_tolerance, rounding); |
142 | } |
143 | } |
144 | |
145 | template <typename T, typename U> |
146 | bool match(const BinaryInput<T> &in, U out) { |
147 | return compare_binary_operation_one_output(op, in, out, ulp_tolerance, |
148 | rounding); |
149 | } |
150 | |
151 | template <typename InType, typename OutType> |
152 | void explain_error(InType in, OutType out) { |
153 | if (cpp::is_same_v<InType, OutType>) { |
154 | explain_unary_operation_single_output_same_type_error( |
155 | op, in, out, ulp_tolerance, rounding); |
156 | } else { |
157 | explain_unary_operation_single_output_different_type_error( |
158 | op, in, out, ulp_tolerance, rounding); |
159 | } |
160 | } |
161 | |
162 | template <typename T, typename U> |
163 | void explain_error(const BinaryInput<T> &in, U out) { |
164 | explain_binary_operation_one_output_error(op, in, out, ulp_tolerance, |
165 | rounding); |
166 | } |
167 | }; |
168 | |
169 | } // namespace internal |
170 | |
171 | // Return true if the input and ouput types for the operation op are valid |
172 | // types. |
173 | template <Operation op, typename InputType, typename OutputType> |
174 | constexpr bool is_valid_operation() { |
175 | return (Operation::BeginBinaryOperationsSingleOutput < op && |
176 | op < Operation::EndBinaryOperationsSingleOutput && |
177 | cpp::is_complex_type_same<InputType, OutputType>() && |
178 | cpp::is_complex_v<InputType>) || |
179 | (Operation::BeginUnaryOperationsSingleOutputSameOutputType < op && |
180 | op < Operation::EndUnaryOperationsSingleOutputSameOutputType && |
181 | cpp::is_complex_type_same<InputType, OutputType>() && |
182 | cpp::is_complex_v<InputType>) || |
183 | (Operation::BeginUnaryOperationsSingleOutputDifferentOutputType < op && |
184 | op < Operation::EndUnaryOperationsSingleOutputDifferentOutputType && |
185 | cpp::is_same_v<make_real_t<InputType>, OutputType> && |
186 | cpp::is_complex_v<InputType>); |
187 | } |
188 | |
189 | template <Operation op, typename InputType, typename OutputType> |
190 | cpp::enable_if_t<is_valid_operation<op, InputType, OutputType>(), |
191 | internal::MPCMatcher<op, InputType, OutputType>> |
192 | get_mpc_matcher(InputType input, [[maybe_unused]] OutputType output, |
193 | double ulp_tolerance, RoundingMode rounding) { |
194 | return internal::MPCMatcher<op, InputType, OutputType>(input, ulp_tolerance, |
195 | rounding); |
196 | } |
197 | |
198 | } // namespace mpc |
199 | } // namespace testing |
200 | } // namespace LIBC_NAMESPACE_DECL |
201 | |
202 | #define EXPECT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ |
203 | EXPECT_THAT(match_value, \ |
204 | LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \ |
205 | input, match_value, ulp_tolerance, \ |
206 | LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest)) |
207 | |
208 | #define EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ |
209 | rounding) \ |
210 | EXPECT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \ |
211 | input, match_value, ulp_tolerance, rounding)) |
212 | |
213 | #define EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ |
214 | ulp_tolerance, rounding) \ |
215 | { \ |
216 | MPCRND::ForceRoundingMode __r(rounding); \ |
217 | if (__r.success) { \ |
218 | EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ |
219 | rounding); \ |
220 | } \ |
221 | } |
222 | |
223 | #define EXPECT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \ |
224 | { \ |
225 | namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \ |
226 | for (int i = 0; i < 4; i++) { \ |
227 | MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i); \ |
228 | EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ |
229 | ulp_tolerance, r_mode); \ |
230 | } \ |
231 | } |
232 | |
233 | #define TEST_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ |
234 | rounding) \ |
235 | LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(input, match_value, \ |
236 | ulp_tolerance, rounding) \ |
237 | .match(match_value) |
238 | |
239 | #define ASSERT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ |
240 | ASSERT_THAT(match_value, \ |
241 | LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \ |
242 | input, match_value, ulp_tolerance, \ |
243 | LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest)) |
244 | |
245 | #define ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ |
246 | rounding) \ |
247 | ASSERT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \ |
248 | input, match_value, ulp_tolerance, rounding)) |
249 | |
250 | #define ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ |
251 | ulp_tolerance, rounding) \ |
252 | { \ |
253 | MPCRND::ForceRoundingMode __r(rounding); \ |
254 | if (__r.success) { \ |
255 | ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ |
256 | rounding); \ |
257 | } \ |
258 | } |
259 | |
260 | #define ASSERT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \ |
261 | { \ |
262 | namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \ |
263 | for (int i = 0; i < 4; i++) { \ |
264 | MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i); \ |
265 | ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ |
266 | ulp_tolerance, r_mode); \ |
267 | } \ |
268 | } |
269 | |
270 | #endif // LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H |
271 | |