| 1 | //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
| 10 | |
| 11 | #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" |
| 12 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
| 13 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| 14 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
| 17 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
| 18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 19 | #include "mlir/IR/TypeUtilities.h" |
| 20 | #include "mlir/Pass/Pass.h" |
| 21 | #include <type_traits> |
| 22 | |
| 23 | namespace mlir { |
| 24 | #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS |
| 25 | #include "mlir/Conversion/Passes.h.inc" |
| 26 | } // namespace mlir |
| 27 | |
| 28 | using namespace mlir; |
| 29 | |
| 30 | namespace { |
| 31 | |
| 32 | /// Operations whose conversion will depend on whether they are passed a |
| 33 | /// rounding mode attribute or not. |
| 34 | /// |
| 35 | /// `SourceOp` is the source operation; `TargetOp`, the operation it will lower |
| 36 | /// to; `AttrConvert` is the attribute conversion to convert the rounding mode |
| 37 | /// attribute. |
| 38 | template <typename SourceOp, typename TargetOp, bool Constrained, |
| 39 | template <typename, typename> typename AttrConvert = |
| 40 | AttrConvertPassThrough> |
| 41 | struct ConstrainedVectorConvertToLLVMPattern |
| 42 | : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> { |
| 43 | using VectorConvertToLLVMPattern<SourceOp, TargetOp, |
| 44 | AttrConvert>::VectorConvertToLLVMPattern; |
| 45 | |
| 46 | LogicalResult |
| 47 | matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, |
| 48 | ConversionPatternRewriter &rewriter) const override { |
| 49 | if (Constrained != static_cast<bool>(op.getRoundingModeAttr())) |
| 50 | return failure(); |
| 51 | return VectorConvertToLLVMPattern<SourceOp, TargetOp, |
| 52 | AttrConvert>::matchAndRewrite(op, adaptor, |
| 53 | rewriter); |
| 54 | } |
| 55 | }; |
| 56 | |
| 57 | /// No-op bitcast. Propagate type input arg if converted source and dest types |
| 58 | /// are the same. |
| 59 | struct IdentityBitcastLowering final |
| 60 | : public OpConversionPattern<arith::BitcastOp> { |
| 61 | using OpConversionPattern::OpConversionPattern; |
| 62 | |
| 63 | LogicalResult |
| 64 | matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, |
| 65 | ConversionPatternRewriter &rewriter) const final { |
| 66 | Value src = adaptor.getIn(); |
| 67 | Type resultType = getTypeConverter()->convertType(op.getType()); |
| 68 | if (src.getType() != resultType) |
| 69 | return rewriter.notifyMatchFailure(op, "Types are different" ); |
| 70 | |
| 71 | rewriter.replaceOp(op, src); |
| 72 | return success(); |
| 73 | } |
| 74 | }; |
| 75 | |
| 76 | //===----------------------------------------------------------------------===// |
| 77 | // Straightforward Op Lowerings |
| 78 | //===----------------------------------------------------------------------===// |
| 79 | |
| 80 | using AddFOpLowering = |
| 81 | VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp, |
| 82 | arith::AttrConvertFastMathToLLVM>; |
| 83 | using AddIOpLowering = |
| 84 | VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp, |
| 85 | arith::AttrConvertOverflowToLLVM>; |
| 86 | using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>; |
| 87 | using BitcastOpLowering = |
| 88 | VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; |
| 89 | using DivFOpLowering = |
| 90 | VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp, |
| 91 | arith::AttrConvertFastMathToLLVM>; |
| 92 | using DivSIOpLowering = |
| 93 | VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; |
| 94 | using DivUIOpLowering = |
| 95 | VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; |
| 96 | using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; |
| 97 | using ExtSIOpLowering = |
| 98 | VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; |
| 99 | using ExtUIOpLowering = |
| 100 | VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; |
| 101 | using FPToSIOpLowering = |
| 102 | VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; |
| 103 | using FPToUIOpLowering = |
| 104 | VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; |
| 105 | using MaximumFOpLowering = |
| 106 | VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp, |
| 107 | arith::AttrConvertFastMathToLLVM>; |
| 108 | using MaxNumFOpLowering = |
| 109 | VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp, |
| 110 | arith::AttrConvertFastMathToLLVM>; |
| 111 | using MaxSIOpLowering = |
| 112 | VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>; |
| 113 | using MaxUIOpLowering = |
| 114 | VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>; |
| 115 | using MinimumFOpLowering = |
| 116 | VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp, |
| 117 | arith::AttrConvertFastMathToLLVM>; |
| 118 | using MinNumFOpLowering = |
| 119 | VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp, |
| 120 | arith::AttrConvertFastMathToLLVM>; |
| 121 | using MinSIOpLowering = |
| 122 | VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>; |
| 123 | using MinUIOpLowering = |
| 124 | VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>; |
| 125 | using MulFOpLowering = |
| 126 | VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp, |
| 127 | arith::AttrConvertFastMathToLLVM>; |
| 128 | using MulIOpLowering = |
| 129 | VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp, |
| 130 | arith::AttrConvertOverflowToLLVM>; |
| 131 | using NegFOpLowering = |
| 132 | VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp, |
| 133 | arith::AttrConvertFastMathToLLVM>; |
| 134 | using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; |
| 135 | using RemFOpLowering = |
| 136 | VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, |
| 137 | arith::AttrConvertFastMathToLLVM>; |
| 138 | using RemSIOpLowering = |
| 139 | VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; |
| 140 | using RemUIOpLowering = |
| 141 | VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>; |
| 142 | using SelectOpLowering = |
| 143 | VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>; |
| 144 | using ShLIOpLowering = |
| 145 | VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp, |
| 146 | arith::AttrConvertOverflowToLLVM>; |
| 147 | using ShRSIOpLowering = |
| 148 | VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>; |
| 149 | using ShRUIOpLowering = |
| 150 | VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>; |
| 151 | using SIToFPOpLowering = |
| 152 | VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; |
| 153 | using SubFOpLowering = |
| 154 | VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp, |
| 155 | arith::AttrConvertFastMathToLLVM>; |
| 156 | using SubIOpLowering = |
| 157 | VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp, |
| 158 | arith::AttrConvertOverflowToLLVM>; |
| 159 | using TruncFOpLowering = |
| 160 | ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp, |
| 161 | false>; |
| 162 | using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern< |
| 163 | arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true, |
| 164 | arith::AttrConverterConstrainedFPToLLVM>; |
| 165 | using TruncIOpLowering = |
| 166 | VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>; |
| 167 | using UIToFPOpLowering = |
| 168 | VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; |
| 169 | using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; |
| 170 | |
| 171 | //===----------------------------------------------------------------------===// |
| 172 | // Op Lowering Patterns |
| 173 | //===----------------------------------------------------------------------===// |
| 174 | |
| 175 | /// Directly lower to LLVM op. |
| 176 | struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> { |
| 177 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 178 | |
| 179 | LogicalResult |
| 180 | matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, |
| 181 | ConversionPatternRewriter &rewriter) const override; |
| 182 | }; |
| 183 | |
| 184 | /// The lowering of index_cast becomes an integer conversion since index |
| 185 | /// becomes an integer. If the bit width of the source and target integer |
| 186 | /// types is the same, just erase the cast. If the target type is wider, |
| 187 | /// sign-extend the value, otherwise truncate it. |
| 188 | template <typename OpTy, typename ExtCastTy> |
| 189 | struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> { |
| 190 | using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern; |
| 191 | |
| 192 | LogicalResult |
| 193 | matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, |
| 194 | ConversionPatternRewriter &rewriter) const override; |
| 195 | }; |
| 196 | |
| 197 | using IndexCastOpSILowering = |
| 198 | IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>; |
| 199 | using IndexCastOpUILowering = |
| 200 | IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>; |
| 201 | |
| 202 | struct AddUIExtendedOpLowering |
| 203 | : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> { |
| 204 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 205 | |
| 206 | LogicalResult |
| 207 | matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, |
| 208 | ConversionPatternRewriter &rewriter) const override; |
| 209 | }; |
| 210 | |
| 211 | template <typename ArithMulOp, bool IsSigned> |
| 212 | struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> { |
| 213 | using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern; |
| 214 | |
| 215 | LogicalResult |
| 216 | matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, |
| 217 | ConversionPatternRewriter &rewriter) const override; |
| 218 | }; |
| 219 | |
| 220 | using MulSIExtendedOpLowering = |
| 221 | MulIExtendedOpLowering<arith::MulSIExtendedOp, true>; |
| 222 | using MulUIExtendedOpLowering = |
| 223 | MulIExtendedOpLowering<arith::MulUIExtendedOp, false>; |
| 224 | |
| 225 | struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> { |
| 226 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 227 | |
| 228 | LogicalResult |
| 229 | matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
| 230 | ConversionPatternRewriter &rewriter) const override; |
| 231 | }; |
| 232 | |
| 233 | struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { |
| 234 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 235 | |
| 236 | LogicalResult |
| 237 | matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
| 238 | ConversionPatternRewriter &rewriter) const override; |
| 239 | }; |
| 240 | |
| 241 | } // namespace |
| 242 | |
| 243 | //===----------------------------------------------------------------------===// |
| 244 | // ConstantOpLowering |
| 245 | //===----------------------------------------------------------------------===// |
| 246 | |
| 247 | LogicalResult |
| 248 | ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, |
| 249 | ConversionPatternRewriter &rewriter) const { |
| 250 | return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), |
| 251 | adaptor.getOperands(), op->getAttrs(), |
| 252 | *getTypeConverter(), rewriter); |
| 253 | } |
| 254 | |
| 255 | //===----------------------------------------------------------------------===// |
| 256 | // IndexCastOpLowering |
| 257 | //===----------------------------------------------------------------------===// |
| 258 | |
| 259 | template <typename OpTy, typename ExtCastTy> |
| 260 | LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite( |
| 261 | OpTy op, typename OpTy::Adaptor adaptor, |
| 262 | ConversionPatternRewriter &rewriter) const { |
| 263 | Type resultType = op.getResult().getType(); |
| 264 | Type targetElementType = |
| 265 | this->typeConverter->convertType(getElementTypeOrSelf(type: resultType)); |
| 266 | Type sourceElementType = |
| 267 | this->typeConverter->convertType(getElementTypeOrSelf(op.getIn())); |
| 268 | unsigned targetBits = targetElementType.getIntOrFloatBitWidth(); |
| 269 | unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth(); |
| 270 | |
| 271 | if (targetBits == sourceBits) { |
| 272 | rewriter.replaceOp(op, adaptor.getIn()); |
| 273 | return success(); |
| 274 | } |
| 275 | |
| 276 | // Handle the scalar and 1D vector cases. |
| 277 | Type operandType = adaptor.getIn().getType(); |
| 278 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
| 279 | Type targetType = this->typeConverter->convertType(resultType); |
| 280 | if (targetBits < sourceBits) |
| 281 | rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, |
| 282 | adaptor.getIn()); |
| 283 | else |
| 284 | rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn()); |
| 285 | return success(); |
| 286 | } |
| 287 | |
| 288 | if (!isa<VectorType>(Val: resultType)) |
| 289 | return rewriter.notifyMatchFailure(op, "expected vector result type" ); |
| 290 | |
| 291 | return LLVM::detail::handleMultidimensionalVectors( |
| 292 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *(this->getTypeConverter()), |
| 293 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) -> Value { |
| 294 | typename OpTy::Adaptor adaptor(operands); |
| 295 | if (targetBits < sourceBits) { |
| 296 | return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy, |
| 297 | adaptor.getIn()); |
| 298 | } |
| 299 | return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy, |
| 300 | adaptor.getIn()); |
| 301 | }, |
| 302 | rewriter); |
| 303 | } |
| 304 | |
| 305 | //===----------------------------------------------------------------------===// |
| 306 | // AddUIExtendedOpLowering |
| 307 | //===----------------------------------------------------------------------===// |
| 308 | |
| 309 | LogicalResult AddUIExtendedOpLowering::matchAndRewrite( |
| 310 | arith::AddUIExtendedOp op, OpAdaptor adaptor, |
| 311 | ConversionPatternRewriter &rewriter) const { |
| 312 | Type operandType = adaptor.getLhs().getType(); |
| 313 | Type sumResultType = op.getSum().getType(); |
| 314 | Type overflowResultType = op.getOverflow().getType(); |
| 315 | |
| 316 | if (!LLVM::isCompatibleType(type: operandType)) |
| 317 | return failure(); |
| 318 | |
| 319 | MLIRContext *ctx = rewriter.getContext(); |
| 320 | Location loc = op.getLoc(); |
| 321 | |
| 322 | // Handle the scalar and 1D vector cases. |
| 323 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
| 324 | Type newOverflowType = typeConverter->convertType(overflowResultType); |
| 325 | Type structType = |
| 326 | LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); |
| 327 | Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>( |
| 328 | loc, structType, adaptor.getLhs(), adaptor.getRhs()); |
| 329 | Value = |
| 330 | rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0); |
| 331 | Value = |
| 332 | rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1); |
| 333 | rewriter.replaceOp(op, {sumExtracted, overflowExtracted}); |
| 334 | return success(); |
| 335 | } |
| 336 | |
| 337 | if (!isa<VectorType>(Val: sumResultType)) |
| 338 | return rewriter.notifyMatchFailure(arg&: loc, msg: "expected vector result types" ); |
| 339 | |
| 340 | return rewriter.notifyMatchFailure(arg&: loc, |
| 341 | msg: "ND vector types are not supported yet" ); |
| 342 | } |
| 343 | |
| 344 | //===----------------------------------------------------------------------===// |
| 345 | // MulIExtendedOpLowering |
| 346 | //===----------------------------------------------------------------------===// |
| 347 | |
| 348 | template <typename ArithMulOp, bool IsSigned> |
| 349 | LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite( |
| 350 | ArithMulOp op, typename ArithMulOp::Adaptor adaptor, |
| 351 | ConversionPatternRewriter &rewriter) const { |
| 352 | Type resultType = adaptor.getLhs().getType(); |
| 353 | |
| 354 | if (!LLVM::isCompatibleType(type: resultType)) |
| 355 | return failure(); |
| 356 | |
| 357 | Location loc = op.getLoc(); |
| 358 | |
| 359 | // Handle the scalar and 1D vector cases. Because LLVM does not have a |
| 360 | // matching extended multiplication intrinsic, perform regular multiplication |
| 361 | // on operands zero-extended to i(2*N) bits, and truncate the results back to |
| 362 | // iN types. |
| 363 | if (!isa<LLVM::LLVMArrayType>(resultType)) { |
| 364 | // Shift amount necessary to extract the high bits from widened result. |
| 365 | TypedAttr shiftValAttr; |
| 366 | |
| 367 | if (auto intTy = dyn_cast<IntegerType>(resultType)) { |
| 368 | unsigned resultBitwidth = intTy.getWidth(); |
| 369 | auto attrTy = rewriter.getIntegerType(resultBitwidth * 2); |
| 370 | shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth); |
| 371 | } else { |
| 372 | auto vecTy = cast<VectorType>(resultType); |
| 373 | unsigned resultBitwidth = vecTy.getElementTypeBitWidth(); |
| 374 | auto attrTy = VectorType::get( |
| 375 | vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2)); |
| 376 | shiftValAttr = SplatElementsAttr::get( |
| 377 | attrTy, APInt(resultBitwidth * 2, resultBitwidth)); |
| 378 | } |
| 379 | Type wideType = shiftValAttr.getType(); |
| 380 | assert(LLVM::isCompatibleType(wideType) && |
| 381 | "LLVM dialect should support all signless integer types" ); |
| 382 | |
| 383 | using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>; |
| 384 | Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs()); |
| 385 | Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs()); |
| 386 | Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt); |
| 387 | |
| 388 | // Split the 2*N-bit wide result into two N-bit values. |
| 389 | Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt); |
| 390 | Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr); |
| 391 | Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal); |
| 392 | Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt); |
| 393 | |
| 394 | rewriter.replaceOp(op, {low, high}); |
| 395 | return success(); |
| 396 | } |
| 397 | |
| 398 | if (!isa<VectorType>(Val: resultType)) |
| 399 | return rewriter.notifyMatchFailure(op, "expected vector result type" ); |
| 400 | |
| 401 | return rewriter.notifyMatchFailure(op, |
| 402 | "ND vector types are not supported yet" ); |
| 403 | } |
| 404 | |
| 405 | //===----------------------------------------------------------------------===// |
| 406 | // CmpIOpLowering |
| 407 | //===----------------------------------------------------------------------===// |
| 408 | |
| 409 | // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums |
| 410 | // share numerical values so just cast. |
| 411 | template <typename LLVMPredType, typename PredType> |
| 412 | static LLVMPredType convertCmpPredicate(PredType pred) { |
| 413 | return static_cast<LLVMPredType>(pred); |
| 414 | } |
| 415 | |
| 416 | LogicalResult |
| 417 | CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
| 418 | ConversionPatternRewriter &rewriter) const { |
| 419 | Type operandType = adaptor.getLhs().getType(); |
| 420 | Type resultType = op.getResult().getType(); |
| 421 | |
| 422 | // Handle the scalar and 1D vector cases. |
| 423 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
| 424 | rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( |
| 425 | op, typeConverter->convertType(resultType), |
| 426 | convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), |
| 427 | adaptor.getLhs(), adaptor.getRhs()); |
| 428 | return success(); |
| 429 | } |
| 430 | |
| 431 | if (!isa<VectorType>(Val: resultType)) |
| 432 | return rewriter.notifyMatchFailure(op, "expected vector result type" ); |
| 433 | |
| 434 | return LLVM::detail::handleMultidimensionalVectors( |
| 435 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(), |
| 436 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) { |
| 437 | OpAdaptor adaptor(operands); |
| 438 | return rewriter.create<LLVM::ICmpOp>( |
| 439 | op.getLoc(), llvm1DVectorTy, |
| 440 | convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), |
| 441 | adaptor.getLhs(), adaptor.getRhs()); |
| 442 | }, |
| 443 | rewriter); |
| 444 | } |
| 445 | |
| 446 | //===----------------------------------------------------------------------===// |
| 447 | // CmpFOpLowering |
| 448 | //===----------------------------------------------------------------------===// |
| 449 | |
| 450 | LogicalResult |
| 451 | CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
| 452 | ConversionPatternRewriter &rewriter) const { |
| 453 | Type operandType = adaptor.getLhs().getType(); |
| 454 | Type resultType = op.getResult().getType(); |
| 455 | LLVM::FastmathFlags fmf = |
| 456 | arith::convertArithFastMathFlagsToLLVM(op.getFastmath()); |
| 457 | |
| 458 | // Handle the scalar and 1D vector cases. |
| 459 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
| 460 | rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( |
| 461 | op, typeConverter->convertType(resultType), |
| 462 | convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), |
| 463 | adaptor.getLhs(), adaptor.getRhs(), fmf); |
| 464 | return success(); |
| 465 | } |
| 466 | |
| 467 | if (!isa<VectorType>(Val: resultType)) |
| 468 | return rewriter.notifyMatchFailure(op, "expected vector result type" ); |
| 469 | |
| 470 | return LLVM::detail::handleMultidimensionalVectors( |
| 471 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(), |
| 472 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) { |
| 473 | OpAdaptor adaptor(operands); |
| 474 | return rewriter.create<LLVM::FCmpOp>( |
| 475 | op.getLoc(), llvm1DVectorTy, |
| 476 | convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), |
| 477 | adaptor.getLhs(), adaptor.getRhs(), fmf); |
| 478 | }, |
| 479 | rewriter); |
| 480 | } |
| 481 | |
| 482 | //===----------------------------------------------------------------------===// |
| 483 | // Pass Definition |
| 484 | //===----------------------------------------------------------------------===// |
| 485 | |
| 486 | namespace { |
| 487 | struct ArithToLLVMConversionPass |
| 488 | : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> { |
| 489 | using Base::Base; |
| 490 | |
| 491 | void runOnOperation() override { |
| 492 | LLVMConversionTarget target(getContext()); |
| 493 | RewritePatternSet patterns(&getContext()); |
| 494 | |
| 495 | LowerToLLVMOptions options(&getContext()); |
| 496 | if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
| 497 | options.overrideIndexBitwidth(indexBitwidth); |
| 498 | |
| 499 | LLVMTypeConverter converter(&getContext(), options); |
| 500 | arith::populateCeilFloorDivExpandOpsPatterns(patterns); |
| 501 | arith::populateArithToLLVMConversionPatterns(converter, patterns); |
| 502 | |
| 503 | if (failed(applyPartialConversion(getOperation(), target, |
| 504 | std::move(patterns)))) |
| 505 | signalPassFailure(); |
| 506 | } |
| 507 | }; |
| 508 | } // namespace |
| 509 | |
| 510 | //===----------------------------------------------------------------------===// |
| 511 | // ConvertToLLVMPatternInterface implementation |
| 512 | //===----------------------------------------------------------------------===// |
| 513 | |
| 514 | namespace { |
| 515 | /// Implement the interface to convert MemRef to LLVM. |
| 516 | struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
| 517 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
| 518 | void loadDependentDialects(MLIRContext *context) const final { |
| 519 | context->loadDialect<LLVM::LLVMDialect>(); |
| 520 | } |
| 521 | |
| 522 | /// Hook for derived dialect interface to provide conversion patterns |
| 523 | /// and mark dialect legal for the conversion target. |
| 524 | void populateConvertToLLVMConversionPatterns( |
| 525 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
| 526 | RewritePatternSet &patterns) const final { |
| 527 | arith::populateCeilFloorDivExpandOpsPatterns(patterns); |
| 528 | arith::populateArithToLLVMConversionPatterns(converter: typeConverter, patterns); |
| 529 | } |
| 530 | }; |
| 531 | } // namespace |
| 532 | |
| 533 | void mlir::arith::registerConvertArithToLLVMInterface( |
| 534 | DialectRegistry ®istry) { |
| 535 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, arith::ArithDialect *dialect) { |
| 536 | dialect->addInterfaces<ArithToLLVMDialectInterface>(); |
| 537 | }); |
| 538 | } |
| 539 | |
| 540 | //===----------------------------------------------------------------------===// |
| 541 | // Pattern Population |
| 542 | //===----------------------------------------------------------------------===// |
| 543 | |
| 544 | void mlir::arith::populateArithToLLVMConversionPatterns( |
| 545 | const LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
| 546 | |
| 547 | // Set a higher pattern benefit for IdentityBitcastLowering so it will run |
| 548 | // before BitcastOpLowering. |
| 549 | patterns.add<IdentityBitcastLowering>(arg: converter, args: patterns.getContext(), |
| 550 | /*patternBenefit*/ args: 10); |
| 551 | |
| 552 | // clang-format off |
| 553 | patterns.add< |
| 554 | AddFOpLowering, |
| 555 | AddIOpLowering, |
| 556 | AndIOpLowering, |
| 557 | AddUIExtendedOpLowering, |
| 558 | BitcastOpLowering, |
| 559 | ConstantOpLowering, |
| 560 | CmpFOpLowering, |
| 561 | CmpIOpLowering, |
| 562 | DivFOpLowering, |
| 563 | DivSIOpLowering, |
| 564 | DivUIOpLowering, |
| 565 | ExtFOpLowering, |
| 566 | ExtSIOpLowering, |
| 567 | ExtUIOpLowering, |
| 568 | FPToSIOpLowering, |
| 569 | FPToUIOpLowering, |
| 570 | IndexCastOpSILowering, |
| 571 | IndexCastOpUILowering, |
| 572 | MaximumFOpLowering, |
| 573 | MaxNumFOpLowering, |
| 574 | MaxSIOpLowering, |
| 575 | MaxUIOpLowering, |
| 576 | MinimumFOpLowering, |
| 577 | MinNumFOpLowering, |
| 578 | MinSIOpLowering, |
| 579 | MinUIOpLowering, |
| 580 | MulFOpLowering, |
| 581 | MulIOpLowering, |
| 582 | MulSIExtendedOpLowering, |
| 583 | MulUIExtendedOpLowering, |
| 584 | NegFOpLowering, |
| 585 | OrIOpLowering, |
| 586 | RemFOpLowering, |
| 587 | RemSIOpLowering, |
| 588 | RemUIOpLowering, |
| 589 | SelectOpLowering, |
| 590 | ShLIOpLowering, |
| 591 | ShRSIOpLowering, |
| 592 | ShRUIOpLowering, |
| 593 | SIToFPOpLowering, |
| 594 | SubFOpLowering, |
| 595 | SubIOpLowering, |
| 596 | TruncFOpLowering, |
| 597 | ConstrainedTruncFOpLowering, |
| 598 | TruncIOpLowering, |
| 599 | UIToFPOpLowering, |
| 600 | XOrIOpLowering |
| 601 | >(converter); |
| 602 | // clang-format on |
| 603 | } |
| 604 | |