1//===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
10#include "mlir/Dialect/Tensor/IR/Tensor.h"
11#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
12#include "mlir/Dialect/Tensor/Utils/Utils.h"
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/OpDefinition.h"
15#include "mlir/IR/PatternMatch.h"
16
17using namespace mlir;
18using namespace mlir::tensor;
19
20namespace {
21/// Merges consecutive tensor.extract_slice ops into one.
22// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
23struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
24 using OpRewritePattern::OpRewritePattern;
25
26 LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
27 PatternRewriter &rewriter) const override {
28 auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
29 if (!prevOp)
30 return failure();
31
32 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
33 if (failed(affine::mergeOffsetsSizesAndStrides(
34 rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
35 newOffsets, newSizes, newStrides)))
36 return failure();
37
38 rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
39 prevOp.getSource(), newOffsets,
40 newSizes, newStrides);
41 return success();
42 }
43};
44
45/// Merges consecutive tensor.insert_slice ops into one.
46// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
47template <typename OpTy>
48struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
49 using OpRewritePattern<OpTy>::OpRewritePattern;
50
51 LogicalResult matchAndRewrite(OpTy nextOp,
52 PatternRewriter &rewriter) const override {
53 auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
54 if (!prevOp)
55 return failure();
56
57 if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
58 return failure();
59
60 // The first insert_slice op should be rank reducing to make sure we cover
61 // the full source tensor to be inserted in the second insert_slice op.
62 SliceVerificationResult result =
63 isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
64 if (result != SliceVerificationResult::Success)
65 return failure();
66
67 // Dynamic dimensions can pass rank reducing check in the above, e.g,
68 // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
69 // the dynamic size covers the full tensor.
70 if (!prevOp.getSourceType().hasStaticShape() ||
71 !prevOp.getDestType().hasStaticShape())
72 return failure();
73
74 rewriter.replaceOpWithNewOp<OpTy>(
75 nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
76 nextOp.getMixedSizes(), nextOp.getMixedStrides());
77 return success();
78 }
79};
80
81/// Drop redundant rank expansion of insert_slice that are directly followed
82/// by extract_slice. E.g.:
83/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
84/// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
85/// : tensor<1x1x5x10xf32> to tensor<2x2xf32>
86struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
87 : public OpRewritePattern<ExtractSliceOp> {
88 using OpRewritePattern::OpRewritePattern;
89
90 LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
91 PatternRewriter &rewriter) const override {
92 // Nothing to do if no dims are dropped.
93 llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
94 if (droppedDims.none())
95 return failure();
96
97 // Look for tensor.insert_slice op that has an inverse rank expansion.
98 auto insertSliceOp =
99 extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
100 if (!insertSliceOp)
101 return failure();
102 llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
103
104 // TODO: This could be extended to support cases where the dropped dims are
105 // a subset of the expanded dims.
106 if (expandedDims != droppedDims)
107 return failure();
108
109 // The tensor.insert_slice may not be redundant if it has multiple users.
110 if (!insertSliceOp->hasOneUse())
111 return failure();
112
113 // Only consider tensor.insert_slice ops that are pure rank-reductions.
114 // I.e., no elements are taken from the destination.
115 if (!isCastLikeInsertSliceOp(insertSliceOp))
116 return failure();
117
118 // Extract directly from the source.
119 OpBuilder::InsertionGuard g(rewriter);
120 rewriter.setInsertionPoint(extractSliceOp);
121 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
122 for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
123 ++i) {
124 if (droppedDims.test(Idx: i))
125 continue;
126 newOffsets.push_back(Elt: extractSliceOp.getMixedOffsets()[i]);
127 newSizes.push_back(Elt: extractSliceOp.getMixedSizes()[i]);
128 newStrides.push_back(Elt: extractSliceOp.getMixedStrides()[i]);
129 }
130 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
131 extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets,
132 newSizes, newStrides);
133 rewriter.eraseOp(op: insertSliceOp);
134 return success();
135 }
136};
137
138/// Drop redundant rank expansion of insert_slice that direclty follows
139/// extract_slice.
140///
141/// This can be done when the insert_slice op purely expands ranks (adds unit
142/// dims) and the extrace_slice drops corresponding unit dims. For example:
143///
144/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
145/// : tensor<2x8xf32> to tensor<8xf32>
146/// %inserted_slice = tensor.insert_slice %extracted_slice
147/// into %dest[0, 0] [1, 8] [1, 1]
148/// : tensor<8xf32> into tensor<1x8xf32>
149///
150/// can be folded into:
151///
152/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
153/// : tensor<2x8xf32> to tensor<1x8xf32>
154struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
155 : public OpRewritePattern<tensor::InsertSliceOp> {
156 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
157
158 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
159 PatternRewriter &rewriter) const override {
160 auto extractSliceOp =
161 insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
162 if (!extractSliceOp) {
163 return rewriter.notifyMatchFailure(insertSliceOp,
164 "source is not extract_slice");
165 }
166
167 // Can't fold if the extract_slice op has other users.
168 if (!extractSliceOp->hasOneUse()) {
169 return rewriter.notifyMatchFailure(insertSliceOp,
170 "source has multi-uses");
171 }
172
173 // Check if the insert_slice op purely expands ranks (add unit dims).
174 if (!isCastLikeInsertSliceOp(insertSliceOp)) {
175 return rewriter.notifyMatchFailure(insertSliceOp,
176 "insert_slice is not cast-like");
177 }
178
179 llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
180 llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
181 // Can't fold if the insert_slice op expands to more dims.
182 if (extractDroppedDims.size() < insertDroppedDims.size()) {
183 return rewriter.notifyMatchFailure(insertSliceOp,
184 "insert_slice expands more dims");
185 }
186
187 // Try to match the extract dropped dims to the insert dropped dims. This is
188 // done by scanning the dims of extract_slice and find the left-most one can
189 // match the dim of insert_slice. If a match is found, advance the dim of
190 // insert_slice to match the next one.
191 unsigned insertDimPos = 0;
192 for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
193 ++extractDimPos) {
194 // Matched all dims.
195 if (insertDimPos == insertDroppedDims.size())
196 break;
197
198 bool isExtractDropped = extractDroppedDims[extractDimPos];
199 bool isInsertDropped = insertDroppedDims[insertDimPos];
200 // Match if both sides drop/keep the dim. Advance and match the next dim
201 // of insert_slice.
202 if (isExtractDropped == isInsertDropped) {
203 insertDimPos += 1;
204 } else if (!isExtractDropped && isInsertDropped) {
205 // Not enough extract dropped dims to match the insert dropped dims.
206 return rewriter.notifyMatchFailure(insertSliceOp,
207 "insert_slice drops more unit dims");
208 }
209 // If the dim is dropped by extract_slice and not by insert_slice, look
210 // the next dim of extract_slice to see if it can match the current dim of
211 // insert_slice.
212 }
213 // Can't match some insert dims.
214 if (insertDimPos != insertDroppedDims.size()) {
215 return rewriter.notifyMatchFailure(insertSliceOp,
216 "insert_slice has unmatched dims");
217 }
218
219 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
220 insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
221 extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
222 extractSliceOp.getMixedStrides());
223 rewriter.eraseOp(op: extractSliceOp);
224
225 return success();
226 }
227};
228} // namespace
229
230void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
231 RewritePatternSet &patterns) {
232 patterns.add<MergeConsecutiveExtractSlice,
233 MergeConsecutiveInsertSlice<InsertSliceOp>,
234 MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
235 patterns.getContext());
236}
237
238void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
239 RewritePatternSet &patterns) {
240 patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
241 DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
242 arg: patterns.getContext());
243}
244

source code of mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp