1 | //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" |
10 | |
11 | #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" |
12 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
13 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
14 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Complex/IR/Complex.h" |
17 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
18 | #include "mlir/Pass/Pass.h" |
19 | |
20 | namespace mlir { |
21 | #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS |
22 | #include "mlir/Conversion/Passes.h.inc" |
23 | } // namespace mlir |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::LLVM; |
27 | using namespace mlir::arith; |
28 | |
29 | //===----------------------------------------------------------------------===// |
30 | // ComplexStructBuilder implementation. |
31 | //===----------------------------------------------------------------------===// |
32 | |
33 | static constexpr unsigned kRealPosInComplexNumberStruct = 0; |
34 | static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; |
35 | |
36 | ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, |
37 | Location loc, Type type) { |
38 | Value val = builder.create<LLVM::UndefOp>(loc, type); |
39 | return ComplexStructBuilder(val); |
40 | } |
41 | |
42 | void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, |
43 | Value real) { |
44 | setPtr(builder, loc, pos: kRealPosInComplexNumberStruct, ptr: real); |
45 | } |
46 | |
47 | Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { |
48 | return extractPtr(builder, loc, pos: kRealPosInComplexNumberStruct); |
49 | } |
50 | |
51 | void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, |
52 | Value imaginary) { |
53 | setPtr(builder, loc, pos: kImaginaryPosInComplexNumberStruct, ptr: imaginary); |
54 | } |
55 | |
56 | Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { |
57 | return extractPtr(builder, loc, pos: kImaginaryPosInComplexNumberStruct); |
58 | } |
59 | |
60 | //===----------------------------------------------------------------------===// |
61 | // Conversion patterns. |
62 | //===----------------------------------------------------------------------===// |
63 | |
64 | namespace { |
65 | |
66 | struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { |
67 | using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; |
68 | |
69 | LogicalResult |
70 | matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, |
71 | ConversionPatternRewriter &rewriter) const override { |
72 | auto loc = op.getLoc(); |
73 | |
74 | ComplexStructBuilder complexStruct(adaptor.getComplex()); |
75 | Value real = complexStruct.real(builder&: rewriter, loc: op.getLoc()); |
76 | Value imag = complexStruct.imaginary(builder&: rewriter, loc: op.getLoc()); |
77 | |
78 | arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
79 | LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
80 | op.getContext(), |
81 | convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
82 | Value sqNorm = rewriter.create<LLVM::FAddOp>( |
83 | loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), |
84 | rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); |
85 | |
86 | rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); |
87 | return success(); |
88 | } |
89 | }; |
90 | |
91 | struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> { |
92 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
93 | |
94 | LogicalResult |
95 | matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor, |
96 | ConversionPatternRewriter &rewriter) const override { |
97 | return LLVM::detail::oneToOneRewrite( |
98 | op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), |
99 | op->getAttrs(), *getTypeConverter(), rewriter); |
100 | } |
101 | }; |
102 | |
103 | struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { |
104 | using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; |
105 | |
106 | LogicalResult |
107 | matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor, |
108 | ConversionPatternRewriter &rewriter) const override { |
109 | // Pack real and imaginary part in a complex number struct. |
110 | auto loc = complexOp.getLoc(); |
111 | auto structType = typeConverter->convertType(complexOp.getType()); |
112 | auto complexStruct = ComplexStructBuilder::undef(builder&: rewriter, loc: loc, type: structType); |
113 | complexStruct.setReal(rewriter, loc, adaptor.getReal()); |
114 | complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary()); |
115 | |
116 | rewriter.replaceOp(complexOp, {complexStruct}); |
117 | return success(); |
118 | } |
119 | }; |
120 | |
121 | struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { |
122 | using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; |
123 | |
124 | LogicalResult |
125 | matchAndRewrite(complex::ReOp op, OpAdaptor adaptor, |
126 | ConversionPatternRewriter &rewriter) const override { |
127 | // Extract real part from the complex number struct. |
128 | ComplexStructBuilder complexStruct(adaptor.getComplex()); |
129 | Value real = complexStruct.real(builder&: rewriter, loc: op.getLoc()); |
130 | rewriter.replaceOp(op, real); |
131 | |
132 | return success(); |
133 | } |
134 | }; |
135 | |
136 | struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { |
137 | using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; |
138 | |
139 | LogicalResult |
140 | matchAndRewrite(complex::ImOp op, OpAdaptor adaptor, |
141 | ConversionPatternRewriter &rewriter) const override { |
142 | // Extract imaginary part from the complex number struct. |
143 | ComplexStructBuilder complexStruct(adaptor.getComplex()); |
144 | Value imaginary = complexStruct.imaginary(builder&: rewriter, loc: op.getLoc()); |
145 | rewriter.replaceOp(op, imaginary); |
146 | |
147 | return success(); |
148 | } |
149 | }; |
150 | |
151 | struct BinaryComplexOperands { |
152 | std::complex<Value> lhs; |
153 | std::complex<Value> rhs; |
154 | }; |
155 | |
156 | template <typename OpTy> |
157 | BinaryComplexOperands |
158 | unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor, |
159 | ConversionPatternRewriter &rewriter) { |
160 | auto loc = op.getLoc(); |
161 | |
162 | // Extract real and imaginary values from operands. |
163 | BinaryComplexOperands unpacked; |
164 | ComplexStructBuilder lhs(adaptor.getLhs()); |
165 | unpacked.lhs.real(lhs.real(builder&: rewriter, loc)); |
166 | unpacked.lhs.imag(lhs.imaginary(builder&: rewriter, loc)); |
167 | ComplexStructBuilder rhs(adaptor.getRhs()); |
168 | unpacked.rhs.real(rhs.real(builder&: rewriter, loc)); |
169 | unpacked.rhs.imag(rhs.imaginary(builder&: rewriter, loc)); |
170 | |
171 | return unpacked; |
172 | } |
173 | |
174 | struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { |
175 | using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; |
176 | |
177 | LogicalResult |
178 | matchAndRewrite(complex::AddOp op, OpAdaptor adaptor, |
179 | ConversionPatternRewriter &rewriter) const override { |
180 | auto loc = op.getLoc(); |
181 | BinaryComplexOperands arg = |
182 | unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter); |
183 | |
184 | // Initialize complex number struct for result. |
185 | auto structType = typeConverter->convertType(op.getType()); |
186 | auto result = ComplexStructBuilder::undef(builder&: rewriter, loc: loc, type: structType); |
187 | |
188 | // Emit IR to add complex numbers. |
189 | arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
190 | LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
191 | op.getContext(), |
192 | convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
193 | Value real = |
194 | rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); |
195 | Value imag = |
196 | rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); |
197 | result.setReal(rewriter, loc, real); |
198 | result.setImaginary(rewriter, loc, imag); |
199 | |
200 | rewriter.replaceOp(op, {result}); |
201 | return success(); |
202 | } |
203 | }; |
204 | |
205 | struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { |
206 | using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; |
207 | |
208 | LogicalResult |
209 | matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, |
210 | ConversionPatternRewriter &rewriter) const override { |
211 | auto loc = op.getLoc(); |
212 | BinaryComplexOperands arg = |
213 | unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter); |
214 | |
215 | // Initialize complex number struct for result. |
216 | auto structType = typeConverter->convertType(op.getType()); |
217 | auto result = ComplexStructBuilder::undef(builder&: rewriter, loc: loc, type: structType); |
218 | |
219 | // Emit IR to add complex numbers. |
220 | arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
221 | LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
222 | op.getContext(), |
223 | convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
224 | Value rhsRe = arg.rhs.real(); |
225 | Value rhsIm = arg.rhs.imag(); |
226 | Value lhsRe = arg.lhs.real(); |
227 | Value lhsIm = arg.lhs.imag(); |
228 | |
229 | Value rhsSqNorm = rewriter.create<LLVM::FAddOp>( |
230 | loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf), |
231 | rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf); |
232 | |
233 | Value resultReal = rewriter.create<LLVM::FAddOp>( |
234 | loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf), |
235 | rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf); |
236 | |
237 | Value resultImag = rewriter.create<LLVM::FSubOp>( |
238 | loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), |
239 | rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); |
240 | |
241 | result.setReal( |
242 | rewriter, loc, |
243 | rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf)); |
244 | result.setImaginary( |
245 | rewriter, loc, |
246 | rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf)); |
247 | |
248 | rewriter.replaceOp(op, {result}); |
249 | return success(); |
250 | } |
251 | }; |
252 | |
253 | struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { |
254 | using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; |
255 | |
256 | LogicalResult |
257 | matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, |
258 | ConversionPatternRewriter &rewriter) const override { |
259 | auto loc = op.getLoc(); |
260 | BinaryComplexOperands arg = |
261 | unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter); |
262 | |
263 | // Initialize complex number struct for result. |
264 | auto structType = typeConverter->convertType(op.getType()); |
265 | auto result = ComplexStructBuilder::undef(builder&: rewriter, loc: loc, type: structType); |
266 | |
267 | // Emit IR to add complex numbers. |
268 | arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
269 | LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
270 | op.getContext(), |
271 | convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
272 | Value rhsRe = arg.rhs.real(); |
273 | Value rhsIm = arg.rhs.imag(); |
274 | Value lhsRe = arg.lhs.real(); |
275 | Value lhsIm = arg.lhs.imag(); |
276 | |
277 | Value real = rewriter.create<LLVM::FSubOp>( |
278 | loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), |
279 | rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); |
280 | |
281 | Value imag = rewriter.create<LLVM::FAddOp>( |
282 | loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), |
283 | rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); |
284 | |
285 | result.setReal(rewriter, loc, real); |
286 | result.setImaginary(rewriter, loc, imag); |
287 | |
288 | rewriter.replaceOp(op, {result}); |
289 | return success(); |
290 | } |
291 | }; |
292 | |
293 | struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { |
294 | using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; |
295 | |
296 | LogicalResult |
297 | matchAndRewrite(complex::SubOp op, OpAdaptor adaptor, |
298 | ConversionPatternRewriter &rewriter) const override { |
299 | auto loc = op.getLoc(); |
300 | BinaryComplexOperands arg = |
301 | unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter); |
302 | |
303 | // Initialize complex number struct for result. |
304 | auto structType = typeConverter->convertType(op.getType()); |
305 | auto result = ComplexStructBuilder::undef(builder&: rewriter, loc: loc, type: structType); |
306 | |
307 | // Emit IR to substract complex numbers. |
308 | arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
309 | LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
310 | op.getContext(), |
311 | convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
312 | Value real = |
313 | rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); |
314 | Value imag = |
315 | rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); |
316 | result.setReal(rewriter, loc, real); |
317 | result.setImaginary(rewriter, loc, imag); |
318 | |
319 | rewriter.replaceOp(op, {result}); |
320 | return success(); |
321 | } |
322 | }; |
323 | } // namespace |
324 | |
325 | void mlir::populateComplexToLLVMConversionPatterns( |
326 | LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
327 | // clang-format off |
328 | patterns.add< |
329 | AbsOpConversion, |
330 | AddOpConversion, |
331 | ConstantOpLowering, |
332 | CreateOpConversion, |
333 | DivOpConversion, |
334 | ImOpConversion, |
335 | MulOpConversion, |
336 | ReOpConversion, |
337 | SubOpConversion |
338 | >(arg&: converter); |
339 | // clang-format on |
340 | } |
341 | |
342 | namespace { |
343 | struct ConvertComplexToLLVMPass |
344 | : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> { |
345 | using Base::Base; |
346 | |
347 | void runOnOperation() override; |
348 | }; |
349 | } // namespace |
350 | |
351 | void ConvertComplexToLLVMPass::runOnOperation() { |
352 | // Convert to the LLVM IR dialect using the converter defined above. |
353 | RewritePatternSet patterns(&getContext()); |
354 | LLVMTypeConverter converter(&getContext()); |
355 | populateComplexToLLVMConversionPatterns(converter, patterns); |
356 | |
357 | LLVMConversionTarget target(getContext()); |
358 | target.addIllegalDialect<complex::ComplexDialect>(); |
359 | if (failed( |
360 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
361 | signalPassFailure(); |
362 | } |
363 | |
364 | //===----------------------------------------------------------------------===// |
365 | // ConvertToLLVMPatternInterface implementation |
366 | //===----------------------------------------------------------------------===// |
367 | |
368 | namespace { |
369 | /// Implement the interface to convert MemRef to LLVM. |
370 | struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
371 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
372 | void loadDependentDialects(MLIRContext *context) const final { |
373 | context->loadDialect<LLVM::LLVMDialect>(); |
374 | } |
375 | |
376 | /// Hook for derived dialect interface to provide conversion patterns |
377 | /// and mark dialect legal for the conversion target. |
378 | void populateConvertToLLVMConversionPatterns( |
379 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
380 | RewritePatternSet &patterns) const final { |
381 | populateComplexToLLVMConversionPatterns(converter&: typeConverter, patterns); |
382 | } |
383 | }; |
384 | } // namespace |
385 | |
386 | void mlir::registerConvertComplexToLLVMInterface(DialectRegistry ®istry) { |
387 | registry.addExtension( |
388 | extensionFn: +[](MLIRContext *ctx, complex::ComplexDialect *dialect) { |
389 | dialect->addInterfaces<ComplexToLLVMDialectInterface>(); |
390 | }); |
391 | } |
392 | |