1 | //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
13 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" |
15 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "mlir/IR/Dominance.h" |
18 | #include "mlir/Interfaces/SubsetOpInterface.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | |
21 | namespace mlir { |
22 | namespace bufferization { |
23 | #define GEN_PASS_DEF_EMPTYTENSORELIMINATION |
24 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
25 | } // namespace bufferization |
26 | } // namespace mlir |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::bufferization; |
30 | |
31 | /// Return true if all `neededValues` are in scope at the given |
32 | /// `insertionPoint`. |
33 | static bool |
34 | neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, |
35 | Operation *insertionPoint, |
36 | const SmallVector<Value> &neededValues) { |
37 | for (Value val : neededValues) { |
38 | if (auto bbArg = dyn_cast<BlockArgument>(val)) { |
39 | Block *owner = bbArg.getOwner(); |
40 | if (!owner->findAncestorOpInBlock(op&: *insertionPoint)) |
41 | return false; |
42 | } else { |
43 | auto opResult = cast<OpResult>(val); |
44 | if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint)) |
45 | return false; |
46 | } |
47 | } |
48 | return true; |
49 | } |
50 | |
51 | /// Return true if the given `insertionPoint` dominates all uses of |
52 | /// `emptyTensorOp`. |
53 | static bool insertionPointDominatesUses(const DominanceInfo &domInfo, |
54 | Operation *insertionPoint, |
55 | Operation *emptyTensorOp) { |
56 | return llvm::all_of(Range: emptyTensorOp->getUsers(), P: [&](Operation *user) { |
57 | return domInfo.dominates(a: insertionPoint, b: user); |
58 | }); |
59 | } |
60 | |
61 | /// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming |
62 | /// that the replacement may use any value from `neededValues`. |
63 | static Operation * |
64 | findValidInsertionPoint(Operation *emptyTensorOp, |
65 | const SmallVector<Value> &neededValues) { |
66 | DominanceInfo domInfo; |
67 | |
68 | // Gather all possible insertion points: the location of `emptyTensorOp` and |
69 | // right after the definition of each value in `neededValues`. |
70 | SmallVector<Operation *> insertionPointCandidates; |
71 | insertionPointCandidates.push_back(Elt: emptyTensorOp); |
72 | for (Value val : neededValues) { |
73 | // Note: The anchor op is using all of `neededValues`, so: |
74 | // * in case of a block argument: There must be at least one op in the block |
75 | // (the anchor op or one of its parents). |
76 | // * in case of an OpResult: There must be at least one op right after the |
77 | // defining op (the anchor op or one of its |
78 | // parents). |
79 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: val)) { |
80 | insertionPointCandidates.push_back( |
81 | Elt: &bbArg.getOwner()->getOperations().front()); |
82 | } else { |
83 | insertionPointCandidates.push_back(Elt: val.getDefiningOp()->getNextNode()); |
84 | } |
85 | } |
86 | |
87 | // Select first matching insertion point. |
88 | for (Operation *insertionPoint : insertionPointCandidates) { |
89 | // Check if all needed values are in scope. |
90 | if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, |
91 | neededValues)) |
92 | continue; |
93 | // Check if the insertion point is before all uses. |
94 | if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp)) |
95 | continue; |
96 | return insertionPoint; |
97 | } |
98 | |
99 | // No suitable insertion point was found. |
100 | return nullptr; |
101 | } |
102 | |
103 | LogicalResult mlir::bufferization::eliminateEmptyTensors( |
104 | RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { |
105 | OpBuilder::InsertionGuard g(rewriter); |
106 | |
107 | op->walk([&](SubsetInsertionOpInterface op) { |
108 | OpOperand &source = op.getSourceOperand(); |
109 | // Skip operands that do not bufferize inplace. "tensor.empty" could still |
110 | // be replaced, but the transformation may not be beneficial. |
111 | if (!state.isInPlace(opOperand&: source)) |
112 | return WalkResult::skip(); |
113 | |
114 | // All values that are needed to create the replacement op. |
115 | SmallVector<Value> neededValues = |
116 | op.getValuesNeededToBuildSubsetExtraction(); |
117 | |
118 | // Find tensor.empty ops on the reverse SSA use-def chain. Only follow |
119 | // equivalent tensors. I.e., stop when there are ops such as extract_slice |
120 | // on the path. |
121 | TraversalConfig config; |
122 | config.followEquivalentOnly = true; |
123 | config.alwaysIncludeLeaves = false; |
124 | // Replace only if the types match or are static <-> dynamic casts. We do |
125 | // not support slices or reshapes. |
126 | // TODO: This could be extended to support IR such as: |
127 | // %0 = tensor.empty() : tensor<128xf32> |
128 | // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) |
129 | // %2 = tensor.expand_shape %1 ... |
130 | // %3 = tensor.insert_slice %2 into ... |
131 | config.followSameTypeOrCastsOnly = true; |
132 | SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( |
133 | source.get(), /*condition=*/ |
134 | [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, |
135 | config); |
136 | |
137 | for (Value v : emptyTensors) { |
138 | Operation *emptyTensorOp = v.getDefiningOp(); |
139 | |
140 | // Find a suitable insertion point. If no suitable insertion point for |
141 | // the replacement can be found, skip this replacement. |
142 | Operation *insertionPoint = |
143 | findValidInsertionPoint(emptyTensorOp, neededValues); |
144 | if (!insertionPoint) |
145 | continue; |
146 | |
147 | rewriter.setInsertionPoint(insertionPoint); |
148 | Value replacement = |
149 | op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); |
150 | if (!replacement) |
151 | continue; |
152 | if (emptyTensorOp == replacement.getDefiningOp()) |
153 | continue; |
154 | if (replacement.getType() != v.getType()) { |
155 | rewriter.setInsertionPointAfterValue(replacement); |
156 | replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(), |
157 | replacement); |
158 | } |
159 | // Replace the tensor::EmptyOp. |
160 | rewriter.replaceOp(emptyTensorOp, replacement); |
161 | state.resetCache(); |
162 | } |
163 | |
164 | return WalkResult::advance(); |
165 | }); |
166 | |
167 | return success(); |
168 | } |
169 | |
170 | namespace { |
171 | struct EmptyTensorElimination |
172 | : public bufferization::impl::EmptyTensorEliminationBase< |
173 | EmptyTensorElimination> { |
174 | EmptyTensorElimination() = default; |
175 | |
176 | void runOnOperation() override; |
177 | |
178 | void getDependentDialects(DialectRegistry ®istry) const override { |
179 | registry |
180 | .insert<bufferization::BufferizationDialect, tensor::TensorDialect>(); |
181 | } |
182 | }; |
183 | } // namespace |
184 | |
185 | LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter, |
186 | Operation *op) { |
187 | auto moduleOp = dyn_cast<ModuleOp>(op); |
188 | OneShotBufferizationOptions options; |
189 | options.allowReturnAllocsFromLoops = true; |
190 | if (moduleOp) |
191 | options.bufferizeFunctionBoundaries = true; |
192 | OneShotAnalysisState state(op, options); |
193 | if (moduleOp) { |
194 | // Module analysis takes into account function boundaries. |
195 | if (failed(analyzeModuleOp(moduleOp, state))) |
196 | return failure(); |
197 | } else { |
198 | // Regular One-Shot Bufferize ignores func.func block arguments, func.call, |
199 | // func.return. |
200 | if (failed(result: analyzeOp(op, state))) |
201 | return failure(); |
202 | } |
203 | |
204 | return bufferization::eliminateEmptyTensors(rewriter, op, state); |
205 | } |
206 | |
207 | void EmptyTensorElimination::runOnOperation() { |
208 | IRRewriter rewriter(getOperation()->getContext()); |
209 | if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation()))) |
210 | signalPassFailure(); |
211 | } |
212 | |
213 | std::unique_ptr<Pass> mlir::bufferization::createEmptyTensorEliminationPass() { |
214 | return std::make_unique<EmptyTensorElimination>(); |
215 | } |
216 | |