1 | //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" |
10 | |
11 | #include "mlir/Conversion/ComplexCommon/DivisionConverter.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Complex/IR/Complex.h" |
14 | #include "mlir/Dialect/Math/IR/Math.h" |
15 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | #include "mlir/Pass/Pass.h" |
18 | #include "mlir/Transforms/DialectConversion.h" |
19 | #include <memory> |
20 | #include <type_traits> |
21 | |
22 | namespace mlir { |
23 | #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS |
24 | #include "mlir/Conversion/Passes.h.inc" |
25 | } // namespace mlir |
26 | |
27 | using namespace mlir; |
28 | |
29 | namespace { |
30 | |
31 | enum class AbsFn { abs, sqrt, rsqrt }; |
32 | |
33 | // Returns the absolute value, its square root or its reciprocal square root. |
34 | Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, |
35 | ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { |
36 | Value one = b.create<arith::ConstantOp>(real.getType(), |
37 | b.getFloatAttr(real.getType(), 1.0)); |
38 | |
39 | Value absReal = b.create<math::AbsFOp>(real, fmf); |
40 | Value absImag = b.create<math::AbsFOp>(imag, fmf); |
41 | |
42 | Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf); |
43 | Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf); |
44 | |
45 | // The lowering below requires NaNs and infinities to work correctly. |
46 | arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( |
47 | fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); |
48 | Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf); |
49 | Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf); |
50 | Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf); |
51 | Value result; |
52 | |
53 | if (fn == AbsFn::rsqrt) { |
54 | ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf); |
55 | min = b.create<math::RsqrtOp>(min, fmfWithNaNInf); |
56 | max = b.create<math::RsqrtOp>(max, fmfWithNaNInf); |
57 | } |
58 | |
59 | if (fn == AbsFn::sqrt) { |
60 | Value quarter = b.create<arith::ConstantOp>( |
61 | real.getType(), b.getFloatAttr(real.getType(), 0.25)); |
62 | // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. |
63 | Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf); |
64 | Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf); |
65 | result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf); |
66 | } else { |
67 | Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf); |
68 | result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf); |
69 | } |
70 | |
71 | Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, |
72 | result, fmfWithNaNInf); |
73 | return b.create<arith::SelectOp>(isNaN, min, result); |
74 | } |
75 | |
76 | struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { |
77 | using OpConversionPattern<complex::AbsOp>::OpConversionPattern; |
78 | |
79 | LogicalResult |
80 | matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, |
81 | ConversionPatternRewriter &rewriter) const override { |
82 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
83 | |
84 | arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); |
85 | |
86 | Value real = b.create<complex::ReOp>(adaptor.getComplex()); |
87 | Value imag = b.create<complex::ImOp>(adaptor.getComplex()); |
88 | rewriter.replaceOp(op, computeAbs(real, imag, fmf, b)); |
89 | |
90 | return success(); |
91 | } |
92 | }; |
93 | |
94 | // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) |
95 | struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> { |
96 | using OpConversionPattern<complex::Atan2Op>::OpConversionPattern; |
97 | |
98 | LogicalResult |
99 | matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor, |
100 | ConversionPatternRewriter &rewriter) const override { |
101 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
102 | |
103 | auto type = cast<ComplexType>(op.getType()); |
104 | Type elementType = type.getElementType(); |
105 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
106 | |
107 | Value lhs = adaptor.getLhs(); |
108 | Value rhs = adaptor.getRhs(); |
109 | |
110 | Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf); |
111 | Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf); |
112 | Value rhsSquaredPlusLhsSquared = |
113 | b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf); |
114 | Value sqrtOfRhsSquaredPlusLhsSquared = |
115 | b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf); |
116 | |
117 | Value zero = |
118 | b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); |
119 | Value one = b.create<arith::ConstantOp>(elementType, |
120 | b.getFloatAttr(elementType, 1)); |
121 | Value i = b.create<complex::CreateOp>(type, zero, one); |
122 | Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf); |
123 | Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf); |
124 | |
125 | Value divResult = b.create<complex::DivOp>( |
126 | rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); |
127 | Value logResult = b.create<complex::LogOp>(divResult, fmf); |
128 | |
129 | Value negativeOne = b.create<arith::ConstantOp>( |
130 | elementType, b.getFloatAttr(elementType, -1)); |
131 | Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne); |
132 | |
133 | rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf); |
134 | return success(); |
135 | } |
136 | }; |
137 | |
138 | template <typename ComparisonOp, arith::CmpFPredicate p> |
139 | struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { |
140 | using OpConversionPattern<ComparisonOp>::OpConversionPattern; |
141 | using ResultCombiner = |
142 | std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, |
143 | arith::AndIOp, arith::OrIOp>; |
144 | |
145 | LogicalResult |
146 | matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, |
147 | ConversionPatternRewriter &rewriter) const override { |
148 | auto loc = op.getLoc(); |
149 | auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType(); |
150 | |
151 | Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); |
152 | Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); |
153 | Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); |
154 | Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); |
155 | Value realComparison = |
156 | rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); |
157 | Value imagComparison = |
158 | rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); |
159 | |
160 | rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, |
161 | imagComparison); |
162 | return success(); |
163 | } |
164 | }; |
165 | |
166 | // Default conversion which applies the BinaryStandardOp separately on the real |
167 | // and imaginary parts. Can for example be used for complex::AddOp and |
168 | // complex::SubOp. |
169 | template <typename BinaryComplexOp, typename BinaryStandardOp> |
170 | struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { |
171 | using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; |
172 | |
173 | LogicalResult |
174 | matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, |
175 | ConversionPatternRewriter &rewriter) const override { |
176 | auto type = cast<ComplexType>(adaptor.getLhs().getType()); |
177 | auto elementType = cast<FloatType>(type.getElementType()); |
178 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
179 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
180 | |
181 | Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); |
182 | Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); |
183 | Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs, |
184 | fmf.getValue()); |
185 | Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); |
186 | Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); |
187 | Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs, |
188 | fmf.getValue()); |
189 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
190 | resultImag); |
191 | return success(); |
192 | } |
193 | }; |
194 | |
195 | template <typename TrigonometricOp> |
196 | struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> { |
197 | using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor; |
198 | |
199 | using OpConversionPattern<TrigonometricOp>::OpConversionPattern; |
200 | |
201 | LogicalResult |
202 | matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, |
203 | ConversionPatternRewriter &rewriter) const override { |
204 | auto loc = op.getLoc(); |
205 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
206 | auto elementType = cast<FloatType>(type.getElementType()); |
207 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
208 | |
209 | Value real = |
210 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); |
211 | Value imag = |
212 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); |
213 | |
214 | // Trigonometric ops use a set of common building blocks to convert to real |
215 | // ops. Here we create these building blocks and call into an op-specific |
216 | // implementation in the subclass to combine them. |
217 | Value half = rewriter.create<arith::ConstantOp>( |
218 | loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); |
219 | Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf); |
220 | Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf); |
221 | Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf); |
222 | Value sin = rewriter.create<math::SinOp>(loc, real, fmf); |
223 | Value cos = rewriter.create<math::CosOp>(loc, real, fmf); |
224 | |
225 | auto resultPair = |
226 | combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf); |
227 | |
228 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first, |
229 | resultPair.second); |
230 | return success(); |
231 | } |
232 | |
233 | virtual std::pair<Value, Value> |
234 | combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, |
235 | Value cos, ConversionPatternRewriter &rewriter, |
236 | arith::FastMathFlagsAttr fmf) const = 0; |
237 | }; |
238 | |
239 | struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> { |
240 | using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion; |
241 | |
242 | std::pair<Value, Value> combine(Location loc, Value scaledExp, |
243 | Value reciprocalExp, Value sin, Value cos, |
244 | ConversionPatternRewriter &rewriter, |
245 | arith::FastMathFlagsAttr fmf) const override { |
246 | // Complex cosine is defined as; |
247 | // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) |
248 | // Plugging in: |
249 | // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) |
250 | // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) |
251 | // and defining t := exp(y) |
252 | // We get: |
253 | // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x |
254 | // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x |
255 | Value sum = |
256 | rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf); |
257 | Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf); |
258 | Value diff = |
259 | rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf); |
260 | Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf); |
261 | return {resultReal, resultImag}; |
262 | } |
263 | }; |
264 | |
265 | struct DivOpConversion : public OpConversionPattern<complex::DivOp> { |
266 | DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags target) |
267 | : OpConversionPattern<complex::DivOp>(context), complexRange(target) {} |
268 | |
269 | using OpConversionPattern<complex::DivOp>::OpConversionPattern; |
270 | |
271 | LogicalResult |
272 | matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, |
273 | ConversionPatternRewriter &rewriter) const override { |
274 | auto loc = op.getLoc(); |
275 | auto type = cast<ComplexType>(adaptor.getLhs().getType()); |
276 | auto elementType = cast<FloatType>(type.getElementType()); |
277 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
278 | |
279 | Value lhsReal = |
280 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); |
281 | Value lhsImag = |
282 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); |
283 | Value rhsReal = |
284 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); |
285 | Value rhsImag = |
286 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); |
287 | |
288 | Value resultReal, resultImag; |
289 | |
290 | if (complexRange == complex::ComplexRangeFlags::basic || |
291 | complexRange == complex::ComplexRangeFlags::none) { |
292 | mlir::complex::convertDivToStandardUsingAlgebraic( |
293 | rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal, |
294 | &resultImag); |
295 | } else if (complexRange == complex::ComplexRangeFlags::improved) { |
296 | mlir::complex::convertDivToStandardUsingRangeReduction( |
297 | rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal, |
298 | &resultImag); |
299 | } |
300 | |
301 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
302 | resultImag); |
303 | |
304 | return success(); |
305 | } |
306 | |
307 | private: |
308 | complex::ComplexRangeFlags complexRange; |
309 | }; |
310 | |
311 | struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { |
312 | using OpConversionPattern<complex::ExpOp>::OpConversionPattern; |
313 | |
314 | LogicalResult |
315 | matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, |
316 | ConversionPatternRewriter &rewriter) const override { |
317 | auto loc = op.getLoc(); |
318 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
319 | auto elementType = cast<FloatType>(type.getElementType()); |
320 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
321 | |
322 | Value real = |
323 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); |
324 | Value imag = |
325 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); |
326 | Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue()); |
327 | Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue()); |
328 | Value resultReal = |
329 | rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue()); |
330 | Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue()); |
331 | Value resultImag = |
332 | rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue()); |
333 | |
334 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
335 | resultImag); |
336 | return success(); |
337 | } |
338 | }; |
339 | |
340 | Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg, |
341 | ArrayRef<double> coefficients, |
342 | arith::FastMathFlagsAttr fmf) { |
343 | auto argType = mlir::cast<FloatType>(arg.getType()); |
344 | Value poly = |
345 | b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0])); |
346 | for (unsigned i = 1; i < coefficients.size(); ++i) { |
347 | poly = b.create<math::FmaOp>( |
348 | poly, arg, |
349 | b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])), |
350 | fmf); |
351 | } |
352 | return poly; |
353 | } |
354 | |
355 | struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> { |
356 | using OpConversionPattern<complex::Expm1Op>::OpConversionPattern; |
357 | |
358 | // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i |
359 | // [handle inaccuracies when a and/or b are small] |
360 | // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i |
361 | // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i |
362 | LogicalResult |
363 | matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, |
364 | ConversionPatternRewriter &rewriter) const override { |
365 | auto type = op.getType(); |
366 | auto elemType = mlir::cast<FloatType>(type.getElementType()); |
367 | |
368 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
369 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
370 | Value real = b.create<complex::ReOp>(adaptor.getComplex()); |
371 | Value imag = b.create<complex::ImOp>(adaptor.getComplex()); |
372 | |
373 | Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0)); |
374 | Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0)); |
375 | |
376 | Value expm1Real = b.create<math::ExpM1Op>(real, fmf); |
377 | Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf); |
378 | |
379 | Value sinImag = b.create<math::SinOp>(imag, fmf); |
380 | Value cosm1Imag = emitCosm1(imag, fmf, b); |
381 | Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf); |
382 | |
383 | Value realResult = b.create<arith::AddFOp>( |
384 | b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf); |
385 | |
386 | Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, |
387 | zero, fmf.getValue()); |
388 | Value imagResult = b.create<arith::SelectOp>( |
389 | imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf)); |
390 | |
391 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult, |
392 | imagResult); |
393 | return success(); |
394 | } |
395 | |
396 | private: |
397 | Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf, |
398 | ImplicitLocOpBuilder &b) const { |
399 | auto argType = mlir::cast<FloatType>(arg.getType()); |
400 | auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5)); |
401 | auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0)); |
402 | |
403 | // Algorithm copied from cephes cosm1. |
404 | SmallVector<double, 7> kCoeffs{ |
405 | 4.7377507964246204691685E-14, -1.1470284843425359765671E-11, |
406 | 2.0876754287081521758361E-9, -2.7557319214999787979814E-7, |
407 | 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, |
408 | 4.1666666666666666609054E-2, |
409 | }; |
410 | Value cos = b.create<math::CosOp>(arg, fmf); |
411 | Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf); |
412 | |
413 | Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf); |
414 | Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf); |
415 | Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf); |
416 | |
417 | auto forSmallArg = |
418 | b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf), |
419 | b.create<arith::MulFOp>(negHalf, argPow2, fmf)); |
420 | |
421 | // (pi/4)^2 is approximately 0.61685 |
422 | Value piOver4Pow2 = |
423 | b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685)); |
424 | Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2, |
425 | piOver4Pow2, fmf.getValue()); |
426 | return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg); |
427 | } |
428 | }; |
429 | |
430 | struct LogOpConversion : public OpConversionPattern<complex::LogOp> { |
431 | using OpConversionPattern<complex::LogOp>::OpConversionPattern; |
432 | |
433 | LogicalResult |
434 | matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, |
435 | ConversionPatternRewriter &rewriter) const override { |
436 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
437 | auto elementType = cast<FloatType>(type.getElementType()); |
438 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
439 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
440 | |
441 | Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), |
442 | fmf.getValue()); |
443 | Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue()); |
444 | Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); |
445 | Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); |
446 | Value resultImag = |
447 | b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue()); |
448 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
449 | resultImag); |
450 | return success(); |
451 | } |
452 | }; |
453 | |
454 | struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { |
455 | using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; |
456 | |
457 | LogicalResult |
458 | matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, |
459 | ConversionPatternRewriter &rewriter) const override { |
460 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
461 | auto elementType = cast<FloatType>(type.getElementType()); |
462 | arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); |
463 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
464 | |
465 | Value real = b.create<complex::ReOp>(adaptor.getComplex()); |
466 | Value imag = b.create<complex::ImOp>(adaptor.getComplex()); |
467 | |
468 | Value half = b.create<arith::ConstantOp>(elementType, |
469 | b.getFloatAttr(elementType, 0.5)); |
470 | Value one = b.create<arith::ConstantOp>(elementType, |
471 | b.getFloatAttr(elementType, 1)); |
472 | Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf); |
473 | Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf); |
474 | Value absImag = b.create<math::AbsFOp>(imag, fmf); |
475 | |
476 | Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf); |
477 | Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf); |
478 | |
479 | Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, |
480 | realPlusOne, absImag, fmf); |
481 | Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf); |
482 | Value maxAbsOfRealPlusOneAndImagMinusOne = |
483 | b.create<arith::SelectOp>(useReal, real, maxMinusOne); |
484 | arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( |
485 | fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); |
486 | Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf); |
487 | Value logOfMaxAbsOfRealPlusOneAndImag = |
488 | b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf); |
489 | Value logOfSqrtPart = b.create<math::Log1pOp>( |
490 | b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf), |
491 | fmfWithNaNInf); |
492 | Value r = b.create<arith::AddFOp>( |
493 | b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf), |
494 | logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf); |
495 | Value resultReal = b.create<arith::SelectOp>( |
496 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf), |
497 | minAbs, r); |
498 | Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf); |
499 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
500 | resultImag); |
501 | return success(); |
502 | } |
503 | }; |
504 | |
505 | struct MulOpConversion : public OpConversionPattern<complex::MulOp> { |
506 | using OpConversionPattern<complex::MulOp>::OpConversionPattern; |
507 | |
508 | LogicalResult |
509 | matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, |
510 | ConversionPatternRewriter &rewriter) const override { |
511 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
512 | auto type = cast<ComplexType>(adaptor.getLhs().getType()); |
513 | auto elementType = cast<FloatType>(type.getElementType()); |
514 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
515 | auto fmfValue = fmf.getValue(); |
516 | Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); |
517 | Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); |
518 | Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); |
519 | Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); |
520 | Value lhsRealTimesRhsReal = |
521 | b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue); |
522 | Value lhsImagTimesRhsImag = |
523 | b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue); |
524 | Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal, |
525 | lhsImagTimesRhsImag, fmfValue); |
526 | Value lhsImagTimesRhsReal = |
527 | b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue); |
528 | Value lhsRealTimesRhsImag = |
529 | b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue); |
530 | Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal, |
531 | lhsRealTimesRhsImag, fmfValue); |
532 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); |
533 | return success(); |
534 | } |
535 | }; |
536 | |
537 | struct NegOpConversion : public OpConversionPattern<complex::NegOp> { |
538 | using OpConversionPattern<complex::NegOp>::OpConversionPattern; |
539 | |
540 | LogicalResult |
541 | matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, |
542 | ConversionPatternRewriter &rewriter) const override { |
543 | auto loc = op.getLoc(); |
544 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
545 | auto elementType = cast<FloatType>(type.getElementType()); |
546 | |
547 | Value real = |
548 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); |
549 | Value imag = |
550 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); |
551 | Value negReal = rewriter.create<arith::NegFOp>(loc, real); |
552 | Value negImag = rewriter.create<arith::NegFOp>(loc, imag); |
553 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); |
554 | return success(); |
555 | } |
556 | }; |
557 | |
558 | struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> { |
559 | using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion; |
560 | |
561 | std::pair<Value, Value> combine(Location loc, Value scaledExp, |
562 | Value reciprocalExp, Value sin, Value cos, |
563 | ConversionPatternRewriter &rewriter, |
564 | arith::FastMathFlagsAttr fmf) const override { |
565 | // Complex sine is defined as; |
566 | // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) |
567 | // Plugging in: |
568 | // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) |
569 | // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) |
570 | // and defining t := exp(y) |
571 | // We get: |
572 | // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x |
573 | // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x |
574 | Value sum = |
575 | rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf); |
576 | Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf); |
577 | Value diff = |
578 | rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf); |
579 | Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf); |
580 | return {resultReal, resultImag}; |
581 | } |
582 | }; |
583 | |
584 | // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. |
585 | struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> { |
586 | using OpConversionPattern<complex::SqrtOp>::OpConversionPattern; |
587 | |
588 | LogicalResult |
589 | matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor, |
590 | ConversionPatternRewriter &rewriter) const override { |
591 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
592 | |
593 | auto type = cast<ComplexType>(op.getType()); |
594 | auto elementType = cast<FloatType>(type.getElementType()); |
595 | arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); |
596 | |
597 | auto cst = [&](APFloat v) { |
598 | return b.create<arith::ConstantOp>(elementType, |
599 | b.getFloatAttr(elementType, v)); |
600 | }; |
601 | const auto &floatSemantics = elementType.getFloatSemantics(); |
602 | Value zero = cst(APFloat::getZero(Sem: floatSemantics)); |
603 | Value half = b.create<arith::ConstantOp>(elementType, |
604 | b.getFloatAttr(elementType, 0.5)); |
605 | |
606 | Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); |
607 | Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); |
608 | Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); |
609 | Value argArg = b.create<math::Atan2Op>(imag, real, fmf); |
610 | Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf); |
611 | Value cos = b.create<math::CosOp>(sqrtArg, fmf); |
612 | Value sin = b.create<math::SinOp>(sqrtArg, fmf); |
613 | // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply |
614 | // 0 * inf. |
615 | Value sinIsZero = |
616 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf); |
617 | |
618 | Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf); |
619 | Value resultImag = b.create<arith::SelectOp>( |
620 | sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf)); |
621 | if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | |
622 | arith::FastMathFlags::ninf)) { |
623 | Value inf = cst(APFloat::getInf(Sem: floatSemantics)); |
624 | Value negInf = cst(APFloat::getInf(Sem: floatSemantics, Negative: true)); |
625 | Value nan = cst(APFloat::getNaN(Sem: floatSemantics)); |
626 | Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf); |
627 | |
628 | Value absImagIsInf = |
629 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); |
630 | Value absImagIsNotInf = |
631 | b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf); |
632 | Value realIsInf = |
633 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf); |
634 | Value realIsNegInf = |
635 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf); |
636 | |
637 | resultReal = b.create<arith::SelectOp>( |
638 | b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero, |
639 | resultReal); |
640 | resultReal = b.create<arith::SelectOp>( |
641 | b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal); |
642 | |
643 | Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf); |
644 | resultImag = b.create<arith::SelectOp>( |
645 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt), |
646 | nan, resultImag); |
647 | resultImag = b.create<arith::SelectOp>( |
648 | b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf, |
649 | resultImag); |
650 | } |
651 | |
652 | Value resultIsZero = |
653 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); |
654 | resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal); |
655 | resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag); |
656 | |
657 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
658 | resultImag); |
659 | return success(); |
660 | } |
661 | }; |
662 | |
663 | struct SignOpConversion : public OpConversionPattern<complex::SignOp> { |
664 | using OpConversionPattern<complex::SignOp>::OpConversionPattern; |
665 | |
666 | LogicalResult |
667 | matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, |
668 | ConversionPatternRewriter &rewriter) const override { |
669 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
670 | auto elementType = cast<FloatType>(type.getElementType()); |
671 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
672 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
673 | |
674 | Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); |
675 | Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); |
676 | Value zero = |
677 | b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); |
678 | Value realIsZero = |
679 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); |
680 | Value imagIsZero = |
681 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); |
682 | Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); |
683 | auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf); |
684 | Value realSign = b.create<arith::DivFOp>(real, abs, fmf); |
685 | Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf); |
686 | Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); |
687 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, |
688 | adaptor.getComplex(), sign); |
689 | return success(); |
690 | } |
691 | }; |
692 | |
693 | template <typename Op> |
694 | struct TanTanhOpConversion : public OpConversionPattern<Op> { |
695 | using OpConversionPattern<Op>::OpConversionPattern; |
696 | |
697 | LogicalResult |
698 | matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
699 | ConversionPatternRewriter &rewriter) const override { |
700 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
701 | auto loc = op.getLoc(); |
702 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
703 | auto elementType = cast<FloatType>(type.getElementType()); |
704 | arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); |
705 | const auto &floatSemantics = elementType.getFloatSemantics(); |
706 | |
707 | Value real = |
708 | b.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); |
709 | Value imag = |
710 | b.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); |
711 | Value negOne = b.create<arith::ConstantOp>( |
712 | elementType, b.getFloatAttr(elementType, -1.0)); |
713 | |
714 | if constexpr (std::is_same_v<Op, complex::TanOp>) { |
715 | // tan(x+yi) = -i*tanh(-y + xi) |
716 | std::swap(real, imag); |
717 | real = b.create<arith::MulFOp>(real, negOne, fmf); |
718 | } |
719 | |
720 | auto cst = [&](APFloat v) { |
721 | return b.create<arith::ConstantOp>(elementType, |
722 | b.getFloatAttr(elementType, v)); |
723 | }; |
724 | Value inf = cst(APFloat::getInf(Sem: floatSemantics)); |
725 | Value four = b.create<arith::ConstantOp>(elementType, |
726 | b.getFloatAttr(elementType, 4.0)); |
727 | Value twoReal = b.create<arith::AddFOp>(real, real, fmf); |
728 | Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf); |
729 | |
730 | Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf); |
731 | Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf); |
732 | Value realNum = |
733 | b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); |
734 | |
735 | Value cosImag = b.create<math::CosOp>(imag, fmf); |
736 | Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf); |
737 | Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf); |
738 | Value sinImag = b.create<math::SinOp>(imag, fmf); |
739 | |
740 | Value imagNum = b.create<arith::MulFOp>( |
741 | four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf); |
742 | |
743 | Value expSumMinusTwo = |
744 | b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); |
745 | Value denom = |
746 | b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf); |
747 | |
748 | Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, |
749 | expSumMinusTwo, inf, fmf); |
750 | Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf); |
751 | |
752 | Value resultReal = b.create<arith::SelectOp>( |
753 | isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf)); |
754 | Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf); |
755 | |
756 | if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | |
757 | arith::FastMathFlags::ninf)) { |
758 | Value absReal = b.create<math::AbsFOp>(real, fmf); |
759 | Value zero = b.create<arith::ConstantOp>( |
760 | elementType, b.getFloatAttr(elementType, 0.0)); |
761 | Value nan = cst(APFloat::getNaN(Sem: floatSemantics)); |
762 | |
763 | Value absRealIsInf = |
764 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); |
765 | Value imagIsZero = |
766 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); |
767 | Value absRealIsNotInf = b.create<arith::XOrIOp>( |
768 | absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1)); |
769 | |
770 | Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, |
771 | imagNum, imagNum, fmf); |
772 | Value resultRealIsNaN = |
773 | b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf); |
774 | Value resultImagIsZero = b.create<arith::OrIOp>( |
775 | imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN)); |
776 | |
777 | resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal); |
778 | resultImag = |
779 | b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag); |
780 | } |
781 | |
782 | if constexpr (std::is_same_v<Op, complex::TanOp>) { |
783 | // tan(x+yi) = -i*tanh(-y + xi) |
784 | std::swap(resultReal, resultImag); |
785 | resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf); |
786 | } |
787 | |
788 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
789 | resultImag); |
790 | return success(); |
791 | } |
792 | }; |
793 | |
794 | struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> { |
795 | using OpConversionPattern<complex::ConjOp>::OpConversionPattern; |
796 | |
797 | LogicalResult |
798 | matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor, |
799 | ConversionPatternRewriter &rewriter) const override { |
800 | auto loc = op.getLoc(); |
801 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
802 | auto elementType = cast<FloatType>(type.getElementType()); |
803 | Value real = |
804 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); |
805 | Value imag = |
806 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); |
807 | Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag); |
808 | |
809 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag); |
810 | |
811 | return success(); |
812 | } |
813 | }; |
814 | |
815 | /// Converts lhs^y = (a+bi)^(c+di) to |
816 | /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), |
817 | /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) |
818 | static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, |
819 | ComplexType type, Value lhs, Value c, Value d, |
820 | arith::FastMathFlags fmf) { |
821 | auto elementType = cast<FloatType>(type.getElementType()); |
822 | |
823 | Value a = builder.create<complex::ReOp>(lhs); |
824 | Value b = builder.create<complex::ImOp>(lhs); |
825 | |
826 | Value abs = builder.create<complex::AbsOp>(lhs, fmf); |
827 | Value absToC = builder.create<math::PowFOp>(abs, c, fmf); |
828 | |
829 | Value negD = builder.create<arith::NegFOp>(d, fmf); |
830 | Value argLhs = builder.create<math::Atan2Op>(b, a, fmf); |
831 | Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf); |
832 | Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf); |
833 | |
834 | Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf); |
835 | Value lnAbs = builder.create<math::LogOp>(abs, fmf); |
836 | Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf); |
837 | Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf); |
838 | Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf); |
839 | Value cosQ = builder.create<math::CosOp>(q, fmf); |
840 | Value sinQ = builder.create<math::SinOp>(q, fmf); |
841 | |
842 | Value inf = builder.create<arith::ConstantOp>( |
843 | elementType, |
844 | builder.getFloatAttr(elementType, |
845 | APFloat::getInf(elementType.getFloatSemantics()))); |
846 | Value zero = builder.create<arith::ConstantOp>( |
847 | elementType, builder.getFloatAttr(elementType, 0.0)); |
848 | Value one = builder.create<arith::ConstantOp>( |
849 | elementType, builder.getFloatAttr(elementType, 1.0)); |
850 | Value complexOne = builder.create<complex::CreateOp>(type, one, zero); |
851 | Value complexZero = builder.create<complex::CreateOp>(type, zero, zero); |
852 | Value complexInf = builder.create<complex::CreateOp>(type, inf, zero); |
853 | |
854 | // Case 0: |
855 | // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see |
856 | // Branch Cuts for Complex Elementary Functions or Much Ado About |
857 | // Nothing's Sign Bit, W. Kahan, Section 10. |
858 | Value absEqZero = |
859 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf); |
860 | Value dEqZero = |
861 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf); |
862 | Value cEqZero = |
863 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf); |
864 | Value bEqZero = |
865 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf); |
866 | |
867 | Value zeroLeC = |
868 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf); |
869 | Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf); |
870 | Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf); |
871 | Value complexOneOrZero = |
872 | builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero); |
873 | Value coeffCosSin = |
874 | builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ); |
875 | Value cutoff0 = builder.create<arith::SelectOp>( |
876 | builder.create<arith::AndIOp>( |
877 | builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC), |
878 | complexOneOrZero, coeffCosSin); |
879 | |
880 | // Case 1: |
881 | // x^0 is defined to be 1 for any x, see |
882 | // Branch Cuts for Complex Elementary Functions or Much Ado About |
883 | // Nothing's Sign Bit, W. Kahan, Section 10. |
884 | Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero); |
885 | Value cutoff1 = |
886 | builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0); |
887 | |
888 | // Case 2: |
889 | // 1^(c + d*i) = 1 + 0*i |
890 | Value lhsEqOne = builder.create<arith::AndIOp>( |
891 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf), |
892 | bEqZero); |
893 | Value cutoff2 = |
894 | builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1); |
895 | |
896 | // Case 3: |
897 | // inf^(c + 0*i) = inf + 0*i, c > 0 |
898 | Value lhsEqInf = builder.create<arith::AndIOp>( |
899 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf), |
900 | bEqZero); |
901 | Value rhsGt0 = builder.create<arith::AndIOp>( |
902 | dEqZero, |
903 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf)); |
904 | Value cutoff3 = builder.create<arith::SelectOp>( |
905 | builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2); |
906 | |
907 | // Case 4: |
908 | // inf^(c + 0*i) = 0 + 0*i, c < 0 |
909 | Value rhsLt0 = builder.create<arith::AndIOp>( |
910 | dEqZero, |
911 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf)); |
912 | Value cutoff4 = builder.create<arith::SelectOp>( |
913 | builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3); |
914 | |
915 | return cutoff4; |
916 | } |
917 | |
918 | struct PowOpConversion : public OpConversionPattern<complex::PowOp> { |
919 | using OpConversionPattern<complex::PowOp>::OpConversionPattern; |
920 | |
921 | LogicalResult |
922 | matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, |
923 | ConversionPatternRewriter &rewriter) const override { |
924 | mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); |
925 | auto type = cast<ComplexType>(adaptor.getLhs().getType()); |
926 | auto elementType = cast<FloatType>(type.getElementType()); |
927 | |
928 | Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs()); |
929 | Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs()); |
930 | |
931 | rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(), |
932 | c, d, op.getFastmath())}); |
933 | return success(); |
934 | } |
935 | }; |
936 | |
937 | struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> { |
938 | using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern; |
939 | |
940 | LogicalResult |
941 | matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, |
942 | ConversionPatternRewriter &rewriter) const override { |
943 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
944 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
945 | auto elementType = cast<FloatType>(type.getElementType()); |
946 | |
947 | arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); |
948 | |
949 | auto cst = [&](APFloat v) { |
950 | return b.create<arith::ConstantOp>(elementType, |
951 | b.getFloatAttr(elementType, v)); |
952 | }; |
953 | const auto &floatSemantics = elementType.getFloatSemantics(); |
954 | Value zero = cst(APFloat::getZero(Sem: floatSemantics)); |
955 | Value inf = cst(APFloat::getInf(Sem: floatSemantics)); |
956 | Value negHalf = b.create<arith::ConstantOp>( |
957 | elementType, b.getFloatAttr(elementType, -0.5)); |
958 | Value nan = cst(APFloat::getNaN(Sem: floatSemantics)); |
959 | |
960 | Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); |
961 | Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); |
962 | Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); |
963 | Value argArg = b.create<math::Atan2Op>(imag, real, fmf); |
964 | Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf); |
965 | Value cos = b.create<math::CosOp>(rsqrtArg, fmf); |
966 | Value sin = b.create<math::SinOp>(rsqrtArg, fmf); |
967 | |
968 | Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf); |
969 | Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf); |
970 | |
971 | if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | |
972 | arith::FastMathFlags::ninf)) { |
973 | Value negOne = b.create<arith::ConstantOp>( |
974 | elementType, b.getFloatAttr(elementType, -1)); |
975 | |
976 | Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf); |
977 | Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf); |
978 | Value negImagSignedZero = |
979 | b.create<arith::MulFOp>(negOne, imagSignedZero, fmf); |
980 | |
981 | Value absReal = b.create<math::AbsFOp>(real, fmf); |
982 | Value absImag = b.create<math::AbsFOp>(imag, fmf); |
983 | |
984 | Value absImagIsInf = |
985 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); |
986 | Value realIsNan = |
987 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf); |
988 | Value realIsInf = |
989 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); |
990 | Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan); |
991 | |
992 | Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf); |
993 | |
994 | resultReal = |
995 | b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal); |
996 | resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero, |
997 | resultImag); |
998 | } |
999 | |
1000 | Value isRealZero = |
1001 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf); |
1002 | Value isImagZero = |
1003 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); |
1004 | Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero); |
1005 | |
1006 | resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal); |
1007 | resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag); |
1008 | |
1009 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
1010 | resultImag); |
1011 | return success(); |
1012 | } |
1013 | }; |
1014 | |
1015 | struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> { |
1016 | using OpConversionPattern<complex::AngleOp>::OpConversionPattern; |
1017 | |
1018 | LogicalResult |
1019 | matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor, |
1020 | ConversionPatternRewriter &rewriter) const override { |
1021 | auto loc = op.getLoc(); |
1022 | auto type = op.getType(); |
1023 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
1024 | |
1025 | Value real = |
1026 | rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); |
1027 | Value imag = |
1028 | rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); |
1029 | |
1030 | rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf); |
1031 | |
1032 | return success(); |
1033 | } |
1034 | }; |
1035 | |
1036 | } // namespace |
1037 | |
1038 | void mlir::populateComplexToStandardConversionPatterns( |
1039 | RewritePatternSet &patterns, complex::ComplexRangeFlags complexRange) { |
1040 | // clang-format off |
1041 | patterns.add< |
1042 | AbsOpConversion, |
1043 | AngleOpConversion, |
1044 | Atan2OpConversion, |
1045 | BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, |
1046 | BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, |
1047 | ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, |
1048 | ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, |
1049 | ConjOpConversion, |
1050 | CosOpConversion, |
1051 | ExpOpConversion, |
1052 | Expm1OpConversion, |
1053 | Log1pOpConversion, |
1054 | LogOpConversion, |
1055 | MulOpConversion, |
1056 | NegOpConversion, |
1057 | SignOpConversion, |
1058 | SinOpConversion, |
1059 | SqrtOpConversion, |
1060 | TanTanhOpConversion<complex::TanOp>, |
1061 | TanTanhOpConversion<complex::TanhOp>, |
1062 | PowOpConversion, |
1063 | RsqrtOpConversion |
1064 | >(patterns.getContext()); |
1065 | |
1066 | patterns.add<DivOpConversion>(patterns.getContext(), complexRange); |
1067 | |
1068 | // clang-format on |
1069 | } |
1070 | |
1071 | namespace { |
1072 | struct ConvertComplexToStandardPass |
1073 | : public impl::ConvertComplexToStandardPassBase< |
1074 | ConvertComplexToStandardPass> { |
1075 | using Base::Base; |
1076 | |
1077 | void runOnOperation() override; |
1078 | }; |
1079 | |
1080 | void ConvertComplexToStandardPass::runOnOperation() { |
1081 | // Convert to the Standard dialect using the converter defined above. |
1082 | RewritePatternSet patterns(&getContext()); |
1083 | populateComplexToStandardConversionPatterns(patterns, complexRange); |
1084 | |
1085 | ConversionTarget target(getContext()); |
1086 | target.addLegalDialect<arith::ArithDialect, math::MathDialect>(); |
1087 | target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); |
1088 | if (failed( |
1089 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
1090 | signalPassFailure(); |
1091 | } |
1092 | } // namespace |
1093 |
Definitions
- AbsFn
- computeAbs
- AbsOpConversion
- matchAndRewrite
- Atan2OpConversion
- matchAndRewrite
- ComparisonOpConversion
- matchAndRewrite
- BinaryComplexOpConversion
- matchAndRewrite
- TrigonometricOpConversion
- matchAndRewrite
- CosOpConversion
- combine
- DivOpConversion
- DivOpConversion
- matchAndRewrite
- ExpOpConversion
- matchAndRewrite
- evaluatePolynomial
- Expm1OpConversion
- matchAndRewrite
- emitCosm1
- LogOpConversion
- matchAndRewrite
- Log1pOpConversion
- matchAndRewrite
- MulOpConversion
- matchAndRewrite
- NegOpConversion
- matchAndRewrite
- SinOpConversion
- combine
- SqrtOpConversion
- matchAndRewrite
- SignOpConversion
- matchAndRewrite
- TanTanhOpConversion
- matchAndRewrite
- ConjOpConversion
- matchAndRewrite
- powOpConversionImpl
- PowOpConversion
- matchAndRewrite
- RsqrtOpConversion
- matchAndRewrite
- AngleOpConversion
- matchAndRewrite
- populateComplexToStandardConversionPatterns
- ConvertComplexToStandardPass
Improve your Profiling and Debugging skills
Find out more