1//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Arith/Utils/Utils.h"
11#include "mlir/Dialect/Tensor/IR/Tensor.h"
12#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Interfaces/ValueBoundsOpInterface.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/Support/LogicalResult.h"
17
18using namespace mlir;
19using namespace mlir::tensor;
20
21namespace {
22/// Fold expand_shape(extract_slice) ops that cancel itself out.
23struct FoldExpandOfRankReducingExtract
24 : public OpRewritePattern<ExpandShapeOp> {
25 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
26
27 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
28 PatternRewriter &rewriter) const override {
29 RankedTensorType resultType = expandShapeOp.getResultType();
30 auto extractSliceOp =
31 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
32 if (!extractSliceOp)
33 return failure();
34 RankedTensorType srcType = extractSliceOp.getSourceType();
35
36 // Only cases where the ExpandShapeOp can be folded away entirely are
37 // supported. Moreover, only simple cases where the resulting ExtractSliceOp
38 // has no rank-reduction anymore are supported at the moment.
39 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
40 sourceTensorType: srcType, staticOffsets: extractSliceOp.getStaticOffsets(),
41 staticSizes: extractSliceOp.getStaticSizes(), staticStrides: extractSliceOp.getStaticStrides());
42 if (nonReducingExtractType != resultType)
43 return failure();
44
45 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
46 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
47 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
48 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
49 op: expandShapeOp, args: extractSliceOp.getSource(), args&: mixedOffsets, args&: mixedSizes,
50 args&: mixedStrides);
51 return success();
52 }
53};
54
55/// Fold collapse_shape which only removes static dimensions of size `1`
56/// into extract_slice.
57struct FoldUnPaddingCollapseIntoExtract
58 : public OpRewritePattern<tensor::CollapseShapeOp> {
59 using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
60
61 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
62 PatternRewriter &rewriter) const override {
63 auto extractSliceOp =
64 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
65 // Collapse cannot be folded away with multiple users of the extract slice
66 // and it is not necessarily beneficial to only convert the collapse into
67 // another extract slice.
68 if (!extractSliceOp || !extractSliceOp->hasOneUse())
69 return failure();
70
71 // Only fold away simple collapse where all removed dimensions have static
72 // size `1`.
73 SliceVerificationResult res = isRankReducedType(
74 originalType: collapseShapeOp.getSrcType(), candidateReducedType: collapseShapeOp.getResultType());
75 if (res != SliceVerificationResult::Success)
76 return rewriter.notifyMatchFailure(arg&: collapseShapeOp,
77 msg: "expected unpadding collapse");
78
79 Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
80 location: extractSliceOp.getLoc(), args: collapseShapeOp.getResultType(),
81 args: extractSliceOp.getSource(), args: extractSliceOp.getMixedOffsets(),
82 args: extractSliceOp.getMixedSizes(), args: extractSliceOp.getMixedStrides());
83 rewriter.replaceOp(op: collapseShapeOp, newValues: unPaddedExtractSlice);
84 return success();
85 }
86};
87
88/// Fold insert_slice(collapse_shape) ops that cancel itself out.
89template <typename OpTy>
90struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
91 using OpRewritePattern<OpTy>::OpRewritePattern;
92
93 LogicalResult matchAndRewrite(OpTy insertSliceOp,
94 PatternRewriter &rewriter) const override {
95 auto collapseShapeOp =
96 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
97 if (!collapseShapeOp)
98 return failure();
99 RankedTensorType srcType = collapseShapeOp.getSrcType();
100
101 // Only cases where the CollapseShapeOp can be folded away entirely are
102 // supported. Moreover, only simple cases where the resulting InsertSliceOp
103 // has no rank-reduction anymore are supported at the moment.
104 RankedTensorType nonReducingInsertType =
105 RankedTensorType::get(shape: insertSliceOp.getStaticSizes(),
106 elementType: insertSliceOp.getDestType().getElementType());
107 if (nonReducingInsertType != srcType)
108 return failure();
109
110 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
111 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
112 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
113 rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
114 insertSliceOp.getDest(), mixedOffsets,
115 mixedSizes, mixedStrides);
116 return success();
117 }
118};
119
120/// Fold expand_shape which only adds static dimensions of size `1`
121/// into insert_slice.
122template <typename OpTy>
123struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
124 using OpRewritePattern<OpTy>::OpRewritePattern;
125
126 LogicalResult matchAndRewrite(OpTy insertSliceOp,
127 PatternRewriter &rewriter) const override {
128 auto expandShapeOp = insertSliceOp.getSource()
129 .template getDefiningOp<tensor::ExpandShapeOp>();
130 if (!expandShapeOp)
131 return failure();
132
133 // Only fold away simple expansion where all added dimensions have static
134 // size `1`.
135 SliceVerificationResult res = isRankReducedType(
136 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
137 if (res != SliceVerificationResult::Success)
138 return rewriter.notifyMatchFailure(insertSliceOp,
139 "expected rank increasing expansion");
140
141 rewriter.modifyOpInPlace(insertSliceOp, [&]() {
142 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
143 });
144 return success();
145 }
146};
147
148/// Pattern to bubble up a tensor.expand_shape op through a producer
149/// tensor.collapse_shape op that has non intersecting reassociations.
150struct BubbleUpExpandThroughParallelCollapse
151 : public OpRewritePattern<tensor::ExpandShapeOp> {
152 using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
153
154 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
155 PatternRewriter &rewriter) const override {
156 auto collapseOp =
157 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
158 if (!collapseOp)
159 return failure();
160 auto expandReInds = expandOp.getReassociationIndices();
161 auto collapseReInds = collapseOp.getReassociationIndices();
162
163 // Special case where the collapsed tensor to expand is a 0-D tensor,
164 // then the reassociation maps will be empty and not produce valid results.
165 if (expandReInds.size() == 0) {
166 return failure();
167 }
168
169 // Reshapes are parallel to each other (by construction the number of
170 // reassociations specified in the collapse and expand are the same), if at
171 // any position
172 // 1. either the reassociation indices are of the same size, or
173 // 2. either the reassociation in the collapse or the expand is of size 1.
174 ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
175 ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
176 for (auto [expandReassociation, collapseReassociation] :
177 llvm::zip_equal(t&: expandReInds, u&: collapseReInds)) {
178 if (collapseReassociation.size() == expandReassociation.size()) {
179 // Even if the reassociations are the same, the collapse/expand should
180 // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
181 // into 4x8x2 again. In presense of dynamic dimensions one can only
182 // verify "equality" when there is only one dynamic dimension present,
183 // and all other static dimensions are equal.
184 ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
185 N: collapseReassociation.front(), M: collapseReassociation.size());
186 int64_t numCollapsedDynamic =
187 llvm::count_if(Range&: collapsedStaticShapes, P: ShapedType::isDynamic);
188 ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
189 N: expandReassociation.front(), M: expandReassociation.size());
190 int64_t numExpandedDynamic =
191 llvm::count_if(Range&: expandedStaticShapes, P: ShapedType::isDynamic);
192 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
193 collapsedStaticShapes != expandedStaticShapes) {
194 return failure();
195 }
196 continue;
197 }
198 // If the reassociations are not same, one or the other needs to be of
199 // size one.
200 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
201 return failure();
202 }
203
204 // Compute new reassociation indices and expanded/collaped shapes.
205 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
206 Location loc = expandOp->getLoc();
207 SmallVector<OpFoldResult> sourceSizes =
208 tensor::getMixedSizes(builder&: rewriter, loc, value: collapseOp.getSrc());
209 SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
210 SmallVector<OpFoldResult> newExpandSizes;
211
212 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
213 resultSizeIndex = 0;
214
215 for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
216 auto &collapseReassociation = collapseReInds[idx];
217 auto &expandReassociation = expandReInds[idx];
218
219 // Case 1. The reassociations are same in the collapse producer
220 // and expand consumer. In the swapped expand, each of the final
221 // dimensions are kept as is in the expand and the collapse. So,
222 // for every element in the `ReassocationIndices` vector add a new
223 // `ReassociationIndices` vector for the swapped expand and collapse
224 // (of size 1).
225 if (collapseReassociation.size() == expandReassociation.size()) {
226 for (size_t i = 0; i < collapseReassociation.size(); ++i) {
227 newCollapseReInds.push_back(Elt: {newCollapseIndex++});
228 newExpandReInds.push_back(Elt: {newExpandIndex++});
229 newExpandSizes.push_back(Elt: resultSizes[resultSizeIndex++]);
230 sourceSizeIndex++;
231 }
232 continue;
233 }
234
235 // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
236 // in the expand is of size == 1). In this case, the original dimensions
237 // are preserved on expansion and collapsed subsequently.
238 if (collapseReassociation.size() != 1) {
239 ReassociationIndices newCollapseReassociation;
240 for (size_t i = 0; i < collapseReassociation.size(); ++i) {
241 newCollapseReassociation.push_back(Elt: newCollapseIndex++);
242 newExpandReInds.push_back(Elt: {newExpandIndex++});
243 newExpandSizes.push_back(Elt: sourceSizes[sourceSizeIndex++]);
244 }
245 resultSizeIndex++;
246 newCollapseReInds.push_back(Elt: newCollapseReassociation);
247 continue;
248 }
249
250 // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
251 // in the collapse is of size == 1). In this case, the expansion happens
252 // first and the expanded dimensions are preserved on collapse.
253 ReassociationIndices newExpandReassociation;
254 for (size_t i = 0; i < expandReassociation.size(); ++i) {
255 newExpandReassociation.push_back(Elt: newExpandIndex++);
256 newCollapseReInds.push_back(Elt: {newCollapseIndex++});
257 newExpandSizes.push_back(Elt: resultSizes[resultSizeIndex++]);
258 }
259 newExpandReInds.push_back(Elt: newExpandReassociation);
260 sourceSizeIndex++;
261 }
262
263 // Swap reshape order.
264 SmallVector<Value> dynamicSizes;
265 SmallVector<int64_t> staticSizes;
266 dispatchIndexOpFoldResults(ofrs: newExpandSizes, dynamicVec&: dynamicSizes, staticVec&: staticSizes);
267 auto expandResultType = expandOp.getResultType().clone(shape: staticSizes);
268 Value newCollapseSrc = collapseOp.getSrc();
269 // If the number of reassociation indices in the new `expand_shape` op
270 // matches the number of dimensions of the result, then the expand_shape
271 // is a no-op.
272 if (newExpandReInds.size() != newExpandSizes.size()) {
273 newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>(
274 location: loc, args&: expandResultType, args&: newCollapseSrc, args&: newExpandReInds,
275 args&: newExpandSizes);
276 }
277
278 // If the number of reassociation indices in the new `collapse_shape` op
279 // matches the number of dimensions of the source, then the collapse_shape
280 // is a no-op.
281 Value replacement = newCollapseSrc;
282 if (newCollapseReInds.size() != newExpandSizes.size()) {
283 replacement = rewriter.create<tensor::CollapseShapeOp>(
284 location: loc, args&: newCollapseSrc, args&: newCollapseReInds);
285 }
286 rewriter.replaceOp(op: expandOp, newValues: replacement);
287 return success();
288 }
289};
290
291/// Converts `tensor.extract_slice(tensor.expand_shape)` to
292/// `tensor.expand_shape(tensor.extract_slice)`.
293///
294/// For this transformation to be possible, the slice must be fully contiguous
295/// within each reassociation group of the expand_shape. A slice is defined as
296/// fully contiguous within a reassociation group if after flattening the
297/// reassociation group to a single 1D range, then the slice taken out of the
298/// group could be defined as a single contiguous subrange within that range.
299///
300/// Rank reducing slices are not supported.
301///
302/// Example:
303/// The transformation is possible because each reassociation group has a
304/// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
305/// ```
306/// BEFORE:
307/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
308/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
309/// %slice = tensor.extract_slice %reshape ...
310/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
311///
312/// AFTER:
313/// %slice = tensor.extract_slice %in ...
314/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
315/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
316/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
317/// ```
318///
319/// Note - this pattern could be extended to be a swap pattern between
320/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
321/// implemented only as a bubble up pattern for `tensor.extract_slice`.
322struct BubbleUpExpandShapeThroughExtractSlice
323 : public OpRewritePattern<tensor::ExtractSliceOp> {
324 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
325
326 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
327 PatternRewriter &rewriter) const override {
328 auto expandShapeOp =
329 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
330
331 if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
332 rewriter)
333 .failed())
334 return failure();
335
336 // The tensor.extract_slice before applying the pattern works on the result
337 // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
338 // referring to the state before applying the pattern are named with the
339 // prefix "expanded", and ones referring to the state after applying the
340 // pattern are named with the prefix "collapsed".
341 SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
342 SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
343 SmallVector<OpFoldResult> expandedShape =
344 getMixedValues(staticValues: expandShapeOp.getStaticOutputShape(),
345 dynamicValues: expandShapeOp.getOutputShape(), b&: rewriter);
346
347 // Helper variables and function for accumulating the size values.
348 Location loc = expandShapeOp->getLoc();
349 AffineExpr d0, d1, d2;
350 bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1, exprs&: d2);
351 // Multiply two integers.
352 auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
353 auto mulMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {d0 * d1});
354 return affine::makeComposedFoldedAffineApply(b&: rewriter, loc, map: mulMap,
355 operands: {v1, v2});
356 };
357
358 // Compute new offsets, sizes, and strides for tensor.extract_slice.
359 // The new tensor.extract_slice will work on a tensor that has has a rank of
360 // ReassociationIndices.size(). In the loop a single offset, size, and
361 // stride value is computed per reassociation group.
362 SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
363 collapsedStrides;
364 for (const ReassociationIndices &indices :
365 expandShapeOp.getReassociationIndices()) {
366 // collapsedSize will hold the size of the single dim that represents the
367 // reassociation group in the non expanded tensor.
368 OpFoldResult collapsedSize = rewriter.getIndexAttr(value: 1);
369 // The reassocGroupSizes and reassocGroupOffsets are used to create an
370 // affine.linearize_index op to linearize the single offset value required
371 // for this reassociation group.
372 SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
373
374 for (long expandedDim : indices) {
375 // reassocGroupSizes and reassocGroupOffsets can be obtained directly
376 // from the expanded state, but the collapsed size requires calculation
377 // as it did not previously exist.
378 reassocGroupSizes.push_back(Elt: expandedShape[expandedDim]);
379 reassocGroupOffsets.push_back(Elt: expandedOffsets[expandedDim]);
380 collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
381 }
382
383 SmallVector<Value> offsetVals =
384 llvm::map_to_vector(C&: reassocGroupOffsets, F: [&](OpFoldResult ofr) {
385 return getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr);
386 });
387 OpFoldResult collapsedOffset =
388 rewriter
389 .create<affine::AffineLinearizeIndexOp>(location: loc, args&: offsetVals,
390 args&: reassocGroupSizes,
391 /*disjoint=*/args: true)
392 .getResult();
393 collapsedOffsets.push_back(Elt: collapsedOffset);
394 collapsedSizes.push_back(Elt: collapsedSize);
395
396 // Only unit stride is supported.
397 collapsedStrides.push_back(Elt: rewriter.getIndexAttr(value: 1));
398 }
399
400 // The shape of the result can be obtained from the sizes passed in.
401 SmallVector<Value> dynDims;
402 SmallVector<int64_t> shape;
403 dispatchIndexOpFoldResults(ofrs: expandedSizes, dynamicVec&: dynDims, staticVec&: shape);
404 RankedTensorType resultType = RankedTensorType::get(
405 shape, elementType: expandShapeOp.getResultType().getElementType());
406
407 // Create a new ExtractSliceOp and ExpandShapeOp.
408 Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
409 location: loc, args: expandShapeOp.getSrc(), args&: collapsedOffsets, args&: collapsedSizes,
410 args&: collapsedStrides);
411 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
412 op: sliceOp, args&: resultType, args&: newSliceOp,
413 args: expandShapeOp.getReassociationIndices(), args&: expandedSizes);
414 return success();
415 }
416
417 // Helper function to check if all the required conditions for the
418 // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
419 // met.
420 LogicalResult
421 checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
422 tensor::ExpandShapeOp expandShapeOp,
423 PatternRewriter &rewriter) const {
424
425 if (!expandShapeOp) {
426 return rewriter.notifyMatchFailure(
427 arg&: sliceOp, msg: "tensor.extract_slice source not produced by expand_shape");
428 }
429
430 if (!sliceOp.hasUnitStride()) {
431 return rewriter.notifyMatchFailure(
432 arg&: sliceOp, msg: "unsupported: non-unit stride. Only contiguous slices can "
433 "be supported in this transformation.");
434 }
435
436 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
437 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
438
439 if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
440 sizes.size()) {
441 return rewriter.notifyMatchFailure(arg&: sliceOp,
442 msg: "unimplemented: rank reducing slice");
443 }
444
445 SmallVector<OpFoldResult> outputShape =
446 getMixedValues(staticValues: expandShapeOp.getStaticOutputShape(),
447 dynamicValues: expandShapeOp.getOutputShape(), b&: rewriter);
448
449 std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
450 isZeroOffsetAndFullSize =
451 [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
452 if (!isZeroInteger(v: offset))
453 return false;
454 FailureOr<bool> maybeEqual =
455 ValueBoundsConstraintSet::areEqual(var1: sliceSize, var2: size);
456 return llvm::succeeded(Result: maybeEqual) && maybeEqual.value();
457 };
458
459 // Check that the slice is contiguous within each reassociation group.
460 // The slice is contiguous only if after the first dimension where a non
461 // unit slice is taken, the slice size on all subsequent dimensions of the
462 // group is equal to the entire size of the dimension.
463 // Examples of contiguous slices:
464 // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
465 // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
466 // Examples of non contiguous slices:
467 // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
468 // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
469 for (const ReassociationIndices &indices :
470 expandShapeOp.getReassociationIndices()) {
471 int64_t i = 0;
472 int64_t e = indices.size();
473 // Find the first expanded dim after the first dim with non-unit extracted
474 // size.
475 for (; i < e; ++i) {
476 if (!isOneInteger(v: sizes[indices[i]])) {
477 // +1 to skip the first non-unit size dim.
478 i++;
479 break;
480 }
481 }
482
483 // Verify that all subsequent dimensions extract the full size of the
484 // source tensor.
485 for (; i < e; ++i) {
486 int64_t expandedDim = indices[i];
487 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
488 outputShape[expandedDim])) {
489 return rewriter.notifyMatchFailure(
490 arg&: sliceOp, msg: "Not a contiguous slice of the expanded tensor.");
491 }
492 }
493 }
494
495 return success();
496 }
497};
498
499/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
500/// `tensor.collapse_shape(tensor.extract_slice)`.
501///
502/// For this transformation to be possible - after bubbling up, the extraction
503/// of the contiguous slice must be representable as a single slice obtained via
504/// tensor.extract_slice within each reassociation group of the src.
505///
506/// In case the size and offset extracted are static then this is possible if
507/// the following conditions are met within each reassociation group:
508/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
509/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
510/// shape of a desired slice. A slice of shape S can be extracted as a
511/// contiguous span of elements if and only if there exists an index k in {0, 1,
512/// ..., n} such that:
513/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
514/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
515/// one dimension),
516/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
517/// in full).
518/// In other words, the slice shape S must be of the form:
519/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
520///
521/// In case the size and/or offset extracted are dynamic then this is possible
522/// only if there is single dimension in the reassociation group that has a size
523/// not equal to 1.
524/// In other words, the tensor shape must be of the form:
525/// [ 1, 1, ..., 1, A, 1, ...,1 ]
526/// Note - it might be possible to enable this pattern for more cases when the
527/// size/offset are dynamic via performing an analysis of the possible values
528/// that could be given to the size/offset.
529///
530/// Example:
531/// The transformation is possible because each reassociation group can be
532/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
533/// [20->10]).
534/// ```
535/// BEFORE:
536/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
537/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
538/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
539/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
540///
541/// AFTER:
542/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
543// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
544/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
545/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
546/// ```
547///
548/// Negative example:
549/// The transformation is not possible because we cannot use a single slice to
550/// represent the reassociation group [2x3x10->???]. If we would want the
551/// collapse to be after the extraction, we would need to extract multiple
552/// slices and concat them together.
553/// ```
554/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
555/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
556/// tensor<60xf32> to tensor<15xf32>
557/// ```
558/// If we would want the collapse to be after the extraction, a possible
559/// alternate transformation could be to extract multiple slices and concat them
560/// together:
561/// ```
562/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
563/// tensor<2x3x10xf32> to tensor <1x1x10xf32>
564/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
565/// tensor<2x3x10xf32> to tensor <1x1x5xf32>
566/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
567/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
568/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
569/// to tensor<15xf32>
570/// ```
571/// But this is not the intended purpose of the transformation.
572struct BubbleUpCollapseShapeThroughExtractSlice
573 : public OpRewritePattern<tensor::ExtractSliceOp> {
574 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
575
576 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
577 PatternRewriter &rewriter) const override {
578 auto collapseShapeOp =
579 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
580 if (!collapseShapeOp) {
581 return rewriter.notifyMatchFailure(
582 arg&: sliceOp,
583 msg: "tensor.extract_slice source not produced by tensor.collapse_shape");
584 }
585
586 if (!sliceOp.hasUnitStride()) {
587 return rewriter.notifyMatchFailure(
588 arg&: sliceOp, msg: "unsupported: non-unit stride. Only contiguous slices can "
589 "be supported in this transformation.");
590 }
591
592 // The tensor.extract_slice before applying the pattern works on the result
593 // of the tensor.collapse_shape, so variables (i.e. inputs for
594 // ExtractSliceOp) referring to the state before applying the pattern are
595 // named with the prefix "collapsed", and ones referring to the state after
596 // applying the pattern are named with the prefix "expanded".
597 SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
598 SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
599
600 if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
601 collapsedSizes.size()) {
602 return rewriter.notifyMatchFailure(arg&: sliceOp,
603 msg: "unimplemented: rank reducing slice");
604 }
605
606 ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
607 SmallVector<ReassociationIndices, 4> reassociationIndices =
608 collapseShapeOp.getReassociationIndices();
609
610 // Compute new offsets, sizes, and strides for tensor.extract_slice.
611 // The new tensor.extract_slice will work on a tensor that has has a rank
612 // equal to the rank of the src of the collapse_shape. In each iteration of
613 // the loop, the offsets and sizes will be computed per reassociation group.
614 SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
615 SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
616 rewriter.getIndexAttr(value: 1));
617
618 for (auto [collapsedSize, collapsedOffset, reassocIndices] :
619 llvm::zip_equal(t&: collapsedSizes, u&: collapsedOffsets,
620 args: collapseShapeOp.getReassociationIndices())) {
621 // CASE #1 - size and/or offset are dynamic.
622 // In this case, the slice can be represented as a contiguous slice only
623 // if there is a single dimension in the reassociation group that has a
624 // size not equal to 1.
625 if (isa<Value>(Val: collapsedSize) || isa<Value>(Val: collapsedOffset)) {
626 int nonUnitSizeCount = 0;
627 for (int64_t expandedShapeIdx : reassocIndices) {
628 if (srcShape[expandedShapeIdx] != 1) {
629 nonUnitSizeCount++;
630 expandedSizes.push_back(Elt: collapsedSize);
631 expandedOffsets.push_back(Elt: collapsedOffset);
632 continue;
633 }
634
635 expandedSizes.push_back(Elt: rewriter.getIndexAttr(value: 1));
636 expandedOffsets.push_back(Elt: rewriter.getIndexAttr(value: 0));
637 }
638
639 if (nonUnitSizeCount != 1) {
640 return rewriter.notifyMatchFailure(
641 arg&: sliceOp,
642 msg: "unsupported: slice cannot be verified to be contiguous");
643 }
644 continue;
645 }
646
647 // CASE #2 = size and offset are static.
648 // Verify that the slice can be represented as a contiguous slice of the
649 // src of the collapse_shape.
650 // Checking this is done on order of most internal dimensions first,
651 // so traversal is done in reverse order of the reassociation group.
652 // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
653 // ...,An] then we first find the size and offset for n...k+1 then for k
654 // and then for k-1...0.
655
656 // currentCollapsedsize and currentCollapsedOffset are initialized with
657 // the original collapsed size and offset and divided by the expanded
658 // shape size in each dimension as we go along the reassociation group.
659 // In essence we are spreading the original collapsed size and offset over
660 // the various expanded slice dimensions.
661 // The variables are used both to check the validity of the slice and to
662 // compute the expanded sizes and offsets.
663 int64_t currentCollapsedsize = getConstantIntValue(ofr: collapsedSize).value();
664 int64_t currentCollapsedOffset =
665 getConstantIntValue(ofr: collapsedOffset).value();
666
667 SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
668
669 ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
670 reassocIndices.rend());
671 int64_t idx = 0;
672 int64_t reassocGroupSize = reassocIndices.size();
673
674 // First handle the trailing dimensions where the slice size should be
675 // equal to the tensor shape and the offset should be 0 (n...k+1).
676 for (; idx < reassocGroupSize; ++idx) {
677 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
678
679 if (currentCollapsedsize < expandedShapeSize)
680 break;
681
682 // We need to make sure that the slice size can be set to the shape size
683 // and the offset to 0.
684 if ((currentCollapsedsize % expandedShapeSize) != 0 ||
685 (currentCollapsedOffset % expandedShapeSize) != 0) {
686 return rewriter.notifyMatchFailure(
687 arg&: sliceOp, msg: "unsupported: cannot be extracted as a contiguous slice "
688 "of the src of the collapse_shape");
689 }
690
691 groupExpandedSizes.push_back(Elt: rewriter.getIndexAttr(value: expandedShapeSize));
692 groupExpandedOffsets.push_back(Elt: rewriter.getIndexAttr(value: 0));
693
694 currentCollapsedsize /= expandedShapeSize;
695 currentCollapsedOffset /= expandedShapeSize;
696 }
697
698 // Now handle the first dim where slicing occurs on (k).
699 if (idx < reassocGroupSize) {
700 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
701 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
702 // We need to make sure that the slice size in this dim + offset will
703 // not exceed the shape size.
704 if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
705 return rewriter.notifyMatchFailure(
706 arg&: sliceOp, msg: "unsupported: slice cannot be extracted as a contiguous "
707 "slice of the src of the collapse_shape");
708 }
709
710 groupExpandedSizes.push_back(
711 Elt: rewriter.getIndexAttr(value: currentCollapsedsize));
712 groupExpandedOffsets.push_back(Elt: rewriter.getIndexAttr(value: offsetInDim));
713
714 currentCollapsedOffset /= expandedShapeSize;
715 }
716
717 // Now handle the leading dimensions where the slice size is equal to 1
718 // (k-1...0).
719 // The size for these dimensions must be 1 because of how we constructed
720 // the slice size of the expanded shape. We spread the original collapsed
721 // size over the expanded shape sizes until we reached dimension k where
722 // the remaining size was smaller than the expanded shape size, and spread
723 // the remaining size on it. So, now we are left with only 1s.
724 for (idx++; idx < reassocGroupSize; ++idx) {
725 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
726 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
727 groupExpandedSizes.push_back(Elt: rewriter.getIndexAttr(value: 1));
728 groupExpandedOffsets.push_back(Elt: rewriter.getIndexAttr(value: offsetInDim));
729 currentCollapsedOffset /= expandedShapeSize;
730 }
731
732 expandedSizes.append(in_start: groupExpandedSizes.rbegin(),
733 in_end: groupExpandedSizes.rend());
734 expandedOffsets.append(in_start: groupExpandedOffsets.rbegin(),
735 in_end: groupExpandedOffsets.rend());
736 }
737
738 Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
739 location: collapseShapeOp->getLoc(), args: collapseShapeOp.getSrc(), args&: expandedOffsets,
740 args&: expandedSizes, args&: expandedStrides);
741 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
742 op: sliceOp, args: sliceOp.getResultType(), args&: newSliceOp,
743 args: collapseShapeOp.getReassociationIndices());
744
745 return success();
746 }
747};
748
749} // namespace
750
751void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
752 RewritePatternSet &patterns) {
753 patterns
754 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
755 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
756 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
757 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
758 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
759 arg: patterns.getContext());
760}
761
762void mlir::tensor::populateBubbleUpExpandShapePatterns(
763 RewritePatternSet &patterns) {
764 patterns.add<BubbleUpExpandThroughParallelCollapse>(arg: patterns.getContext());
765}
766
767void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
768 RewritePatternSet &patterns) {
769 patterns.add<BubbleUpExpandShapeThroughExtractSlice,
770 BubbleUpCollapseShapeThroughExtractSlice>(arg: patterns.getContext());
771}
772

source code of mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp