1//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering patterns from vector.contract to
10// arm_neon.intr.smmla
11//
12//===---
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
16#include "mlir/Dialect/ArmNeon/Transforms.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/Dialect/Utils/IndexingUtils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Support/LogicalResult.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26#define DEBUG_TYPE "lower-contract-to-arm-neon"
27
28using namespace mlir;
29using namespace mlir::arm_neon;
30
31namespace {
32
33/// Return the shaped type with new element type.
34static Type matchContainerType(Type element, Type container) {
35 if (auto shapedTy = dyn_cast<ShapedType>(container)) {
36 return shapedTy.clone(element);
37 }
38 return element;
39}
40
41/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
42/// any vector.contract into multiple smmla instructions with unrolling so long
43/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
44/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
45/// necessary, a single smmla instruction is emitted.
46class LowerContractionToSMMLAPattern
47 : public OpRewritePattern<vector::ContractionOp> {
48public:
49 using OpRewritePattern::OpRewritePattern;
50 LogicalResult matchAndRewrite(vector::ContractionOp op,
51 PatternRewriter &rewriter) const override {
52 Location loc = op.getLoc();
53 // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
54 // Note: RHS is not transposed.
55 mlir::VectorType lhsType = op.getLhsType();
56 mlir::VectorType rhsType = op.getRhsType();
57 // Avoid 0-D vectors and 1-D rhs:
58 if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
59 return failure();
60 auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
61 auto dimN = rhsType.getDimSize(0);
62 auto dimK = rhsType.getDimSize(1);
63 bool isVecmat = dimM == 1 ? true : false;
64 if (lhsType.getDimSize(lhsType.getRank() - 1) !=
65 rhsType.getDimSize(rhsType.getRank() - 1)) {
66 return failure(); // dimK mismatch
67 }
68 // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
69 // tiling.
70 if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
71 return failure();
72 }
73
74 // Check iterator types for contract. All iterators except inner-most
75 // dimension must be parallel.
76 auto iteratorTypes = op.getIteratorTypesArray();
77 if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
78 vector::IteratorType::reduction) {
79 return failure();
80 }
81 if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
82 [](vector::IteratorType iteratorType) {
83 return iteratorType != vector::IteratorType::parallel;
84 })) {
85 return failure();
86 }
87
88 // Check two extsi inputs Rhs Lhs for contract.
89 arith::ExtSIOp origLhsExtOp =
90 dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
91 arith::ExtSIOp origRhsExtOp =
92 dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
93 if (!origLhsExtOp || !origRhsExtOp) {
94 return failure();
95 }
96
97 // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
98 // following neon instruction. Check inputs for extsi are <=i8
99 Value extsiLhs;
100 Value extsiRhs;
101 if (auto lhsExtInType =
102 dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
103 if (lhsExtInType.getElementTypeBitWidth() <= 8) {
104 Type targetLhsExtTy =
105 matchContainerType(rewriter.getI8Type(), lhsExtInType);
106 extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
107 origLhsExtOp.getIn());
108 }
109 }
110 if (auto rhsExtInType =
111 dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
112 if (rhsExtInType.getElementTypeBitWidth() <= 8) {
113 Type targetRhsExtTy =
114 matchContainerType(rewriter.getI8Type(), rhsExtInType);
115 extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
116 origRhsExtOp.getIn());
117 }
118 }
119
120 if (!extsiLhs || !extsiRhs) {
121 return failure();
122 }
123
124 // Initial accumulator for the final result. This is the un-tiled result if
125 // tiling is done.
126 Value result = rewriter.create<arith::ConstantOp>(
127 loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
128
129 SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
130 SmallVector<int64_t> smmlaShape{2, 8};
131 SmallVector<int64_t> loopOrder{0, 1};
132 if (unrolledSize.size() == 3) {
133 smmlaShape.insert(I: smmlaShape.begin(), Elt: isVecmat ? 1 : 2);
134 loopOrder.push_back(Elt: 2);
135 }
136
137 // Keep track of the previous accumulator when tiling over K.
138 Value kAcc;
139 for (SmallVector<int64_t> offsets :
140 StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
141 // Helper to compute the new shape of each operand and extract the slice.
142 auto extractOperand = [&](Value operand, AffineMap permutationMap,
143 ArrayRef<int64_t> operandOffsets) {
144 SmallVector<int64_t> operandShape =
145 applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
146 SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
147 return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
148 loc, operand, operandOffsets, operandShape, operandStrides);
149 };
150
151 // Extract tiled lhs, rhs, and acc
152 AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
153 SmallVector<int64_t> lhsOffsets =
154 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
155 Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
156 AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
157 SmallVector<int64_t> rhsOffsets =
158 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
159 Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
160 AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
161 SmallVector<int64_t> accOffsets =
162 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
163 Value tiledAcc =
164 extractOperand(op.getAcc(), accPermutationMap, accOffsets);
165
166 auto inputElementType =
167 cast<ShapedType>(tiledLhs.getType()).getElementType();
168 auto accElementType =
169 cast<ShapedType>(tiledAcc.getType()).getElementType();
170 auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
171 auto outputExpandedType = VectorType::get({2, 2}, accElementType);
172
173 // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
174 // rows along dimM. Expand their shapes to match the smmla op.
175 if (isVecmat) {
176 auto expandForSMMLA = [&](Value tiledOperand,
177 VectorType expandedTypeType) {
178 auto emptyOperand = rewriter.create<arith::ConstantOp>(
179 loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
180 SmallVector<int64_t> offsets(
181 cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
182 SmallVector<int64_t> strides(
183 cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
184 return rewriter.createOrFold<vector::InsertStridedSliceOp>(
185 loc, tiledOperand, emptyOperand, offsets, strides);
186 };
187 tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
188 tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
189 }
190
191 // Collapse tiled operands to 1D vectors required by smmla intrinsic
192 auto collapsedInputType =
193 VectorType::get(inputExpandedType.getNumElements(), inputElementType);
194 auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
195 tiledLhs.getLoc(), collapsedInputType, tiledLhs);
196 auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
197 tiledRhs.getLoc(), collapsedInputType, tiledRhs);
198 auto collapsedOutputType =
199 VectorType::get(outputExpandedType.getNumElements(), accElementType);
200
201 bool initialKAcc = offsets.back() == 0;
202 Value collapsedRes;
203 if (!initialKAcc) {
204 collapsedRes = kAcc;
205 } else {
206 collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
207 tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
208 }
209
210 // Insert contract op
211 kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
212 op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
213 collapsedRhs);
214
215 // Reshape output back to 2D
216 Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
217 kAcc.getLoc(), tiledAcc.getType(), kAcc);
218
219 // With vecmat, only one row of tiled ACC can be inserted into file result
220 if (isVecmat) {
221 tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
222 }
223
224 // Insert the tiled result back into the non tiled result of the
225 // contract op.
226 SmallVector<int64_t> strides(
227 cast<ShapedType>(tiledRes.getType()).getRank(), 1);
228 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
229 loc, tiledRes, result, accOffsets, strides);
230 }
231
232 rewriter.replaceOp(op, result);
233 return success();
234 }
235};
236
237} // namespace
238
239void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
240 RewritePatternSet &patterns) {
241 MLIRContext *context = patterns.getContext();
242 patterns.add<LowerContractionToSMMLAPattern>(arg&: context, /*benefit=*/args: 1);
243}
244

source code of mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp