1 | //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" |
10 | |
11 | #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" |
12 | #include "mlir/Conversion/ComplexCommon/DivisionConverter.h" |
13 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
14 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
15 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Complex/IR/Complex.h" |
18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | |
21 | namespace mlir { |
22 | #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS |
23 | #include "mlir/Conversion/Passes.h.inc" |
24 | } // namespace mlir |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::LLVM; |
28 | using namespace mlir::arith; |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // ComplexStructBuilder implementation. |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | static constexpr unsigned kRealPosInComplexNumberStruct = 0; |
35 | static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; |
36 | |
37 | ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder, |
38 | Location loc, Type type) { |
39 | Value val = builder.create<LLVM::PoisonOp>(loc, type); |
40 | return ComplexStructBuilder(val); |
41 | } |
42 | |
43 | void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, |
44 | Value real) { |
45 | setPtr(builder, loc, pos: kRealPosInComplexNumberStruct, ptr: real); |
46 | } |
47 | |
48 | Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { |
49 | return extractPtr(builder, loc, pos: kRealPosInComplexNumberStruct); |
50 | } |
51 | |
52 | void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, |
53 | Value imaginary) { |
54 | setPtr(builder, loc, pos: kImaginaryPosInComplexNumberStruct, ptr: imaginary); |
55 | } |
56 | |
57 | Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { |
58 | return extractPtr(builder, loc, pos: kImaginaryPosInComplexNumberStruct); |
59 | } |
60 | |
61 | //===----------------------------------------------------------------------===// |
62 | // Conversion patterns. |
63 | //===----------------------------------------------------------------------===// |
64 | |
65 | namespace { |
66 | |
67 | struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { |
68 | using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; |
69 | |
70 | LogicalResult |
71 | matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, |
72 | ConversionPatternRewriter &rewriter) const override { |
73 | auto loc = op.getLoc(); |
74 | |
75 | ComplexStructBuilder complexStruct(adaptor.getComplex()); |
76 | Value real = complexStruct.real(builder&: rewriter, loc: op.getLoc()); |
77 | Value imag = complexStruct.imaginary(builder&: rewriter, loc: op.getLoc()); |
78 | |
79 | arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
80 | LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
81 | op.getContext(), |
82 | convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
83 | Value sqNorm = rewriter.create<LLVM::FAddOp>( |
84 | loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), |
85 | rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); |
86 | |
87 | rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); |
88 | return success(); |
89 | } |
90 | }; |
91 | |
92 | struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> { |
93 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
94 | |
95 | LogicalResult |
96 | matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor, |
97 | ConversionPatternRewriter &rewriter) const override { |
98 | return LLVM::detail::oneToOneRewrite( |
99 | op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), |
100 | op->getAttrs(), *getTypeConverter(), rewriter); |
101 | } |
102 | }; |
103 | |
104 | struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { |
105 | using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; |
106 | |
107 | LogicalResult |
108 | matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor, |
109 | ConversionPatternRewriter &rewriter) const override { |
110 | // Pack real and imaginary part in a complex number struct. |
111 | auto loc = complexOp.getLoc(); |
112 | auto structType = typeConverter->convertType(complexOp.getType()); |
113 | auto complexStruct = |
114 | ComplexStructBuilder::poison(builder&: rewriter, loc: loc, type: structType); |
115 | complexStruct.setReal(rewriter, loc, adaptor.getReal()); |
116 | complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary()); |
117 | |
118 | rewriter.replaceOp(complexOp, {complexStruct}); |
119 | return success(); |
120 | } |
121 | }; |
122 | |
123 | struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { |
124 | using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; |
125 | |
126 | LogicalResult |
127 | matchAndRewrite(complex::ReOp op, OpAdaptor adaptor, |
128 | ConversionPatternRewriter &rewriter) const override { |
129 | // Extract real part from the complex number struct. |
130 | ComplexStructBuilder complexStruct(adaptor.getComplex()); |
131 | Value real = complexStruct.real(builder&: rewriter, loc: op.getLoc()); |
132 | rewriter.replaceOp(op, real); |
133 | |
134 | return success(); |
135 | } |
136 | }; |
137 | |
138 | struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { |
139 | using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; |
140 | |
141 | LogicalResult |
142 | matchAndRewrite(complex::ImOp op, OpAdaptor adaptor, |
143 | ConversionPatternRewriter &rewriter) const override { |
144 | // Extract imaginary part from the complex number struct. |
145 | ComplexStructBuilder complexStruct(adaptor.getComplex()); |
146 | Value imaginary = complexStruct.imaginary(builder&: rewriter, loc: op.getLoc()); |
147 | rewriter.replaceOp(op, imaginary); |
148 | |
149 | return success(); |
150 | } |
151 | }; |
152 | |
153 | struct BinaryComplexOperands { |
154 | std::complex<Value> lhs; |
155 | std::complex<Value> rhs; |
156 | }; |
157 | |
158 | template <typename OpTy> |
159 | BinaryComplexOperands |
160 | unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor, |
161 | ConversionPatternRewriter &rewriter) { |
162 | auto loc = op.getLoc(); |
163 | |
164 | // Extract real and imaginary values from operands. |
165 | BinaryComplexOperands unpacked; |
166 | ComplexStructBuilder lhs(adaptor.getLhs()); |
167 | unpacked.lhs.real(lhs.real(builder&: rewriter, loc)); |
168 | unpacked.lhs.imag(lhs.imaginary(builder&: rewriter, loc)); |
169 | ComplexStructBuilder rhs(adaptor.getRhs()); |
170 | unpacked.rhs.real(rhs.real(builder&: rewriter, loc)); |
171 | unpacked.rhs.imag(rhs.imaginary(builder&: rewriter, loc)); |
172 | |
173 | return unpacked; |
174 | } |
175 | |
176 | struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { |
177 | using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; |
178 | |
179 | LogicalResult |
180 | matchAndRewrite(complex::AddOp op, OpAdaptor adaptor, |
181 | ConversionPatternRewriter &rewriter) const override { |
182 | auto loc = op.getLoc(); |
183 | BinaryComplexOperands arg = |
184 | unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter); |
185 | |
186 | // Initialize complex number struct for result. |
187 | auto structType = typeConverter->convertType(op.getType()); |
188 | auto result = ComplexStructBuilder::poison(builder&: rewriter, loc: loc, type: structType); |
189 | |
190 | // Emit IR to add complex numbers. |
191 | arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
192 | LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
193 | op.getContext(), |
194 | convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
195 | Value real = |
196 | rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); |
197 | Value imag = |
198 | rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); |
199 | result.setReal(rewriter, loc, real); |
200 | result.setImaginary(rewriter, loc, imag); |
201 | |
202 | rewriter.replaceOp(op, {result}); |
203 | return success(); |
204 | } |
205 | }; |
206 | |
207 | struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { |
208 | DivOpConversion(const LLVMTypeConverter &converter, |
209 | complex::ComplexRangeFlags target) |
210 | : ConvertOpToLLVMPattern<complex::DivOp>(converter), |
211 | complexRange(target) {} |
212 | |
213 | using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; |
214 | |
215 | LogicalResult |
216 | matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, |
217 | ConversionPatternRewriter &rewriter) const override { |
218 | auto loc = op.getLoc(); |
219 | BinaryComplexOperands arg = |
220 | unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter); |
221 | |
222 | // Initialize complex number struct for result. |
223 | auto structType = typeConverter->convertType(op.getType()); |
224 | auto result = ComplexStructBuilder::poison(builder&: rewriter, loc: loc, type: structType); |
225 | |
226 | // Emit IR to add complex numbers. |
227 | arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
228 | LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
229 | op.getContext(), |
230 | convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
231 | Value rhsRe = arg.rhs.real(); |
232 | Value rhsIm = arg.rhs.imag(); |
233 | Value lhsRe = arg.lhs.real(); |
234 | Value lhsIm = arg.lhs.imag(); |
235 | |
236 | Value resultRe, resultIm; |
237 | |
238 | if (complexRange == complex::ComplexRangeFlags::basic || |
239 | complexRange == complex::ComplexRangeFlags::none) { |
240 | mlir::complex::convertDivToLLVMUsingAlgebraic( |
241 | rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm); |
242 | } else if (complexRange == complex::ComplexRangeFlags::improved) { |
243 | mlir::complex::convertDivToLLVMUsingRangeReduction( |
244 | rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm); |
245 | } |
246 | |
247 | result.setReal(rewriter, loc, resultRe); |
248 | result.setImaginary(rewriter, loc, resultIm); |
249 | |
250 | rewriter.replaceOp(op, {result}); |
251 | return success(); |
252 | } |
253 | |
254 | private: |
255 | complex::ComplexRangeFlags complexRange; |
256 | }; |
257 | |
258 | struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { |
259 | using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; |
260 | |
261 | LogicalResult |
262 | matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, |
263 | ConversionPatternRewriter &rewriter) const override { |
264 | auto loc = op.getLoc(); |
265 | BinaryComplexOperands arg = |
266 | unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter); |
267 | |
268 | // Initialize complex number struct for result. |
269 | auto structType = typeConverter->convertType(op.getType()); |
270 | auto result = ComplexStructBuilder::poison(builder&: rewriter, loc: loc, type: structType); |
271 | |
272 | // Emit IR to add complex numbers. |
273 | arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
274 | LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
275 | op.getContext(), |
276 | convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
277 | Value rhsRe = arg.rhs.real(); |
278 | Value rhsIm = arg.rhs.imag(); |
279 | Value lhsRe = arg.lhs.real(); |
280 | Value lhsIm = arg.lhs.imag(); |
281 | |
282 | Value real = rewriter.create<LLVM::FSubOp>( |
283 | loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), |
284 | rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); |
285 | |
286 | Value imag = rewriter.create<LLVM::FAddOp>( |
287 | loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), |
288 | rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); |
289 | |
290 | result.setReal(rewriter, loc, real); |
291 | result.setImaginary(rewriter, loc, imag); |
292 | |
293 | rewriter.replaceOp(op, {result}); |
294 | return success(); |
295 | } |
296 | }; |
297 | |
298 | struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { |
299 | using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; |
300 | |
301 | LogicalResult |
302 | matchAndRewrite(complex::SubOp op, OpAdaptor adaptor, |
303 | ConversionPatternRewriter &rewriter) const override { |
304 | auto loc = op.getLoc(); |
305 | BinaryComplexOperands arg = |
306 | unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter); |
307 | |
308 | // Initialize complex number struct for result. |
309 | auto structType = typeConverter->convertType(op.getType()); |
310 | auto result = ComplexStructBuilder::poison(builder&: rewriter, loc: loc, type: structType); |
311 | |
312 | // Emit IR to substract complex numbers. |
313 | arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
314 | LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
315 | op.getContext(), |
316 | convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
317 | Value real = |
318 | rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); |
319 | Value imag = |
320 | rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); |
321 | result.setReal(rewriter, loc, real); |
322 | result.setImaginary(rewriter, loc, imag); |
323 | |
324 | rewriter.replaceOp(op, {result}); |
325 | return success(); |
326 | } |
327 | }; |
328 | } // namespace |
329 | |
330 | void mlir::populateComplexToLLVMConversionPatterns( |
331 | const LLVMTypeConverter &converter, RewritePatternSet &patterns, |
332 | complex::ComplexRangeFlags complexRange) { |
333 | // clang-format off |
334 | patterns.add< |
335 | AbsOpConversion, |
336 | AddOpConversion, |
337 | ConstantOpLowering, |
338 | CreateOpConversion, |
339 | ImOpConversion, |
340 | MulOpConversion, |
341 | ReOpConversion, |
342 | SubOpConversion |
343 | >(arg: converter); |
344 | |
345 | patterns.add<DivOpConversion>(converter, complexRange); |
346 | // clang-format on |
347 | } |
348 | |
349 | namespace { |
350 | struct ConvertComplexToLLVMPass |
351 | : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> { |
352 | using Base::Base; |
353 | |
354 | void runOnOperation() override; |
355 | }; |
356 | } // namespace |
357 | |
358 | void ConvertComplexToLLVMPass::runOnOperation() { |
359 | // Convert to the LLVM IR dialect using the converter defined above. |
360 | RewritePatternSet patterns(&getContext()); |
361 | LLVMTypeConverter converter(&getContext()); |
362 | populateComplexToLLVMConversionPatterns(converter, patterns, complexRange); |
363 | |
364 | LLVMConversionTarget target(getContext()); |
365 | target.addIllegalDialect<complex::ComplexDialect>(); |
366 | if (failed( |
367 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
368 | signalPassFailure(); |
369 | } |
370 | |
371 | //===----------------------------------------------------------------------===// |
372 | // ConvertToLLVMPatternInterface implementation |
373 | //===----------------------------------------------------------------------===// |
374 | |
375 | namespace { |
376 | /// Implement the interface to convert MemRef to LLVM. |
377 | struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
378 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
379 | void loadDependentDialects(MLIRContext *context) const final { |
380 | context->loadDialect<LLVM::LLVMDialect>(); |
381 | } |
382 | |
383 | /// Hook for derived dialect interface to provide conversion patterns |
384 | /// and mark dialect legal for the conversion target. |
385 | void populateConvertToLLVMConversionPatterns( |
386 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
387 | RewritePatternSet &patterns) const final { |
388 | populateComplexToLLVMConversionPatterns(converter: typeConverter, patterns); |
389 | } |
390 | }; |
391 | } // namespace |
392 | |
393 | void mlir::registerConvertComplexToLLVMInterface(DialectRegistry ®istry) { |
394 | registry.addExtension( |
395 | extensionFn: +[](MLIRContext *ctx, complex::ComplexDialect *dialect) { |
396 | dialect->addInterfaces<ComplexToLLVMDialectInterface>(); |
397 | }); |
398 | } |
399 |
Definitions
- kRealPosInComplexNumberStruct
- kImaginaryPosInComplexNumberStruct
- poison
- setReal
- real
- setImaginary
- imaginary
- AbsOpConversion
- matchAndRewrite
- ConstantOpLowering
- matchAndRewrite
- CreateOpConversion
- matchAndRewrite
- ReOpConversion
- matchAndRewrite
- ImOpConversion
- matchAndRewrite
- BinaryComplexOperands
- unpackBinaryComplexOperands
- AddOpConversion
- matchAndRewrite
- DivOpConversion
- DivOpConversion
- matchAndRewrite
- MulOpConversion
- matchAndRewrite
- SubOpConversion
- matchAndRewrite
- populateComplexToLLVMConversionPatterns
- ConvertComplexToLLVMPass
- runOnOperation
- ComplexToLLVMDialectInterface
- loadDependentDialects
- populateConvertToLLVMConversionPatterns
Learn to use CMake with our Intro Training
Find out more