1 | //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" |
10 | |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Dialect/Complex/IR/Complex.h" |
13 | #include "mlir/Dialect/Math/IR/Math.h" |
14 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
15 | #include "mlir/IR/PatternMatch.h" |
16 | #include "mlir/Pass/Pass.h" |
17 | #include "mlir/Transforms/DialectConversion.h" |
18 | #include <memory> |
19 | #include <type_traits> |
20 | |
21 | namespace mlir { |
22 | #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD |
23 | #include "mlir/Conversion/Passes.h.inc" |
24 | } // namespace mlir |
25 | |
26 | using namespace mlir; |
27 | |
28 | namespace { |
29 | |
30 | enum class AbsFn { abs, sqrt, rsqrt }; |
31 | |
32 | // Returns the absolute value, its square root or its reciprocal square root. |
33 | Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, |
34 | ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { |
35 | Value one = b.create<arith::ConstantOp>(real.getType(), |
36 | b.getFloatAttr(real.getType(), 1.0)); |
37 | |
38 | Value absReal = b.create<math::AbsFOp>(real, fmf); |
39 | Value absImag = b.create<math::AbsFOp>(imag, fmf); |
40 | |
41 | Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf); |
42 | Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf); |
43 | Value ratio = b.create<arith::DivFOp>(min, max, fmf); |
44 | Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf); |
45 | Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf); |
46 | Value result; |
47 | |
48 | if (fn == AbsFn::rsqrt) { |
49 | ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf); |
50 | min = b.create<math::RsqrtOp>(min, fmf); |
51 | max = b.create<math::RsqrtOp>(max, fmf); |
52 | } |
53 | |
54 | if (fn == AbsFn::sqrt) { |
55 | Value quarter = b.create<arith::ConstantOp>( |
56 | real.getType(), b.getFloatAttr(real.getType(), 0.25)); |
57 | // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. |
58 | Value sqrt = b.create<math::SqrtOp>(max, fmf); |
59 | Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmf); |
60 | result = b.create<arith::MulFOp>(sqrt, p025, fmf); |
61 | } else { |
62 | Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf); |
63 | result = b.create<arith::MulFOp>(max, sqrt, fmf); |
64 | } |
65 | |
66 | Value isNaN = |
67 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf); |
68 | return b.create<arith::SelectOp>(isNaN, min, result); |
69 | } |
70 | |
71 | struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { |
72 | using OpConversionPattern<complex::AbsOp>::OpConversionPattern; |
73 | |
74 | LogicalResult |
75 | matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, |
76 | ConversionPatternRewriter &rewriter) const override { |
77 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
78 | |
79 | arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); |
80 | |
81 | Value real = b.create<complex::ReOp>(adaptor.getComplex()); |
82 | Value imag = b.create<complex::ImOp>(adaptor.getComplex()); |
83 | rewriter.replaceOp(op, computeAbs(real, imag, fmf, b)); |
84 | |
85 | return success(); |
86 | } |
87 | }; |
88 | |
89 | // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) |
90 | struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> { |
91 | using OpConversionPattern<complex::Atan2Op>::OpConversionPattern; |
92 | |
93 | LogicalResult |
94 | matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor, |
95 | ConversionPatternRewriter &rewriter) const override { |
96 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
97 | |
98 | auto type = cast<ComplexType>(op.getType()); |
99 | Type elementType = type.getElementType(); |
100 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
101 | |
102 | Value lhs = adaptor.getLhs(); |
103 | Value rhs = adaptor.getRhs(); |
104 | |
105 | Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf); |
106 | Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf); |
107 | Value rhsSquaredPlusLhsSquared = |
108 | b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf); |
109 | Value sqrtOfRhsSquaredPlusLhsSquared = |
110 | b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf); |
111 | |
112 | Value zero = |
113 | b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); |
114 | Value one = b.create<arith::ConstantOp>(elementType, |
115 | b.getFloatAttr(elementType, 1)); |
116 | Value i = b.create<complex::CreateOp>(type, zero, one); |
117 | Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf); |
118 | Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf); |
119 | |
120 | Value divResult = b.create<complex::DivOp>( |
121 | rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); |
122 | Value logResult = b.create<complex::LogOp>(divResult, fmf); |
123 | |
124 | Value negativeOne = b.create<arith::ConstantOp>( |
125 | elementType, b.getFloatAttr(elementType, -1)); |
126 | Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne); |
127 | |
128 | rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf); |
129 | return success(); |
130 | } |
131 | }; |
132 | |
133 | template <typename ComparisonOp, arith::CmpFPredicate p> |
134 | struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { |
135 | using OpConversionPattern<ComparisonOp>::OpConversionPattern; |
136 | using ResultCombiner = |
137 | std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, |
138 | arith::AndIOp, arith::OrIOp>; |
139 | |
140 | LogicalResult |
141 | matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, |
142 | ConversionPatternRewriter &rewriter) const override { |
143 | auto loc = op.getLoc(); |
144 | auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType(); |
145 | |
146 | Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); |
147 | Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); |
148 | Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); |
149 | Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); |
150 | Value realComparison = |
151 | rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); |
152 | Value imagComparison = |
153 | rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); |
154 | |
155 | rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, |
156 | imagComparison); |
157 | return success(); |
158 | } |
159 | }; |
160 | |
161 | // Default conversion which applies the BinaryStandardOp separately on the real |
162 | // and imaginary parts. Can for example be used for complex::AddOp and |
163 | // complex::SubOp. |
164 | template <typename BinaryComplexOp, typename BinaryStandardOp> |
165 | struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { |
166 | using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; |
167 | |
168 | LogicalResult |
169 | matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, |
170 | ConversionPatternRewriter &rewriter) const override { |
171 | auto type = cast<ComplexType>(adaptor.getLhs().getType()); |
172 | auto elementType = cast<FloatType>(type.getElementType()); |
173 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
174 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
175 | |
176 | Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); |
177 | Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); |
178 | Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs, |
179 | fmf.getValue()); |
180 | Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); |
181 | Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); |
182 | Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs, |
183 | fmf.getValue()); |
184 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
185 | resultImag); |
186 | return success(); |
187 | } |
188 | }; |
189 | |
190 | template <typename TrigonometricOp> |
191 | struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> { |
192 | using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor; |
193 | |
194 | using OpConversionPattern<TrigonometricOp>::OpConversionPattern; |
195 | |
196 | LogicalResult |
197 | matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, |
198 | ConversionPatternRewriter &rewriter) const override { |
199 | auto loc = op.getLoc(); |
200 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
201 | auto elementType = cast<FloatType>(type.getElementType()); |
202 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
203 | |
204 | Value real = |
205 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); |
206 | Value imag = |
207 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); |
208 | |
209 | // Trigonometric ops use a set of common building blocks to convert to real |
210 | // ops. Here we create these building blocks and call into an op-specific |
211 | // implementation in the subclass to combine them. |
212 | Value half = rewriter.create<arith::ConstantOp>( |
213 | loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); |
214 | Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf); |
215 | Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf); |
216 | Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf); |
217 | Value sin = rewriter.create<math::SinOp>(loc, real, fmf); |
218 | Value cos = rewriter.create<math::CosOp>(loc, real, fmf); |
219 | |
220 | auto resultPair = |
221 | combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf); |
222 | |
223 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first, |
224 | resultPair.second); |
225 | return success(); |
226 | } |
227 | |
228 | virtual std::pair<Value, Value> |
229 | combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, |
230 | Value cos, ConversionPatternRewriter &rewriter, |
231 | arith::FastMathFlagsAttr fmf) const = 0; |
232 | }; |
233 | |
234 | struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> { |
235 | using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion; |
236 | |
237 | std::pair<Value, Value> combine(Location loc, Value scaledExp, |
238 | Value reciprocalExp, Value sin, Value cos, |
239 | ConversionPatternRewriter &rewriter, |
240 | arith::FastMathFlagsAttr fmf) const override { |
241 | // Complex cosine is defined as; |
242 | // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) |
243 | // Plugging in: |
244 | // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) |
245 | // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) |
246 | // and defining t := exp(y) |
247 | // We get: |
248 | // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x |
249 | // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x |
250 | Value sum = |
251 | rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf); |
252 | Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf); |
253 | Value diff = |
254 | rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf); |
255 | Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf); |
256 | return {resultReal, resultImag}; |
257 | } |
258 | }; |
259 | |
260 | struct DivOpConversion : public OpConversionPattern<complex::DivOp> { |
261 | using OpConversionPattern<complex::DivOp>::OpConversionPattern; |
262 | |
263 | LogicalResult |
264 | matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, |
265 | ConversionPatternRewriter &rewriter) const override { |
266 | auto loc = op.getLoc(); |
267 | auto type = cast<ComplexType>(adaptor.getLhs().getType()); |
268 | auto elementType = cast<FloatType>(type.getElementType()); |
269 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
270 | |
271 | Value lhsReal = |
272 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); |
273 | Value lhsImag = |
274 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); |
275 | Value rhsReal = |
276 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); |
277 | Value rhsImag = |
278 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); |
279 | |
280 | // Smith's algorithm to divide complex numbers. It is just a bit smarter |
281 | // way to compute the following formula: |
282 | // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) |
283 | // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / |
284 | // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) |
285 | // = ((lhsReal * rhsReal + lhsImag * rhsImag) + |
286 | // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 |
287 | // |
288 | // Depending on whether |rhsReal| < |rhsImag| we compute either |
289 | // rhsRealImagRatio = rhsReal / rhsImag |
290 | // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio |
291 | // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom |
292 | // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom |
293 | // |
294 | // or |
295 | // |
296 | // rhsImagRealRatio = rhsImag / rhsReal |
297 | // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio |
298 | // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom |
299 | // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom |
300 | // |
301 | // See https://dl.acm.org/citation.cfm?id=368661 for more details. |
302 | Value rhsRealImagRatio = |
303 | rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf); |
304 | Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( |
305 | loc, rhsImag, |
306 | rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf), |
307 | fmf); |
308 | Value realNumerator1 = rewriter.create<arith::AddFOp>( |
309 | loc, |
310 | rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf), |
311 | lhsImag, fmf); |
312 | Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1, |
313 | rhsRealImagDenom, fmf); |
314 | Value imagNumerator1 = rewriter.create<arith::SubFOp>( |
315 | loc, |
316 | rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf), |
317 | lhsReal, fmf); |
318 | Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1, |
319 | rhsRealImagDenom, fmf); |
320 | |
321 | Value rhsImagRealRatio = |
322 | rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf); |
323 | Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( |
324 | loc, rhsReal, |
325 | rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf), |
326 | fmf); |
327 | Value realNumerator2 = rewriter.create<arith::AddFOp>( |
328 | loc, lhsReal, |
329 | rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf), |
330 | fmf); |
331 | Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2, |
332 | rhsImagRealDenom, fmf); |
333 | Value imagNumerator2 = rewriter.create<arith::SubFOp>( |
334 | loc, lhsImag, |
335 | rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf), |
336 | fmf); |
337 | Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2, |
338 | rhsImagRealDenom, fmf); |
339 | |
340 | // Consider corner cases. |
341 | // Case 1. Zero denominator, numerator contains at most one NaN value. |
342 | Value zero = rewriter.create<arith::ConstantOp>( |
343 | loc, elementType, rewriter.getZeroAttr(elementType)); |
344 | Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf); |
345 | Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( |
346 | loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); |
347 | Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf); |
348 | Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( |
349 | loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); |
350 | Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( |
351 | loc, arith::CmpFPredicate::ORD, lhsReal, zero); |
352 | Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( |
353 | loc, arith::CmpFPredicate::ORD, lhsImag, zero); |
354 | Value lhsContainsNotNaNValue = |
355 | rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); |
356 | Value resultIsInfinity = rewriter.create<arith::AndIOp>( |
357 | loc, lhsContainsNotNaNValue, |
358 | rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); |
359 | Value inf = rewriter.create<arith::ConstantOp>( |
360 | loc, elementType, |
361 | rewriter.getFloatAttr( |
362 | elementType, APFloat::getInf(elementType.getFloatSemantics()))); |
363 | Value infWithSignOfRhsReal = |
364 | rewriter.create<math::CopySignOp>(loc, inf, rhsReal); |
365 | Value infinityResultReal = |
366 | rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf); |
367 | Value infinityResultImag = |
368 | rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf); |
369 | |
370 | // Case 2. Infinite numerator, finite denominator. |
371 | Value rhsRealFinite = rewriter.create<arith::CmpFOp>( |
372 | loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); |
373 | Value rhsImagFinite = rewriter.create<arith::CmpFOp>( |
374 | loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); |
375 | Value rhsFinite = |
376 | rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); |
377 | Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf); |
378 | Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( |
379 | loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); |
380 | Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf); |
381 | Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( |
382 | loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); |
383 | Value lhsInfinite = |
384 | rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); |
385 | Value infNumFiniteDenom = |
386 | rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); |
387 | Value one = rewriter.create<arith::ConstantOp>( |
388 | loc, elementType, rewriter.getFloatAttr(elementType, 1)); |
389 | Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( |
390 | loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero), |
391 | lhsReal); |
392 | Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( |
393 | loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero), |
394 | lhsImag); |
395 | Value lhsRealIsInfWithSignTimesRhsReal = |
396 | rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf); |
397 | Value lhsImagIsInfWithSignTimesRhsImag = |
398 | rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf); |
399 | Value resultReal3 = rewriter.create<arith::MulFOp>( |
400 | loc, inf, |
401 | rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, |
402 | lhsImagIsInfWithSignTimesRhsImag, fmf), |
403 | fmf); |
404 | Value lhsRealIsInfWithSignTimesRhsImag = |
405 | rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf); |
406 | Value lhsImagIsInfWithSignTimesRhsReal = |
407 | rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf); |
408 | Value resultImag3 = rewriter.create<arith::MulFOp>( |
409 | loc, inf, |
410 | rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, |
411 | lhsRealIsInfWithSignTimesRhsImag, fmf), |
412 | fmf); |
413 | |
414 | // Case 3: Finite numerator, infinite denominator. |
415 | Value lhsRealFinite = rewriter.create<arith::CmpFOp>( |
416 | loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); |
417 | Value lhsImagFinite = rewriter.create<arith::CmpFOp>( |
418 | loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); |
419 | Value lhsFinite = |
420 | rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); |
421 | Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( |
422 | loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); |
423 | Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( |
424 | loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); |
425 | Value rhsInfinite = |
426 | rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); |
427 | Value finiteNumInfiniteDenom = |
428 | rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); |
429 | Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( |
430 | loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero), |
431 | rhsReal); |
432 | Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( |
433 | loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero), |
434 | rhsImag); |
435 | Value rhsRealIsInfWithSignTimesLhsReal = |
436 | rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf); |
437 | Value rhsImagIsInfWithSignTimesLhsImag = |
438 | rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf); |
439 | Value resultReal4 = rewriter.create<arith::MulFOp>( |
440 | loc, zero, |
441 | rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, |
442 | rhsImagIsInfWithSignTimesLhsImag, fmf), |
443 | fmf); |
444 | Value rhsRealIsInfWithSignTimesLhsImag = |
445 | rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf); |
446 | Value rhsImagIsInfWithSignTimesLhsReal = |
447 | rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf); |
448 | Value resultImag4 = rewriter.create<arith::MulFOp>( |
449 | loc, zero, |
450 | rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, |
451 | rhsImagIsInfWithSignTimesLhsReal, fmf), |
452 | fmf); |
453 | |
454 | Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( |
455 | loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); |
456 | Value resultReal = rewriter.create<arith::SelectOp>( |
457 | loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); |
458 | Value resultImag = rewriter.create<arith::SelectOp>( |
459 | loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); |
460 | Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>( |
461 | loc, finiteNumInfiniteDenom, resultReal4, resultReal); |
462 | Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>( |
463 | loc, finiteNumInfiniteDenom, resultImag4, resultImag); |
464 | Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>( |
465 | loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); |
466 | Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>( |
467 | loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); |
468 | Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>( |
469 | loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); |
470 | Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>( |
471 | loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); |
472 | |
473 | Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( |
474 | loc, arith::CmpFPredicate::UNO, resultReal, zero); |
475 | Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( |
476 | loc, arith::CmpFPredicate::UNO, resultImag, zero); |
477 | Value resultIsNaN = |
478 | rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); |
479 | Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>( |
480 | loc, resultIsNaN, resultRealSpecialCase1, resultReal); |
481 | Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>( |
482 | loc, resultIsNaN, resultImagSpecialCase1, resultImag); |
483 | |
484 | rewriter.replaceOpWithNewOp<complex::CreateOp>( |
485 | op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); |
486 | return success(); |
487 | } |
488 | }; |
489 | |
490 | struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { |
491 | using OpConversionPattern<complex::ExpOp>::OpConversionPattern; |
492 | |
493 | LogicalResult |
494 | matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, |
495 | ConversionPatternRewriter &rewriter) const override { |
496 | auto loc = op.getLoc(); |
497 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
498 | auto elementType = cast<FloatType>(type.getElementType()); |
499 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
500 | |
501 | Value real = |
502 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); |
503 | Value imag = |
504 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); |
505 | Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue()); |
506 | Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue()); |
507 | Value resultReal = |
508 | rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue()); |
509 | Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue()); |
510 | Value resultImag = |
511 | rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue()); |
512 | |
513 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
514 | resultImag); |
515 | return success(); |
516 | } |
517 | }; |
518 | |
519 | struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> { |
520 | using OpConversionPattern<complex::Expm1Op>::OpConversionPattern; |
521 | |
522 | LogicalResult |
523 | matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, |
524 | ConversionPatternRewriter &rewriter) const override { |
525 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
526 | auto elementType = cast<FloatType>(type.getElementType()); |
527 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
528 | |
529 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
530 | Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue()); |
531 | |
532 | Value real = b.create<complex::ReOp>(elementType, exp); |
533 | Value one = b.create<arith::ConstantOp>(elementType, |
534 | b.getFloatAttr(elementType, 1)); |
535 | Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue()); |
536 | Value imag = b.create<complex::ImOp>(elementType, exp); |
537 | |
538 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne, |
539 | imag); |
540 | return success(); |
541 | } |
542 | }; |
543 | |
544 | struct LogOpConversion : public OpConversionPattern<complex::LogOp> { |
545 | using OpConversionPattern<complex::LogOp>::OpConversionPattern; |
546 | |
547 | LogicalResult |
548 | matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, |
549 | ConversionPatternRewriter &rewriter) const override { |
550 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
551 | auto elementType = cast<FloatType>(type.getElementType()); |
552 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
553 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
554 | |
555 | Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), |
556 | fmf.getValue()); |
557 | Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue()); |
558 | Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); |
559 | Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); |
560 | Value resultImag = |
561 | b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue()); |
562 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
563 | resultImag); |
564 | return success(); |
565 | } |
566 | }; |
567 | |
568 | struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { |
569 | using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; |
570 | |
571 | LogicalResult |
572 | matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, |
573 | ConversionPatternRewriter &rewriter) const override { |
574 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
575 | auto elementType = cast<FloatType>(type.getElementType()); |
576 | arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); |
577 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
578 | |
579 | Value real = b.create<complex::ReOp>(adaptor.getComplex()); |
580 | Value imag = b.create<complex::ImOp>(adaptor.getComplex()); |
581 | |
582 | Value half = b.create<arith::ConstantOp>(elementType, |
583 | b.getFloatAttr(elementType, 0.5)); |
584 | Value one = b.create<arith::ConstantOp>(elementType, |
585 | b.getFloatAttr(elementType, 1)); |
586 | Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf); |
587 | Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf); |
588 | Value absImag = b.create<math::AbsFOp>(imag, fmf); |
589 | |
590 | Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf); |
591 | Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf); |
592 | |
593 | Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, |
594 | realPlusOne, absImag, fmf); |
595 | Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf); |
596 | Value maxAbsOfRealPlusOneAndImagMinusOne = |
597 | b.create<arith::SelectOp>(useReal, real, maxMinusOne); |
598 | Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmf); |
599 | Value logOfMaxAbsOfRealPlusOneAndImag = |
600 | b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf); |
601 | Value logOfSqrtPart = b.create<math::Log1pOp>( |
602 | b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmf), fmf); |
603 | Value r = b.create<arith::AddFOp>( |
604 | b.create<arith::MulFOp>(half, logOfSqrtPart, fmf), |
605 | logOfMaxAbsOfRealPlusOneAndImag, fmf); |
606 | Value resultReal = b.create<arith::SelectOp>( |
607 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmf), minAbs, |
608 | r); |
609 | Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf); |
610 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
611 | resultImag); |
612 | return success(); |
613 | } |
614 | }; |
615 | |
616 | struct MulOpConversion : public OpConversionPattern<complex::MulOp> { |
617 | using OpConversionPattern<complex::MulOp>::OpConversionPattern; |
618 | |
619 | LogicalResult |
620 | matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, |
621 | ConversionPatternRewriter &rewriter) const override { |
622 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
623 | auto type = cast<ComplexType>(adaptor.getLhs().getType()); |
624 | auto elementType = cast<FloatType>(type.getElementType()); |
625 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
626 | auto fmfValue = fmf.getValue(); |
627 | |
628 | Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); |
629 | Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue); |
630 | Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); |
631 | Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue); |
632 | Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); |
633 | Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue); |
634 | Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); |
635 | Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue); |
636 | |
637 | Value lhsRealTimesRhsReal = |
638 | b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue); |
639 | Value lhsRealTimesRhsRealAbs = |
640 | b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue); |
641 | Value lhsImagTimesRhsImag = |
642 | b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue); |
643 | Value lhsImagTimesRhsImagAbs = |
644 | b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue); |
645 | Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal, |
646 | lhsImagTimesRhsImag, fmfValue); |
647 | |
648 | Value lhsImagTimesRhsReal = |
649 | b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue); |
650 | Value lhsImagTimesRhsRealAbs = |
651 | b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue); |
652 | Value lhsRealTimesRhsImag = |
653 | b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue); |
654 | Value lhsRealTimesRhsImagAbs = |
655 | b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue); |
656 | Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal, |
657 | lhsRealTimesRhsImag, fmfValue); |
658 | |
659 | // Handle cases where the "naive" calculation results in NaN values. |
660 | Value realIsNan = |
661 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real); |
662 | Value imagIsNan = |
663 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag); |
664 | Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan); |
665 | |
666 | Value inf = b.create<arith::ConstantOp>( |
667 | elementType, |
668 | b.getFloatAttr(elementType, |
669 | APFloat::getInf(elementType.getFloatSemantics()))); |
670 | |
671 | // Case 1. `lhsReal` or `lhsImag` are infinite. |
672 | Value lhsRealIsInf = |
673 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); |
674 | Value lhsImagIsInf = |
675 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); |
676 | Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf); |
677 | Value rhsRealIsNan = |
678 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal); |
679 | Value rhsImagIsNan = |
680 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag); |
681 | Value zero = |
682 | b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); |
683 | Value one = b.create<arith::ConstantOp>(elementType, |
684 | b.getFloatAttr(elementType, 1)); |
685 | Value lhsRealIsInfFloat = |
686 | b.create<arith::SelectOp>(lhsRealIsInf, one, zero); |
687 | lhsReal = b.create<arith::SelectOp>( |
688 | lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal), |
689 | lhsReal); |
690 | Value lhsImagIsInfFloat = |
691 | b.create<arith::SelectOp>(lhsImagIsInf, one, zero); |
692 | lhsImag = b.create<arith::SelectOp>( |
693 | lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag), |
694 | lhsImag); |
695 | Value lhsIsInfAndRhsRealIsNan = |
696 | b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan); |
697 | rhsReal = b.create<arith::SelectOp>( |
698 | lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), |
699 | rhsReal); |
700 | Value lhsIsInfAndRhsImagIsNan = |
701 | b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan); |
702 | rhsImag = b.create<arith::SelectOp>( |
703 | lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), |
704 | rhsImag); |
705 | |
706 | // Case 2. `rhsReal` or `rhsImag` are infinite. |
707 | Value rhsRealIsInf = |
708 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); |
709 | Value rhsImagIsInf = |
710 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); |
711 | Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf); |
712 | Value lhsRealIsNan = |
713 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal); |
714 | Value lhsImagIsNan = |
715 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag); |
716 | Value rhsRealIsInfFloat = |
717 | b.create<arith::SelectOp>(rhsRealIsInf, one, zero); |
718 | rhsReal = b.create<arith::SelectOp>( |
719 | rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal), |
720 | rhsReal); |
721 | Value rhsImagIsInfFloat = |
722 | b.create<arith::SelectOp>(rhsImagIsInf, one, zero); |
723 | rhsImag = b.create<arith::SelectOp>( |
724 | rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag), |
725 | rhsImag); |
726 | Value rhsIsInfAndLhsRealIsNan = |
727 | b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan); |
728 | lhsReal = b.create<arith::SelectOp>( |
729 | rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), |
730 | lhsReal); |
731 | Value rhsIsInfAndLhsImagIsNan = |
732 | b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan); |
733 | lhsImag = b.create<arith::SelectOp>( |
734 | rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), |
735 | lhsImag); |
736 | Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf); |
737 | |
738 | // Case 3. One of the pairwise products of left hand side with right hand |
739 | // side is infinite. |
740 | Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>( |
741 | arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); |
742 | Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>( |
743 | arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); |
744 | Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf, |
745 | lhsImagTimesRhsImagIsInf); |
746 | Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>( |
747 | arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); |
748 | isSpecialCase = |
749 | b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); |
750 | Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>( |
751 | arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); |
752 | isSpecialCase = |
753 | b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); |
754 | Type i1Type = b.getI1Type(); |
755 | Value notRecalc = b.create<arith::XOrIOp>( |
756 | recalc, |
757 | b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); |
758 | isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc); |
759 | Value isSpecialCaseAndLhsRealIsNan = |
760 | b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan); |
761 | lhsReal = b.create<arith::SelectOp>( |
762 | isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), |
763 | lhsReal); |
764 | Value isSpecialCaseAndLhsImagIsNan = |
765 | b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan); |
766 | lhsImag = b.create<arith::SelectOp>( |
767 | isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), |
768 | lhsImag); |
769 | Value isSpecialCaseAndRhsRealIsNan = |
770 | b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan); |
771 | rhsReal = b.create<arith::SelectOp>( |
772 | isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), |
773 | rhsReal); |
774 | Value isSpecialCaseAndRhsImagIsNan = |
775 | b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan); |
776 | rhsImag = b.create<arith::SelectOp>( |
777 | isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), |
778 | rhsImag); |
779 | recalc = b.create<arith::OrIOp>(recalc, isSpecialCase); |
780 | recalc = b.create<arith::AndIOp>(isNan, recalc); |
781 | |
782 | // Recalculate real part. |
783 | lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue); |
784 | lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue); |
785 | Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal, |
786 | lhsImagTimesRhsImag, fmfValue); |
787 | real = b.create<arith::SelectOp>( |
788 | recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real); |
789 | |
790 | // Recalculate imag part. |
791 | lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue); |
792 | lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue); |
793 | Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal, |
794 | lhsRealTimesRhsImag, fmfValue); |
795 | imag = b.create<arith::SelectOp>( |
796 | recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag); |
797 | |
798 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); |
799 | return success(); |
800 | } |
801 | }; |
802 | |
803 | struct NegOpConversion : public OpConversionPattern<complex::NegOp> { |
804 | using OpConversionPattern<complex::NegOp>::OpConversionPattern; |
805 | |
806 | LogicalResult |
807 | matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, |
808 | ConversionPatternRewriter &rewriter) const override { |
809 | auto loc = op.getLoc(); |
810 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
811 | auto elementType = cast<FloatType>(type.getElementType()); |
812 | |
813 | Value real = |
814 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); |
815 | Value imag = |
816 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); |
817 | Value negReal = rewriter.create<arith::NegFOp>(loc, real); |
818 | Value negImag = rewriter.create<arith::NegFOp>(loc, imag); |
819 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); |
820 | return success(); |
821 | } |
822 | }; |
823 | |
824 | struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> { |
825 | using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion; |
826 | |
827 | std::pair<Value, Value> combine(Location loc, Value scaledExp, |
828 | Value reciprocalExp, Value sin, Value cos, |
829 | ConversionPatternRewriter &rewriter, |
830 | arith::FastMathFlagsAttr fmf) const override { |
831 | // Complex sine is defined as; |
832 | // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) |
833 | // Plugging in: |
834 | // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) |
835 | // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) |
836 | // and defining t := exp(y) |
837 | // We get: |
838 | // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x |
839 | // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x |
840 | Value sum = |
841 | rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf); |
842 | Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf); |
843 | Value diff = |
844 | rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf); |
845 | Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf); |
846 | return {resultReal, resultImag}; |
847 | } |
848 | }; |
849 | |
850 | // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. |
851 | struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> { |
852 | using OpConversionPattern<complex::SqrtOp>::OpConversionPattern; |
853 | |
854 | LogicalResult |
855 | matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor, |
856 | ConversionPatternRewriter &rewriter) const override { |
857 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
858 | |
859 | auto type = cast<ComplexType>(op.getType()); |
860 | auto elementType = type.getElementType().cast<FloatType>(); |
861 | arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); |
862 | |
863 | auto cst = [&](APFloat v) { |
864 | return b.create<arith::ConstantOp>(elementType, |
865 | b.getFloatAttr(elementType, v)); |
866 | }; |
867 | const auto &floatSemantics = elementType.getFloatSemantics(); |
868 | Value zero = cst(APFloat::getZero(Sem: floatSemantics)); |
869 | Value half = b.create<arith::ConstantOp>(elementType, |
870 | b.getFloatAttr(elementType, 0.5)); |
871 | |
872 | Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); |
873 | Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); |
874 | Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); |
875 | Value argArg = b.create<math::Atan2Op>(imag, real, fmf); |
876 | Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf); |
877 | Value cos = b.create<math::CosOp>(sqrtArg, fmf); |
878 | Value sin = b.create<math::SinOp>(sqrtArg, fmf); |
879 | // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply |
880 | // 0 * inf. |
881 | Value sinIsZero = |
882 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf); |
883 | |
884 | Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf); |
885 | Value resultImag = b.create<arith::SelectOp>( |
886 | sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf)); |
887 | if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | |
888 | arith::FastMathFlags::ninf)) { |
889 | Value inf = cst(APFloat::getInf(Sem: floatSemantics)); |
890 | Value negInf = cst(APFloat::getInf(Sem: floatSemantics, Negative: true)); |
891 | Value nan = cst(APFloat::getNaN(Sem: floatSemantics)); |
892 | Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf); |
893 | |
894 | Value absImagIsInf = |
895 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); |
896 | Value absImagIsNotInf = |
897 | b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf); |
898 | Value realIsInf = |
899 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf); |
900 | Value realIsNegInf = |
901 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf); |
902 | |
903 | resultReal = b.create<arith::SelectOp>( |
904 | b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero, |
905 | resultReal); |
906 | resultReal = b.create<arith::SelectOp>( |
907 | b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal); |
908 | |
909 | Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf); |
910 | resultImag = b.create<arith::SelectOp>( |
911 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt), |
912 | nan, resultImag); |
913 | resultImag = b.create<arith::SelectOp>( |
914 | b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf, |
915 | resultImag); |
916 | } |
917 | |
918 | Value resultIsZero = |
919 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); |
920 | resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal); |
921 | resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag); |
922 | |
923 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
924 | resultImag); |
925 | return success(); |
926 | } |
927 | }; |
928 | |
929 | struct SignOpConversion : public OpConversionPattern<complex::SignOp> { |
930 | using OpConversionPattern<complex::SignOp>::OpConversionPattern; |
931 | |
932 | LogicalResult |
933 | matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, |
934 | ConversionPatternRewriter &rewriter) const override { |
935 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
936 | auto elementType = cast<FloatType>(type.getElementType()); |
937 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
938 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
939 | |
940 | Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); |
941 | Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); |
942 | Value zero = |
943 | b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); |
944 | Value realIsZero = |
945 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); |
946 | Value imagIsZero = |
947 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); |
948 | Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); |
949 | auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf); |
950 | Value realSign = b.create<arith::DivFOp>(real, abs, fmf); |
951 | Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf); |
952 | Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); |
953 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, |
954 | adaptor.getComplex(), sign); |
955 | return success(); |
956 | } |
957 | }; |
958 | |
959 | struct TanOpConversion : public OpConversionPattern<complex::TanOp> { |
960 | using OpConversionPattern<complex::TanOp>::OpConversionPattern; |
961 | |
962 | LogicalResult |
963 | matchAndRewrite(complex::TanOp op, OpAdaptor adaptor, |
964 | ConversionPatternRewriter &rewriter) const override { |
965 | auto loc = op.getLoc(); |
966 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
967 | |
968 | Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex(), fmf); |
969 | Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex(), fmf); |
970 | rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos, fmf); |
971 | return success(); |
972 | } |
973 | }; |
974 | |
975 | struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> { |
976 | using OpConversionPattern<complex::TanhOp>::OpConversionPattern; |
977 | |
978 | LogicalResult |
979 | matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor, |
980 | ConversionPatternRewriter &rewriter) const override { |
981 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
982 | auto loc = op.getLoc(); |
983 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
984 | auto elementType = cast<FloatType>(type.getElementType()); |
985 | arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); |
986 | const auto &floatSemantics = elementType.getFloatSemantics(); |
987 | |
988 | Value real = |
989 | b.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); |
990 | Value imag = |
991 | b.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); |
992 | |
993 | auto cst = [&](APFloat v) { |
994 | return b.create<arith::ConstantOp>(elementType, |
995 | b.getFloatAttr(elementType, v)); |
996 | }; |
997 | Value inf = cst(APFloat::getInf(Sem: floatSemantics)); |
998 | Value negOne = b.create<arith::ConstantOp>( |
999 | elementType, b.getFloatAttr(elementType, -1.0)); |
1000 | Value four = b.create<arith::ConstantOp>(elementType, |
1001 | b.getFloatAttr(elementType, 4.0)); |
1002 | Value twoReal = b.create<arith::AddFOp>(real, real, fmf); |
1003 | Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf); |
1004 | |
1005 | Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf); |
1006 | Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf); |
1007 | Value realNum = |
1008 | b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); |
1009 | |
1010 | Value cosImag = b.create<math::CosOp>(imag, fmf); |
1011 | Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf); |
1012 | Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf); |
1013 | Value sinImag = b.create<math::SinOp>(imag, fmf); |
1014 | |
1015 | Value imagNum = b.create<arith::MulFOp>( |
1016 | four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf); |
1017 | |
1018 | Value expSumMinusTwo = |
1019 | b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); |
1020 | Value denom = |
1021 | b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf); |
1022 | |
1023 | Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, |
1024 | expSumMinusTwo, inf, fmf); |
1025 | Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf); |
1026 | |
1027 | Value resultReal = b.create<arith::SelectOp>( |
1028 | isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf)); |
1029 | Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf); |
1030 | |
1031 | if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | |
1032 | arith::FastMathFlags::ninf)) { |
1033 | Value absReal = b.create<math::AbsFOp>(real, fmf); |
1034 | Value zero = b.create<arith::ConstantOp>( |
1035 | elementType, b.getFloatAttr(elementType, 0.0)); |
1036 | Value nan = cst(APFloat::getNaN(Sem: floatSemantics)); |
1037 | |
1038 | Value absRealIsInf = |
1039 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); |
1040 | Value imagIsZero = |
1041 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); |
1042 | Value absRealIsNotInf = b.create<arith::XOrIOp>( |
1043 | absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1)); |
1044 | |
1045 | Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, |
1046 | imagNum, imagNum, fmf); |
1047 | Value resultRealIsNaN = |
1048 | b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf); |
1049 | Value resultImagIsZero = b.create<arith::OrIOp>( |
1050 | imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN)); |
1051 | |
1052 | resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal); |
1053 | resultImag = |
1054 | b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag); |
1055 | } |
1056 | |
1057 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
1058 | resultImag); |
1059 | return success(); |
1060 | } |
1061 | }; |
1062 | |
1063 | struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> { |
1064 | using OpConversionPattern<complex::ConjOp>::OpConversionPattern; |
1065 | |
1066 | LogicalResult |
1067 | matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor, |
1068 | ConversionPatternRewriter &rewriter) const override { |
1069 | auto loc = op.getLoc(); |
1070 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
1071 | auto elementType = cast<FloatType>(type.getElementType()); |
1072 | Value real = |
1073 | rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); |
1074 | Value imag = |
1075 | rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); |
1076 | Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag); |
1077 | |
1078 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag); |
1079 | |
1080 | return success(); |
1081 | } |
1082 | }; |
1083 | |
1084 | /// Converts lhs^y = (a+bi)^(c+di) to |
1085 | /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), |
1086 | /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) |
1087 | static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, |
1088 | ComplexType type, Value lhs, Value c, Value d, |
1089 | arith::FastMathFlags fmf) { |
1090 | auto elementType = cast<FloatType>(type.getElementType()); |
1091 | |
1092 | Value a = builder.create<complex::ReOp>(lhs); |
1093 | Value b = builder.create<complex::ImOp>(lhs); |
1094 | |
1095 | Value abs = builder.create<complex::AbsOp>(lhs, fmf); |
1096 | Value absToC = builder.create<math::PowFOp>(abs, c, fmf); |
1097 | |
1098 | Value negD = builder.create<arith::NegFOp>(d, fmf); |
1099 | Value argLhs = builder.create<math::Atan2Op>(b, a, fmf); |
1100 | Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf); |
1101 | Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf); |
1102 | |
1103 | Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf); |
1104 | Value lnAbs = builder.create<math::LogOp>(abs, fmf); |
1105 | Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf); |
1106 | Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf); |
1107 | Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf); |
1108 | Value cosQ = builder.create<math::CosOp>(q, fmf); |
1109 | Value sinQ = builder.create<math::SinOp>(q, fmf); |
1110 | |
1111 | Value inf = builder.create<arith::ConstantOp>( |
1112 | elementType, |
1113 | builder.getFloatAttr(elementType, |
1114 | APFloat::getInf(elementType.getFloatSemantics()))); |
1115 | Value zero = builder.create<arith::ConstantOp>( |
1116 | elementType, builder.getFloatAttr(elementType, 0.0)); |
1117 | Value one = builder.create<arith::ConstantOp>( |
1118 | elementType, builder.getFloatAttr(elementType, 1.0)); |
1119 | Value complexOne = builder.create<complex::CreateOp>(type, one, zero); |
1120 | Value complexZero = builder.create<complex::CreateOp>(type, zero, zero); |
1121 | Value complexInf = builder.create<complex::CreateOp>(type, inf, zero); |
1122 | |
1123 | // Case 0: |
1124 | // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see |
1125 | // Branch Cuts for Complex Elementary Functions or Much Ado About |
1126 | // Nothing's Sign Bit, W. Kahan, Section 10. |
1127 | Value absEqZero = |
1128 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf); |
1129 | Value dEqZero = |
1130 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf); |
1131 | Value cEqZero = |
1132 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf); |
1133 | Value bEqZero = |
1134 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf); |
1135 | |
1136 | Value zeroLeC = |
1137 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf); |
1138 | Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf); |
1139 | Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf); |
1140 | Value complexOneOrZero = |
1141 | builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero); |
1142 | Value coeffCosSin = |
1143 | builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ); |
1144 | Value cutoff0 = builder.create<arith::SelectOp>( |
1145 | builder.create<arith::AndIOp>( |
1146 | builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC), |
1147 | complexOneOrZero, coeffCosSin); |
1148 | |
1149 | // Case 1: |
1150 | // x^0 is defined to be 1 for any x, see |
1151 | // Branch Cuts for Complex Elementary Functions or Much Ado About |
1152 | // Nothing's Sign Bit, W. Kahan, Section 10. |
1153 | Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero); |
1154 | Value cutoff1 = |
1155 | builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0); |
1156 | |
1157 | // Case 2: |
1158 | // 1^(c + d*i) = 1 + 0*i |
1159 | Value lhsEqOne = builder.create<arith::AndIOp>( |
1160 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf), |
1161 | bEqZero); |
1162 | Value cutoff2 = |
1163 | builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1); |
1164 | |
1165 | // Case 3: |
1166 | // inf^(c + 0*i) = inf + 0*i, c > 0 |
1167 | Value lhsEqInf = builder.create<arith::AndIOp>( |
1168 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf), |
1169 | bEqZero); |
1170 | Value rhsGt0 = builder.create<arith::AndIOp>( |
1171 | dEqZero, |
1172 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf)); |
1173 | Value cutoff3 = builder.create<arith::SelectOp>( |
1174 | builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2); |
1175 | |
1176 | // Case 4: |
1177 | // inf^(c + 0*i) = 0 + 0*i, c < 0 |
1178 | Value rhsLt0 = builder.create<arith::AndIOp>( |
1179 | dEqZero, |
1180 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf)); |
1181 | Value cutoff4 = builder.create<arith::SelectOp>( |
1182 | builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3); |
1183 | |
1184 | return cutoff4; |
1185 | } |
1186 | |
1187 | struct PowOpConversion : public OpConversionPattern<complex::PowOp> { |
1188 | using OpConversionPattern<complex::PowOp>::OpConversionPattern; |
1189 | |
1190 | LogicalResult |
1191 | matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, |
1192 | ConversionPatternRewriter &rewriter) const override { |
1193 | mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); |
1194 | auto type = cast<ComplexType>(adaptor.getLhs().getType()); |
1195 | auto elementType = cast<FloatType>(type.getElementType()); |
1196 | |
1197 | Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs()); |
1198 | Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs()); |
1199 | |
1200 | rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(), |
1201 | c, d, op.getFastmath())}); |
1202 | return success(); |
1203 | } |
1204 | }; |
1205 | |
1206 | struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> { |
1207 | using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern; |
1208 | |
1209 | LogicalResult |
1210 | matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, |
1211 | ConversionPatternRewriter &rewriter) const override { |
1212 | mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
1213 | auto type = cast<ComplexType>(adaptor.getComplex().getType()); |
1214 | auto elementType = cast<FloatType>(type.getElementType()); |
1215 | |
1216 | arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); |
1217 | |
1218 | auto cst = [&](APFloat v) { |
1219 | return b.create<arith::ConstantOp>(elementType, |
1220 | b.getFloatAttr(elementType, v)); |
1221 | }; |
1222 | const auto &floatSemantics = elementType.getFloatSemantics(); |
1223 | Value zero = cst(APFloat::getZero(Sem: floatSemantics)); |
1224 | Value inf = cst(APFloat::getInf(Sem: floatSemantics)); |
1225 | Value negHalf = b.create<arith::ConstantOp>( |
1226 | elementType, b.getFloatAttr(elementType, -0.5)); |
1227 | Value nan = cst(APFloat::getNaN(Sem: floatSemantics)); |
1228 | |
1229 | Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); |
1230 | Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); |
1231 | Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); |
1232 | Value argArg = b.create<math::Atan2Op>(imag, real, fmf); |
1233 | Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf); |
1234 | Value cos = b.create<math::CosOp>(rsqrtArg, fmf); |
1235 | Value sin = b.create<math::SinOp>(rsqrtArg, fmf); |
1236 | |
1237 | Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf); |
1238 | Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf); |
1239 | |
1240 | if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | |
1241 | arith::FastMathFlags::ninf)) { |
1242 | Value negOne = b.create<arith::ConstantOp>( |
1243 | elementType, b.getFloatAttr(elementType, -1)); |
1244 | |
1245 | Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf); |
1246 | Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf); |
1247 | Value negImagSignedZero = |
1248 | b.create<arith::MulFOp>(negOne, imagSignedZero, fmf); |
1249 | |
1250 | Value absReal = b.create<math::AbsFOp>(real, fmf); |
1251 | Value absImag = b.create<math::AbsFOp>(imag, fmf); |
1252 | |
1253 | Value absImagIsInf = |
1254 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); |
1255 | Value realIsNan = |
1256 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf); |
1257 | Value realIsInf = |
1258 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); |
1259 | Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan); |
1260 | |
1261 | Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf); |
1262 | |
1263 | resultReal = |
1264 | b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal); |
1265 | resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero, |
1266 | resultImag); |
1267 | } |
1268 | |
1269 | Value isRealZero = |
1270 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf); |
1271 | Value isImagZero = |
1272 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); |
1273 | Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero); |
1274 | |
1275 | resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal); |
1276 | resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag); |
1277 | |
1278 | rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, |
1279 | resultImag); |
1280 | return success(); |
1281 | } |
1282 | }; |
1283 | |
1284 | struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> { |
1285 | using OpConversionPattern<complex::AngleOp>::OpConversionPattern; |
1286 | |
1287 | LogicalResult |
1288 | matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor, |
1289 | ConversionPatternRewriter &rewriter) const override { |
1290 | auto loc = op.getLoc(); |
1291 | auto type = op.getType(); |
1292 | arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); |
1293 | |
1294 | Value real = |
1295 | rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); |
1296 | Value imag = |
1297 | rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); |
1298 | |
1299 | rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf); |
1300 | |
1301 | return success(); |
1302 | } |
1303 | }; |
1304 | |
1305 | } // namespace |
1306 | |
1307 | void mlir::populateComplexToStandardConversionPatterns( |
1308 | RewritePatternSet &patterns) { |
1309 | // clang-format off |
1310 | patterns.add< |
1311 | AbsOpConversion, |
1312 | AngleOpConversion, |
1313 | Atan2OpConversion, |
1314 | BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, |
1315 | BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, |
1316 | ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, |
1317 | ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, |
1318 | ConjOpConversion, |
1319 | CosOpConversion, |
1320 | DivOpConversion, |
1321 | ExpOpConversion, |
1322 | Expm1OpConversion, |
1323 | Log1pOpConversion, |
1324 | LogOpConversion, |
1325 | MulOpConversion, |
1326 | NegOpConversion, |
1327 | SignOpConversion, |
1328 | SinOpConversion, |
1329 | SqrtOpConversion, |
1330 | TanOpConversion, |
1331 | TanhOpConversion, |
1332 | PowOpConversion, |
1333 | RsqrtOpConversion |
1334 | >(patterns.getContext()); |
1335 | // clang-format on |
1336 | } |
1337 | |
1338 | namespace { |
1339 | struct ConvertComplexToStandardPass |
1340 | : public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> { |
1341 | void runOnOperation() override; |
1342 | }; |
1343 | |
1344 | void ConvertComplexToStandardPass::runOnOperation() { |
1345 | // Convert to the Standard dialect using the converter defined above. |
1346 | RewritePatternSet patterns(&getContext()); |
1347 | populateComplexToStandardConversionPatterns(patterns); |
1348 | |
1349 | ConversionTarget target(getContext()); |
1350 | target.addLegalDialect<arith::ArithDialect, math::MathDialect>(); |
1351 | target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); |
1352 | if (failed( |
1353 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
1354 | signalPassFailure(); |
1355 | } |
1356 | } // namespace |
1357 | |
1358 | std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() { |
1359 | return std::make_unique<ConvertComplexToStandardPass>(); |
1360 | } |
1361 | |