| 1 | //===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Swap a `tensor.extract_slice` with the producer of the source if the producer |
| 10 | // implements the `TilingInterface`. When used in conjunction with tiling this |
| 11 | // effectively tiles + fuses the producer with its consumer. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 17 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| 18 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 19 | #include "mlir/Interfaces/TilingInterface.h" |
| 20 | #include "llvm/Support/Debug.h" |
| 21 | |
| 22 | #define DEBUG_TYPE "tensor-swap-slices" |
| 23 | |
| 24 | using namespace mlir; |
| 25 | |
| 26 | FailureOr<TilingResult> tensor::( |
| 27 | OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) { |
| 28 | auto producerOp = dyn_cast<TilingInterface>(Val: producer.getOwner()); |
| 29 | if (!producerOp) |
| 30 | return failure(); |
| 31 | |
| 32 | // `TilingInterface` currently only supports strides being 1. |
| 33 | if (!llvm::all_of(Range: sliceOp.getMixedStrides(), P: isOneInteger)) |
| 34 | return failure(); |
| 35 | |
| 36 | FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue( |
| 37 | b&: builder, resultNumber: producer.getResultNumber(), offsets: sliceOp.getMixedOffsets(), |
| 38 | sizes: sliceOp.getMixedSizes()); |
| 39 | if (failed(Result: tiledResult)) |
| 40 | return failure(); |
| 41 | |
| 42 | // For cases where the slice was rank-reducing, create a rank-reducing slice |
| 43 | // to get the same type back. |
| 44 | llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); |
| 45 | if (droppedDims.any()) { |
| 46 | assert(tiledResult->tiledValues.size() == 1 && |
| 47 | "expected only a single tiled result value to replace the extract " |
| 48 | "slice" ); |
| 49 | SmallVector<OpFoldResult> offsets(sliceOp.getSourceType().getRank(), |
| 50 | builder.getIndexAttr(value: 0)); |
| 51 | SmallVector<OpFoldResult> strides(sliceOp.getSourceType().getRank(), |
| 52 | builder.getIndexAttr(value: 1)); |
| 53 | auto newSliceOp = builder.create<tensor::ExtractSliceOp>( |
| 54 | location: sliceOp.getLoc(), args: sliceOp.getType(), args&: tiledResult->tiledValues[0], |
| 55 | args&: offsets, args: sliceOp.getMixedSizes(), args&: strides); |
| 56 | tiledResult->tiledValues[0] = newSliceOp; |
| 57 | } |
| 58 | |
| 59 | return *tiledResult; |
| 60 | } |
| 61 | |
| 62 | FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer( |
| 63 | OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps, |
| 64 | ArrayRef<OpOperand *> consumerOperands) { |
| 65 | if (sliceOps.empty()) { |
| 66 | LLVM_DEBUG( |
| 67 | { llvm::dbgs() << "expected candidate slices list to be non-empty" ; }); |
| 68 | return failure(); |
| 69 | } |
| 70 | if (sliceOps.size() != consumerOperands.size()) { |
| 71 | LLVM_DEBUG({ |
| 72 | llvm::dbgs() |
| 73 | << "expected as many operands as the number of slices passed" ; |
| 74 | }); |
| 75 | return failure(); |
| 76 | } |
| 77 | auto consumerOp = |
| 78 | dyn_cast<TilingInterface>(Val: consumerOperands.front()->getOwner()); |
| 79 | if (!consumerOp) |
| 80 | return failure(); |
| 81 | for (auto opOperand : consumerOperands.drop_front()) { |
| 82 | if (opOperand->getOwner() != consumerOp) { |
| 83 | LLVM_DEBUG({ |
| 84 | llvm::dbgs() |
| 85 | << "expected all consumer operands to be from the same operation" ; |
| 86 | }); |
| 87 | return failure(); |
| 88 | } |
| 89 | } |
| 90 | |
| 91 | auto consumerOperandNums = llvm::map_to_vector( |
| 92 | C&: consumerOperands, F: [](OpOperand *opOperand) -> unsigned { |
| 93 | return opOperand->getOperandNumber(); |
| 94 | }); |
| 95 | SmallVector<SmallVector<OpFoldResult>> allOffsets; |
| 96 | SmallVector<SmallVector<OpFoldResult>> allSizes; |
| 97 | for (auto sliceOp : sliceOps) { |
| 98 | |
| 99 | // `TilingInterface` currently only supports strides being 1. |
| 100 | if (!llvm::all_of(Range: sliceOp.getMixedStrides(), P: isOneInteger)) |
| 101 | return failure(); |
| 102 | |
| 103 | SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); |
| 104 | SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); |
| 105 | allOffsets.emplace_back(Args: std::move(offsets)); |
| 106 | allSizes.emplace_back(Args: std::move(sizes)); |
| 107 | } |
| 108 | FailureOr<TilingResult> tiledResult = |
| 109 | consumerOp.getTiledImplementationFromOperandTiles( |
| 110 | b&: builder, operandNumbers: consumerOperandNums, allOffsets, allSizes); |
| 111 | if (failed(Result: tiledResult)) |
| 112 | return failure(); |
| 113 | |
| 114 | return *tiledResult; |
| 115 | } |
| 116 | |