1//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Complex/IR/Complex.h"
13#include "mlir/Dialect/Math/IR/Math.h"
14#include "mlir/IR/ImplicitLocOpBuilder.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include <memory>
19#include <type_traits>
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27
28namespace {
29
30enum class AbsFn { abs, sqrt, rsqrt };
31
32// Returns the absolute value, its square root or its reciprocal square root.
33Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
34 ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
35 Value one = b.create<arith::ConstantOp>(real.getType(),
36 b.getFloatAttr(real.getType(), 1.0));
37
38 Value absReal = b.create<math::AbsFOp>(real, fmf);
39 Value absImag = b.create<math::AbsFOp>(imag, fmf);
40
41 Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
42 Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
43 Value ratio = b.create<arith::DivFOp>(min, max, fmf);
44 Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
45 Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
46 Value result;
47
48 if (fn == AbsFn::rsqrt) {
49 ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf);
50 min = b.create<math::RsqrtOp>(min, fmf);
51 max = b.create<math::RsqrtOp>(max, fmf);
52 }
53
54 if (fn == AbsFn::sqrt) {
55 Value quarter = b.create<arith::ConstantOp>(
56 real.getType(), b.getFloatAttr(real.getType(), 0.25));
57 // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
58 Value sqrt = b.create<math::SqrtOp>(max, fmf);
59 Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmf);
60 result = b.create<arith::MulFOp>(sqrt, p025, fmf);
61 } else {
62 Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
63 result = b.create<arith::MulFOp>(max, sqrt, fmf);
64 }
65
66 Value isNaN =
67 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
68 return b.create<arith::SelectOp>(isNaN, min, result);
69}
70
71struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
72 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
73
74 LogicalResult
75 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
76 ConversionPatternRewriter &rewriter) const override {
77 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
78
79 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
80
81 Value real = b.create<complex::ReOp>(adaptor.getComplex());
82 Value imag = b.create<complex::ImOp>(adaptor.getComplex());
83 rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
84
85 return success();
86 }
87};
88
89// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
90struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
91 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
92
93 LogicalResult
94 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
95 ConversionPatternRewriter &rewriter) const override {
96 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
97
98 auto type = cast<ComplexType>(op.getType());
99 Type elementType = type.getElementType();
100 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
101
102 Value lhs = adaptor.getLhs();
103 Value rhs = adaptor.getRhs();
104
105 Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
106 Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
107 Value rhsSquaredPlusLhsSquared =
108 b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
109 Value sqrtOfRhsSquaredPlusLhsSquared =
110 b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
111
112 Value zero =
113 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
114 Value one = b.create<arith::ConstantOp>(elementType,
115 b.getFloatAttr(elementType, 1));
116 Value i = b.create<complex::CreateOp>(type, zero, one);
117 Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
118 Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
119
120 Value divResult = b.create<complex::DivOp>(
121 rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
122 Value logResult = b.create<complex::LogOp>(divResult, fmf);
123
124 Value negativeOne = b.create<arith::ConstantOp>(
125 elementType, b.getFloatAttr(elementType, -1));
126 Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
127
128 rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
129 return success();
130 }
131};
132
133template <typename ComparisonOp, arith::CmpFPredicate p>
134struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
135 using OpConversionPattern<ComparisonOp>::OpConversionPattern;
136 using ResultCombiner =
137 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
138 arith::AndIOp, arith::OrIOp>;
139
140 LogicalResult
141 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
142 ConversionPatternRewriter &rewriter) const override {
143 auto loc = op.getLoc();
144 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
145
146 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
147 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
148 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
149 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
150 Value realComparison =
151 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
152 Value imagComparison =
153 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
154
155 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
156 imagComparison);
157 return success();
158 }
159};
160
161// Default conversion which applies the BinaryStandardOp separately on the real
162// and imaginary parts. Can for example be used for complex::AddOp and
163// complex::SubOp.
164template <typename BinaryComplexOp, typename BinaryStandardOp>
165struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
166 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
167
168 LogicalResult
169 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
170 ConversionPatternRewriter &rewriter) const override {
171 auto type = cast<ComplexType>(adaptor.getLhs().getType());
172 auto elementType = cast<FloatType>(type.getElementType());
173 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
174 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
175
176 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
177 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
178 Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
179 fmf.getValue());
180 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
181 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
182 Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
183 fmf.getValue());
184 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
185 resultImag);
186 return success();
187 }
188};
189
190template <typename TrigonometricOp>
191struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
192 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
193
194 using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
195
196 LogicalResult
197 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
198 ConversionPatternRewriter &rewriter) const override {
199 auto loc = op.getLoc();
200 auto type = cast<ComplexType>(adaptor.getComplex().getType());
201 auto elementType = cast<FloatType>(type.getElementType());
202 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
203
204 Value real =
205 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
206 Value imag =
207 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
208
209 // Trigonometric ops use a set of common building blocks to convert to real
210 // ops. Here we create these building blocks and call into an op-specific
211 // implementation in the subclass to combine them.
212 Value half = rewriter.create<arith::ConstantOp>(
213 loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
214 Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
215 Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
216 Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
217 Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
218 Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
219
220 auto resultPair =
221 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
222
223 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
224 resultPair.second);
225 return success();
226 }
227
228 virtual std::pair<Value, Value>
229 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
230 Value cos, ConversionPatternRewriter &rewriter,
231 arith::FastMathFlagsAttr fmf) const = 0;
232};
233
234struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
235 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
236
237 std::pair<Value, Value> combine(Location loc, Value scaledExp,
238 Value reciprocalExp, Value sin, Value cos,
239 ConversionPatternRewriter &rewriter,
240 arith::FastMathFlagsAttr fmf) const override {
241 // Complex cosine is defined as;
242 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
243 // Plugging in:
244 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
245 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
246 // and defining t := exp(y)
247 // We get:
248 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
249 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
250 Value sum =
251 rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
252 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
253 Value diff =
254 rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
255 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
256 return {resultReal, resultImag};
257 }
258};
259
260struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
261 using OpConversionPattern<complex::DivOp>::OpConversionPattern;
262
263 LogicalResult
264 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter) const override {
266 auto loc = op.getLoc();
267 auto type = cast<ComplexType>(adaptor.getLhs().getType());
268 auto elementType = cast<FloatType>(type.getElementType());
269 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
270
271 Value lhsReal =
272 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
273 Value lhsImag =
274 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
275 Value rhsReal =
276 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
277 Value rhsImag =
278 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
279
280 // Smith's algorithm to divide complex numbers. It is just a bit smarter
281 // way to compute the following formula:
282 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
283 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
284 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
285 // = ((lhsReal * rhsReal + lhsImag * rhsImag) +
286 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
287 //
288 // Depending on whether |rhsReal| < |rhsImag| we compute either
289 // rhsRealImagRatio = rhsReal / rhsImag
290 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
291 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
292 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
293 //
294 // or
295 //
296 // rhsImagRealRatio = rhsImag / rhsReal
297 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
298 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
299 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
300 //
301 // See https://dl.acm.org/citation.cfm?id=368661 for more details.
302 Value rhsRealImagRatio =
303 rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
304 Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
305 loc, rhsImag,
306 rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
307 fmf);
308 Value realNumerator1 = rewriter.create<arith::AddFOp>(
309 loc,
310 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
311 lhsImag, fmf);
312 Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
313 rhsRealImagDenom, fmf);
314 Value imagNumerator1 = rewriter.create<arith::SubFOp>(
315 loc,
316 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
317 lhsReal, fmf);
318 Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
319 rhsRealImagDenom, fmf);
320
321 Value rhsImagRealRatio =
322 rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
323 Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
324 loc, rhsReal,
325 rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
326 fmf);
327 Value realNumerator2 = rewriter.create<arith::AddFOp>(
328 loc, lhsReal,
329 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
330 fmf);
331 Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
332 rhsImagRealDenom, fmf);
333 Value imagNumerator2 = rewriter.create<arith::SubFOp>(
334 loc, lhsImag,
335 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
336 fmf);
337 Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
338 rhsImagRealDenom, fmf);
339
340 // Consider corner cases.
341 // Case 1. Zero denominator, numerator contains at most one NaN value.
342 Value zero = rewriter.create<arith::ConstantOp>(
343 loc, elementType, rewriter.getZeroAttr(elementType));
344 Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf);
345 Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
346 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
347 Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf);
348 Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
349 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
350 Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
351 loc, arith::CmpFPredicate::ORD, lhsReal, zero);
352 Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
353 loc, arith::CmpFPredicate::ORD, lhsImag, zero);
354 Value lhsContainsNotNaNValue =
355 rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
356 Value resultIsInfinity = rewriter.create<arith::AndIOp>(
357 loc, lhsContainsNotNaNValue,
358 rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
359 Value inf = rewriter.create<arith::ConstantOp>(
360 loc, elementType,
361 rewriter.getFloatAttr(
362 elementType, APFloat::getInf(elementType.getFloatSemantics())));
363 Value infWithSignOfRhsReal =
364 rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
365 Value infinityResultReal =
366 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
367 Value infinityResultImag =
368 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
369
370 // Case 2. Infinite numerator, finite denominator.
371 Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
372 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
373 Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
374 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
375 Value rhsFinite =
376 rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
377 Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf);
378 Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
379 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
380 Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf);
381 Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
382 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
383 Value lhsInfinite =
384 rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
385 Value infNumFiniteDenom =
386 rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
387 Value one = rewriter.create<arith::ConstantOp>(
388 loc, elementType, rewriter.getFloatAttr(elementType, 1));
389 Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
390 loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
391 lhsReal);
392 Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
393 loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
394 lhsImag);
395 Value lhsRealIsInfWithSignTimesRhsReal =
396 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
397 Value lhsImagIsInfWithSignTimesRhsImag =
398 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
399 Value resultReal3 = rewriter.create<arith::MulFOp>(
400 loc, inf,
401 rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
402 lhsImagIsInfWithSignTimesRhsImag, fmf),
403 fmf);
404 Value lhsRealIsInfWithSignTimesRhsImag =
405 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
406 Value lhsImagIsInfWithSignTimesRhsReal =
407 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
408 Value resultImag3 = rewriter.create<arith::MulFOp>(
409 loc, inf,
410 rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
411 lhsRealIsInfWithSignTimesRhsImag, fmf),
412 fmf);
413
414 // Case 3: Finite numerator, infinite denominator.
415 Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
416 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
417 Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
418 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
419 Value lhsFinite =
420 rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
421 Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
422 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
423 Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
424 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
425 Value rhsInfinite =
426 rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
427 Value finiteNumInfiniteDenom =
428 rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
429 Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
430 loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
431 rhsReal);
432 Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
433 loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
434 rhsImag);
435 Value rhsRealIsInfWithSignTimesLhsReal =
436 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
437 Value rhsImagIsInfWithSignTimesLhsImag =
438 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
439 Value resultReal4 = rewriter.create<arith::MulFOp>(
440 loc, zero,
441 rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
442 rhsImagIsInfWithSignTimesLhsImag, fmf),
443 fmf);
444 Value rhsRealIsInfWithSignTimesLhsImag =
445 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
446 Value rhsImagIsInfWithSignTimesLhsReal =
447 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
448 Value resultImag4 = rewriter.create<arith::MulFOp>(
449 loc, zero,
450 rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
451 rhsImagIsInfWithSignTimesLhsReal, fmf),
452 fmf);
453
454 Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
455 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
456 Value resultReal = rewriter.create<arith::SelectOp>(
457 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
458 Value resultImag = rewriter.create<arith::SelectOp>(
459 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
460 Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
461 loc, finiteNumInfiniteDenom, resultReal4, resultReal);
462 Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
463 loc, finiteNumInfiniteDenom, resultImag4, resultImag);
464 Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
465 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
466 Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
467 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
468 Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
469 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
470 Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
471 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
472
473 Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
474 loc, arith::CmpFPredicate::UNO, resultReal, zero);
475 Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
476 loc, arith::CmpFPredicate::UNO, resultImag, zero);
477 Value resultIsNaN =
478 rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
479 Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
480 loc, resultIsNaN, resultRealSpecialCase1, resultReal);
481 Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
482 loc, resultIsNaN, resultImagSpecialCase1, resultImag);
483
484 rewriter.replaceOpWithNewOp<complex::CreateOp>(
485 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
486 return success();
487 }
488};
489
490struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
491 using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
492
493 LogicalResult
494 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
495 ConversionPatternRewriter &rewriter) const override {
496 auto loc = op.getLoc();
497 auto type = cast<ComplexType>(adaptor.getComplex().getType());
498 auto elementType = cast<FloatType>(type.getElementType());
499 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
500
501 Value real =
502 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
503 Value imag =
504 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
505 Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
506 Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
507 Value resultReal =
508 rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
509 Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
510 Value resultImag =
511 rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
512
513 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
514 resultImag);
515 return success();
516 }
517};
518
519struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
520 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
521
522 LogicalResult
523 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
524 ConversionPatternRewriter &rewriter) const override {
525 auto type = cast<ComplexType>(adaptor.getComplex().getType());
526 auto elementType = cast<FloatType>(type.getElementType());
527 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
528
529 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
530 Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
531
532 Value real = b.create<complex::ReOp>(elementType, exp);
533 Value one = b.create<arith::ConstantOp>(elementType,
534 b.getFloatAttr(elementType, 1));
535 Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
536 Value imag = b.create<complex::ImOp>(elementType, exp);
537
538 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
539 imag);
540 return success();
541 }
542};
543
544struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
545 using OpConversionPattern<complex::LogOp>::OpConversionPattern;
546
547 LogicalResult
548 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
549 ConversionPatternRewriter &rewriter) const override {
550 auto type = cast<ComplexType>(adaptor.getComplex().getType());
551 auto elementType = cast<FloatType>(type.getElementType());
552 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
553 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
554
555 Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
556 fmf.getValue());
557 Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
558 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
559 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
560 Value resultImag =
561 b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
562 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
563 resultImag);
564 return success();
565 }
566};
567
568struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
569 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
570
571 LogicalResult
572 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
573 ConversionPatternRewriter &rewriter) const override {
574 auto type = cast<ComplexType>(adaptor.getComplex().getType());
575 auto elementType = cast<FloatType>(type.getElementType());
576 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
577 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
578
579 Value real = b.create<complex::ReOp>(adaptor.getComplex());
580 Value imag = b.create<complex::ImOp>(adaptor.getComplex());
581
582 Value half = b.create<arith::ConstantOp>(elementType,
583 b.getFloatAttr(elementType, 0.5));
584 Value one = b.create<arith::ConstantOp>(elementType,
585 b.getFloatAttr(elementType, 1));
586 Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
587 Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
588 Value absImag = b.create<math::AbsFOp>(imag, fmf);
589
590 Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
591 Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
592
593 Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
594 realPlusOne, absImag, fmf);
595 Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
596 Value maxAbsOfRealPlusOneAndImagMinusOne =
597 b.create<arith::SelectOp>(useReal, real, maxMinusOne);
598 Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmf);
599 Value logOfMaxAbsOfRealPlusOneAndImag =
600 b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
601 Value logOfSqrtPart = b.create<math::Log1pOp>(
602 b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmf), fmf);
603 Value r = b.create<arith::AddFOp>(
604 b.create<arith::MulFOp>(half, logOfSqrtPart, fmf),
605 logOfMaxAbsOfRealPlusOneAndImag, fmf);
606 Value resultReal = b.create<arith::SelectOp>(
607 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmf), minAbs,
608 r);
609 Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
610 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
611 resultImag);
612 return success();
613 }
614};
615
616struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
617 using OpConversionPattern<complex::MulOp>::OpConversionPattern;
618
619 LogicalResult
620 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
621 ConversionPatternRewriter &rewriter) const override {
622 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
623 auto type = cast<ComplexType>(adaptor.getLhs().getType());
624 auto elementType = cast<FloatType>(type.getElementType());
625 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
626 auto fmfValue = fmf.getValue();
627
628 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
629 Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue);
630 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
631 Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue);
632 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
633 Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue);
634 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
635 Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue);
636
637 Value lhsRealTimesRhsReal =
638 b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
639 Value lhsRealTimesRhsRealAbs =
640 b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
641 Value lhsImagTimesRhsImag =
642 b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
643 Value lhsImagTimesRhsImagAbs =
644 b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
645 Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
646 lhsImagTimesRhsImag, fmfValue);
647
648 Value lhsImagTimesRhsReal =
649 b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
650 Value lhsImagTimesRhsRealAbs =
651 b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
652 Value lhsRealTimesRhsImag =
653 b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
654 Value lhsRealTimesRhsImagAbs =
655 b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
656 Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
657 lhsRealTimesRhsImag, fmfValue);
658
659 // Handle cases where the "naive" calculation results in NaN values.
660 Value realIsNan =
661 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
662 Value imagIsNan =
663 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
664 Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
665
666 Value inf = b.create<arith::ConstantOp>(
667 elementType,
668 b.getFloatAttr(elementType,
669 APFloat::getInf(elementType.getFloatSemantics())));
670
671 // Case 1. `lhsReal` or `lhsImag` are infinite.
672 Value lhsRealIsInf =
673 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
674 Value lhsImagIsInf =
675 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
676 Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
677 Value rhsRealIsNan =
678 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
679 Value rhsImagIsNan =
680 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
681 Value zero =
682 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
683 Value one = b.create<arith::ConstantOp>(elementType,
684 b.getFloatAttr(elementType, 1));
685 Value lhsRealIsInfFloat =
686 b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
687 lhsReal = b.create<arith::SelectOp>(
688 lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
689 lhsReal);
690 Value lhsImagIsInfFloat =
691 b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
692 lhsImag = b.create<arith::SelectOp>(
693 lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
694 lhsImag);
695 Value lhsIsInfAndRhsRealIsNan =
696 b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
697 rhsReal = b.create<arith::SelectOp>(
698 lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
699 rhsReal);
700 Value lhsIsInfAndRhsImagIsNan =
701 b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
702 rhsImag = b.create<arith::SelectOp>(
703 lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
704 rhsImag);
705
706 // Case 2. `rhsReal` or `rhsImag` are infinite.
707 Value rhsRealIsInf =
708 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
709 Value rhsImagIsInf =
710 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
711 Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
712 Value lhsRealIsNan =
713 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
714 Value lhsImagIsNan =
715 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
716 Value rhsRealIsInfFloat =
717 b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
718 rhsReal = b.create<arith::SelectOp>(
719 rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
720 rhsReal);
721 Value rhsImagIsInfFloat =
722 b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
723 rhsImag = b.create<arith::SelectOp>(
724 rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
725 rhsImag);
726 Value rhsIsInfAndLhsRealIsNan =
727 b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
728 lhsReal = b.create<arith::SelectOp>(
729 rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
730 lhsReal);
731 Value rhsIsInfAndLhsImagIsNan =
732 b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
733 lhsImag = b.create<arith::SelectOp>(
734 rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
735 lhsImag);
736 Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
737
738 // Case 3. One of the pairwise products of left hand side with right hand
739 // side is infinite.
740 Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
741 arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
742 Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
743 arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
744 Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
745 lhsImagTimesRhsImagIsInf);
746 Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
747 arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
748 isSpecialCase =
749 b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
750 Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
751 arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
752 isSpecialCase =
753 b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
754 Type i1Type = b.getI1Type();
755 Value notRecalc = b.create<arith::XOrIOp>(
756 recalc,
757 b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
758 isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
759 Value isSpecialCaseAndLhsRealIsNan =
760 b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
761 lhsReal = b.create<arith::SelectOp>(
762 isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
763 lhsReal);
764 Value isSpecialCaseAndLhsImagIsNan =
765 b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
766 lhsImag = b.create<arith::SelectOp>(
767 isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
768 lhsImag);
769 Value isSpecialCaseAndRhsRealIsNan =
770 b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
771 rhsReal = b.create<arith::SelectOp>(
772 isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
773 rhsReal);
774 Value isSpecialCaseAndRhsImagIsNan =
775 b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
776 rhsImag = b.create<arith::SelectOp>(
777 isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
778 rhsImag);
779 recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
780 recalc = b.create<arith::AndIOp>(isNan, recalc);
781
782 // Recalculate real part.
783 lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
784 lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
785 Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
786 lhsImagTimesRhsImag, fmfValue);
787 real = b.create<arith::SelectOp>(
788 recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real);
789
790 // Recalculate imag part.
791 lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
792 lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
793 Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
794 lhsRealTimesRhsImag, fmfValue);
795 imag = b.create<arith::SelectOp>(
796 recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag);
797
798 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
799 return success();
800 }
801};
802
803struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
804 using OpConversionPattern<complex::NegOp>::OpConversionPattern;
805
806 LogicalResult
807 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
808 ConversionPatternRewriter &rewriter) const override {
809 auto loc = op.getLoc();
810 auto type = cast<ComplexType>(adaptor.getComplex().getType());
811 auto elementType = cast<FloatType>(type.getElementType());
812
813 Value real =
814 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
815 Value imag =
816 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
817 Value negReal = rewriter.create<arith::NegFOp>(loc, real);
818 Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
819 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
820 return success();
821 }
822};
823
824struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
825 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
826
827 std::pair<Value, Value> combine(Location loc, Value scaledExp,
828 Value reciprocalExp, Value sin, Value cos,
829 ConversionPatternRewriter &rewriter,
830 arith::FastMathFlagsAttr fmf) const override {
831 // Complex sine is defined as;
832 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
833 // Plugging in:
834 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
835 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
836 // and defining t := exp(y)
837 // We get:
838 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
839 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
840 Value sum =
841 rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
842 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
843 Value diff =
844 rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
845 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
846 return {resultReal, resultImag};
847 }
848};
849
850// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
851struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
852 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
853
854 LogicalResult
855 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
856 ConversionPatternRewriter &rewriter) const override {
857 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
858
859 auto type = cast<ComplexType>(op.getType());
860 auto elementType = type.getElementType().cast<FloatType>();
861 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
862
863 auto cst = [&](APFloat v) {
864 return b.create<arith::ConstantOp>(elementType,
865 b.getFloatAttr(elementType, v));
866 };
867 const auto &floatSemantics = elementType.getFloatSemantics();
868 Value zero = cst(APFloat::getZero(Sem: floatSemantics));
869 Value half = b.create<arith::ConstantOp>(elementType,
870 b.getFloatAttr(elementType, 0.5));
871
872 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
873 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
874 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
875 Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
876 Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
877 Value cos = b.create<math::CosOp>(sqrtArg, fmf);
878 Value sin = b.create<math::SinOp>(sqrtArg, fmf);
879 // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
880 // 0 * inf.
881 Value sinIsZero =
882 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
883
884 Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
885 Value resultImag = b.create<arith::SelectOp>(
886 sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
887 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
888 arith::FastMathFlags::ninf)) {
889 Value inf = cst(APFloat::getInf(Sem: floatSemantics));
890 Value negInf = cst(APFloat::getInf(Sem: floatSemantics, Negative: true));
891 Value nan = cst(APFloat::getNaN(Sem: floatSemantics));
892 Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
893
894 Value absImagIsInf =
895 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
896 Value absImagIsNotInf =
897 b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
898 Value realIsInf =
899 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
900 Value realIsNegInf =
901 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
902
903 resultReal = b.create<arith::SelectOp>(
904 b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
905 resultReal);
906 resultReal = b.create<arith::SelectOp>(
907 b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
908
909 Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
910 resultImag = b.create<arith::SelectOp>(
911 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
912 nan, resultImag);
913 resultImag = b.create<arith::SelectOp>(
914 b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
915 resultImag);
916 }
917
918 Value resultIsZero =
919 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
920 resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
921 resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
922
923 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
924 resultImag);
925 return success();
926 }
927};
928
929struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
930 using OpConversionPattern<complex::SignOp>::OpConversionPattern;
931
932 LogicalResult
933 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
934 ConversionPatternRewriter &rewriter) const override {
935 auto type = cast<ComplexType>(adaptor.getComplex().getType());
936 auto elementType = cast<FloatType>(type.getElementType());
937 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
938 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
939
940 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
941 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
942 Value zero =
943 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
944 Value realIsZero =
945 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
946 Value imagIsZero =
947 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
948 Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
949 auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
950 Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
951 Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
952 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
953 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
954 adaptor.getComplex(), sign);
955 return success();
956 }
957};
958
959struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
960 using OpConversionPattern<complex::TanOp>::OpConversionPattern;
961
962 LogicalResult
963 matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
964 ConversionPatternRewriter &rewriter) const override {
965 auto loc = op.getLoc();
966 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
967
968 Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex(), fmf);
969 Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex(), fmf);
970 rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos, fmf);
971 return success();
972 }
973};
974
975struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
976 using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
977
978 LogicalResult
979 matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
980 ConversionPatternRewriter &rewriter) const override {
981 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
982 auto loc = op.getLoc();
983 auto type = cast<ComplexType>(adaptor.getComplex().getType());
984 auto elementType = cast<FloatType>(type.getElementType());
985 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
986 const auto &floatSemantics = elementType.getFloatSemantics();
987
988 Value real =
989 b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
990 Value imag =
991 b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
992
993 auto cst = [&](APFloat v) {
994 return b.create<arith::ConstantOp>(elementType,
995 b.getFloatAttr(elementType, v));
996 };
997 Value inf = cst(APFloat::getInf(Sem: floatSemantics));
998 Value negOne = b.create<arith::ConstantOp>(
999 elementType, b.getFloatAttr(elementType, -1.0));
1000 Value four = b.create<arith::ConstantOp>(elementType,
1001 b.getFloatAttr(elementType, 4.0));
1002 Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
1003 Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
1004
1005 Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
1006 Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
1007 Value realNum =
1008 b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1009
1010 Value cosImag = b.create<math::CosOp>(imag, fmf);
1011 Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
1012 Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
1013 Value sinImag = b.create<math::SinOp>(imag, fmf);
1014
1015 Value imagNum = b.create<arith::MulFOp>(
1016 four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
1017
1018 Value expSumMinusTwo =
1019 b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1020 Value denom =
1021 b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
1022
1023 Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1024 expSumMinusTwo, inf, fmf);
1025 Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
1026
1027 Value resultReal = b.create<arith::SelectOp>(
1028 isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
1029 Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
1030
1031 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1032 arith::FastMathFlags::ninf)) {
1033 Value absReal = b.create<math::AbsFOp>(real, fmf);
1034 Value zero = b.create<arith::ConstantOp>(
1035 elementType, b.getFloatAttr(elementType, 0.0));
1036 Value nan = cst(APFloat::getNaN(Sem: floatSemantics));
1037
1038 Value absRealIsInf =
1039 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1040 Value imagIsZero =
1041 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1042 Value absRealIsNotInf = b.create<arith::XOrIOp>(
1043 absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
1044
1045 Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
1046 imagNum, imagNum, fmf);
1047 Value resultRealIsNaN =
1048 b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
1049 Value resultImagIsZero = b.create<arith::OrIOp>(
1050 imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
1051
1052 resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
1053 resultImag =
1054 b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
1055 }
1056
1057 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1058 resultImag);
1059 return success();
1060 }
1061};
1062
1063struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
1064 using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
1065
1066 LogicalResult
1067 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
1068 ConversionPatternRewriter &rewriter) const override {
1069 auto loc = op.getLoc();
1070 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1071 auto elementType = cast<FloatType>(type.getElementType());
1072 Value real =
1073 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
1074 Value imag =
1075 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
1076 Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
1077
1078 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
1079
1080 return success();
1081 }
1082};
1083
1084/// Converts lhs^y = (a+bi)^(c+di) to
1085/// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1086/// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
1087static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
1088 ComplexType type, Value lhs, Value c, Value d,
1089 arith::FastMathFlags fmf) {
1090 auto elementType = cast<FloatType>(type.getElementType());
1091
1092 Value a = builder.create<complex::ReOp>(lhs);
1093 Value b = builder.create<complex::ImOp>(lhs);
1094
1095 Value abs = builder.create<complex::AbsOp>(lhs, fmf);
1096 Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
1097
1098 Value negD = builder.create<arith::NegFOp>(d, fmf);
1099 Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
1100 Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
1101 Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
1102
1103 Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
1104 Value lnAbs = builder.create<math::LogOp>(abs, fmf);
1105 Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
1106 Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
1107 Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
1108 Value cosQ = builder.create<math::CosOp>(q, fmf);
1109 Value sinQ = builder.create<math::SinOp>(q, fmf);
1110
1111 Value inf = builder.create<arith::ConstantOp>(
1112 elementType,
1113 builder.getFloatAttr(elementType,
1114 APFloat::getInf(elementType.getFloatSemantics())));
1115 Value zero = builder.create<arith::ConstantOp>(
1116 elementType, builder.getFloatAttr(elementType, 0.0));
1117 Value one = builder.create<arith::ConstantOp>(
1118 elementType, builder.getFloatAttr(elementType, 1.0));
1119 Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
1120 Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
1121 Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
1122
1123 // Case 0:
1124 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1125 // Branch Cuts for Complex Elementary Functions or Much Ado About
1126 // Nothing's Sign Bit, W. Kahan, Section 10.
1127 Value absEqZero =
1128 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
1129 Value dEqZero =
1130 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
1131 Value cEqZero =
1132 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
1133 Value bEqZero =
1134 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
1135
1136 Value zeroLeC =
1137 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
1138 Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
1139 Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
1140 Value complexOneOrZero =
1141 builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
1142 Value coeffCosSin =
1143 builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
1144 Value cutoff0 = builder.create<arith::SelectOp>(
1145 builder.create<arith::AndIOp>(
1146 builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
1147 complexOneOrZero, coeffCosSin);
1148
1149 // Case 1:
1150 // x^0 is defined to be 1 for any x, see
1151 // Branch Cuts for Complex Elementary Functions or Much Ado About
1152 // Nothing's Sign Bit, W. Kahan, Section 10.
1153 Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
1154 Value cutoff1 =
1155 builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
1156
1157 // Case 2:
1158 // 1^(c + d*i) = 1 + 0*i
1159 Value lhsEqOne = builder.create<arith::AndIOp>(
1160 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
1161 bEqZero);
1162 Value cutoff2 =
1163 builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
1164
1165 // Case 3:
1166 // inf^(c + 0*i) = inf + 0*i, c > 0
1167 Value lhsEqInf = builder.create<arith::AndIOp>(
1168 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
1169 bEqZero);
1170 Value rhsGt0 = builder.create<arith::AndIOp>(
1171 dEqZero,
1172 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
1173 Value cutoff3 = builder.create<arith::SelectOp>(
1174 builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
1175
1176 // Case 4:
1177 // inf^(c + 0*i) = 0 + 0*i, c < 0
1178 Value rhsLt0 = builder.create<arith::AndIOp>(
1179 dEqZero,
1180 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
1181 Value cutoff4 = builder.create<arith::SelectOp>(
1182 builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
1183
1184 return cutoff4;
1185}
1186
1187struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
1188 using OpConversionPattern<complex::PowOp>::OpConversionPattern;
1189
1190 LogicalResult
1191 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
1192 ConversionPatternRewriter &rewriter) const override {
1193 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1194 auto type = cast<ComplexType>(adaptor.getLhs().getType());
1195 auto elementType = cast<FloatType>(type.getElementType());
1196
1197 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
1198 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
1199
1200 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
1201 c, d, op.getFastmath())});
1202 return success();
1203 }
1204};
1205
1206struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1207 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
1208
1209 LogicalResult
1210 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1211 ConversionPatternRewriter &rewriter) const override {
1212 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1213 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1214 auto elementType = cast<FloatType>(type.getElementType());
1215
1216 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1217
1218 auto cst = [&](APFloat v) {
1219 return b.create<arith::ConstantOp>(elementType,
1220 b.getFloatAttr(elementType, v));
1221 };
1222 const auto &floatSemantics = elementType.getFloatSemantics();
1223 Value zero = cst(APFloat::getZero(Sem: floatSemantics));
1224 Value inf = cst(APFloat::getInf(Sem: floatSemantics));
1225 Value negHalf = b.create<arith::ConstantOp>(
1226 elementType, b.getFloatAttr(elementType, -0.5));
1227 Value nan = cst(APFloat::getNaN(Sem: floatSemantics));
1228
1229 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
1230 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
1231 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1232 Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
1233 Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
1234 Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
1235 Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
1236
1237 Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
1238 Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
1239
1240 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1241 arith::FastMathFlags::ninf)) {
1242 Value negOne = b.create<arith::ConstantOp>(
1243 elementType, b.getFloatAttr(elementType, -1));
1244
1245 Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
1246 Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
1247 Value negImagSignedZero =
1248 b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
1249
1250 Value absReal = b.create<math::AbsFOp>(real, fmf);
1251 Value absImag = b.create<math::AbsFOp>(imag, fmf);
1252
1253 Value absImagIsInf =
1254 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1255 Value realIsNan =
1256 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1257 Value realIsInf =
1258 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1259 Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
1260
1261 Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
1262
1263 resultReal =
1264 b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1265 resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1266 resultImag);
1267 }
1268
1269 Value isRealZero =
1270 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1271 Value isImagZero =
1272 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1273 Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
1274
1275 resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
1276 resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
1277
1278 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1279 resultImag);
1280 return success();
1281 }
1282};
1283
1284struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1285 using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
1286
1287 LogicalResult
1288 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1289 ConversionPatternRewriter &rewriter) const override {
1290 auto loc = op.getLoc();
1291 auto type = op.getType();
1292 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1293
1294 Value real =
1295 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
1296 Value imag =
1297 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
1298
1299 rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
1300
1301 return success();
1302 }
1303};
1304
1305} // namespace
1306
1307void mlir::populateComplexToStandardConversionPatterns(
1308 RewritePatternSet &patterns) {
1309 // clang-format off
1310 patterns.add<
1311 AbsOpConversion,
1312 AngleOpConversion,
1313 Atan2OpConversion,
1314 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1315 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1316 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1317 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1318 ConjOpConversion,
1319 CosOpConversion,
1320 DivOpConversion,
1321 ExpOpConversion,
1322 Expm1OpConversion,
1323 Log1pOpConversion,
1324 LogOpConversion,
1325 MulOpConversion,
1326 NegOpConversion,
1327 SignOpConversion,
1328 SinOpConversion,
1329 SqrtOpConversion,
1330 TanOpConversion,
1331 TanhOpConversion,
1332 PowOpConversion,
1333 RsqrtOpConversion
1334 >(patterns.getContext());
1335 // clang-format on
1336}
1337
1338namespace {
1339struct ConvertComplexToStandardPass
1340 : public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1341 void runOnOperation() override;
1342};
1343
1344void ConvertComplexToStandardPass::runOnOperation() {
1345 // Convert to the Standard dialect using the converter defined above.
1346 RewritePatternSet patterns(&getContext());
1347 populateComplexToStandardConversionPatterns(patterns);
1348
1349 ConversionTarget target(getContext());
1350 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1351 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1352 if (failed(
1353 applyPartialConversion(getOperation(), target, std::move(patterns))))
1354 signalPassFailure();
1355}
1356} // namespace
1357
1358std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
1359 return std::make_unique<ConvertComplexToStandardPass>();
1360}
1361

source code of mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp