| 1 | //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| 10 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 11 | #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" |
| 12 | #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" |
| 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 14 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 15 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 16 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 17 | #include "mlir/IR/BuiltinOps.h" |
| 18 | #include "mlir/IR/PatternMatch.h" |
| 19 | |
| 20 | using namespace mlir; |
| 21 | using namespace mlir::arm_sve; |
| 22 | |
| 23 | using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; |
| 24 | using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; |
| 25 | using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; |
| 26 | using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; |
| 27 | using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>; |
| 28 | using DupQLaneLowering = |
| 29 | OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>; |
| 30 | using ScalableMaskedAddIOpLowering = |
| 31 | OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, |
| 32 | ScalableMaskedAddIIntrOp>; |
| 33 | using ScalableMaskedAddFOpLowering = |
| 34 | OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, |
| 35 | ScalableMaskedAddFIntrOp>; |
| 36 | using ScalableMaskedSubIOpLowering = |
| 37 | OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, |
| 38 | ScalableMaskedSubIIntrOp>; |
| 39 | using ScalableMaskedSubFOpLowering = |
| 40 | OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, |
| 41 | ScalableMaskedSubFIntrOp>; |
| 42 | using ScalableMaskedMulIOpLowering = |
| 43 | OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, |
| 44 | ScalableMaskedMulIIntrOp>; |
| 45 | using ScalableMaskedMulFOpLowering = |
| 46 | OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, |
| 47 | ScalableMaskedMulFIntrOp>; |
| 48 | using ScalableMaskedSDivIOpLowering = |
| 49 | OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, |
| 50 | ScalableMaskedSDivIIntrOp>; |
| 51 | using ScalableMaskedUDivIOpLowering = |
| 52 | OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, |
| 53 | ScalableMaskedUDivIIntrOp>; |
| 54 | using ScalableMaskedDivFOpLowering = |
| 55 | OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, |
| 56 | ScalableMaskedDivFIntrOp>; |
| 57 | |
| 58 | namespace { |
| 59 | |
| 60 | /// Unrolls a conversion to/from equivalent vector types, to allow using a |
| 61 | /// conversion intrinsic that only supports 1-D vector types. |
| 62 | /// |
| 63 | /// Example: |
| 64 | /// ``` |
| 65 | /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1> |
| 66 | /// ``` |
| 67 | /// is rewritten into: |
| 68 | /// ``` |
| 69 | /// %cst = arith.constant dense<false> : vector<2x[16]xi1> |
| 70 | /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1> |
| 71 | /// %2 = "arm_sve.intr.convert.to.svbool"(%1) |
| 72 | /// : (vector<[4]xi1>) -> vector<[16]xi1> |
| 73 | /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1> |
| 74 | /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1> |
| 75 | /// %5 = "arm_sve.intr.convert.to.svbool"(%4) |
| 76 | /// : (vector<[4]xi1>) -> vector<[16]xi1> |
| 77 | /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1> |
| 78 | /// ``` |
| 79 | template <typename Op, typename IntrOp> |
| 80 | struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> { |
| 81 | using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern; |
| 82 | |
| 83 | LogicalResult |
| 84 | matchAndRewrite(Op convertOp, typename Op::Adaptor, |
| 85 | ConversionPatternRewriter &rewriter) const override { |
| 86 | auto loc = convertOp.getLoc(); |
| 87 | |
| 88 | auto source = convertOp.getSource(); |
| 89 | VectorType sourceType = source.getType(); |
| 90 | VectorType resultType = convertOp.getResult().getType(); |
| 91 | |
| 92 | Value result = rewriter.create<arith::ConstantOp>( |
| 93 | loc, resultType, rewriter.getZeroAttr(resultType)); |
| 94 | |
| 95 | // We want to iterate over the input vector in steps of the trailing |
| 96 | // dimension. So this creates tile shape where all leading dimensions are 1, |
| 97 | // and the trailing dimension step is the size of the dimension. |
| 98 | SmallVector<int64_t> tileShape(sourceType.getRank(), 1); |
| 99 | tileShape.back() = sourceType.getShape().back(); |
| 100 | |
| 101 | // Iterate over all scalable mask/predicate slices of the source vector. |
| 102 | for (SmallVector<int64_t> index : |
| 103 | StaticTileOffsetRange(sourceType.getShape(), tileShape)) { |
| 104 | auto extractOrInsertPosition = ArrayRef(index).drop_back(); |
| 105 | auto sourceVector = rewriter.create<vector::ExtractOp>( |
| 106 | loc, source, extractOrInsertPosition); |
| 107 | VectorType convertedType = |
| 108 | VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType())) |
| 109 | .setDim(0, resultType.getShape().back()); |
| 110 | auto convertedVector = |
| 111 | rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector); |
| 112 | result = rewriter.create<vector::InsertOp>(loc, convertedVector, result, |
| 113 | extractOrInsertPosition); |
| 114 | } |
| 115 | |
| 116 | rewriter.replaceOp(convertOp, result); |
| 117 | return success(); |
| 118 | } |
| 119 | }; |
| 120 | |
| 121 | using ConvertToSvboolOpLowering = |
| 122 | SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>; |
| 123 | |
| 124 | using ConvertFromSvboolOpLowering = |
| 125 | SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>; |
| 126 | |
| 127 | using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>; |
| 128 | using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>; |
| 129 | |
| 130 | /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion |
| 131 | /// but first input (P1) and result predicates need conversion to/from svbool. |
| 132 | struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> { |
| 133 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 134 | |
| 135 | LogicalResult |
| 136 | matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor, |
| 137 | ConversionPatternRewriter &rewriter) const override { |
| 138 | auto svboolType = VectorType::get(16, rewriter.getI1Type(), true); |
| 139 | auto loc = pselOp.getLoc(); |
| 140 | auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType, |
| 141 | adaptor.getP1()); |
| 142 | auto indexI32 = rewriter.create<arith::IndexCastOp>( |
| 143 | loc, rewriter.getI32Type(), pselOp.getIndex()); |
| 144 | auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1, |
| 145 | pselOp.getP2(), indexI32); |
| 146 | rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>( |
| 147 | pselOp, adaptor.getP1().getType(), pselIntr); |
| 148 | return success(); |
| 149 | } |
| 150 | }; |
| 151 | |
| 152 | /// Converts `vector.create_mask` ops that match the size of an SVE predicate |
| 153 | /// to the `whilelt` intrinsic. This produces more canonical codegen than the |
| 154 | /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840 |
| 155 | /// for more details. Note that we can't use (the more general) active.lane.mask |
| 156 | /// as its semantics don't neatly map on to `vector.create_mask`, as it does an |
| 157 | /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if |
| 158 | /// `n` is zero (whereas `create_mask` just returns an all-false mask). |
| 159 | struct CreateMaskOpLowering |
| 160 | : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { |
| 161 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 162 | |
| 163 | LogicalResult |
| 164 | matchAndRewrite(vector::CreateMaskOp createMaskOp, |
| 165 | vector::CreateMaskOp::Adaptor adaptor, |
| 166 | ConversionPatternRewriter &rewriter) const override { |
| 167 | auto maskType = createMaskOp.getVectorType(); |
| 168 | if (maskType.getRank() != 1 || !maskType.isScalable()) |
| 169 | return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable" ); |
| 170 | |
| 171 | // TODO: Support masks which are multiples of SVE predicates. |
| 172 | auto maskBaseSize = maskType.getDimSize(0); |
| 173 | if (maskBaseSize < 2 || maskBaseSize > 16 || |
| 174 | !llvm::isPowerOf2_32(Value: uint32_t(maskBaseSize))) |
| 175 | return rewriter.notifyMatchFailure(createMaskOp, |
| 176 | "not SVE predicate-sized" ); |
| 177 | |
| 178 | auto loc = createMaskOp.getLoc(); |
| 179 | auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type()); |
| 180 | rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero, |
| 181 | adaptor.getOperands()[0]); |
| 182 | return success(); |
| 183 | } |
| 184 | }; |
| 185 | |
| 186 | } // namespace |
| 187 | |
| 188 | /// Populate the given list with patterns that convert from ArmSVE to LLVM. |
| 189 | void mlir::populateArmSVELegalizeForLLVMExportPatterns( |
| 190 | const LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
| 191 | // Populate conversion patterns |
| 192 | |
| 193 | // clang-format off |
| 194 | patterns.add<ConvertFromSvboolOpLowering, |
| 195 | ConvertToSvboolOpLowering, |
| 196 | DupQLaneLowering, |
| 197 | PselOpLowering, |
| 198 | ScalableMaskedAddFOpLowering, |
| 199 | ScalableMaskedAddIOpLowering, |
| 200 | ScalableMaskedDivFOpLowering, |
| 201 | ScalableMaskedMulFOpLowering, |
| 202 | ScalableMaskedMulIOpLowering, |
| 203 | ScalableMaskedSDivIOpLowering, |
| 204 | ScalableMaskedSubFOpLowering, |
| 205 | ScalableMaskedSubIOpLowering, |
| 206 | ScalableMaskedUDivIOpLowering, |
| 207 | SmmlaOpLowering, |
| 208 | UdotOpLowering, |
| 209 | UmmlaOpLowering, |
| 210 | UsmmlaOpLowering, |
| 211 | ZipX2OpLowering, |
| 212 | ZipX4OpLowering, |
| 213 | SdotOpLowering>(converter); |
| 214 | // Add vector.create_mask conversion with a high benefit as it produces much |
| 215 | // nicer code than the generic lowering. |
| 216 | patterns.add<CreateMaskOpLowering>(arg: converter, /*benefit=*/args: 4096); |
| 217 | // clang-format on |
| 218 | } |
| 219 | |
| 220 | void mlir::configureArmSVELegalizeForExportTarget( |
| 221 | LLVMConversionTarget &target) { |
| 222 | // clang-format off |
| 223 | target.addLegalOp<ConvertFromSvboolIntrOp, |
| 224 | ConvertToSvboolIntrOp, |
| 225 | DupQLaneIntrOp, |
| 226 | PselIntrOp, |
| 227 | ScalableMaskedAddFIntrOp, |
| 228 | ScalableMaskedAddIIntrOp, |
| 229 | ScalableMaskedDivFIntrOp, |
| 230 | ScalableMaskedMulFIntrOp, |
| 231 | ScalableMaskedMulIIntrOp, |
| 232 | ScalableMaskedSDivIIntrOp, |
| 233 | ScalableMaskedSubFIntrOp, |
| 234 | ScalableMaskedSubIIntrOp, |
| 235 | ScalableMaskedUDivIIntrOp, |
| 236 | SmmlaIntrOp, |
| 237 | UdotIntrOp, |
| 238 | UmmlaIntrOp, |
| 239 | UsmmlaIntrOp, |
| 240 | WhileLTIntrOp, |
| 241 | ZipX2IntrOp, |
| 242 | ZipX4IntrOp, |
| 243 | SdotIntrOp>(); |
| 244 | target.addIllegalOp<ConvertFromSvboolOp, |
| 245 | ConvertToSvboolOp, |
| 246 | DupQLaneOp, |
| 247 | PselOp, |
| 248 | ScalableMaskedAddFOp, |
| 249 | ScalableMaskedAddIOp, |
| 250 | ScalableMaskedDivFOp, |
| 251 | ScalableMaskedMulFOp, |
| 252 | ScalableMaskedMulIOp, |
| 253 | ScalableMaskedSDivIOp, |
| 254 | ScalableMaskedSubFOp, |
| 255 | ScalableMaskedSubIOp, |
| 256 | ScalableMaskedUDivIOp, |
| 257 | SmmlaOp, |
| 258 | UdotOp, |
| 259 | UmmlaOp, |
| 260 | UsmmlaOp, |
| 261 | ZipX2Op, |
| 262 | ZipX4Op, |
| 263 | SdotOp>(); |
| 264 | // clang-format on |
| 265 | } |
| 266 | |