1//===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns that transforms linalg.<op> +
10// tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
11// the computation for the linalg op.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/Linalg/Passes.h"
19#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20#include "mlir/Dialect/Linalg/Utils/Utils.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23using namespace mlir;
24using namespace mlir::linalg;
25
26namespace {
27/// Bubble up extract_slice above Linalg operation.
28///
29/// A sequence of operations
30///
31/// ```mlir
32/// %0 = linalg.<op> ... arg0, arg1, ...
33/// %1 = tensor.extract_slice %0 ...
34/// ```
35///
36/// can be replaced with
37///
38/// ```mlir
39/// %0 = tensor.extract_slice %arg0
40/// %1 = tensor.extract_slice %arg1
41/// %2 = linalg.<op> ... %0, %1, ...
42/// ```
43///
44/// This results in the reduce computation of the linalg operation.
45///
46struct BubbleUpExtractSliceOpPattern
47 : OpRewritePattern<tensor::ExtractSliceOp> {
48 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
49
50 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
51 PatternRewriter &rewriter) const final {
52 Value source = sliceOp.getSource();
53 auto linalgOp = source.getDefiningOp<LinalgOp>();
54 if (!linalgOp) {
55 return rewriter.notifyMatchFailure(sliceOp,
56 "expected source to be linalg op");
57 }
58
59 // TODO: we might relax this if we want heuristics to detect that all uses
60 // are small portion of the output.
61 if (!linalgOp->hasOneUse()) {
62 return rewriter.notifyMatchFailure(sliceOp,
63 "expected single use of linalg op");
64 }
65
66 if (linalgOp.getNumDpsInits() != 1) {
67 return rewriter.notifyMatchFailure(sliceOp,
68 "expected single output of linalg op");
69 }
70
71 if (!linalgOp.hasPureTensorSemantics()) {
72 return rewriter.notifyMatchFailure(sliceOp,
73 "expected tensor of linalg op");
74 }
75
76 if (!sliceOp.hasUnitStride())
77 return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
78
79 if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
80 return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
81 }
82
83 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
84 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
85 if (!indexingMap.isProjectedPermutation()) {
86 return rewriter.notifyMatchFailure(
87 sliceOp, "expected a projected permutation for output");
88 }
89
90 auto linalgLoc = linalgOp.getLoc();
91 SmallVector<OpFoldResult> allShapeSizes =
92 linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
93 AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
94 if (!shapeSizesToLoopsMap) {
95 return rewriter.notifyMatchFailure(
96 linalgOp, "failed to get loops map from shape sizes");
97 }
98 SmallVector<OpFoldResult> sizeBounds =
99 affine::makeComposedFoldedMultiResultAffineApply(
100 rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes);
101
102 // The offsets and sizes from the slice operation only give you the tile
103 // size of the output. Use that compute the tile sizes and offsets of the
104 // loops. For loops not used to access the output, set the tile sizes to
105 // loop bounds and set the offset to 0.
106 SmallVector<OpFoldResult> tileOffsets(sizeBounds.size(),
107 rewriter.getIndexAttr(0));
108 SmallVector<OpFoldResult> tileSizes = sizeBounds;
109 for (auto const &result : enumerate(indexingMap.getResults())) {
110 unsigned position = cast<AffineDimExpr>(result.value()).getPosition();
111 tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
112 tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
113 }
114
115 SmallVector<Value> valuesToTile = linalgOp->getOperands();
116 SmallVector<Value> tiledOperands =
117 makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
118 tileOffsets, tileSizes, sizeBounds,
119 /*omitPartialTileCheck=*/true);
120
121 SmallVector<Type, 4> resultTensorTypes;
122 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable())
123 resultTensorTypes.push_back(
124 tiledOperands[opOperand.getOperandNumber()].getType());
125
126 Operation *newOp =
127 clone(rewriter, linalgOp, resultTensorTypes, tiledOperands);
128 rewriter.replaceOp(sliceOp, newOp->getResults());
129 return success();
130 }
131};
132} // namespace
133
134void mlir::linalg::populateBubbleUpExtractSliceOpPatterns(
135 RewritePatternSet &patterns) {
136 auto *context = patterns.getContext();
137 patterns.add<BubbleUpExtractSliceOpPattern>(arg&: context);
138}
139

source code of mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp