| 1 | //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 10 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 11 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 12 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| 13 | #include "mlir/IR/PatternMatch.h" |
| 14 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| 15 | #include "llvm/ADT/STLExtras.h" |
| 16 | #include "llvm/Support/Debug.h" |
| 17 | #include "llvm/Support/LogicalResult.h" |
| 18 | |
| 19 | using namespace mlir; |
| 20 | using namespace mlir::tensor; |
| 21 | |
| 22 | namespace { |
| 23 | /// Fold expand_shape(extract_slice) ops that cancel itself out. |
| 24 | struct FoldExpandOfRankReducingExtract |
| 25 | : public OpRewritePattern<ExpandShapeOp> { |
| 26 | using OpRewritePattern<ExpandShapeOp>::OpRewritePattern; |
| 27 | |
| 28 | LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, |
| 29 | PatternRewriter &rewriter) const override { |
| 30 | RankedTensorType resultType = expandShapeOp.getResultType(); |
| 31 | auto = |
| 32 | expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>(); |
| 33 | if (!extractSliceOp) |
| 34 | return failure(); |
| 35 | RankedTensorType srcType = extractSliceOp.getSourceType(); |
| 36 | |
| 37 | // Only cases where the ExpandShapeOp can be folded away entirely are |
| 38 | // supported. Moreover, only simple cases where the resulting ExtractSliceOp |
| 39 | // has no rank-reduction anymore are supported at the moment. |
| 40 | RankedTensorType = ExtractSliceOp::inferResultType( |
| 41 | srcType, extractSliceOp.getStaticOffsets(), |
| 42 | extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); |
| 43 | if (nonReducingExtractType != resultType) |
| 44 | return failure(); |
| 45 | |
| 46 | SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); |
| 47 | SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); |
| 48 | SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); |
| 49 | rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( |
| 50 | expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes, |
| 51 | mixedStrides); |
| 52 | return success(); |
| 53 | } |
| 54 | }; |
| 55 | |
| 56 | /// Fold collapse_shape which only removes static dimensions of size `1` |
| 57 | /// into extract_slice. |
| 58 | struct |
| 59 | : public OpRewritePattern<tensor::CollapseShapeOp> { |
| 60 | using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern; |
| 61 | |
| 62 | LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp, |
| 63 | PatternRewriter &rewriter) const override { |
| 64 | auto = |
| 65 | collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>(); |
| 66 | // Collapse cannot be folded away with multiple users of the extract slice |
| 67 | // and it is not necessarily beneficial to only convert the collapse into |
| 68 | // another extract slice. |
| 69 | if (!extractSliceOp || !extractSliceOp->hasOneUse()) |
| 70 | return failure(); |
| 71 | |
| 72 | // Only fold away simple collapse where all removed dimensions have static |
| 73 | // size `1`. |
| 74 | SliceVerificationResult res = isRankReducedType( |
| 75 | collapseShapeOp.getSrcType(), collapseShapeOp.getResultType()); |
| 76 | if (res != SliceVerificationResult::Success) |
| 77 | return rewriter.notifyMatchFailure(collapseShapeOp, |
| 78 | "expected unpadding collapse" ); |
| 79 | |
| 80 | Value = rewriter.create<tensor::ExtractSliceOp>( |
| 81 | extractSliceOp.getLoc(), collapseShapeOp.getResultType(), |
| 82 | extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(), |
| 83 | extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); |
| 84 | rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice); |
| 85 | return success(); |
| 86 | } |
| 87 | }; |
| 88 | |
| 89 | /// Fold insert_slice(collapse_shape) ops that cancel itself out. |
| 90 | template <typename OpTy> |
| 91 | struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> { |
| 92 | using OpRewritePattern<OpTy>::OpRewritePattern; |
| 93 | |
| 94 | LogicalResult matchAndRewrite(OpTy insertSliceOp, |
| 95 | PatternRewriter &rewriter) const override { |
| 96 | auto collapseShapeOp = |
| 97 | insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>(); |
| 98 | if (!collapseShapeOp) |
| 99 | return failure(); |
| 100 | RankedTensorType srcType = collapseShapeOp.getSrcType(); |
| 101 | |
| 102 | // Only cases where the CollapseShapeOp can be folded away entirely are |
| 103 | // supported. Moreover, only simple cases where the resulting InsertSliceOp |
| 104 | // has no rank-reduction anymore are supported at the moment. |
| 105 | RankedTensorType nonReducingInsertType = |
| 106 | RankedTensorType::get(insertSliceOp.getStaticSizes(), |
| 107 | insertSliceOp.getDestType().getElementType()); |
| 108 | if (nonReducingInsertType != srcType) |
| 109 | return failure(); |
| 110 | |
| 111 | SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); |
| 112 | SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); |
| 113 | SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); |
| 114 | rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(), |
| 115 | insertSliceOp.getDest(), mixedOffsets, |
| 116 | mixedSizes, mixedStrides); |
| 117 | return success(); |
| 118 | } |
| 119 | }; |
| 120 | |
| 121 | /// Fold expand_shape which only adds static dimensions of size `1` |
| 122 | /// into insert_slice. |
| 123 | template <typename OpTy> |
| 124 | struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> { |
| 125 | using OpRewritePattern<OpTy>::OpRewritePattern; |
| 126 | |
| 127 | LogicalResult matchAndRewrite(OpTy insertSliceOp, |
| 128 | PatternRewriter &rewriter) const override { |
| 129 | auto expandShapeOp = insertSliceOp.getSource() |
| 130 | .template getDefiningOp<tensor::ExpandShapeOp>(); |
| 131 | if (!expandShapeOp) |
| 132 | return failure(); |
| 133 | |
| 134 | // Only fold away simple expansion where all added dimensions have static |
| 135 | // size `1`. |
| 136 | SliceVerificationResult res = isRankReducedType( |
| 137 | expandShapeOp.getResultType(), expandShapeOp.getSrcType()); |
| 138 | if (res != SliceVerificationResult::Success) |
| 139 | return rewriter.notifyMatchFailure(insertSliceOp, |
| 140 | "expected rank increasing expansion" ); |
| 141 | |
| 142 | rewriter.modifyOpInPlace(insertSliceOp, [&]() { |
| 143 | insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc()); |
| 144 | }); |
| 145 | return success(); |
| 146 | } |
| 147 | }; |
| 148 | |
| 149 | /// Pattern to bubble up a tensor.expand_shape op through a producer |
| 150 | /// tensor.collapse_shape op that has non intersecting reassociations. |
| 151 | struct BubbleUpExpandThroughParallelCollapse |
| 152 | : public OpRewritePattern<tensor::ExpandShapeOp> { |
| 153 | using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern; |
| 154 | |
| 155 | LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, |
| 156 | PatternRewriter &rewriter) const override { |
| 157 | auto collapseOp = |
| 158 | expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>(); |
| 159 | if (!collapseOp) |
| 160 | return failure(); |
| 161 | auto expandReInds = expandOp.getReassociationIndices(); |
| 162 | auto collapseReInds = collapseOp.getReassociationIndices(); |
| 163 | |
| 164 | // Special case where the collapsed tensor to expand is a 0-D tensor, |
| 165 | // then the reassociation maps will be empty and not produce valid results. |
| 166 | if (expandReInds.size() == 0) { |
| 167 | return failure(); |
| 168 | } |
| 169 | |
| 170 | // Reshapes are parallel to each other (by construction the number of |
| 171 | // reassociations specified in the collapse and expand are the same), if at |
| 172 | // any position |
| 173 | // 1. either the reassociation indices are of the same size, or |
| 174 | // 2. either the reassociation in the collapse or the expand is of size 1. |
| 175 | ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape(); |
| 176 | ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape(); |
| 177 | for (auto [expandReassociation, collapseReassociation] : |
| 178 | llvm::zip_equal(expandReInds, collapseReInds)) { |
| 179 | if (collapseReassociation.size() == expandReassociation.size()) { |
| 180 | // Even if the reassociations are the same, the collapse/expand should |
| 181 | // result in the same dimensions. i.e 4x8x2 into 64 should be expanded |
| 182 | // into 4x8x2 again. In presense of dynamic dimensions one can only |
| 183 | // verify "equality" when there is only one dynamic dimension present, |
| 184 | // and all other static dimensions are equal. |
| 185 | ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice( |
| 186 | collapseReassociation.front(), collapseReassociation.size()); |
| 187 | int64_t numCollapsedDynamic = |
| 188 | llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic); |
| 189 | ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice( |
| 190 | expandReassociation.front(), expandReassociation.size()); |
| 191 | int64_t numExpandedDynamic = |
| 192 | llvm::count_if(expandedStaticShapes, ShapedType::isDynamic); |
| 193 | if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 || |
| 194 | collapsedStaticShapes != expandedStaticShapes) { |
| 195 | return failure(); |
| 196 | } |
| 197 | continue; |
| 198 | } |
| 199 | // If the reassociations are not same, one or the other needs to be of |
| 200 | // size one. |
| 201 | if (collapseReassociation.size() != 1 && expandReassociation.size() != 1) |
| 202 | return failure(); |
| 203 | } |
| 204 | |
| 205 | // Compute new reassociation indices and expanded/collaped shapes. |
| 206 | SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds; |
| 207 | Location loc = expandOp->getLoc(); |
| 208 | SmallVector<OpFoldResult> sourceSizes = |
| 209 | tensor::getMixedSizes(builder&: rewriter, loc, value: collapseOp.getSrc()); |
| 210 | SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape(); |
| 211 | SmallVector<OpFoldResult> newExpandSizes; |
| 212 | |
| 213 | int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0, |
| 214 | resultSizeIndex = 0; |
| 215 | |
| 216 | for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) { |
| 217 | auto &collapseReassociation = collapseReInds[idx]; |
| 218 | auto &expandReassociation = expandReInds[idx]; |
| 219 | |
| 220 | // Case 1. The reassociations are same in the collapse producer |
| 221 | // and expand consumer. In the swapped expand, each of the final |
| 222 | // dimensions are kept as is in the expand and the collapse. So, |
| 223 | // for every element in the `ReassocationIndices` vector add a new |
| 224 | // `ReassociationIndices` vector for the swapped expand and collapse |
| 225 | // (of size 1). |
| 226 | if (collapseReassociation.size() == expandReassociation.size()) { |
| 227 | for (size_t i = 0; i < collapseReassociation.size(); ++i) { |
| 228 | newCollapseReInds.push_back(Elt: {newCollapseIndex++}); |
| 229 | newExpandReInds.push_back(Elt: {newExpandIndex++}); |
| 230 | newExpandSizes.push_back(Elt: resultSizes[resultSizeIndex++]); |
| 231 | sourceSizeIndex++; |
| 232 | } |
| 233 | continue; |
| 234 | } |
| 235 | |
| 236 | // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and |
| 237 | // in the expand is of size == 1). In this case, the original dimensions |
| 238 | // are preserved on expansion and collapsed subsequently. |
| 239 | if (collapseReassociation.size() != 1) { |
| 240 | ReassociationIndices newCollapseReassociation; |
| 241 | for (size_t i = 0; i < collapseReassociation.size(); ++i) { |
| 242 | newCollapseReassociation.push_back(Elt: newCollapseIndex++); |
| 243 | newExpandReInds.push_back(Elt: {newExpandIndex++}); |
| 244 | newExpandSizes.push_back(Elt: sourceSizes[sourceSizeIndex++]); |
| 245 | } |
| 246 | resultSizeIndex++; |
| 247 | newCollapseReInds.push_back(Elt: newCollapseReassociation); |
| 248 | continue; |
| 249 | } |
| 250 | |
| 251 | // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and |
| 252 | // in the collapse is of size == 1). In this case, the expansion happens |
| 253 | // first and the expanded dimensions are preserved on collapse. |
| 254 | ReassociationIndices newExpandReassociation; |
| 255 | for (size_t i = 0; i < expandReassociation.size(); ++i) { |
| 256 | newExpandReassociation.push_back(Elt: newExpandIndex++); |
| 257 | newCollapseReInds.push_back(Elt: {newCollapseIndex++}); |
| 258 | newExpandSizes.push_back(Elt: resultSizes[resultSizeIndex++]); |
| 259 | } |
| 260 | newExpandReInds.push_back(Elt: newExpandReassociation); |
| 261 | sourceSizeIndex++; |
| 262 | } |
| 263 | |
| 264 | // Swap reshape order. |
| 265 | SmallVector<Value> dynamicSizes; |
| 266 | SmallVector<int64_t> staticSizes; |
| 267 | dispatchIndexOpFoldResults(ofrs: newExpandSizes, dynamicVec&: dynamicSizes, staticVec&: staticSizes); |
| 268 | auto expandResultType = expandOp.getResultType().clone(staticSizes); |
| 269 | Value newCollapseSrc = collapseOp.getSrc(); |
| 270 | // If the number of reassociation indices in the new `expand_shape` op |
| 271 | // matches the number of dimensions of the result, then the expand_shape |
| 272 | // is a no-op. |
| 273 | if (newExpandReInds.size() != newExpandSizes.size()) { |
| 274 | newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>( |
| 275 | loc, expandResultType, newCollapseSrc, newExpandReInds, |
| 276 | newExpandSizes); |
| 277 | } |
| 278 | |
| 279 | // If the number of reassociation indices in the new `collapse_shape` op |
| 280 | // matches the number of dimensions of the source, then the collapse_shape |
| 281 | // is a no-op. |
| 282 | Value replacement = newCollapseSrc; |
| 283 | if (newCollapseReInds.size() != newExpandSizes.size()) { |
| 284 | replacement = rewriter.create<tensor::CollapseShapeOp>( |
| 285 | loc, newCollapseSrc, newCollapseReInds); |
| 286 | } |
| 287 | rewriter.replaceOp(expandOp, replacement); |
| 288 | return success(); |
| 289 | } |
| 290 | }; |
| 291 | |
| 292 | /// Converts `tensor.extract_slice(tensor.expand_shape)` to |
| 293 | /// `tensor.expand_shape(tensor.extract_slice)`. |
| 294 | /// |
| 295 | /// For this transformation to be possible, the slice must be fully contiguous |
| 296 | /// within each reassociation group of the expand_shape. A slice is defined as |
| 297 | /// fully contiguous within a reassociation group if after flattening the |
| 298 | /// reassociation group to a single 1D range, then the slice taken out of the |
| 299 | /// group could be defined as a single contiguous subrange within that range. |
| 300 | /// |
| 301 | /// Rank reducing slices are not supported. |
| 302 | /// |
| 303 | /// Example: |
| 304 | /// The transformation is possible because each reassociation group has a |
| 305 | /// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]). |
| 306 | /// ``` |
| 307 | /// BEFORE: |
| 308 | /// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]] |
| 309 | /// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32> |
| 310 | /// %slice = tensor.extract_slice %reshape ... |
| 311 | /// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32> |
| 312 | /// |
| 313 | /// AFTER: |
| 314 | /// %slice = tensor.extract_slice %in ... |
| 315 | /// tensor<8x16x32xf32> to tensor<8x5x4xf32> |
| 316 | /// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]] |
| 317 | /// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32> |
| 318 | /// ``` |
| 319 | /// |
| 320 | /// Note - this pattern could be extended to be a swap pattern between |
| 321 | /// `tensor.expand_shape` and `tensor.extract_slice`, but is currently |
| 322 | /// implemented only as a bubble up pattern for `tensor.extract_slice`. |
| 323 | struct BubbleUpExpandShapeThroughExtractSlice |
| 324 | : public OpRewritePattern<tensor::ExtractSliceOp> { |
| 325 | using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
| 326 | |
| 327 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
| 328 | PatternRewriter &rewriter) const override { |
| 329 | auto expandShapeOp = |
| 330 | sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); |
| 331 | |
| 332 | if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp, |
| 333 | rewriter) |
| 334 | .failed()) |
| 335 | return failure(); |
| 336 | |
| 337 | // The tensor.extract_slice before applying the pattern works on the result |
| 338 | // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) |
| 339 | // referring to the state before applying the pattern are named with the |
| 340 | // prefix "expanded", and ones referring to the state after applying the |
| 341 | // pattern are named with the prefix "collapsed". |
| 342 | SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets(); |
| 343 | SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); |
| 344 | SmallVector<OpFoldResult> expandedShape = |
| 345 | getMixedValues(expandShapeOp.getStaticOutputShape(), |
| 346 | expandShapeOp.getOutputShape(), rewriter); |
| 347 | |
| 348 | // Helper variables and function for accumulating the size values. |
| 349 | Location loc = expandShapeOp->getLoc(); |
| 350 | AffineExpr d0, d1, d2; |
| 351 | bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1, exprs&: d2); |
| 352 | // Multiply two integers. |
| 353 | auto mul = [&](OpFoldResult v1, OpFoldResult v2) { |
| 354 | auto mulMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {d0 * d1}); |
| 355 | return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap, |
| 356 | {v1, v2}); |
| 357 | }; |
| 358 | |
| 359 | // Compute new offsets, sizes, and strides for tensor.extract_slice. |
| 360 | // The new tensor.extract_slice will work on a tensor that has has a rank of |
| 361 | // ReassociationIndices.size(). In the loop a single offset, size, and |
| 362 | // stride value is computed per reassociation group. |
| 363 | SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes, |
| 364 | collapsedStrides; |
| 365 | for (const ReassociationIndices &indices : |
| 366 | expandShapeOp.getReassociationIndices()) { |
| 367 | // collapsedSize will hold the size of the single dim that represents the |
| 368 | // reassociation group in the non expanded tensor. |
| 369 | OpFoldResult collapsedSize = rewriter.getIndexAttr(1); |
| 370 | // The reassocGroupSizes and reassocGroupOffsets are used to create an |
| 371 | // affine.linearize_index op to linearize the single offset value required |
| 372 | // for this reassociation group. |
| 373 | SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets; |
| 374 | |
| 375 | for (long expandedDim : indices) { |
| 376 | // reassocGroupSizes and reassocGroupOffsets can be obtained directly |
| 377 | // from the expanded state, but the collapsed size requires calculation |
| 378 | // as it did not previously exist. |
| 379 | reassocGroupSizes.push_back(expandedShape[expandedDim]); |
| 380 | reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); |
| 381 | collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); |
| 382 | } |
| 383 | |
| 384 | SmallVector<Value> offsetVals = |
| 385 | llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { |
| 386 | return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); |
| 387 | }); |
| 388 | OpFoldResult collapsedOffset = |
| 389 | rewriter |
| 390 | .create<affine::AffineLinearizeIndexOp>(loc, offsetVals, |
| 391 | reassocGroupSizes, |
| 392 | /*disjoint=*/true) |
| 393 | .getResult(); |
| 394 | collapsedOffsets.push_back(collapsedOffset); |
| 395 | collapsedSizes.push_back(collapsedSize); |
| 396 | |
| 397 | // Only unit stride is supported. |
| 398 | collapsedStrides.push_back(rewriter.getIndexAttr(1)); |
| 399 | } |
| 400 | |
| 401 | // The shape of the result can be obtained from the sizes passed in. |
| 402 | SmallVector<Value> dynDims; |
| 403 | SmallVector<int64_t> shape; |
| 404 | dispatchIndexOpFoldResults(ofrs: expandedSizes, dynamicVec&: dynDims, staticVec&: shape); |
| 405 | RankedTensorType resultType = RankedTensorType::get( |
| 406 | shape, expandShapeOp.getResultType().getElementType()); |
| 407 | |
| 408 | // Create a new ExtractSliceOp and ExpandShapeOp. |
| 409 | Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>( |
| 410 | loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes, |
| 411 | collapsedStrides); |
| 412 | rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( |
| 413 | sliceOp, resultType, newSliceOp, |
| 414 | expandShapeOp.getReassociationIndices(), expandedSizes); |
| 415 | return success(); |
| 416 | } |
| 417 | |
| 418 | // Helper function to check if all the required conditions for the |
| 419 | // tensor.extract_slice to be bubbled up through the tensor.expand_shape are |
| 420 | // met. |
| 421 | LogicalResult |
| 422 | checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp, |
| 423 | tensor::ExpandShapeOp expandShapeOp, |
| 424 | PatternRewriter &rewriter) const { |
| 425 | |
| 426 | if (!expandShapeOp) { |
| 427 | return rewriter.notifyMatchFailure( |
| 428 | sliceOp, "tensor.extract_slice source not produced by expand_shape" ); |
| 429 | } |
| 430 | |
| 431 | if (!sliceOp.hasUnitStride()) { |
| 432 | return rewriter.notifyMatchFailure( |
| 433 | sliceOp, "unsupported: non-unit stride. Only contiguous slices can " |
| 434 | "be supported in this transformation." ); |
| 435 | } |
| 436 | |
| 437 | SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); |
| 438 | SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); |
| 439 | |
| 440 | if (static_cast<size_t>(sliceOp.getResultType().getRank()) != |
| 441 | sizes.size()) { |
| 442 | return rewriter.notifyMatchFailure(sliceOp, |
| 443 | "unimplemented: rank reducing slice" ); |
| 444 | } |
| 445 | |
| 446 | SmallVector<OpFoldResult> outputShape = |
| 447 | getMixedValues(expandShapeOp.getStaticOutputShape(), |
| 448 | expandShapeOp.getOutputShape(), rewriter); |
| 449 | |
| 450 | std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)> |
| 451 | isZeroOffsetAndFullSize = |
| 452 | [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { |
| 453 | if (!isZeroInteger(v: offset)) |
| 454 | return false; |
| 455 | FailureOr<bool> maybeEqual = |
| 456 | ValueBoundsConstraintSet::areEqual(var1: sliceSize, var2: size); |
| 457 | return llvm::succeeded(Result: maybeEqual) && maybeEqual.value(); |
| 458 | }; |
| 459 | |
| 460 | // Check that the slice is contiguous within each reassociation group. |
| 461 | // The slice is contiguous only if after the first dimension where a non |
| 462 | // unit slice is taken, the slice size on all subsequent dimensions of the |
| 463 | // group is equal to the entire size of the dimension. |
| 464 | // Examples of contiguous slices: |
| 465 | // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] |
| 466 | // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] |
| 467 | // Examples of non contiguous slices: |
| 468 | // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] |
| 469 | // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] |
| 470 | for (const ReassociationIndices &indices : |
| 471 | expandShapeOp.getReassociationIndices()) { |
| 472 | int64_t i = 0; |
| 473 | int64_t e = indices.size(); |
| 474 | // Find the first expanded dim after the first dim with non-unit extracted |
| 475 | // size. |
| 476 | for (; i < e; ++i) { |
| 477 | if (!isOneInteger(sizes[indices[i]])) { |
| 478 | // +1 to skip the first non-unit size dim. |
| 479 | i++; |
| 480 | break; |
| 481 | } |
| 482 | } |
| 483 | |
| 484 | // Verify that all subsequent dimensions extract the full size of the |
| 485 | // source tensor. |
| 486 | for (; i < e; ++i) { |
| 487 | int64_t expandedDim = indices[i]; |
| 488 | if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], |
| 489 | outputShape[expandedDim])) { |
| 490 | return rewriter.notifyMatchFailure( |
| 491 | sliceOp, "Not a contiguous slice of the expanded tensor." ); |
| 492 | } |
| 493 | } |
| 494 | } |
| 495 | |
| 496 | return success(); |
| 497 | } |
| 498 | }; |
| 499 | |
| 500 | /// Converts `tensor.extract_slice(tensor.collapse_shape)` to |
| 501 | /// `tensor.collapse_shape(tensor.extract_slice)`. |
| 502 | /// |
| 503 | /// For this transformation to be possible - after bubbling up, the extraction |
| 504 | /// of the contiguous slice must be representable as a single slice obtained via |
| 505 | /// tensor.extract_slice within each reassociation group of the src. |
| 506 | /// |
| 507 | /// In case the size and offset extracted are static then this is possible if |
| 508 | /// the following conditions are met within each reassociation group: |
| 509 | /// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the |
| 510 | /// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the |
| 511 | /// shape of a desired slice. A slice of shape S can be extracted as a |
| 512 | /// contiguous span of elements if and only if there exists an index k in {0, 1, |
| 513 | /// ..., n} such that: |
| 514 | /// S_i = 1 for all i < k (that is, all leading dimensions are singleton), |
| 515 | /// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly |
| 516 | /// one dimension), |
| 517 | /// S_i = A_i for all i > k (that is, all trailing dimensions are preserved |
| 518 | /// in full). |
| 519 | /// In other words, the slice shape S must be of the form: |
| 520 | /// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] |
| 521 | /// |
| 522 | /// In case the size and/or offset extracted are dynamic then this is possible |
| 523 | /// only if there is single dimension in the reassociation group that has a size |
| 524 | /// not equal to 1. |
| 525 | /// In other words, the tensor shape must be of the form: |
| 526 | /// [ 1, 1, ..., 1, A, 1, ...,1 ] |
| 527 | /// Note - it might be possible to enable this pattern for more cases when the |
| 528 | /// size/offset are dynamic via performing an analysis of the possible values |
| 529 | /// that could be given to the size/offset. |
| 530 | /// |
| 531 | /// Example: |
| 532 | /// The transformation is possible because each reassociation group can be |
| 533 | /// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?], |
| 534 | /// [20->10]). |
| 535 | /// ``` |
| 536 | /// BEFORE: |
| 537 | /// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ... |
| 538 | /// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32> |
| 539 | /// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1] |
| 540 | /// tensor<128x7x20xf32> to tensor<32x?x10xf32> |
| 541 | /// |
| 542 | /// AFTER: |
| 543 | /// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10] |
| 544 | // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32> |
| 545 | /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ... |
| 546 | /// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32> |
| 547 | /// ``` |
| 548 | /// |
| 549 | /// Negative example: |
| 550 | /// The transformation is not possible because we cannot use a single slice to |
| 551 | /// represent the reassociation group [2x3x10->???]. If we would want the |
| 552 | /// collapse to be after the extraction, we would need to extract multiple |
| 553 | /// slices and concat them together. |
| 554 | /// ``` |
| 555 | /// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into |
| 556 | /// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] : |
| 557 | /// tensor<60xf32> to tensor<15xf32> |
| 558 | /// ``` |
| 559 | /// If we would want the collapse to be after the extraction, a possible |
| 560 | /// alternate transformation could be to extract multiple slices and concat them |
| 561 | /// together: |
| 562 | /// ``` |
| 563 | /// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] : |
| 564 | /// tensor<2x3x10xf32> to tensor <1x1x10xf32> |
| 565 | /// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] : |
| 566 | /// tensor<2x3x10xf32> to tensor <1x1x5xf32> |
| 567 | /// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} : |
| 568 | /// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32> |
| 569 | /// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32> |
| 570 | /// to tensor<15xf32> |
| 571 | /// ``` |
| 572 | /// But this is not the intended purpose of the transformation. |
| 573 | struct |
| 574 | : public OpRewritePattern<tensor::ExtractSliceOp> { |
| 575 | using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
| 576 | |
| 577 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
| 578 | PatternRewriter &rewriter) const override { |
| 579 | auto collapseShapeOp = |
| 580 | sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); |
| 581 | if (!collapseShapeOp) { |
| 582 | return rewriter.notifyMatchFailure( |
| 583 | sliceOp, |
| 584 | "tensor.extract_slice source not produced by tensor.collapse_shape" ); |
| 585 | } |
| 586 | |
| 587 | if (!sliceOp.hasUnitStride()) { |
| 588 | return rewriter.notifyMatchFailure( |
| 589 | sliceOp, "unsupported: non-unit stride. Only contiguous slices can " |
| 590 | "be supported in this transformation." ); |
| 591 | } |
| 592 | |
| 593 | // The tensor.extract_slice before applying the pattern works on the result |
| 594 | // of the tensor.collapse_shape, so variables (i.e. inputs for |
| 595 | // ExtractSliceOp) referring to the state before applying the pattern are |
| 596 | // named with the prefix "collapsed", and ones referring to the state after |
| 597 | // applying the pattern are named with the prefix "expanded". |
| 598 | SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets(); |
| 599 | SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes(); |
| 600 | |
| 601 | if (static_cast<size_t>(sliceOp.getResultType().getRank()) != |
| 602 | collapsedSizes.size()) { |
| 603 | return rewriter.notifyMatchFailure(sliceOp, |
| 604 | "unimplemented: rank reducing slice" ); |
| 605 | } |
| 606 | |
| 607 | ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape(); |
| 608 | SmallVector<ReassociationIndices, 4> reassociationIndices = |
| 609 | collapseShapeOp.getReassociationIndices(); |
| 610 | |
| 611 | // Compute new offsets, sizes, and strides for tensor.extract_slice. |
| 612 | // The new tensor.extract_slice will work on a tensor that has has a rank |
| 613 | // equal to the rank of the src of the collapse_shape. In each iteration of |
| 614 | // the loop, the offsets and sizes will be computed per reassociation group. |
| 615 | SmallVector<OpFoldResult> expandedOffsets, expandedSizes; |
| 616 | SmallVector<OpFoldResult> expandedStrides(srcShape.size(), |
| 617 | rewriter.getIndexAttr(1)); |
| 618 | |
| 619 | for (auto [collapsedSize, collapsedOffset, reassocIndices] : |
| 620 | llvm::zip_equal(collapsedSizes, collapsedOffsets, |
| 621 | collapseShapeOp.getReassociationIndices())) { |
| 622 | // CASE #1 - size and/or offset are dynamic. |
| 623 | // In this case, the slice can be represented as a contiguous slice only |
| 624 | // if there is a single dimension in the reassociation group that has a |
| 625 | // size not equal to 1. |
| 626 | if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) { |
| 627 | int nonUnitSizeCount = 0; |
| 628 | for (int64_t expandedShapeIdx : reassocIndices) { |
| 629 | if (srcShape[expandedShapeIdx] != 1) { |
| 630 | nonUnitSizeCount++; |
| 631 | expandedSizes.push_back(collapsedSize); |
| 632 | expandedOffsets.push_back(collapsedOffset); |
| 633 | continue; |
| 634 | } |
| 635 | |
| 636 | expandedSizes.push_back(rewriter.getIndexAttr(1)); |
| 637 | expandedOffsets.push_back(rewriter.getIndexAttr(0)); |
| 638 | } |
| 639 | |
| 640 | if (nonUnitSizeCount != 1) { |
| 641 | return rewriter.notifyMatchFailure( |
| 642 | sliceOp, |
| 643 | "unsupported: slice cannot be verified to be contiguous" ); |
| 644 | } |
| 645 | continue; |
| 646 | } |
| 647 | |
| 648 | // CASE #2 = size and offset are static. |
| 649 | // Verify that the slice can be represented as a contiguous slice of the |
| 650 | // src of the collapse_shape. |
| 651 | // Checking this is done on order of most internal dimensions first, |
| 652 | // so traversal is done in reverse order of the reassociation group. |
| 653 | // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, |
| 654 | // ...,An] then we first find the size and offset for n...k+1 then for k |
| 655 | // and then for k-1...0. |
| 656 | |
| 657 | // currentCollapsedsize and currentCollapsedOffset are initialized with |
| 658 | // the original collapsed size and offset and divided by the expanded |
| 659 | // shape size in each dimension as we go along the reassociation group. |
| 660 | // In essence we are spreading the original collapsed size and offset over |
| 661 | // the various expanded slice dimensions. |
| 662 | // The variables are used both to check the validity of the slice and to |
| 663 | // compute the expanded sizes and offsets. |
| 664 | int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); |
| 665 | int64_t currentCollapsedOffset = |
| 666 | getConstantIntValue(collapsedOffset).value(); |
| 667 | |
| 668 | SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets; |
| 669 | |
| 670 | ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), |
| 671 | reassocIndices.rend()); |
| 672 | int64_t idx = 0; |
| 673 | int64_t reassocGroupSize = reassocIndices.size(); |
| 674 | |
| 675 | // First handle the trailing dimensions where the slice size should be |
| 676 | // equal to the tensor shape and the offset should be 0 (n...k+1). |
| 677 | for (; idx < reassocGroupSize; ++idx) { |
| 678 | int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; |
| 679 | |
| 680 | if (currentCollapsedsize < expandedShapeSize) |
| 681 | break; |
| 682 | |
| 683 | // We need to make sure that the slice size can be set to the shape size |
| 684 | // and the offset to 0. |
| 685 | if ((currentCollapsedsize % expandedShapeSize) != 0 || |
| 686 | (currentCollapsedOffset % expandedShapeSize) != 0) { |
| 687 | return rewriter.notifyMatchFailure( |
| 688 | sliceOp, "unsupported: cannot be extracted as a contiguous slice " |
| 689 | "of the src of the collapse_shape" ); |
| 690 | } |
| 691 | |
| 692 | groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize)); |
| 693 | groupExpandedOffsets.push_back(rewriter.getIndexAttr(0)); |
| 694 | |
| 695 | currentCollapsedsize /= expandedShapeSize; |
| 696 | currentCollapsedOffset /= expandedShapeSize; |
| 697 | } |
| 698 | |
| 699 | // Now handle the first dim where slicing occurs on (k). |
| 700 | if (idx < reassocGroupSize) { |
| 701 | int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; |
| 702 | int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; |
| 703 | // We need to make sure that the slice size in this dim + offset will |
| 704 | // not exceed the shape size. |
| 705 | if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { |
| 706 | return rewriter.notifyMatchFailure( |
| 707 | sliceOp, "unsupported: slice cannot be extracted as a contiguous " |
| 708 | "slice of the src of the collapse_shape" ); |
| 709 | } |
| 710 | |
| 711 | groupExpandedSizes.push_back( |
| 712 | rewriter.getIndexAttr(currentCollapsedsize)); |
| 713 | groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); |
| 714 | |
| 715 | currentCollapsedOffset /= expandedShapeSize; |
| 716 | } |
| 717 | |
| 718 | // Now handle the leading dimensions where the slice size is equal to 1 |
| 719 | // (k-1...0). |
| 720 | // The size for these dimensions must be 1 because of how we constructed |
| 721 | // the slice size of the expanded shape. We spread the original collapsed |
| 722 | // size over the expanded shape sizes until we reached dimension k where |
| 723 | // the remaining size was smaller than the expanded shape size, and spread |
| 724 | // the remaining size on it. So, now we are left with only 1s. |
| 725 | for (idx++; idx < reassocGroupSize; ++idx) { |
| 726 | int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; |
| 727 | int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; |
| 728 | groupExpandedSizes.push_back(rewriter.getIndexAttr(1)); |
| 729 | groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); |
| 730 | currentCollapsedOffset /= expandedShapeSize; |
| 731 | } |
| 732 | |
| 733 | expandedSizes.append(groupExpandedSizes.rbegin(), |
| 734 | groupExpandedSizes.rend()); |
| 735 | expandedOffsets.append(groupExpandedOffsets.rbegin(), |
| 736 | groupExpandedOffsets.rend()); |
| 737 | } |
| 738 | |
| 739 | Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>( |
| 740 | collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets, |
| 741 | expandedSizes, expandedStrides); |
| 742 | rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( |
| 743 | sliceOp, sliceOp.getResultType(), newSliceOp, |
| 744 | collapseShapeOp.getReassociationIndices()); |
| 745 | |
| 746 | return success(); |
| 747 | } |
| 748 | }; |
| 749 | |
| 750 | } // namespace |
| 751 | |
| 752 | void mlir::tensor::populateReassociativeReshapeFoldingPatterns( |
| 753 | RewritePatternSet &patterns) { |
| 754 | patterns |
| 755 | .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract, |
| 756 | FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>, |
| 757 | FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>, |
| 758 | FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>, |
| 759 | FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>( |
| 760 | patterns.getContext()); |
| 761 | } |
| 762 | |
| 763 | void mlir::tensor::populateBubbleUpExpandShapePatterns( |
| 764 | RewritePatternSet &patterns) { |
| 765 | patterns.add<BubbleUpExpandThroughParallelCollapse>(arg: patterns.getContext()); |
| 766 | } |
| 767 | |
| 768 | void mlir::tensor::( |
| 769 | RewritePatternSet &patterns) { |
| 770 | patterns.add<BubbleUpExpandShapeThroughExtractSlice, |
| 771 | BubbleUpCollapseShapeThroughExtractSlice>(arg: patterns.getContext()); |
| 772 | } |
| 773 | |