| 1 | //===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements lowering patterns from vector.contract to |
| 10 | // arm_neon.intr.smmla |
| 11 | // |
| 12 | //===--- |
| 13 | |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" |
| 16 | #include "mlir/Dialect/ArmNeon/Transforms.h" |
| 17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 19 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 20 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 21 | #include "mlir/IR/AffineMap.h" |
| 22 | #include "mlir/IR/PatternMatch.h" |
| 23 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 24 | |
| 25 | #define DEBUG_TYPE "lower-contract-to-arm-neon" |
| 26 | |
| 27 | using namespace mlir; |
| 28 | using namespace mlir::arm_neon; |
| 29 | |
| 30 | namespace { |
| 31 | |
| 32 | /// Return the shaped type with new element type. |
| 33 | static Type matchContainerType(Type element, Type container) { |
| 34 | if (auto shapedTy = dyn_cast<ShapedType>(container)) { |
| 35 | return shapedTy.clone(element); |
| 36 | } |
| 37 | return element; |
| 38 | } |
| 39 | |
| 40 | /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile |
| 41 | /// any vector.contract into multiple smmla instructions with unrolling so long |
| 42 | /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM |
| 43 | /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is |
| 44 | /// necessary, a single smmla instruction is emitted. |
| 45 | class LowerContractionToSMMLAPattern |
| 46 | : public OpRewritePattern<vector::ContractionOp> { |
| 47 | public: |
| 48 | using OpRewritePattern::OpRewritePattern; |
| 49 | LogicalResult matchAndRewrite(vector::ContractionOp op, |
| 50 | PatternRewriter &rewriter) const override { |
| 51 | Location loc = op.getLoc(); |
| 52 | // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim. |
| 53 | // Note: RHS is not transposed. |
| 54 | mlir::VectorType lhsType = op.getLhsType(); |
| 55 | mlir::VectorType rhsType = op.getRhsType(); |
| 56 | // Avoid 0-D vectors and 1-D rhs: |
| 57 | if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) |
| 58 | return failure(); |
| 59 | // This codegen does not work for scalable vectors. Return failure so this |
| 60 | // pattern is not accidentally chosen over patterns that lower to ArmSVE. |
| 61 | if (lhsType.isScalable() || rhsType.isScalable()) |
| 62 | return failure(); |
| 63 | auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); |
| 64 | auto dimN = rhsType.getDimSize(0); |
| 65 | auto dimK = rhsType.getDimSize(1); |
| 66 | bool isVecmat = dimM == 1 ? true : false; |
| 67 | if (lhsType.getDimSize(lhsType.getRank() - 1) != |
| 68 | rhsType.getDimSize(rhsType.getRank() - 1)) { |
| 69 | return failure(); // dimK mismatch |
| 70 | } |
| 71 | // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for |
| 72 | // tiling. |
| 73 | if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) { |
| 74 | return failure(); |
| 75 | } |
| 76 | |
| 77 | // Check iterator types for contract. All iterators except inner-most |
| 78 | // dimension must be parallel. |
| 79 | auto iteratorTypes = op.getIteratorTypesArray(); |
| 80 | if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] != |
| 81 | vector::IteratorType::reduction) { |
| 82 | return failure(); |
| 83 | } |
| 84 | if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1), |
| 85 | [](vector::IteratorType iteratorType) { |
| 86 | return iteratorType != vector::IteratorType::parallel; |
| 87 | })) { |
| 88 | return failure(); |
| 89 | } |
| 90 | |
| 91 | // Check two extsi inputs Rhs Lhs for contract. |
| 92 | arith::ExtSIOp origLhsExtOp = |
| 93 | dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp()); |
| 94 | arith::ExtSIOp origRhsExtOp = |
| 95 | dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp()); |
| 96 | if (!origLhsExtOp || !origRhsExtOp) { |
| 97 | return failure(); |
| 98 | } |
| 99 | |
| 100 | // Match any iX to i32 for X<8 then turn into an i8 output. Feed into |
| 101 | // following neon instruction. Check inputs for extsi are <=i8 |
| 102 | Value extsiLhs; |
| 103 | Value extsiRhs; |
| 104 | if (auto lhsExtInType = |
| 105 | dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) { |
| 106 | if (lhsExtInType.getElementTypeBitWidth() <= 8) { |
| 107 | Type targetLhsExtTy = |
| 108 | matchContainerType(rewriter.getI8Type(), lhsExtInType); |
| 109 | extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy, |
| 110 | origLhsExtOp.getIn()); |
| 111 | } |
| 112 | } |
| 113 | if (auto rhsExtInType = |
| 114 | dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) { |
| 115 | if (rhsExtInType.getElementTypeBitWidth() <= 8) { |
| 116 | Type targetRhsExtTy = |
| 117 | matchContainerType(rewriter.getI8Type(), rhsExtInType); |
| 118 | extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy, |
| 119 | origRhsExtOp.getIn()); |
| 120 | } |
| 121 | } |
| 122 | |
| 123 | if (!extsiLhs || !extsiRhs) { |
| 124 | return failure(); |
| 125 | } |
| 126 | |
| 127 | // Initial accumulator for the final result. This is the un-tiled result if |
| 128 | // tiling is done. |
| 129 | Value result = rewriter.create<arith::ConstantOp>( |
| 130 | loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); |
| 131 | |
| 132 | SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll(); |
| 133 | SmallVector<int64_t> smmlaShape = {2, 8}; |
| 134 | SmallVector<int64_t> loopOrder = {0, 1}; |
| 135 | if (unrolledSize.size() == 3) { |
| 136 | smmlaShape.insert(I: smmlaShape.begin(), Elt: isVecmat ? 1 : 2); |
| 137 | loopOrder.push_back(Elt: 2); |
| 138 | } |
| 139 | |
| 140 | // Keep track of the previous accumulator when tiling over K. |
| 141 | Value kAcc; |
| 142 | for (SmallVector<int64_t> offsets : |
| 143 | StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { |
| 144 | // Helper to compute the new shape of each operand and extract the slice. |
| 145 | auto extractOperand = [&](Value operand, AffineMap permutationMap, |
| 146 | ArrayRef<int64_t> operandOffsets) { |
| 147 | SmallVector<int64_t> operandShape = |
| 148 | applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape)); |
| 149 | SmallVector<int64_t> operandStrides(operandOffsets.size(), 1); |
| 150 | return rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 151 | loc, operand, operandOffsets, operandShape, operandStrides); |
| 152 | }; |
| 153 | |
| 154 | // Extract tiled lhs, rhs, and acc |
| 155 | AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0]; |
| 156 | SmallVector<int64_t> lhsOffsets = |
| 157 | applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); |
| 158 | Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets); |
| 159 | AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1]; |
| 160 | SmallVector<int64_t> rhsOffsets = |
| 161 | applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); |
| 162 | Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets); |
| 163 | AffineMap accPermutationMap = op.getIndexingMapsArray()[2]; |
| 164 | SmallVector<int64_t> accOffsets = |
| 165 | applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); |
| 166 | Value tiledAcc = |
| 167 | extractOperand(op.getAcc(), accPermutationMap, accOffsets); |
| 168 | |
| 169 | auto inputElementType = |
| 170 | cast<ShapedType>(tiledLhs.getType()).getElementType(); |
| 171 | auto accElementType = |
| 172 | cast<ShapedType>(tiledAcc.getType()).getElementType(); |
| 173 | auto inputExpandedType = VectorType::get({2, 8}, inputElementType); |
| 174 | auto outputExpandedType = VectorType::get({2, 2}, accElementType); |
| 175 | |
| 176 | // With vecmat, tiled LHS and ACC will contain only one of 2 necessary |
| 177 | // rows along dimM. Expand their shapes to match the smmla op. |
| 178 | if (isVecmat) { |
| 179 | auto expandForSMMLA = [&](Value tiledOperand, |
| 180 | VectorType expandedTypeType) { |
| 181 | auto emptyOperand = rewriter.create<arith::ConstantOp>( |
| 182 | loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); |
| 183 | SmallVector<int64_t> offsets( |
| 184 | cast<ShapedType>(emptyOperand.getType()).getRank(), 0); |
| 185 | SmallVector<int64_t> strides( |
| 186 | cast<ShapedType>(tiledOperand.getType()).getRank(), 1); |
| 187 | return rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 188 | loc, tiledOperand, emptyOperand, offsets, strides); |
| 189 | }; |
| 190 | tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType); |
| 191 | tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType); |
| 192 | } |
| 193 | |
| 194 | // Collapse tiled operands to 1D vectors required by smmla intrinsic |
| 195 | auto collapsedInputType = |
| 196 | VectorType::get(inputExpandedType.getNumElements(), inputElementType); |
| 197 | auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>( |
| 198 | tiledLhs.getLoc(), collapsedInputType, tiledLhs); |
| 199 | auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>( |
| 200 | tiledRhs.getLoc(), collapsedInputType, tiledRhs); |
| 201 | auto collapsedOutputType = |
| 202 | VectorType::get(outputExpandedType.getNumElements(), accElementType); |
| 203 | |
| 204 | bool initialKAcc = offsets.back() == 0; |
| 205 | Value collapsedRes; |
| 206 | if (!initialKAcc) { |
| 207 | collapsedRes = kAcc; |
| 208 | } else { |
| 209 | collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>( |
| 210 | tiledAcc.getLoc(), collapsedOutputType, tiledAcc); |
| 211 | } |
| 212 | |
| 213 | // Insert contract op |
| 214 | kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>( |
| 215 | op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs, |
| 216 | collapsedRhs); |
| 217 | |
| 218 | // Reshape output back to 2D |
| 219 | Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>( |
| 220 | kAcc.getLoc(), tiledAcc.getType(), kAcc); |
| 221 | |
| 222 | // With vecmat, only one row of tiled ACC can be inserted into file result |
| 223 | if (isVecmat) { |
| 224 | tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0); |
| 225 | } |
| 226 | |
| 227 | // Insert the tiled result back into the non tiled result of the |
| 228 | // contract op. |
| 229 | SmallVector<int64_t> strides( |
| 230 | cast<ShapedType>(tiledRes.getType()).getRank(), 1); |
| 231 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 232 | loc, tiledRes, result, accOffsets, strides); |
| 233 | } |
| 234 | |
| 235 | rewriter.replaceOp(op, result); |
| 236 | return success(); |
| 237 | } |
| 238 | }; |
| 239 | |
| 240 | } // namespace |
| 241 | |
| 242 | void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns( |
| 243 | RewritePatternSet &patterns) { |
| 244 | MLIRContext *context = patterns.getContext(); |
| 245 | patterns.add<LowerContractionToSMMLAPattern>(arg&: context, /*benefit=*/args: 2); |
| 246 | } |
| 247 | |