1 | //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
10 | |
11 | #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" |
12 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
13 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
14 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
17 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
19 | #include "mlir/IR/TypeUtilities.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | #include <type_traits> |
22 | |
23 | namespace mlir { |
24 | #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS |
25 | #include "mlir/Conversion/Passes.h.inc" |
26 | } // namespace mlir |
27 | |
28 | using namespace mlir; |
29 | |
30 | namespace { |
31 | |
32 | /// Operations whose conversion will depend on whether they are passed a |
33 | /// rounding mode attribute or not. |
34 | /// |
35 | /// `SourceOp` is the source operation; `TargetOp`, the operation it will lower |
36 | /// to; `AttrConvert` is the attribute conversion to convert the rounding mode |
37 | /// attribute. |
38 | template <typename SourceOp, typename TargetOp, bool Constrained, |
39 | template <typename, typename> typename AttrConvert = |
40 | AttrConvertPassThrough> |
41 | struct ConstrainedVectorConvertToLLVMPattern |
42 | : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> { |
43 | using VectorConvertToLLVMPattern<SourceOp, TargetOp, |
44 | AttrConvert>::VectorConvertToLLVMPattern; |
45 | |
46 | LogicalResult |
47 | matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, |
48 | ConversionPatternRewriter &rewriter) const override { |
49 | if (Constrained != static_cast<bool>(op.getRoundingModeAttr())) |
50 | return failure(); |
51 | return VectorConvertToLLVMPattern<SourceOp, TargetOp, |
52 | AttrConvert>::matchAndRewrite(op, adaptor, |
53 | rewriter); |
54 | } |
55 | }; |
56 | |
57 | /// No-op bitcast. Propagate type input arg if converted source and dest types |
58 | /// are the same. |
59 | struct IdentityBitcastLowering final |
60 | : public OpConversionPattern<arith::BitcastOp> { |
61 | using OpConversionPattern::OpConversionPattern; |
62 | |
63 | LogicalResult |
64 | matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, |
65 | ConversionPatternRewriter &rewriter) const final { |
66 | Value src = adaptor.getIn(); |
67 | Type resultType = getTypeConverter()->convertType(op.getType()); |
68 | if (src.getType() != resultType) |
69 | return rewriter.notifyMatchFailure(op, "Types are different"); |
70 | |
71 | rewriter.replaceOp(op, src); |
72 | return success(); |
73 | } |
74 | }; |
75 | |
76 | //===----------------------------------------------------------------------===// |
77 | // Straightforward Op Lowerings |
78 | //===----------------------------------------------------------------------===// |
79 | |
80 | using AddFOpLowering = |
81 | VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp, |
82 | arith::AttrConvertFastMathToLLVM>; |
83 | using AddIOpLowering = |
84 | VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp, |
85 | arith::AttrConvertOverflowToLLVM>; |
86 | using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>; |
87 | using BitcastOpLowering = |
88 | VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; |
89 | using DivFOpLowering = |
90 | VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp, |
91 | arith::AttrConvertFastMathToLLVM>; |
92 | using DivSIOpLowering = |
93 | VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; |
94 | using DivUIOpLowering = |
95 | VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; |
96 | using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; |
97 | using ExtSIOpLowering = |
98 | VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; |
99 | using ExtUIOpLowering = |
100 | VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; |
101 | using FPToSIOpLowering = |
102 | VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; |
103 | using FPToUIOpLowering = |
104 | VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; |
105 | using MaximumFOpLowering = |
106 | VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp, |
107 | arith::AttrConvertFastMathToLLVM>; |
108 | using MaxNumFOpLowering = |
109 | VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp, |
110 | arith::AttrConvertFastMathToLLVM>; |
111 | using MaxSIOpLowering = |
112 | VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>; |
113 | using MaxUIOpLowering = |
114 | VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>; |
115 | using MinimumFOpLowering = |
116 | VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp, |
117 | arith::AttrConvertFastMathToLLVM>; |
118 | using MinNumFOpLowering = |
119 | VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp, |
120 | arith::AttrConvertFastMathToLLVM>; |
121 | using MinSIOpLowering = |
122 | VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>; |
123 | using MinUIOpLowering = |
124 | VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>; |
125 | using MulFOpLowering = |
126 | VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp, |
127 | arith::AttrConvertFastMathToLLVM>; |
128 | using MulIOpLowering = |
129 | VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp, |
130 | arith::AttrConvertOverflowToLLVM>; |
131 | using NegFOpLowering = |
132 | VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp, |
133 | arith::AttrConvertFastMathToLLVM>; |
134 | using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; |
135 | using RemFOpLowering = |
136 | VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, |
137 | arith::AttrConvertFastMathToLLVM>; |
138 | using RemSIOpLowering = |
139 | VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; |
140 | using RemUIOpLowering = |
141 | VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>; |
142 | using SelectOpLowering = |
143 | VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>; |
144 | using ShLIOpLowering = |
145 | VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp, |
146 | arith::AttrConvertOverflowToLLVM>; |
147 | using ShRSIOpLowering = |
148 | VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>; |
149 | using ShRUIOpLowering = |
150 | VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>; |
151 | using SIToFPOpLowering = |
152 | VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; |
153 | using SubFOpLowering = |
154 | VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp, |
155 | arith::AttrConvertFastMathToLLVM>; |
156 | using SubIOpLowering = |
157 | VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp, |
158 | arith::AttrConvertOverflowToLLVM>; |
159 | using TruncFOpLowering = |
160 | ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp, |
161 | false>; |
162 | using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern< |
163 | arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true, |
164 | arith::AttrConverterConstrainedFPToLLVM>; |
165 | using TruncIOpLowering = |
166 | VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>; |
167 | using UIToFPOpLowering = |
168 | VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; |
169 | using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; |
170 | |
171 | //===----------------------------------------------------------------------===// |
172 | // Op Lowering Patterns |
173 | //===----------------------------------------------------------------------===// |
174 | |
175 | /// Directly lower to LLVM op. |
176 | struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> { |
177 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
178 | |
179 | LogicalResult |
180 | matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, |
181 | ConversionPatternRewriter &rewriter) const override; |
182 | }; |
183 | |
184 | /// The lowering of index_cast becomes an integer conversion since index |
185 | /// becomes an integer. If the bit width of the source and target integer |
186 | /// types is the same, just erase the cast. If the target type is wider, |
187 | /// sign-extend the value, otherwise truncate it. |
188 | template <typename OpTy, typename ExtCastTy> |
189 | struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> { |
190 | using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern; |
191 | |
192 | LogicalResult |
193 | matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, |
194 | ConversionPatternRewriter &rewriter) const override; |
195 | }; |
196 | |
197 | using IndexCastOpSILowering = |
198 | IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>; |
199 | using IndexCastOpUILowering = |
200 | IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>; |
201 | |
202 | struct AddUIExtendedOpLowering |
203 | : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> { |
204 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
205 | |
206 | LogicalResult |
207 | matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, |
208 | ConversionPatternRewriter &rewriter) const override; |
209 | }; |
210 | |
211 | template <typename ArithMulOp, bool IsSigned> |
212 | struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> { |
213 | using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern; |
214 | |
215 | LogicalResult |
216 | matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, |
217 | ConversionPatternRewriter &rewriter) const override; |
218 | }; |
219 | |
220 | using MulSIExtendedOpLowering = |
221 | MulIExtendedOpLowering<arith::MulSIExtendedOp, true>; |
222 | using MulUIExtendedOpLowering = |
223 | MulIExtendedOpLowering<arith::MulUIExtendedOp, false>; |
224 | |
225 | struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> { |
226 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
227 | |
228 | LogicalResult |
229 | matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
230 | ConversionPatternRewriter &rewriter) const override; |
231 | }; |
232 | |
233 | struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { |
234 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
235 | |
236 | LogicalResult |
237 | matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
238 | ConversionPatternRewriter &rewriter) const override; |
239 | }; |
240 | |
241 | } // namespace |
242 | |
243 | //===----------------------------------------------------------------------===// |
244 | // ConstantOpLowering |
245 | //===----------------------------------------------------------------------===// |
246 | |
247 | LogicalResult |
248 | ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, |
249 | ConversionPatternRewriter &rewriter) const { |
250 | return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), |
251 | adaptor.getOperands(), op->getAttrs(), |
252 | *getTypeConverter(), rewriter); |
253 | } |
254 | |
255 | //===----------------------------------------------------------------------===// |
256 | // IndexCastOpLowering |
257 | //===----------------------------------------------------------------------===// |
258 | |
259 | template <typename OpTy, typename ExtCastTy> |
260 | LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite( |
261 | OpTy op, typename OpTy::Adaptor adaptor, |
262 | ConversionPatternRewriter &rewriter) const { |
263 | Type resultType = op.getResult().getType(); |
264 | Type targetElementType = |
265 | this->typeConverter->convertType(getElementTypeOrSelf(type: resultType)); |
266 | Type sourceElementType = |
267 | this->typeConverter->convertType(getElementTypeOrSelf(op.getIn())); |
268 | unsigned targetBits = targetElementType.getIntOrFloatBitWidth(); |
269 | unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth(); |
270 | |
271 | if (targetBits == sourceBits) { |
272 | rewriter.replaceOp(op, adaptor.getIn()); |
273 | return success(); |
274 | } |
275 | |
276 | // Handle the scalar and 1D vector cases. |
277 | Type operandType = adaptor.getIn().getType(); |
278 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
279 | Type targetType = this->typeConverter->convertType(resultType); |
280 | if (targetBits < sourceBits) |
281 | rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, |
282 | adaptor.getIn()); |
283 | else |
284 | rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn()); |
285 | return success(); |
286 | } |
287 | |
288 | if (!isa<VectorType>(Val: resultType)) |
289 | return rewriter.notifyMatchFailure(op, "expected vector result type"); |
290 | |
291 | return LLVM::detail::handleMultidimensionalVectors( |
292 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *(this->getTypeConverter()), |
293 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) -> Value { |
294 | typename OpTy::Adaptor adaptor(operands); |
295 | if (targetBits < sourceBits) { |
296 | return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy, |
297 | adaptor.getIn()); |
298 | } |
299 | return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy, |
300 | adaptor.getIn()); |
301 | }, |
302 | rewriter); |
303 | } |
304 | |
305 | //===----------------------------------------------------------------------===// |
306 | // AddUIExtendedOpLowering |
307 | //===----------------------------------------------------------------------===// |
308 | |
309 | LogicalResult AddUIExtendedOpLowering::matchAndRewrite( |
310 | arith::AddUIExtendedOp op, OpAdaptor adaptor, |
311 | ConversionPatternRewriter &rewriter) const { |
312 | Type operandType = adaptor.getLhs().getType(); |
313 | Type sumResultType = op.getSum().getType(); |
314 | Type overflowResultType = op.getOverflow().getType(); |
315 | |
316 | if (!LLVM::isCompatibleType(type: operandType)) |
317 | return failure(); |
318 | |
319 | MLIRContext *ctx = rewriter.getContext(); |
320 | Location loc = op.getLoc(); |
321 | |
322 | // Handle the scalar and 1D vector cases. |
323 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
324 | Type newOverflowType = typeConverter->convertType(overflowResultType); |
325 | Type structType = |
326 | LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); |
327 | Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>( |
328 | loc, structType, adaptor.getLhs(), adaptor.getRhs()); |
329 | Value sumExtracted = |
330 | rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0); |
331 | Value overflowExtracted = |
332 | rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1); |
333 | rewriter.replaceOp(op, {sumExtracted, overflowExtracted}); |
334 | return success(); |
335 | } |
336 | |
337 | if (!isa<VectorType>(Val: sumResultType)) |
338 | return rewriter.notifyMatchFailure(arg&: loc, msg: "expected vector result types"); |
339 | |
340 | return rewriter.notifyMatchFailure(arg&: loc, |
341 | msg: "ND vector types are not supported yet"); |
342 | } |
343 | |
344 | //===----------------------------------------------------------------------===// |
345 | // MulIExtendedOpLowering |
346 | //===----------------------------------------------------------------------===// |
347 | |
348 | template <typename ArithMulOp, bool IsSigned> |
349 | LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite( |
350 | ArithMulOp op, typename ArithMulOp::Adaptor adaptor, |
351 | ConversionPatternRewriter &rewriter) const { |
352 | Type resultType = adaptor.getLhs().getType(); |
353 | |
354 | if (!LLVM::isCompatibleType(type: resultType)) |
355 | return failure(); |
356 | |
357 | Location loc = op.getLoc(); |
358 | |
359 | // Handle the scalar and 1D vector cases. Because LLVM does not have a |
360 | // matching extended multiplication intrinsic, perform regular multiplication |
361 | // on operands zero-extended to i(2*N) bits, and truncate the results back to |
362 | // iN types. |
363 | if (!isa<LLVM::LLVMArrayType>(resultType)) { |
364 | // Shift amount necessary to extract the high bits from widened result. |
365 | TypedAttr shiftValAttr; |
366 | |
367 | if (auto intTy = dyn_cast<IntegerType>(resultType)) { |
368 | unsigned resultBitwidth = intTy.getWidth(); |
369 | auto attrTy = rewriter.getIntegerType(resultBitwidth * 2); |
370 | shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth); |
371 | } else { |
372 | auto vecTy = cast<VectorType>(resultType); |
373 | unsigned resultBitwidth = vecTy.getElementTypeBitWidth(); |
374 | auto attrTy = VectorType::get( |
375 | vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2)); |
376 | shiftValAttr = SplatElementsAttr::get( |
377 | attrTy, APInt(resultBitwidth * 2, resultBitwidth)); |
378 | } |
379 | Type wideType = shiftValAttr.getType(); |
380 | assert(LLVM::isCompatibleType(wideType) && |
381 | "LLVM dialect should support all signless integer types"); |
382 | |
383 | using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>; |
384 | Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs()); |
385 | Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs()); |
386 | Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt); |
387 | |
388 | // Split the 2*N-bit wide result into two N-bit values. |
389 | Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt); |
390 | Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr); |
391 | Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal); |
392 | Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt); |
393 | |
394 | rewriter.replaceOp(op, {low, high}); |
395 | return success(); |
396 | } |
397 | |
398 | if (!isa<VectorType>(Val: resultType)) |
399 | return rewriter.notifyMatchFailure(op, "expected vector result type"); |
400 | |
401 | return rewriter.notifyMatchFailure(op, |
402 | "ND vector types are not supported yet"); |
403 | } |
404 | |
405 | //===----------------------------------------------------------------------===// |
406 | // CmpIOpLowering |
407 | //===----------------------------------------------------------------------===// |
408 | |
409 | // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums |
410 | // share numerical values so just cast. |
411 | template <typename LLVMPredType, typename PredType> |
412 | static LLVMPredType convertCmpPredicate(PredType pred) { |
413 | return static_cast<LLVMPredType>(pred); |
414 | } |
415 | |
416 | LogicalResult |
417 | CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
418 | ConversionPatternRewriter &rewriter) const { |
419 | Type operandType = adaptor.getLhs().getType(); |
420 | Type resultType = op.getResult().getType(); |
421 | |
422 | // Handle the scalar and 1D vector cases. |
423 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
424 | rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( |
425 | op, typeConverter->convertType(resultType), |
426 | convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), |
427 | adaptor.getLhs(), adaptor.getRhs()); |
428 | return success(); |
429 | } |
430 | |
431 | if (!isa<VectorType>(Val: resultType)) |
432 | return rewriter.notifyMatchFailure(op, "expected vector result type"); |
433 | |
434 | return LLVM::detail::handleMultidimensionalVectors( |
435 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(), |
436 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) { |
437 | OpAdaptor adaptor(operands); |
438 | return rewriter.create<LLVM::ICmpOp>( |
439 | op.getLoc(), llvm1DVectorTy, |
440 | convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), |
441 | adaptor.getLhs(), adaptor.getRhs()); |
442 | }, |
443 | rewriter); |
444 | } |
445 | |
446 | //===----------------------------------------------------------------------===// |
447 | // CmpFOpLowering |
448 | //===----------------------------------------------------------------------===// |
449 | |
450 | LogicalResult |
451 | CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
452 | ConversionPatternRewriter &rewriter) const { |
453 | Type operandType = adaptor.getLhs().getType(); |
454 | Type resultType = op.getResult().getType(); |
455 | LLVM::FastmathFlags fmf = |
456 | arith::convertArithFastMathFlagsToLLVM(op.getFastmath()); |
457 | |
458 | // Handle the scalar and 1D vector cases. |
459 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
460 | rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( |
461 | op, typeConverter->convertType(resultType), |
462 | convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), |
463 | adaptor.getLhs(), adaptor.getRhs(), fmf); |
464 | return success(); |
465 | } |
466 | |
467 | if (!isa<VectorType>(Val: resultType)) |
468 | return rewriter.notifyMatchFailure(op, "expected vector result type"); |
469 | |
470 | return LLVM::detail::handleMultidimensionalVectors( |
471 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(), |
472 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) { |
473 | OpAdaptor adaptor(operands); |
474 | return rewriter.create<LLVM::FCmpOp>( |
475 | op.getLoc(), llvm1DVectorTy, |
476 | convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), |
477 | adaptor.getLhs(), adaptor.getRhs(), fmf); |
478 | }, |
479 | rewriter); |
480 | } |
481 | |
482 | //===----------------------------------------------------------------------===// |
483 | // Pass Definition |
484 | //===----------------------------------------------------------------------===// |
485 | |
486 | namespace { |
487 | struct ArithToLLVMConversionPass |
488 | : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> { |
489 | using Base::Base; |
490 | |
491 | void runOnOperation() override { |
492 | LLVMConversionTarget target(getContext()); |
493 | RewritePatternSet patterns(&getContext()); |
494 | |
495 | LowerToLLVMOptions options(&getContext()); |
496 | if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
497 | options.overrideIndexBitwidth(indexBitwidth); |
498 | |
499 | LLVMTypeConverter converter(&getContext(), options); |
500 | arith::populateCeilFloorDivExpandOpsPatterns(patterns); |
501 | arith::populateArithToLLVMConversionPatterns(converter, patterns); |
502 | |
503 | if (failed(applyPartialConversion(getOperation(), target, |
504 | std::move(patterns)))) |
505 | signalPassFailure(); |
506 | } |
507 | }; |
508 | } // namespace |
509 | |
510 | //===----------------------------------------------------------------------===// |
511 | // ConvertToLLVMPatternInterface implementation |
512 | //===----------------------------------------------------------------------===// |
513 | |
514 | namespace { |
515 | /// Implement the interface to convert MemRef to LLVM. |
516 | struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
517 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
518 | void loadDependentDialects(MLIRContext *context) const final { |
519 | context->loadDialect<LLVM::LLVMDialect>(); |
520 | } |
521 | |
522 | /// Hook for derived dialect interface to provide conversion patterns |
523 | /// and mark dialect legal for the conversion target. |
524 | void populateConvertToLLVMConversionPatterns( |
525 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
526 | RewritePatternSet &patterns) const final { |
527 | arith::populateCeilFloorDivExpandOpsPatterns(patterns); |
528 | arith::populateArithToLLVMConversionPatterns(converter: typeConverter, patterns); |
529 | } |
530 | }; |
531 | } // namespace |
532 | |
533 | void mlir::arith::registerConvertArithToLLVMInterface( |
534 | DialectRegistry ®istry) { |
535 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, arith::ArithDialect *dialect) { |
536 | dialect->addInterfaces<ArithToLLVMDialectInterface>(); |
537 | }); |
538 | } |
539 | |
540 | //===----------------------------------------------------------------------===// |
541 | // Pattern Population |
542 | //===----------------------------------------------------------------------===// |
543 | |
544 | void mlir::arith::populateArithToLLVMConversionPatterns( |
545 | const LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
546 | |
547 | // Set a higher pattern benefit for IdentityBitcastLowering so it will run |
548 | // before BitcastOpLowering. |
549 | patterns.add<IdentityBitcastLowering>(arg: converter, args: patterns.getContext(), |
550 | /*patternBenefit*/ args: 10); |
551 | |
552 | // clang-format off |
553 | patterns.add< |
554 | AddFOpLowering, |
555 | AddIOpLowering, |
556 | AndIOpLowering, |
557 | AddUIExtendedOpLowering, |
558 | BitcastOpLowering, |
559 | ConstantOpLowering, |
560 | CmpFOpLowering, |
561 | CmpIOpLowering, |
562 | DivFOpLowering, |
563 | DivSIOpLowering, |
564 | DivUIOpLowering, |
565 | ExtFOpLowering, |
566 | ExtSIOpLowering, |
567 | ExtUIOpLowering, |
568 | FPToSIOpLowering, |
569 | FPToUIOpLowering, |
570 | IndexCastOpSILowering, |
571 | IndexCastOpUILowering, |
572 | MaximumFOpLowering, |
573 | MaxNumFOpLowering, |
574 | MaxSIOpLowering, |
575 | MaxUIOpLowering, |
576 | MinimumFOpLowering, |
577 | MinNumFOpLowering, |
578 | MinSIOpLowering, |
579 | MinUIOpLowering, |
580 | MulFOpLowering, |
581 | MulIOpLowering, |
582 | MulSIExtendedOpLowering, |
583 | MulUIExtendedOpLowering, |
584 | NegFOpLowering, |
585 | OrIOpLowering, |
586 | RemFOpLowering, |
587 | RemSIOpLowering, |
588 | RemUIOpLowering, |
589 | SelectOpLowering, |
590 | ShLIOpLowering, |
591 | ShRSIOpLowering, |
592 | ShRUIOpLowering, |
593 | SIToFPOpLowering, |
594 | SubFOpLowering, |
595 | SubIOpLowering, |
596 | TruncFOpLowering, |
597 | ConstrainedTruncFOpLowering, |
598 | TruncIOpLowering, |
599 | UIToFPOpLowering, |
600 | XOrIOpLowering |
601 | >(converter); |
602 | // clang-format on |
603 | } |
604 |
Definitions
- ConstrainedVectorConvertToLLVMPattern
- matchAndRewrite
- IdentityBitcastLowering
- matchAndRewrite
- ConstantOpLowering
- IndexCastOpLowering
- AddUIExtendedOpLowering
- MulIExtendedOpLowering
- CmpIOpLowering
- CmpFOpLowering
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- convertCmpPredicate
- matchAndRewrite
- matchAndRewrite
- ArithToLLVMConversionPass
- runOnOperation
- ArithToLLVMDialectInterface
- loadDependentDialects
- populateConvertToLLVMConversionPatterns
- registerConvertArithToLLVMInterface
Learn to use CMake with our Intro Training
Find out more