1 | //===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass drops return values from functions if they are equivalent to one of |
10 | // their arguments. E.g.: |
11 | // |
12 | // ``` |
13 | // func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) { |
14 | // return %m : memref<?xf32> |
15 | // } |
16 | // ``` |
17 | // |
18 | // This functions is rewritten to: |
19 | // |
20 | // ``` |
21 | // func.func @foo(%m : memref<?xf32>) { |
22 | // return |
23 | // } |
24 | // ``` |
25 | // |
26 | // All call sites are updated accordingly. If a function returns a cast of a |
27 | // function argument, it is also considered equivalent. A cast is inserted at |
28 | // the call site in that case. |
29 | |
30 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
31 | |
32 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
33 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
34 | #include "mlir/IR/Operation.h" |
35 | #include "mlir/Pass/Pass.h" |
36 | |
37 | namespace mlir { |
38 | namespace bufferization { |
39 | #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTS |
40 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
41 | } // namespace bufferization |
42 | } // namespace mlir |
43 | |
44 | using namespace mlir; |
45 | |
46 | /// Return the unique ReturnOp that terminates `funcOp`. |
47 | /// Return nullptr if there is no such unique ReturnOp. |
48 | static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { |
49 | func::ReturnOp returnOp; |
50 | for (Block &b : funcOp.getBody()) { |
51 | if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { |
52 | if (returnOp) |
53 | return nullptr; |
54 | returnOp = candidateOp; |
55 | } |
56 | } |
57 | return returnOp; |
58 | } |
59 | |
60 | /// Return the func::FuncOp called by `callOp`. |
61 | static func::FuncOp getCalledFunction(CallOpInterface callOp) { |
62 | SymbolRefAttr sym = |
63 | llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); |
64 | if (!sym) |
65 | return nullptr; |
66 | return dyn_cast_or_null<func::FuncOp>( |
67 | SymbolTable::lookupNearestSymbolFrom(callOp, sym)); |
68 | } |
69 | |
70 | LogicalResult |
71 | mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { |
72 | IRRewriter rewriter(module.getContext()); |
73 | |
74 | for (auto funcOp : module.getOps<func::FuncOp>()) { |
75 | if (funcOp.isExternal()) |
76 | continue; |
77 | func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); |
78 | // TODO: Support functions with multiple blocks. |
79 | if (!returnOp) |
80 | continue; |
81 | |
82 | // Compute erased results. |
83 | SmallVector<Value> newReturnValues; |
84 | BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults()); |
85 | DenseMap<int64_t, int64_t> resultToArgs; |
86 | for (const auto &it : llvm::enumerate(returnOp.getOperands())) { |
87 | bool erased = false; |
88 | for (BlockArgument bbArg : funcOp.getArguments()) { |
89 | Value val = it.value(); |
90 | while (auto castOp = val.getDefiningOp<memref::CastOp>()) |
91 | val = castOp.getSource(); |
92 | |
93 | if (val == bbArg) { |
94 | resultToArgs[it.index()] = bbArg.getArgNumber(); |
95 | erased = true; |
96 | break; |
97 | } |
98 | } |
99 | |
100 | if (erased) { |
101 | erasedResultIndices.set(it.index()); |
102 | } else { |
103 | newReturnValues.push_back(it.value()); |
104 | } |
105 | } |
106 | |
107 | // Update function. |
108 | funcOp.eraseResults(erasedResultIndices); |
109 | returnOp.getOperandsMutable().assign(newReturnValues); |
110 | |
111 | // Update function calls. |
112 | module.walk([&](func::CallOp callOp) { |
113 | if (getCalledFunction(callOp) != funcOp) |
114 | return WalkResult::skip(); |
115 | |
116 | rewriter.setInsertionPoint(callOp); |
117 | auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp, |
118 | callOp.getOperands()); |
119 | SmallVector<Value> newResults; |
120 | int64_t nextResult = 0; |
121 | for (int64_t i = 0; i < callOp.getNumResults(); ++i) { |
122 | if (!resultToArgs.count(i)) { |
123 | // This result was not erased. |
124 | newResults.push_back(newCallOp.getResult(nextResult++)); |
125 | continue; |
126 | } |
127 | |
128 | // This result was erased. |
129 | Value replacement = callOp.getOperand(resultToArgs[i]); |
130 | Type expectedType = callOp.getResult(i).getType(); |
131 | if (replacement.getType() != expectedType) { |
132 | // A cast must be inserted at the call site. |
133 | replacement = rewriter.create<memref::CastOp>( |
134 | callOp.getLoc(), expectedType, replacement); |
135 | } |
136 | newResults.push_back(replacement); |
137 | } |
138 | rewriter.replaceOp(callOp, newResults); |
139 | return WalkResult::advance(); |
140 | }); |
141 | } |
142 | |
143 | return success(); |
144 | } |
145 | |
146 | namespace { |
147 | struct DropEquivalentBufferResultsPass |
148 | : bufferization::impl::DropEquivalentBufferResultsBase< |
149 | DropEquivalentBufferResultsPass> { |
150 | void runOnOperation() override { |
151 | if (failed(bufferization::dropEquivalentBufferResults(module: getOperation()))) |
152 | return signalPassFailure(); |
153 | } |
154 | }; |
155 | } // namespace |
156 | |
157 | std::unique_ptr<Pass> |
158 | mlir::bufferization::createDropEquivalentBufferResultsPass() { |
159 | return std::make_unique<DropEquivalentBufferResultsPass>(); |
160 | } |
161 | |