1 | //===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements lowering patterns from vector.contract to |
10 | // arm_neon.intr.smmla |
11 | // |
12 | //===--- |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" |
16 | #include "mlir/Dialect/ArmNeon/Transforms.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
19 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
20 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
21 | #include "mlir/IR/AffineMap.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
24 | |
25 | #define DEBUG_TYPE "lower-contract-to-arm-neon" |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::arm_neon; |
29 | |
30 | namespace { |
31 | |
32 | /// Return the shaped type with new element type. |
33 | static Type matchContainerType(Type element, Type container) { |
34 | if (auto shapedTy = dyn_cast<ShapedType>(container)) { |
35 | return shapedTy.clone(element); |
36 | } |
37 | return element; |
38 | } |
39 | |
40 | /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile |
41 | /// any vector.contract into multiple smmla instructions with unrolling so long |
42 | /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM |
43 | /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is |
44 | /// necessary, a single smmla instruction is emitted. |
45 | class LowerContractionToSMMLAPattern |
46 | : public OpRewritePattern<vector::ContractionOp> { |
47 | public: |
48 | using OpRewritePattern::OpRewritePattern; |
49 | LogicalResult matchAndRewrite(vector::ContractionOp op, |
50 | PatternRewriter &rewriter) const override { |
51 | Location loc = op.getLoc(); |
52 | // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim. |
53 | // Note: RHS is not transposed. |
54 | mlir::VectorType lhsType = op.getLhsType(); |
55 | mlir::VectorType rhsType = op.getRhsType(); |
56 | // Avoid 0-D vectors and 1-D rhs: |
57 | if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) |
58 | return failure(); |
59 | // This codegen does not work for scalable vectors. Return failure so this |
60 | // pattern is not accidentally chosen over patterns that lower to ArmSVE. |
61 | if (lhsType.isScalable() || rhsType.isScalable()) |
62 | return failure(); |
63 | auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); |
64 | auto dimN = rhsType.getDimSize(0); |
65 | auto dimK = rhsType.getDimSize(1); |
66 | bool isVecmat = dimM == 1 ? true : false; |
67 | if (lhsType.getDimSize(lhsType.getRank() - 1) != |
68 | rhsType.getDimSize(rhsType.getRank() - 1)) { |
69 | return failure(); // dimK mismatch |
70 | } |
71 | // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for |
72 | // tiling. |
73 | if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) { |
74 | return failure(); |
75 | } |
76 | |
77 | // Check iterator types for contract. All iterators except inner-most |
78 | // dimension must be parallel. |
79 | auto iteratorTypes = op.getIteratorTypesArray(); |
80 | if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] != |
81 | vector::IteratorType::reduction) { |
82 | return failure(); |
83 | } |
84 | if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1), |
85 | [](vector::IteratorType iteratorType) { |
86 | return iteratorType != vector::IteratorType::parallel; |
87 | })) { |
88 | return failure(); |
89 | } |
90 | |
91 | // Check two extsi inputs Rhs Lhs for contract. |
92 | arith::ExtSIOp origLhsExtOp = |
93 | dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp()); |
94 | arith::ExtSIOp origRhsExtOp = |
95 | dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp()); |
96 | if (!origLhsExtOp || !origRhsExtOp) { |
97 | return failure(); |
98 | } |
99 | |
100 | // Match any iX to i32 for X<8 then turn into an i8 output. Feed into |
101 | // following neon instruction. Check inputs for extsi are <=i8 |
102 | Value extsiLhs; |
103 | Value extsiRhs; |
104 | if (auto lhsExtInType = |
105 | dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) { |
106 | if (lhsExtInType.getElementTypeBitWidth() <= 8) { |
107 | Type targetLhsExtTy = |
108 | matchContainerType(rewriter.getI8Type(), lhsExtInType); |
109 | extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy, |
110 | origLhsExtOp.getIn()); |
111 | } |
112 | } |
113 | if (auto rhsExtInType = |
114 | dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) { |
115 | if (rhsExtInType.getElementTypeBitWidth() <= 8) { |
116 | Type targetRhsExtTy = |
117 | matchContainerType(rewriter.getI8Type(), rhsExtInType); |
118 | extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy, |
119 | origRhsExtOp.getIn()); |
120 | } |
121 | } |
122 | |
123 | if (!extsiLhs || !extsiRhs) { |
124 | return failure(); |
125 | } |
126 | |
127 | // Initial accumulator for the final result. This is the un-tiled result if |
128 | // tiling is done. |
129 | Value result = rewriter.create<arith::ConstantOp>( |
130 | loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); |
131 | |
132 | SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll(); |
133 | SmallVector<int64_t> smmlaShape = {2, 8}; |
134 | SmallVector<int64_t> loopOrder = {0, 1}; |
135 | if (unrolledSize.size() == 3) { |
136 | smmlaShape.insert(I: smmlaShape.begin(), Elt: isVecmat ? 1 : 2); |
137 | loopOrder.push_back(Elt: 2); |
138 | } |
139 | |
140 | // Keep track of the previous accumulator when tiling over K. |
141 | Value kAcc; |
142 | for (SmallVector<int64_t> offsets : |
143 | StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { |
144 | // Helper to compute the new shape of each operand and extract the slice. |
145 | auto extractOperand = [&](Value operand, AffineMap permutationMap, |
146 | ArrayRef<int64_t> operandOffsets) { |
147 | SmallVector<int64_t> operandShape = |
148 | applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape)); |
149 | SmallVector<int64_t> operandStrides(operandOffsets.size(), 1); |
150 | return rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
151 | loc, operand, operandOffsets, operandShape, operandStrides); |
152 | }; |
153 | |
154 | // Extract tiled lhs, rhs, and acc |
155 | AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0]; |
156 | SmallVector<int64_t> lhsOffsets = |
157 | applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); |
158 | Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets); |
159 | AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1]; |
160 | SmallVector<int64_t> rhsOffsets = |
161 | applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); |
162 | Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets); |
163 | AffineMap accPermutationMap = op.getIndexingMapsArray()[2]; |
164 | SmallVector<int64_t> accOffsets = |
165 | applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); |
166 | Value tiledAcc = |
167 | extractOperand(op.getAcc(), accPermutationMap, accOffsets); |
168 | |
169 | auto inputElementType = |
170 | cast<ShapedType>(tiledLhs.getType()).getElementType(); |
171 | auto accElementType = |
172 | cast<ShapedType>(tiledAcc.getType()).getElementType(); |
173 | auto inputExpandedType = VectorType::get({2, 8}, inputElementType); |
174 | auto outputExpandedType = VectorType::get({2, 2}, accElementType); |
175 | |
176 | // With vecmat, tiled LHS and ACC will contain only one of 2 necessary |
177 | // rows along dimM. Expand their shapes to match the smmla op. |
178 | if (isVecmat) { |
179 | auto expandForSMMLA = [&](Value tiledOperand, |
180 | VectorType expandedTypeType) { |
181 | auto emptyOperand = rewriter.create<arith::ConstantOp>( |
182 | loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); |
183 | SmallVector<int64_t> offsets( |
184 | cast<ShapedType>(emptyOperand.getType()).getRank(), 0); |
185 | SmallVector<int64_t> strides( |
186 | cast<ShapedType>(tiledOperand.getType()).getRank(), 1); |
187 | return rewriter.createOrFold<vector::InsertStridedSliceOp>( |
188 | loc, tiledOperand, emptyOperand, offsets, strides); |
189 | }; |
190 | tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType); |
191 | tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType); |
192 | } |
193 | |
194 | // Collapse tiled operands to 1D vectors required by smmla intrinsic |
195 | auto collapsedInputType = |
196 | VectorType::get(inputExpandedType.getNumElements(), inputElementType); |
197 | auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>( |
198 | tiledLhs.getLoc(), collapsedInputType, tiledLhs); |
199 | auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>( |
200 | tiledRhs.getLoc(), collapsedInputType, tiledRhs); |
201 | auto collapsedOutputType = |
202 | VectorType::get(outputExpandedType.getNumElements(), accElementType); |
203 | |
204 | bool initialKAcc = offsets.back() == 0; |
205 | Value collapsedRes; |
206 | if (!initialKAcc) { |
207 | collapsedRes = kAcc; |
208 | } else { |
209 | collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>( |
210 | tiledAcc.getLoc(), collapsedOutputType, tiledAcc); |
211 | } |
212 | |
213 | // Insert contract op |
214 | kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>( |
215 | op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs, |
216 | collapsedRhs); |
217 | |
218 | // Reshape output back to 2D |
219 | Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>( |
220 | kAcc.getLoc(), tiledAcc.getType(), kAcc); |
221 | |
222 | // With vecmat, only one row of tiled ACC can be inserted into file result |
223 | if (isVecmat) { |
224 | tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0); |
225 | } |
226 | |
227 | // Insert the tiled result back into the non tiled result of the |
228 | // contract op. |
229 | SmallVector<int64_t> strides( |
230 | cast<ShapedType>(tiledRes.getType()).getRank(), 1); |
231 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
232 | loc, tiledRes, result, accOffsets, strides); |
233 | } |
234 | |
235 | rewriter.replaceOp(op, result); |
236 | return success(); |
237 | } |
238 | }; |
239 | |
240 | } // namespace |
241 | |
242 | void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns( |
243 | RewritePatternSet &patterns) { |
244 | MLIRContext *context = patterns.getContext(); |
245 | patterns.add<LowerContractionToSMMLAPattern>(arg&: context, /*benefit=*/args: 2); |
246 | } |
247 | |