1 | //===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass drops return values from functions if they are equivalent to one of |
10 | // their arguments. E.g.: |
11 | // |
12 | // ``` |
13 | // func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) { |
14 | // return %m : memref<?xf32> |
15 | // } |
16 | // ``` |
17 | // |
18 | // This functions is rewritten to: |
19 | // |
20 | // ``` |
21 | // func.func @foo(%m : memref<?xf32>) { |
22 | // return |
23 | // } |
24 | // ``` |
25 | // |
26 | // All call sites are updated accordingly. If a function returns a cast of a |
27 | // function argument, it is also considered equivalent. A cast is inserted at |
28 | // the call site in that case. |
29 | |
30 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
31 | |
32 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
33 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
34 | #include "mlir/IR/Operation.h" |
35 | #include "mlir/Pass/Pass.h" |
36 | |
37 | namespace mlir { |
38 | namespace bufferization { |
39 | #define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTSPASS |
40 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
41 | } // namespace bufferization |
42 | } // namespace mlir |
43 | |
44 | using namespace mlir; |
45 | |
46 | /// Return the unique ReturnOp that terminates `funcOp`. |
47 | /// Return nullptr if there is no such unique ReturnOp. |
48 | static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { |
49 | func::ReturnOp returnOp; |
50 | for (Block &b : funcOp.getBody()) { |
51 | if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { |
52 | if (returnOp) |
53 | return nullptr; |
54 | returnOp = candidateOp; |
55 | } |
56 | } |
57 | return returnOp; |
58 | } |
59 | |
60 | /// Return the func::FuncOp called by `callOp`. |
61 | static func::FuncOp getCalledFunction(CallOpInterface callOp) { |
62 | SymbolRefAttr sym = |
63 | llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); |
64 | if (!sym) |
65 | return nullptr; |
66 | return dyn_cast_or_null<func::FuncOp>( |
67 | SymbolTable::lookupNearestSymbolFrom(callOp, sym)); |
68 | } |
69 | |
70 | LogicalResult |
71 | mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { |
72 | IRRewriter rewriter(module.getContext()); |
73 | |
74 | DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap; |
75 | // Collect the mapping of functions to their call sites. |
76 | module.walk([&](func::CallOp callOp) { |
77 | if (func::FuncOp calledFunc = getCalledFunction(callOp)) { |
78 | callerMap[calledFunc].insert(callOp); |
79 | } |
80 | }); |
81 | |
82 | for (auto funcOp : module.getOps<func::FuncOp>()) { |
83 | if (funcOp.isExternal()) |
84 | continue; |
85 | func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); |
86 | // TODO: Support functions with multiple blocks. |
87 | if (!returnOp) |
88 | continue; |
89 | |
90 | // Compute erased results. |
91 | SmallVector<Value> newReturnValues; |
92 | BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults()); |
93 | DenseMap<int64_t, int64_t> resultToArgs; |
94 | for (const auto &it : llvm::enumerate(returnOp.getOperands())) { |
95 | bool erased = false; |
96 | for (BlockArgument bbArg : funcOp.getArguments()) { |
97 | Value val = it.value(); |
98 | while (auto castOp = val.getDefiningOp<memref::CastOp>()) |
99 | val = castOp.getSource(); |
100 | |
101 | if (val == bbArg) { |
102 | resultToArgs[it.index()] = bbArg.getArgNumber(); |
103 | erased = true; |
104 | break; |
105 | } |
106 | } |
107 | |
108 | if (erased) { |
109 | erasedResultIndices.set(it.index()); |
110 | } else { |
111 | newReturnValues.push_back(it.value()); |
112 | } |
113 | } |
114 | |
115 | // Update function. |
116 | if (failed(funcOp.eraseResults(erasedResultIndices))) |
117 | return failure(); |
118 | returnOp.getOperandsMutable().assign(newReturnValues); |
119 | |
120 | // Update function calls. |
121 | for (func::CallOp callOp : callerMap[funcOp]) { |
122 | rewriter.setInsertionPoint(callOp); |
123 | auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp, |
124 | callOp.getOperands()); |
125 | SmallVector<Value> newResults; |
126 | int64_t nextResult = 0; |
127 | for (int64_t i = 0; i < callOp.getNumResults(); ++i) { |
128 | if (!resultToArgs.count(i)) { |
129 | // This result was not erased. |
130 | newResults.push_back(newCallOp.getResult(nextResult++)); |
131 | continue; |
132 | } |
133 | |
134 | // This result was erased. |
135 | Value replacement = callOp.getOperand(resultToArgs[i]); |
136 | Type expectedType = callOp.getResult(i).getType(); |
137 | if (replacement.getType() != expectedType) { |
138 | // A cast must be inserted at the call site. |
139 | replacement = rewriter.create<memref::CastOp>( |
140 | callOp.getLoc(), expectedType, replacement); |
141 | } |
142 | newResults.push_back(replacement); |
143 | } |
144 | rewriter.replaceOp(callOp, newResults); |
145 | } |
146 | } |
147 | |
148 | return success(); |
149 | } |
150 | |
151 | namespace { |
152 | struct DropEquivalentBufferResultsPass |
153 | : bufferization::impl::DropEquivalentBufferResultsPassBase< |
154 | DropEquivalentBufferResultsPass> { |
155 | void runOnOperation() override { |
156 | if (failed(bufferization::dropEquivalentBufferResults(module: getOperation()))) |
157 | return signalPassFailure(); |
158 | } |
159 | }; |
160 | } // namespace |
161 | |