1//===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
10
11#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12#include "mlir/Conversion/ComplexCommon/DivisionConverter.h"
13#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
14#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15#include "mlir/Conversion/LLVMCommon/Pattern.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Complex/IR/Complex.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/Pass/Pass.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::LLVM;
28using namespace mlir::arith;
29
30//===----------------------------------------------------------------------===//
31// ComplexStructBuilder implementation.
32//===----------------------------------------------------------------------===//
33
34static constexpr unsigned kRealPosInComplexNumberStruct = 0;
35static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
36
37ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder,
38 Location loc, Type type) {
39 Value val = builder.create<LLVM::PoisonOp>(loc, type);
40 return ComplexStructBuilder(val);
41}
42
43void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
44 Value real) {
45 setPtr(builder, loc, pos: kRealPosInComplexNumberStruct, ptr: real);
46}
47
48Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
49 return extractPtr(builder, loc, pos: kRealPosInComplexNumberStruct);
50}
51
52void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
53 Value imaginary) {
54 setPtr(builder, loc, pos: kImaginaryPosInComplexNumberStruct, ptr: imaginary);
55}
56
57Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
58 return extractPtr(builder, loc, pos: kImaginaryPosInComplexNumberStruct);
59}
60
61//===----------------------------------------------------------------------===//
62// Conversion patterns.
63//===----------------------------------------------------------------------===//
64
65namespace {
66
67struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
68 using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
69
70 LogicalResult
71 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter) const override {
73 auto loc = op.getLoc();
74
75 ComplexStructBuilder complexStruct(adaptor.getComplex());
76 Value real = complexStruct.real(builder&: rewriter, loc: op.getLoc());
77 Value imag = complexStruct.imaginary(builder&: rewriter, loc: op.getLoc());
78
79 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
80 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
81 op.getContext(),
82 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
83 Value sqNorm = rewriter.create<LLVM::FAddOp>(
84 loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
85 rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
86
87 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
88 return success();
89 }
90};
91
92struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
93 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
94
95 LogicalResult
96 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
97 ConversionPatternRewriter &rewriter) const override {
98 return LLVM::detail::oneToOneRewrite(
99 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
100 op->getAttrs(), *getTypeConverter(), rewriter);
101 }
102};
103
104struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
105 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
106
107 LogicalResult
108 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
109 ConversionPatternRewriter &rewriter) const override {
110 // Pack real and imaginary part in a complex number struct.
111 auto loc = complexOp.getLoc();
112 auto structType = typeConverter->convertType(complexOp.getType());
113 auto complexStruct =
114 ComplexStructBuilder::poison(builder&: rewriter, loc: loc, type: structType);
115 complexStruct.setReal(rewriter, loc, adaptor.getReal());
116 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
117
118 rewriter.replaceOp(complexOp, {complexStruct});
119 return success();
120 }
121};
122
123struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
124 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
125
126 LogicalResult
127 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
128 ConversionPatternRewriter &rewriter) const override {
129 // Extract real part from the complex number struct.
130 ComplexStructBuilder complexStruct(adaptor.getComplex());
131 Value real = complexStruct.real(builder&: rewriter, loc: op.getLoc());
132 rewriter.replaceOp(op, real);
133
134 return success();
135 }
136};
137
138struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
139 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
140
141 LogicalResult
142 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
143 ConversionPatternRewriter &rewriter) const override {
144 // Extract imaginary part from the complex number struct.
145 ComplexStructBuilder complexStruct(adaptor.getComplex());
146 Value imaginary = complexStruct.imaginary(builder&: rewriter, loc: op.getLoc());
147 rewriter.replaceOp(op, imaginary);
148
149 return success();
150 }
151};
152
153struct BinaryComplexOperands {
154 std::complex<Value> lhs;
155 std::complex<Value> rhs;
156};
157
158template <typename OpTy>
159BinaryComplexOperands
160unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
161 ConversionPatternRewriter &rewriter) {
162 auto loc = op.getLoc();
163
164 // Extract real and imaginary values from operands.
165 BinaryComplexOperands unpacked;
166 ComplexStructBuilder lhs(adaptor.getLhs());
167 unpacked.lhs.real(lhs.real(builder&: rewriter, loc));
168 unpacked.lhs.imag(lhs.imaginary(builder&: rewriter, loc));
169 ComplexStructBuilder rhs(adaptor.getRhs());
170 unpacked.rhs.real(rhs.real(builder&: rewriter, loc));
171 unpacked.rhs.imag(rhs.imaginary(builder&: rewriter, loc));
172
173 return unpacked;
174}
175
176struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
177 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
178
179 LogicalResult
180 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter) const override {
182 auto loc = op.getLoc();
183 BinaryComplexOperands arg =
184 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
185
186 // Initialize complex number struct for result.
187 auto structType = typeConverter->convertType(op.getType());
188 auto result = ComplexStructBuilder::poison(builder&: rewriter, loc: loc, type: structType);
189
190 // Emit IR to add complex numbers.
191 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
192 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
193 op.getContext(),
194 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
195 Value real =
196 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
197 Value imag =
198 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
199 result.setReal(rewriter, loc, real);
200 result.setImaginary(rewriter, loc, imag);
201
202 rewriter.replaceOp(op, {result});
203 return success();
204 }
205};
206
207struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
208 DivOpConversion(const LLVMTypeConverter &converter,
209 complex::ComplexRangeFlags target)
210 : ConvertOpToLLVMPattern<complex::DivOp>(converter),
211 complexRange(target) {}
212
213 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
214
215 LogicalResult
216 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
217 ConversionPatternRewriter &rewriter) const override {
218 auto loc = op.getLoc();
219 BinaryComplexOperands arg =
220 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
221
222 // Initialize complex number struct for result.
223 auto structType = typeConverter->convertType(op.getType());
224 auto result = ComplexStructBuilder::poison(builder&: rewriter, loc: loc, type: structType);
225
226 // Emit IR to add complex numbers.
227 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
228 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
229 op.getContext(),
230 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
231 Value rhsRe = arg.rhs.real();
232 Value rhsIm = arg.rhs.imag();
233 Value lhsRe = arg.lhs.real();
234 Value lhsIm = arg.lhs.imag();
235
236 Value resultRe, resultIm;
237
238 if (complexRange == complex::ComplexRangeFlags::basic ||
239 complexRange == complex::ComplexRangeFlags::none) {
240 mlir::complex::convertDivToLLVMUsingAlgebraic(
241 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
242 } else if (complexRange == complex::ComplexRangeFlags::improved) {
243 mlir::complex::convertDivToLLVMUsingRangeReduction(
244 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
245 }
246
247 result.setReal(rewriter, loc, resultRe);
248 result.setImaginary(rewriter, loc, resultIm);
249
250 rewriter.replaceOp(op, {result});
251 return success();
252 }
253
254private:
255 complex::ComplexRangeFlags complexRange;
256};
257
258struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
259 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
260
261 LogicalResult
262 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
263 ConversionPatternRewriter &rewriter) const override {
264 auto loc = op.getLoc();
265 BinaryComplexOperands arg =
266 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
267
268 // Initialize complex number struct for result.
269 auto structType = typeConverter->convertType(op.getType());
270 auto result = ComplexStructBuilder::poison(builder&: rewriter, loc: loc, type: structType);
271
272 // Emit IR to add complex numbers.
273 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
274 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
275 op.getContext(),
276 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
277 Value rhsRe = arg.rhs.real();
278 Value rhsIm = arg.rhs.imag();
279 Value lhsRe = arg.lhs.real();
280 Value lhsIm = arg.lhs.imag();
281
282 Value real = rewriter.create<LLVM::FSubOp>(
283 loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
284 rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
285
286 Value imag = rewriter.create<LLVM::FAddOp>(
287 loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
288 rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
289
290 result.setReal(rewriter, loc, real);
291 result.setImaginary(rewriter, loc, imag);
292
293 rewriter.replaceOp(op, {result});
294 return success();
295 }
296};
297
298struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
299 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
300
301 LogicalResult
302 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter) const override {
304 auto loc = op.getLoc();
305 BinaryComplexOperands arg =
306 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
307
308 // Initialize complex number struct for result.
309 auto structType = typeConverter->convertType(op.getType());
310 auto result = ComplexStructBuilder::poison(builder&: rewriter, loc: loc, type: structType);
311
312 // Emit IR to substract complex numbers.
313 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
314 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
315 op.getContext(),
316 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
317 Value real =
318 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
319 Value imag =
320 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
321 result.setReal(rewriter, loc, real);
322 result.setImaginary(rewriter, loc, imag);
323
324 rewriter.replaceOp(op, {result});
325 return success();
326 }
327};
328} // namespace
329
330void mlir::populateComplexToLLVMConversionPatterns(
331 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
332 complex::ComplexRangeFlags complexRange) {
333 // clang-format off
334 patterns.add<
335 AbsOpConversion,
336 AddOpConversion,
337 ConstantOpLowering,
338 CreateOpConversion,
339 ImOpConversion,
340 MulOpConversion,
341 ReOpConversion,
342 SubOpConversion
343 >(arg: converter);
344
345 patterns.add<DivOpConversion>(converter, complexRange);
346 // clang-format on
347}
348
349namespace {
350struct ConvertComplexToLLVMPass
351 : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
352 using Base::Base;
353
354 void runOnOperation() override;
355};
356} // namespace
357
358void ConvertComplexToLLVMPass::runOnOperation() {
359 // Convert to the LLVM IR dialect using the converter defined above.
360 RewritePatternSet patterns(&getContext());
361 LLVMTypeConverter converter(&getContext());
362 populateComplexToLLVMConversionPatterns(converter, patterns, complexRange);
363
364 LLVMConversionTarget target(getContext());
365 target.addIllegalDialect<complex::ComplexDialect>();
366 if (failed(
367 applyPartialConversion(getOperation(), target, std::move(patterns))))
368 signalPassFailure();
369}
370
371//===----------------------------------------------------------------------===//
372// ConvertToLLVMPatternInterface implementation
373//===----------------------------------------------------------------------===//
374
375namespace {
376/// Implement the interface to convert MemRef to LLVM.
377struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
378 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
379 void loadDependentDialects(MLIRContext *context) const final {
380 context->loadDialect<LLVM::LLVMDialect>();
381 }
382
383 /// Hook for derived dialect interface to provide conversion patterns
384 /// and mark dialect legal for the conversion target.
385 void populateConvertToLLVMConversionPatterns(
386 ConversionTarget &target, LLVMTypeConverter &typeConverter,
387 RewritePatternSet &patterns) const final {
388 populateComplexToLLVMConversionPatterns(converter: typeConverter, patterns);
389 }
390};
391} // namespace
392
393void mlir::registerConvertComplexToLLVMInterface(DialectRegistry &registry) {
394 registry.addExtension(
395 extensionFn: +[](MLIRContext *ctx, complex::ComplexDialect *dialect) {
396 dialect->addInterfaces<ComplexToLLVMDialectInterface>();
397 });
398}
399

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp