1//===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass drops return values from functions if they are equivalent to one of
10// their arguments. E.g.:
11//
12// ```
13// func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) {
14// return %m : memref<?xf32>
15// }
16// ```
17//
18// This functions is rewritten to:
19//
20// ```
21// func.func @foo(%m : memref<?xf32>) {
22// return
23// }
24// ```
25//
26// All call sites are updated accordingly. If a function returns a cast of a
27// function argument, it is also considered equivalent. A cast is inserted at
28// the call site in that case.
29
30#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
31
32#include "mlir/Dialect/Func/IR/FuncOps.h"
33#include "mlir/Dialect/MemRef/IR/MemRef.h"
34
35namespace mlir {
36namespace bufferization {
37#define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTSPASS
38#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
39} // namespace bufferization
40} // namespace mlir
41
42using namespace mlir;
43
44/// Return the unique ReturnOp that terminates `funcOp`.
45/// Return nullptr if there is no such unique ReturnOp.
46static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
47 func::ReturnOp returnOp;
48 for (Block &b : funcOp.getBody()) {
49 if (auto candidateOp = dyn_cast<func::ReturnOp>(Val: b.getTerminator())) {
50 if (returnOp)
51 return nullptr;
52 returnOp = candidateOp;
53 }
54 }
55 return returnOp;
56}
57
58/// Return the func::FuncOp called by `callOp`.
59static func::FuncOp getCalledFunction(CallOpInterface callOp) {
60 SymbolRefAttr sym =
61 llvm::dyn_cast_if_present<SymbolRefAttr>(Val: callOp.getCallableForCallee());
62 if (!sym)
63 return nullptr;
64 return dyn_cast_or_null<func::FuncOp>(
65 Val: SymbolTable::lookupNearestSymbolFrom(from: callOp, symbol: sym));
66}
67
68LogicalResult
69mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
70 IRRewriter rewriter(module.getContext());
71
72 DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap;
73 // Collect the mapping of functions to their call sites.
74 module.walk(callback: [&](func::CallOp callOp) {
75 if (func::FuncOp calledFunc = getCalledFunction(callOp)) {
76 callerMap[calledFunc].insert(V: callOp);
77 }
78 });
79
80 for (auto funcOp : module.getOps<func::FuncOp>()) {
81 if (funcOp.isExternal())
82 continue;
83 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
84 // TODO: Support functions with multiple blocks.
85 if (!returnOp)
86 continue;
87
88 // Compute erased results.
89 SmallVector<Value> newReturnValues;
90 BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
91 DenseMap<int64_t, int64_t> resultToArgs;
92 for (const auto &it : llvm::enumerate(First: returnOp.getOperands())) {
93 bool erased = false;
94 for (BlockArgument bbArg : funcOp.getArguments()) {
95 Value val = it.value();
96 while (auto castOp = val.getDefiningOp<memref::CastOp>())
97 val = castOp.getSource();
98
99 if (val == bbArg) {
100 resultToArgs[it.index()] = bbArg.getArgNumber();
101 erased = true;
102 break;
103 }
104 }
105
106 if (erased) {
107 erasedResultIndices.set(it.index());
108 } else {
109 newReturnValues.push_back(Elt: it.value());
110 }
111 }
112
113 // Update function.
114 if (failed(Result: funcOp.eraseResults(resultIndices: erasedResultIndices)))
115 return failure();
116 returnOp.getOperandsMutable().assign(values: newReturnValues);
117
118 // Update function calls.
119 for (func::CallOp callOp : callerMap[funcOp]) {
120 rewriter.setInsertionPoint(callOp);
121 auto newCallOp = rewriter.create<func::CallOp>(location: callOp.getLoc(), args&: funcOp,
122 args: callOp.getOperands());
123 SmallVector<Value> newResults;
124 int64_t nextResult = 0;
125 for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
126 if (!resultToArgs.count(Val: i)) {
127 // This result was not erased.
128 newResults.push_back(Elt: newCallOp.getResult(i: nextResult++));
129 continue;
130 }
131
132 // This result was erased.
133 Value replacement = callOp.getOperand(i: resultToArgs[i]);
134 Type expectedType = callOp.getResult(i).getType();
135 if (replacement.getType() != expectedType) {
136 // A cast must be inserted at the call site.
137 replacement = rewriter.create<memref::CastOp>(
138 location: callOp.getLoc(), args&: expectedType, args&: replacement);
139 }
140 newResults.push_back(Elt: replacement);
141 }
142 rewriter.replaceOp(op: callOp, newValues: newResults);
143 }
144 }
145
146 return success();
147}
148
149namespace {
150struct DropEquivalentBufferResultsPass
151 : bufferization::impl::DropEquivalentBufferResultsPassBase<
152 DropEquivalentBufferResultsPass> {
153 void runOnOperation() override {
154 if (failed(Result: bufferization::dropEquivalentBufferResults(module: getOperation())))
155 return signalPassFailure();
156 }
157};
158} // namespace
159

source code of mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp