1//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10
11#include "mlir/Conversion/ComplexCommon/DivisionConverter.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Complex/IR/Complex.h"
14#include "mlir/Dialect/Math/IR/Math.h"
15#include "mlir/IR/ImplicitLocOpBuilder.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include <type_traits>
19
20namespace mlir {
21#define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
22#include "mlir/Conversion/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26
27namespace {
28
29enum class AbsFn { abs, sqrt, rsqrt };
30
31// Returns the absolute value, its square root or its reciprocal square root.
32Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
33 ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
34 Value one = b.create<arith::ConstantOp>(args: real.getType(),
35 args: b.getFloatAttr(type: real.getType(), value: 1.0));
36
37 Value absReal = b.create<math::AbsFOp>(args&: real, args&: fmf);
38 Value absImag = b.create<math::AbsFOp>(args&: imag, args&: fmf);
39
40 Value max = b.create<arith::MaximumFOp>(args&: absReal, args&: absImag, args&: fmf);
41 Value min = b.create<arith::MinimumFOp>(args&: absReal, args&: absImag, args&: fmf);
42
43 // The lowering below requires NaNs and infinities to work correctly.
44 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
45 bits: fmf, bit: arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
46 Value ratio = b.create<arith::DivFOp>(args&: min, args&: max, args&: fmfWithNaNInf);
47 Value ratioSq = b.create<arith::MulFOp>(args&: ratio, args&: ratio, args&: fmfWithNaNInf);
48 Value ratioSqPlusOne = b.create<arith::AddFOp>(args&: ratioSq, args&: one, args&: fmfWithNaNInf);
49 Value result;
50
51 if (fn == AbsFn::rsqrt) {
52 ratioSqPlusOne = b.create<math::RsqrtOp>(args&: ratioSqPlusOne, args&: fmfWithNaNInf);
53 min = b.create<math::RsqrtOp>(args&: min, args&: fmfWithNaNInf);
54 max = b.create<math::RsqrtOp>(args&: max, args&: fmfWithNaNInf);
55 }
56
57 if (fn == AbsFn::sqrt) {
58 Value quarter = b.create<arith::ConstantOp>(
59 args: real.getType(), args: b.getFloatAttr(type: real.getType(), value: 0.25));
60 // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
61 Value sqrt = b.create<math::SqrtOp>(args&: max, args&: fmfWithNaNInf);
62 Value p025 = b.create<math::PowFOp>(args&: ratioSqPlusOne, args&: quarter, args&: fmfWithNaNInf);
63 result = b.create<arith::MulFOp>(args&: sqrt, args&: p025, args&: fmfWithNaNInf);
64 } else {
65 Value sqrt = b.create<math::SqrtOp>(args&: ratioSqPlusOne, args&: fmfWithNaNInf);
66 result = b.create<arith::MulFOp>(args&: max, args&: sqrt, args&: fmfWithNaNInf);
67 }
68
69 Value isNaN = b.create<arith::CmpFOp>(args: arith::CmpFPredicate::UNO, args&: result,
70 args&: result, args&: fmfWithNaNInf);
71 return b.create<arith::SelectOp>(args&: isNaN, args&: min, args&: result);
72}
73
74struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
75 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
76
77 LogicalResult
78 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
79 ConversionPatternRewriter &rewriter) const override {
80 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
81
82 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
83
84 Value real = b.create<complex::ReOp>(args: adaptor.getComplex());
85 Value imag = b.create<complex::ImOp>(args: adaptor.getComplex());
86 rewriter.replaceOp(op, newValues: computeAbs(real, imag, fmf, b));
87
88 return success();
89 }
90};
91
92// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
93struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
94 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
95
96 LogicalResult
97 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
98 ConversionPatternRewriter &rewriter) const override {
99 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
100
101 auto type = cast<ComplexType>(Val: op.getType());
102 Type elementType = type.getElementType();
103 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
104
105 Value lhs = adaptor.getLhs();
106 Value rhs = adaptor.getRhs();
107
108 Value rhsSquared = b.create<complex::MulOp>(args&: type, args&: rhs, args&: rhs, args&: fmf);
109 Value lhsSquared = b.create<complex::MulOp>(args&: type, args&: lhs, args&: lhs, args&: fmf);
110 Value rhsSquaredPlusLhsSquared =
111 b.create<complex::AddOp>(args&: type, args&: rhsSquared, args&: lhsSquared, args&: fmf);
112 Value sqrtOfRhsSquaredPlusLhsSquared =
113 b.create<complex::SqrtOp>(args&: type, args&: rhsSquaredPlusLhsSquared, args&: fmf);
114
115 Value zero =
116 b.create<arith::ConstantOp>(args&: elementType, args: b.getZeroAttr(type: elementType));
117 Value one = b.create<arith::ConstantOp>(args&: elementType,
118 args: b.getFloatAttr(type: elementType, value: 1));
119 Value i = b.create<complex::CreateOp>(args&: type, args&: zero, args&: one);
120 Value iTimesLhs = b.create<complex::MulOp>(args&: i, args&: lhs, args&: fmf);
121 Value rhsPlusILhs = b.create<complex::AddOp>(args&: rhs, args&: iTimesLhs, args&: fmf);
122
123 Value divResult = b.create<complex::DivOp>(
124 args&: rhsPlusILhs, args&: sqrtOfRhsSquaredPlusLhsSquared, args&: fmf);
125 Value logResult = b.create<complex::LogOp>(args&: divResult, args&: fmf);
126
127 Value negativeOne = b.create<arith::ConstantOp>(
128 args&: elementType, args: b.getFloatAttr(type: elementType, value: -1));
129 Value negativeI = b.create<complex::CreateOp>(args&: type, args&: zero, args&: negativeOne);
130
131 rewriter.replaceOpWithNewOp<complex::MulOp>(op, args&: negativeI, args&: logResult, args&: fmf);
132 return success();
133 }
134};
135
136template <typename ComparisonOp, arith::CmpFPredicate p>
137struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
138 using OpConversionPattern<ComparisonOp>::OpConversionPattern;
139 using ResultCombiner =
140 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
141 arith::AndIOp, arith::OrIOp>;
142
143 LogicalResult
144 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
145 ConversionPatternRewriter &rewriter) const override {
146 auto loc = op.getLoc();
147 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
148
149 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
150 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
151 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
152 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
153 Value realComparison =
154 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
155 Value imagComparison =
156 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
157
158 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
159 imagComparison);
160 return success();
161 }
162};
163
164// Default conversion which applies the BinaryStandardOp separately on the real
165// and imaginary parts. Can for example be used for complex::AddOp and
166// complex::SubOp.
167template <typename BinaryComplexOp, typename BinaryStandardOp>
168struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
169 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
170
171 LogicalResult
172 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
173 ConversionPatternRewriter &rewriter) const override {
174 auto type = cast<ComplexType>(adaptor.getLhs().getType());
175 auto elementType = cast<FloatType>(type.getElementType());
176 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
177 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
178
179 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
180 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
181 Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
182 fmf.getValue());
183 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
184 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
185 Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
186 fmf.getValue());
187 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
188 resultImag);
189 return success();
190 }
191};
192
193template <typename TrigonometricOp>
194struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
195 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
196
197 using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
198
199 LogicalResult
200 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter) const override {
202 auto loc = op.getLoc();
203 auto type = cast<ComplexType>(adaptor.getComplex().getType());
204 auto elementType = cast<FloatType>(type.getElementType());
205 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
206
207 Value real =
208 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
209 Value imag =
210 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
211
212 // Trigonometric ops use a set of common building blocks to convert to real
213 // ops. Here we create these building blocks and call into an op-specific
214 // implementation in the subclass to combine them.
215 Value half = rewriter.create<arith::ConstantOp>(
216 loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
217 Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
218 Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
219 Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
220 Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
221 Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
222
223 auto resultPair =
224 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
225
226 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
227 resultPair.second);
228 return success();
229 }
230
231 virtual std::pair<Value, Value>
232 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
233 Value cos, ConversionPatternRewriter &rewriter,
234 arith::FastMathFlagsAttr fmf) const = 0;
235};
236
237struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
238 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
239
240 std::pair<Value, Value> combine(Location loc, Value scaledExp,
241 Value reciprocalExp, Value sin, Value cos,
242 ConversionPatternRewriter &rewriter,
243 arith::FastMathFlagsAttr fmf) const override {
244 // Complex cosine is defined as;
245 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
246 // Plugging in:
247 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
248 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
249 // and defining t := exp(y)
250 // We get:
251 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
252 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
253 Value sum =
254 rewriter.create<arith::AddFOp>(location: loc, args&: reciprocalExp, args&: scaledExp, args&: fmf);
255 Value resultReal = rewriter.create<arith::MulFOp>(location: loc, args&: sum, args&: cos, args&: fmf);
256 Value diff =
257 rewriter.create<arith::SubFOp>(location: loc, args&: reciprocalExp, args&: scaledExp, args&: fmf);
258 Value resultImag = rewriter.create<arith::MulFOp>(location: loc, args&: diff, args&: sin, args&: fmf);
259 return {resultReal, resultImag};
260 }
261};
262
263struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
264 DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags target)
265 : OpConversionPattern<complex::DivOp>(context), complexRange(target) {}
266
267 using OpConversionPattern<complex::DivOp>::OpConversionPattern;
268
269 LogicalResult
270 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter) const override {
272 auto loc = op.getLoc();
273 auto type = cast<ComplexType>(Val: adaptor.getLhs().getType());
274 auto elementType = cast<FloatType>(Val: type.getElementType());
275 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
276
277 Value lhsReal =
278 rewriter.create<complex::ReOp>(location: loc, args&: elementType, args: adaptor.getLhs());
279 Value lhsImag =
280 rewriter.create<complex::ImOp>(location: loc, args&: elementType, args: adaptor.getLhs());
281 Value rhsReal =
282 rewriter.create<complex::ReOp>(location: loc, args&: elementType, args: adaptor.getRhs());
283 Value rhsImag =
284 rewriter.create<complex::ImOp>(location: loc, args&: elementType, args: adaptor.getRhs());
285
286 Value resultReal, resultImag;
287
288 if (complexRange == complex::ComplexRangeFlags::basic ||
289 complexRange == complex::ComplexRangeFlags::none) {
290 mlir::complex::convertDivToStandardUsingAlgebraic(
291 rewriter, loc, lhsRe: lhsReal, lhsIm: lhsImag, rhsRe: rhsReal, rhsIm: rhsImag, fmf, resultRe: &resultReal,
292 resultIm: &resultImag);
293 } else if (complexRange == complex::ComplexRangeFlags::improved) {
294 mlir::complex::convertDivToStandardUsingRangeReduction(
295 rewriter, loc, lhsRe: lhsReal, lhsIm: lhsImag, rhsRe: rhsReal, rhsIm: rhsImag, fmf, resultRe: &resultReal,
296 resultIm: &resultImag);
297 }
298
299 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, args&: type, args&: resultReal,
300 args&: resultImag);
301
302 return success();
303 }
304
305private:
306 complex::ComplexRangeFlags complexRange;
307};
308
309struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
310 using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
311
312 LogicalResult
313 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter) const override {
315 auto loc = op.getLoc();
316 auto type = cast<ComplexType>(Val: adaptor.getComplex().getType());
317 auto elementType = cast<FloatType>(Val: type.getElementType());
318 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
319
320 Value real =
321 rewriter.create<complex::ReOp>(location: loc, args&: elementType, args: adaptor.getComplex());
322 Value imag =
323 rewriter.create<complex::ImOp>(location: loc, args&: elementType, args: adaptor.getComplex());
324 Value expReal = rewriter.create<math::ExpOp>(location: loc, args&: real, args: fmf.getValue());
325 Value cosImag = rewriter.create<math::CosOp>(location: loc, args&: imag, args: fmf.getValue());
326 Value resultReal =
327 rewriter.create<arith::MulFOp>(location: loc, args&: expReal, args&: cosImag, args: fmf.getValue());
328 Value sinImag = rewriter.create<math::SinOp>(location: loc, args&: imag, args: fmf.getValue());
329 Value resultImag =
330 rewriter.create<arith::MulFOp>(location: loc, args&: expReal, args&: sinImag, args: fmf.getValue());
331
332 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, args&: type, args&: resultReal,
333 args&: resultImag);
334 return success();
335 }
336};
337
338Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
339 ArrayRef<double> coefficients,
340 arith::FastMathFlagsAttr fmf) {
341 auto argType = mlir::cast<FloatType>(Val: arg.getType());
342 Value poly =
343 b.create<arith::ConstantOp>(args: b.getFloatAttr(type: argType, value: coefficients[0]));
344 for (unsigned i = 1; i < coefficients.size(); ++i) {
345 poly = b.create<math::FmaOp>(
346 args&: poly, args&: arg,
347 args: b.create<arith::ConstantOp>(args: b.getFloatAttr(type: argType, value: coefficients[i])),
348 args&: fmf);
349 }
350 return poly;
351}
352
353struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
354 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
355
356 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
357 // [handle inaccuracies when a and/or b are small]
358 // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
359 // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
360 LogicalResult
361 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter) const override {
363 auto type = op.getType();
364 auto elemType = mlir::cast<FloatType>(Val: type.getElementType());
365
366 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
367 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
368 Value real = b.create<complex::ReOp>(args: adaptor.getComplex());
369 Value imag = b.create<complex::ImOp>(args: adaptor.getComplex());
370
371 Value zero = b.create<arith::ConstantOp>(args: b.getFloatAttr(type: elemType, value: 0.0));
372 Value one = b.create<arith::ConstantOp>(args: b.getFloatAttr(type: elemType, value: 1.0));
373
374 Value expm1Real = b.create<math::ExpM1Op>(args&: real, args&: fmf);
375 Value expReal = b.create<arith::AddFOp>(args&: expm1Real, args&: one, args&: fmf);
376
377 Value sinImag = b.create<math::SinOp>(args&: imag, args&: fmf);
378 Value cosm1Imag = emitCosm1(arg: imag, fmf, b);
379 Value cosImag = b.create<arith::AddFOp>(args&: cosm1Imag, args&: one, args&: fmf);
380
381 Value realResult = b.create<arith::AddFOp>(
382 args: b.create<arith::MulFOp>(args&: expm1Real, args&: cosImag, args&: fmf), args&: cosm1Imag, args&: fmf);
383
384 Value imagIsZero = b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: imag,
385 args&: zero, args: fmf.getValue());
386 Value imagResult = b.create<arith::SelectOp>(
387 args&: imagIsZero, args&: zero, args: b.create<arith::MulFOp>(args&: expReal, args&: sinImag, args&: fmf));
388
389 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, args&: type, args&: realResult,
390 args&: imagResult);
391 return success();
392 }
393
394private:
395 Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
396 ImplicitLocOpBuilder &b) const {
397 auto argType = mlir::cast<FloatType>(Val: arg.getType());
398 auto negHalf = b.create<arith::ConstantOp>(args: b.getFloatAttr(type: argType, value: -0.5));
399 auto negOne = b.create<arith::ConstantOp>(args: b.getFloatAttr(type: argType, value: -1.0));
400
401 // Algorithm copied from cephes cosm1.
402 SmallVector<double, 7> kCoeffs{
403 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
404 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
405 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
406 4.1666666666666666609054E-2,
407 };
408 Value cos = b.create<math::CosOp>(args&: arg, args&: fmf);
409 Value forLargeArg = b.create<arith::AddFOp>(args&: cos, args&: negOne, args&: fmf);
410
411 Value argPow2 = b.create<arith::MulFOp>(args&: arg, args&: arg, args&: fmf);
412 Value argPow4 = b.create<arith::MulFOp>(args&: argPow2, args&: argPow2, args&: fmf);
413 Value poly = evaluatePolynomial(b, arg: argPow2, coefficients: kCoeffs, fmf);
414
415 auto forSmallArg =
416 b.create<arith::AddFOp>(args: b.create<arith::MulFOp>(args&: argPow4, args&: poly, args&: fmf),
417 args: b.create<arith::MulFOp>(args&: negHalf, args&: argPow2, args&: fmf));
418
419 // (pi/4)^2 is approximately 0.61685
420 Value piOver4Pow2 =
421 b.create<arith::ConstantOp>(args: b.getFloatAttr(type: argType, value: 0.61685));
422 Value cond = b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGE, args&: argPow2,
423 args&: piOver4Pow2, args: fmf.getValue());
424 return b.create<arith::SelectOp>(args&: cond, args&: forLargeArg, args&: forSmallArg);
425 }
426};
427
428struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
429 using OpConversionPattern<complex::LogOp>::OpConversionPattern;
430
431 LogicalResult
432 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
433 ConversionPatternRewriter &rewriter) const override {
434 auto type = cast<ComplexType>(Val: adaptor.getComplex().getType());
435 auto elementType = cast<FloatType>(Val: type.getElementType());
436 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
437 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
438
439 Value abs = b.create<complex::AbsOp>(args&: elementType, args: adaptor.getComplex(),
440 args: fmf.getValue());
441 Value resultReal = b.create<math::LogOp>(args&: elementType, args&: abs, args: fmf.getValue());
442 Value real = b.create<complex::ReOp>(args&: elementType, args: adaptor.getComplex());
443 Value imag = b.create<complex::ImOp>(args&: elementType, args: adaptor.getComplex());
444 Value resultImag =
445 b.create<math::Atan2Op>(args&: elementType, args&: imag, args&: real, args: fmf.getValue());
446 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, args&: type, args&: resultReal,
447 args&: resultImag);
448 return success();
449 }
450};
451
452struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
453 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
454
455 LogicalResult
456 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
457 ConversionPatternRewriter &rewriter) const override {
458 auto type = cast<ComplexType>(Val: adaptor.getComplex().getType());
459 auto elementType = cast<FloatType>(Val: type.getElementType());
460 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
461 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
462
463 Value real = b.create<complex::ReOp>(args: adaptor.getComplex());
464 Value imag = b.create<complex::ImOp>(args: adaptor.getComplex());
465
466 Value half = b.create<arith::ConstantOp>(args&: elementType,
467 args: b.getFloatAttr(type: elementType, value: 0.5));
468 Value one = b.create<arith::ConstantOp>(args&: elementType,
469 args: b.getFloatAttr(type: elementType, value: 1));
470 Value realPlusOne = b.create<arith::AddFOp>(args&: real, args&: one, args&: fmf);
471 Value absRealPlusOne = b.create<math::AbsFOp>(args&: realPlusOne, args&: fmf);
472 Value absImag = b.create<math::AbsFOp>(args&: imag, args&: fmf);
473
474 Value maxAbs = b.create<arith::MaximumFOp>(args&: absRealPlusOne, args&: absImag, args&: fmf);
475 Value minAbs = b.create<arith::MinimumFOp>(args&: absRealPlusOne, args&: absImag, args&: fmf);
476
477 Value useReal = b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT,
478 args&: realPlusOne, args&: absImag, args&: fmf);
479 Value maxMinusOne = b.create<arith::SubFOp>(args&: maxAbs, args&: one, args&: fmf);
480 Value maxAbsOfRealPlusOneAndImagMinusOne =
481 b.create<arith::SelectOp>(args&: useReal, args&: real, args&: maxMinusOne);
482 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
483 bits: fmf, bit: arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
484 Value minMaxRatio = b.create<arith::DivFOp>(args&: minAbs, args&: maxAbs, args&: fmfWithNaNInf);
485 Value logOfMaxAbsOfRealPlusOneAndImag =
486 b.create<math::Log1pOp>(args&: maxAbsOfRealPlusOneAndImagMinusOne, args&: fmf);
487 Value logOfSqrtPart = b.create<math::Log1pOp>(
488 args: b.create<arith::MulFOp>(args&: minMaxRatio, args&: minMaxRatio, args&: fmfWithNaNInf),
489 args&: fmfWithNaNInf);
490 Value r = b.create<arith::AddFOp>(
491 args: b.create<arith::MulFOp>(args&: half, args&: logOfSqrtPart, args&: fmfWithNaNInf),
492 args&: logOfMaxAbsOfRealPlusOneAndImag, args&: fmfWithNaNInf);
493 Value resultReal = b.create<arith::SelectOp>(
494 args: b.create<arith::CmpFOp>(args: arith::CmpFPredicate::UNO, args&: r, args&: r, args&: fmfWithNaNInf),
495 args&: minAbs, args&: r);
496 Value resultImag = b.create<math::Atan2Op>(args&: imag, args&: realPlusOne, args&: fmf);
497 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, args&: type, args&: resultReal,
498 args&: resultImag);
499 return success();
500 }
501};
502
503struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
504 using OpConversionPattern<complex::MulOp>::OpConversionPattern;
505
506 LogicalResult
507 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
508 ConversionPatternRewriter &rewriter) const override {
509 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
510 auto type = cast<ComplexType>(Val: adaptor.getLhs().getType());
511 auto elementType = cast<FloatType>(Val: type.getElementType());
512 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
513 auto fmfValue = fmf.getValue();
514 Value lhsReal = b.create<complex::ReOp>(args&: elementType, args: adaptor.getLhs());
515 Value lhsImag = b.create<complex::ImOp>(args&: elementType, args: adaptor.getLhs());
516 Value rhsReal = b.create<complex::ReOp>(args&: elementType, args: adaptor.getRhs());
517 Value rhsImag = b.create<complex::ImOp>(args&: elementType, args: adaptor.getRhs());
518 Value lhsRealTimesRhsReal =
519 b.create<arith::MulFOp>(args&: lhsReal, args&: rhsReal, args&: fmfValue);
520 Value lhsImagTimesRhsImag =
521 b.create<arith::MulFOp>(args&: lhsImag, args&: rhsImag, args&: fmfValue);
522 Value real = b.create<arith::SubFOp>(args&: lhsRealTimesRhsReal,
523 args&: lhsImagTimesRhsImag, args&: fmfValue);
524 Value lhsImagTimesRhsReal =
525 b.create<arith::MulFOp>(args&: lhsImag, args&: rhsReal, args&: fmfValue);
526 Value lhsRealTimesRhsImag =
527 b.create<arith::MulFOp>(args&: lhsReal, args&: rhsImag, args&: fmfValue);
528 Value imag = b.create<arith::AddFOp>(args&: lhsImagTimesRhsReal,
529 args&: lhsRealTimesRhsImag, args&: fmfValue);
530 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, args&: type, args&: real, args&: imag);
531 return success();
532 }
533};
534
535struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
536 using OpConversionPattern<complex::NegOp>::OpConversionPattern;
537
538 LogicalResult
539 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter) const override {
541 auto loc = op.getLoc();
542 auto type = cast<ComplexType>(Val: adaptor.getComplex().getType());
543 auto elementType = cast<FloatType>(Val: type.getElementType());
544
545 Value real =
546 rewriter.create<complex::ReOp>(location: loc, args&: elementType, args: adaptor.getComplex());
547 Value imag =
548 rewriter.create<complex::ImOp>(location: loc, args&: elementType, args: adaptor.getComplex());
549 Value negReal = rewriter.create<arith::NegFOp>(location: loc, args&: real);
550 Value negImag = rewriter.create<arith::NegFOp>(location: loc, args&: imag);
551 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, args&: type, args&: negReal, args&: negImag);
552 return success();
553 }
554};
555
556struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
557 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
558
559 std::pair<Value, Value> combine(Location loc, Value scaledExp,
560 Value reciprocalExp, Value sin, Value cos,
561 ConversionPatternRewriter &rewriter,
562 arith::FastMathFlagsAttr fmf) const override {
563 // Complex sine is defined as;
564 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
565 // Plugging in:
566 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
567 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
568 // and defining t := exp(y)
569 // We get:
570 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
571 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
572 Value sum =
573 rewriter.create<arith::AddFOp>(location: loc, args&: scaledExp, args&: reciprocalExp, args&: fmf);
574 Value resultReal = rewriter.create<arith::MulFOp>(location: loc, args&: sum, args&: sin, args&: fmf);
575 Value diff =
576 rewriter.create<arith::SubFOp>(location: loc, args&: scaledExp, args&: reciprocalExp, args&: fmf);
577 Value resultImag = rewriter.create<arith::MulFOp>(location: loc, args&: diff, args&: cos, args&: fmf);
578 return {resultReal, resultImag};
579 }
580};
581
582// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
583struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
584 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
585
586 LogicalResult
587 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
588 ConversionPatternRewriter &rewriter) const override {
589 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
590
591 auto type = cast<ComplexType>(Val: op.getType());
592 auto elementType = cast<FloatType>(Val: type.getElementType());
593 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
594
595 auto cst = [&](APFloat v) {
596 return b.create<arith::ConstantOp>(args&: elementType,
597 args: b.getFloatAttr(type: elementType, value: v));
598 };
599 const auto &floatSemantics = elementType.getFloatSemantics();
600 Value zero = cst(APFloat::getZero(Sem: floatSemantics));
601 Value half = b.create<arith::ConstantOp>(args&: elementType,
602 args: b.getFloatAttr(type: elementType, value: 0.5));
603
604 Value real = b.create<complex::ReOp>(args&: elementType, args: adaptor.getComplex());
605 Value imag = b.create<complex::ImOp>(args&: elementType, args: adaptor.getComplex());
606 Value absSqrt = computeAbs(real, imag, fmf, b, fn: AbsFn::sqrt);
607 Value argArg = b.create<math::Atan2Op>(args&: imag, args&: real, args&: fmf);
608 Value sqrtArg = b.create<arith::MulFOp>(args&: argArg, args&: half, args&: fmf);
609 Value cos = b.create<math::CosOp>(args&: sqrtArg, args&: fmf);
610 Value sin = b.create<math::SinOp>(args&: sqrtArg, args&: fmf);
611 // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
612 // 0 * inf.
613 Value sinIsZero =
614 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: sin, args&: zero, args&: fmf);
615
616 Value resultReal = b.create<arith::MulFOp>(args&: absSqrt, args&: cos, args&: fmf);
617 Value resultImag = b.create<arith::SelectOp>(
618 args&: sinIsZero, args&: zero, args: b.create<arith::MulFOp>(args&: absSqrt, args&: sin, args&: fmf));
619 if (!arith::bitEnumContainsAll(bits: fmf, bit: arith::FastMathFlags::nnan |
620 arith::FastMathFlags::ninf)) {
621 Value inf = cst(APFloat::getInf(Sem: floatSemantics));
622 Value negInf = cst(APFloat::getInf(Sem: floatSemantics, Negative: true));
623 Value nan = cst(APFloat::getNaN(Sem: floatSemantics));
624 Value absImag = b.create<math::AbsFOp>(args&: elementType, args&: imag, args&: fmf);
625
626 Value absImagIsInf =
627 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: absImag, args&: inf, args&: fmf);
628 Value absImagIsNotInf =
629 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::ONE, args&: absImag, args&: inf, args&: fmf);
630 Value realIsInf =
631 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: real, args&: inf, args&: fmf);
632 Value realIsNegInf =
633 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: real, args&: negInf, args&: fmf);
634
635 resultReal = b.create<arith::SelectOp>(
636 args: b.create<arith::AndIOp>(args&: realIsNegInf, args&: absImagIsNotInf), args&: zero,
637 args&: resultReal);
638 resultReal = b.create<arith::SelectOp>(
639 args: b.create<arith::OrIOp>(args&: absImagIsInf, args&: realIsInf), args&: inf, args&: resultReal);
640
641 Value imagSignInf = b.create<math::CopySignOp>(args&: inf, args&: imag, args&: fmf);
642 resultImag = b.create<arith::SelectOp>(
643 args: b.create<arith::CmpFOp>(args: arith::CmpFPredicate::UNO, args&: absSqrt, args&: absSqrt),
644 args&: nan, args&: resultImag);
645 resultImag = b.create<arith::SelectOp>(
646 args: b.create<arith::OrIOp>(args&: absImagIsInf, args&: realIsNegInf), args&: imagSignInf,
647 args&: resultImag);
648 }
649
650 Value resultIsZero =
651 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: absSqrt, args&: zero, args&: fmf);
652 resultReal = b.create<arith::SelectOp>(args&: resultIsZero, args&: zero, args&: resultReal);
653 resultImag = b.create<arith::SelectOp>(args&: resultIsZero, args&: zero, args&: resultImag);
654
655 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, args&: type, args&: resultReal,
656 args&: resultImag);
657 return success();
658 }
659};
660
661struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
662 using OpConversionPattern<complex::SignOp>::OpConversionPattern;
663
664 LogicalResult
665 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
666 ConversionPatternRewriter &rewriter) const override {
667 auto type = cast<ComplexType>(Val: adaptor.getComplex().getType());
668 auto elementType = cast<FloatType>(Val: type.getElementType());
669 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
670 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
671
672 Value real = b.create<complex::ReOp>(args&: elementType, args: adaptor.getComplex());
673 Value imag = b.create<complex::ImOp>(args&: elementType, args: adaptor.getComplex());
674 Value zero =
675 b.create<arith::ConstantOp>(args&: elementType, args: b.getZeroAttr(type: elementType));
676 Value realIsZero =
677 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: real, args&: zero);
678 Value imagIsZero =
679 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: imag, args&: zero);
680 Value isZero = b.create<arith::AndIOp>(args&: realIsZero, args&: imagIsZero);
681 auto abs = b.create<complex::AbsOp>(args&: elementType, args: adaptor.getComplex(), args&: fmf);
682 Value realSign = b.create<arith::DivFOp>(args&: real, args&: abs, args&: fmf);
683 Value imagSign = b.create<arith::DivFOp>(args&: imag, args&: abs, args&: fmf);
684 Value sign = b.create<complex::CreateOp>(args&: type, args&: realSign, args&: imagSign);
685 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, args&: isZero,
686 args: adaptor.getComplex(), args&: sign);
687 return success();
688 }
689};
690
691template <typename Op>
692struct TanTanhOpConversion : public OpConversionPattern<Op> {
693 using OpConversionPattern<Op>::OpConversionPattern;
694
695 LogicalResult
696 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
697 ConversionPatternRewriter &rewriter) const override {
698 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
699 auto loc = op.getLoc();
700 auto type = cast<ComplexType>(adaptor.getComplex().getType());
701 auto elementType = cast<FloatType>(type.getElementType());
702 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
703 const auto &floatSemantics = elementType.getFloatSemantics();
704
705 Value real =
706 b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
707 Value imag =
708 b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
709 Value negOne = b.create<arith::ConstantOp>(
710 elementType, b.getFloatAttr(elementType, -1.0));
711
712 if constexpr (std::is_same_v<Op, complex::TanOp>) {
713 // tan(x+yi) = -i*tanh(-y + xi)
714 std::swap(a&: real, b&: imag);
715 real = b.create<arith::MulFOp>(args&: real, args&: negOne, args&: fmf);
716 }
717
718 auto cst = [&](APFloat v) {
719 return b.create<arith::ConstantOp>(elementType,
720 b.getFloatAttr(elementType, v));
721 };
722 Value inf = cst(APFloat::getInf(Sem: floatSemantics));
723 Value four = b.create<arith::ConstantOp>(elementType,
724 b.getFloatAttr(elementType, 4.0));
725 Value twoReal = b.create<arith::AddFOp>(args&: real, args&: real, args&: fmf);
726 Value negTwoReal = b.create<arith::MulFOp>(args&: negOne, args&: twoReal, args&: fmf);
727
728 Value expTwoRealMinusOne = b.create<math::ExpM1Op>(args&: twoReal, args&: fmf);
729 Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(args&: negTwoReal, args&: fmf);
730 Value realNum =
731 b.create<arith::SubFOp>(args&: expTwoRealMinusOne, args&: expNegTwoRealMinusOne, args&: fmf);
732
733 Value cosImag = b.create<math::CosOp>(args&: imag, args&: fmf);
734 Value cosImagSq = b.create<arith::MulFOp>(args&: cosImag, args&: cosImag, args&: fmf);
735 Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(args&: cosImagSq, args&: four, args&: fmf);
736 Value sinImag = b.create<math::SinOp>(args&: imag, args&: fmf);
737
738 Value imagNum = b.create<arith::MulFOp>(
739 args&: four, args: b.create<arith::MulFOp>(args&: cosImag, args&: sinImag, args&: fmf), args&: fmf);
740
741 Value expSumMinusTwo =
742 b.create<arith::AddFOp>(args&: expTwoRealMinusOne, args&: expNegTwoRealMinusOne, args&: fmf);
743 Value denom =
744 b.create<arith::AddFOp>(args&: expSumMinusTwo, args&: twoCosTwoImagPlusOne, args&: fmf);
745
746 Value isInf = b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ,
747 args&: expSumMinusTwo, args&: inf, args&: fmf);
748 Value realLimit = b.create<math::CopySignOp>(args&: negOne, args&: real, args&: fmf);
749
750 Value resultReal = b.create<arith::SelectOp>(
751 args&: isInf, args&: realLimit, args: b.create<arith::DivFOp>(args&: realNum, args&: denom, args&: fmf));
752 Value resultImag = b.create<arith::DivFOp>(args&: imagNum, args&: denom, args&: fmf);
753
754 if (!arith::bitEnumContainsAll(bits: fmf, bit: arith::FastMathFlags::nnan |
755 arith::FastMathFlags::ninf)) {
756 Value absReal = b.create<math::AbsFOp>(args&: real, args&: fmf);
757 Value zero = b.create<arith::ConstantOp>(
758 elementType, b.getFloatAttr(elementType, 0.0));
759 Value nan = cst(APFloat::getNaN(Sem: floatSemantics));
760
761 Value absRealIsInf =
762 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: absReal, args&: inf, args&: fmf);
763 Value imagIsZero =
764 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: imag, args&: zero, args&: fmf);
765 Value absRealIsNotInf = b.create<arith::XOrIOp>(
766 args&: absRealIsInf, args: b.create<arith::ConstantIntOp>(args: true, /*width=*/args: 1));
767
768 Value imagNumIsNaN = b.create<arith::CmpFOp>(args: arith::CmpFPredicate::UNO,
769 args&: imagNum, args&: imagNum, args&: fmf);
770 Value resultRealIsNaN =
771 b.create<arith::AndIOp>(args&: imagNumIsNaN, args&: absRealIsNotInf);
772 Value resultImagIsZero = b.create<arith::OrIOp>(
773 args&: imagIsZero, args: b.create<arith::AndIOp>(args&: absRealIsInf, args&: imagNumIsNaN));
774
775 resultReal = b.create<arith::SelectOp>(args&: resultRealIsNaN, args&: nan, args&: resultReal);
776 resultImag =
777 b.create<arith::SelectOp>(args&: resultImagIsZero, args&: zero, args&: resultImag);
778 }
779
780 if constexpr (std::is_same_v<Op, complex::TanOp>) {
781 // tan(x+yi) = -i*tanh(-y + xi)
782 std::swap(a&: resultReal, b&: resultImag);
783 resultImag = b.create<arith::MulFOp>(args&: resultImag, args&: negOne, args&: fmf);
784 }
785
786 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
787 resultImag);
788 return success();
789 }
790};
791
792struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
793 using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
794
795 LogicalResult
796 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
797 ConversionPatternRewriter &rewriter) const override {
798 auto loc = op.getLoc();
799 auto type = cast<ComplexType>(Val: adaptor.getComplex().getType());
800 auto elementType = cast<FloatType>(Val: type.getElementType());
801 Value real =
802 rewriter.create<complex::ReOp>(location: loc, args&: elementType, args: adaptor.getComplex());
803 Value imag =
804 rewriter.create<complex::ImOp>(location: loc, args&: elementType, args: adaptor.getComplex());
805 Value negImag = rewriter.create<arith::NegFOp>(location: loc, args&: elementType, args&: imag);
806
807 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, args&: type, args&: real, args&: negImag);
808
809 return success();
810 }
811};
812
813/// Converts lhs^y = (a+bi)^(c+di) to
814/// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
815/// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
816static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
817 ComplexType type, Value lhs, Value c, Value d,
818 arith::FastMathFlags fmf) {
819 auto elementType = cast<FloatType>(Val: type.getElementType());
820
821 Value a = builder.create<complex::ReOp>(args&: lhs);
822 Value b = builder.create<complex::ImOp>(args&: lhs);
823
824 Value abs = builder.create<complex::AbsOp>(args&: lhs, args&: fmf);
825 Value absToC = builder.create<math::PowFOp>(args&: abs, args&: c, args&: fmf);
826
827 Value negD = builder.create<arith::NegFOp>(args&: d, args&: fmf);
828 Value argLhs = builder.create<math::Atan2Op>(args&: b, args&: a, args&: fmf);
829 Value negDArgLhs = builder.create<arith::MulFOp>(args&: negD, args&: argLhs, args&: fmf);
830 Value expNegDArgLhs = builder.create<math::ExpOp>(args&: negDArgLhs, args&: fmf);
831
832 Value coeff = builder.create<arith::MulFOp>(args&: absToC, args&: expNegDArgLhs, args&: fmf);
833 Value lnAbs = builder.create<math::LogOp>(args&: abs, args&: fmf);
834 Value cArgLhs = builder.create<arith::MulFOp>(args&: c, args&: argLhs, args&: fmf);
835 Value dLnAbs = builder.create<arith::MulFOp>(args&: d, args&: lnAbs, args&: fmf);
836 Value q = builder.create<arith::AddFOp>(args&: cArgLhs, args&: dLnAbs, args&: fmf);
837 Value cosQ = builder.create<math::CosOp>(args&: q, args&: fmf);
838 Value sinQ = builder.create<math::SinOp>(args&: q, args&: fmf);
839
840 Value inf = builder.create<arith::ConstantOp>(
841 args&: elementType,
842 args: builder.getFloatAttr(type: elementType,
843 value: APFloat::getInf(Sem: elementType.getFloatSemantics())));
844 Value zero = builder.create<arith::ConstantOp>(
845 args&: elementType, args: builder.getFloatAttr(type: elementType, value: 0.0));
846 Value one = builder.create<arith::ConstantOp>(
847 args&: elementType, args: builder.getFloatAttr(type: elementType, value: 1.0));
848 Value complexOne = builder.create<complex::CreateOp>(args&: type, args&: one, args&: zero);
849 Value complexZero = builder.create<complex::CreateOp>(args&: type, args&: zero, args&: zero);
850 Value complexInf = builder.create<complex::CreateOp>(args&: type, args&: inf, args&: zero);
851
852 // Case 0:
853 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
854 // Branch Cuts for Complex Elementary Functions or Much Ado About
855 // Nothing's Sign Bit, W. Kahan, Section 10.
856 Value absEqZero =
857 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: abs, args&: zero, args&: fmf);
858 Value dEqZero =
859 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: d, args&: zero, args&: fmf);
860 Value cEqZero =
861 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: c, args&: zero, args&: fmf);
862 Value bEqZero =
863 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: b, args&: zero, args&: fmf);
864
865 Value zeroLeC =
866 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OLE, args&: zero, args&: c, args&: fmf);
867 Value coeffCosQ = builder.create<arith::MulFOp>(args&: coeff, args&: cosQ, args&: fmf);
868 Value coeffSinQ = builder.create<arith::MulFOp>(args&: coeff, args&: sinQ, args&: fmf);
869 Value complexOneOrZero =
870 builder.create<arith::SelectOp>(args&: cEqZero, args&: complexOne, args&: complexZero);
871 Value coeffCosSin =
872 builder.create<complex::CreateOp>(args&: type, args&: coeffCosQ, args&: coeffSinQ);
873 Value cutoff0 = builder.create<arith::SelectOp>(
874 args: builder.create<arith::AndIOp>(
875 args: builder.create<arith::AndIOp>(args&: absEqZero, args&: dEqZero), args&: zeroLeC),
876 args&: complexOneOrZero, args&: coeffCosSin);
877
878 // Case 1:
879 // x^0 is defined to be 1 for any x, see
880 // Branch Cuts for Complex Elementary Functions or Much Ado About
881 // Nothing's Sign Bit, W. Kahan, Section 10.
882 Value rhsEqZero = builder.create<arith::AndIOp>(args&: cEqZero, args&: dEqZero);
883 Value cutoff1 =
884 builder.create<arith::SelectOp>(args&: rhsEqZero, args&: complexOne, args&: cutoff0);
885
886 // Case 2:
887 // 1^(c + d*i) = 1 + 0*i
888 Value lhsEqOne = builder.create<arith::AndIOp>(
889 args: builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: a, args&: one, args&: fmf),
890 args&: bEqZero);
891 Value cutoff2 =
892 builder.create<arith::SelectOp>(args&: lhsEqOne, args&: complexOne, args&: cutoff1);
893
894 // Case 3:
895 // inf^(c + 0*i) = inf + 0*i, c > 0
896 Value lhsEqInf = builder.create<arith::AndIOp>(
897 args: builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: a, args&: inf, args&: fmf),
898 args&: bEqZero);
899 Value rhsGt0 = builder.create<arith::AndIOp>(
900 args&: dEqZero,
901 args: builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: c, args&: zero, args&: fmf));
902 Value cutoff3 = builder.create<arith::SelectOp>(
903 args: builder.create<arith::AndIOp>(args&: lhsEqInf, args&: rhsGt0), args&: complexInf, args&: cutoff2);
904
905 // Case 4:
906 // inf^(c + 0*i) = 0 + 0*i, c < 0
907 Value rhsLt0 = builder.create<arith::AndIOp>(
908 args&: dEqZero,
909 args: builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OLT, args&: c, args&: zero, args&: fmf));
910 Value cutoff4 = builder.create<arith::SelectOp>(
911 args: builder.create<arith::AndIOp>(args&: lhsEqInf, args&: rhsLt0), args&: complexZero, args&: cutoff3);
912
913 return cutoff4;
914}
915
916struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
917 using OpConversionPattern<complex::PowOp>::OpConversionPattern;
918
919 LogicalResult
920 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
921 ConversionPatternRewriter &rewriter) const override {
922 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
923 auto type = cast<ComplexType>(Val: adaptor.getLhs().getType());
924 auto elementType = cast<FloatType>(Val: type.getElementType());
925
926 Value c = builder.create<complex::ReOp>(args&: elementType, args: adaptor.getRhs());
927 Value d = builder.create<complex::ImOp>(args&: elementType, args: adaptor.getRhs());
928
929 rewriter.replaceOp(op, newValues: {powOpConversionImpl(builder, type, lhs: adaptor.getLhs(),
930 c, d, fmf: op.getFastmath())});
931 return success();
932 }
933};
934
935struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
936 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
937
938 LogicalResult
939 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
940 ConversionPatternRewriter &rewriter) const override {
941 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
942 auto type = cast<ComplexType>(Val: adaptor.getComplex().getType());
943 auto elementType = cast<FloatType>(Val: type.getElementType());
944
945 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
946
947 auto cst = [&](APFloat v) {
948 return b.create<arith::ConstantOp>(args&: elementType,
949 args: b.getFloatAttr(type: elementType, value: v));
950 };
951 const auto &floatSemantics = elementType.getFloatSemantics();
952 Value zero = cst(APFloat::getZero(Sem: floatSemantics));
953 Value inf = cst(APFloat::getInf(Sem: floatSemantics));
954 Value negHalf = b.create<arith::ConstantOp>(
955 args&: elementType, args: b.getFloatAttr(type: elementType, value: -0.5));
956 Value nan = cst(APFloat::getNaN(Sem: floatSemantics));
957
958 Value real = b.create<complex::ReOp>(args&: elementType, args: adaptor.getComplex());
959 Value imag = b.create<complex::ImOp>(args&: elementType, args: adaptor.getComplex());
960 Value absRsqrt = computeAbs(real, imag, fmf, b, fn: AbsFn::rsqrt);
961 Value argArg = b.create<math::Atan2Op>(args&: imag, args&: real, args&: fmf);
962 Value rsqrtArg = b.create<arith::MulFOp>(args&: argArg, args&: negHalf, args&: fmf);
963 Value cos = b.create<math::CosOp>(args&: rsqrtArg, args&: fmf);
964 Value sin = b.create<math::SinOp>(args&: rsqrtArg, args&: fmf);
965
966 Value resultReal = b.create<arith::MulFOp>(args&: absRsqrt, args&: cos, args&: fmf);
967 Value resultImag = b.create<arith::MulFOp>(args&: absRsqrt, args&: sin, args&: fmf);
968
969 if (!arith::bitEnumContainsAll(bits: fmf, bit: arith::FastMathFlags::nnan |
970 arith::FastMathFlags::ninf)) {
971 Value negOne = b.create<arith::ConstantOp>(
972 args&: elementType, args: b.getFloatAttr(type: elementType, value: -1));
973
974 Value realSignedZero = b.create<math::CopySignOp>(args&: zero, args&: real, args&: fmf);
975 Value imagSignedZero = b.create<math::CopySignOp>(args&: zero, args&: imag, args&: fmf);
976 Value negImagSignedZero =
977 b.create<arith::MulFOp>(args&: negOne, args&: imagSignedZero, args&: fmf);
978
979 Value absReal = b.create<math::AbsFOp>(args&: real, args&: fmf);
980 Value absImag = b.create<math::AbsFOp>(args&: imag, args&: fmf);
981
982 Value absImagIsInf =
983 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: absImag, args&: inf, args&: fmf);
984 Value realIsNan =
985 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::UNO, args&: real, args&: real, args&: fmf);
986 Value realIsInf =
987 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: absReal, args&: inf, args&: fmf);
988 Value inIsNanInf = b.create<arith::AndIOp>(args&: absImagIsInf, args&: realIsNan);
989
990 Value resultIsZero = b.create<arith::OrIOp>(args&: inIsNanInf, args&: realIsInf);
991
992 resultReal =
993 b.create<arith::SelectOp>(args&: resultIsZero, args&: realSignedZero, args&: resultReal);
994 resultImag = b.create<arith::SelectOp>(args&: resultIsZero, args&: negImagSignedZero,
995 args&: resultImag);
996 }
997
998 Value isRealZero =
999 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: real, args&: zero, args&: fmf);
1000 Value isImagZero =
1001 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: imag, args&: zero, args&: fmf);
1002 Value isZero = b.create<arith::AndIOp>(args&: isRealZero, args&: isImagZero);
1003
1004 resultReal = b.create<arith::SelectOp>(args&: isZero, args&: inf, args&: resultReal);
1005 resultImag = b.create<arith::SelectOp>(args&: isZero, args&: nan, args&: resultImag);
1006
1007 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, args&: type, args&: resultReal,
1008 args&: resultImag);
1009 return success();
1010 }
1011};
1012
1013struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1014 using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
1015
1016 LogicalResult
1017 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1018 ConversionPatternRewriter &rewriter) const override {
1019 auto loc = op.getLoc();
1020 auto type = op.getType();
1021 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1022
1023 Value real =
1024 rewriter.create<complex::ReOp>(location: loc, args&: type, args: adaptor.getComplex());
1025 Value imag =
1026 rewriter.create<complex::ImOp>(location: loc, args&: type, args: adaptor.getComplex());
1027
1028 rewriter.replaceOpWithNewOp<math::Atan2Op>(op, args&: imag, args&: real, args&: fmf);
1029
1030 return success();
1031 }
1032};
1033
1034} // namespace
1035
1036void mlir::populateComplexToStandardConversionPatterns(
1037 RewritePatternSet &patterns, complex::ComplexRangeFlags complexRange) {
1038 // clang-format off
1039 patterns.add<
1040 AbsOpConversion,
1041 AngleOpConversion,
1042 Atan2OpConversion,
1043 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1044 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1045 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1046 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1047 ConjOpConversion,
1048 CosOpConversion,
1049 ExpOpConversion,
1050 Expm1OpConversion,
1051 Log1pOpConversion,
1052 LogOpConversion,
1053 MulOpConversion,
1054 NegOpConversion,
1055 SignOpConversion,
1056 SinOpConversion,
1057 SqrtOpConversion,
1058 TanTanhOpConversion<complex::TanOp>,
1059 TanTanhOpConversion<complex::TanhOp>,
1060 PowOpConversion,
1061 RsqrtOpConversion
1062 >(arg: patterns.getContext());
1063
1064 patterns.add<DivOpConversion>(arg: patterns.getContext(), args&: complexRange);
1065
1066 // clang-format on
1067}
1068
1069namespace {
1070struct ConvertComplexToStandardPass
1071 : public impl::ConvertComplexToStandardPassBase<
1072 ConvertComplexToStandardPass> {
1073 using Base::Base;
1074
1075 void runOnOperation() override;
1076};
1077
1078void ConvertComplexToStandardPass::runOnOperation() {
1079 // Convert to the Standard dialect using the converter defined above.
1080 RewritePatternSet patterns(&getContext());
1081 populateComplexToStandardConversionPatterns(patterns, complexRange);
1082
1083 ConversionTarget target(getContext());
1084 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1085 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1086 if (failed(
1087 Result: applyPartialConversion(op: getOperation(), target, patterns: std::move(patterns))))
1088 signalPassFailure();
1089}
1090} // namespace
1091

source code of mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp