1//===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
10
11#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Arith/Transforms/Passes.h"
17#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/IR/TypeUtilities.h"
20#include "mlir/Pass/Pass.h"
21#include <type_traits>
22
23namespace mlir {
24#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32/// Operations whose conversion will depend on whether they are passed a
33/// rounding mode attribute or not.
34///
35/// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
36/// to; `AttrConvert` is the attribute conversion to convert the rounding mode
37/// attribute.
38template <typename SourceOp, typename TargetOp, bool Constrained,
39 template <typename, typename> typename AttrConvert =
40 AttrConvertPassThrough>
41struct ConstrainedVectorConvertToLLVMPattern
42 : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
43 using VectorConvertToLLVMPattern<SourceOp, TargetOp,
44 AttrConvert>::VectorConvertToLLVMPattern;
45
46 LogicalResult
47 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
48 ConversionPatternRewriter &rewriter) const override {
49 if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
50 return failure();
51 return VectorConvertToLLVMPattern<SourceOp, TargetOp,
52 AttrConvert>::matchAndRewrite(op, adaptor,
53 rewriter);
54 }
55};
56
57/// No-op bitcast. Propagate type input arg if converted source and dest types
58/// are the same.
59struct IdentityBitcastLowering final
60 : public OpConversionPattern<arith::BitcastOp> {
61 using OpConversionPattern::OpConversionPattern;
62
63 LogicalResult
64 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
65 ConversionPatternRewriter &rewriter) const final {
66 Value src = adaptor.getIn();
67 Type resultType = getTypeConverter()->convertType(op.getType());
68 if (src.getType() != resultType)
69 return rewriter.notifyMatchFailure(op, "Types are different");
70
71 rewriter.replaceOp(op, src);
72 return success();
73 }
74};
75
76//===----------------------------------------------------------------------===//
77// Straightforward Op Lowerings
78//===----------------------------------------------------------------------===//
79
80using AddFOpLowering =
81 VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
82 arith::AttrConvertFastMathToLLVM>;
83using AddIOpLowering =
84 VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
85 arith::AttrConvertOverflowToLLVM>;
86using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
87using BitcastOpLowering =
88 VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
89using DivFOpLowering =
90 VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
91 arith::AttrConvertFastMathToLLVM>;
92using DivSIOpLowering =
93 VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
94using DivUIOpLowering =
95 VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
96using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
97using ExtSIOpLowering =
98 VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
99using ExtUIOpLowering =
100 VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
101using FPToSIOpLowering =
102 VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
103using FPToUIOpLowering =
104 VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
105using MaximumFOpLowering =
106 VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
107 arith::AttrConvertFastMathToLLVM>;
108using MaxNumFOpLowering =
109 VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
110 arith::AttrConvertFastMathToLLVM>;
111using MaxSIOpLowering =
112 VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
113using MaxUIOpLowering =
114 VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
115using MinimumFOpLowering =
116 VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
117 arith::AttrConvertFastMathToLLVM>;
118using MinNumFOpLowering =
119 VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
120 arith::AttrConvertFastMathToLLVM>;
121using MinSIOpLowering =
122 VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
123using MinUIOpLowering =
124 VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
125using MulFOpLowering =
126 VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
127 arith::AttrConvertFastMathToLLVM>;
128using MulIOpLowering =
129 VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
130 arith::AttrConvertOverflowToLLVM>;
131using NegFOpLowering =
132 VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
133 arith::AttrConvertFastMathToLLVM>;
134using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
135using RemFOpLowering =
136 VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
137 arith::AttrConvertFastMathToLLVM>;
138using RemSIOpLowering =
139 VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
140using RemUIOpLowering =
141 VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
142using SelectOpLowering =
143 VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
144using ShLIOpLowering =
145 VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
146 arith::AttrConvertOverflowToLLVM>;
147using ShRSIOpLowering =
148 VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
149using ShRUIOpLowering =
150 VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
151using SIToFPOpLowering =
152 VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
153using SubFOpLowering =
154 VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
155 arith::AttrConvertFastMathToLLVM>;
156using SubIOpLowering =
157 VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
158 arith::AttrConvertOverflowToLLVM>;
159using TruncFOpLowering =
160 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
161 false>;
162using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
163 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
164 arith::AttrConverterConstrainedFPToLLVM>;
165using TruncIOpLowering =
166 VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
167using UIToFPOpLowering =
168 VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
169using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
170
171//===----------------------------------------------------------------------===//
172// Op Lowering Patterns
173//===----------------------------------------------------------------------===//
174
175/// Directly lower to LLVM op.
176struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
177 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
178
179 LogicalResult
180 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter) const override;
182};
183
184/// The lowering of index_cast becomes an integer conversion since index
185/// becomes an integer. If the bit width of the source and target integer
186/// types is the same, just erase the cast. If the target type is wider,
187/// sign-extend the value, otherwise truncate it.
188template <typename OpTy, typename ExtCastTy>
189struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
190 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
191
192 LogicalResult
193 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
194 ConversionPatternRewriter &rewriter) const override;
195};
196
197using IndexCastOpSILowering =
198 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
199using IndexCastOpUILowering =
200 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
201
202struct AddUIExtendedOpLowering
203 : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
204 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
205
206 LogicalResult
207 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
208 ConversionPatternRewriter &rewriter) const override;
209};
210
211template <typename ArithMulOp, bool IsSigned>
212struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
213 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
214
215 LogicalResult
216 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
217 ConversionPatternRewriter &rewriter) const override;
218};
219
220using MulSIExtendedOpLowering =
221 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
222using MulUIExtendedOpLowering =
223 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
224
225struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
226 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
227
228 LogicalResult
229 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter) const override;
231};
232
233struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
234 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
235
236 LogicalResult
237 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter) const override;
239};
240
241} // namespace
242
243//===----------------------------------------------------------------------===//
244// ConstantOpLowering
245//===----------------------------------------------------------------------===//
246
247LogicalResult
248ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
249 ConversionPatternRewriter &rewriter) const {
250 return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
251 adaptor.getOperands(), op->getAttrs(),
252 *getTypeConverter(), rewriter);
253}
254
255//===----------------------------------------------------------------------===//
256// IndexCastOpLowering
257//===----------------------------------------------------------------------===//
258
259template <typename OpTy, typename ExtCastTy>
260LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
261 OpTy op, typename OpTy::Adaptor adaptor,
262 ConversionPatternRewriter &rewriter) const {
263 Type resultType = op.getResult().getType();
264 Type targetElementType =
265 this->typeConverter->convertType(getElementTypeOrSelf(type: resultType));
266 Type sourceElementType =
267 this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
268 unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
269 unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
270
271 if (targetBits == sourceBits) {
272 rewriter.replaceOp(op, adaptor.getIn());
273 return success();
274 }
275
276 // Handle the scalar and 1D vector cases.
277 Type operandType = adaptor.getIn().getType();
278 if (!isa<LLVM::LLVMArrayType>(operandType)) {
279 Type targetType = this->typeConverter->convertType(resultType);
280 if (targetBits < sourceBits)
281 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
282 adaptor.getIn());
283 else
284 rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
285 return success();
286 }
287
288 if (!isa<VectorType>(Val: resultType))
289 return rewriter.notifyMatchFailure(op, "expected vector result type");
290
291 return LLVM::detail::handleMultidimensionalVectors(
292 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *(this->getTypeConverter()),
293 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
294 typename OpTy::Adaptor adaptor(operands);
295 if (targetBits < sourceBits) {
296 return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
297 adaptor.getIn());
298 }
299 return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
300 adaptor.getIn());
301 },
302 rewriter);
303}
304
305//===----------------------------------------------------------------------===//
306// AddUIExtendedOpLowering
307//===----------------------------------------------------------------------===//
308
309LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
310 arith::AddUIExtendedOp op, OpAdaptor adaptor,
311 ConversionPatternRewriter &rewriter) const {
312 Type operandType = adaptor.getLhs().getType();
313 Type sumResultType = op.getSum().getType();
314 Type overflowResultType = op.getOverflow().getType();
315
316 if (!LLVM::isCompatibleType(type: operandType))
317 return failure();
318
319 MLIRContext *ctx = rewriter.getContext();
320 Location loc = op.getLoc();
321
322 // Handle the scalar and 1D vector cases.
323 if (!isa<LLVM::LLVMArrayType>(operandType)) {
324 Type newOverflowType = typeConverter->convertType(overflowResultType);
325 Type structType =
326 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
327 Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
328 loc, structType, adaptor.getLhs(), adaptor.getRhs());
329 Value sumExtracted =
330 rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
331 Value overflowExtracted =
332 rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
333 rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
334 return success();
335 }
336
337 if (!isa<VectorType>(Val: sumResultType))
338 return rewriter.notifyMatchFailure(arg&: loc, msg: "expected vector result types");
339
340 return rewriter.notifyMatchFailure(arg&: loc,
341 msg: "ND vector types are not supported yet");
342}
343
344//===----------------------------------------------------------------------===//
345// MulIExtendedOpLowering
346//===----------------------------------------------------------------------===//
347
348template <typename ArithMulOp, bool IsSigned>
349LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
350 ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
351 ConversionPatternRewriter &rewriter) const {
352 Type resultType = adaptor.getLhs().getType();
353
354 if (!LLVM::isCompatibleType(type: resultType))
355 return failure();
356
357 Location loc = op.getLoc();
358
359 // Handle the scalar and 1D vector cases. Because LLVM does not have a
360 // matching extended multiplication intrinsic, perform regular multiplication
361 // on operands zero-extended to i(2*N) bits, and truncate the results back to
362 // iN types.
363 if (!isa<LLVM::LLVMArrayType>(resultType)) {
364 // Shift amount necessary to extract the high bits from widened result.
365 TypedAttr shiftValAttr;
366
367 if (auto intTy = dyn_cast<IntegerType>(resultType)) {
368 unsigned resultBitwidth = intTy.getWidth();
369 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
370 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
371 } else {
372 auto vecTy = cast<VectorType>(resultType);
373 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
374 auto attrTy = VectorType::get(
375 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
376 shiftValAttr = SplatElementsAttr::get(
377 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
378 }
379 Type wideType = shiftValAttr.getType();
380 assert(LLVM::isCompatibleType(wideType) &&
381 "LLVM dialect should support all signless integer types");
382
383 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
384 Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
385 Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
386 Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
387
388 // Split the 2*N-bit wide result into two N-bit values.
389 Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
390 Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
391 Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
392 Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
393
394 rewriter.replaceOp(op, {low, high});
395 return success();
396 }
397
398 if (!isa<VectorType>(Val: resultType))
399 return rewriter.notifyMatchFailure(op, "expected vector result type");
400
401 return rewriter.notifyMatchFailure(op,
402 "ND vector types are not supported yet");
403}
404
405//===----------------------------------------------------------------------===//
406// CmpIOpLowering
407//===----------------------------------------------------------------------===//
408
409// Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
410// share numerical values so just cast.
411template <typename LLVMPredType, typename PredType>
412static LLVMPredType convertCmpPredicate(PredType pred) {
413 return static_cast<LLVMPredType>(pred);
414}
415
416LogicalResult
417CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
418 ConversionPatternRewriter &rewriter) const {
419 Type operandType = adaptor.getLhs().getType();
420 Type resultType = op.getResult().getType();
421
422 // Handle the scalar and 1D vector cases.
423 if (!isa<LLVM::LLVMArrayType>(operandType)) {
424 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
425 op, typeConverter->convertType(resultType),
426 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
427 adaptor.getLhs(), adaptor.getRhs());
428 return success();
429 }
430
431 if (!isa<VectorType>(Val: resultType))
432 return rewriter.notifyMatchFailure(op, "expected vector result type");
433
434 return LLVM::detail::handleMultidimensionalVectors(
435 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(),
436 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) {
437 OpAdaptor adaptor(operands);
438 return rewriter.create<LLVM::ICmpOp>(
439 op.getLoc(), llvm1DVectorTy,
440 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
441 adaptor.getLhs(), adaptor.getRhs());
442 },
443 rewriter);
444}
445
446//===----------------------------------------------------------------------===//
447// CmpFOpLowering
448//===----------------------------------------------------------------------===//
449
450LogicalResult
451CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
452 ConversionPatternRewriter &rewriter) const {
453 Type operandType = adaptor.getLhs().getType();
454 Type resultType = op.getResult().getType();
455 LLVM::FastmathFlags fmf =
456 arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
457
458 // Handle the scalar and 1D vector cases.
459 if (!isa<LLVM::LLVMArrayType>(operandType)) {
460 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
461 op, typeConverter->convertType(resultType),
462 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
463 adaptor.getLhs(), adaptor.getRhs(), fmf);
464 return success();
465 }
466
467 if (!isa<VectorType>(Val: resultType))
468 return rewriter.notifyMatchFailure(op, "expected vector result type");
469
470 return LLVM::detail::handleMultidimensionalVectors(
471 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(),
472 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) {
473 OpAdaptor adaptor(operands);
474 return rewriter.create<LLVM::FCmpOp>(
475 op.getLoc(), llvm1DVectorTy,
476 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
477 adaptor.getLhs(), adaptor.getRhs(), fmf);
478 },
479 rewriter);
480}
481
482//===----------------------------------------------------------------------===//
483// Pass Definition
484//===----------------------------------------------------------------------===//
485
486namespace {
487struct ArithToLLVMConversionPass
488 : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
489 using Base::Base;
490
491 void runOnOperation() override {
492 LLVMConversionTarget target(getContext());
493 RewritePatternSet patterns(&getContext());
494
495 LowerToLLVMOptions options(&getContext());
496 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
497 options.overrideIndexBitwidth(indexBitwidth);
498
499 LLVMTypeConverter converter(&getContext(), options);
500 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
501 arith::populateArithToLLVMConversionPatterns(converter, patterns);
502
503 if (failed(applyPartialConversion(getOperation(), target,
504 std::move(patterns))))
505 signalPassFailure();
506 }
507};
508} // namespace
509
510//===----------------------------------------------------------------------===//
511// ConvertToLLVMPatternInterface implementation
512//===----------------------------------------------------------------------===//
513
514namespace {
515/// Implement the interface to convert MemRef to LLVM.
516struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
517 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
518 void loadDependentDialects(MLIRContext *context) const final {
519 context->loadDialect<LLVM::LLVMDialect>();
520 }
521
522 /// Hook for derived dialect interface to provide conversion patterns
523 /// and mark dialect legal for the conversion target.
524 void populateConvertToLLVMConversionPatterns(
525 ConversionTarget &target, LLVMTypeConverter &typeConverter,
526 RewritePatternSet &patterns) const final {
527 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
528 arith::populateArithToLLVMConversionPatterns(converter: typeConverter, patterns);
529 }
530};
531} // namespace
532
533void mlir::arith::registerConvertArithToLLVMInterface(
534 DialectRegistry &registry) {
535 registry.addExtension(extensionFn: +[](MLIRContext *ctx, arith::ArithDialect *dialect) {
536 dialect->addInterfaces<ArithToLLVMDialectInterface>();
537 });
538}
539
540//===----------------------------------------------------------------------===//
541// Pattern Population
542//===----------------------------------------------------------------------===//
543
544void mlir::arith::populateArithToLLVMConversionPatterns(
545 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
546
547 // Set a higher pattern benefit for IdentityBitcastLowering so it will run
548 // before BitcastOpLowering.
549 patterns.add<IdentityBitcastLowering>(arg: converter, args: patterns.getContext(),
550 /*patternBenefit*/ args: 10);
551
552 // clang-format off
553 patterns.add<
554 AddFOpLowering,
555 AddIOpLowering,
556 AndIOpLowering,
557 AddUIExtendedOpLowering,
558 BitcastOpLowering,
559 ConstantOpLowering,
560 CmpFOpLowering,
561 CmpIOpLowering,
562 DivFOpLowering,
563 DivSIOpLowering,
564 DivUIOpLowering,
565 ExtFOpLowering,
566 ExtSIOpLowering,
567 ExtUIOpLowering,
568 FPToSIOpLowering,
569 FPToUIOpLowering,
570 IndexCastOpSILowering,
571 IndexCastOpUILowering,
572 MaximumFOpLowering,
573 MaxNumFOpLowering,
574 MaxSIOpLowering,
575 MaxUIOpLowering,
576 MinimumFOpLowering,
577 MinNumFOpLowering,
578 MinSIOpLowering,
579 MinUIOpLowering,
580 MulFOpLowering,
581 MulIOpLowering,
582 MulSIExtendedOpLowering,
583 MulUIExtendedOpLowering,
584 NegFOpLowering,
585 OrIOpLowering,
586 RemFOpLowering,
587 RemSIOpLowering,
588 RemUIOpLowering,
589 SelectOpLowering,
590 ShLIOpLowering,
591 ShRSIOpLowering,
592 ShRUIOpLowering,
593 SIToFPOpLowering,
594 SubFOpLowering,
595 SubIOpLowering,
596 TruncFOpLowering,
597 ConstrainedTruncFOpLowering,
598 TruncIOpLowering,
599 UIToFPOpLowering,
600 XOrIOpLowering
601 >(converter);
602 // clang-format on
603}
604

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp