1//===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10
11#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12#include "mlir/Conversion/ComplexCommon/DivisionConverter.h"
13#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
14#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15#include "mlir/Conversion/LLVMCommon/Pattern.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Complex/IR/Complex.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19
20namespace mlir {
21#define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22#include "mlir/Conversion/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::LLVM;
27using namespace mlir::arith;
28
29//===----------------------------------------------------------------------===//
30// ComplexStructBuilder implementation.
31//===----------------------------------------------------------------------===//
32
33static constexpr unsigned kRealPosInComplexNumberStruct = 0;
34static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
35
36ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder,
37 Location loc, Type type) {
38 Value val = builder.create<LLVM::PoisonOp>(location: loc, args&: type);
39 return ComplexStructBuilder(val);
40}
41
42void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
43 Value real) {
44 setPtr(builder, loc, pos: kRealPosInComplexNumberStruct, ptr: real);
45}
46
47Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
48 return extractPtr(builder, loc, pos: kRealPosInComplexNumberStruct);
49}
50
51void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
52 Value imaginary) {
53 setPtr(builder, loc, pos: kImaginaryPosInComplexNumberStruct, ptr: imaginary);
54}
55
56Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
57 return extractPtr(builder, loc, pos: kImaginaryPosInComplexNumberStruct);
58}
59
60//===----------------------------------------------------------------------===//
61// Conversion patterns.
62//===----------------------------------------------------------------------===//
63
64namespace {
65
66struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
67 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
68
69 LogicalResult
70 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter) const override {
72 auto loc = op.getLoc();
73
74 ComplexStructBuilder complexStruct(adaptor.getComplex());
75 Value real = complexStruct.real(builder&: rewriter, loc: op.getLoc());
76 Value imag = complexStruct.imaginary(builder&: rewriter, loc: op.getLoc());
77
78 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
79 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
80 context: op.getContext(),
81 value: convertArithFastMathFlagsToLLVM(arithFMF: complexFMFAttr.getValue()));
82 Value sqNorm = rewriter.create<LLVM::FAddOp>(
83 location: loc, args: rewriter.create<LLVM::FMulOp>(location: loc, args&: real, args&: real, args&: fmf),
84 args: rewriter.create<LLVM::FMulOp>(location: loc, args&: imag, args&: imag, args&: fmf), args&: fmf);
85
86 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, args&: sqNorm);
87 return success();
88 }
89};
90
91struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
92 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
93
94 LogicalResult
95 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
96 ConversionPatternRewriter &rewriter) const override {
97 return LLVM::detail::oneToOneRewrite(
98 op, targetOp: LLVM::ConstantOp::getOperationName(), operands: adaptor.getOperands(),
99 targetAttrs: op->getAttrs(), typeConverter: *getTypeConverter(), rewriter);
100 }
101};
102
103struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
104 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
105
106 LogicalResult
107 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter) const override {
109 // Pack real and imaginary part in a complex number struct.
110 auto loc = complexOp.getLoc();
111 auto structType = typeConverter->convertType(t: complexOp.getType());
112 auto complexStruct =
113 ComplexStructBuilder::poison(builder&: rewriter, loc, type: structType);
114 complexStruct.setReal(builder&: rewriter, loc, real: adaptor.getReal());
115 complexStruct.setImaginary(builder&: rewriter, loc, imaginary: adaptor.getImaginary());
116
117 rewriter.replaceOp(op: complexOp, newValues: {complexStruct});
118 return success();
119 }
120};
121
122struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
123 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
124
125 LogicalResult
126 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
127 ConversionPatternRewriter &rewriter) const override {
128 // Extract real part from the complex number struct.
129 ComplexStructBuilder complexStruct(adaptor.getComplex());
130 Value real = complexStruct.real(builder&: rewriter, loc: op.getLoc());
131 rewriter.replaceOp(op, newValues: real);
132
133 return success();
134 }
135};
136
137struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
138 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
139
140 LogicalResult
141 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
142 ConversionPatternRewriter &rewriter) const override {
143 // Extract imaginary part from the complex number struct.
144 ComplexStructBuilder complexStruct(adaptor.getComplex());
145 Value imaginary = complexStruct.imaginary(builder&: rewriter, loc: op.getLoc());
146 rewriter.replaceOp(op, newValues: imaginary);
147
148 return success();
149 }
150};
151
152struct BinaryComplexOperands {
153 std::complex<Value> lhs;
154 std::complex<Value> rhs;
155};
156
157template <typename OpTy>
158BinaryComplexOperands
159unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
160 ConversionPatternRewriter &rewriter) {
161 auto loc = op.getLoc();
162
163 // Extract real and imaginary values from operands.
164 BinaryComplexOperands unpacked;
165 ComplexStructBuilder lhs(adaptor.getLhs());
166 unpacked.lhs.real(lhs.real(builder&: rewriter, loc));
167 unpacked.lhs.imag(lhs.imaginary(builder&: rewriter, loc));
168 ComplexStructBuilder rhs(adaptor.getRhs());
169 unpacked.rhs.real(rhs.real(builder&: rewriter, loc));
170 unpacked.rhs.imag(rhs.imaginary(builder&: rewriter, loc));
171
172 return unpacked;
173}
174
175struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
176 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
177
178 LogicalResult
179 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
180 ConversionPatternRewriter &rewriter) const override {
181 auto loc = op.getLoc();
182 BinaryComplexOperands arg =
183 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
184
185 // Initialize complex number struct for result.
186 auto structType = typeConverter->convertType(t: op.getType());
187 auto result = ComplexStructBuilder::poison(builder&: rewriter, loc, type: structType);
188
189 // Emit IR to add complex numbers.
190 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
191 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
192 context: op.getContext(),
193 value: convertArithFastMathFlagsToLLVM(arithFMF: complexFMFAttr.getValue()));
194 Value real =
195 rewriter.create<LLVM::FAddOp>(location: loc, args: arg.lhs.real(), args: arg.rhs.real(), args&: fmf);
196 Value imag =
197 rewriter.create<LLVM::FAddOp>(location: loc, args: arg.lhs.imag(), args: arg.rhs.imag(), args&: fmf);
198 result.setReal(builder&: rewriter, loc, real);
199 result.setImaginary(builder&: rewriter, loc, imaginary: imag);
200
201 rewriter.replaceOp(op, newValues: {result});
202 return success();
203 }
204};
205
206struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
207 DivOpConversion(const LLVMTypeConverter &converter,
208 complex::ComplexRangeFlags target)
209 : ConvertOpToLLVMPattern<complex::DivOp>(converter),
210 complexRange(target) {}
211
212 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
213
214 LogicalResult
215 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter) const override {
217 auto loc = op.getLoc();
218 BinaryComplexOperands arg =
219 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
220
221 // Initialize complex number struct for result.
222 auto structType = typeConverter->convertType(t: op.getType());
223 auto result = ComplexStructBuilder::poison(builder&: rewriter, loc, type: structType);
224
225 // Emit IR to add complex numbers.
226 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
227 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
228 context: op.getContext(),
229 value: convertArithFastMathFlagsToLLVM(arithFMF: complexFMFAttr.getValue()));
230 Value rhsRe = arg.rhs.real();
231 Value rhsIm = arg.rhs.imag();
232 Value lhsRe = arg.lhs.real();
233 Value lhsIm = arg.lhs.imag();
234
235 Value resultRe, resultIm;
236
237 if (complexRange == complex::ComplexRangeFlags::basic ||
238 complexRange == complex::ComplexRangeFlags::none) {
239 mlir::complex::convertDivToLLVMUsingAlgebraic(
240 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, resultRe: &resultRe, resultIm: &resultIm);
241 } else if (complexRange == complex::ComplexRangeFlags::improved) {
242 mlir::complex::convertDivToLLVMUsingRangeReduction(
243 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, resultRe: &resultRe, resultIm: &resultIm);
244 }
245
246 result.setReal(builder&: rewriter, loc, real: resultRe);
247 result.setImaginary(builder&: rewriter, loc, imaginary: resultIm);
248
249 rewriter.replaceOp(op, newValues: {result});
250 return success();
251 }
252
253private:
254 complex::ComplexRangeFlags complexRange;
255};
256
257struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
258 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
259
260 LogicalResult
261 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter) const override {
263 auto loc = op.getLoc();
264 BinaryComplexOperands arg =
265 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
266
267 // Initialize complex number struct for result.
268 auto structType = typeConverter->convertType(t: op.getType());
269 auto result = ComplexStructBuilder::poison(builder&: rewriter, loc, type: structType);
270
271 // Emit IR to add complex numbers.
272 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
273 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
274 context: op.getContext(),
275 value: convertArithFastMathFlagsToLLVM(arithFMF: complexFMFAttr.getValue()));
276 Value rhsRe = arg.rhs.real();
277 Value rhsIm = arg.rhs.imag();
278 Value lhsRe = arg.lhs.real();
279 Value lhsIm = arg.lhs.imag();
280
281 Value real = rewriter.create<LLVM::FSubOp>(
282 location: loc, args: rewriter.create<LLVM::FMulOp>(location: loc, args&: rhsRe, args&: lhsRe, args&: fmf),
283 args: rewriter.create<LLVM::FMulOp>(location: loc, args&: rhsIm, args&: lhsIm, args&: fmf), args&: fmf);
284
285 Value imag = rewriter.create<LLVM::FAddOp>(
286 location: loc, args: rewriter.create<LLVM::FMulOp>(location: loc, args&: lhsIm, args&: rhsRe, args&: fmf),
287 args: rewriter.create<LLVM::FMulOp>(location: loc, args&: lhsRe, args&: rhsIm, args&: fmf), args&: fmf);
288
289 result.setReal(builder&: rewriter, loc, real);
290 result.setImaginary(builder&: rewriter, loc, imaginary: imag);
291
292 rewriter.replaceOp(op, newValues: {result});
293 return success();
294 }
295};
296
297struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
298 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
299
300 LogicalResult
301 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
302 ConversionPatternRewriter &rewriter) const override {
303 auto loc = op.getLoc();
304 BinaryComplexOperands arg =
305 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
306
307 // Initialize complex number struct for result.
308 auto structType = typeConverter->convertType(t: op.getType());
309 auto result = ComplexStructBuilder::poison(builder&: rewriter, loc, type: structType);
310
311 // Emit IR to substract complex numbers.
312 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
313 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
314 context: op.getContext(),
315 value: convertArithFastMathFlagsToLLVM(arithFMF: complexFMFAttr.getValue()));
316 Value real =
317 rewriter.create<LLVM::FSubOp>(location: loc, args: arg.lhs.real(), args: arg.rhs.real(), args&: fmf);
318 Value imag =
319 rewriter.create<LLVM::FSubOp>(location: loc, args: arg.lhs.imag(), args: arg.rhs.imag(), args&: fmf);
320 result.setReal(builder&: rewriter, loc, real);
321 result.setImaginary(builder&: rewriter, loc, imaginary: imag);
322
323 rewriter.replaceOp(op, newValues: {result});
324 return success();
325 }
326};
327} // namespace
328
329void mlir::populateComplexToLLVMConversionPatterns(
330 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
331 complex::ComplexRangeFlags complexRange) {
332 // clang-format off
333 patterns.add<
334 AbsOpConversion,
335 AddOpConversion,
336 ConstantOpLowering,
337 CreateOpConversion,
338 ImOpConversion,
339 MulOpConversion,
340 ReOpConversion,
341 SubOpConversion
342 >(arg: converter);
343
344 patterns.add<DivOpConversion>(arg: converter, args&: complexRange);
345 // clang-format on
346}
347
348namespace {
349struct ConvertComplexToLLVMPass
350 : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
351 using Base::Base;
352
353 void runOnOperation() override;
354};
355} // namespace
356
357void ConvertComplexToLLVMPass::runOnOperation() {
358 // Convert to the LLVM IR dialect using the converter defined above.
359 RewritePatternSet patterns(&getContext());
360 LLVMTypeConverter converter(&getContext());
361 populateComplexToLLVMConversionPatterns(converter, patterns, complexRange);
362
363 LLVMConversionTarget target(getContext());
364 target.addIllegalDialect<complex::ComplexDialect>();
365 if (failed(
366 Result: applyPartialConversion(op: getOperation(), target, patterns: std::move(patterns))))
367 signalPassFailure();
368}
369
370//===----------------------------------------------------------------------===//
371// ConvertToLLVMPatternInterface implementation
372//===----------------------------------------------------------------------===//
373
374namespace {
375/// Implement the interface to convert MemRef to LLVM.
376struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
377 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
378 void loadDependentDialects(MLIRContext *context) const final {
379 context->loadDialect<LLVM::LLVMDialect>();
380 }
381
382 /// Hook for derived dialect interface to provide conversion patterns
383 /// and mark dialect legal for the conversion target.
384 void populateConvertToLLVMConversionPatterns(
385 ConversionTarget &target, LLVMTypeConverter &typeConverter,
386 RewritePatternSet &patterns) const final {
387 populateComplexToLLVMConversionPatterns(converter: typeConverter, patterns);
388 }
389};
390} // namespace
391
392void mlir::registerConvertComplexToLLVMInterface(DialectRegistry &registry) {
393 registry.addExtension(
394 extensionFn: +[](MLIRContext *ctx, complex::ComplexDialect *dialect) {
395 dialect->addInterfaces<ComplexToLLVMDialectInterface>();
396 });
397}
398

source code of mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp