| 1 | //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 10 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 11 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 12 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| 13 | #include "mlir/IR/PatternMatch.h" |
| 14 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| 15 | #include "llvm/ADT/STLExtras.h" |
| 16 | #include "llvm/Support/LogicalResult.h" |
| 17 | |
| 18 | using namespace mlir; |
| 19 | using namespace mlir::tensor; |
| 20 | |
| 21 | namespace { |
| 22 | /// Fold expand_shape(extract_slice) ops that cancel itself out. |
| 23 | struct FoldExpandOfRankReducingExtract |
| 24 | : public OpRewritePattern<ExpandShapeOp> { |
| 25 | using OpRewritePattern<ExpandShapeOp>::OpRewritePattern; |
| 26 | |
| 27 | LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, |
| 28 | PatternRewriter &rewriter) const override { |
| 29 | RankedTensorType resultType = expandShapeOp.getResultType(); |
| 30 | auto = |
| 31 | expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>(); |
| 32 | if (!extractSliceOp) |
| 33 | return failure(); |
| 34 | RankedTensorType srcType = extractSliceOp.getSourceType(); |
| 35 | |
| 36 | // Only cases where the ExpandShapeOp can be folded away entirely are |
| 37 | // supported. Moreover, only simple cases where the resulting ExtractSliceOp |
| 38 | // has no rank-reduction anymore are supported at the moment. |
| 39 | RankedTensorType = ExtractSliceOp::inferResultType( |
| 40 | sourceTensorType: srcType, staticOffsets: extractSliceOp.getStaticOffsets(), |
| 41 | staticSizes: extractSliceOp.getStaticSizes(), staticStrides: extractSliceOp.getStaticStrides()); |
| 42 | if (nonReducingExtractType != resultType) |
| 43 | return failure(); |
| 44 | |
| 45 | SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); |
| 46 | SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); |
| 47 | SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); |
| 48 | rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( |
| 49 | op: expandShapeOp, args: extractSliceOp.getSource(), args&: mixedOffsets, args&: mixedSizes, |
| 50 | args&: mixedStrides); |
| 51 | return success(); |
| 52 | } |
| 53 | }; |
| 54 | |
| 55 | /// Fold collapse_shape which only removes static dimensions of size `1` |
| 56 | /// into extract_slice. |
| 57 | struct |
| 58 | : public OpRewritePattern<tensor::CollapseShapeOp> { |
| 59 | using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern; |
| 60 | |
| 61 | LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp, |
| 62 | PatternRewriter &rewriter) const override { |
| 63 | auto = |
| 64 | collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>(); |
| 65 | // Collapse cannot be folded away with multiple users of the extract slice |
| 66 | // and it is not necessarily beneficial to only convert the collapse into |
| 67 | // another extract slice. |
| 68 | if (!extractSliceOp || !extractSliceOp->hasOneUse()) |
| 69 | return failure(); |
| 70 | |
| 71 | // Only fold away simple collapse where all removed dimensions have static |
| 72 | // size `1`. |
| 73 | SliceVerificationResult res = isRankReducedType( |
| 74 | originalType: collapseShapeOp.getSrcType(), candidateReducedType: collapseShapeOp.getResultType()); |
| 75 | if (res != SliceVerificationResult::Success) |
| 76 | return rewriter.notifyMatchFailure(arg&: collapseShapeOp, |
| 77 | msg: "expected unpadding collapse" ); |
| 78 | |
| 79 | Value = rewriter.create<tensor::ExtractSliceOp>( |
| 80 | location: extractSliceOp.getLoc(), args: collapseShapeOp.getResultType(), |
| 81 | args: extractSliceOp.getSource(), args: extractSliceOp.getMixedOffsets(), |
| 82 | args: extractSliceOp.getMixedSizes(), args: extractSliceOp.getMixedStrides()); |
| 83 | rewriter.replaceOp(op: collapseShapeOp, newValues: unPaddedExtractSlice); |
| 84 | return success(); |
| 85 | } |
| 86 | }; |
| 87 | |
| 88 | /// Fold insert_slice(collapse_shape) ops that cancel itself out. |
| 89 | template <typename OpTy> |
| 90 | struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> { |
| 91 | using OpRewritePattern<OpTy>::OpRewritePattern; |
| 92 | |
| 93 | LogicalResult matchAndRewrite(OpTy insertSliceOp, |
| 94 | PatternRewriter &rewriter) const override { |
| 95 | auto collapseShapeOp = |
| 96 | insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>(); |
| 97 | if (!collapseShapeOp) |
| 98 | return failure(); |
| 99 | RankedTensorType srcType = collapseShapeOp.getSrcType(); |
| 100 | |
| 101 | // Only cases where the CollapseShapeOp can be folded away entirely are |
| 102 | // supported. Moreover, only simple cases where the resulting InsertSliceOp |
| 103 | // has no rank-reduction anymore are supported at the moment. |
| 104 | RankedTensorType nonReducingInsertType = |
| 105 | RankedTensorType::get(shape: insertSliceOp.getStaticSizes(), |
| 106 | elementType: insertSliceOp.getDestType().getElementType()); |
| 107 | if (nonReducingInsertType != srcType) |
| 108 | return failure(); |
| 109 | |
| 110 | SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); |
| 111 | SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); |
| 112 | SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); |
| 113 | rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(), |
| 114 | insertSliceOp.getDest(), mixedOffsets, |
| 115 | mixedSizes, mixedStrides); |
| 116 | return success(); |
| 117 | } |
| 118 | }; |
| 119 | |
| 120 | /// Fold expand_shape which only adds static dimensions of size `1` |
| 121 | /// into insert_slice. |
| 122 | template <typename OpTy> |
| 123 | struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> { |
| 124 | using OpRewritePattern<OpTy>::OpRewritePattern; |
| 125 | |
| 126 | LogicalResult matchAndRewrite(OpTy insertSliceOp, |
| 127 | PatternRewriter &rewriter) const override { |
| 128 | auto expandShapeOp = insertSliceOp.getSource() |
| 129 | .template getDefiningOp<tensor::ExpandShapeOp>(); |
| 130 | if (!expandShapeOp) |
| 131 | return failure(); |
| 132 | |
| 133 | // Only fold away simple expansion where all added dimensions have static |
| 134 | // size `1`. |
| 135 | SliceVerificationResult res = isRankReducedType( |
| 136 | expandShapeOp.getResultType(), expandShapeOp.getSrcType()); |
| 137 | if (res != SliceVerificationResult::Success) |
| 138 | return rewriter.notifyMatchFailure(insertSliceOp, |
| 139 | "expected rank increasing expansion" ); |
| 140 | |
| 141 | rewriter.modifyOpInPlace(insertSliceOp, [&]() { |
| 142 | insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc()); |
| 143 | }); |
| 144 | return success(); |
| 145 | } |
| 146 | }; |
| 147 | |
| 148 | /// Pattern to bubble up a tensor.expand_shape op through a producer |
| 149 | /// tensor.collapse_shape op that has non intersecting reassociations. |
| 150 | struct BubbleUpExpandThroughParallelCollapse |
| 151 | : public OpRewritePattern<tensor::ExpandShapeOp> { |
| 152 | using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern; |
| 153 | |
| 154 | LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, |
| 155 | PatternRewriter &rewriter) const override { |
| 156 | auto collapseOp = |
| 157 | expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>(); |
| 158 | if (!collapseOp) |
| 159 | return failure(); |
| 160 | auto expandReInds = expandOp.getReassociationIndices(); |
| 161 | auto collapseReInds = collapseOp.getReassociationIndices(); |
| 162 | |
| 163 | // Special case where the collapsed tensor to expand is a 0-D tensor, |
| 164 | // then the reassociation maps will be empty and not produce valid results. |
| 165 | if (expandReInds.size() == 0) { |
| 166 | return failure(); |
| 167 | } |
| 168 | |
| 169 | // Reshapes are parallel to each other (by construction the number of |
| 170 | // reassociations specified in the collapse and expand are the same), if at |
| 171 | // any position |
| 172 | // 1. either the reassociation indices are of the same size, or |
| 173 | // 2. either the reassociation in the collapse or the expand is of size 1. |
| 174 | ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape(); |
| 175 | ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape(); |
| 176 | for (auto [expandReassociation, collapseReassociation] : |
| 177 | llvm::zip_equal(t&: expandReInds, u&: collapseReInds)) { |
| 178 | if (collapseReassociation.size() == expandReassociation.size()) { |
| 179 | // Even if the reassociations are the same, the collapse/expand should |
| 180 | // result in the same dimensions. i.e 4x8x2 into 64 should be expanded |
| 181 | // into 4x8x2 again. In presense of dynamic dimensions one can only |
| 182 | // verify "equality" when there is only one dynamic dimension present, |
| 183 | // and all other static dimensions are equal. |
| 184 | ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice( |
| 185 | N: collapseReassociation.front(), M: collapseReassociation.size()); |
| 186 | int64_t numCollapsedDynamic = |
| 187 | llvm::count_if(Range&: collapsedStaticShapes, P: ShapedType::isDynamic); |
| 188 | ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice( |
| 189 | N: expandReassociation.front(), M: expandReassociation.size()); |
| 190 | int64_t numExpandedDynamic = |
| 191 | llvm::count_if(Range&: expandedStaticShapes, P: ShapedType::isDynamic); |
| 192 | if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 || |
| 193 | collapsedStaticShapes != expandedStaticShapes) { |
| 194 | return failure(); |
| 195 | } |
| 196 | continue; |
| 197 | } |
| 198 | // If the reassociations are not same, one or the other needs to be of |
| 199 | // size one. |
| 200 | if (collapseReassociation.size() != 1 && expandReassociation.size() != 1) |
| 201 | return failure(); |
| 202 | } |
| 203 | |
| 204 | // Compute new reassociation indices and expanded/collaped shapes. |
| 205 | SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds; |
| 206 | Location loc = expandOp->getLoc(); |
| 207 | SmallVector<OpFoldResult> sourceSizes = |
| 208 | tensor::getMixedSizes(builder&: rewriter, loc, value: collapseOp.getSrc()); |
| 209 | SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape(); |
| 210 | SmallVector<OpFoldResult> newExpandSizes; |
| 211 | |
| 212 | int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0, |
| 213 | resultSizeIndex = 0; |
| 214 | |
| 215 | for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) { |
| 216 | auto &collapseReassociation = collapseReInds[idx]; |
| 217 | auto &expandReassociation = expandReInds[idx]; |
| 218 | |
| 219 | // Case 1. The reassociations are same in the collapse producer |
| 220 | // and expand consumer. In the swapped expand, each of the final |
| 221 | // dimensions are kept as is in the expand and the collapse. So, |
| 222 | // for every element in the `ReassocationIndices` vector add a new |
| 223 | // `ReassociationIndices` vector for the swapped expand and collapse |
| 224 | // (of size 1). |
| 225 | if (collapseReassociation.size() == expandReassociation.size()) { |
| 226 | for (size_t i = 0; i < collapseReassociation.size(); ++i) { |
| 227 | newCollapseReInds.push_back(Elt: {newCollapseIndex++}); |
| 228 | newExpandReInds.push_back(Elt: {newExpandIndex++}); |
| 229 | newExpandSizes.push_back(Elt: resultSizes[resultSizeIndex++]); |
| 230 | sourceSizeIndex++; |
| 231 | } |
| 232 | continue; |
| 233 | } |
| 234 | |
| 235 | // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and |
| 236 | // in the expand is of size == 1). In this case, the original dimensions |
| 237 | // are preserved on expansion and collapsed subsequently. |
| 238 | if (collapseReassociation.size() != 1) { |
| 239 | ReassociationIndices newCollapseReassociation; |
| 240 | for (size_t i = 0; i < collapseReassociation.size(); ++i) { |
| 241 | newCollapseReassociation.push_back(Elt: newCollapseIndex++); |
| 242 | newExpandReInds.push_back(Elt: {newExpandIndex++}); |
| 243 | newExpandSizes.push_back(Elt: sourceSizes[sourceSizeIndex++]); |
| 244 | } |
| 245 | resultSizeIndex++; |
| 246 | newCollapseReInds.push_back(Elt: newCollapseReassociation); |
| 247 | continue; |
| 248 | } |
| 249 | |
| 250 | // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and |
| 251 | // in the collapse is of size == 1). In this case, the expansion happens |
| 252 | // first and the expanded dimensions are preserved on collapse. |
| 253 | ReassociationIndices newExpandReassociation; |
| 254 | for (size_t i = 0; i < expandReassociation.size(); ++i) { |
| 255 | newExpandReassociation.push_back(Elt: newExpandIndex++); |
| 256 | newCollapseReInds.push_back(Elt: {newCollapseIndex++}); |
| 257 | newExpandSizes.push_back(Elt: resultSizes[resultSizeIndex++]); |
| 258 | } |
| 259 | newExpandReInds.push_back(Elt: newExpandReassociation); |
| 260 | sourceSizeIndex++; |
| 261 | } |
| 262 | |
| 263 | // Swap reshape order. |
| 264 | SmallVector<Value> dynamicSizes; |
| 265 | SmallVector<int64_t> staticSizes; |
| 266 | dispatchIndexOpFoldResults(ofrs: newExpandSizes, dynamicVec&: dynamicSizes, staticVec&: staticSizes); |
| 267 | auto expandResultType = expandOp.getResultType().clone(shape: staticSizes); |
| 268 | Value newCollapseSrc = collapseOp.getSrc(); |
| 269 | // If the number of reassociation indices in the new `expand_shape` op |
| 270 | // matches the number of dimensions of the result, then the expand_shape |
| 271 | // is a no-op. |
| 272 | if (newExpandReInds.size() != newExpandSizes.size()) { |
| 273 | newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>( |
| 274 | location: loc, args&: expandResultType, args&: newCollapseSrc, args&: newExpandReInds, |
| 275 | args&: newExpandSizes); |
| 276 | } |
| 277 | |
| 278 | // If the number of reassociation indices in the new `collapse_shape` op |
| 279 | // matches the number of dimensions of the source, then the collapse_shape |
| 280 | // is a no-op. |
| 281 | Value replacement = newCollapseSrc; |
| 282 | if (newCollapseReInds.size() != newExpandSizes.size()) { |
| 283 | replacement = rewriter.create<tensor::CollapseShapeOp>( |
| 284 | location: loc, args&: newCollapseSrc, args&: newCollapseReInds); |
| 285 | } |
| 286 | rewriter.replaceOp(op: expandOp, newValues: replacement); |
| 287 | return success(); |
| 288 | } |
| 289 | }; |
| 290 | |
| 291 | /// Converts `tensor.extract_slice(tensor.expand_shape)` to |
| 292 | /// `tensor.expand_shape(tensor.extract_slice)`. |
| 293 | /// |
| 294 | /// For this transformation to be possible, the slice must be fully contiguous |
| 295 | /// within each reassociation group of the expand_shape. A slice is defined as |
| 296 | /// fully contiguous within a reassociation group if after flattening the |
| 297 | /// reassociation group to a single 1D range, then the slice taken out of the |
| 298 | /// group could be defined as a single contiguous subrange within that range. |
| 299 | /// |
| 300 | /// Rank reducing slices are not supported. |
| 301 | /// |
| 302 | /// Example: |
| 303 | /// The transformation is possible because each reassociation group has a |
| 304 | /// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]). |
| 305 | /// ``` |
| 306 | /// BEFORE: |
| 307 | /// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]] |
| 308 | /// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32> |
| 309 | /// %slice = tensor.extract_slice %reshape ... |
| 310 | /// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32> |
| 311 | /// |
| 312 | /// AFTER: |
| 313 | /// %slice = tensor.extract_slice %in ... |
| 314 | /// tensor<8x16x32xf32> to tensor<8x5x4xf32> |
| 315 | /// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]] |
| 316 | /// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32> |
| 317 | /// ``` |
| 318 | /// |
| 319 | /// Note - this pattern could be extended to be a swap pattern between |
| 320 | /// `tensor.expand_shape` and `tensor.extract_slice`, but is currently |
| 321 | /// implemented only as a bubble up pattern for `tensor.extract_slice`. |
| 322 | struct BubbleUpExpandShapeThroughExtractSlice |
| 323 | : public OpRewritePattern<tensor::ExtractSliceOp> { |
| 324 | using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
| 325 | |
| 326 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
| 327 | PatternRewriter &rewriter) const override { |
| 328 | auto expandShapeOp = |
| 329 | sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); |
| 330 | |
| 331 | if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp, |
| 332 | rewriter) |
| 333 | .failed()) |
| 334 | return failure(); |
| 335 | |
| 336 | // The tensor.extract_slice before applying the pattern works on the result |
| 337 | // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) |
| 338 | // referring to the state before applying the pattern are named with the |
| 339 | // prefix "expanded", and ones referring to the state after applying the |
| 340 | // pattern are named with the prefix "collapsed". |
| 341 | SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets(); |
| 342 | SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); |
| 343 | SmallVector<OpFoldResult> expandedShape = |
| 344 | getMixedValues(staticValues: expandShapeOp.getStaticOutputShape(), |
| 345 | dynamicValues: expandShapeOp.getOutputShape(), b&: rewriter); |
| 346 | |
| 347 | // Helper variables and function for accumulating the size values. |
| 348 | Location loc = expandShapeOp->getLoc(); |
| 349 | AffineExpr d0, d1, d2; |
| 350 | bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1, exprs&: d2); |
| 351 | // Multiply two integers. |
| 352 | auto mul = [&](OpFoldResult v1, OpFoldResult v2) { |
| 353 | auto mulMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {d0 * d1}); |
| 354 | return affine::makeComposedFoldedAffineApply(b&: rewriter, loc, map: mulMap, |
| 355 | operands: {v1, v2}); |
| 356 | }; |
| 357 | |
| 358 | // Compute new offsets, sizes, and strides for tensor.extract_slice. |
| 359 | // The new tensor.extract_slice will work on a tensor that has has a rank of |
| 360 | // ReassociationIndices.size(). In the loop a single offset, size, and |
| 361 | // stride value is computed per reassociation group. |
| 362 | SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes, |
| 363 | collapsedStrides; |
| 364 | for (const ReassociationIndices &indices : |
| 365 | expandShapeOp.getReassociationIndices()) { |
| 366 | // collapsedSize will hold the size of the single dim that represents the |
| 367 | // reassociation group in the non expanded tensor. |
| 368 | OpFoldResult collapsedSize = rewriter.getIndexAttr(value: 1); |
| 369 | // The reassocGroupSizes and reassocGroupOffsets are used to create an |
| 370 | // affine.linearize_index op to linearize the single offset value required |
| 371 | // for this reassociation group. |
| 372 | SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets; |
| 373 | |
| 374 | for (long expandedDim : indices) { |
| 375 | // reassocGroupSizes and reassocGroupOffsets can be obtained directly |
| 376 | // from the expanded state, but the collapsed size requires calculation |
| 377 | // as it did not previously exist. |
| 378 | reassocGroupSizes.push_back(Elt: expandedShape[expandedDim]); |
| 379 | reassocGroupOffsets.push_back(Elt: expandedOffsets[expandedDim]); |
| 380 | collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); |
| 381 | } |
| 382 | |
| 383 | SmallVector<Value> offsetVals = |
| 384 | llvm::map_to_vector(C&: reassocGroupOffsets, F: [&](OpFoldResult ofr) { |
| 385 | return getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr); |
| 386 | }); |
| 387 | OpFoldResult collapsedOffset = |
| 388 | rewriter |
| 389 | .create<affine::AffineLinearizeIndexOp>(location: loc, args&: offsetVals, |
| 390 | args&: reassocGroupSizes, |
| 391 | /*disjoint=*/args: true) |
| 392 | .getResult(); |
| 393 | collapsedOffsets.push_back(Elt: collapsedOffset); |
| 394 | collapsedSizes.push_back(Elt: collapsedSize); |
| 395 | |
| 396 | // Only unit stride is supported. |
| 397 | collapsedStrides.push_back(Elt: rewriter.getIndexAttr(value: 1)); |
| 398 | } |
| 399 | |
| 400 | // The shape of the result can be obtained from the sizes passed in. |
| 401 | SmallVector<Value> dynDims; |
| 402 | SmallVector<int64_t> shape; |
| 403 | dispatchIndexOpFoldResults(ofrs: expandedSizes, dynamicVec&: dynDims, staticVec&: shape); |
| 404 | RankedTensorType resultType = RankedTensorType::get( |
| 405 | shape, elementType: expandShapeOp.getResultType().getElementType()); |
| 406 | |
| 407 | // Create a new ExtractSliceOp and ExpandShapeOp. |
| 408 | Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>( |
| 409 | location: loc, args: expandShapeOp.getSrc(), args&: collapsedOffsets, args&: collapsedSizes, |
| 410 | args&: collapsedStrides); |
| 411 | rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( |
| 412 | op: sliceOp, args&: resultType, args&: newSliceOp, |
| 413 | args: expandShapeOp.getReassociationIndices(), args&: expandedSizes); |
| 414 | return success(); |
| 415 | } |
| 416 | |
| 417 | // Helper function to check if all the required conditions for the |
| 418 | // tensor.extract_slice to be bubbled up through the tensor.expand_shape are |
| 419 | // met. |
| 420 | LogicalResult |
| 421 | checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp, |
| 422 | tensor::ExpandShapeOp expandShapeOp, |
| 423 | PatternRewriter &rewriter) const { |
| 424 | |
| 425 | if (!expandShapeOp) { |
| 426 | return rewriter.notifyMatchFailure( |
| 427 | arg&: sliceOp, msg: "tensor.extract_slice source not produced by expand_shape" ); |
| 428 | } |
| 429 | |
| 430 | if (!sliceOp.hasUnitStride()) { |
| 431 | return rewriter.notifyMatchFailure( |
| 432 | arg&: sliceOp, msg: "unsupported: non-unit stride. Only contiguous slices can " |
| 433 | "be supported in this transformation." ); |
| 434 | } |
| 435 | |
| 436 | SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); |
| 437 | SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); |
| 438 | |
| 439 | if (static_cast<size_t>(sliceOp.getResultType().getRank()) != |
| 440 | sizes.size()) { |
| 441 | return rewriter.notifyMatchFailure(arg&: sliceOp, |
| 442 | msg: "unimplemented: rank reducing slice" ); |
| 443 | } |
| 444 | |
| 445 | SmallVector<OpFoldResult> outputShape = |
| 446 | getMixedValues(staticValues: expandShapeOp.getStaticOutputShape(), |
| 447 | dynamicValues: expandShapeOp.getOutputShape(), b&: rewriter); |
| 448 | |
| 449 | std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)> |
| 450 | isZeroOffsetAndFullSize = |
| 451 | [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { |
| 452 | if (!isZeroInteger(v: offset)) |
| 453 | return false; |
| 454 | FailureOr<bool> maybeEqual = |
| 455 | ValueBoundsConstraintSet::areEqual(var1: sliceSize, var2: size); |
| 456 | return llvm::succeeded(Result: maybeEqual) && maybeEqual.value(); |
| 457 | }; |
| 458 | |
| 459 | // Check that the slice is contiguous within each reassociation group. |
| 460 | // The slice is contiguous only if after the first dimension where a non |
| 461 | // unit slice is taken, the slice size on all subsequent dimensions of the |
| 462 | // group is equal to the entire size of the dimension. |
| 463 | // Examples of contiguous slices: |
| 464 | // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] |
| 465 | // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] |
| 466 | // Examples of non contiguous slices: |
| 467 | // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] |
| 468 | // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] |
| 469 | for (const ReassociationIndices &indices : |
| 470 | expandShapeOp.getReassociationIndices()) { |
| 471 | int64_t i = 0; |
| 472 | int64_t e = indices.size(); |
| 473 | // Find the first expanded dim after the first dim with non-unit extracted |
| 474 | // size. |
| 475 | for (; i < e; ++i) { |
| 476 | if (!isOneInteger(v: sizes[indices[i]])) { |
| 477 | // +1 to skip the first non-unit size dim. |
| 478 | i++; |
| 479 | break; |
| 480 | } |
| 481 | } |
| 482 | |
| 483 | // Verify that all subsequent dimensions extract the full size of the |
| 484 | // source tensor. |
| 485 | for (; i < e; ++i) { |
| 486 | int64_t expandedDim = indices[i]; |
| 487 | if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], |
| 488 | outputShape[expandedDim])) { |
| 489 | return rewriter.notifyMatchFailure( |
| 490 | arg&: sliceOp, msg: "Not a contiguous slice of the expanded tensor." ); |
| 491 | } |
| 492 | } |
| 493 | } |
| 494 | |
| 495 | return success(); |
| 496 | } |
| 497 | }; |
| 498 | |
| 499 | /// Converts `tensor.extract_slice(tensor.collapse_shape)` to |
| 500 | /// `tensor.collapse_shape(tensor.extract_slice)`. |
| 501 | /// |
| 502 | /// For this transformation to be possible - after bubbling up, the extraction |
| 503 | /// of the contiguous slice must be representable as a single slice obtained via |
| 504 | /// tensor.extract_slice within each reassociation group of the src. |
| 505 | /// |
| 506 | /// In case the size and offset extracted are static then this is possible if |
| 507 | /// the following conditions are met within each reassociation group: |
| 508 | /// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the |
| 509 | /// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the |
| 510 | /// shape of a desired slice. A slice of shape S can be extracted as a |
| 511 | /// contiguous span of elements if and only if there exists an index k in {0, 1, |
| 512 | /// ..., n} such that: |
| 513 | /// S_i = 1 for all i < k (that is, all leading dimensions are singleton), |
| 514 | /// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly |
| 515 | /// one dimension), |
| 516 | /// S_i = A_i for all i > k (that is, all trailing dimensions are preserved |
| 517 | /// in full). |
| 518 | /// In other words, the slice shape S must be of the form: |
| 519 | /// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] |
| 520 | /// |
| 521 | /// In case the size and/or offset extracted are dynamic then this is possible |
| 522 | /// only if there is single dimension in the reassociation group that has a size |
| 523 | /// not equal to 1. |
| 524 | /// In other words, the tensor shape must be of the form: |
| 525 | /// [ 1, 1, ..., 1, A, 1, ...,1 ] |
| 526 | /// Note - it might be possible to enable this pattern for more cases when the |
| 527 | /// size/offset are dynamic via performing an analysis of the possible values |
| 528 | /// that could be given to the size/offset. |
| 529 | /// |
| 530 | /// Example: |
| 531 | /// The transformation is possible because each reassociation group can be |
| 532 | /// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?], |
| 533 | /// [20->10]). |
| 534 | /// ``` |
| 535 | /// BEFORE: |
| 536 | /// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ... |
| 537 | /// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32> |
| 538 | /// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1] |
| 539 | /// tensor<128x7x20xf32> to tensor<32x?x10xf32> |
| 540 | /// |
| 541 | /// AFTER: |
| 542 | /// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10] |
| 543 | // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32> |
| 544 | /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ... |
| 545 | /// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32> |
| 546 | /// ``` |
| 547 | /// |
| 548 | /// Negative example: |
| 549 | /// The transformation is not possible because we cannot use a single slice to |
| 550 | /// represent the reassociation group [2x3x10->???]. If we would want the |
| 551 | /// collapse to be after the extraction, we would need to extract multiple |
| 552 | /// slices and concat them together. |
| 553 | /// ``` |
| 554 | /// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into |
| 555 | /// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] : |
| 556 | /// tensor<60xf32> to tensor<15xf32> |
| 557 | /// ``` |
| 558 | /// If we would want the collapse to be after the extraction, a possible |
| 559 | /// alternate transformation could be to extract multiple slices and concat them |
| 560 | /// together: |
| 561 | /// ``` |
| 562 | /// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] : |
| 563 | /// tensor<2x3x10xf32> to tensor <1x1x10xf32> |
| 564 | /// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] : |
| 565 | /// tensor<2x3x10xf32> to tensor <1x1x5xf32> |
| 566 | /// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} : |
| 567 | /// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32> |
| 568 | /// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32> |
| 569 | /// to tensor<15xf32> |
| 570 | /// ``` |
| 571 | /// But this is not the intended purpose of the transformation. |
| 572 | struct |
| 573 | : public OpRewritePattern<tensor::ExtractSliceOp> { |
| 574 | using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
| 575 | |
| 576 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
| 577 | PatternRewriter &rewriter) const override { |
| 578 | auto collapseShapeOp = |
| 579 | sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); |
| 580 | if (!collapseShapeOp) { |
| 581 | return rewriter.notifyMatchFailure( |
| 582 | arg&: sliceOp, |
| 583 | msg: "tensor.extract_slice source not produced by tensor.collapse_shape" ); |
| 584 | } |
| 585 | |
| 586 | if (!sliceOp.hasUnitStride()) { |
| 587 | return rewriter.notifyMatchFailure( |
| 588 | arg&: sliceOp, msg: "unsupported: non-unit stride. Only contiguous slices can " |
| 589 | "be supported in this transformation." ); |
| 590 | } |
| 591 | |
| 592 | // The tensor.extract_slice before applying the pattern works on the result |
| 593 | // of the tensor.collapse_shape, so variables (i.e. inputs for |
| 594 | // ExtractSliceOp) referring to the state before applying the pattern are |
| 595 | // named with the prefix "collapsed", and ones referring to the state after |
| 596 | // applying the pattern are named with the prefix "expanded". |
| 597 | SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets(); |
| 598 | SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes(); |
| 599 | |
| 600 | if (static_cast<size_t>(sliceOp.getResultType().getRank()) != |
| 601 | collapsedSizes.size()) { |
| 602 | return rewriter.notifyMatchFailure(arg&: sliceOp, |
| 603 | msg: "unimplemented: rank reducing slice" ); |
| 604 | } |
| 605 | |
| 606 | ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape(); |
| 607 | SmallVector<ReassociationIndices, 4> reassociationIndices = |
| 608 | collapseShapeOp.getReassociationIndices(); |
| 609 | |
| 610 | // Compute new offsets, sizes, and strides for tensor.extract_slice. |
| 611 | // The new tensor.extract_slice will work on a tensor that has has a rank |
| 612 | // equal to the rank of the src of the collapse_shape. In each iteration of |
| 613 | // the loop, the offsets and sizes will be computed per reassociation group. |
| 614 | SmallVector<OpFoldResult> expandedOffsets, expandedSizes; |
| 615 | SmallVector<OpFoldResult> expandedStrides(srcShape.size(), |
| 616 | rewriter.getIndexAttr(value: 1)); |
| 617 | |
| 618 | for (auto [collapsedSize, collapsedOffset, reassocIndices] : |
| 619 | llvm::zip_equal(t&: collapsedSizes, u&: collapsedOffsets, |
| 620 | args: collapseShapeOp.getReassociationIndices())) { |
| 621 | // CASE #1 - size and/or offset are dynamic. |
| 622 | // In this case, the slice can be represented as a contiguous slice only |
| 623 | // if there is a single dimension in the reassociation group that has a |
| 624 | // size not equal to 1. |
| 625 | if (isa<Value>(Val: collapsedSize) || isa<Value>(Val: collapsedOffset)) { |
| 626 | int nonUnitSizeCount = 0; |
| 627 | for (int64_t expandedShapeIdx : reassocIndices) { |
| 628 | if (srcShape[expandedShapeIdx] != 1) { |
| 629 | nonUnitSizeCount++; |
| 630 | expandedSizes.push_back(Elt: collapsedSize); |
| 631 | expandedOffsets.push_back(Elt: collapsedOffset); |
| 632 | continue; |
| 633 | } |
| 634 | |
| 635 | expandedSizes.push_back(Elt: rewriter.getIndexAttr(value: 1)); |
| 636 | expandedOffsets.push_back(Elt: rewriter.getIndexAttr(value: 0)); |
| 637 | } |
| 638 | |
| 639 | if (nonUnitSizeCount != 1) { |
| 640 | return rewriter.notifyMatchFailure( |
| 641 | arg&: sliceOp, |
| 642 | msg: "unsupported: slice cannot be verified to be contiguous" ); |
| 643 | } |
| 644 | continue; |
| 645 | } |
| 646 | |
| 647 | // CASE #2 = size and offset are static. |
| 648 | // Verify that the slice can be represented as a contiguous slice of the |
| 649 | // src of the collapse_shape. |
| 650 | // Checking this is done on order of most internal dimensions first, |
| 651 | // so traversal is done in reverse order of the reassociation group. |
| 652 | // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, |
| 653 | // ...,An] then we first find the size and offset for n...k+1 then for k |
| 654 | // and then for k-1...0. |
| 655 | |
| 656 | // currentCollapsedsize and currentCollapsedOffset are initialized with |
| 657 | // the original collapsed size and offset and divided by the expanded |
| 658 | // shape size in each dimension as we go along the reassociation group. |
| 659 | // In essence we are spreading the original collapsed size and offset over |
| 660 | // the various expanded slice dimensions. |
| 661 | // The variables are used both to check the validity of the slice and to |
| 662 | // compute the expanded sizes and offsets. |
| 663 | int64_t currentCollapsedsize = getConstantIntValue(ofr: collapsedSize).value(); |
| 664 | int64_t currentCollapsedOffset = |
| 665 | getConstantIntValue(ofr: collapsedOffset).value(); |
| 666 | |
| 667 | SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets; |
| 668 | |
| 669 | ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), |
| 670 | reassocIndices.rend()); |
| 671 | int64_t idx = 0; |
| 672 | int64_t reassocGroupSize = reassocIndices.size(); |
| 673 | |
| 674 | // First handle the trailing dimensions where the slice size should be |
| 675 | // equal to the tensor shape and the offset should be 0 (n...k+1). |
| 676 | for (; idx < reassocGroupSize; ++idx) { |
| 677 | int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; |
| 678 | |
| 679 | if (currentCollapsedsize < expandedShapeSize) |
| 680 | break; |
| 681 | |
| 682 | // We need to make sure that the slice size can be set to the shape size |
| 683 | // and the offset to 0. |
| 684 | if ((currentCollapsedsize % expandedShapeSize) != 0 || |
| 685 | (currentCollapsedOffset % expandedShapeSize) != 0) { |
| 686 | return rewriter.notifyMatchFailure( |
| 687 | arg&: sliceOp, msg: "unsupported: cannot be extracted as a contiguous slice " |
| 688 | "of the src of the collapse_shape" ); |
| 689 | } |
| 690 | |
| 691 | groupExpandedSizes.push_back(Elt: rewriter.getIndexAttr(value: expandedShapeSize)); |
| 692 | groupExpandedOffsets.push_back(Elt: rewriter.getIndexAttr(value: 0)); |
| 693 | |
| 694 | currentCollapsedsize /= expandedShapeSize; |
| 695 | currentCollapsedOffset /= expandedShapeSize; |
| 696 | } |
| 697 | |
| 698 | // Now handle the first dim where slicing occurs on (k). |
| 699 | if (idx < reassocGroupSize) { |
| 700 | int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; |
| 701 | int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; |
| 702 | // We need to make sure that the slice size in this dim + offset will |
| 703 | // not exceed the shape size. |
| 704 | if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { |
| 705 | return rewriter.notifyMatchFailure( |
| 706 | arg&: sliceOp, msg: "unsupported: slice cannot be extracted as a contiguous " |
| 707 | "slice of the src of the collapse_shape" ); |
| 708 | } |
| 709 | |
| 710 | groupExpandedSizes.push_back( |
| 711 | Elt: rewriter.getIndexAttr(value: currentCollapsedsize)); |
| 712 | groupExpandedOffsets.push_back(Elt: rewriter.getIndexAttr(value: offsetInDim)); |
| 713 | |
| 714 | currentCollapsedOffset /= expandedShapeSize; |
| 715 | } |
| 716 | |
| 717 | // Now handle the leading dimensions where the slice size is equal to 1 |
| 718 | // (k-1...0). |
| 719 | // The size for these dimensions must be 1 because of how we constructed |
| 720 | // the slice size of the expanded shape. We spread the original collapsed |
| 721 | // size over the expanded shape sizes until we reached dimension k where |
| 722 | // the remaining size was smaller than the expanded shape size, and spread |
| 723 | // the remaining size on it. So, now we are left with only 1s. |
| 724 | for (idx++; idx < reassocGroupSize; ++idx) { |
| 725 | int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; |
| 726 | int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; |
| 727 | groupExpandedSizes.push_back(Elt: rewriter.getIndexAttr(value: 1)); |
| 728 | groupExpandedOffsets.push_back(Elt: rewriter.getIndexAttr(value: offsetInDim)); |
| 729 | currentCollapsedOffset /= expandedShapeSize; |
| 730 | } |
| 731 | |
| 732 | expandedSizes.append(in_start: groupExpandedSizes.rbegin(), |
| 733 | in_end: groupExpandedSizes.rend()); |
| 734 | expandedOffsets.append(in_start: groupExpandedOffsets.rbegin(), |
| 735 | in_end: groupExpandedOffsets.rend()); |
| 736 | } |
| 737 | |
| 738 | Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>( |
| 739 | location: collapseShapeOp->getLoc(), args: collapseShapeOp.getSrc(), args&: expandedOffsets, |
| 740 | args&: expandedSizes, args&: expandedStrides); |
| 741 | rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( |
| 742 | op: sliceOp, args: sliceOp.getResultType(), args&: newSliceOp, |
| 743 | args: collapseShapeOp.getReassociationIndices()); |
| 744 | |
| 745 | return success(); |
| 746 | } |
| 747 | }; |
| 748 | |
| 749 | } // namespace |
| 750 | |
| 751 | void mlir::tensor::populateReassociativeReshapeFoldingPatterns( |
| 752 | RewritePatternSet &patterns) { |
| 753 | patterns |
| 754 | .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract, |
| 755 | FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>, |
| 756 | FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>, |
| 757 | FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>, |
| 758 | FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>( |
| 759 | arg: patterns.getContext()); |
| 760 | } |
| 761 | |
| 762 | void mlir::tensor::populateBubbleUpExpandShapePatterns( |
| 763 | RewritePatternSet &patterns) { |
| 764 | patterns.add<BubbleUpExpandThroughParallelCollapse>(arg: patterns.getContext()); |
| 765 | } |
| 766 | |
| 767 | void mlir::tensor::( |
| 768 | RewritePatternSet &patterns) { |
| 769 | patterns.add<BubbleUpExpandShapeThroughExtractSlice, |
| 770 | BubbleUpCollapseShapeThroughExtractSlice>(arg: patterns.getContext()); |
| 771 | } |
| 772 | |