1//===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10
11#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14#include "mlir/Conversion/LLVMCommon/Pattern.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Complex/IR/Complex.h"
17#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18#include "mlir/Pass/Pass.h"
19
20namespace mlir {
21#define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22#include "mlir/Conversion/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::LLVM;
27using namespace mlir::arith;
28
29//===----------------------------------------------------------------------===//
30// ComplexStructBuilder implementation.
31//===----------------------------------------------------------------------===//
32
33static constexpr unsigned kRealPosInComplexNumberStruct = 0;
34static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
35
36ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
37 Location loc, Type type) {
38 Value val = builder.create<LLVM::UndefOp>(loc, type);
39 return ComplexStructBuilder(val);
40}
41
42void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
43 Value real) {
44 setPtr(builder, loc, pos: kRealPosInComplexNumberStruct, ptr: real);
45}
46
47Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
48 return extractPtr(builder, loc, pos: kRealPosInComplexNumberStruct);
49}
50
51void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
52 Value imaginary) {
53 setPtr(builder, loc, pos: kImaginaryPosInComplexNumberStruct, ptr: imaginary);
54}
55
56Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
57 return extractPtr(builder, loc, pos: kImaginaryPosInComplexNumberStruct);
58}
59
60//===----------------------------------------------------------------------===//
61// Conversion patterns.
62//===----------------------------------------------------------------------===//
63
64namespace {
65
66struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
67 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
68
69 LogicalResult
70 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter) const override {
72 auto loc = op.getLoc();
73
74 ComplexStructBuilder complexStruct(adaptor.getComplex());
75 Value real = complexStruct.real(builder&: rewriter, loc: op.getLoc());
76 Value imag = complexStruct.imaginary(builder&: rewriter, loc: op.getLoc());
77
78 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
79 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
80 op.getContext(),
81 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
82 Value sqNorm = rewriter.create<LLVM::FAddOp>(
83 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
84 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
85
86 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
87 return success();
88 }
89};
90
91struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
92 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
93
94 LogicalResult
95 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
96 ConversionPatternRewriter &rewriter) const override {
97 return LLVM::detail::oneToOneRewrite(
98 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99 op->getAttrs(), *getTypeConverter(), rewriter);
100 }
101};
102
103struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
104 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
105
106 LogicalResult
107 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter) const override {
109 // Pack real and imaginary part in a complex number struct.
110 auto loc = complexOp.getLoc();
111 auto structType = typeConverter->convertType(complexOp.getType());
112 auto complexStruct = ComplexStructBuilder::undef(builder&: rewriter, loc: loc, type: structType);
113 complexStruct.setReal(rewriter, loc, adaptor.getReal());
114 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
115
116 rewriter.replaceOp(complexOp, {complexStruct});
117 return success();
118 }
119};
120
121struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
122 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
123
124 LogicalResult
125 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
126 ConversionPatternRewriter &rewriter) const override {
127 // Extract real part from the complex number struct.
128 ComplexStructBuilder complexStruct(adaptor.getComplex());
129 Value real = complexStruct.real(builder&: rewriter, loc: op.getLoc());
130 rewriter.replaceOp(op, real);
131
132 return success();
133 }
134};
135
136struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
137 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
138
139 LogicalResult
140 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
141 ConversionPatternRewriter &rewriter) const override {
142 // Extract imaginary part from the complex number struct.
143 ComplexStructBuilder complexStruct(adaptor.getComplex());
144 Value imaginary = complexStruct.imaginary(builder&: rewriter, loc: op.getLoc());
145 rewriter.replaceOp(op, imaginary);
146
147 return success();
148 }
149};
150
151struct BinaryComplexOperands {
152 std::complex<Value> lhs;
153 std::complex<Value> rhs;
154};
155
156template <typename OpTy>
157BinaryComplexOperands
158unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
159 ConversionPatternRewriter &rewriter) {
160 auto loc = op.getLoc();
161
162 // Extract real and imaginary values from operands.
163 BinaryComplexOperands unpacked;
164 ComplexStructBuilder lhs(adaptor.getLhs());
165 unpacked.lhs.real(lhs.real(builder&: rewriter, loc));
166 unpacked.lhs.imag(lhs.imaginary(builder&: rewriter, loc));
167 ComplexStructBuilder rhs(adaptor.getRhs());
168 unpacked.rhs.real(rhs.real(builder&: rewriter, loc));
169 unpacked.rhs.imag(rhs.imaginary(builder&: rewriter, loc));
170
171 return unpacked;
172}
173
174struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
175 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
176
177 LogicalResult
178 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
179 ConversionPatternRewriter &rewriter) const override {
180 auto loc = op.getLoc();
181 BinaryComplexOperands arg =
182 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
183
184 // Initialize complex number struct for result.
185 auto structType = typeConverter->convertType(op.getType());
186 auto result = ComplexStructBuilder::undef(builder&: rewriter, loc: loc, type: structType);
187
188 // Emit IR to add complex numbers.
189 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
190 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
191 op.getContext(),
192 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
193 Value real =
194 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
195 Value imag =
196 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
197 result.setReal(rewriter, loc, real);
198 result.setImaginary(rewriter, loc, imag);
199
200 rewriter.replaceOp(op, {result});
201 return success();
202 }
203};
204
205struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
206 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
207
208 LogicalResult
209 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
210 ConversionPatternRewriter &rewriter) const override {
211 auto loc = op.getLoc();
212 BinaryComplexOperands arg =
213 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
214
215 // Initialize complex number struct for result.
216 auto structType = typeConverter->convertType(op.getType());
217 auto result = ComplexStructBuilder::undef(builder&: rewriter, loc: loc, type: structType);
218
219 // Emit IR to add complex numbers.
220 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
221 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
222 op.getContext(),
223 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
224 Value rhsRe = arg.rhs.real();
225 Value rhsIm = arg.rhs.imag();
226 Value lhsRe = arg.lhs.real();
227 Value lhsIm = arg.lhs.imag();
228
229 Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
230 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
231 rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
232
233 Value resultReal = rewriter.create<LLVM::FAddOp>(
234 loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
235 rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
236
237 Value resultImag = rewriter.create<LLVM::FSubOp>(
238 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
239 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
240
241 result.setReal(
242 rewriter, loc,
243 rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
244 result.setImaginary(
245 rewriter, loc,
246 rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
247
248 rewriter.replaceOp(op, {result});
249 return success();
250 }
251};
252
253struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
254 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
255
256 LogicalResult
257 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
258 ConversionPatternRewriter &rewriter) const override {
259 auto loc = op.getLoc();
260 BinaryComplexOperands arg =
261 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
262
263 // Initialize complex number struct for result.
264 auto structType = typeConverter->convertType(op.getType());
265 auto result = ComplexStructBuilder::undef(builder&: rewriter, loc: loc, type: structType);
266
267 // Emit IR to add complex numbers.
268 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
269 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
270 op.getContext(),
271 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
272 Value rhsRe = arg.rhs.real();
273 Value rhsIm = arg.rhs.imag();
274 Value lhsRe = arg.lhs.real();
275 Value lhsIm = arg.lhs.imag();
276
277 Value real = rewriter.create<LLVM::FSubOp>(
278 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
279 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
280
281 Value imag = rewriter.create<LLVM::FAddOp>(
282 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
283 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
284
285 result.setReal(rewriter, loc, real);
286 result.setImaginary(rewriter, loc, imag);
287
288 rewriter.replaceOp(op, {result});
289 return success();
290 }
291};
292
293struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
294 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
295
296 LogicalResult
297 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter) const override {
299 auto loc = op.getLoc();
300 BinaryComplexOperands arg =
301 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
302
303 // Initialize complex number struct for result.
304 auto structType = typeConverter->convertType(op.getType());
305 auto result = ComplexStructBuilder::undef(builder&: rewriter, loc: loc, type: structType);
306
307 // Emit IR to substract complex numbers.
308 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
309 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
310 op.getContext(),
311 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
312 Value real =
313 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
314 Value imag =
315 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
316 result.setReal(rewriter, loc, real);
317 result.setImaginary(rewriter, loc, imag);
318
319 rewriter.replaceOp(op, {result});
320 return success();
321 }
322};
323} // namespace
324
325void mlir::populateComplexToLLVMConversionPatterns(
326 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
327 // clang-format off
328 patterns.add<
329 AbsOpConversion,
330 AddOpConversion,
331 ConstantOpLowering,
332 CreateOpConversion,
333 DivOpConversion,
334 ImOpConversion,
335 MulOpConversion,
336 ReOpConversion,
337 SubOpConversion
338 >(arg&: converter);
339 // clang-format on
340}
341
342namespace {
343struct ConvertComplexToLLVMPass
344 : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
345 using Base::Base;
346
347 void runOnOperation() override;
348};
349} // namespace
350
351void ConvertComplexToLLVMPass::runOnOperation() {
352 // Convert to the LLVM IR dialect using the converter defined above.
353 RewritePatternSet patterns(&getContext());
354 LLVMTypeConverter converter(&getContext());
355 populateComplexToLLVMConversionPatterns(converter, patterns);
356
357 LLVMConversionTarget target(getContext());
358 target.addIllegalDialect<complex::ComplexDialect>();
359 if (failed(
360 applyPartialConversion(getOperation(), target, std::move(patterns))))
361 signalPassFailure();
362}
363
364//===----------------------------------------------------------------------===//
365// ConvertToLLVMPatternInterface implementation
366//===----------------------------------------------------------------------===//
367
368namespace {
369/// Implement the interface to convert MemRef to LLVM.
370struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
371 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
372 void loadDependentDialects(MLIRContext *context) const final {
373 context->loadDialect<LLVM::LLVMDialect>();
374 }
375
376 /// Hook for derived dialect interface to provide conversion patterns
377 /// and mark dialect legal for the conversion target.
378 void populateConvertToLLVMConversionPatterns(
379 ConversionTarget &target, LLVMTypeConverter &typeConverter,
380 RewritePatternSet &patterns) const final {
381 populateComplexToLLVMConversionPatterns(converter&: typeConverter, patterns);
382 }
383};
384} // namespace
385
386void mlir::registerConvertComplexToLLVMInterface(DialectRegistry &registry) {
387 registry.addExtension(
388 extensionFn: +[](MLIRContext *ctx, complex::ComplexDialect *dialect) {
389 dialect->addInterfaces<ComplexToLLVMDialectInterface>();
390 });
391}
392

source code of mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp