1//===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Passes.h"
10
11#include "mlir/Dialect/Linalg/IR/Linalg.h"
12#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13#include "mlir/Dialect/Linalg/Utils/Utils.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/ADT/TypeSwitch.h"
18
19#include <optional>
20
21namespace mlir {
22#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
23#include "mlir/Dialect/Linalg/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::linalg;
28
29/// Return constant range span or nullopt, otherwise.
30static std::optional<int64_t> getConstantRange(const Range &range) {
31 std::optional<int64_t> stride = getConstantIntValue(ofr: range.stride);
32 if (!stride || *stride != 1)
33 return std::nullopt;
34 std::optional<int64_t> offset = getConstantIntValue(ofr: range.offset);
35 if (!offset)
36 return std::nullopt;
37 std::optional<int64_t> size = getConstantIntValue(ofr: range.size);
38 if (!size)
39 return std::nullopt;
40 return (*size - *offset);
41}
42
43/// Return true if all dimensions are fully divisible by the respective tiles.
44static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
45 ArrayRef<OpFoldResult> tiles,
46 ArrayRef<int64_t> dims) {
47 if (dims.size() != tiles.size() || tiles.empty())
48 return false;
49
50 FailureOr<ContractionDimensions> contractDims =
51 inferContractionDims(linalgOp);
52 if (failed(Result: contractDims))
53 return false;
54 unsigned batchDimsOffset = contractDims->batch.size();
55
56 // Skip the batch dimension if present.
57 // Offset all dimensions accordingly.
58 SmallVector<int64_t, 3> offsetDims(dims);
59 for (size_t i = 0; i < offsetDims.size(); i++)
60 offsetDims[i] += batchDimsOffset;
61
62 auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
63 OpBuilder builder(tileOp);
64 OpBuilder::InsertionGuard guard(builder);
65 SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
66
67 for (auto dim : llvm::enumerate(First&: offsetDims)) {
68 if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
69 return false;
70
71 std::optional<int64_t> tileSize = getConstantIntValue(ofr: tiles[dim.index()]);
72 std::optional<int64_t> rangeOnDim =
73 getConstantRange(range: iterationDomain[dim.value()]);
74
75 // If the tile factor or the range are non-constant, the tile size is
76 // considered to be invalid.
77 if (!tileSize || !rangeOnDim)
78 return false;
79
80 // The dimension must be fully divisible by the tile.
81 if (*rangeOnDim % *tileSize != 0)
82 return false;
83 }
84
85 return true;
86}
87
88/// Return failure or packed matmul with one of its operands transposed.
89static FailureOr<PackTransposeResult>
90transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
91 linalg::PackOp packOp, AffineMap operandMap,
92 ArrayRef<unsigned> blocksStartDimPos,
93 bool transposeOuterBlocks, bool transposeInnerBlocks) {
94 assert(operandMap.getNumDims() >= 4 &&
95 "expected at least 4D prepacked matmul");
96 assert(blocksStartDimPos.size() >= 2 &&
97 "expected starting outer and inner block positions");
98
99 // Bias toward innermost dimensions.
100 unsigned outerBlockPos = operandMap.getNumResults() - 4;
101 unsigned innerBlockPos = operandMap.getNumResults() - 2;
102
103 // Transpose control options define the desired block and element layout.
104 // Block transposition (outer dimensions) or element transposition (inner
105 // dimensions) may not be necessary depending on the original matmul data
106 // layout.
107 bool isOuterTransposed =
108 operandMap.getDimPosition(idx: outerBlockPos) != blocksStartDimPos.end()[-2];
109 bool isInnerTransposed =
110 operandMap.getDimPosition(idx: innerBlockPos) != blocksStartDimPos.back();
111
112 // Transpose only the dimensions that need that to conform to the provided
113 // transpotion settings.
114 SmallVector<int64_t> innerPerm = {0, 1};
115 if (isInnerTransposed != transposeInnerBlocks)
116 innerPerm = {1, 0};
117 SmallVector<int64_t> outerPerm = {0, 1};
118 if (isOuterTransposed != transposeOuterBlocks)
119 outerPerm = {1, 0};
120
121 // Leave the outer dimensions, like batch, unchanged by offsetting all
122 // outer dimensions permutations.
123 SmallVector<int64_t> offsetPerms;
124 for (auto i : llvm::seq(Begin: 0u, End: outerBlockPos))
125 offsetPerms.push_back(Elt: i);
126 for (auto perm : outerPerm)
127 offsetPerms.push_back(Elt: perm + outerBlockPos);
128 outerPerm = offsetPerms;
129
130 FailureOr<PackTransposeResult> packTransposedMatmul =
131 packTranspose(rewriter, packOp, linalgOp,
132 /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
133
134 return packTransposedMatmul;
135}
136
137/// Pack a matmul operation into blocked 4D layout.
138FailureOr<PackResult>
139linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
140 const ControlBlockPackMatmulFn &controlPackMatmul) {
141 // Check to not let go the batch_matmul with extended semantic, through this
142 // transform.
143 if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
144 if (batchMatmulOp->hasUserDefinedMaps()) {
145 return rewriter.notifyMatchFailure(
146 *batchMatmulOp,
147 "only batch_matmul ops with non-extended semantics are supported");
148 }
149 }
150
151 if (linalgOp.hasPureBufferSemantics())
152 return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
153
154 std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
155 if (!options)
156 return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
157
158 if (options->blockFactors.size() != 3)
159 return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
160
161 SmallVector<OpFoldResult> mnkTiles =
162 getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
163
164 // If padding is disabled, make sure that dimensions can be packed cleanly.
165 if (!options->allowPadding &&
166 !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
167 return rewriter.notifyMatchFailure(linalgOp,
168 "expect packing full tiles only");
169 }
170
171 OpBuilder::InsertionGuard guard(rewriter);
172 // The op is replaced, we need to set the insertion point after it.
173 rewriter.setInsertionPointAfter(linalgOp);
174
175 // Pack the matmul operation into blocked layout with two levels of
176 // subdivision:
177 // - major 2D blocks - outer dimensions, consist of minor blocks
178 // - minor 2D blocks - inner dimensions, consist of scalar elements
179 FailureOr<PackResult> packedMatmul = packMatmulGreedily(
180 rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
181 options->mnkOrder);
182 if (failed(Result: packedMatmul))
183 return failure();
184
185 assert(packedMatmul->packOps.size() == 3 &&
186 "invalid number of pack ops after matmul packing");
187 assert(packedMatmul->unPackOps.size() == 1 &&
188 "invalid number of unpack ops after matmul packing");
189
190 FailureOr<ContractionDimensions> contractDims =
191 inferContractionDims(packedMatmul->packedLinalgOp);
192 if (failed(Result: contractDims))
193 return failure();
194
195 auto genericOp =
196 dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
197 SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
198
199 // Transpose LHS matrix according to the options.
200 FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
201 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
202 contractDims->m, options->lhsTransposeOuterBlocks,
203 options->lhsTransposeInnerBlocks);
204 if (failed(Result: packedLhs))
205 return failure();
206
207 // Update results.
208 packedMatmul->packOps[0] = packedLhs->transposedPackOp;
209 packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
210
211 // Transpose RHS matrix according to the options.
212 FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
213 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
214 contractDims->k, options->rhsTransposeOuterBlocks,
215 options->rhsTransposeInnerBlocks);
216 if (failed(Result: packedRhs))
217 return failure();
218
219 // Update results.
220 packedMatmul->packOps[1] = packedRhs->transposedPackOp;
221 packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
222
223 return packedMatmul;
224}
225
226namespace {
227template <typename OpTy>
228struct BlockPackMatmul : public OpRewritePattern<OpTy> {
229 BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
230 PatternBenefit benefit = 1)
231 : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
232
233 LogicalResult matchAndRewrite(OpTy linalgOp,
234 PatternRewriter &rewriter) const override {
235 FailureOr<PackResult> packedMatmul =
236 blockPackMatmul(rewriter, linalgOp, controlFn);
237 if (failed(Result: packedMatmul))
238 return failure();
239 return success();
240 }
241
242private:
243 ControlBlockPackMatmulFn controlFn;
244};
245
246template <>
247struct BlockPackMatmul<linalg::GenericOp>
248 : public OpRewritePattern<linalg::GenericOp> {
249 BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
250 PatternBenefit benefit = 1)
251 : OpRewritePattern<linalg::GenericOp>(context, benefit),
252 controlFn(std::move(fun)) {}
253
254 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
255 PatternRewriter &rewriter) const override {
256 // Match suitable generics.
257 if (!linalg::isaContractionOpInterface(linalgOp: linalgOp)) {
258 return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
259 }
260
261 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
262 auto infer = [&](MapList m) {
263 return AffineMap::inferFromExprList(m, linalgOp.getContext());
264 };
265
266 AffineExpr i, j, k;
267 bindDims(linalgOp->getContext(), i, j, k);
268 SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
269
270 // For now, only match simple matmuls.
271 if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
272 maps == infer({{k, i}, {k, j}, {i, j}}) ||
273 maps == infer({{i, k}, {j, k}, {i, j}}))) {
274 return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
275 }
276
277 FailureOr<PackResult> packedMatmul =
278 blockPackMatmul(rewriter, linalgOp, controlFn);
279 if (failed(Result: packedMatmul))
280 return failure();
281 return success();
282 }
283
284private:
285 ControlBlockPackMatmulFn controlFn;
286};
287
288/// Convert linalg matmul ops to block layout and back.
289struct LinalgBlockPackMatmul
290 : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
291 using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
292
293 void runOnOperation() override {
294 Operation *op = getOperation();
295 RewritePatternSet patterns(&getContext());
296
297 ControlBlockPackMatmulFn controlFn =
298 [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
299 BlockPackMatmulOptions options;
300 options.blockFactors = SmallVector<int64_t>{*blockFactors};
301 options.allowPadding = allowPadding;
302 options.mnkPaddedSizesNextMultipleOf =
303 SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
304 if (!mnkOrder.empty())
305 options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
306 options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
307 options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
308 options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
309 options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
310 return options;
311 };
312
313 linalg::populateBlockPackMatmulPatterns(patterns, controlFn);
314 if (failed(applyPatternsGreedily(op, std::move(patterns))))
315 return signalPassFailure();
316 }
317};
318} // namespace
319
320void linalg::populateBlockPackMatmulPatterns(
321 RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
322 patterns.add<BlockPackMatmul<linalg::GenericOp>,
323 BlockPackMatmul<linalg::MatmulOp>,
324 BlockPackMatmul<linalg::BatchMatmulOp>,
325 BlockPackMatmul<linalg::MatmulTransposeAOp>,
326 BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
327 BlockPackMatmul<linalg::MatmulTransposeBOp>,
328 BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
329 patterns.getContext(), controlFn);
330}
331

source code of mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp