1 | //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
10 | |
11 | #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" |
12 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
13 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
14 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
17 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
18 | #include "mlir/IR/TypeUtilities.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | #include <type_traits> |
21 | |
22 | namespace mlir { |
23 | #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS |
24 | #include "mlir/Conversion/Passes.h.inc" |
25 | } // namespace mlir |
26 | |
27 | using namespace mlir; |
28 | |
29 | namespace { |
30 | |
31 | /// Operations whose conversion will depend on whether they are passed a |
32 | /// rounding mode attribute or not. |
33 | /// |
34 | /// `SourceOp` is the source operation; `TargetOp`, the operation it will lower |
35 | /// to; `AttrConvert` is the attribute conversion to convert the rounding mode |
36 | /// attribute. |
37 | template <typename SourceOp, typename TargetOp, bool Constrained, |
38 | template <typename, typename> typename AttrConvert = |
39 | AttrConvertPassThrough> |
40 | struct ConstrainedVectorConvertToLLVMPattern |
41 | : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> { |
42 | using VectorConvertToLLVMPattern<SourceOp, TargetOp, |
43 | AttrConvert>::VectorConvertToLLVMPattern; |
44 | |
45 | LogicalResult |
46 | matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, |
47 | ConversionPatternRewriter &rewriter) const override { |
48 | if (Constrained != static_cast<bool>(op.getRoundingModeAttr())) |
49 | return failure(); |
50 | return VectorConvertToLLVMPattern<SourceOp, TargetOp, |
51 | AttrConvert>::matchAndRewrite(op, adaptor, |
52 | rewriter); |
53 | } |
54 | }; |
55 | |
56 | //===----------------------------------------------------------------------===// |
57 | // Straightforward Op Lowerings |
58 | //===----------------------------------------------------------------------===// |
59 | |
60 | using AddFOpLowering = |
61 | VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp, |
62 | arith::AttrConvertFastMathToLLVM>; |
63 | using AddIOpLowering = |
64 | VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp, |
65 | arith::AttrConvertOverflowToLLVM>; |
66 | using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>; |
67 | using BitcastOpLowering = |
68 | VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; |
69 | using DivFOpLowering = |
70 | VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp, |
71 | arith::AttrConvertFastMathToLLVM>; |
72 | using DivSIOpLowering = |
73 | VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; |
74 | using DivUIOpLowering = |
75 | VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; |
76 | using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; |
77 | using ExtSIOpLowering = |
78 | VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; |
79 | using ExtUIOpLowering = |
80 | VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; |
81 | using FPToSIOpLowering = |
82 | VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; |
83 | using FPToUIOpLowering = |
84 | VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; |
85 | using MaximumFOpLowering = |
86 | VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp, |
87 | arith::AttrConvertFastMathToLLVM>; |
88 | using MaxNumFOpLowering = |
89 | VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp, |
90 | arith::AttrConvertFastMathToLLVM>; |
91 | using MaxSIOpLowering = |
92 | VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>; |
93 | using MaxUIOpLowering = |
94 | VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>; |
95 | using MinimumFOpLowering = |
96 | VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp, |
97 | arith::AttrConvertFastMathToLLVM>; |
98 | using MinNumFOpLowering = |
99 | VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp, |
100 | arith::AttrConvertFastMathToLLVM>; |
101 | using MinSIOpLowering = |
102 | VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>; |
103 | using MinUIOpLowering = |
104 | VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>; |
105 | using MulFOpLowering = |
106 | VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp, |
107 | arith::AttrConvertFastMathToLLVM>; |
108 | using MulIOpLowering = |
109 | VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp, |
110 | arith::AttrConvertOverflowToLLVM>; |
111 | using NegFOpLowering = |
112 | VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp, |
113 | arith::AttrConvertFastMathToLLVM>; |
114 | using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; |
115 | using RemFOpLowering = |
116 | VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, |
117 | arith::AttrConvertFastMathToLLVM>; |
118 | using RemSIOpLowering = |
119 | VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; |
120 | using RemUIOpLowering = |
121 | VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>; |
122 | using SelectOpLowering = |
123 | VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>; |
124 | using ShLIOpLowering = |
125 | VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp, |
126 | arith::AttrConvertOverflowToLLVM>; |
127 | using ShRSIOpLowering = |
128 | VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>; |
129 | using ShRUIOpLowering = |
130 | VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>; |
131 | using SIToFPOpLowering = |
132 | VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; |
133 | using SubFOpLowering = |
134 | VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp, |
135 | arith::AttrConvertFastMathToLLVM>; |
136 | using SubIOpLowering = |
137 | VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp, |
138 | arith::AttrConvertOverflowToLLVM>; |
139 | using TruncFOpLowering = |
140 | ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp, |
141 | false>; |
142 | using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern< |
143 | arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true, |
144 | arith::AttrConverterConstrainedFPToLLVM>; |
145 | using TruncIOpLowering = |
146 | VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>; |
147 | using UIToFPOpLowering = |
148 | VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; |
149 | using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; |
150 | |
151 | //===----------------------------------------------------------------------===// |
152 | // Op Lowering Patterns |
153 | //===----------------------------------------------------------------------===// |
154 | |
155 | /// Directly lower to LLVM op. |
156 | struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> { |
157 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
158 | |
159 | LogicalResult |
160 | matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, |
161 | ConversionPatternRewriter &rewriter) const override; |
162 | }; |
163 | |
164 | /// The lowering of index_cast becomes an integer conversion since index |
165 | /// becomes an integer. If the bit width of the source and target integer |
166 | /// types is the same, just erase the cast. If the target type is wider, |
167 | /// sign-extend the value, otherwise truncate it. |
168 | template <typename OpTy, typename ExtCastTy> |
169 | struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> { |
170 | using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern; |
171 | |
172 | LogicalResult |
173 | matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, |
174 | ConversionPatternRewriter &rewriter) const override; |
175 | }; |
176 | |
177 | using IndexCastOpSILowering = |
178 | IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>; |
179 | using IndexCastOpUILowering = |
180 | IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>; |
181 | |
182 | struct AddUIExtendedOpLowering |
183 | : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> { |
184 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
185 | |
186 | LogicalResult |
187 | matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, |
188 | ConversionPatternRewriter &rewriter) const override; |
189 | }; |
190 | |
191 | template <typename ArithMulOp, bool IsSigned> |
192 | struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> { |
193 | using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern; |
194 | |
195 | LogicalResult |
196 | matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, |
197 | ConversionPatternRewriter &rewriter) const override; |
198 | }; |
199 | |
200 | using MulSIExtendedOpLowering = |
201 | MulIExtendedOpLowering<arith::MulSIExtendedOp, true>; |
202 | using MulUIExtendedOpLowering = |
203 | MulIExtendedOpLowering<arith::MulUIExtendedOp, false>; |
204 | |
205 | struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> { |
206 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
207 | |
208 | LogicalResult |
209 | matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
210 | ConversionPatternRewriter &rewriter) const override; |
211 | }; |
212 | |
213 | struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { |
214 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
215 | |
216 | LogicalResult |
217 | matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
218 | ConversionPatternRewriter &rewriter) const override; |
219 | }; |
220 | |
221 | } // namespace |
222 | |
223 | //===----------------------------------------------------------------------===// |
224 | // ConstantOpLowering |
225 | //===----------------------------------------------------------------------===// |
226 | |
227 | LogicalResult |
228 | ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, |
229 | ConversionPatternRewriter &rewriter) const { |
230 | return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), |
231 | adaptor.getOperands(), op->getAttrs(), |
232 | *getTypeConverter(), rewriter); |
233 | } |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // IndexCastOpLowering |
237 | //===----------------------------------------------------------------------===// |
238 | |
239 | template <typename OpTy, typename ExtCastTy> |
240 | LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite( |
241 | OpTy op, typename OpTy::Adaptor adaptor, |
242 | ConversionPatternRewriter &rewriter) const { |
243 | Type resultType = op.getResult().getType(); |
244 | Type targetElementType = |
245 | this->typeConverter->convertType(getElementTypeOrSelf(type: resultType)); |
246 | Type sourceElementType = |
247 | this->typeConverter->convertType(getElementTypeOrSelf(op.getIn())); |
248 | unsigned targetBits = targetElementType.getIntOrFloatBitWidth(); |
249 | unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth(); |
250 | |
251 | if (targetBits == sourceBits) { |
252 | rewriter.replaceOp(op, adaptor.getIn()); |
253 | return success(); |
254 | } |
255 | |
256 | // Handle the scalar and 1D vector cases. |
257 | Type operandType = adaptor.getIn().getType(); |
258 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
259 | Type targetType = this->typeConverter->convertType(resultType); |
260 | if (targetBits < sourceBits) |
261 | rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, |
262 | adaptor.getIn()); |
263 | else |
264 | rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn()); |
265 | return success(); |
266 | } |
267 | |
268 | if (!isa<VectorType>(Val: resultType)) |
269 | return rewriter.notifyMatchFailure(op, "expected vector result type" ); |
270 | |
271 | return LLVM::detail::handleMultidimensionalVectors( |
272 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *(this->getTypeConverter()), |
273 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) -> Value { |
274 | typename OpTy::Adaptor adaptor(operands); |
275 | if (targetBits < sourceBits) { |
276 | return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy, |
277 | adaptor.getIn()); |
278 | } |
279 | return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy, |
280 | adaptor.getIn()); |
281 | }, |
282 | rewriter); |
283 | } |
284 | |
285 | //===----------------------------------------------------------------------===// |
286 | // AddUIExtendedOpLowering |
287 | //===----------------------------------------------------------------------===// |
288 | |
289 | LogicalResult AddUIExtendedOpLowering::matchAndRewrite( |
290 | arith::AddUIExtendedOp op, OpAdaptor adaptor, |
291 | ConversionPatternRewriter &rewriter) const { |
292 | Type operandType = adaptor.getLhs().getType(); |
293 | Type sumResultType = op.getSum().getType(); |
294 | Type overflowResultType = op.getOverflow().getType(); |
295 | |
296 | if (!LLVM::isCompatibleType(type: operandType)) |
297 | return failure(); |
298 | |
299 | MLIRContext *ctx = rewriter.getContext(); |
300 | Location loc = op.getLoc(); |
301 | |
302 | // Handle the scalar and 1D vector cases. |
303 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
304 | Type newOverflowType = typeConverter->convertType(overflowResultType); |
305 | Type structType = |
306 | LLVM::LLVMStructType::getLiteral(context: ctx, types: {sumResultType, newOverflowType}); |
307 | Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>( |
308 | loc, structType, adaptor.getLhs(), adaptor.getRhs()); |
309 | Value = |
310 | rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0); |
311 | Value = |
312 | rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1); |
313 | rewriter.replaceOp(op, {sumExtracted, overflowExtracted}); |
314 | return success(); |
315 | } |
316 | |
317 | if (!isa<VectorType>(Val: sumResultType)) |
318 | return rewriter.notifyMatchFailure(arg&: loc, msg: "expected vector result types" ); |
319 | |
320 | return rewriter.notifyMatchFailure(arg&: loc, |
321 | msg: "ND vector types are not supported yet" ); |
322 | } |
323 | |
324 | //===----------------------------------------------------------------------===// |
325 | // MulIExtendedOpLowering |
326 | //===----------------------------------------------------------------------===// |
327 | |
328 | template <typename ArithMulOp, bool IsSigned> |
329 | LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite( |
330 | ArithMulOp op, typename ArithMulOp::Adaptor adaptor, |
331 | ConversionPatternRewriter &rewriter) const { |
332 | Type resultType = adaptor.getLhs().getType(); |
333 | |
334 | if (!LLVM::isCompatibleType(type: resultType)) |
335 | return failure(); |
336 | |
337 | Location loc = op.getLoc(); |
338 | |
339 | // Handle the scalar and 1D vector cases. Because LLVM does not have a |
340 | // matching extended multiplication intrinsic, perform regular multiplication |
341 | // on operands zero-extended to i(2*N) bits, and truncate the results back to |
342 | // iN types. |
343 | if (!isa<LLVM::LLVMArrayType>(resultType)) { |
344 | // Shift amount necessary to extract the high bits from widened result. |
345 | TypedAttr shiftValAttr; |
346 | |
347 | if (auto intTy = dyn_cast<IntegerType>(resultType)) { |
348 | unsigned resultBitwidth = intTy.getWidth(); |
349 | auto attrTy = rewriter.getIntegerType(resultBitwidth * 2); |
350 | shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth); |
351 | } else { |
352 | auto vecTy = cast<VectorType>(resultType); |
353 | unsigned resultBitwidth = vecTy.getElementTypeBitWidth(); |
354 | auto attrTy = VectorType::get( |
355 | vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2)); |
356 | shiftValAttr = SplatElementsAttr::get( |
357 | attrTy, APInt(resultBitwidth * 2, resultBitwidth)); |
358 | } |
359 | Type wideType = shiftValAttr.getType(); |
360 | assert(LLVM::isCompatibleType(wideType) && |
361 | "LLVM dialect should support all signless integer types" ); |
362 | |
363 | using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>; |
364 | Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs()); |
365 | Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs()); |
366 | Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt); |
367 | |
368 | // Split the 2*N-bit wide result into two N-bit values. |
369 | Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt); |
370 | Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr); |
371 | Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal); |
372 | Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt); |
373 | |
374 | rewriter.replaceOp(op, {low, high}); |
375 | return success(); |
376 | } |
377 | |
378 | if (!isa<VectorType>(Val: resultType)) |
379 | return rewriter.notifyMatchFailure(op, "expected vector result type" ); |
380 | |
381 | return rewriter.notifyMatchFailure(op, |
382 | "ND vector types are not supported yet" ); |
383 | } |
384 | |
385 | //===----------------------------------------------------------------------===// |
386 | // CmpIOpLowering |
387 | //===----------------------------------------------------------------------===// |
388 | |
389 | // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums |
390 | // share numerical values so just cast. |
391 | template <typename LLVMPredType, typename PredType> |
392 | static LLVMPredType convertCmpPredicate(PredType pred) { |
393 | return static_cast<LLVMPredType>(pred); |
394 | } |
395 | |
396 | LogicalResult |
397 | CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
398 | ConversionPatternRewriter &rewriter) const { |
399 | Type operandType = adaptor.getLhs().getType(); |
400 | Type resultType = op.getResult().getType(); |
401 | |
402 | // Handle the scalar and 1D vector cases. |
403 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
404 | rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( |
405 | op, typeConverter->convertType(resultType), |
406 | convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), |
407 | adaptor.getLhs(), adaptor.getRhs()); |
408 | return success(); |
409 | } |
410 | |
411 | if (!isa<VectorType>(Val: resultType)) |
412 | return rewriter.notifyMatchFailure(op, "expected vector result type" ); |
413 | |
414 | return LLVM::detail::handleMultidimensionalVectors( |
415 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(), |
416 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) { |
417 | OpAdaptor adaptor(operands); |
418 | return rewriter.create<LLVM::ICmpOp>( |
419 | op.getLoc(), llvm1DVectorTy, |
420 | convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), |
421 | adaptor.getLhs(), adaptor.getRhs()); |
422 | }, |
423 | rewriter); |
424 | } |
425 | |
426 | //===----------------------------------------------------------------------===// |
427 | // CmpFOpLowering |
428 | //===----------------------------------------------------------------------===// |
429 | |
430 | LogicalResult |
431 | CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
432 | ConversionPatternRewriter &rewriter) const { |
433 | Type operandType = adaptor.getLhs().getType(); |
434 | Type resultType = op.getResult().getType(); |
435 | LLVM::FastmathFlags fmf = |
436 | arith::convertArithFastMathFlagsToLLVM(op.getFastmath()); |
437 | |
438 | // Handle the scalar and 1D vector cases. |
439 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
440 | rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( |
441 | op, typeConverter->convertType(resultType), |
442 | convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), |
443 | adaptor.getLhs(), adaptor.getRhs(), fmf); |
444 | return success(); |
445 | } |
446 | |
447 | if (!isa<VectorType>(Val: resultType)) |
448 | return rewriter.notifyMatchFailure(op, "expected vector result type" ); |
449 | |
450 | return LLVM::detail::handleMultidimensionalVectors( |
451 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(), |
452 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) { |
453 | OpAdaptor adaptor(operands); |
454 | return rewriter.create<LLVM::FCmpOp>( |
455 | op.getLoc(), llvm1DVectorTy, |
456 | convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), |
457 | adaptor.getLhs(), adaptor.getRhs(), fmf); |
458 | }, |
459 | rewriter); |
460 | } |
461 | |
462 | //===----------------------------------------------------------------------===// |
463 | // Pass Definition |
464 | //===----------------------------------------------------------------------===// |
465 | |
466 | namespace { |
467 | struct ArithToLLVMConversionPass |
468 | : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> { |
469 | using Base::Base; |
470 | |
471 | void runOnOperation() override { |
472 | LLVMConversionTarget target(getContext()); |
473 | RewritePatternSet patterns(&getContext()); |
474 | |
475 | LowerToLLVMOptions options(&getContext()); |
476 | if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
477 | options.overrideIndexBitwidth(indexBitwidth); |
478 | |
479 | LLVMTypeConverter converter(&getContext(), options); |
480 | mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns); |
481 | |
482 | if (failed(applyPartialConversion(getOperation(), target, |
483 | std::move(patterns)))) |
484 | signalPassFailure(); |
485 | } |
486 | }; |
487 | } // namespace |
488 | |
489 | //===----------------------------------------------------------------------===// |
490 | // ConvertToLLVMPatternInterface implementation |
491 | //===----------------------------------------------------------------------===// |
492 | |
493 | namespace { |
494 | /// Implement the interface to convert MemRef to LLVM. |
495 | struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
496 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
497 | void loadDependentDialects(MLIRContext *context) const final { |
498 | context->loadDialect<LLVM::LLVMDialect>(); |
499 | } |
500 | |
501 | /// Hook for derived dialect interface to provide conversion patterns |
502 | /// and mark dialect legal for the conversion target. |
503 | void populateConvertToLLVMConversionPatterns( |
504 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
505 | RewritePatternSet &patterns) const final { |
506 | arith::populateArithToLLVMConversionPatterns(converter&: typeConverter, patterns); |
507 | } |
508 | }; |
509 | } // namespace |
510 | |
511 | void mlir::arith::registerConvertArithToLLVMInterface( |
512 | DialectRegistry ®istry) { |
513 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, arith::ArithDialect *dialect) { |
514 | dialect->addInterfaces<ArithToLLVMDialectInterface>(); |
515 | }); |
516 | } |
517 | |
518 | //===----------------------------------------------------------------------===// |
519 | // Pattern Population |
520 | //===----------------------------------------------------------------------===// |
521 | |
522 | void mlir::arith::populateArithToLLVMConversionPatterns( |
523 | LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
524 | // clang-format off |
525 | patterns.add< |
526 | AddFOpLowering, |
527 | AddIOpLowering, |
528 | AndIOpLowering, |
529 | AddUIExtendedOpLowering, |
530 | BitcastOpLowering, |
531 | ConstantOpLowering, |
532 | CmpFOpLowering, |
533 | CmpIOpLowering, |
534 | DivFOpLowering, |
535 | DivSIOpLowering, |
536 | DivUIOpLowering, |
537 | ExtFOpLowering, |
538 | ExtSIOpLowering, |
539 | ExtUIOpLowering, |
540 | FPToSIOpLowering, |
541 | FPToUIOpLowering, |
542 | IndexCastOpSILowering, |
543 | IndexCastOpUILowering, |
544 | MaximumFOpLowering, |
545 | MaxNumFOpLowering, |
546 | MaxSIOpLowering, |
547 | MaxUIOpLowering, |
548 | MinimumFOpLowering, |
549 | MinNumFOpLowering, |
550 | MinSIOpLowering, |
551 | MinUIOpLowering, |
552 | MulFOpLowering, |
553 | MulIOpLowering, |
554 | MulSIExtendedOpLowering, |
555 | MulUIExtendedOpLowering, |
556 | NegFOpLowering, |
557 | OrIOpLowering, |
558 | RemFOpLowering, |
559 | RemSIOpLowering, |
560 | RemUIOpLowering, |
561 | SelectOpLowering, |
562 | ShLIOpLowering, |
563 | ShRSIOpLowering, |
564 | ShRUIOpLowering, |
565 | SIToFPOpLowering, |
566 | SubFOpLowering, |
567 | SubIOpLowering, |
568 | TruncFOpLowering, |
569 | ConstrainedTruncFOpLowering, |
570 | TruncIOpLowering, |
571 | UIToFPOpLowering, |
572 | XOrIOpLowering |
573 | >(converter); |
574 | // clang-format on |
575 | } |
576 | |