1 | //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
11 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
12 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
13 | #include "mlir/IR/PatternMatch.h" |
14 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
15 | #include "llvm/ADT/STLExtras.h" |
16 | #include "llvm/Support/Debug.h" |
17 | #include "llvm/Support/LogicalResult.h" |
18 | |
19 | using namespace mlir; |
20 | using namespace mlir::tensor; |
21 | |
22 | namespace { |
23 | /// Fold expand_shape(extract_slice) ops that cancel itself out. |
24 | struct FoldExpandOfRankReducingExtract |
25 | : public OpRewritePattern<ExpandShapeOp> { |
26 | using OpRewritePattern<ExpandShapeOp>::OpRewritePattern; |
27 | |
28 | LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, |
29 | PatternRewriter &rewriter) const override { |
30 | RankedTensorType resultType = expandShapeOp.getResultType(); |
31 | auto extractSliceOp = |
32 | expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>(); |
33 | if (!extractSliceOp) |
34 | return failure(); |
35 | RankedTensorType srcType = extractSliceOp.getSourceType(); |
36 | |
37 | // Only cases where the ExpandShapeOp can be folded away entirely are |
38 | // supported. Moreover, only simple cases where the resulting ExtractSliceOp |
39 | // has no rank-reduction anymore are supported at the moment. |
40 | RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( |
41 | srcType, extractSliceOp.getStaticOffsets(), |
42 | extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); |
43 | if (nonReducingExtractType != resultType) |
44 | return failure(); |
45 | |
46 | SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); |
47 | SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); |
48 | SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); |
49 | rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( |
50 | expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes, |
51 | mixedStrides); |
52 | return success(); |
53 | } |
54 | }; |
55 | |
56 | /// Fold collapse_shape which only removes static dimensions of size `1` |
57 | /// into extract_slice. |
58 | struct FoldUnPaddingCollapseIntoExtract |
59 | : public OpRewritePattern<tensor::CollapseShapeOp> { |
60 | using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern; |
61 | |
62 | LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp, |
63 | PatternRewriter &rewriter) const override { |
64 | auto extractSliceOp = |
65 | collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>(); |
66 | // Collapse cannot be folded away with multiple users of the extract slice |
67 | // and it is not necessarily beneficial to only convert the collapse into |
68 | // another extract slice. |
69 | if (!extractSliceOp || !extractSliceOp->hasOneUse()) |
70 | return failure(); |
71 | |
72 | // Only fold away simple collapse where all removed dimensions have static |
73 | // size `1`. |
74 | SliceVerificationResult res = isRankReducedType( |
75 | collapseShapeOp.getSrcType(), collapseShapeOp.getResultType()); |
76 | if (res != SliceVerificationResult::Success) |
77 | return rewriter.notifyMatchFailure(collapseShapeOp, |
78 | "expected unpadding collapse"); |
79 | |
80 | Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>( |
81 | extractSliceOp.getLoc(), collapseShapeOp.getResultType(), |
82 | extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(), |
83 | extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); |
84 | rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice); |
85 | return success(); |
86 | } |
87 | }; |
88 | |
89 | /// Fold insert_slice(collapse_shape) ops that cancel itself out. |
90 | template <typename OpTy> |
91 | struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> { |
92 | using OpRewritePattern<OpTy>::OpRewritePattern; |
93 | |
94 | LogicalResult matchAndRewrite(OpTy insertSliceOp, |
95 | PatternRewriter &rewriter) const override { |
96 | auto collapseShapeOp = |
97 | insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>(); |
98 | if (!collapseShapeOp) |
99 | return failure(); |
100 | RankedTensorType srcType = collapseShapeOp.getSrcType(); |
101 | |
102 | // Only cases where the CollapseShapeOp can be folded away entirely are |
103 | // supported. Moreover, only simple cases where the resulting InsertSliceOp |
104 | // has no rank-reduction anymore are supported at the moment. |
105 | RankedTensorType nonReducingInsertType = |
106 | RankedTensorType::get(insertSliceOp.getStaticSizes(), |
107 | insertSliceOp.getDestType().getElementType()); |
108 | if (nonReducingInsertType != srcType) |
109 | return failure(); |
110 | |
111 | SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); |
112 | SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); |
113 | SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); |
114 | rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(), |
115 | insertSliceOp.getDest(), mixedOffsets, |
116 | mixedSizes, mixedStrides); |
117 | return success(); |
118 | } |
119 | }; |
120 | |
121 | /// Fold expand_shape which only adds static dimensions of size `1` |
122 | /// into insert_slice. |
123 | template <typename OpTy> |
124 | struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> { |
125 | using OpRewritePattern<OpTy>::OpRewritePattern; |
126 | |
127 | LogicalResult matchAndRewrite(OpTy insertSliceOp, |
128 | PatternRewriter &rewriter) const override { |
129 | auto expandShapeOp = insertSliceOp.getSource() |
130 | .template getDefiningOp<tensor::ExpandShapeOp>(); |
131 | if (!expandShapeOp) |
132 | return failure(); |
133 | |
134 | // Only fold away simple expansion where all added dimensions have static |
135 | // size `1`. |
136 | SliceVerificationResult res = isRankReducedType( |
137 | expandShapeOp.getResultType(), expandShapeOp.getSrcType()); |
138 | if (res != SliceVerificationResult::Success) |
139 | return rewriter.notifyMatchFailure(insertSliceOp, |
140 | "expected rank increasing expansion"); |
141 | |
142 | rewriter.modifyOpInPlace(insertSliceOp, [&]() { |
143 | insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc()); |
144 | }); |
145 | return success(); |
146 | } |
147 | }; |
148 | |
149 | /// Pattern to bubble up a tensor.expand_shape op through a producer |
150 | /// tensor.collapse_shape op that has non intersecting reassociations. |
151 | struct BubbleUpExpandThroughParallelCollapse |
152 | : public OpRewritePattern<tensor::ExpandShapeOp> { |
153 | using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern; |
154 | |
155 | LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, |
156 | PatternRewriter &rewriter) const override { |
157 | auto collapseOp = |
158 | expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>(); |
159 | if (!collapseOp) |
160 | return failure(); |
161 | auto expandReInds = expandOp.getReassociationIndices(); |
162 | auto collapseReInds = collapseOp.getReassociationIndices(); |
163 | |
164 | // Special case where the collapsed tensor to expand is a 0-D tensor, |
165 | // then the reassociation maps will be empty and not produce valid results. |
166 | if (expandReInds.size() == 0) { |
167 | return failure(); |
168 | } |
169 | |
170 | // Reshapes are parallel to each other (by construction the number of |
171 | // reassociations specified in the collapse and expand are the same), if at |
172 | // any position |
173 | // 1. either the reassociation indices are of the same size, or |
174 | // 2. either the reassociation in the collapse or the expand is of size 1. |
175 | ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape(); |
176 | ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape(); |
177 | for (auto [expandReassociation, collapseReassociation] : |
178 | llvm::zip_equal(expandReInds, collapseReInds)) { |
179 | if (collapseReassociation.size() == expandReassociation.size()) { |
180 | // Even if the reassociations are the same, the collapse/expand should |
181 | // result in the same dimensions. i.e 4x8x2 into 64 should be expanded |
182 | // into 4x8x2 again. In presense of dynamic dimensions one can only |
183 | // verify "equality" when there is only one dynamic dimension present, |
184 | // and all other static dimensions are equal. |
185 | ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice( |
186 | collapseReassociation.front(), collapseReassociation.size()); |
187 | int64_t numCollapsedDynamic = |
188 | llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic); |
189 | ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice( |
190 | expandReassociation.front(), expandReassociation.size()); |
191 | int64_t numExpandedDynamic = |
192 | llvm::count_if(expandedStaticShapes, ShapedType::isDynamic); |
193 | if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 || |
194 | collapsedStaticShapes != expandedStaticShapes) { |
195 | return failure(); |
196 | } |
197 | continue; |
198 | } |
199 | // If the reassociations are not same, one or the other needs to be of |
200 | // size one. |
201 | if (collapseReassociation.size() != 1 && expandReassociation.size() != 1) |
202 | return failure(); |
203 | } |
204 | |
205 | // Compute new reassociation indices and expanded/collaped shapes. |
206 | SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds; |
207 | Location loc = expandOp->getLoc(); |
208 | SmallVector<OpFoldResult> sourceSizes = |
209 | tensor::getMixedSizes(builder&: rewriter, loc, value: collapseOp.getSrc()); |
210 | SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape(); |
211 | SmallVector<OpFoldResult> newExpandSizes; |
212 | |
213 | int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0, |
214 | resultSizeIndex = 0; |
215 | |
216 | for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) { |
217 | auto &collapseReassociation = collapseReInds[idx]; |
218 | auto &expandReassociation = expandReInds[idx]; |
219 | |
220 | // Case 1. The reassociations are same in the collapse producer |
221 | // and expand consumer. In the swapped expand, each of the final |
222 | // dimensions are kept as is in the expand and the collapse. So, |
223 | // for every element in the `ReassocationIndices` vector add a new |
224 | // `ReassociationIndices` vector for the swapped expand and collapse |
225 | // (of size 1). |
226 | if (collapseReassociation.size() == expandReassociation.size()) { |
227 | for (size_t i = 0; i < collapseReassociation.size(); ++i) { |
228 | newCollapseReInds.push_back(Elt: {newCollapseIndex++}); |
229 | newExpandReInds.push_back(Elt: {newExpandIndex++}); |
230 | newExpandSizes.push_back(Elt: resultSizes[resultSizeIndex++]); |
231 | sourceSizeIndex++; |
232 | } |
233 | continue; |
234 | } |
235 | |
236 | // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and |
237 | // in the expand is of size == 1). In this case, the original dimensions |
238 | // are preserved on expansion and collapsed subsequently. |
239 | if (collapseReassociation.size() != 1) { |
240 | ReassociationIndices newCollapseReassociation; |
241 | for (size_t i = 0; i < collapseReassociation.size(); ++i) { |
242 | newCollapseReassociation.push_back(Elt: newCollapseIndex++); |
243 | newExpandReInds.push_back(Elt: {newExpandIndex++}); |
244 | newExpandSizes.push_back(Elt: sourceSizes[sourceSizeIndex++]); |
245 | } |
246 | resultSizeIndex++; |
247 | newCollapseReInds.push_back(Elt: newCollapseReassociation); |
248 | continue; |
249 | } |
250 | |
251 | // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and |
252 | // in the collapse is of size == 1). In this case, the expansion happens |
253 | // first and the expanded dimensions are preserved on collapse. |
254 | ReassociationIndices newExpandReassociation; |
255 | for (size_t i = 0; i < expandReassociation.size(); ++i) { |
256 | newExpandReassociation.push_back(Elt: newExpandIndex++); |
257 | newCollapseReInds.push_back(Elt: {newCollapseIndex++}); |
258 | newExpandSizes.push_back(Elt: resultSizes[resultSizeIndex++]); |
259 | } |
260 | newExpandReInds.push_back(Elt: newExpandReassociation); |
261 | sourceSizeIndex++; |
262 | } |
263 | |
264 | // Swap reshape order. |
265 | SmallVector<Value> dynamicSizes; |
266 | SmallVector<int64_t> staticSizes; |
267 | dispatchIndexOpFoldResults(ofrs: newExpandSizes, dynamicVec&: dynamicSizes, staticVec&: staticSizes); |
268 | auto expandResultType = expandOp.getResultType().clone(staticSizes); |
269 | Value newCollapseSrc = collapseOp.getSrc(); |
270 | // If the number of reassociation indices in the new `expand_shape` op |
271 | // matches the number of dimensions of the result, then the expand_shape |
272 | // is a no-op. |
273 | if (newExpandReInds.size() != newExpandSizes.size()) { |
274 | newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>( |
275 | loc, expandResultType, newCollapseSrc, newExpandReInds, |
276 | newExpandSizes); |
277 | } |
278 | |
279 | // If the number of reassociation indices in the new `collapse_shape` op |
280 | // matches the number of dimensions of the source, then the collapse_shape |
281 | // is a no-op. |
282 | Value replacement = newCollapseSrc; |
283 | if (newCollapseReInds.size() != newExpandSizes.size()) { |
284 | replacement = rewriter.create<tensor::CollapseShapeOp>( |
285 | loc, newCollapseSrc, newCollapseReInds); |
286 | } |
287 | rewriter.replaceOp(expandOp, replacement); |
288 | return success(); |
289 | } |
290 | }; |
291 | |
292 | /// Converts `tensor.extract_slice(tensor.expand_shape)` to |
293 | /// `tensor.expand_shape(tensor.extract_slice)`. |
294 | /// |
295 | /// For this transformation to be possible, the slice must be fully contiguous |
296 | /// within each reassociation group of the expand_shape. A slice is defined as |
297 | /// fully contiguous within a reassociation group if after flattening the |
298 | /// reassociation group to a single 1D range, then the slice taken out of the |
299 | /// group could be defined as a single contiguous subrange within that range. |
300 | /// |
301 | /// Rank reducing slices are not supported. |
302 | /// |
303 | /// Example: |
304 | /// The transformation is possible because each reassociation group has a |
305 | /// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]). |
306 | /// ``` |
307 | /// BEFORE: |
308 | /// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]] |
309 | /// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32> |
310 | /// %slice = tensor.extract_slice %reshape ... |
311 | /// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32> |
312 | /// |
313 | /// AFTER: |
314 | /// %slice = tensor.extract_slice %in ... |
315 | /// tensor<8x16x32xf32> to tensor<8x5x4xf32> |
316 | /// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]] |
317 | /// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32> |
318 | /// ``` |
319 | /// |
320 | /// Note - this pattern could be extended to be a swap pattern between |
321 | /// `tensor.expand_shape` and `tensor.extract_slice`, but is currently |
322 | /// implemented only as a bubble up pattern for `tensor.extract_slice`. |
323 | struct BubbleUpExpandShapeThroughExtractSlice |
324 | : public OpRewritePattern<tensor::ExtractSliceOp> { |
325 | using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
326 | |
327 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
328 | PatternRewriter &rewriter) const override { |
329 | auto expandShapeOp = |
330 | sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); |
331 | |
332 | if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp, |
333 | rewriter) |
334 | .failed()) |
335 | return failure(); |
336 | |
337 | // The tensor.extract_slice before applying the pattern works on the result |
338 | // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) |
339 | // referring to the state before applying the pattern are named with the |
340 | // prefix "expanded", and ones referring to the state after applying the |
341 | // pattern are named with the prefix "collapsed". |
342 | SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets(); |
343 | SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); |
344 | SmallVector<OpFoldResult> expandedShape = |
345 | getMixedValues(expandShapeOp.getStaticOutputShape(), |
346 | expandShapeOp.getOutputShape(), rewriter); |
347 | |
348 | // Helper variables and function for accumulating the size values. |
349 | Location loc = expandShapeOp->getLoc(); |
350 | AffineExpr d0, d1, d2; |
351 | bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1, exprs&: d2); |
352 | // Multiply two integers. |
353 | auto mul = [&](OpFoldResult v1, OpFoldResult v2) { |
354 | auto mulMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {d0 * d1}); |
355 | return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap, |
356 | {v1, v2}); |
357 | }; |
358 | |
359 | // Compute new offsets, sizes, and strides for tensor.extract_slice. |
360 | // The new tensor.extract_slice will work on a tensor that has has a rank of |
361 | // ReassociationIndices.size(). In the loop a single offset, size, and |
362 | // stride value is computed per reassociation group. |
363 | SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes, |
364 | collapsedStrides; |
365 | for (const ReassociationIndices &indices : |
366 | expandShapeOp.getReassociationIndices()) { |
367 | // collapsedSize will hold the size of the single dim that represents the |
368 | // reassociation group in the non expanded tensor. |
369 | OpFoldResult collapsedSize = rewriter.getIndexAttr(1); |
370 | // The reassocGroupSizes and reassocGroupOffsets are used to create an |
371 | // affine.linearize_index op to linearize the single offset value required |
372 | // for this reassociation group. |
373 | SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets; |
374 | |
375 | for (long expandedDim : indices) { |
376 | // reassocGroupSizes and reassocGroupOffsets can be obtained directly |
377 | // from the expanded state, but the collapsed size requires calculation |
378 | // as it did not previously exist. |
379 | reassocGroupSizes.push_back(expandedShape[expandedDim]); |
380 | reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); |
381 | collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); |
382 | } |
383 | |
384 | SmallVector<Value> offsetVals = |
385 | llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { |
386 | return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); |
387 | }); |
388 | OpFoldResult collapsedOffset = |
389 | rewriter |
390 | .create<affine::AffineLinearizeIndexOp>(loc, offsetVals, |
391 | reassocGroupSizes, |
392 | /*disjoint=*/true) |
393 | .getResult(); |
394 | collapsedOffsets.push_back(collapsedOffset); |
395 | collapsedSizes.push_back(collapsedSize); |
396 | |
397 | // Only unit stride is supported. |
398 | collapsedStrides.push_back(rewriter.getIndexAttr(1)); |
399 | } |
400 | |
401 | // The shape of the result can be obtained from the sizes passed in. |
402 | SmallVector<Value> dynDims; |
403 | SmallVector<int64_t> shape; |
404 | dispatchIndexOpFoldResults(ofrs: expandedSizes, dynamicVec&: dynDims, staticVec&: shape); |
405 | RankedTensorType resultType = RankedTensorType::get( |
406 | shape, expandShapeOp.getResultType().getElementType()); |
407 | |
408 | // Create a new ExtractSliceOp and ExpandShapeOp. |
409 | Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>( |
410 | loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes, |
411 | collapsedStrides); |
412 | rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( |
413 | sliceOp, resultType, newSliceOp, |
414 | expandShapeOp.getReassociationIndices(), expandedSizes); |
415 | return success(); |
416 | } |
417 | |
418 | // Helper function to check if all the required conditions for the |
419 | // tensor.extract_slice to be bubbled up through the tensor.expand_shape are |
420 | // met. |
421 | LogicalResult |
422 | checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp, |
423 | tensor::ExpandShapeOp expandShapeOp, |
424 | PatternRewriter &rewriter) const { |
425 | |
426 | if (!expandShapeOp) { |
427 | return rewriter.notifyMatchFailure( |
428 | sliceOp, "tensor.extract_slice source not produced by expand_shape"); |
429 | } |
430 | |
431 | if (!sliceOp.hasUnitStride()) { |
432 | return rewriter.notifyMatchFailure( |
433 | sliceOp, "unsupported: non-unit stride. Only contiguous slices can " |
434 | "be supported in this transformation."); |
435 | } |
436 | |
437 | SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); |
438 | SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); |
439 | |
440 | if (static_cast<size_t>(sliceOp.getResultType().getRank()) != |
441 | sizes.size()) { |
442 | return rewriter.notifyMatchFailure(sliceOp, |
443 | "unimplemented: rank reducing slice"); |
444 | } |
445 | |
446 | SmallVector<OpFoldResult> outputShape = |
447 | getMixedValues(expandShapeOp.getStaticOutputShape(), |
448 | expandShapeOp.getOutputShape(), rewriter); |
449 | |
450 | std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)> |
451 | isZeroOffsetAndFullSize = |
452 | [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { |
453 | if (!isZeroInteger(v: offset)) |
454 | return false; |
455 | FailureOr<bool> maybeEqual = |
456 | ValueBoundsConstraintSet::areEqual(var1: sliceSize, var2: size); |
457 | return llvm::succeeded(Result: maybeEqual) && maybeEqual.value(); |
458 | }; |
459 | |
460 | // Check that the slice is contiguous within each reassociation group. |
461 | // The slice is contiguous only if after the first dimension where a non |
462 | // unit slice is taken, the slice size on all subsequent dimensions of the |
463 | // group is equal to the entire size of the dimension. |
464 | // Examples of contiguous slices: |
465 | // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] |
466 | // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] |
467 | // Examples of non contiguous slices: |
468 | // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] |
469 | // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] |
470 | for (const ReassociationIndices &indices : |
471 | expandShapeOp.getReassociationIndices()) { |
472 | int64_t i = 0; |
473 | int64_t e = indices.size(); |
474 | // Find the first expanded dim after the first dim with non-unit extracted |
475 | // size. |
476 | for (; i < e; ++i) { |
477 | if (!isOneInteger(sizes[indices[i]])) { |
478 | // +1 to skip the first non-unit size dim. |
479 | i++; |
480 | break; |
481 | } |
482 | } |
483 | |
484 | // Verify that all subsequent dimensions extract the full size of the |
485 | // source tensor. |
486 | for (; i < e; ++i) { |
487 | int64_t expandedDim = indices[i]; |
488 | if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], |
489 | outputShape[expandedDim])) { |
490 | return rewriter.notifyMatchFailure( |
491 | sliceOp, "Not a contiguous slice of the expanded tensor."); |
492 | } |
493 | } |
494 | } |
495 | |
496 | return success(); |
497 | } |
498 | }; |
499 | |
500 | /// Converts `tensor.extract_slice(tensor.collapse_shape)` to |
501 | /// `tensor.collapse_shape(tensor.extract_slice)`. |
502 | /// |
503 | /// For this transformation to be possible - after bubbling up, the extraction |
504 | /// of the contiguous slice must be representable as a single slice obtained via |
505 | /// tensor.extract_slice within each reassociation group of the src. |
506 | /// |
507 | /// In case the size and offset extracted are static then this is possible if |
508 | /// the following conditions are met within each reassociation group: |
509 | /// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the |
510 | /// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the |
511 | /// shape of a desired slice. A slice of shape S can be extracted as a |
512 | /// contiguous span of elements if and only if there exists an index k in {0, 1, |
513 | /// ..., n} such that: |
514 | /// S_i = 1 for all i < k (that is, all leading dimensions are singleton), |
515 | /// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly |
516 | /// one dimension), |
517 | /// S_i = A_i for all i > k (that is, all trailing dimensions are preserved |
518 | /// in full). |
519 | /// In other words, the slice shape S must be of the form: |
520 | /// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] |
521 | /// |
522 | /// In case the size and/or offset extracted are dynamic then this is possible |
523 | /// only if there is single dimension in the reassociation group that has a size |
524 | /// not equal to 1. |
525 | /// In other words, the tensor shape must be of the form: |
526 | /// [ 1, 1, ..., 1, A, 1, ...,1 ] |
527 | /// Note - it might be possible to enable this pattern for more cases when the |
528 | /// size/offset are dynamic via performing an analysis of the possible values |
529 | /// that could be given to the size/offset. |
530 | /// |
531 | /// Example: |
532 | /// The transformation is possible because each reassociation group can be |
533 | /// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?], |
534 | /// [20->10]). |
535 | /// ``` |
536 | /// BEFORE: |
537 | /// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ... |
538 | /// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32> |
539 | /// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1] |
540 | /// tensor<128x7x20xf32> to tensor<32x?x10xf32> |
541 | /// |
542 | /// AFTER: |
543 | /// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10] |
544 | // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32> |
545 | /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ... |
546 | /// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32> |
547 | /// ``` |
548 | /// |
549 | /// Negative example: |
550 | /// The transformation is not possible because we cannot use a single slice to |
551 | /// represent the reassociation group [2x3x10->???]. If we would want the |
552 | /// collapse to be after the extraction, we would need to extract multiple |
553 | /// slices and concat them together. |
554 | /// ``` |
555 | /// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into |
556 | /// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] : |
557 | /// tensor<60xf32> to tensor<15xf32> |
558 | /// ``` |
559 | /// If we would want the collapse to be after the extraction, a possible |
560 | /// alternate transformation could be to extract multiple slices and concat them |
561 | /// together: |
562 | /// ``` |
563 | /// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] : |
564 | /// tensor<2x3x10xf32> to tensor <1x1x10xf32> |
565 | /// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] : |
566 | /// tensor<2x3x10xf32> to tensor <1x1x5xf32> |
567 | /// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} : |
568 | /// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32> |
569 | /// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32> |
570 | /// to tensor<15xf32> |
571 | /// ``` |
572 | /// But this is not the intended purpose of the transformation. |
573 | struct BubbleUpCollapseShapeThroughExtractSlice |
574 | : public OpRewritePattern<tensor::ExtractSliceOp> { |
575 | using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
576 | |
577 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
578 | PatternRewriter &rewriter) const override { |
579 | auto collapseShapeOp = |
580 | sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); |
581 | if (!collapseShapeOp) { |
582 | return rewriter.notifyMatchFailure( |
583 | sliceOp, |
584 | "tensor.extract_slice source not produced by tensor.collapse_shape"); |
585 | } |
586 | |
587 | if (!sliceOp.hasUnitStride()) { |
588 | return rewriter.notifyMatchFailure( |
589 | sliceOp, "unsupported: non-unit stride. Only contiguous slices can " |
590 | "be supported in this transformation."); |
591 | } |
592 | |
593 | // The tensor.extract_slice before applying the pattern works on the result |
594 | // of the tensor.collapse_shape, so variables (i.e. inputs for |
595 | // ExtractSliceOp) referring to the state before applying the pattern are |
596 | // named with the prefix "collapsed", and ones referring to the state after |
597 | // applying the pattern are named with the prefix "expanded". |
598 | SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets(); |
599 | SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes(); |
600 | |
601 | if (static_cast<size_t>(sliceOp.getResultType().getRank()) != |
602 | collapsedSizes.size()) { |
603 | return rewriter.notifyMatchFailure(sliceOp, |
604 | "unimplemented: rank reducing slice"); |
605 | } |
606 | |
607 | ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape(); |
608 | SmallVector<ReassociationIndices, 4> reassociationIndices = |
609 | collapseShapeOp.getReassociationIndices(); |
610 | |
611 | // Compute new offsets, sizes, and strides for tensor.extract_slice. |
612 | // The new tensor.extract_slice will work on a tensor that has has a rank |
613 | // equal to the rank of the src of the collapse_shape. In each iteration of |
614 | // the loop, the offsets and sizes will be computed per reassociation group. |
615 | SmallVector<OpFoldResult> expandedOffsets, expandedSizes; |
616 | SmallVector<OpFoldResult> expandedStrides(srcShape.size(), |
617 | rewriter.getIndexAttr(1)); |
618 | |
619 | for (auto [collapsedSize, collapsedOffset, reassocIndices] : |
620 | llvm::zip_equal(collapsedSizes, collapsedOffsets, |
621 | collapseShapeOp.getReassociationIndices())) { |
622 | // CASE #1 - size and/or offset are dynamic. |
623 | // In this case, the slice can be represented as a contiguous slice only |
624 | // if there is a single dimension in the reassociation group that has a |
625 | // size not equal to 1. |
626 | if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) { |
627 | int nonUnitSizeCount = 0; |
628 | for (int64_t expandedShapeIdx : reassocIndices) { |
629 | if (srcShape[expandedShapeIdx] != 1) { |
630 | nonUnitSizeCount++; |
631 | expandedSizes.push_back(collapsedSize); |
632 | expandedOffsets.push_back(collapsedOffset); |
633 | continue; |
634 | } |
635 | |
636 | expandedSizes.push_back(rewriter.getIndexAttr(1)); |
637 | expandedOffsets.push_back(rewriter.getIndexAttr(0)); |
638 | } |
639 | |
640 | if (nonUnitSizeCount != 1) { |
641 | return rewriter.notifyMatchFailure( |
642 | sliceOp, |
643 | "unsupported: slice cannot be verified to be contiguous"); |
644 | } |
645 | continue; |
646 | } |
647 | |
648 | // CASE #2 = size and offset are static. |
649 | // Verify that the slice can be represented as a contiguous slice of the |
650 | // src of the collapse_shape. |
651 | // Checking this is done on order of most internal dimensions first, |
652 | // so traversal is done in reverse order of the reassociation group. |
653 | // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, |
654 | // ...,An] then we first find the size and offset for n...k+1 then for k |
655 | // and then for k-1...0. |
656 | |
657 | // currentCollapsedsize and currentCollapsedOffset are initialized with |
658 | // the original collapsed size and offset and divided by the expanded |
659 | // shape size in each dimension as we go along the reassociation group. |
660 | // In essence we are spreading the original collapsed size and offset over |
661 | // the various expanded slice dimensions. |
662 | // The variables are used both to check the validity of the slice and to |
663 | // compute the expanded sizes and offsets. |
664 | int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); |
665 | int64_t currentCollapsedOffset = |
666 | getConstantIntValue(collapsedOffset).value(); |
667 | |
668 | SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets; |
669 | |
670 | ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), |
671 | reassocIndices.rend()); |
672 | int64_t idx = 0; |
673 | int64_t reassocGroupSize = reassocIndices.size(); |
674 | |
675 | // First handle the trailing dimensions where the slice size should be |
676 | // equal to the tensor shape and the offset should be 0 (n...k+1). |
677 | for (; idx < reassocGroupSize; ++idx) { |
678 | int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; |
679 | |
680 | if (currentCollapsedsize < expandedShapeSize) |
681 | break; |
682 | |
683 | // We need to make sure that the slice size can be set to the shape size |
684 | // and the offset to 0. |
685 | if ((currentCollapsedsize % expandedShapeSize) != 0 || |
686 | (currentCollapsedOffset % expandedShapeSize) != 0) { |
687 | return rewriter.notifyMatchFailure( |
688 | sliceOp, "unsupported: cannot be extracted as a contiguous slice " |
689 | "of the src of the collapse_shape"); |
690 | } |
691 | |
692 | groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize)); |
693 | groupExpandedOffsets.push_back(rewriter.getIndexAttr(0)); |
694 | |
695 | currentCollapsedsize /= expandedShapeSize; |
696 | currentCollapsedOffset /= expandedShapeSize; |
697 | } |
698 | |
699 | // Now handle the first dim where slicing occurs on (k). |
700 | if (idx < reassocGroupSize) { |
701 | int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; |
702 | int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; |
703 | // We need to make sure that the slice size in this dim + offset will |
704 | // not exceed the shape size. |
705 | if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { |
706 | return rewriter.notifyMatchFailure( |
707 | sliceOp, "unsupported: slice cannot be extracted as a contiguous " |
708 | "slice of the src of the collapse_shape"); |
709 | } |
710 | |
711 | groupExpandedSizes.push_back( |
712 | rewriter.getIndexAttr(currentCollapsedsize)); |
713 | groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); |
714 | |
715 | currentCollapsedOffset /= expandedShapeSize; |
716 | } |
717 | |
718 | // Now handle the leading dimensions where the slice size is equal to 1 |
719 | // (k-1...0). |
720 | // The size for these dimensions must be 1 because of how we constructed |
721 | // the slice size of the expanded shape. We spread the original collapsed |
722 | // size over the expanded shape sizes until we reached dimension k where |
723 | // the remaining size was smaller than the expanded shape size, and spread |
724 | // the remaining size on it. So, now we are left with only 1s. |
725 | for (idx++; idx < reassocGroupSize; ++idx) { |
726 | int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; |
727 | int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; |
728 | groupExpandedSizes.push_back(rewriter.getIndexAttr(1)); |
729 | groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); |
730 | currentCollapsedOffset /= expandedShapeSize; |
731 | } |
732 | |
733 | expandedSizes.append(groupExpandedSizes.rbegin(), |
734 | groupExpandedSizes.rend()); |
735 | expandedOffsets.append(groupExpandedOffsets.rbegin(), |
736 | groupExpandedOffsets.rend()); |
737 | } |
738 | |
739 | Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>( |
740 | collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets, |
741 | expandedSizes, expandedStrides); |
742 | rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( |
743 | sliceOp, sliceOp.getResultType(), newSliceOp, |
744 | collapseShapeOp.getReassociationIndices()); |
745 | |
746 | return success(); |
747 | } |
748 | }; |
749 | |
750 | } // namespace |
751 | |
752 | void mlir::tensor::populateReassociativeReshapeFoldingPatterns( |
753 | RewritePatternSet &patterns) { |
754 | patterns |
755 | .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract, |
756 | FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>, |
757 | FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>, |
758 | FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>, |
759 | FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>( |
760 | patterns.getContext()); |
761 | } |
762 | |
763 | void mlir::tensor::populateBubbleUpExpandShapePatterns( |
764 | RewritePatternSet &patterns) { |
765 | patterns.add<BubbleUpExpandThroughParallelCollapse>(arg: patterns.getContext()); |
766 | } |
767 | |
768 | void mlir::tensor::populateBubbleUpExtractSliceOpPatterns( |
769 | RewritePatternSet &patterns) { |
770 | patterns.add<BubbleUpExpandShapeThroughExtractSlice, |
771 | BubbleUpCollapseShapeThroughExtractSlice>(arg: patterns.getContext()); |
772 | } |
773 |
Definitions
- FoldExpandOfRankReducingExtract
- matchAndRewrite
- FoldUnPaddingCollapseIntoExtract
- matchAndRewrite
- FoldInsertOfRankReducingInsert
- matchAndRewrite
- FoldPaddingExpandIntoInsert
- matchAndRewrite
- BubbleUpExpandThroughParallelCollapse
- matchAndRewrite
- BubbleUpExpandShapeThroughExtractSlice
- matchAndRewrite
- checkPreconditionForBubbleUpExtractSlice
- BubbleUpCollapseShapeThroughExtractSlice
- matchAndRewrite
- populateReassociativeReshapeFoldingPatterns
- populateBubbleUpExpandShapePatterns
Learn to use CMake with our Intro Training
Find out more