1//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering patterns from vector.contract to operations
10// that map to instructions from the Neon FEAT_I8MM extension.
11//
12// TODO: There may be opportunities to unify this with a similar pattern
13// for SVE. See:
14// https://github.com/llvm/llvm-project/issues/145559
15// LowerContractionToSVEI8MMPattern.cpp
16//
17//===----------------------------------------------------------------------===//
18
19#include "mlir/Dialect/Arith/IR/Arith.h"
20#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
21#include "mlir/Dialect/ArmNeon/Transforms.h"
22#include "mlir/Dialect/Func/IR/FuncOps.h"
23#include "mlir/Dialect/Utils/IndexingUtils.h"
24#include "mlir/Dialect/Vector/IR/VectorOps.h"
25#include "mlir/IR/AffineMap.h"
26#include "mlir/IR/PatternMatch.h"
27
28#define DEBUG_TYPE "lower-contract-to-arm-neon"
29
30using namespace mlir;
31using namespace mlir::arm_neon;
32
33namespace {
34
35/// Return the shaped type with new element type.
36static Type matchContainerType(Type element, Type container) {
37 if (auto shapedTy = dyn_cast<ShapedType>(Val&: container)) {
38 return shapedTy.clone(elementType: element);
39 }
40 return element;
41}
42
43// Get the operand of a `vector.contract`. This function is intended to abstract
44// away from the particular way a value is extended before feeding it into the
45// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
46// (for implicit sign-extension see `vector.contract` documentation).
47//
48// The template parameter `Op` indicates the extension operation (explicit or
49// implicit) for which we are checking.
50//
51// Return success only for extensions from `iN` (N <= 8) to `i32`.
52template <typename Op>
53std::optional<Value> getExtOperand(Value v) {
54
55 static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
56 "Must be instantiated with either sign- or zero- extension op");
57
58 // If the operand is not defined by an explicit extend operation of the
59 // accepted operation type allow for an implicit sign-extension.
60 auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
61 if (!extOp) {
62 if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
63 auto eltTy = cast<VectorType>(Val: v.getType()).getElementType();
64 if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
65 return {};
66 return v;
67 }
68 return {};
69 }
70
71 // If the operand is defined by an explicit extend operation of the accepted
72 // operation type, check it's extended from `iN` (N <= 8) to `i32`.
73 auto inOp = extOp.getIn();
74 auto inTy = dyn_cast<VectorType>(inOp.getType());
75 if (!inTy)
76 return {};
77 auto inEltTy = inTy.getElementType();
78 if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
79 return {};
80
81 auto outTy = dyn_cast<VectorType>(extOp.getType());
82 if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
83 return {};
84
85 return inOp;
86}
87
88// Designate the operation (resp. instruction) used to do sub-tile matrix
89// multiplications.
90enum class MMLA {
91 Signed, // smmla
92 Unsigned, // ummla
93 Mixed, // usmmla
94 MixedSwapped // usmmla with LHS and RHS swapped
95};
96
97// Create the matrix mulitply and accumulate operation according to `op`.
98Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
99 mlir::Type accType, Value acc, Value lhs, Value rhs) {
100 switch (op) {
101 case MMLA::Signed:
102 return rewriter.createOrFold<arm_neon::SmmlaOp>(location: loc, args&: accType, args&: acc, args&: lhs,
103 args&: rhs);
104 case MMLA::Unsigned:
105 return rewriter.createOrFold<arm_neon::UmmlaOp>(location: loc, args&: accType, args&: acc, args&: lhs,
106 args&: rhs);
107 case MMLA::Mixed:
108 return rewriter.createOrFold<arm_neon::UsmmlaOp>(location: loc, args&: accType, args&: acc, args&: lhs,
109 args&: rhs);
110 case MMLA::MixedSwapped:
111 // The accumulator comes transposed and the result will be transposed
112 // later, so all we have to do here is swap the operands.
113 return rewriter.createOrFold<arm_neon::UsmmlaOp>(location: loc, args&: accType, args&: acc, args&: rhs,
114 args&: lhs);
115 }
116}
117
118/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
119/// any vector.contract into multiple smmla instructions with unrolling so long
120/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
121/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
122/// necessary, a single smmla instruction is emitted.
123class LowerContractionToNeonI8MMPattern
124 : public OpRewritePattern<vector::ContractionOp> {
125public:
126 using OpRewritePattern::OpRewritePattern;
127 LogicalResult matchAndRewrite(vector::ContractionOp op,
128 PatternRewriter &rewriter) const override {
129 Location loc = op.getLoc();
130 // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
131 // Note: RHS is not transposed.
132 mlir::VectorType lhsType = op.getLhsType();
133 mlir::VectorType rhsType = op.getRhsType();
134 // Avoid 0-D vectors and 1-D rhs:
135 if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
136 return failure();
137 // This codegen does not work for scalable vectors. Return failure so this
138 // pattern is not accidentally chosen over patterns that lower to ArmSVE.
139 if (lhsType.isScalable() || rhsType.isScalable())
140 return failure();
141 auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(idx: 0);
142 auto dimN = rhsType.getDimSize(idx: 0);
143 auto dimK = rhsType.getDimSize(idx: 1);
144 bool isVecmat = dimM == 1 ? true : false;
145 if (lhsType.getDimSize(idx: lhsType.getRank() - 1) !=
146 rhsType.getDimSize(idx: rhsType.getRank() - 1)) {
147 return failure(); // dimK mismatch
148 }
149 // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
150 // tiling.
151 if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
152 return failure();
153 }
154
155 // Check iterator types for contract. All iterators except inner-most
156 // dimension must be parallel.
157 auto iteratorTypes = op.getIteratorTypesArray();
158 if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
159 vector::IteratorType::reduction) {
160 return failure();
161 }
162 if (llvm::any_of(Range: ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(N: 1),
163 P: [](vector::IteratorType iteratorType) {
164 return iteratorType != vector::IteratorType::parallel;
165 })) {
166 return failure();
167 }
168
169 // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
170 // values before the extension. All four signed/unsigned combinations for
171 // input operands are supported, but they are lowered to different
172 // operations. Determine which is the appropriate operation to lower to.
173 MMLA mmlaOp = MMLA::Signed;
174 auto maybeLhs = getExtOperand<arith::ExtSIOp>(v: op.getLhs());
175 if (!maybeLhs) {
176 mmlaOp = MMLA::Unsigned;
177 maybeLhs = getExtOperand<arith::ExtUIOp>(v: op.getLhs());
178 }
179 if (!maybeLhs)
180 return failure();
181
182 auto maybeRhs = getExtOperand<arith::ExtSIOp>(v: op.getRhs());
183 if (maybeRhs) {
184 if (mmlaOp == MMLA::Unsigned)
185 mmlaOp = MMLA::Mixed;
186 } else {
187 if (mmlaOp == MMLA::Signed)
188 mmlaOp = MMLA::MixedSwapped;
189 maybeRhs = getExtOperand<arith::ExtUIOp>(v: op.getRhs());
190 }
191 if (!maybeRhs)
192 return failure();
193
194 Value origLhs = *maybeLhs;
195 Value origRhs = *maybeRhs;
196
197 // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
198 // following neon instruction. Check inputs for extsi are <=i8
199 Value extLhs;
200 Value extRhs;
201 if (auto lhsExtInType = dyn_cast<mlir::VectorType>(Val: origLhs.getType())) {
202 if (lhsExtInType.getElementTypeBitWidth() <= 8) {
203 Type targetLhsExtTy =
204 matchContainerType(element: rewriter.getI8Type(), container: lhsExtInType);
205 if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
206 extLhs = rewriter.createOrFold<arith::ExtSIOp>(location: loc, args&: targetLhsExtTy,
207 args&: origLhs);
208 else
209 extLhs = rewriter.createOrFold<arith::ExtUIOp>(location: loc, args&: targetLhsExtTy,
210 args&: origLhs);
211 }
212 }
213 if (auto rhsExtInType = dyn_cast<mlir::VectorType>(Val: origRhs.getType())) {
214 if (rhsExtInType.getElementTypeBitWidth() <= 8) {
215 Type targetRhsExtTy =
216 matchContainerType(element: rewriter.getI8Type(), container: rhsExtInType);
217 if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
218 extRhs = rewriter.createOrFold<arith::ExtUIOp>(location: loc, args&: targetRhsExtTy,
219 args&: origRhs);
220 else
221 extRhs = rewriter.createOrFold<arith::ExtSIOp>(location: loc, args&: targetRhsExtTy,
222 args&: origRhs);
223 }
224 }
225
226 if (!extLhs || !extRhs) {
227 return failure();
228 }
229
230 // Initial accumulator for the final result. This is the un-tiled result if
231 // tiling is done.
232 Value result = rewriter.create<arith::ConstantOp>(
233 location: loc, args: op.getResultType(), args: rewriter.getZeroAttr(type: op.getResultType()));
234
235 SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
236 SmallVector<int64_t> smmlaShape = {2, 8};
237 SmallVector<int64_t> loopOrder = {0, 1};
238 if (unrolledSize.size() == 3) {
239 smmlaShape.insert(I: smmlaShape.begin(), Elt: isVecmat ? 1 : 2);
240 loopOrder.push_back(Elt: 2);
241 }
242
243 // Keep track of the previous accumulator when tiling over K.
244 Value kAcc;
245 for (SmallVector<int64_t> offsets :
246 StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
247 // Helper to compute the new shape of each operand and extract the slice.
248 auto extractOperand = [&](Value operand, AffineMap permutationMap,
249 ArrayRef<int64_t> operandOffsets) {
250 SmallVector<int64_t> operandShape =
251 applyPermutationMap(map: permutationMap, source: ArrayRef<int64_t>(smmlaShape));
252 SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
253 return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
254 location: loc, args&: operand, args&: operandOffsets, args&: operandShape, args&: operandStrides);
255 };
256
257 // Extract tiled lhs, rhs, and acc
258 AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
259 SmallVector<int64_t> lhsOffsets =
260 applyPermutationMap(map: lhsPermutationMap, source: ArrayRef<int64_t>(offsets));
261 Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
262 AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
263 SmallVector<int64_t> rhsOffsets =
264 applyPermutationMap(map: rhsPermutationMap, source: ArrayRef<int64_t>(offsets));
265 Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
266 AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
267 SmallVector<int64_t> accOffsets =
268 applyPermutationMap(map: accPermutationMap, source: ArrayRef<int64_t>(offsets));
269 Value tiledAcc =
270 extractOperand(op.getAcc(), accPermutationMap, accOffsets);
271
272 auto inputElementType =
273 cast<ShapedType>(Val: tiledLhs.getType()).getElementType();
274 auto accElementType =
275 cast<ShapedType>(Val: tiledAcc.getType()).getElementType();
276 auto inputExpandedType = VectorType::get(shape: {2, 8}, elementType: inputElementType);
277 auto outputExpandedType = VectorType::get(shape: {2, 2}, elementType: accElementType);
278
279 // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
280 // rows along dimM. Expand their shapes to match the smmla op.
281 if (isVecmat) {
282 auto expandForSMMLA = [&](Value tiledOperand,
283 VectorType expandedTypeType) {
284 auto emptyOperand = rewriter.create<arith::ConstantOp>(
285 location: loc, args&: expandedTypeType, args: rewriter.getZeroAttr(type: expandedTypeType));
286 SmallVector<int64_t> offsets(
287 cast<ShapedType>(Val: emptyOperand.getType()).getRank(), 0);
288 SmallVector<int64_t> strides(
289 cast<ShapedType>(Val: tiledOperand.getType()).getRank(), 1);
290 return rewriter.createOrFold<vector::InsertStridedSliceOp>(
291 location: loc, args&: tiledOperand, args&: emptyOperand, args&: offsets, args&: strides);
292 };
293 tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
294 tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
295 }
296
297 // Transpose ACC if doing signed by unsigned multiplication, because we're
298 // using the instruction for unsigned by signed multiplication with
299 // reversed operands.
300 if (mmlaOp == MMLA::MixedSwapped)
301 tiledAcc = rewriter.create<vector::TransposeOp>(
302 location: loc, args&: tiledAcc, args: ArrayRef<int64_t>({1, 0}));
303
304 // Collapse tiled operands to 1D vectors required by smmla intrinsic
305 auto collapsedInputType =
306 VectorType::get(shape: inputExpandedType.getNumElements(), elementType: inputElementType);
307 auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
308 location: tiledLhs.getLoc(), args&: collapsedInputType, args&: tiledLhs);
309 auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
310 location: tiledRhs.getLoc(), args&: collapsedInputType, args&: tiledRhs);
311 auto collapsedOutputType =
312 VectorType::get(shape: outputExpandedType.getNumElements(), elementType: accElementType);
313
314 bool initialKAcc = offsets.back() == 0;
315 Value collapsedRes;
316 if (!initialKAcc) {
317 collapsedRes = kAcc;
318 } else {
319 collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
320 location: tiledAcc.getLoc(), args&: collapsedOutputType, args&: tiledAcc);
321 }
322
323 // Insert contract op
324 kAcc = createMMLA(rewriter, op: mmlaOp, loc: op.getLoc(), accType: collapsedRes.getType(),
325 acc: collapsedRes, lhs: collapsedLhs, rhs: collapsedRhs);
326
327 // Reshape output back to 2D
328 Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
329 location: kAcc.getLoc(), args: tiledAcc.getType(), args&: kAcc);
330
331 // Because of the reversed operands the result is obtained transposed.
332 // Transpose it back,
333 if (mmlaOp == MMLA::MixedSwapped)
334 tiledRes = rewriter.create<vector::TransposeOp>(
335 location: loc, args&: tiledRes, args: ArrayRef<int64_t>({1, 0}));
336
337 // With vecmat, only one row of tiled ACC can be inserted into the final
338 // result
339 if (isVecmat) {
340 tiledRes = rewriter.createOrFold<vector::ExtractOp>(location: loc, args&: tiledRes, args: 0);
341 }
342
343 // Insert the tiled result back into the non tiled result of the
344 // contract op.
345 SmallVector<int64_t> strides(
346 cast<ShapedType>(Val: tiledRes.getType()).getRank(), 1);
347 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
348 location: loc, args&: tiledRes, args&: result, args&: accOffsets, args&: strides);
349 }
350
351 rewriter.replaceOp(op, newValues: result);
352 return success();
353 }
354};
355
356} // namespace
357
358void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(
359 RewritePatternSet &patterns) {
360 MLIRContext *context = patterns.getContext();
361 patterns.add<LowerContractionToNeonI8MMPattern>(arg&: context, /*benefit=*/args: 2);
362}
363

source code of mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp