1//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering patterns from vector.contract to
10// arm_neon.intr.smmla
11//
12//===---
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
16#include "mlir/Dialect/ArmNeon/Transforms.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/Dialect/Utils/IndexingUtils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24
25#define DEBUG_TYPE "lower-contract-to-arm-neon"
26
27using namespace mlir;
28using namespace mlir::arm_neon;
29
30namespace {
31
32/// Return the shaped type with new element type.
33static Type matchContainerType(Type element, Type container) {
34 if (auto shapedTy = dyn_cast<ShapedType>(container)) {
35 return shapedTy.clone(element);
36 }
37 return element;
38}
39
40/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
41/// any vector.contract into multiple smmla instructions with unrolling so long
42/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
43/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
44/// necessary, a single smmla instruction is emitted.
45class LowerContractionToSMMLAPattern
46 : public OpRewritePattern<vector::ContractionOp> {
47public:
48 using OpRewritePattern::OpRewritePattern;
49 LogicalResult matchAndRewrite(vector::ContractionOp op,
50 PatternRewriter &rewriter) const override {
51 Location loc = op.getLoc();
52 // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
53 // Note: RHS is not transposed.
54 mlir::VectorType lhsType = op.getLhsType();
55 mlir::VectorType rhsType = op.getRhsType();
56 // Avoid 0-D vectors and 1-D rhs:
57 if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
58 return failure();
59 // This codegen does not work for scalable vectors. Return failure so this
60 // pattern is not accidentally chosen over patterns that lower to ArmSVE.
61 if (lhsType.isScalable() || rhsType.isScalable())
62 return failure();
63 auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
64 auto dimN = rhsType.getDimSize(0);
65 auto dimK = rhsType.getDimSize(1);
66 bool isVecmat = dimM == 1 ? true : false;
67 if (lhsType.getDimSize(lhsType.getRank() - 1) !=
68 rhsType.getDimSize(rhsType.getRank() - 1)) {
69 return failure(); // dimK mismatch
70 }
71 // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
72 // tiling.
73 if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
74 return failure();
75 }
76
77 // Check iterator types for contract. All iterators except inner-most
78 // dimension must be parallel.
79 auto iteratorTypes = op.getIteratorTypesArray();
80 if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
81 vector::IteratorType::reduction) {
82 return failure();
83 }
84 if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
85 [](vector::IteratorType iteratorType) {
86 return iteratorType != vector::IteratorType::parallel;
87 })) {
88 return failure();
89 }
90
91 // Check two extsi inputs Rhs Lhs for contract.
92 arith::ExtSIOp origLhsExtOp =
93 dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
94 arith::ExtSIOp origRhsExtOp =
95 dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
96 if (!origLhsExtOp || !origRhsExtOp) {
97 return failure();
98 }
99
100 // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
101 // following neon instruction. Check inputs for extsi are <=i8
102 Value extsiLhs;
103 Value extsiRhs;
104 if (auto lhsExtInType =
105 dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
106 if (lhsExtInType.getElementTypeBitWidth() <= 8) {
107 Type targetLhsExtTy =
108 matchContainerType(rewriter.getI8Type(), lhsExtInType);
109 extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
110 origLhsExtOp.getIn());
111 }
112 }
113 if (auto rhsExtInType =
114 dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
115 if (rhsExtInType.getElementTypeBitWidth() <= 8) {
116 Type targetRhsExtTy =
117 matchContainerType(rewriter.getI8Type(), rhsExtInType);
118 extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
119 origRhsExtOp.getIn());
120 }
121 }
122
123 if (!extsiLhs || !extsiRhs) {
124 return failure();
125 }
126
127 // Initial accumulator for the final result. This is the un-tiled result if
128 // tiling is done.
129 Value result = rewriter.create<arith::ConstantOp>(
130 loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
131
132 SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
133 SmallVector<int64_t> smmlaShape = {2, 8};
134 SmallVector<int64_t> loopOrder = {0, 1};
135 if (unrolledSize.size() == 3) {
136 smmlaShape.insert(I: smmlaShape.begin(), Elt: isVecmat ? 1 : 2);
137 loopOrder.push_back(Elt: 2);
138 }
139
140 // Keep track of the previous accumulator when tiling over K.
141 Value kAcc;
142 for (SmallVector<int64_t> offsets :
143 StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
144 // Helper to compute the new shape of each operand and extract the slice.
145 auto extractOperand = [&](Value operand, AffineMap permutationMap,
146 ArrayRef<int64_t> operandOffsets) {
147 SmallVector<int64_t> operandShape =
148 applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
149 SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
150 return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
151 loc, operand, operandOffsets, operandShape, operandStrides);
152 };
153
154 // Extract tiled lhs, rhs, and acc
155 AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
156 SmallVector<int64_t> lhsOffsets =
157 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
158 Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
159 AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
160 SmallVector<int64_t> rhsOffsets =
161 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
162 Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
163 AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
164 SmallVector<int64_t> accOffsets =
165 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
166 Value tiledAcc =
167 extractOperand(op.getAcc(), accPermutationMap, accOffsets);
168
169 auto inputElementType =
170 cast<ShapedType>(tiledLhs.getType()).getElementType();
171 auto accElementType =
172 cast<ShapedType>(tiledAcc.getType()).getElementType();
173 auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
174 auto outputExpandedType = VectorType::get({2, 2}, accElementType);
175
176 // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
177 // rows along dimM. Expand their shapes to match the smmla op.
178 if (isVecmat) {
179 auto expandForSMMLA = [&](Value tiledOperand,
180 VectorType expandedTypeType) {
181 auto emptyOperand = rewriter.create<arith::ConstantOp>(
182 loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
183 SmallVector<int64_t> offsets(
184 cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
185 SmallVector<int64_t> strides(
186 cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
187 return rewriter.createOrFold<vector::InsertStridedSliceOp>(
188 loc, tiledOperand, emptyOperand, offsets, strides);
189 };
190 tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
191 tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
192 }
193
194 // Collapse tiled operands to 1D vectors required by smmla intrinsic
195 auto collapsedInputType =
196 VectorType::get(inputExpandedType.getNumElements(), inputElementType);
197 auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
198 tiledLhs.getLoc(), collapsedInputType, tiledLhs);
199 auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
200 tiledRhs.getLoc(), collapsedInputType, tiledRhs);
201 auto collapsedOutputType =
202 VectorType::get(outputExpandedType.getNumElements(), accElementType);
203
204 bool initialKAcc = offsets.back() == 0;
205 Value collapsedRes;
206 if (!initialKAcc) {
207 collapsedRes = kAcc;
208 } else {
209 collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
210 tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
211 }
212
213 // Insert contract op
214 kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
215 op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
216 collapsedRhs);
217
218 // Reshape output back to 2D
219 Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
220 kAcc.getLoc(), tiledAcc.getType(), kAcc);
221
222 // With vecmat, only one row of tiled ACC can be inserted into file result
223 if (isVecmat) {
224 tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
225 }
226
227 // Insert the tiled result back into the non tiled result of the
228 // contract op.
229 SmallVector<int64_t> strides(
230 cast<ShapedType>(tiledRes.getType()).getRank(), 1);
231 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
232 loc, tiledRes, result, accOffsets, strides);
233 }
234
235 rewriter.replaceOp(op, result);
236 return success();
237 }
238};
239
240} // namespace
241
242void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
243 RewritePatternSet &patterns) {
244 MLIRContext *context = patterns.getContext();
245 patterns.add<LowerContractionToSMMLAPattern>(arg&: context, /*benefit=*/args: 2);
246}
247

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp