1//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
15#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/IR/Dominance.h"
18#include "mlir/Interfaces/SubsetOpInterface.h"
19#include "mlir/Pass/Pass.h"
20
21namespace mlir {
22namespace bufferization {
23#define GEN_PASS_DEF_EMPTYTENSORELIMINATION
24#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
25} // namespace bufferization
26} // namespace mlir
27
28using namespace mlir;
29using namespace mlir::bufferization;
30
31/// Return true if all `neededValues` are in scope at the given
32/// `insertionPoint`.
33static bool
34neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
35 Operation *insertionPoint,
36 const SmallVector<Value> &neededValues) {
37 for (Value val : neededValues) {
38 if (auto bbArg = dyn_cast<BlockArgument>(val)) {
39 Block *owner = bbArg.getOwner();
40 if (!owner->findAncestorOpInBlock(op&: *insertionPoint))
41 return false;
42 } else {
43 auto opResult = cast<OpResult>(val);
44 if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
45 return false;
46 }
47 }
48 return true;
49}
50
51/// Return true if the given `insertionPoint` dominates all uses of
52/// `emptyTensorOp`.
53static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
54 Operation *insertionPoint,
55 Operation *emptyTensorOp) {
56 return llvm::all_of(Range: emptyTensorOp->getUsers(), P: [&](Operation *user) {
57 return domInfo.dominates(a: insertionPoint, b: user);
58 });
59}
60
61/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
62/// that the replacement may use any value from `neededValues`.
63static Operation *
64findValidInsertionPoint(Operation *emptyTensorOp,
65 const SmallVector<Value> &neededValues) {
66 DominanceInfo domInfo;
67
68 // Gather all possible insertion points: the location of `emptyTensorOp` and
69 // right after the definition of each value in `neededValues`.
70 SmallVector<Operation *> insertionPointCandidates;
71 insertionPointCandidates.push_back(Elt: emptyTensorOp);
72 for (Value val : neededValues) {
73 // Note: The anchor op is using all of `neededValues`, so:
74 // * in case of a block argument: There must be at least one op in the block
75 // (the anchor op or one of its parents).
76 // * in case of an OpResult: There must be at least one op right after the
77 // defining op (the anchor op or one of its
78 // parents).
79 if (auto bbArg = dyn_cast<BlockArgument>(Val&: val)) {
80 insertionPointCandidates.push_back(
81 Elt: &bbArg.getOwner()->getOperations().front());
82 } else {
83 insertionPointCandidates.push_back(Elt: val.getDefiningOp()->getNextNode());
84 }
85 }
86
87 // Select first matching insertion point.
88 for (Operation *insertionPoint : insertionPointCandidates) {
89 // Check if all needed values are in scope.
90 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
91 neededValues))
92 continue;
93 // Check if the insertion point is before all uses.
94 if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
95 continue;
96 return insertionPoint;
97 }
98
99 // No suitable insertion point was found.
100 return nullptr;
101}
102
103LogicalResult mlir::bufferization::eliminateEmptyTensors(
104 RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
105 OpBuilder::InsertionGuard g(rewriter);
106
107 op->walk([&](SubsetInsertionOpInterface op) {
108 OpOperand &source = op.getSourceOperand();
109 // Skip operands that do not bufferize inplace. "tensor.empty" could still
110 // be replaced, but the transformation may not be beneficial.
111 if (!state.isInPlace(opOperand&: source))
112 return WalkResult::skip();
113
114 // All values that are needed to create the replacement op.
115 SmallVector<Value> neededValues =
116 op.getValuesNeededToBuildSubsetExtraction();
117
118 // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
119 // equivalent tensors. I.e., stop when there are ops such as extract_slice
120 // on the path.
121 TraversalConfig config;
122 config.followEquivalentOnly = true;
123 config.alwaysIncludeLeaves = false;
124 // Replace only if the types match or are static <-> dynamic casts. We do
125 // not support slices or reshapes.
126 // TODO: This could be extended to support IR such as:
127 // %0 = tensor.empty() : tensor<128xf32>
128 // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
129 // %2 = tensor.expand_shape %1 ...
130 // %3 = tensor.insert_slice %2 into ...
131 config.followSameTypeOrCastsOnly = true;
132 SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
133 source.get(), /*condition=*/
134 [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
135 config);
136
137 for (Value v : emptyTensors) {
138 Operation *emptyTensorOp = v.getDefiningOp();
139
140 // Find a suitable insertion point. If no suitable insertion point for
141 // the replacement can be found, skip this replacement.
142 Operation *insertionPoint =
143 findValidInsertionPoint(emptyTensorOp, neededValues);
144 if (!insertionPoint)
145 continue;
146
147 rewriter.setInsertionPoint(insertionPoint);
148 Value replacement =
149 op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
150 if (!replacement)
151 continue;
152 if (emptyTensorOp == replacement.getDefiningOp())
153 continue;
154 if (replacement.getType() != v.getType()) {
155 rewriter.setInsertionPointAfterValue(replacement);
156 replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
157 replacement);
158 }
159 // Replace the tensor::EmptyOp.
160 rewriter.replaceOp(emptyTensorOp, replacement);
161 state.resetCache();
162 }
163
164 return WalkResult::advance();
165 });
166
167 return success();
168}
169
170namespace {
171struct EmptyTensorElimination
172 : public bufferization::impl::EmptyTensorEliminationBase<
173 EmptyTensorElimination> {
174 EmptyTensorElimination() = default;
175
176 void runOnOperation() override;
177
178 void getDependentDialects(DialectRegistry &registry) const override {
179 registry
180 .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
181 }
182};
183} // namespace
184
185LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
186 Operation *op) {
187 auto moduleOp = dyn_cast<ModuleOp>(op);
188 OneShotBufferizationOptions options;
189 options.allowReturnAllocsFromLoops = true;
190 if (moduleOp)
191 options.bufferizeFunctionBoundaries = true;
192 OneShotAnalysisState state(op, options);
193 if (moduleOp) {
194 // Module analysis takes into account function boundaries.
195 if (failed(analyzeModuleOp(moduleOp, state)))
196 return failure();
197 } else {
198 // Regular One-Shot Bufferize ignores func.func block arguments, func.call,
199 // func.return.
200 if (failed(result: analyzeOp(op, state)))
201 return failure();
202 }
203
204 return bufferization::eliminateEmptyTensors(rewriter, op, state);
205}
206
207void EmptyTensorElimination::runOnOperation() {
208 IRRewriter rewriter(getOperation()->getContext());
209 if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
210 signalPassFailure();
211}
212
213std::unique_ptr<Pass> mlir::bufferization::createEmptyTensorEliminationPass() {
214 return std::make_unique<EmptyTensorElimination>();
215}
216

source code of mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp