1//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
10
11#include "mlir/Conversion/ComplexCommon/DivisionConverter.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Complex/IR/Complex.h"
14#include "mlir/Dialect/Math/IR/Math.h"
15#include "mlir/IR/ImplicitLocOpBuilder.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/DialectConversion.h"
19#include <memory>
20#include <type_traits>
21
22namespace mlir {
23#define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28
29namespace {
30
31enum class AbsFn { abs, sqrt, rsqrt };
32
33// Returns the absolute value, its square root or its reciprocal square root.
34Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
35 ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
36 Value one = b.create<arith::ConstantOp>(real.getType(),
37 b.getFloatAttr(real.getType(), 1.0));
38
39 Value absReal = b.create<math::AbsFOp>(real, fmf);
40 Value absImag = b.create<math::AbsFOp>(imag, fmf);
41
42 Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
43 Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
44
45 // The lowering below requires NaNs and infinities to work correctly.
46 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
47 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
48 Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf);
49 Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
50 Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
51 Value result;
52
53 if (fn == AbsFn::rsqrt) {
54 ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
55 min = b.create<math::RsqrtOp>(min, fmfWithNaNInf);
56 max = b.create<math::RsqrtOp>(max, fmfWithNaNInf);
57 }
58
59 if (fn == AbsFn::sqrt) {
60 Value quarter = b.create<arith::ConstantOp>(
61 real.getType(), b.getFloatAttr(real.getType(), 0.25));
62 // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
63 Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf);
64 Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
65 result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
66 } else {
67 Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
68 result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf);
69 }
70
71 Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
72 result, fmfWithNaNInf);
73 return b.create<arith::SelectOp>(isNaN, min, result);
74}
75
76struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
77 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
78
79 LogicalResult
80 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
81 ConversionPatternRewriter &rewriter) const override {
82 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
83
84 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
85
86 Value real = b.create<complex::ReOp>(adaptor.getComplex());
87 Value imag = b.create<complex::ImOp>(adaptor.getComplex());
88 rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
89
90 return success();
91 }
92};
93
94// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
95struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
96 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
97
98 LogicalResult
99 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
100 ConversionPatternRewriter &rewriter) const override {
101 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
102
103 auto type = cast<ComplexType>(op.getType());
104 Type elementType = type.getElementType();
105 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
106
107 Value lhs = adaptor.getLhs();
108 Value rhs = adaptor.getRhs();
109
110 Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
111 Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
112 Value rhsSquaredPlusLhsSquared =
113 b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
114 Value sqrtOfRhsSquaredPlusLhsSquared =
115 b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
116
117 Value zero =
118 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
119 Value one = b.create<arith::ConstantOp>(elementType,
120 b.getFloatAttr(elementType, 1));
121 Value i = b.create<complex::CreateOp>(type, zero, one);
122 Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
123 Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
124
125 Value divResult = b.create<complex::DivOp>(
126 rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
127 Value logResult = b.create<complex::LogOp>(divResult, fmf);
128
129 Value negativeOne = b.create<arith::ConstantOp>(
130 elementType, b.getFloatAttr(elementType, -1));
131 Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
132
133 rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
134 return success();
135 }
136};
137
138template <typename ComparisonOp, arith::CmpFPredicate p>
139struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
140 using OpConversionPattern<ComparisonOp>::OpConversionPattern;
141 using ResultCombiner =
142 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
143 arith::AndIOp, arith::OrIOp>;
144
145 LogicalResult
146 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
147 ConversionPatternRewriter &rewriter) const override {
148 auto loc = op.getLoc();
149 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
150
151 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
152 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
153 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
154 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
155 Value realComparison =
156 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
157 Value imagComparison =
158 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
159
160 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
161 imagComparison);
162 return success();
163 }
164};
165
166// Default conversion which applies the BinaryStandardOp separately on the real
167// and imaginary parts. Can for example be used for complex::AddOp and
168// complex::SubOp.
169template <typename BinaryComplexOp, typename BinaryStandardOp>
170struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
171 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
172
173 LogicalResult
174 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
175 ConversionPatternRewriter &rewriter) const override {
176 auto type = cast<ComplexType>(adaptor.getLhs().getType());
177 auto elementType = cast<FloatType>(type.getElementType());
178 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
179 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
180
181 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
182 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
183 Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
184 fmf.getValue());
185 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
186 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
187 Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
188 fmf.getValue());
189 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
190 resultImag);
191 return success();
192 }
193};
194
195template <typename TrigonometricOp>
196struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
197 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
198
199 using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
200
201 LogicalResult
202 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter) const override {
204 auto loc = op.getLoc();
205 auto type = cast<ComplexType>(adaptor.getComplex().getType());
206 auto elementType = cast<FloatType>(type.getElementType());
207 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
208
209 Value real =
210 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
211 Value imag =
212 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
213
214 // Trigonometric ops use a set of common building blocks to convert to real
215 // ops. Here we create these building blocks and call into an op-specific
216 // implementation in the subclass to combine them.
217 Value half = rewriter.create<arith::ConstantOp>(
218 loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
219 Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
220 Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
221 Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
222 Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
223 Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
224
225 auto resultPair =
226 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
227
228 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
229 resultPair.second);
230 return success();
231 }
232
233 virtual std::pair<Value, Value>
234 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
235 Value cos, ConversionPatternRewriter &rewriter,
236 arith::FastMathFlagsAttr fmf) const = 0;
237};
238
239struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
240 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
241
242 std::pair<Value, Value> combine(Location loc, Value scaledExp,
243 Value reciprocalExp, Value sin, Value cos,
244 ConversionPatternRewriter &rewriter,
245 arith::FastMathFlagsAttr fmf) const override {
246 // Complex cosine is defined as;
247 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
248 // Plugging in:
249 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
250 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
251 // and defining t := exp(y)
252 // We get:
253 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
254 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
255 Value sum =
256 rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
257 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
258 Value diff =
259 rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
260 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
261 return {resultReal, resultImag};
262 }
263};
264
265struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
266 DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags target)
267 : OpConversionPattern<complex::DivOp>(context), complexRange(target) {}
268
269 using OpConversionPattern<complex::DivOp>::OpConversionPattern;
270
271 LogicalResult
272 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter) const override {
274 auto loc = op.getLoc();
275 auto type = cast<ComplexType>(adaptor.getLhs().getType());
276 auto elementType = cast<FloatType>(type.getElementType());
277 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
278
279 Value lhsReal =
280 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
281 Value lhsImag =
282 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
283 Value rhsReal =
284 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
285 Value rhsImag =
286 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
287
288 Value resultReal, resultImag;
289
290 if (complexRange == complex::ComplexRangeFlags::basic ||
291 complexRange == complex::ComplexRangeFlags::none) {
292 mlir::complex::convertDivToStandardUsingAlgebraic(
293 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
294 &resultImag);
295 } else if (complexRange == complex::ComplexRangeFlags::improved) {
296 mlir::complex::convertDivToStandardUsingRangeReduction(
297 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
298 &resultImag);
299 }
300
301 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
302 resultImag);
303
304 return success();
305 }
306
307private:
308 complex::ComplexRangeFlags complexRange;
309};
310
311struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
312 using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
313
314 LogicalResult
315 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
316 ConversionPatternRewriter &rewriter) const override {
317 auto loc = op.getLoc();
318 auto type = cast<ComplexType>(adaptor.getComplex().getType());
319 auto elementType = cast<FloatType>(type.getElementType());
320 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
321
322 Value real =
323 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
324 Value imag =
325 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
326 Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
327 Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
328 Value resultReal =
329 rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
330 Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
331 Value resultImag =
332 rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
333
334 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
335 resultImag);
336 return success();
337 }
338};
339
340Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
341 ArrayRef<double> coefficients,
342 arith::FastMathFlagsAttr fmf) {
343 auto argType = mlir::cast<FloatType>(arg.getType());
344 Value poly =
345 b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
346 for (unsigned i = 1; i < coefficients.size(); ++i) {
347 poly = b.create<math::FmaOp>(
348 poly, arg,
349 b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
350 fmf);
351 }
352 return poly;
353}
354
355struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
356 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
357
358 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
359 // [handle inaccuracies when a and/or b are small]
360 // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
361 // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
362 LogicalResult
363 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
364 ConversionPatternRewriter &rewriter) const override {
365 auto type = op.getType();
366 auto elemType = mlir::cast<FloatType>(type.getElementType());
367
368 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
369 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
370 Value real = b.create<complex::ReOp>(adaptor.getComplex());
371 Value imag = b.create<complex::ImOp>(adaptor.getComplex());
372
373 Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
374 Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
375
376 Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
377 Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
378
379 Value sinImag = b.create<math::SinOp>(imag, fmf);
380 Value cosm1Imag = emitCosm1(imag, fmf, b);
381 Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
382
383 Value realResult = b.create<arith::AddFOp>(
384 b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
385
386 Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
387 zero, fmf.getValue());
388 Value imagResult = b.create<arith::SelectOp>(
389 imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
390
391 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
392 imagResult);
393 return success();
394 }
395
396private:
397 Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
398 ImplicitLocOpBuilder &b) const {
399 auto argType = mlir::cast<FloatType>(arg.getType());
400 auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
401 auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
402
403 // Algorithm copied from cephes cosm1.
404 SmallVector<double, 7> kCoeffs{
405 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
406 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
407 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
408 4.1666666666666666609054E-2,
409 };
410 Value cos = b.create<math::CosOp>(arg, fmf);
411 Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
412
413 Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
414 Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
415 Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
416
417 auto forSmallArg =
418 b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
419 b.create<arith::MulFOp>(negHalf, argPow2, fmf));
420
421 // (pi/4)^2 is approximately 0.61685
422 Value piOver4Pow2 =
423 b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
424 Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
425 piOver4Pow2, fmf.getValue());
426 return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
427 }
428};
429
430struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
431 using OpConversionPattern<complex::LogOp>::OpConversionPattern;
432
433 LogicalResult
434 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
435 ConversionPatternRewriter &rewriter) const override {
436 auto type = cast<ComplexType>(adaptor.getComplex().getType());
437 auto elementType = cast<FloatType>(type.getElementType());
438 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
439 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
440
441 Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
442 fmf.getValue());
443 Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
444 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
445 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
446 Value resultImag =
447 b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
448 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
449 resultImag);
450 return success();
451 }
452};
453
454struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
455 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
456
457 LogicalResult
458 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
459 ConversionPatternRewriter &rewriter) const override {
460 auto type = cast<ComplexType>(adaptor.getComplex().getType());
461 auto elementType = cast<FloatType>(type.getElementType());
462 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
463 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
464
465 Value real = b.create<complex::ReOp>(adaptor.getComplex());
466 Value imag = b.create<complex::ImOp>(adaptor.getComplex());
467
468 Value half = b.create<arith::ConstantOp>(elementType,
469 b.getFloatAttr(elementType, 0.5));
470 Value one = b.create<arith::ConstantOp>(elementType,
471 b.getFloatAttr(elementType, 1));
472 Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
473 Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
474 Value absImag = b.create<math::AbsFOp>(imag, fmf);
475
476 Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
477 Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
478
479 Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
480 realPlusOne, absImag, fmf);
481 Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
482 Value maxAbsOfRealPlusOneAndImagMinusOne =
483 b.create<arith::SelectOp>(useReal, real, maxMinusOne);
484 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
485 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
486 Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
487 Value logOfMaxAbsOfRealPlusOneAndImag =
488 b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
489 Value logOfSqrtPart = b.create<math::Log1pOp>(
490 b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
491 fmfWithNaNInf);
492 Value r = b.create<arith::AddFOp>(
493 b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
494 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
495 Value resultReal = b.create<arith::SelectOp>(
496 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
497 minAbs, r);
498 Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
499 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
500 resultImag);
501 return success();
502 }
503};
504
505struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
506 using OpConversionPattern<complex::MulOp>::OpConversionPattern;
507
508 LogicalResult
509 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
510 ConversionPatternRewriter &rewriter) const override {
511 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
512 auto type = cast<ComplexType>(adaptor.getLhs().getType());
513 auto elementType = cast<FloatType>(type.getElementType());
514 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
515 auto fmfValue = fmf.getValue();
516 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
517 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
518 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
519 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
520 Value lhsRealTimesRhsReal =
521 b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
522 Value lhsImagTimesRhsImag =
523 b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
524 Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
525 lhsImagTimesRhsImag, fmfValue);
526 Value lhsImagTimesRhsReal =
527 b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
528 Value lhsRealTimesRhsImag =
529 b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
530 Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
531 lhsRealTimesRhsImag, fmfValue);
532 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
533 return success();
534 }
535};
536
537struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
538 using OpConversionPattern<complex::NegOp>::OpConversionPattern;
539
540 LogicalResult
541 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
542 ConversionPatternRewriter &rewriter) const override {
543 auto loc = op.getLoc();
544 auto type = cast<ComplexType>(adaptor.getComplex().getType());
545 auto elementType = cast<FloatType>(type.getElementType());
546
547 Value real =
548 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
549 Value imag =
550 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
551 Value negReal = rewriter.create<arith::NegFOp>(loc, real);
552 Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
553 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
554 return success();
555 }
556};
557
558struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
559 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
560
561 std::pair<Value, Value> combine(Location loc, Value scaledExp,
562 Value reciprocalExp, Value sin, Value cos,
563 ConversionPatternRewriter &rewriter,
564 arith::FastMathFlagsAttr fmf) const override {
565 // Complex sine is defined as;
566 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
567 // Plugging in:
568 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
569 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
570 // and defining t := exp(y)
571 // We get:
572 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
573 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
574 Value sum =
575 rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
576 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
577 Value diff =
578 rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
579 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
580 return {resultReal, resultImag};
581 }
582};
583
584// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
585struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
586 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
587
588 LogicalResult
589 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
590 ConversionPatternRewriter &rewriter) const override {
591 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
592
593 auto type = cast<ComplexType>(op.getType());
594 auto elementType = cast<FloatType>(type.getElementType());
595 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
596
597 auto cst = [&](APFloat v) {
598 return b.create<arith::ConstantOp>(elementType,
599 b.getFloatAttr(elementType, v));
600 };
601 const auto &floatSemantics = elementType.getFloatSemantics();
602 Value zero = cst(APFloat::getZero(Sem: floatSemantics));
603 Value half = b.create<arith::ConstantOp>(elementType,
604 b.getFloatAttr(elementType, 0.5));
605
606 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
607 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
608 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
609 Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
610 Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
611 Value cos = b.create<math::CosOp>(sqrtArg, fmf);
612 Value sin = b.create<math::SinOp>(sqrtArg, fmf);
613 // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
614 // 0 * inf.
615 Value sinIsZero =
616 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
617
618 Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
619 Value resultImag = b.create<arith::SelectOp>(
620 sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
621 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
622 arith::FastMathFlags::ninf)) {
623 Value inf = cst(APFloat::getInf(Sem: floatSemantics));
624 Value negInf = cst(APFloat::getInf(Sem: floatSemantics, Negative: true));
625 Value nan = cst(APFloat::getNaN(Sem: floatSemantics));
626 Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
627
628 Value absImagIsInf =
629 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
630 Value absImagIsNotInf =
631 b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
632 Value realIsInf =
633 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
634 Value realIsNegInf =
635 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
636
637 resultReal = b.create<arith::SelectOp>(
638 b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
639 resultReal);
640 resultReal = b.create<arith::SelectOp>(
641 b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
642
643 Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
644 resultImag = b.create<arith::SelectOp>(
645 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
646 nan, resultImag);
647 resultImag = b.create<arith::SelectOp>(
648 b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
649 resultImag);
650 }
651
652 Value resultIsZero =
653 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
654 resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
655 resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
656
657 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
658 resultImag);
659 return success();
660 }
661};
662
663struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
664 using OpConversionPattern<complex::SignOp>::OpConversionPattern;
665
666 LogicalResult
667 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
668 ConversionPatternRewriter &rewriter) const override {
669 auto type = cast<ComplexType>(adaptor.getComplex().getType());
670 auto elementType = cast<FloatType>(type.getElementType());
671 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
672 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
673
674 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
675 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
676 Value zero =
677 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
678 Value realIsZero =
679 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
680 Value imagIsZero =
681 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
682 Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
683 auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
684 Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
685 Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
686 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
687 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
688 adaptor.getComplex(), sign);
689 return success();
690 }
691};
692
693template <typename Op>
694struct TanTanhOpConversion : public OpConversionPattern<Op> {
695 using OpConversionPattern<Op>::OpConversionPattern;
696
697 LogicalResult
698 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
699 ConversionPatternRewriter &rewriter) const override {
700 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
701 auto loc = op.getLoc();
702 auto type = cast<ComplexType>(adaptor.getComplex().getType());
703 auto elementType = cast<FloatType>(type.getElementType());
704 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
705 const auto &floatSemantics = elementType.getFloatSemantics();
706
707 Value real =
708 b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
709 Value imag =
710 b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
711 Value negOne = b.create<arith::ConstantOp>(
712 elementType, b.getFloatAttr(elementType, -1.0));
713
714 if constexpr (std::is_same_v<Op, complex::TanOp>) {
715 // tan(x+yi) = -i*tanh(-y + xi)
716 std::swap(real, imag);
717 real = b.create<arith::MulFOp>(real, negOne, fmf);
718 }
719
720 auto cst = [&](APFloat v) {
721 return b.create<arith::ConstantOp>(elementType,
722 b.getFloatAttr(elementType, v));
723 };
724 Value inf = cst(APFloat::getInf(Sem: floatSemantics));
725 Value four = b.create<arith::ConstantOp>(elementType,
726 b.getFloatAttr(elementType, 4.0));
727 Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
728 Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
729
730 Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
731 Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
732 Value realNum =
733 b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
734
735 Value cosImag = b.create<math::CosOp>(imag, fmf);
736 Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
737 Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
738 Value sinImag = b.create<math::SinOp>(imag, fmf);
739
740 Value imagNum = b.create<arith::MulFOp>(
741 four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
742
743 Value expSumMinusTwo =
744 b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
745 Value denom =
746 b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
747
748 Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
749 expSumMinusTwo, inf, fmf);
750 Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
751
752 Value resultReal = b.create<arith::SelectOp>(
753 isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
754 Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
755
756 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
757 arith::FastMathFlags::ninf)) {
758 Value absReal = b.create<math::AbsFOp>(real, fmf);
759 Value zero = b.create<arith::ConstantOp>(
760 elementType, b.getFloatAttr(elementType, 0.0));
761 Value nan = cst(APFloat::getNaN(Sem: floatSemantics));
762
763 Value absRealIsInf =
764 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
765 Value imagIsZero =
766 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
767 Value absRealIsNotInf = b.create<arith::XOrIOp>(
768 absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
769
770 Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
771 imagNum, imagNum, fmf);
772 Value resultRealIsNaN =
773 b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
774 Value resultImagIsZero = b.create<arith::OrIOp>(
775 imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
776
777 resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
778 resultImag =
779 b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
780 }
781
782 if constexpr (std::is_same_v<Op, complex::TanOp>) {
783 // tan(x+yi) = -i*tanh(-y + xi)
784 std::swap(resultReal, resultImag);
785 resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf);
786 }
787
788 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
789 resultImag);
790 return success();
791 }
792};
793
794struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
795 using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
796
797 LogicalResult
798 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
799 ConversionPatternRewriter &rewriter) const override {
800 auto loc = op.getLoc();
801 auto type = cast<ComplexType>(adaptor.getComplex().getType());
802 auto elementType = cast<FloatType>(type.getElementType());
803 Value real =
804 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
805 Value imag =
806 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
807 Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
808
809 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
810
811 return success();
812 }
813};
814
815/// Converts lhs^y = (a+bi)^(c+di) to
816/// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
817/// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
818static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
819 ComplexType type, Value lhs, Value c, Value d,
820 arith::FastMathFlags fmf) {
821 auto elementType = cast<FloatType>(type.getElementType());
822
823 Value a = builder.create<complex::ReOp>(lhs);
824 Value b = builder.create<complex::ImOp>(lhs);
825
826 Value abs = builder.create<complex::AbsOp>(lhs, fmf);
827 Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
828
829 Value negD = builder.create<arith::NegFOp>(d, fmf);
830 Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
831 Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
832 Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
833
834 Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
835 Value lnAbs = builder.create<math::LogOp>(abs, fmf);
836 Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
837 Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
838 Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
839 Value cosQ = builder.create<math::CosOp>(q, fmf);
840 Value sinQ = builder.create<math::SinOp>(q, fmf);
841
842 Value inf = builder.create<arith::ConstantOp>(
843 elementType,
844 builder.getFloatAttr(elementType,
845 APFloat::getInf(elementType.getFloatSemantics())));
846 Value zero = builder.create<arith::ConstantOp>(
847 elementType, builder.getFloatAttr(elementType, 0.0));
848 Value one = builder.create<arith::ConstantOp>(
849 elementType, builder.getFloatAttr(elementType, 1.0));
850 Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
851 Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
852 Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
853
854 // Case 0:
855 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
856 // Branch Cuts for Complex Elementary Functions or Much Ado About
857 // Nothing's Sign Bit, W. Kahan, Section 10.
858 Value absEqZero =
859 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
860 Value dEqZero =
861 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
862 Value cEqZero =
863 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
864 Value bEqZero =
865 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
866
867 Value zeroLeC =
868 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
869 Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
870 Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
871 Value complexOneOrZero =
872 builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
873 Value coeffCosSin =
874 builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
875 Value cutoff0 = builder.create<arith::SelectOp>(
876 builder.create<arith::AndIOp>(
877 builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
878 complexOneOrZero, coeffCosSin);
879
880 // Case 1:
881 // x^0 is defined to be 1 for any x, see
882 // Branch Cuts for Complex Elementary Functions or Much Ado About
883 // Nothing's Sign Bit, W. Kahan, Section 10.
884 Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
885 Value cutoff1 =
886 builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
887
888 // Case 2:
889 // 1^(c + d*i) = 1 + 0*i
890 Value lhsEqOne = builder.create<arith::AndIOp>(
891 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
892 bEqZero);
893 Value cutoff2 =
894 builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
895
896 // Case 3:
897 // inf^(c + 0*i) = inf + 0*i, c > 0
898 Value lhsEqInf = builder.create<arith::AndIOp>(
899 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
900 bEqZero);
901 Value rhsGt0 = builder.create<arith::AndIOp>(
902 dEqZero,
903 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
904 Value cutoff3 = builder.create<arith::SelectOp>(
905 builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
906
907 // Case 4:
908 // inf^(c + 0*i) = 0 + 0*i, c < 0
909 Value rhsLt0 = builder.create<arith::AndIOp>(
910 dEqZero,
911 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
912 Value cutoff4 = builder.create<arith::SelectOp>(
913 builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
914
915 return cutoff4;
916}
917
918struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
919 using OpConversionPattern<complex::PowOp>::OpConversionPattern;
920
921 LogicalResult
922 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
923 ConversionPatternRewriter &rewriter) const override {
924 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
925 auto type = cast<ComplexType>(adaptor.getLhs().getType());
926 auto elementType = cast<FloatType>(type.getElementType());
927
928 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
929 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
930
931 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
932 c, d, op.getFastmath())});
933 return success();
934 }
935};
936
937struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
938 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
939
940 LogicalResult
941 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
942 ConversionPatternRewriter &rewriter) const override {
943 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
944 auto type = cast<ComplexType>(adaptor.getComplex().getType());
945 auto elementType = cast<FloatType>(type.getElementType());
946
947 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
948
949 auto cst = [&](APFloat v) {
950 return b.create<arith::ConstantOp>(elementType,
951 b.getFloatAttr(elementType, v));
952 };
953 const auto &floatSemantics = elementType.getFloatSemantics();
954 Value zero = cst(APFloat::getZero(Sem: floatSemantics));
955 Value inf = cst(APFloat::getInf(Sem: floatSemantics));
956 Value negHalf = b.create<arith::ConstantOp>(
957 elementType, b.getFloatAttr(elementType, -0.5));
958 Value nan = cst(APFloat::getNaN(Sem: floatSemantics));
959
960 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
961 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
962 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
963 Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
964 Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
965 Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
966 Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
967
968 Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
969 Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
970
971 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
972 arith::FastMathFlags::ninf)) {
973 Value negOne = b.create<arith::ConstantOp>(
974 elementType, b.getFloatAttr(elementType, -1));
975
976 Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
977 Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
978 Value negImagSignedZero =
979 b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
980
981 Value absReal = b.create<math::AbsFOp>(real, fmf);
982 Value absImag = b.create<math::AbsFOp>(imag, fmf);
983
984 Value absImagIsInf =
985 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
986 Value realIsNan =
987 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
988 Value realIsInf =
989 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
990 Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
991
992 Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
993
994 resultReal =
995 b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
996 resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
997 resultImag);
998 }
999
1000 Value isRealZero =
1001 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1002 Value isImagZero =
1003 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1004 Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
1005
1006 resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
1007 resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
1008
1009 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1010 resultImag);
1011 return success();
1012 }
1013};
1014
1015struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1016 using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
1017
1018 LogicalResult
1019 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1020 ConversionPatternRewriter &rewriter) const override {
1021 auto loc = op.getLoc();
1022 auto type = op.getType();
1023 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1024
1025 Value real =
1026 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
1027 Value imag =
1028 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
1029
1030 rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
1031
1032 return success();
1033 }
1034};
1035
1036} // namespace
1037
1038void mlir::populateComplexToStandardConversionPatterns(
1039 RewritePatternSet &patterns, complex::ComplexRangeFlags complexRange) {
1040 // clang-format off
1041 patterns.add<
1042 AbsOpConversion,
1043 AngleOpConversion,
1044 Atan2OpConversion,
1045 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1046 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1047 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1048 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1049 ConjOpConversion,
1050 CosOpConversion,
1051 ExpOpConversion,
1052 Expm1OpConversion,
1053 Log1pOpConversion,
1054 LogOpConversion,
1055 MulOpConversion,
1056 NegOpConversion,
1057 SignOpConversion,
1058 SinOpConversion,
1059 SqrtOpConversion,
1060 TanTanhOpConversion<complex::TanOp>,
1061 TanTanhOpConversion<complex::TanhOp>,
1062 PowOpConversion,
1063 RsqrtOpConversion
1064 >(patterns.getContext());
1065
1066 patterns.add<DivOpConversion>(patterns.getContext(), complexRange);
1067
1068 // clang-format on
1069}
1070
1071namespace {
1072struct ConvertComplexToStandardPass
1073 : public impl::ConvertComplexToStandardPassBase<
1074 ConvertComplexToStandardPass> {
1075 using Base::Base;
1076
1077 void runOnOperation() override;
1078};
1079
1080void ConvertComplexToStandardPass::runOnOperation() {
1081 // Convert to the Standard dialect using the converter defined above.
1082 RewritePatternSet patterns(&getContext());
1083 populateComplexToStandardConversionPatterns(patterns, complexRange);
1084
1085 ConversionTarget target(getContext());
1086 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1087 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1088 if (failed(
1089 applyPartialConversion(getOperation(), target, std::move(patterns))))
1090 signalPassFailure();
1091}
1092} // namespace
1093

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp