1 | //===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
12 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
13 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
14 | #include "mlir/IR/PatternMatch.h" |
15 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
16 | #include "llvm/ADT/SmallVector.h" |
17 | #include "llvm/ADT/TypeSwitch.h" |
18 | |
19 | #include <optional> |
20 | |
21 | namespace mlir { |
22 | #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL |
23 | #include "mlir/Dialect/Linalg/Passes.h.inc" |
24 | } // namespace mlir |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::linalg; |
28 | |
29 | /// Return constant range span or nullopt, otherwise. |
30 | static std::optional<int64_t> getConstantRange(const Range &range) { |
31 | std::optional<int64_t> stride = getConstantIntValue(ofr: range.stride); |
32 | if (!stride || *stride != 1) |
33 | return std::nullopt; |
34 | std::optional<int64_t> offset = getConstantIntValue(ofr: range.offset); |
35 | if (!offset) |
36 | return std::nullopt; |
37 | std::optional<int64_t> size = getConstantIntValue(ofr: range.size); |
38 | if (!size) |
39 | return std::nullopt; |
40 | return (*size - *offset); |
41 | } |
42 | |
43 | /// Return true if all dimensions are fully divisible by the respective tiles. |
44 | static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, |
45 | ArrayRef<OpFoldResult> tiles, |
46 | ArrayRef<int64_t> dims) { |
47 | if (dims.size() != tiles.size() || tiles.empty()) |
48 | return false; |
49 | |
50 | FailureOr<ContractionDimensions> contractDims = |
51 | inferContractionDims(linalgOp); |
52 | if (failed(Result: contractDims)) |
53 | return false; |
54 | unsigned batchDimsOffset = contractDims->batch.size(); |
55 | |
56 | // Skip the batch dimension if present. |
57 | // Offset all dimensions accordingly. |
58 | SmallVector<int64_t, 3> offsetDims(dims); |
59 | for (size_t i = 0; i < offsetDims.size(); i++) |
60 | offsetDims[i] += batchDimsOffset; |
61 | |
62 | auto tileOp = cast<TilingInterface>(linalgOp.getOperation()); |
63 | OpBuilder builder(tileOp); |
64 | OpBuilder::InsertionGuard guard(builder); |
65 | SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder); |
66 | |
67 | for (auto dim : llvm::enumerate(First&: offsetDims)) { |
68 | if (dim.value() >= static_cast<int64_t>(iterationDomain.size())) |
69 | return false; |
70 | |
71 | std::optional<int64_t> tileSize = getConstantIntValue(ofr: tiles[dim.index()]); |
72 | std::optional<int64_t> rangeOnDim = |
73 | getConstantRange(range: iterationDomain[dim.value()]); |
74 | |
75 | // If the tile factor or the range are non-constant, the tile size is |
76 | // considered to be invalid. |
77 | if (!tileSize || !rangeOnDim) |
78 | return false; |
79 | |
80 | // The dimension must be fully divisible by the tile. |
81 | if (*rangeOnDim % *tileSize != 0) |
82 | return false; |
83 | } |
84 | |
85 | return true; |
86 | } |
87 | |
88 | /// Return failure or packed matmul with one of its operands transposed. |
89 | static FailureOr<PackTransposeResult> |
90 | transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, |
91 | linalg::PackOp packOp, AffineMap operandMap, |
92 | ArrayRef<unsigned> blocksStartDimPos, |
93 | bool transposeOuterBlocks, bool transposeInnerBlocks) { |
94 | assert(operandMap.getNumDims() >= 4 && |
95 | "expected at least 4D prepacked matmul" ); |
96 | assert(blocksStartDimPos.size() >= 2 && |
97 | "expected starting outer and inner block positions" ); |
98 | |
99 | // Bias toward innermost dimensions. |
100 | unsigned outerBlockPos = operandMap.getNumResults() - 4; |
101 | unsigned innerBlockPos = operandMap.getNumResults() - 2; |
102 | |
103 | // Transpose control options define the desired block and element layout. |
104 | // Block transposition (outer dimensions) or element transposition (inner |
105 | // dimensions) may not be necessary depending on the original matmul data |
106 | // layout. |
107 | bool isOuterTransposed = |
108 | operandMap.getDimPosition(idx: outerBlockPos) != blocksStartDimPos.end()[-2]; |
109 | bool isInnerTransposed = |
110 | operandMap.getDimPosition(idx: innerBlockPos) != blocksStartDimPos.back(); |
111 | |
112 | // Transpose only the dimensions that need that to conform to the provided |
113 | // transpotion settings. |
114 | SmallVector<int64_t> innerPerm = {0, 1}; |
115 | if (isInnerTransposed != transposeInnerBlocks) |
116 | innerPerm = {1, 0}; |
117 | SmallVector<int64_t> outerPerm = {0, 1}; |
118 | if (isOuterTransposed != transposeOuterBlocks) |
119 | outerPerm = {1, 0}; |
120 | |
121 | // Leave the outer dimensions, like batch, unchanged by offsetting all |
122 | // outer dimensions permutations. |
123 | SmallVector<int64_t> offsetPerms; |
124 | for (auto i : llvm::seq(Begin: 0u, End: outerBlockPos)) |
125 | offsetPerms.push_back(Elt: i); |
126 | for (auto perm : outerPerm) |
127 | offsetPerms.push_back(Elt: perm + outerBlockPos); |
128 | outerPerm = offsetPerms; |
129 | |
130 | FailureOr<PackTransposeResult> packTransposedMatmul = |
131 | packTranspose(rewriter, packOp, linalgOp, |
132 | /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm); |
133 | |
134 | return packTransposedMatmul; |
135 | } |
136 | |
137 | /// Pack a matmul operation into blocked 4D layout. |
138 | FailureOr<PackResult> |
139 | linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, |
140 | const ControlBlockPackMatmulFn &controlPackMatmul) { |
141 | // Check to not let go the batch_matmul with extended semantic, through this |
142 | // transform. |
143 | if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) { |
144 | if (batchMatmulOp->hasUserDefinedMaps()) { |
145 | return rewriter.notifyMatchFailure( |
146 | *batchMatmulOp, |
147 | "only batch_matmul ops with non-extended semantics are supported" ); |
148 | } |
149 | } |
150 | |
151 | if (linalgOp.hasPureBufferSemantics()) |
152 | return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics" ); |
153 | |
154 | std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp); |
155 | if (!options) |
156 | return rewriter.notifyMatchFailure(linalgOp, "invalid packing options" ); |
157 | |
158 | if (options->blockFactors.size() != 3) |
159 | return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors" ); |
160 | |
161 | SmallVector<OpFoldResult> mnkTiles = |
162 | getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors)); |
163 | |
164 | // If padding is disabled, make sure that dimensions can be packed cleanly. |
165 | if (!options->allowPadding && |
166 | !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) { |
167 | return rewriter.notifyMatchFailure(linalgOp, |
168 | "expect packing full tiles only" ); |
169 | } |
170 | |
171 | OpBuilder::InsertionGuard guard(rewriter); |
172 | // The op is replaced, we need to set the insertion point after it. |
173 | rewriter.setInsertionPointAfter(linalgOp); |
174 | |
175 | // Pack the matmul operation into blocked layout with two levels of |
176 | // subdivision: |
177 | // - major 2D blocks - outer dimensions, consist of minor blocks |
178 | // - minor 2D blocks - inner dimensions, consist of scalar elements |
179 | FailureOr<PackResult> packedMatmul = packMatmulGreedily( |
180 | rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf, |
181 | options->mnkOrder); |
182 | if (failed(Result: packedMatmul)) |
183 | return failure(); |
184 | |
185 | assert(packedMatmul->packOps.size() == 3 && |
186 | "invalid number of pack ops after matmul packing" ); |
187 | assert(packedMatmul->unPackOps.size() == 1 && |
188 | "invalid number of unpack ops after matmul packing" ); |
189 | |
190 | FailureOr<ContractionDimensions> contractDims = |
191 | inferContractionDims(packedMatmul->packedLinalgOp); |
192 | if (failed(Result: contractDims)) |
193 | return failure(); |
194 | |
195 | auto genericOp = |
196 | dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation()); |
197 | SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray(); |
198 | |
199 | // Transpose LHS matrix according to the options. |
200 | FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul( |
201 | rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0], |
202 | contractDims->m, options->lhsTransposeOuterBlocks, |
203 | options->lhsTransposeInnerBlocks); |
204 | if (failed(Result: packedLhs)) |
205 | return failure(); |
206 | |
207 | // Update results. |
208 | packedMatmul->packOps[0] = packedLhs->transposedPackOp; |
209 | packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp; |
210 | |
211 | // Transpose RHS matrix according to the options. |
212 | FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul( |
213 | rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1], |
214 | contractDims->k, options->rhsTransposeOuterBlocks, |
215 | options->rhsTransposeInnerBlocks); |
216 | if (failed(Result: packedRhs)) |
217 | return failure(); |
218 | |
219 | // Update results. |
220 | packedMatmul->packOps[1] = packedRhs->transposedPackOp; |
221 | packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp; |
222 | |
223 | return packedMatmul; |
224 | } |
225 | |
226 | namespace { |
227 | template <typename OpTy> |
228 | struct BlockPackMatmul : public OpRewritePattern<OpTy> { |
229 | BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun, |
230 | PatternBenefit benefit = 1) |
231 | : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {} |
232 | |
233 | LogicalResult matchAndRewrite(OpTy linalgOp, |
234 | PatternRewriter &rewriter) const override { |
235 | FailureOr<PackResult> packedMatmul = |
236 | blockPackMatmul(rewriter, linalgOp, controlFn); |
237 | if (failed(Result: packedMatmul)) |
238 | return failure(); |
239 | return success(); |
240 | } |
241 | |
242 | private: |
243 | ControlBlockPackMatmulFn controlFn; |
244 | }; |
245 | |
246 | template <> |
247 | struct BlockPackMatmul<linalg::GenericOp> |
248 | : public OpRewritePattern<linalg::GenericOp> { |
249 | BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun, |
250 | PatternBenefit benefit = 1) |
251 | : OpRewritePattern<linalg::GenericOp>(context, benefit), |
252 | controlFn(std::move(fun)) {} |
253 | |
254 | LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, |
255 | PatternRewriter &rewriter) const override { |
256 | // Match suitable generics. |
257 | if (!linalg::isaContractionOpInterface(linalgOp: linalgOp)) { |
258 | return rewriter.notifyMatchFailure(linalgOp, "not a contraction" ); |
259 | } |
260 | |
261 | using MapList = ArrayRef<ArrayRef<AffineExpr>>; |
262 | auto infer = [&](MapList m) { |
263 | return AffineMap::inferFromExprList(m, linalgOp.getContext()); |
264 | }; |
265 | |
266 | AffineExpr i, j, k; |
267 | bindDims(linalgOp->getContext(), i, j, k); |
268 | SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray(); |
269 | |
270 | // For now, only match simple matmuls. |
271 | if (!(maps == infer({{i, k}, {k, j}, {i, j}}) || |
272 | maps == infer({{k, i}, {k, j}, {i, j}}) || |
273 | maps == infer({{i, k}, {j, k}, {i, j}}))) { |
274 | return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul" ); |
275 | } |
276 | |
277 | FailureOr<PackResult> packedMatmul = |
278 | blockPackMatmul(rewriter, linalgOp, controlFn); |
279 | if (failed(Result: packedMatmul)) |
280 | return failure(); |
281 | return success(); |
282 | } |
283 | |
284 | private: |
285 | ControlBlockPackMatmulFn controlFn; |
286 | }; |
287 | |
288 | /// Convert linalg matmul ops to block layout and back. |
289 | struct LinalgBlockPackMatmul |
290 | : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> { |
291 | using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase; |
292 | |
293 | void runOnOperation() override { |
294 | Operation *op = getOperation(); |
295 | RewritePatternSet patterns(&getContext()); |
296 | |
297 | ControlBlockPackMatmulFn controlFn = |
298 | [&](linalg::LinalgOp op) -> BlockPackMatmulOptions { |
299 | BlockPackMatmulOptions options; |
300 | options.blockFactors = SmallVector<int64_t>{*blockFactors}; |
301 | options.allowPadding = allowPadding; |
302 | options.mnkPaddedSizesNextMultipleOf = |
303 | SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf}; |
304 | if (!mnkOrder.empty()) |
305 | options.mnkOrder = SmallVector<int64_t>{*mnkOrder}; |
306 | options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks; |
307 | options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks; |
308 | options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks; |
309 | options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks; |
310 | return options; |
311 | }; |
312 | |
313 | linalg::populateBlockPackMatmulPatterns(patterns, controlFn); |
314 | if (failed(applyPatternsGreedily(op, std::move(patterns)))) |
315 | return signalPassFailure(); |
316 | } |
317 | }; |
318 | } // namespace |
319 | |
320 | void linalg::populateBlockPackMatmulPatterns( |
321 | RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) { |
322 | patterns.add<BlockPackMatmul<linalg::GenericOp>, |
323 | BlockPackMatmul<linalg::MatmulOp>, |
324 | BlockPackMatmul<linalg::BatchMatmulOp>, |
325 | BlockPackMatmul<linalg::MatmulTransposeAOp>, |
326 | BlockPackMatmul<linalg::BatchMatmulTransposeAOp>, |
327 | BlockPackMatmul<linalg::MatmulTransposeBOp>, |
328 | BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>( |
329 | patterns.getContext(), controlFn); |
330 | } |
331 | |