1//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
15#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/IR/Dominance.h"
18#include "mlir/Interfaces/SubsetOpInterface.h"
19
20namespace mlir {
21namespace bufferization {
22#define GEN_PASS_DEF_EMPTYTENSORELIMINATIONPASS
23#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
24} // namespace bufferization
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::bufferization;
29
30/// Return true if all `neededValues` are in scope at the given
31/// `insertionPoint`.
32static bool
33neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
34 Operation *insertionPoint,
35 const SmallVector<Value> &neededValues) {
36 for (Value val : neededValues) {
37 if (auto bbArg = dyn_cast<BlockArgument>(Val&: val)) {
38 Block *owner = bbArg.getOwner();
39 if (!owner->findAncestorOpInBlock(op&: *insertionPoint))
40 return false;
41 } else {
42 auto opResult = cast<OpResult>(Val&: val);
43 if (!domInfo.properlyDominates(a: opResult.getOwner(), b: insertionPoint))
44 return false;
45 }
46 }
47 return true;
48}
49
50/// Find a valid insertion point for a replacement of `emptyTensorOp`'s
51/// use of `user` operation, assuming that the replacement may use any
52/// value from `neededValues`.
53static Operation *
54findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
55 const SmallVector<Value> &neededValues) {
56 DominanceInfo domInfo;
57 Operation *candidateInsertionPoint = emptyTensorOp;
58
59 // Gather all possible insertion points: the location of
60 // `candidateInsertionPoint` and right after the definition of each value in
61 // `neededValues`.
62 SmallVector<Operation *> insertionPointCandidates;
63 insertionPointCandidates.push_back(Elt: candidateInsertionPoint);
64 for (Value val : neededValues) {
65 // Note: The anchor op is using all of `neededValues`, so:
66 // * in case of a block argument: There must be at least one op in the block
67 // (the anchor op or one of its parents).
68 // * in case of an OpResult: There must be at least one op right after the
69 // defining op (the anchor op or one of its
70 // parents).
71 if (auto bbArg = dyn_cast<BlockArgument>(Val&: val)) {
72 insertionPointCandidates.push_back(
73 Elt: &bbArg.getOwner()->getOperations().front());
74 } else {
75 insertionPointCandidates.push_back(Elt: val.getDefiningOp()->getNextNode());
76 }
77 }
78
79 // Select first matching insertion point.
80 for (Operation *insertionPoint : insertionPointCandidates) {
81 // Check if all needed values are in scope.
82 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
83 neededValues))
84 continue;
85 // Check if the insertion point is before the use to be replaced.
86 if (!domInfo.dominates(a: insertionPoint, b: user))
87 continue;
88 return insertionPoint;
89 }
90
91 // No suitable insertion point was found.
92 return nullptr;
93}
94
95Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
96 SubsetInsertionOpInterface op,
97 tensor::EmptyOp emptyTensorOp,
98 Operation *user) {
99
100 mlir::OpBuilder::InsertionGuard guard(rewriter);
101 // All values that are needed to create the replacement op.
102 SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
103 // Find a suitable insertion point. If no suitable insertion point
104 // for the replacement can be found, return an empty value to skip
105 // this replacement.
106 Operation *insertionPoint =
107 findValidInsertionPoint(emptyTensorOp, user, neededValues);
108 if (!insertionPoint)
109 return {};
110
111 rewriter.setInsertionPoint(insertionPoint);
112 Value replacement =
113 op.buildSubsetExtraction(builder&: rewriter, loc: emptyTensorOp->getLoc());
114 return replacement;
115}
116
117LogicalResult mlir::bufferization::eliminateEmptyTensors(
118 RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
119 ControlBuildSubsetExtractionFn subsetsExtractionFn) {
120 OpBuilder::InsertionGuard g(rewriter);
121 llvm::DenseSet<OpOperand *> visitedOpOperands;
122 op->walk(callback: [&](SubsetInsertionOpInterface op) {
123 visitedOpOperands.clear();
124 OpOperand &source = op.getSourceOperand();
125 // Skip operands that do not bufferize inplace. "tensor.empty" could still
126 // be replaced, but the transformation may not be beneficial.
127 if (!state.isInPlace(opOperand&: source))
128 return WalkResult::skip();
129
130 // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
131 // equivalent tensors. I.e., stop when there are ops such as extract_slice
132 // on the path.
133 TraversalConfig config;
134 config.followEquivalentOnly = true;
135 config.alwaysIncludeLeaves = false;
136 // Replace only if the types match or are static <-> dynamic casts. We do
137 // not support slices or reshapes.
138 // TODO: This could be extended to support IR such as:
139 // %0 = tensor.empty() : tensor<128xf32>
140 // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
141 // %2 = tensor.expand_shape %1 ...
142 // %3 = tensor.insert_slice %2 into ...
143 config.followSameTypeOrCastsOnly = true;
144 SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
145 opOperand: &source, /*condition=*/
146 [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
147 visitedOpOperands: &visitedOpOperands);
148
149 for (Value v : emptyTensors) {
150 auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
151 assert(emptyTensorOp && "expected tensor.empty op");
152 // Find the use to be replaced from the use-def chain.
153 auto iter = llvm::find_if(
154 Range&: visitedOpOperands, P: [&emptyTensorOp](OpOperand *opOperand) {
155 return llvm::count(Range: emptyTensorOp->getUses(), Element: *opOperand);
156 });
157
158 assert(iter != visitedOpOperands.end() && "could not find use");
159 OpOperand *useToBeReplaced = *iter;
160 Operation *user = useToBeReplaced->getOwner();
161 auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
162 if (!replacement)
163 continue;
164 if (emptyTensorOp == replacement.getDefiningOp())
165 continue;
166 if (replacement.getType() != v.getType()) {
167 if (cast<ShapedType>(Val: replacement.getType()).getElementType() !=
168 cast<ShapedType>(Val: v.getType()).getElementType())
169 continue;
170 rewriter.setInsertionPointAfterValue(replacement);
171 replacement = rewriter.create<tensor::CastOp>(location: v.getLoc(), args: v.getType(),
172 args&: replacement);
173 }
174 // Replace the specific use of the tensor::EmptyOp.
175 rewriter.modifyOpInPlace(root: user, callable: [&]() {
176 user->setOperand(idx: useToBeReplaced->getOperandNumber(), value: replacement);
177 });
178 state.resetCache();
179 }
180
181 return WalkResult::advance();
182 });
183
184 return success();
185}
186
187namespace {
188struct EmptyTensorElimination
189 : public bufferization::impl::EmptyTensorEliminationPassBase<
190 EmptyTensorElimination> {
191 using Base::Base;
192
193 void runOnOperation() override;
194
195 void getDependentDialects(DialectRegistry &registry) const override {
196 registry
197 .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
198 }
199};
200} // namespace
201
202LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
203 Operation *op) {
204 auto moduleOp = dyn_cast<ModuleOp>(Val: op);
205 OneShotBufferizationOptions options;
206 options.allowReturnAllocsFromLoops = true;
207 if (moduleOp)
208 options.bufferizeFunctionBoundaries = true;
209 OneShotAnalysisState state(op, options);
210 if (moduleOp) {
211 // Module analysis takes into account function boundaries.
212 if (failed(Result: analyzeModuleOp(moduleOp, state)))
213 return failure();
214 } else {
215 // Regular One-Shot Bufferize ignores func.func block arguments, func.call,
216 // func.return.
217 if (failed(Result: analyzeOp(op, state)))
218 return failure();
219 }
220
221 return bufferization::eliminateEmptyTensors(rewriter, op, state);
222}
223
224void EmptyTensorElimination::runOnOperation() {
225 IRRewriter rewriter(getOperation()->getContext());
226 if (failed(Result: bufferization::eliminateEmptyTensors(rewriter, op: getOperation())))
227 signalPassFailure();
228}
229

source code of mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp