1//===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
10#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
11
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/IR/Operation.h"
15#include "mlir/Pass/Pass.h"
16
17namespace mlir {
18namespace bufferization {
19#define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS
20#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
21} // namespace bufferization
22} // namespace mlir
23
24using namespace mlir;
25using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
26using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
27
28/// Return `true` if the given MemRef type has a fully dynamic layout.
29static bool hasFullyDynamicLayoutMap(MemRefType type) {
30 int64_t offset;
31 SmallVector<int64_t, 4> strides;
32 if (failed(type.getStridesAndOffset(strides, offset)))
33 return false;
34 if (!llvm::all_of(strides, ShapedType::isDynamic))
35 return false;
36 if (!ShapedType::isDynamic(offset))
37 return false;
38 return true;
39}
40
41/// Return `true` if the given MemRef type has a static identity layout (i.e.,
42/// no layout).
43static bool hasStaticIdentityLayout(MemRefType type) {
44 return type.getLayout().isIdentity();
45}
46
47// Updates the func op and entry block.
48//
49// Any args appended to the entry block are added to `appendedEntryArgs`.
50// If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
51// to each newly created function argument.
52static LogicalResult
53updateFuncOp(func::FuncOp func,
54 SmallVectorImpl<BlockArgument> &appendedEntryArgs,
55 bool addResultAttribute) {
56 auto functionType = func.getFunctionType();
57
58 // Collect information about the results will become appended arguments.
59 SmallVector<Type, 6> erasedResultTypes;
60 BitVector erasedResultIndices(functionType.getNumResults());
61 for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
62 if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
63 if (!hasStaticIdentityLayout(memrefType) &&
64 !hasFullyDynamicLayoutMap(memrefType)) {
65 // Only buffers with static identity layout can be allocated. These can
66 // be casted to memrefs with fully dynamic layout map. Other layout maps
67 // are not supported.
68 return func->emitError()
69 << "cannot create out param for result with unsupported layout";
70 }
71 erasedResultIndices.set(resultType.index());
72 erasedResultTypes.push_back(memrefType);
73 }
74 }
75
76 // Add the new arguments to the function type.
77 auto newArgTypes = llvm::to_vector<6>(
78 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
79 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
80 functionType.getResults());
81 func.setType(newFunctionType);
82
83 // Transfer the result attributes to arg attributes.
84 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
85 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
86 func.setArgAttrs(functionType.getNumInputs() + i,
87 func.getResultAttrs(*erasedIndicesIt));
88 if (addResultAttribute)
89 func.setArgAttr(functionType.getNumInputs() + i,
90 StringAttr::get(func.getContext(), "bufferize.result"),
91 UnitAttr::get(func.getContext()));
92 }
93
94 // Erase the results.
95 if (failed(func.eraseResults(erasedResultIndices)))
96 return failure();
97
98 // Add the new arguments to the entry block if the function is not external.
99 if (func.isExternal())
100 return success();
101 Location loc = func.getLoc();
102 for (Type type : erasedResultTypes)
103 appendedEntryArgs.push_back(Elt: func.front().addArgument(type, loc));
104
105 return success();
106}
107
108// Updates all ReturnOps in the scope of the given func::FuncOp by either
109// keeping them as return values or copying the associated buffer contents into
110// the given out-params.
111static LogicalResult
112updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
113 const bufferization::BufferResultsToOutParamsOpts &options) {
114 auto res = func.walk([&](func::ReturnOp op) {
115 SmallVector<Value, 6> copyIntoOutParams;
116 SmallVector<Value, 6> keepAsReturnOperands;
117 for (Value operand : op.getOperands()) {
118 if (isa<MemRefType>(operand.getType()))
119 copyIntoOutParams.push_back(operand);
120 else
121 keepAsReturnOperands.push_back(operand);
122 }
123 OpBuilder builder(op);
124 for (auto [orig, arg] : llvm::zip(t&: copyIntoOutParams, u&: appendedEntryArgs)) {
125 if (options.hoistStaticAllocs &&
126 isa_and_nonnull<bufferization::AllocationOpInterface>(
127 orig.getDefiningOp()) &&
128 mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
129 orig.replaceAllUsesWith(newValue: arg);
130 orig.getDefiningOp()->erase();
131 } else {
132 if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
133 return WalkResult::interrupt();
134 }
135 }
136 builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
137 op.erase();
138 return WalkResult::advance();
139 });
140 return failure(res.wasInterrupted());
141}
142
143// Updates all CallOps in the scope of the given ModuleOp by allocating
144// temporary buffers for newly introduced out params.
145static LogicalResult
146updateCalls(ModuleOp module,
147 const bufferization::BufferResultsToOutParamsOpts &options) {
148 bool didFail = false;
149 SymbolTable symtab(module);
150 module.walk([&](func::CallOp op) {
151 auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
152 if (!callee) {
153 op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
154 << "symbol table";
155 didFail = true;
156 return;
157 }
158 if (!options.filterFn(&callee))
159 return;
160 SmallVector<Value, 6> replaceWithNewCallResults;
161 SmallVector<Value, 6> replaceWithOutParams;
162 for (OpResult result : op.getResults()) {
163 if (isa<MemRefType>(result.getType()))
164 replaceWithOutParams.push_back(result);
165 else
166 replaceWithNewCallResults.push_back(result);
167 }
168 SmallVector<Value, 6> outParams;
169 OpBuilder builder(op);
170 for (Value memref : replaceWithOutParams) {
171 if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
172 op.emitError()
173 << "cannot create out param for dynamically shaped result";
174 didFail = true;
175 return;
176 }
177 auto memrefType = cast<MemRefType>(memref.getType());
178 auto allocType =
179 MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
180 AffineMap(), memrefType.getMemorySpace());
181 auto maybeOutParam =
182 options.allocationFn(builder, op.getLoc(), allocType);
183 if (failed(maybeOutParam)) {
184 op.emitError() << "failed to create allocation op";
185 didFail = true;
186 return;
187 }
188 Value outParam = maybeOutParam.value();
189 if (!hasStaticIdentityLayout(memrefType)) {
190 // Layout maps are already checked in `updateFuncOp`.
191 assert(hasFullyDynamicLayoutMap(memrefType) &&
192 "layout map not supported");
193 outParam =
194 builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
195 }
196 memref.replaceAllUsesWith(newValue: outParam);
197 outParams.push_back(Elt: outParam);
198 }
199
200 auto newOperands = llvm::to_vector<6>(op.getOperands());
201 newOperands.append(outParams.begin(), outParams.end());
202 auto newResultTypes = llvm::to_vector<6>(Range: llvm::map_range(
203 C&: replaceWithNewCallResults, F: [](Value v) { return v.getType(); }));
204 auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
205 newResultTypes, newOperands);
206 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
207 std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
208 op.erase();
209 });
210
211 return failure(IsFailure: didFail);
212}
213
214LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
215 ModuleOp module,
216 const bufferization::BufferResultsToOutParamsOpts &options) {
217 for (auto func : module.getOps<func::FuncOp>()) {
218 if (!options.filterFn(&func))
219 continue;
220 SmallVector<BlockArgument, 6> appendedEntryArgs;
221 if (failed(
222 updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
223 return failure();
224 if (func.isExternal())
225 continue;
226 if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
227 return failure();
228 }
229 }
230 if (failed(updateCalls(module, options)))
231 return failure();
232 return success();
233}
234
235namespace {
236struct BufferResultsToOutParamsPass
237 : bufferization::impl::BufferResultsToOutParamsPassBase<
238 BufferResultsToOutParamsPass> {
239 using Base::Base;
240
241 void runOnOperation() override {
242 // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
243 if (addResultAttribute)
244 options.addResultAttribute = true;
245 if (hoistStaticAllocs)
246 options.hoistStaticAllocs = true;
247
248 if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
249 options)))
250 return signalPassFailure();
251 }
252
253private:
254 bufferization::BufferResultsToOutParamsOpts options;
255};
256} // namespace
257

source code of mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp