1 | //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
13 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" |
15 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "mlir/IR/Dominance.h" |
18 | #include "mlir/Interfaces/SubsetOpInterface.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | |
21 | namespace mlir { |
22 | namespace bufferization { |
23 | #define GEN_PASS_DEF_EMPTYTENSORELIMINATIONPASS |
24 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
25 | } // namespace bufferization |
26 | } // namespace mlir |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::bufferization; |
30 | |
31 | /// Return true if all `neededValues` are in scope at the given |
32 | /// `insertionPoint`. |
33 | static bool |
34 | neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, |
35 | Operation *insertionPoint, |
36 | const SmallVector<Value> &neededValues) { |
37 | for (Value val : neededValues) { |
38 | if (auto bbArg = dyn_cast<BlockArgument>(val)) { |
39 | Block *owner = bbArg.getOwner(); |
40 | if (!owner->findAncestorOpInBlock(op&: *insertionPoint)) |
41 | return false; |
42 | } else { |
43 | auto opResult = cast<OpResult>(val); |
44 | if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint)) |
45 | return false; |
46 | } |
47 | } |
48 | return true; |
49 | } |
50 | |
51 | /// Find a valid insertion point for a replacement of `emptyTensorOp`'s |
52 | /// use of `user` operation, assuming that the replacement may use any |
53 | /// value from `neededValues`. |
54 | static Operation * |
55 | findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, |
56 | const SmallVector<Value> &neededValues) { |
57 | DominanceInfo domInfo; |
58 | Operation *candidateInsertionPoint = emptyTensorOp; |
59 | |
60 | // Gather all possible insertion points: the location of |
61 | // `candidateInsertionPoint` and right after the definition of each value in |
62 | // `neededValues`. |
63 | SmallVector<Operation *> insertionPointCandidates; |
64 | insertionPointCandidates.push_back(Elt: candidateInsertionPoint); |
65 | for (Value val : neededValues) { |
66 | // Note: The anchor op is using all of `neededValues`, so: |
67 | // * in case of a block argument: There must be at least one op in the block |
68 | // (the anchor op or one of its parents). |
69 | // * in case of an OpResult: There must be at least one op right after the |
70 | // defining op (the anchor op or one of its |
71 | // parents). |
72 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: val)) { |
73 | insertionPointCandidates.push_back( |
74 | Elt: &bbArg.getOwner()->getOperations().front()); |
75 | } else { |
76 | insertionPointCandidates.push_back(Elt: val.getDefiningOp()->getNextNode()); |
77 | } |
78 | } |
79 | |
80 | // Select first matching insertion point. |
81 | for (Operation *insertionPoint : insertionPointCandidates) { |
82 | // Check if all needed values are in scope. |
83 | if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, |
84 | neededValues)) |
85 | continue; |
86 | // Check if the insertion point is before the use to be replaced. |
87 | if (!domInfo.dominates(a: insertionPoint, b: user)) |
88 | continue; |
89 | return insertionPoint; |
90 | } |
91 | |
92 | // No suitable insertion point was found. |
93 | return nullptr; |
94 | } |
95 | |
96 | Value mlir::bufferization::(RewriterBase &rewriter, |
97 | SubsetInsertionOpInterface op, |
98 | tensor::EmptyOp emptyTensorOp, |
99 | Operation *user) { |
100 | |
101 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
102 | // All values that are needed to create the replacement op. |
103 | SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction(); |
104 | // Find a suitable insertion point. If no suitable insertion point |
105 | // for the replacement can be found, return an empty value to skip |
106 | // this replacement. |
107 | Operation *insertionPoint = |
108 | findValidInsertionPoint(emptyTensorOp, user, neededValues); |
109 | if (!insertionPoint) |
110 | return {}; |
111 | |
112 | rewriter.setInsertionPoint(insertionPoint); |
113 | Value replacement = |
114 | op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); |
115 | return replacement; |
116 | } |
117 | |
118 | LogicalResult mlir::bufferization::eliminateEmptyTensors( |
119 | RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, |
120 | ControlBuildSubsetExtractionFn ) { |
121 | OpBuilder::InsertionGuard g(rewriter); |
122 | llvm::DenseSet<OpOperand *> visitedOpOperands; |
123 | op->walk([&](SubsetInsertionOpInterface op) { |
124 | visitedOpOperands.clear(); |
125 | OpOperand &source = op.getSourceOperand(); |
126 | // Skip operands that do not bufferize inplace. "tensor.empty" could still |
127 | // be replaced, but the transformation may not be beneficial. |
128 | if (!state.isInPlace(opOperand&: source)) |
129 | return WalkResult::skip(); |
130 | |
131 | // Find tensor.empty ops on the reverse SSA use-def chain. Only follow |
132 | // equivalent tensors. I.e., stop when there are ops such as extract_slice |
133 | // on the path. |
134 | TraversalConfig config; |
135 | config.followEquivalentOnly = true; |
136 | config.alwaysIncludeLeaves = false; |
137 | // Replace only if the types match or are static <-> dynamic casts. We do |
138 | // not support slices or reshapes. |
139 | // TODO: This could be extended to support IR such as: |
140 | // %0 = tensor.empty() : tensor<128xf32> |
141 | // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) |
142 | // %2 = tensor.expand_shape %1 ... |
143 | // %3 = tensor.insert_slice %2 into ... |
144 | config.followSameTypeOrCastsOnly = true; |
145 | SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( |
146 | &source, /*condition=*/ |
147 | [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config, |
148 | &visitedOpOperands); |
149 | |
150 | for (Value v : emptyTensors) { |
151 | auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>(); |
152 | assert(emptyTensorOp && "expected tensor.empty op" ); |
153 | // Find the use to be replaced from the use-def chain. |
154 | auto iter = llvm::find_if( |
155 | visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) { |
156 | return llvm::count(emptyTensorOp->getUses(), *opOperand); |
157 | }); |
158 | |
159 | assert(iter != visitedOpOperands.end() && "could not find use" ); |
160 | OpOperand *useToBeReplaced = *iter; |
161 | Operation *user = useToBeReplaced->getOwner(); |
162 | auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user); |
163 | if (!replacement) |
164 | continue; |
165 | if (emptyTensorOp == replacement.getDefiningOp()) |
166 | continue; |
167 | if (replacement.getType() != v.getType()) { |
168 | if (cast<ShapedType>(replacement.getType()).getElementType() != |
169 | cast<ShapedType>(v.getType()).getElementType()) |
170 | continue; |
171 | rewriter.setInsertionPointAfterValue(replacement); |
172 | replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(), |
173 | replacement); |
174 | } |
175 | // Replace the specific use of the tensor::EmptyOp. |
176 | rewriter.modifyOpInPlace(user, [&]() { |
177 | user->setOperand(useToBeReplaced->getOperandNumber(), replacement); |
178 | }); |
179 | state.resetCache(); |
180 | } |
181 | |
182 | return WalkResult::advance(); |
183 | }); |
184 | |
185 | return success(); |
186 | } |
187 | |
188 | namespace { |
189 | struct EmptyTensorElimination |
190 | : public bufferization::impl::EmptyTensorEliminationPassBase< |
191 | EmptyTensorElimination> { |
192 | using Base::Base; |
193 | |
194 | void runOnOperation() override; |
195 | |
196 | void getDependentDialects(DialectRegistry ®istry) const override { |
197 | registry |
198 | .insert<bufferization::BufferizationDialect, tensor::TensorDialect>(); |
199 | } |
200 | }; |
201 | } // namespace |
202 | |
203 | LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter, |
204 | Operation *op) { |
205 | auto moduleOp = dyn_cast<ModuleOp>(op); |
206 | OneShotBufferizationOptions options; |
207 | options.allowReturnAllocsFromLoops = true; |
208 | if (moduleOp) |
209 | options.bufferizeFunctionBoundaries = true; |
210 | OneShotAnalysisState state(op, options); |
211 | if (moduleOp) { |
212 | // Module analysis takes into account function boundaries. |
213 | if (failed(analyzeModuleOp(moduleOp, state))) |
214 | return failure(); |
215 | } else { |
216 | // Regular One-Shot Bufferize ignores func.func block arguments, func.call, |
217 | // func.return. |
218 | if (failed(Result: analyzeOp(op, state))) |
219 | return failure(); |
220 | } |
221 | |
222 | return bufferization::eliminateEmptyTensors(rewriter, op, state); |
223 | } |
224 | |
225 | void EmptyTensorElimination::runOnOperation() { |
226 | IRRewriter rewriter(getOperation()->getContext()); |
227 | if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation()))) |
228 | signalPassFailure(); |
229 | } |
230 | |