1//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Arith/Utils/Utils.h"
11#include "mlir/Dialect/Tensor/IR/Tensor.h"
12#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Interfaces/ValueBoundsOpInterface.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/Support/Debug.h"
17#include "llvm/Support/LogicalResult.h"
18
19using namespace mlir;
20using namespace mlir::tensor;
21
22namespace {
23/// Fold expand_shape(extract_slice) ops that cancel itself out.
24struct FoldExpandOfRankReducingExtract
25 : public OpRewritePattern<ExpandShapeOp> {
26 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
27
28 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
29 PatternRewriter &rewriter) const override {
30 RankedTensorType resultType = expandShapeOp.getResultType();
31 auto extractSliceOp =
32 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
33 if (!extractSliceOp)
34 return failure();
35 RankedTensorType srcType = extractSliceOp.getSourceType();
36
37 // Only cases where the ExpandShapeOp can be folded away entirely are
38 // supported. Moreover, only simple cases where the resulting ExtractSliceOp
39 // has no rank-reduction anymore are supported at the moment.
40 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
41 srcType, extractSliceOp.getStaticOffsets(),
42 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
43 if (nonReducingExtractType != resultType)
44 return failure();
45
46 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
47 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
48 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
49 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
50 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
51 mixedStrides);
52 return success();
53 }
54};
55
56/// Fold collapse_shape which only removes static dimensions of size `1`
57/// into extract_slice.
58struct FoldUnPaddingCollapseIntoExtract
59 : public OpRewritePattern<tensor::CollapseShapeOp> {
60 using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
61
62 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
63 PatternRewriter &rewriter) const override {
64 auto extractSliceOp =
65 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
66 // Collapse cannot be folded away with multiple users of the extract slice
67 // and it is not necessarily beneficial to only convert the collapse into
68 // another extract slice.
69 if (!extractSliceOp || !extractSliceOp->hasOneUse())
70 return failure();
71
72 // Only fold away simple collapse where all removed dimensions have static
73 // size `1`.
74 SliceVerificationResult res = isRankReducedType(
75 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
76 if (res != SliceVerificationResult::Success)
77 return rewriter.notifyMatchFailure(collapseShapeOp,
78 "expected unpadding collapse");
79
80 Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
81 extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
82 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
83 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
84 rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
85 return success();
86 }
87};
88
89/// Fold insert_slice(collapse_shape) ops that cancel itself out.
90template <typename OpTy>
91struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
92 using OpRewritePattern<OpTy>::OpRewritePattern;
93
94 LogicalResult matchAndRewrite(OpTy insertSliceOp,
95 PatternRewriter &rewriter) const override {
96 auto collapseShapeOp =
97 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
98 if (!collapseShapeOp)
99 return failure();
100 RankedTensorType srcType = collapseShapeOp.getSrcType();
101
102 // Only cases where the CollapseShapeOp can be folded away entirely are
103 // supported. Moreover, only simple cases where the resulting InsertSliceOp
104 // has no rank-reduction anymore are supported at the moment.
105 RankedTensorType nonReducingInsertType =
106 RankedTensorType::get(insertSliceOp.getStaticSizes(),
107 insertSliceOp.getDestType().getElementType());
108 if (nonReducingInsertType != srcType)
109 return failure();
110
111 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
112 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
113 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
114 rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
115 insertSliceOp.getDest(), mixedOffsets,
116 mixedSizes, mixedStrides);
117 return success();
118 }
119};
120
121/// Fold expand_shape which only adds static dimensions of size `1`
122/// into insert_slice.
123template <typename OpTy>
124struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
125 using OpRewritePattern<OpTy>::OpRewritePattern;
126
127 LogicalResult matchAndRewrite(OpTy insertSliceOp,
128 PatternRewriter &rewriter) const override {
129 auto expandShapeOp = insertSliceOp.getSource()
130 .template getDefiningOp<tensor::ExpandShapeOp>();
131 if (!expandShapeOp)
132 return failure();
133
134 // Only fold away simple expansion where all added dimensions have static
135 // size `1`.
136 SliceVerificationResult res = isRankReducedType(
137 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
138 if (res != SliceVerificationResult::Success)
139 return rewriter.notifyMatchFailure(insertSliceOp,
140 "expected rank increasing expansion");
141
142 rewriter.modifyOpInPlace(insertSliceOp, [&]() {
143 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
144 });
145 return success();
146 }
147};
148
149/// Pattern to bubble up a tensor.expand_shape op through a producer
150/// tensor.collapse_shape op that has non intersecting reassociations.
151struct BubbleUpExpandThroughParallelCollapse
152 : public OpRewritePattern<tensor::ExpandShapeOp> {
153 using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
154
155 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
156 PatternRewriter &rewriter) const override {
157 auto collapseOp =
158 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
159 if (!collapseOp)
160 return failure();
161 auto expandReInds = expandOp.getReassociationIndices();
162 auto collapseReInds = collapseOp.getReassociationIndices();
163
164 // Special case where the collapsed tensor to expand is a 0-D tensor,
165 // then the reassociation maps will be empty and not produce valid results.
166 if (expandReInds.size() == 0) {
167 return failure();
168 }
169
170 // Reshapes are parallel to each other (by construction the number of
171 // reassociations specified in the collapse and expand are the same), if at
172 // any position
173 // 1. either the reassociation indices are of the same size, or
174 // 2. either the reassociation in the collapse or the expand is of size 1.
175 ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
176 ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
177 for (auto [expandReassociation, collapseReassociation] :
178 llvm::zip_equal(expandReInds, collapseReInds)) {
179 if (collapseReassociation.size() == expandReassociation.size()) {
180 // Even if the reassociations are the same, the collapse/expand should
181 // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
182 // into 4x8x2 again. In presense of dynamic dimensions one can only
183 // verify "equality" when there is only one dynamic dimension present,
184 // and all other static dimensions are equal.
185 ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
186 collapseReassociation.front(), collapseReassociation.size());
187 int64_t numCollapsedDynamic =
188 llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
189 ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
190 expandReassociation.front(), expandReassociation.size());
191 int64_t numExpandedDynamic =
192 llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
193 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
194 collapsedStaticShapes != expandedStaticShapes) {
195 return failure();
196 }
197 continue;
198 }
199 // If the reassociations are not same, one or the other needs to be of
200 // size one.
201 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
202 return failure();
203 }
204
205 // Compute new reassociation indices and expanded/collaped shapes.
206 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
207 Location loc = expandOp->getLoc();
208 SmallVector<OpFoldResult> sourceSizes =
209 tensor::getMixedSizes(builder&: rewriter, loc, value: collapseOp.getSrc());
210 SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
211 SmallVector<OpFoldResult> newExpandSizes;
212
213 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
214 resultSizeIndex = 0;
215
216 for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
217 auto &collapseReassociation = collapseReInds[idx];
218 auto &expandReassociation = expandReInds[idx];
219
220 // Case 1. The reassociations are same in the collapse producer
221 // and expand consumer. In the swapped expand, each of the final
222 // dimensions are kept as is in the expand and the collapse. So,
223 // for every element in the `ReassocationIndices` vector add a new
224 // `ReassociationIndices` vector for the swapped expand and collapse
225 // (of size 1).
226 if (collapseReassociation.size() == expandReassociation.size()) {
227 for (size_t i = 0; i < collapseReassociation.size(); ++i) {
228 newCollapseReInds.push_back(Elt: {newCollapseIndex++});
229 newExpandReInds.push_back(Elt: {newExpandIndex++});
230 newExpandSizes.push_back(Elt: resultSizes[resultSizeIndex++]);
231 sourceSizeIndex++;
232 }
233 continue;
234 }
235
236 // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
237 // in the expand is of size == 1). In this case, the original dimensions
238 // are preserved on expansion and collapsed subsequently.
239 if (collapseReassociation.size() != 1) {
240 ReassociationIndices newCollapseReassociation;
241 for (size_t i = 0; i < collapseReassociation.size(); ++i) {
242 newCollapseReassociation.push_back(Elt: newCollapseIndex++);
243 newExpandReInds.push_back(Elt: {newExpandIndex++});
244 newExpandSizes.push_back(Elt: sourceSizes[sourceSizeIndex++]);
245 }
246 resultSizeIndex++;
247 newCollapseReInds.push_back(Elt: newCollapseReassociation);
248 continue;
249 }
250
251 // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
252 // in the collapse is of size == 1). In this case, the expansion happens
253 // first and the expanded dimensions are preserved on collapse.
254 ReassociationIndices newExpandReassociation;
255 for (size_t i = 0; i < expandReassociation.size(); ++i) {
256 newExpandReassociation.push_back(Elt: newExpandIndex++);
257 newCollapseReInds.push_back(Elt: {newCollapseIndex++});
258 newExpandSizes.push_back(Elt: resultSizes[resultSizeIndex++]);
259 }
260 newExpandReInds.push_back(Elt: newExpandReassociation);
261 sourceSizeIndex++;
262 }
263
264 // Swap reshape order.
265 SmallVector<Value> dynamicSizes;
266 SmallVector<int64_t> staticSizes;
267 dispatchIndexOpFoldResults(ofrs: newExpandSizes, dynamicVec&: dynamicSizes, staticVec&: staticSizes);
268 auto expandResultType = expandOp.getResultType().clone(staticSizes);
269 Value newCollapseSrc = collapseOp.getSrc();
270 // If the number of reassociation indices in the new `expand_shape` op
271 // matches the number of dimensions of the result, then the expand_shape
272 // is a no-op.
273 if (newExpandReInds.size() != newExpandSizes.size()) {
274 newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>(
275 loc, expandResultType, newCollapseSrc, newExpandReInds,
276 newExpandSizes);
277 }
278
279 // If the number of reassociation indices in the new `collapse_shape` op
280 // matches the number of dimensions of the source, then the collapse_shape
281 // is a no-op.
282 Value replacement = newCollapseSrc;
283 if (newCollapseReInds.size() != newExpandSizes.size()) {
284 replacement = rewriter.create<tensor::CollapseShapeOp>(
285 loc, newCollapseSrc, newCollapseReInds);
286 }
287 rewriter.replaceOp(expandOp, replacement);
288 return success();
289 }
290};
291
292/// Converts `tensor.extract_slice(tensor.expand_shape)` to
293/// `tensor.expand_shape(tensor.extract_slice)`.
294///
295/// For this transformation to be possible, the slice must be fully contiguous
296/// within each reassociation group of the expand_shape. A slice is defined as
297/// fully contiguous within a reassociation group if after flattening the
298/// reassociation group to a single 1D range, then the slice taken out of the
299/// group could be defined as a single contiguous subrange within that range.
300///
301/// Rank reducing slices are not supported.
302///
303/// Example:
304/// The transformation is possible because each reassociation group has a
305/// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
306/// ```
307/// BEFORE:
308/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
309/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
310/// %slice = tensor.extract_slice %reshape ...
311/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
312///
313/// AFTER:
314/// %slice = tensor.extract_slice %in ...
315/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
316/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
317/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
318/// ```
319///
320/// Note - this pattern could be extended to be a swap pattern between
321/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
322/// implemented only as a bubble up pattern for `tensor.extract_slice`.
323struct BubbleUpExpandShapeThroughExtractSlice
324 : public OpRewritePattern<tensor::ExtractSliceOp> {
325 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
326
327 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
328 PatternRewriter &rewriter) const override {
329 auto expandShapeOp =
330 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
331
332 if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
333 rewriter)
334 .failed())
335 return failure();
336
337 // The tensor.extract_slice before applying the pattern works on the result
338 // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
339 // referring to the state before applying the pattern are named with the
340 // prefix "expanded", and ones referring to the state after applying the
341 // pattern are named with the prefix "collapsed".
342 SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
343 SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
344 SmallVector<OpFoldResult> expandedShape =
345 getMixedValues(expandShapeOp.getStaticOutputShape(),
346 expandShapeOp.getOutputShape(), rewriter);
347
348 // Helper variables and function for accumulating the size values.
349 Location loc = expandShapeOp->getLoc();
350 AffineExpr d0, d1, d2;
351 bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1, exprs&: d2);
352 // Multiply two integers.
353 auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
354 auto mulMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {d0 * d1});
355 return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
356 {v1, v2});
357 };
358
359 // Compute new offsets, sizes, and strides for tensor.extract_slice.
360 // The new tensor.extract_slice will work on a tensor that has has a rank of
361 // ReassociationIndices.size(). In the loop a single offset, size, and
362 // stride value is computed per reassociation group.
363 SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
364 collapsedStrides;
365 for (const ReassociationIndices &indices :
366 expandShapeOp.getReassociationIndices()) {
367 // collapsedSize will hold the size of the single dim that represents the
368 // reassociation group in the non expanded tensor.
369 OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
370 // The reassocGroupSizes and reassocGroupOffsets are used to create an
371 // affine.linearize_index op to linearize the single offset value required
372 // for this reassociation group.
373 SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
374
375 for (long expandedDim : indices) {
376 // reassocGroupSizes and reassocGroupOffsets can be obtained directly
377 // from the expanded state, but the collapsed size requires calculation
378 // as it did not previously exist.
379 reassocGroupSizes.push_back(expandedShape[expandedDim]);
380 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
381 collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
382 }
383
384 SmallVector<Value> offsetVals =
385 llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
386 return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
387 });
388 OpFoldResult collapsedOffset =
389 rewriter
390 .create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
391 reassocGroupSizes,
392 /*disjoint=*/true)
393 .getResult();
394 collapsedOffsets.push_back(collapsedOffset);
395 collapsedSizes.push_back(collapsedSize);
396
397 // Only unit stride is supported.
398 collapsedStrides.push_back(rewriter.getIndexAttr(1));
399 }
400
401 // The shape of the result can be obtained from the sizes passed in.
402 SmallVector<Value> dynDims;
403 SmallVector<int64_t> shape;
404 dispatchIndexOpFoldResults(ofrs: expandedSizes, dynamicVec&: dynDims, staticVec&: shape);
405 RankedTensorType resultType = RankedTensorType::get(
406 shape, expandShapeOp.getResultType().getElementType());
407
408 // Create a new ExtractSliceOp and ExpandShapeOp.
409 Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
410 loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
411 collapsedStrides);
412 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
413 sliceOp, resultType, newSliceOp,
414 expandShapeOp.getReassociationIndices(), expandedSizes);
415 return success();
416 }
417
418 // Helper function to check if all the required conditions for the
419 // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
420 // met.
421 LogicalResult
422 checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
423 tensor::ExpandShapeOp expandShapeOp,
424 PatternRewriter &rewriter) const {
425
426 if (!expandShapeOp) {
427 return rewriter.notifyMatchFailure(
428 sliceOp, "tensor.extract_slice source not produced by expand_shape");
429 }
430
431 if (!sliceOp.hasUnitStride()) {
432 return rewriter.notifyMatchFailure(
433 sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
434 "be supported in this transformation.");
435 }
436
437 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
438 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
439
440 if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
441 sizes.size()) {
442 return rewriter.notifyMatchFailure(sliceOp,
443 "unimplemented: rank reducing slice");
444 }
445
446 SmallVector<OpFoldResult> outputShape =
447 getMixedValues(expandShapeOp.getStaticOutputShape(),
448 expandShapeOp.getOutputShape(), rewriter);
449
450 std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
451 isZeroOffsetAndFullSize =
452 [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
453 if (!isZeroInteger(v: offset))
454 return false;
455 FailureOr<bool> maybeEqual =
456 ValueBoundsConstraintSet::areEqual(var1: sliceSize, var2: size);
457 return llvm::succeeded(Result: maybeEqual) && maybeEqual.value();
458 };
459
460 // Check that the slice is contiguous within each reassociation group.
461 // The slice is contiguous only if after the first dimension where a non
462 // unit slice is taken, the slice size on all subsequent dimensions of the
463 // group is equal to the entire size of the dimension.
464 // Examples of contiguous slices:
465 // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
466 // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
467 // Examples of non contiguous slices:
468 // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
469 // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
470 for (const ReassociationIndices &indices :
471 expandShapeOp.getReassociationIndices()) {
472 int64_t i = 0;
473 int64_t e = indices.size();
474 // Find the first expanded dim after the first dim with non-unit extracted
475 // size.
476 for (; i < e; ++i) {
477 if (!isOneInteger(sizes[indices[i]])) {
478 // +1 to skip the first non-unit size dim.
479 i++;
480 break;
481 }
482 }
483
484 // Verify that all subsequent dimensions extract the full size of the
485 // source tensor.
486 for (; i < e; ++i) {
487 int64_t expandedDim = indices[i];
488 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
489 outputShape[expandedDim])) {
490 return rewriter.notifyMatchFailure(
491 sliceOp, "Not a contiguous slice of the expanded tensor.");
492 }
493 }
494 }
495
496 return success();
497 }
498};
499
500/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
501/// `tensor.collapse_shape(tensor.extract_slice)`.
502///
503/// For this transformation to be possible - after bubbling up, the extraction
504/// of the contiguous slice must be representable as a single slice obtained via
505/// tensor.extract_slice within each reassociation group of the src.
506///
507/// In case the size and offset extracted are static then this is possible if
508/// the following conditions are met within each reassociation group:
509/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
510/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
511/// shape of a desired slice. A slice of shape S can be extracted as a
512/// contiguous span of elements if and only if there exists an index k in {0, 1,
513/// ..., n} such that:
514/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
515/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
516/// one dimension),
517/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
518/// in full).
519/// In other words, the slice shape S must be of the form:
520/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
521///
522/// In case the size and/or offset extracted are dynamic then this is possible
523/// only if there is single dimension in the reassociation group that has a size
524/// not equal to 1.
525/// In other words, the tensor shape must be of the form:
526/// [ 1, 1, ..., 1, A, 1, ...,1 ]
527/// Note - it might be possible to enable this pattern for more cases when the
528/// size/offset are dynamic via performing an analysis of the possible values
529/// that could be given to the size/offset.
530///
531/// Example:
532/// The transformation is possible because each reassociation group can be
533/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
534/// [20->10]).
535/// ```
536/// BEFORE:
537/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
538/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
539/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
540/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
541///
542/// AFTER:
543/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
544// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
545/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
546/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
547/// ```
548///
549/// Negative example:
550/// The transformation is not possible because we cannot use a single slice to
551/// represent the reassociation group [2x3x10->???]. If we would want the
552/// collapse to be after the extraction, we would need to extract multiple
553/// slices and concat them together.
554/// ```
555/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
556/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
557/// tensor<60xf32> to tensor<15xf32>
558/// ```
559/// If we would want the collapse to be after the extraction, a possible
560/// alternate transformation could be to extract multiple slices and concat them
561/// together:
562/// ```
563/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
564/// tensor<2x3x10xf32> to tensor <1x1x10xf32>
565/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
566/// tensor<2x3x10xf32> to tensor <1x1x5xf32>
567/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
568/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
569/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
570/// to tensor<15xf32>
571/// ```
572/// But this is not the intended purpose of the transformation.
573struct BubbleUpCollapseShapeThroughExtractSlice
574 : public OpRewritePattern<tensor::ExtractSliceOp> {
575 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
576
577 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
578 PatternRewriter &rewriter) const override {
579 auto collapseShapeOp =
580 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
581 if (!collapseShapeOp) {
582 return rewriter.notifyMatchFailure(
583 sliceOp,
584 "tensor.extract_slice source not produced by tensor.collapse_shape");
585 }
586
587 if (!sliceOp.hasUnitStride()) {
588 return rewriter.notifyMatchFailure(
589 sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
590 "be supported in this transformation.");
591 }
592
593 // The tensor.extract_slice before applying the pattern works on the result
594 // of the tensor.collapse_shape, so variables (i.e. inputs for
595 // ExtractSliceOp) referring to the state before applying the pattern are
596 // named with the prefix "collapsed", and ones referring to the state after
597 // applying the pattern are named with the prefix "expanded".
598 SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
599 SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
600
601 if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
602 collapsedSizes.size()) {
603 return rewriter.notifyMatchFailure(sliceOp,
604 "unimplemented: rank reducing slice");
605 }
606
607 ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
608 SmallVector<ReassociationIndices, 4> reassociationIndices =
609 collapseShapeOp.getReassociationIndices();
610
611 // Compute new offsets, sizes, and strides for tensor.extract_slice.
612 // The new tensor.extract_slice will work on a tensor that has has a rank
613 // equal to the rank of the src of the collapse_shape. In each iteration of
614 // the loop, the offsets and sizes will be computed per reassociation group.
615 SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
616 SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
617 rewriter.getIndexAttr(1));
618
619 for (auto [collapsedSize, collapsedOffset, reassocIndices] :
620 llvm::zip_equal(collapsedSizes, collapsedOffsets,
621 collapseShapeOp.getReassociationIndices())) {
622 // CASE #1 - size and/or offset are dynamic.
623 // In this case, the slice can be represented as a contiguous slice only
624 // if there is a single dimension in the reassociation group that has a
625 // size not equal to 1.
626 if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
627 int nonUnitSizeCount = 0;
628 for (int64_t expandedShapeIdx : reassocIndices) {
629 if (srcShape[expandedShapeIdx] != 1) {
630 nonUnitSizeCount++;
631 expandedSizes.push_back(collapsedSize);
632 expandedOffsets.push_back(collapsedOffset);
633 continue;
634 }
635
636 expandedSizes.push_back(rewriter.getIndexAttr(1));
637 expandedOffsets.push_back(rewriter.getIndexAttr(0));
638 }
639
640 if (nonUnitSizeCount != 1) {
641 return rewriter.notifyMatchFailure(
642 sliceOp,
643 "unsupported: slice cannot be verified to be contiguous");
644 }
645 continue;
646 }
647
648 // CASE #2 = size and offset are static.
649 // Verify that the slice can be represented as a contiguous slice of the
650 // src of the collapse_shape.
651 // Checking this is done on order of most internal dimensions first,
652 // so traversal is done in reverse order of the reassociation group.
653 // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
654 // ...,An] then we first find the size and offset for n...k+1 then for k
655 // and then for k-1...0.
656
657 // currentCollapsedsize and currentCollapsedOffset are initialized with
658 // the original collapsed size and offset and divided by the expanded
659 // shape size in each dimension as we go along the reassociation group.
660 // In essence we are spreading the original collapsed size and offset over
661 // the various expanded slice dimensions.
662 // The variables are used both to check the validity of the slice and to
663 // compute the expanded sizes and offsets.
664 int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
665 int64_t currentCollapsedOffset =
666 getConstantIntValue(collapsedOffset).value();
667
668 SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
669
670 ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
671 reassocIndices.rend());
672 int64_t idx = 0;
673 int64_t reassocGroupSize = reassocIndices.size();
674
675 // First handle the trailing dimensions where the slice size should be
676 // equal to the tensor shape and the offset should be 0 (n...k+1).
677 for (; idx < reassocGroupSize; ++idx) {
678 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
679
680 if (currentCollapsedsize < expandedShapeSize)
681 break;
682
683 // We need to make sure that the slice size can be set to the shape size
684 // and the offset to 0.
685 if ((currentCollapsedsize % expandedShapeSize) != 0 ||
686 (currentCollapsedOffset % expandedShapeSize) != 0) {
687 return rewriter.notifyMatchFailure(
688 sliceOp, "unsupported: cannot be extracted as a contiguous slice "
689 "of the src of the collapse_shape");
690 }
691
692 groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
693 groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
694
695 currentCollapsedsize /= expandedShapeSize;
696 currentCollapsedOffset /= expandedShapeSize;
697 }
698
699 // Now handle the first dim where slicing occurs on (k).
700 if (idx < reassocGroupSize) {
701 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
702 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
703 // We need to make sure that the slice size in this dim + offset will
704 // not exceed the shape size.
705 if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
706 return rewriter.notifyMatchFailure(
707 sliceOp, "unsupported: slice cannot be extracted as a contiguous "
708 "slice of the src of the collapse_shape");
709 }
710
711 groupExpandedSizes.push_back(
712 rewriter.getIndexAttr(currentCollapsedsize));
713 groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
714
715 currentCollapsedOffset /= expandedShapeSize;
716 }
717
718 // Now handle the leading dimensions where the slice size is equal to 1
719 // (k-1...0).
720 // The size for these dimensions must be 1 because of how we constructed
721 // the slice size of the expanded shape. We spread the original collapsed
722 // size over the expanded shape sizes until we reached dimension k where
723 // the remaining size was smaller than the expanded shape size, and spread
724 // the remaining size on it. So, now we are left with only 1s.
725 for (idx++; idx < reassocGroupSize; ++idx) {
726 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
727 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
728 groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
729 groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
730 currentCollapsedOffset /= expandedShapeSize;
731 }
732
733 expandedSizes.append(groupExpandedSizes.rbegin(),
734 groupExpandedSizes.rend());
735 expandedOffsets.append(groupExpandedOffsets.rbegin(),
736 groupExpandedOffsets.rend());
737 }
738
739 Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
740 collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
741 expandedSizes, expandedStrides);
742 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
743 sliceOp, sliceOp.getResultType(), newSliceOp,
744 collapseShapeOp.getReassociationIndices());
745
746 return success();
747 }
748};
749
750} // namespace
751
752void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
753 RewritePatternSet &patterns) {
754 patterns
755 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
756 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
757 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
758 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
759 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
760 patterns.getContext());
761}
762
763void mlir::tensor::populateBubbleUpExpandShapePatterns(
764 RewritePatternSet &patterns) {
765 patterns.add<BubbleUpExpandThroughParallelCollapse>(arg: patterns.getContext());
766}
767
768void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
769 RewritePatternSet &patterns) {
770 patterns.add<BubbleUpExpandShapeThroughExtractSlice,
771 BubbleUpCollapseShapeThroughExtractSlice>(arg: patterns.getContext());
772}
773

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp