1//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
15#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/IR/Dominance.h"
18#include "mlir/Interfaces/SubsetOpInterface.h"
19#include "mlir/Pass/Pass.h"
20
21namespace mlir {
22namespace bufferization {
23#define GEN_PASS_DEF_EMPTYTENSORELIMINATIONPASS
24#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
25} // namespace bufferization
26} // namespace mlir
27
28using namespace mlir;
29using namespace mlir::bufferization;
30
31/// Return true if all `neededValues` are in scope at the given
32/// `insertionPoint`.
33static bool
34neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
35 Operation *insertionPoint,
36 const SmallVector<Value> &neededValues) {
37 for (Value val : neededValues) {
38 if (auto bbArg = dyn_cast<BlockArgument>(val)) {
39 Block *owner = bbArg.getOwner();
40 if (!owner->findAncestorOpInBlock(op&: *insertionPoint))
41 return false;
42 } else {
43 auto opResult = cast<OpResult>(val);
44 if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
45 return false;
46 }
47 }
48 return true;
49}
50
51/// Find a valid insertion point for a replacement of `emptyTensorOp`'s
52/// use of `user` operation, assuming that the replacement may use any
53/// value from `neededValues`.
54static Operation *
55findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
56 const SmallVector<Value> &neededValues) {
57 DominanceInfo domInfo;
58 Operation *candidateInsertionPoint = emptyTensorOp;
59
60 // Gather all possible insertion points: the location of
61 // `candidateInsertionPoint` and right after the definition of each value in
62 // `neededValues`.
63 SmallVector<Operation *> insertionPointCandidates;
64 insertionPointCandidates.push_back(Elt: candidateInsertionPoint);
65 for (Value val : neededValues) {
66 // Note: The anchor op is using all of `neededValues`, so:
67 // * in case of a block argument: There must be at least one op in the block
68 // (the anchor op or one of its parents).
69 // * in case of an OpResult: There must be at least one op right after the
70 // defining op (the anchor op or one of its
71 // parents).
72 if (auto bbArg = dyn_cast<BlockArgument>(Val&: val)) {
73 insertionPointCandidates.push_back(
74 Elt: &bbArg.getOwner()->getOperations().front());
75 } else {
76 insertionPointCandidates.push_back(Elt: val.getDefiningOp()->getNextNode());
77 }
78 }
79
80 // Select first matching insertion point.
81 for (Operation *insertionPoint : insertionPointCandidates) {
82 // Check if all needed values are in scope.
83 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
84 neededValues))
85 continue;
86 // Check if the insertion point is before the use to be replaced.
87 if (!domInfo.dominates(a: insertionPoint, b: user))
88 continue;
89 return insertionPoint;
90 }
91
92 // No suitable insertion point was found.
93 return nullptr;
94}
95
96Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
97 SubsetInsertionOpInterface op,
98 tensor::EmptyOp emptyTensorOp,
99 Operation *user) {
100
101 mlir::OpBuilder::InsertionGuard guard(rewriter);
102 // All values that are needed to create the replacement op.
103 SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
104 // Find a suitable insertion point. If no suitable insertion point
105 // for the replacement can be found, return an empty value to skip
106 // this replacement.
107 Operation *insertionPoint =
108 findValidInsertionPoint(emptyTensorOp, user, neededValues);
109 if (!insertionPoint)
110 return {};
111
112 rewriter.setInsertionPoint(insertionPoint);
113 Value replacement =
114 op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
115 return replacement;
116}
117
118LogicalResult mlir::bufferization::eliminateEmptyTensors(
119 RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
120 ControlBuildSubsetExtractionFn subsetsExtractionFn) {
121 OpBuilder::InsertionGuard g(rewriter);
122 llvm::DenseSet<OpOperand *> visitedOpOperands;
123 op->walk([&](SubsetInsertionOpInterface op) {
124 visitedOpOperands.clear();
125 OpOperand &source = op.getSourceOperand();
126 // Skip operands that do not bufferize inplace. "tensor.empty" could still
127 // be replaced, but the transformation may not be beneficial.
128 if (!state.isInPlace(opOperand&: source))
129 return WalkResult::skip();
130
131 // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
132 // equivalent tensors. I.e., stop when there are ops such as extract_slice
133 // on the path.
134 TraversalConfig config;
135 config.followEquivalentOnly = true;
136 config.alwaysIncludeLeaves = false;
137 // Replace only if the types match or are static <-> dynamic casts. We do
138 // not support slices or reshapes.
139 // TODO: This could be extended to support IR such as:
140 // %0 = tensor.empty() : tensor<128xf32>
141 // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
142 // %2 = tensor.expand_shape %1 ...
143 // %3 = tensor.insert_slice %2 into ...
144 config.followSameTypeOrCastsOnly = true;
145 SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
146 &source, /*condition=*/
147 [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
148 &visitedOpOperands);
149
150 for (Value v : emptyTensors) {
151 auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
152 assert(emptyTensorOp && "expected tensor.empty op");
153 // Find the use to be replaced from the use-def chain.
154 auto iter = llvm::find_if(
155 visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
156 return llvm::count(emptyTensorOp->getUses(), *opOperand);
157 });
158
159 assert(iter != visitedOpOperands.end() && "could not find use");
160 OpOperand *useToBeReplaced = *iter;
161 Operation *user = useToBeReplaced->getOwner();
162 auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
163 if (!replacement)
164 continue;
165 if (emptyTensorOp == replacement.getDefiningOp())
166 continue;
167 if (replacement.getType() != v.getType()) {
168 if (cast<ShapedType>(replacement.getType()).getElementType() !=
169 cast<ShapedType>(v.getType()).getElementType())
170 continue;
171 rewriter.setInsertionPointAfterValue(replacement);
172 replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
173 replacement);
174 }
175 // Replace the specific use of the tensor::EmptyOp.
176 rewriter.modifyOpInPlace(user, [&]() {
177 user->setOperand(useToBeReplaced->getOperandNumber(), replacement);
178 });
179 state.resetCache();
180 }
181
182 return WalkResult::advance();
183 });
184
185 return success();
186}
187
188namespace {
189struct EmptyTensorElimination
190 : public bufferization::impl::EmptyTensorEliminationPassBase<
191 EmptyTensorElimination> {
192 using Base::Base;
193
194 void runOnOperation() override;
195
196 void getDependentDialects(DialectRegistry &registry) const override {
197 registry
198 .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
199 }
200};
201} // namespace
202
203LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
204 Operation *op) {
205 auto moduleOp = dyn_cast<ModuleOp>(op);
206 OneShotBufferizationOptions options;
207 options.allowReturnAllocsFromLoops = true;
208 if (moduleOp)
209 options.bufferizeFunctionBoundaries = true;
210 OneShotAnalysisState state(op, options);
211 if (moduleOp) {
212 // Module analysis takes into account function boundaries.
213 if (failed(analyzeModuleOp(moduleOp, state)))
214 return failure();
215 } else {
216 // Regular One-Shot Bufferize ignores func.func block arguments, func.call,
217 // func.return.
218 if (failed(Result: analyzeOp(op, state)))
219 return failure();
220 }
221
222 return bufferization::eliminateEmptyTensors(rewriter, op, state);
223}
224
225void EmptyTensorElimination::runOnOperation() {
226 IRRewriter rewriter(getOperation()->getContext());
227 if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
228 signalPassFailure();
229}
230

source code of mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp