1//===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Swap a `tensor.extract_slice` with the producer of the source if the producer
10// implements the `TilingInterface`. When used in conjunction with tiling this
11// effectively tiles + fuses the producer with its consumer.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
18#include "mlir/Dialect/Utils/StaticValueUtils.h"
19#include "mlir/Interfaces/TilingInterface.h"
20#include "llvm/Support/Debug.h"
21
22#define DEBUG_TYPE "tensor-swap-slices"
23
24using namespace mlir;
25
26FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
27 OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
28 auto producerOp = dyn_cast<TilingInterface>(Val: producer.getOwner());
29 if (!producerOp)
30 return failure();
31
32 // `TilingInterface` currently only supports strides being 1.
33 if (!llvm::all_of(Range: sliceOp.getMixedStrides(), P: isOneInteger))
34 return failure();
35
36 FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
37 b&: builder, resultNumber: producer.getResultNumber(), offsets: sliceOp.getMixedOffsets(),
38 sizes: sliceOp.getMixedSizes());
39 if (failed(Result: tiledResult))
40 return failure();
41
42 // For cases where the slice was rank-reducing, create a rank-reducing slice
43 // to get the same type back.
44 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
45 if (droppedDims.any()) {
46 assert(tiledResult->tiledValues.size() == 1 &&
47 "expected only a single tiled result value to replace the extract "
48 "slice");
49 SmallVector<OpFoldResult> offsets(sliceOp.getSourceType().getRank(),
50 builder.getIndexAttr(value: 0));
51 SmallVector<OpFoldResult> strides(sliceOp.getSourceType().getRank(),
52 builder.getIndexAttr(value: 1));
53 auto newSliceOp = builder.create<tensor::ExtractSliceOp>(
54 location: sliceOp.getLoc(), args: sliceOp.getType(), args&: tiledResult->tiledValues[0],
55 args&: offsets, args: sliceOp.getMixedSizes(), args&: strides);
56 tiledResult->tiledValues[0] = newSliceOp;
57 }
58
59 return *tiledResult;
60}
61
62FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
63 OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
64 ArrayRef<OpOperand *> consumerOperands) {
65 if (sliceOps.empty()) {
66 LLVM_DEBUG(
67 { llvm::dbgs() << "expected candidate slices list to be non-empty"; });
68 return failure();
69 }
70 if (sliceOps.size() != consumerOperands.size()) {
71 LLVM_DEBUG({
72 llvm::dbgs()
73 << "expected as many operands as the number of slices passed";
74 });
75 return failure();
76 }
77 auto consumerOp =
78 dyn_cast<TilingInterface>(Val: consumerOperands.front()->getOwner());
79 if (!consumerOp)
80 return failure();
81 for (auto opOperand : consumerOperands.drop_front()) {
82 if (opOperand->getOwner() != consumerOp) {
83 LLVM_DEBUG({
84 llvm::dbgs()
85 << "expected all consumer operands to be from the same operation";
86 });
87 return failure();
88 }
89 }
90
91 auto consumerOperandNums = llvm::map_to_vector(
92 C&: consumerOperands, F: [](OpOperand *opOperand) -> unsigned {
93 return opOperand->getOperandNumber();
94 });
95 SmallVector<SmallVector<OpFoldResult>> allOffsets;
96 SmallVector<SmallVector<OpFoldResult>> allSizes;
97 for (auto sliceOp : sliceOps) {
98
99 // `TilingInterface` currently only supports strides being 1.
100 if (!llvm::all_of(Range: sliceOp.getMixedStrides(), P: isOneInteger))
101 return failure();
102
103 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
104 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
105 allOffsets.emplace_back(Args: std::move(offsets));
106 allSizes.emplace_back(Args: std::move(sizes));
107 }
108 FailureOr<TilingResult> tiledResult =
109 consumerOp.getTiledImplementationFromOperandTiles(
110 b&: builder, operandNumbers: consumerOperandNums, allOffsets, allSizes);
111 if (failed(Result: tiledResult))
112 return failure();
113
114 return *tiledResult;
115}
116

source code of mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp