1 | //===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements lowering patterns from vector.contract to |
10 | // arm_neon.intr.smmla |
11 | // |
12 | //===--- |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" |
16 | #include "mlir/Dialect/ArmNeon/Transforms.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
19 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
20 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
21 | #include "mlir/IR/AffineMap.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/Support/LogicalResult.h" |
24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
25 | |
26 | #define DEBUG_TYPE "lower-contract-to-arm-neon" |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::arm_neon; |
30 | |
31 | namespace { |
32 | |
33 | /// Return the shaped type with new element type. |
34 | static Type matchContainerType(Type element, Type container) { |
35 | if (auto shapedTy = dyn_cast<ShapedType>(container)) { |
36 | return shapedTy.clone(element); |
37 | } |
38 | return element; |
39 | } |
40 | |
41 | /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile |
42 | /// any vector.contract into multiple smmla instructions with unrolling so long |
43 | /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM |
44 | /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is |
45 | /// necessary, a single smmla instruction is emitted. |
46 | class LowerContractionToSMMLAPattern |
47 | : public OpRewritePattern<vector::ContractionOp> { |
48 | public: |
49 | using OpRewritePattern::OpRewritePattern; |
50 | LogicalResult matchAndRewrite(vector::ContractionOp op, |
51 | PatternRewriter &rewriter) const override { |
52 | Location loc = op.getLoc(); |
53 | // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim. |
54 | // Note: RHS is not transposed. |
55 | mlir::VectorType lhsType = op.getLhsType(); |
56 | mlir::VectorType rhsType = op.getRhsType(); |
57 | // Avoid 0-D vectors and 1-D rhs: |
58 | if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) |
59 | return failure(); |
60 | auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); |
61 | auto dimN = rhsType.getDimSize(0); |
62 | auto dimK = rhsType.getDimSize(1); |
63 | bool isVecmat = dimM == 1 ? true : false; |
64 | if (lhsType.getDimSize(lhsType.getRank() - 1) != |
65 | rhsType.getDimSize(rhsType.getRank() - 1)) { |
66 | return failure(); // dimK mismatch |
67 | } |
68 | // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for |
69 | // tiling. |
70 | if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) { |
71 | return failure(); |
72 | } |
73 | |
74 | // Check iterator types for contract. All iterators except inner-most |
75 | // dimension must be parallel. |
76 | auto iteratorTypes = op.getIteratorTypesArray(); |
77 | if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] != |
78 | vector::IteratorType::reduction) { |
79 | return failure(); |
80 | } |
81 | if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1), |
82 | [](vector::IteratorType iteratorType) { |
83 | return iteratorType != vector::IteratorType::parallel; |
84 | })) { |
85 | return failure(); |
86 | } |
87 | |
88 | // Check two extsi inputs Rhs Lhs for contract. |
89 | arith::ExtSIOp origLhsExtOp = |
90 | dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp()); |
91 | arith::ExtSIOp origRhsExtOp = |
92 | dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp()); |
93 | if (!origLhsExtOp || !origRhsExtOp) { |
94 | return failure(); |
95 | } |
96 | |
97 | // Match any iX to i32 for X<8 then turn into an i8 output. Feed into |
98 | // following neon instruction. Check inputs for extsi are <=i8 |
99 | Value extsiLhs; |
100 | Value extsiRhs; |
101 | if (auto lhsExtInType = |
102 | dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) { |
103 | if (lhsExtInType.getElementTypeBitWidth() <= 8) { |
104 | Type targetLhsExtTy = |
105 | matchContainerType(rewriter.getI8Type(), lhsExtInType); |
106 | extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy, |
107 | origLhsExtOp.getIn()); |
108 | } |
109 | } |
110 | if (auto rhsExtInType = |
111 | dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) { |
112 | if (rhsExtInType.getElementTypeBitWidth() <= 8) { |
113 | Type targetRhsExtTy = |
114 | matchContainerType(rewriter.getI8Type(), rhsExtInType); |
115 | extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy, |
116 | origRhsExtOp.getIn()); |
117 | } |
118 | } |
119 | |
120 | if (!extsiLhs || !extsiRhs) { |
121 | return failure(); |
122 | } |
123 | |
124 | // Initial accumulator for the final result. This is the un-tiled result if |
125 | // tiling is done. |
126 | Value result = rewriter.create<arith::ConstantOp>( |
127 | loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); |
128 | |
129 | SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll(); |
130 | SmallVector<int64_t> smmlaShape{2, 8}; |
131 | SmallVector<int64_t> loopOrder{0, 1}; |
132 | if (unrolledSize.size() == 3) { |
133 | smmlaShape.insert(I: smmlaShape.begin(), Elt: isVecmat ? 1 : 2); |
134 | loopOrder.push_back(Elt: 2); |
135 | } |
136 | |
137 | // Keep track of the previous accumulator when tiling over K. |
138 | Value kAcc; |
139 | for (SmallVector<int64_t> offsets : |
140 | StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { |
141 | // Helper to compute the new shape of each operand and extract the slice. |
142 | auto extractOperand = [&](Value operand, AffineMap permutationMap, |
143 | ArrayRef<int64_t> operandOffsets) { |
144 | SmallVector<int64_t> operandShape = |
145 | applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape)); |
146 | SmallVector<int64_t> operandStrides(operandOffsets.size(), 1); |
147 | return rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
148 | loc, operand, operandOffsets, operandShape, operandStrides); |
149 | }; |
150 | |
151 | // Extract tiled lhs, rhs, and acc |
152 | AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0]; |
153 | SmallVector<int64_t> lhsOffsets = |
154 | applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); |
155 | Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets); |
156 | AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1]; |
157 | SmallVector<int64_t> rhsOffsets = |
158 | applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); |
159 | Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets); |
160 | AffineMap accPermutationMap = op.getIndexingMapsArray()[2]; |
161 | SmallVector<int64_t> accOffsets = |
162 | applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); |
163 | Value tiledAcc = |
164 | extractOperand(op.getAcc(), accPermutationMap, accOffsets); |
165 | |
166 | auto inputElementType = |
167 | cast<ShapedType>(tiledLhs.getType()).getElementType(); |
168 | auto accElementType = |
169 | cast<ShapedType>(tiledAcc.getType()).getElementType(); |
170 | auto inputExpandedType = VectorType::get({2, 8}, inputElementType); |
171 | auto outputExpandedType = VectorType::get({2, 2}, accElementType); |
172 | |
173 | // With vecmat, tiled LHS and ACC will contain only one of 2 necessary |
174 | // rows along dimM. Expand their shapes to match the smmla op. |
175 | if (isVecmat) { |
176 | auto expandForSMMLA = [&](Value tiledOperand, |
177 | VectorType expandedTypeType) { |
178 | auto emptyOperand = rewriter.create<arith::ConstantOp>( |
179 | loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); |
180 | SmallVector<int64_t> offsets( |
181 | cast<ShapedType>(emptyOperand.getType()).getRank(), 0); |
182 | SmallVector<int64_t> strides( |
183 | cast<ShapedType>(tiledOperand.getType()).getRank(), 1); |
184 | return rewriter.createOrFold<vector::InsertStridedSliceOp>( |
185 | loc, tiledOperand, emptyOperand, offsets, strides); |
186 | }; |
187 | tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType); |
188 | tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType); |
189 | } |
190 | |
191 | // Collapse tiled operands to 1D vectors required by smmla intrinsic |
192 | auto collapsedInputType = |
193 | VectorType::get(inputExpandedType.getNumElements(), inputElementType); |
194 | auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>( |
195 | tiledLhs.getLoc(), collapsedInputType, tiledLhs); |
196 | auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>( |
197 | tiledRhs.getLoc(), collapsedInputType, tiledRhs); |
198 | auto collapsedOutputType = |
199 | VectorType::get(outputExpandedType.getNumElements(), accElementType); |
200 | |
201 | bool initialKAcc = offsets.back() == 0; |
202 | Value collapsedRes; |
203 | if (!initialKAcc) { |
204 | collapsedRes = kAcc; |
205 | } else { |
206 | collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>( |
207 | tiledAcc.getLoc(), collapsedOutputType, tiledAcc); |
208 | } |
209 | |
210 | // Insert contract op |
211 | kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>( |
212 | op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs, |
213 | collapsedRhs); |
214 | |
215 | // Reshape output back to 2D |
216 | Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>( |
217 | kAcc.getLoc(), tiledAcc.getType(), kAcc); |
218 | |
219 | // With vecmat, only one row of tiled ACC can be inserted into file result |
220 | if (isVecmat) { |
221 | tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0); |
222 | } |
223 | |
224 | // Insert the tiled result back into the non tiled result of the |
225 | // contract op. |
226 | SmallVector<int64_t> strides( |
227 | cast<ShapedType>(tiledRes.getType()).getRank(), 1); |
228 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
229 | loc, tiledRes, result, accOffsets, strides); |
230 | } |
231 | |
232 | rewriter.replaceOp(op, result); |
233 | return success(); |
234 | } |
235 | }; |
236 | |
237 | } // namespace |
238 | |
239 | void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns( |
240 | RewritePatternSet &patterns) { |
241 | MLIRContext *context = patterns.getContext(); |
242 | patterns.add<LowerContractionToSMMLAPattern>(arg&: context, /*benefit=*/args: 1); |
243 | } |
244 | |