1//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/IR/Linalg.h"
10#include "mlir/Dialect/Tensor/IR/Tensor.h"
11#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
12#include "mlir/Dialect/Utils/IndexingUtils.h"
13#include "mlir/IR/PatternMatch.h"
14
15namespace mlir {
16namespace linalg {
17namespace {
18
19/// Returns the number of shape sizes that is either dynamic or greater than 1.
20static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
21 return llvm::count_if(
22 Range&: shape, P: [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
23}
24
25/// Returns success() if there is only 1 dimension size in non-packed domain
26/// being greater than 1 and packing only happens on the dimension.
27/// Note: this method should only be used by pack/unpack to reshape conversion.
28/// It assumes that non-unit inner tile size must be used by the non-unit
29/// dimension.
30static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
31 ArrayRef<int64_t> srcShape,
32 ArrayRef<int64_t> innerPackTileSize) {
33 if (getNumGtOneDims(shape: srcShape) > 1) {
34 return rewriter.notifyMatchFailure(
35 arg&: op, msg: "expects non-packed domain to have at most one non-unit dims");
36 }
37 // Non-unit inner tile size must be used by the non-unit dimension. If not, it
38 // will faill on getting reassociation maps.
39 if (getNumGtOneDims(shape: innerPackTileSize) > 1) {
40 return rewriter.notifyMatchFailure(
41 arg&: op, msg: "expects at most one non-unit inner tiles");
42 }
43 return success();
44}
45
46// If the `linalgOp` represents a transpose, return the permutation vector for
47// the transpose. Otherwise, return failure.
48static FailureOr<SmallVector<int64_t>>
49getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
50 if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
51 return SmallVector<int64_t>(transposeOp.getPermutation());
52 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
53 return failure();
54
55 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
56 return failure();
57 auto mapRange = linalgOp.getIndexingMapsArray();
58 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
59 mapRange.front() == mapRange.back()) {
60 return failure();
61 }
62 if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
63 return failure();
64 AffineMap outMap = mapRange.back();
65 AffineMap inMap = mapRange.front();
66 // To get the permutation, look at each output index and find which
67 // dimension in the input we're reading from for that index.
68 return llvm::map_to_vector(C: outMap.getResults(),
69 F: [&](AffineExpr expr) -> int64_t {
70 return *inMap.getResultPosition(input: expr);
71 });
72}
73
74/// Packing one-dimensional tensor can be expressed as an expand shape op.
75struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
76 using OpRewritePattern<PackOp>::OpRewritePattern;
77
78 FailureOr<Value>
79 insertExpand(RewriterBase &rewriter, Location loc, Value operand,
80 Type newOperandType,
81 ArrayRef<ReassociationIndices> reassociation) const {
82 if (operand.getType() == newOperandType)
83 return operand;
84 return rewriter
85 .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
86 reassociation)
87 .getResult();
88 }
89
90 /// Returns success() if it is only packing on the innermost dimension.
91 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92 PackOp packOp) const {
93 auto outerDimsPerm = packOp.getOuterDimsPerm();
94 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
95 return rewriter.notifyMatchFailure(
96 packOp,
97 "expects outer_dims_perm is empty or an identity permutation");
98 }
99
100 int64_t srcRank = packOp.getSourceRank();
101 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
103 return rewriter.notifyMatchFailure(
104 packOp, "expects packing at the innermost dimension");
105 }
106 return success();
107 }
108
109 LogicalResult matchAndRewrite(PackOp packOp,
110 PatternRewriter &rewriter) const override {
111 if (packOp.getPaddingValue())
112 return rewriter.notifyMatchFailure(packOp, "expects no padding value");
113
114 RankedTensorType sourceType = packOp.getSourceType();
115 if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117 packOp.getStaticTiles())) &&
118 !packOp.isLikePad()) {
119 return failure();
120 }
121
122 RankedTensorType destType = packOp.getDestType();
123 auto reassociation =
124 getReassociationIndicesForReshape(sourceType, destType);
125 if (!reassociation)
126 return failure();
127 FailureOr<Value> expanded =
128 insertExpand(rewriter, loc: packOp.getLoc(), operand: packOp.getSource(), newOperandType: destType,
129 reassociation: *reassociation);
130 if (failed(Result: expanded)) {
131 return rewriter.notifyMatchFailure(
132 packOp, "unable to expand source of tensor.pack");
133 }
134 rewriter.replaceOp(packOp, *expanded);
135 return success();
136 }
137};
138
139struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
140 using OpRewritePattern<UnPackOp>::OpRewritePattern;
141
142 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
143 Type newOperandType, ArrayAttr reassociation) const {
144 if (operand.getType() == newOperandType)
145 return operand;
146 return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
147 operand, reassociation);
148 }
149
150 /// Returns success() if it is unpacking on the innermost dimension.
151 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152 UnPackOp unpackOp) const {
153 auto outerDimsPerm = unpackOp.getOuterDimsPerm();
154 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
155 return rewriter.notifyMatchFailure(
156 unpackOp,
157 "expects outer_dims_perm is empty or an identity permutation");
158 }
159
160 RankedTensorType sourceType = unpackOp.getSourceType();
161 RankedTensorType destType = unpackOp.getDestType();
162 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163 return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
164
165 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
166 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
167 return rewriter.notifyMatchFailure(
168 unpackOp, "expects unpacking on the innermost dimension");
169 }
170
171 return success();
172 }
173
174 LogicalResult matchAndRewrite(UnPackOp unpackOp,
175 PatternRewriter &rewriter) const override {
176 RankedTensorType destType = unpackOp.getDestType();
177 if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179 unpackOp.getStaticTiles())) &&
180 !unpackOp.isLikeUnPad()) {
181 return failure();
182 }
183
184 RankedTensorType sourceType = unpackOp.getSourceType();
185 auto reassociation =
186 getReassociationIndicesForReshape(sourceType, destType);
187 if (!reassociation)
188 return failure();
189 Value collapsed = insertCollapse(
190 rewriter, loc: unpackOp.getLoc(), operand: unpackOp.getSource(), newOperandType: destType,
191 reassociation: getReassociationIndicesAttribute(rewriter, *reassociation));
192 rewriter.replaceOp(unpackOp, collapsed);
193 return success();
194 }
195};
196
197/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198/// the pad op has zero low paddings, or if `pack` has no padding values.
199struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
200 using OpRewritePattern<PackOp>::OpRewritePattern;
201
202 LogicalResult matchAndRewrite(PackOp packOp,
203 PatternRewriter &rewriter) const override {
204 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
205
206 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
207 return failure();
208
209 Value constantPaddingValue = padOp.getConstantPaddingValue();
210 if (!constantPaddingValue)
211 return failure();
212
213 if (auto paddingValue = packOp.getPaddingValue())
214 if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
215 return failure();
216
217 rewriter.replaceOpWithNewOp<PackOp>(
218 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
219 packOp.getMixedTiles(), constantPaddingValue,
220 packOp.getOuterDimsPerm());
221 return success();
222 }
223};
224
225/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
226/// has extract_slice semantics.
227struct FoldUnpackWithExtractSliceOp
228 : public OpRewritePattern<tensor::ExtractSliceOp> {
229 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
230
231 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
232 PatternRewriter &rewriter) const override {
233 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
234 if (!unpackOp)
235 return failure();
236
237 if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
238 return rewriter.notifyMatchFailure(
239 sliceOp, "rank-reduced folding is not supported");
240 }
241
242 // Check all offsets are zeros, and all strides are ones.
243 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
244 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
245 return rewriter.notifyMatchFailure(
246 sliceOp, "expects offsets to be 0s and strides to be 1s");
247 }
248
249 // Create a new empty output tensor.
250 Type elementType = unpackOp.getDestType().getElementType();
251 Value output = rewriter.create<tensor::EmptyOp>(
252 sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
253 rewriter.replaceOpWithNewOp<UnPackOp>(
254 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
255 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
256 return success();
257 }
258};
259
260// Applies 'permutation' on 'inVec' and stores the result in resVec.
261// 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
262// `rank` sets the boundary for permutation i.e., the permutation dim can't be
263// greater than the rank specified. If it's so then return false.
264// For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
265// permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
266// not allowed since `3` exceeds the value of the rank in the given range.
267static bool checkAndPermute(ArrayRef<int64_t> permutation,
268 ArrayRef<int64_t> inVec,
269 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
270
271 for (unsigned int i = 0; i < rank; ++i) {
272 int64_t remappedPosition = permutation[i];
273 if (remappedPosition >= rank)
274 return false;
275 if (!inVec.empty())
276 remappedPosition = inVec[remappedPosition];
277 resVec.push_back(Elt: remappedPosition);
278 }
279
280 return true;
281}
282
283/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
284/// semantics.
285struct FoldProducerPackWithConsumerLinalgTransposeOp
286 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
287 using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
288
289 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
290 PatternRewriter &rewriter) const override {
291 auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
292
293 if (!packOp)
294 return failure();
295
296 FailureOr<SmallVector<int64_t>> maybePerm =
297 getTransposeOpPermutation(linalgOp);
298 if (failed(Result: maybePerm))
299 return failure();
300
301 auto innerDimsPos = packOp.getInnerDimsPos();
302 auto mixedInnerTiles = packOp.getMixedTiles();
303 auto outerDimsPerm = packOp.getOuterDimsPerm();
304 auto transposePerm = maybePerm.value();
305 SmallVector<int64_t> newOuterDimsPermVec;
306 SmallVector<int64_t> newInnerDimsPosVec;
307 SmallVector<OpFoldResult> newMixedInnerTilesVec;
308 int64_t srcRank = packOp.getSourceRank();
309
310 if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
311 srcRank))
312 return rewriter.notifyMatchFailure(
313 linalgOp,
314 "Cannot fold in tensor.pack if a tile dimension was transposed "
315 "with a non-tile dimension in linalg.transpose.");
316
317 // Process transpose operation for tiled inner dimensions
318 for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
319 int64_t remappedPosition = transposePerm[i] - srcRank;
320 newMixedInnerTilesVec.push_back(Elt: mixedInnerTiles[remappedPosition]);
321 newInnerDimsPosVec.push_back(Elt: innerDimsPos[remappedPosition]);
322 }
323
324 Value output = packOp.createDestinationTensor(
325 rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
326 newInnerDimsPosVec, newOuterDimsPermVec);
327
328 rewriter.replaceOpWithNewOp<PackOp>(
329 linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
330 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
331
332 return success();
333 }
334};
335
336/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
337/// semantics.
338struct FoldConsumerPackWithProducerLinalgTransposeOp
339 : public OpRewritePattern<PackOp> {
340 using OpRewritePattern<PackOp>::OpRewritePattern;
341
342 LogicalResult matchAndRewrite(PackOp packOp,
343 PatternRewriter &rewriter) const override {
344 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
345 if (!linalgOp)
346 return failure();
347
348 FailureOr<SmallVector<int64_t>> maybePerm =
349 getTransposeOpPermutation(linalgOp);
350 if (failed(Result: maybePerm))
351 return failure();
352
353 auto transposePermutation = maybePerm.value();
354 auto outerDimsPerm = packOp.getOuterDimsPerm();
355 auto innerDimsPos = packOp.getInnerDimsPos();
356 SmallVector<int64_t> newInnerDimsPosVec;
357 SmallVector<int64_t> newOuterDimsPermVec =
358 llvm::to_vector(transposePermutation);
359
360 if (!outerDimsPerm.empty())
361 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
362
363 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
364 // permutation rank won't necessarily be equal in all cases.
365 for (auto dim : innerDimsPos)
366 newInnerDimsPosVec.push_back(transposePermutation[dim]);
367
368 Value output = packOp.createDestinationTensor(
369 rewriter, packOp.getLoc(), linalgOp->getOperand(0),
370 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
371
372 rewriter.replaceOpWithNewOp<PackOp>(
373 packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
374 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
375
376 return success();
377 }
378};
379
380/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
381/// transpose semantics.
382struct FoldProducerUnPackWithConsumerLinalgTransposeOp
383 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
384 using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
385
386 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
387 PatternRewriter &rewriter) const override {
388 auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
389
390 if (!unPackOp)
391 return failure();
392
393 FailureOr<SmallVector<int64_t>> maybePerm =
394 getTransposeOpPermutation(linalgOp);
395 if (failed(Result: maybePerm))
396 return failure();
397
398 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
399 auto innerDimsPos = unPackOp.getInnerDimsPos();
400 SmallVector<int64_t> newInnerDimsPosVec;
401 SmallVector<int64_t> newOuterDimsPermVec =
402 invertPermutationVector(permutation: maybePerm.value());
403
404 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
405 // permutation rank won't necessarily be equal in all cases.
406 for (auto dim : innerDimsPos)
407 newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
408
409 if (!outerDimsPerm.empty())
410 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
411
412 // Reuse the destination of the transpose op.
413 rewriter.replaceOpWithNewOp<UnPackOp>(
414 linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
415 newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
416
417 return success();
418 }
419};
420
421/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
422/// transpose semantics.
423struct FoldConsumerUnPackWithProducerLinalgTransposeOp
424 : public OpRewritePattern<UnPackOp> {
425 using OpRewritePattern<UnPackOp>::OpRewritePattern;
426
427 LogicalResult matchAndRewrite(UnPackOp unPackOp,
428 PatternRewriter &rewriter) const override {
429 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
430 if (!linalgOp)
431 return failure();
432
433 FailureOr<SmallVector<int64_t>> maybePerm =
434 getTransposeOpPermutation(linalgOp);
435 if (failed(Result: maybePerm))
436 return failure();
437
438 SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
439 if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
440 return failure();
441 }
442
443 SmallVector<int64_t> inverseTransposePerm =
444 invertPermutationVector(permutation: maybePerm.value());
445 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
446 auto innerDimsPos = unPackOp.getInnerDimsPos();
447 int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
448 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
449 SmallVector<int64_t> newOuterDimsPermVec;
450 SmallVector<int64_t> newInnerDimsPosVec;
451 SmallVector<OpFoldResult> newMixedInnerTilesVec;
452 if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
453 newOuterDimsPermVec, destRank))
454 return rewriter.notifyMatchFailure(
455 unPackOp,
456 "Cannot fold in tensor.unpack if a tile dimension was transposed "
457 "with a non-tile dimension in linalg.transpose.");
458
459 // Process transpose operation for tiled inner dimensions
460 for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
461 int64_t remappedPosition = inverseTransposePerm[i] - destRank;
462 newMixedInnerTilesVec.push_back(Elt: mixedInnerTilesVec[remappedPosition]);
463 newInnerDimsPosVec.push_back(Elt: innerDimsPos[remappedPosition]);
464 }
465
466 auto elemType =
467 cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
468 Value output = rewriter.create<tensor::EmptyOp>(
469 unPackOp->getLoc(), unpackOpResultDims[0], elemType);
470
471 rewriter.replaceOpWithNewOp<UnPackOp>(
472 unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
473 newMixedInnerTilesVec, newOuterDimsPermVec);
474
475 return success();
476 }
477};
478
479/// tensor.empty does not define any tensor contents, so an unpadded pack
480/// can be folded away.
481struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
482 using OpRewritePattern<PackOp>::OpRewritePattern;
483
484 LogicalResult matchAndRewrite(PackOp packOp,
485 PatternRewriter &rewriter) const override {
486 // Check for tensor.empty source.
487 auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>();
488 if (!emptyOp)
489 return failure();
490
491 // Check for padding.
492 // Packing with padding cannot be simply removed.
493 if (packOp.getPaddingValue())
494 return rewriter.notifyMatchFailure(packOp, "expects no padding value");
495
496 // Replace the pack directly with its destination.
497 rewriter.replaceOp(packOp, packOp.getDest());
498
499 return success();
500 }
501};
502
503/// tensor.empty does not define any tensor contents, so an unpack
504/// can be folded away.
505struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
506 using OpRewritePattern<UnPackOp>::OpRewritePattern;
507
508 LogicalResult matchAndRewrite(UnPackOp unPackOp,
509 PatternRewriter &rewriter) const override {
510 // Check for tensor.empty source.
511 auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
512 if (!emptyOp)
513 return failure();
514
515 // Replace the unpack directly with its destination.
516 rewriter.replaceOp(unPackOp, unPackOp.getDest());
517
518 return success();
519 }
520};
521
522} // namespace
523
524void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
525 patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
526 FoldProducerPackWithConsumerLinalgTransposeOp,
527 FoldConsumerPackWithProducerLinalgTransposeOp,
528 FoldConsumerUnPackWithProducerLinalgTransposeOp,
529 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
530 arg: patterns.getContext());
531}
532
533void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
534 patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
535 arg: patterns.getContext());
536}
537
538void populateFoldPackUnpackIntoTensorEmptyPatterns(
539 RewritePatternSet &patterns) {
540 patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
541 arg: patterns.getContext());
542}
543
544} // namespace linalg
545} // namespace mlir
546

source code of mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp