1//===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass drops return values from functions if they are equivalent to one of
10// their arguments. E.g.:
11//
12// ```
13// func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) {
14// return %m : memref<?xf32>
15// }
16// ```
17//
18// This functions is rewritten to:
19//
20// ```
21// func.func @foo(%m : memref<?xf32>) {
22// return
23// }
24// ```
25//
26// All call sites are updated accordingly. If a function returns a cast of a
27// function argument, it is also considered equivalent. A cast is inserted at
28// the call site in that case.
29
30#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
31
32#include "mlir/Dialect/Func/IR/FuncOps.h"
33#include "mlir/Dialect/MemRef/IR/MemRef.h"
34#include "mlir/IR/Operation.h"
35#include "mlir/Pass/Pass.h"
36
37namespace mlir {
38namespace bufferization {
39#define GEN_PASS_DEF_DROPEQUIVALENTBUFFERRESULTSPASS
40#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
41} // namespace bufferization
42} // namespace mlir
43
44using namespace mlir;
45
46/// Return the unique ReturnOp that terminates `funcOp`.
47/// Return nullptr if there is no such unique ReturnOp.
48static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
49 func::ReturnOp returnOp;
50 for (Block &b : funcOp.getBody()) {
51 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
52 if (returnOp)
53 return nullptr;
54 returnOp = candidateOp;
55 }
56 }
57 return returnOp;
58}
59
60/// Return the func::FuncOp called by `callOp`.
61static func::FuncOp getCalledFunction(CallOpInterface callOp) {
62 SymbolRefAttr sym =
63 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
64 if (!sym)
65 return nullptr;
66 return dyn_cast_or_null<func::FuncOp>(
67 SymbolTable::lookupNearestSymbolFrom(callOp, sym));
68}
69
70LogicalResult
71mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
72 IRRewriter rewriter(module.getContext());
73
74 DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap;
75 // Collect the mapping of functions to their call sites.
76 module.walk([&](func::CallOp callOp) {
77 if (func::FuncOp calledFunc = getCalledFunction(callOp)) {
78 callerMap[calledFunc].insert(callOp);
79 }
80 });
81
82 for (auto funcOp : module.getOps<func::FuncOp>()) {
83 if (funcOp.isExternal())
84 continue;
85 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
86 // TODO: Support functions with multiple blocks.
87 if (!returnOp)
88 continue;
89
90 // Compute erased results.
91 SmallVector<Value> newReturnValues;
92 BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
93 DenseMap<int64_t, int64_t> resultToArgs;
94 for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
95 bool erased = false;
96 for (BlockArgument bbArg : funcOp.getArguments()) {
97 Value val = it.value();
98 while (auto castOp = val.getDefiningOp<memref::CastOp>())
99 val = castOp.getSource();
100
101 if (val == bbArg) {
102 resultToArgs[it.index()] = bbArg.getArgNumber();
103 erased = true;
104 break;
105 }
106 }
107
108 if (erased) {
109 erasedResultIndices.set(it.index());
110 } else {
111 newReturnValues.push_back(it.value());
112 }
113 }
114
115 // Update function.
116 if (failed(funcOp.eraseResults(erasedResultIndices)))
117 return failure();
118 returnOp.getOperandsMutable().assign(newReturnValues);
119
120 // Update function calls.
121 for (func::CallOp callOp : callerMap[funcOp]) {
122 rewriter.setInsertionPoint(callOp);
123 auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp,
124 callOp.getOperands());
125 SmallVector<Value> newResults;
126 int64_t nextResult = 0;
127 for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
128 if (!resultToArgs.count(i)) {
129 // This result was not erased.
130 newResults.push_back(newCallOp.getResult(nextResult++));
131 continue;
132 }
133
134 // This result was erased.
135 Value replacement = callOp.getOperand(resultToArgs[i]);
136 Type expectedType = callOp.getResult(i).getType();
137 if (replacement.getType() != expectedType) {
138 // A cast must be inserted at the call site.
139 replacement = rewriter.create<memref::CastOp>(
140 callOp.getLoc(), expectedType, replacement);
141 }
142 newResults.push_back(replacement);
143 }
144 rewriter.replaceOp(callOp, newResults);
145 }
146 }
147
148 return success();
149}
150
151namespace {
152struct DropEquivalentBufferResultsPass
153 : bufferization::impl::DropEquivalentBufferResultsPassBase<
154 DropEquivalentBufferResultsPass> {
155 void runOnOperation() override {
156 if (failed(bufferization::dropEquivalentBufferResults(module: getOperation())))
157 return signalPassFailure();
158 }
159};
160} // namespace
161

source code of mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp