1//===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Passes.h"
10
11#include "mlir/Dialect/Linalg/IR/Linalg.h"
12#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13#include "mlir/Dialect/Linalg/Utils/Utils.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16#include "llvm/ADT/SmallVector.h"
17
18#include <optional>
19
20namespace mlir {
21#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
22#include "mlir/Dialect/Linalg/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::linalg;
27
28/// Return constant range span or nullopt, otherwise.
29static std::optional<int64_t> getConstantRange(const Range &range) {
30 std::optional<int64_t> stride = getConstantIntValue(ofr: range.stride);
31 if (!stride || *stride != 1)
32 return std::nullopt;
33 std::optional<int64_t> offset = getConstantIntValue(ofr: range.offset);
34 if (!offset)
35 return std::nullopt;
36 std::optional<int64_t> size = getConstantIntValue(ofr: range.size);
37 if (!size)
38 return std::nullopt;
39 return (*size - *offset);
40}
41
42/// Return true if all dimensions are fully divisible by the respective tiles.
43static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
44 ArrayRef<OpFoldResult> tiles,
45 ArrayRef<int64_t> dims) {
46 if (dims.size() != tiles.size() || tiles.empty())
47 return false;
48
49 FailureOr<ContractionDimensions> contractDims =
50 inferContractionDims(linalgOp);
51 if (failed(Result: contractDims))
52 return false;
53 unsigned batchDimsOffset = contractDims->batch.size();
54
55 // Skip the batch dimension if present.
56 // Offset all dimensions accordingly.
57 SmallVector<int64_t, 3> offsetDims(dims);
58 for (size_t i = 0; i < offsetDims.size(); i++)
59 offsetDims[i] += batchDimsOffset;
60
61 auto tileOp = cast<TilingInterface>(Val: linalgOp.getOperation());
62 OpBuilder builder(tileOp);
63 OpBuilder::InsertionGuard guard(builder);
64 SmallVector<Range> iterationDomain = tileOp.getIterationDomain(b&: builder);
65
66 for (auto dim : llvm::enumerate(First&: offsetDims)) {
67 if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
68 return false;
69
70 std::optional<int64_t> tileSize = getConstantIntValue(ofr: tiles[dim.index()]);
71 std::optional<int64_t> rangeOnDim =
72 getConstantRange(range: iterationDomain[dim.value()]);
73
74 // If the tile factor or the range are non-constant, the tile size is
75 // considered to be invalid.
76 if (!tileSize || !rangeOnDim)
77 return false;
78
79 // The dimension must be fully divisible by the tile.
80 if (*rangeOnDim % *tileSize != 0)
81 return false;
82 }
83
84 return true;
85}
86
87/// Return failure or packed matmul with one of its operands transposed.
88static FailureOr<PackTransposeResult>
89transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
90 linalg::PackOp packOp, AffineMap operandMap,
91 ArrayRef<unsigned> blocksStartDimPos,
92 bool transposeOuterBlocks, bool transposeInnerBlocks) {
93 assert(operandMap.getNumDims() >= 4 &&
94 "expected at least 4D prepacked matmul");
95 assert(blocksStartDimPos.size() >= 2 &&
96 "expected starting outer and inner block positions");
97
98 // Bias toward innermost dimensions.
99 unsigned outerBlockPos = operandMap.getNumResults() - 4;
100 unsigned innerBlockPos = operandMap.getNumResults() - 2;
101
102 // Transpose control options define the desired block and element layout.
103 // Block transposition (outer dimensions) or element transposition (inner
104 // dimensions) may not be necessary depending on the original matmul data
105 // layout.
106 bool isOuterTransposed =
107 operandMap.getDimPosition(idx: outerBlockPos) != blocksStartDimPos.end()[-2];
108 bool isInnerTransposed =
109 operandMap.getDimPosition(idx: innerBlockPos) != blocksStartDimPos.back();
110
111 // Transpose only the dimensions that need that to conform to the provided
112 // transpotion settings.
113 SmallVector<int64_t> innerPerm = {0, 1};
114 if (isInnerTransposed != transposeInnerBlocks)
115 innerPerm = {1, 0};
116 SmallVector<int64_t> outerPerm = {0, 1};
117 if (isOuterTransposed != transposeOuterBlocks)
118 outerPerm = {1, 0};
119
120 // Leave the outer dimensions, like batch, unchanged by offsetting all
121 // outer dimensions permutations.
122 SmallVector<int64_t> offsetPerms;
123 for (auto i : llvm::seq(Begin: 0u, End: outerBlockPos))
124 offsetPerms.push_back(Elt: i);
125 for (auto perm : outerPerm)
126 offsetPerms.push_back(Elt: perm + outerBlockPos);
127 outerPerm = offsetPerms;
128
129 FailureOr<PackTransposeResult> packTransposedMatmul =
130 packTranspose(rewriter, packOp, linalgOp,
131 /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
132
133 return packTransposedMatmul;
134}
135
136/// Pack a matmul operation into blocked 4D layout.
137FailureOr<PackResult>
138linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
139 const ControlBlockPackMatmulFn &controlPackMatmul) {
140 // Check to not let go the batch_matmul with extended semantic, through this
141 // transform.
142 if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(Val: &linalgOp)) {
143 if (batchMatmulOp->hasUserDefinedMaps()) {
144 return rewriter.notifyMatchFailure(
145 arg&: *batchMatmulOp,
146 msg: "only batch_matmul ops with non-extended semantics are supported");
147 }
148 }
149
150 if (linalgOp.hasPureBufferSemantics())
151 return rewriter.notifyMatchFailure(arg&: linalgOp, msg: "require tensor semantics");
152
153 std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
154 if (!options)
155 return rewriter.notifyMatchFailure(arg&: linalgOp, msg: "invalid packing options");
156
157 if (options->blockFactors.size() != 3)
158 return rewriter.notifyMatchFailure(arg&: linalgOp, msg: "require 3 tile factors");
159
160 SmallVector<OpFoldResult> mnkTiles =
161 getAsOpFoldResult(arrayAttr: rewriter.getI64ArrayAttr(values: options->blockFactors));
162
163 // If padding is disabled, make sure that dimensions can be packed cleanly.
164 if (!options->allowPadding &&
165 !validateFullTilesOnDims(linalgOp, tiles: mnkTiles, dims: options->mnkOrder)) {
166 return rewriter.notifyMatchFailure(arg&: linalgOp,
167 msg: "expect packing full tiles only");
168 }
169
170 OpBuilder::InsertionGuard guard(rewriter);
171 // The op is replaced, we need to set the insertion point after it.
172 rewriter.setInsertionPointAfter(linalgOp);
173
174 // Pack the matmul operation into blocked layout with two levels of
175 // subdivision:
176 // - major 2D blocks - outer dimensions, consist of minor blocks
177 // - minor 2D blocks - inner dimensions, consist of scalar elements
178 FailureOr<PackResult> packedMatmul = packMatmulGreedily(
179 rewriter, linalgOp, mnkPackedSizes: mnkTiles, mnkPaddedSizesNextMultipleOf: options->mnkPaddedSizesNextMultipleOf,
180 mnkOrder: options->mnkOrder);
181 if (failed(Result: packedMatmul))
182 return failure();
183
184 assert(packedMatmul->packOps.size() == 3 &&
185 "invalid number of pack ops after matmul packing");
186 assert(packedMatmul->unPackOps.size() == 1 &&
187 "invalid number of unpack ops after matmul packing");
188
189 FailureOr<ContractionDimensions> contractDims =
190 inferContractionDims(linalgOp: packedMatmul->packedLinalgOp);
191 if (failed(Result: contractDims))
192 return failure();
193
194 auto genericOp =
195 dyn_cast<linalg::GenericOp>(Val: packedMatmul->packedLinalgOp.getOperation());
196 SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
197
198 // Transpose LHS matrix according to the options.
199 FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
200 rewriter, linalgOp: packedMatmul->packedLinalgOp, packOp: packedMatmul->packOps[0], operandMap: maps[0],
201 blocksStartDimPos: contractDims->m, transposeOuterBlocks: options->lhsTransposeOuterBlocks,
202 transposeInnerBlocks: options->lhsTransposeInnerBlocks);
203 if (failed(Result: packedLhs))
204 return failure();
205
206 // Update results.
207 packedMatmul->packOps[0] = packedLhs->transposedPackOp;
208 packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
209
210 // Transpose RHS matrix according to the options.
211 FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
212 rewriter, linalgOp: packedMatmul->packedLinalgOp, packOp: packedMatmul->packOps[1], operandMap: maps[1],
213 blocksStartDimPos: contractDims->k, transposeOuterBlocks: options->rhsTransposeOuterBlocks,
214 transposeInnerBlocks: options->rhsTransposeInnerBlocks);
215 if (failed(Result: packedRhs))
216 return failure();
217
218 // Update results.
219 packedMatmul->packOps[1] = packedRhs->transposedPackOp;
220 packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
221
222 return packedMatmul;
223}
224
225namespace {
226template <typename OpTy>
227struct BlockPackMatmul : public OpRewritePattern<OpTy> {
228 BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
229 PatternBenefit benefit = 1)
230 : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
231
232 LogicalResult matchAndRewrite(OpTy linalgOp,
233 PatternRewriter &rewriter) const override {
234 FailureOr<PackResult> packedMatmul =
235 blockPackMatmul(rewriter, linalgOp, controlFn);
236 if (failed(Result: packedMatmul))
237 return failure();
238 return success();
239 }
240
241private:
242 ControlBlockPackMatmulFn controlFn;
243};
244
245template <>
246struct BlockPackMatmul<linalg::GenericOp>
247 : public OpRewritePattern<linalg::GenericOp> {
248 BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
249 PatternBenefit benefit = 1)
250 : OpRewritePattern<linalg::GenericOp>(context, benefit),
251 controlFn(std::move(fun)) {}
252
253 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
254 PatternRewriter &rewriter) const override {
255 // Match suitable generics.
256 if (!linalg::isaContractionOpInterface(linalgOp)) {
257 return rewriter.notifyMatchFailure(arg&: linalgOp, msg: "not a contraction");
258 }
259
260 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
261 auto infer = [&](MapList m) {
262 return AffineMap::inferFromExprList(exprsList: m, context: linalgOp.getContext());
263 };
264
265 AffineExpr i, j, k;
266 bindDims(ctx: linalgOp->getContext(), exprs&: i, exprs&: j, exprs&: k);
267 SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
268
269 // For now, only match simple matmuls.
270 if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
271 maps == infer({{k, i}, {k, j}, {i, j}}) ||
272 maps == infer({{i, k}, {j, k}, {i, j}}))) {
273 return rewriter.notifyMatchFailure(arg&: linalgOp, msg: "not a suitable matmul");
274 }
275
276 FailureOr<PackResult> packedMatmul =
277 blockPackMatmul(rewriter, linalgOp, controlPackMatmul: controlFn);
278 if (failed(Result: packedMatmul))
279 return failure();
280 return success();
281 }
282
283private:
284 ControlBlockPackMatmulFn controlFn;
285};
286
287/// Convert linalg matmul ops to block layout and back.
288struct LinalgBlockPackMatmul
289 : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
290 using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
291
292 void runOnOperation() override {
293 Operation *op = getOperation();
294 RewritePatternSet patterns(&getContext());
295
296 ControlBlockPackMatmulFn controlFn =
297 [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
298 BlockPackMatmulOptions options;
299 options.blockFactors = SmallVector<int64_t>{*blockFactors};
300 options.allowPadding = allowPadding;
301 options.mnkPaddedSizesNextMultipleOf =
302 SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
303 if (!mnkOrder.empty())
304 options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
305 options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
306 options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
307 options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
308 options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
309 return options;
310 };
311
312 linalg::populateBlockPackMatmulPatterns(patterns, controlFn);
313 if (failed(Result: applyPatternsGreedily(op, patterns: std::move(patterns))))
314 return signalPassFailure();
315 }
316};
317} // namespace
318
319void linalg::populateBlockPackMatmulPatterns(
320 RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
321 patterns.add<BlockPackMatmul<linalg::GenericOp>,
322 BlockPackMatmul<linalg::MatmulOp>,
323 BlockPackMatmul<linalg::BatchMatmulOp>,
324 BlockPackMatmul<linalg::MatmulTransposeAOp>,
325 BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
326 BlockPackMatmul<linalg::MatmulTransposeBOp>,
327 BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
328 arg: patterns.getContext(), args: controlFn);
329}
330

source code of mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp