1//===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
10#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
11
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/IR/Operation.h"
15
16namespace mlir {
17namespace bufferization {
18#define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS
19#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
20} // namespace bufferization
21} // namespace mlir
22
23using namespace mlir;
24using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
25using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
26
27/// Return `true` if the given MemRef type has a fully dynamic layout.
28static bool hasFullyDynamicLayoutMap(MemRefType type) {
29 int64_t offset;
30 SmallVector<int64_t, 4> strides;
31 if (failed(Result: type.getStridesAndOffset(strides, offset)))
32 return false;
33 if (!llvm::all_of(Range&: strides, P: ShapedType::isDynamic))
34 return false;
35 if (ShapedType::isStatic(dValue: offset))
36 return false;
37 return true;
38}
39
40/// Return `true` if the given MemRef type has a static identity layout (i.e.,
41/// no layout).
42static bool hasStaticIdentityLayout(MemRefType type) {
43 return type.getLayout().isIdentity();
44}
45
46// Updates the func op and entry block.
47//
48// Any args appended to the entry block are added to `appendedEntryArgs`.
49// If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
50// to each newly created function argument.
51static LogicalResult
52updateFuncOp(func::FuncOp func,
53 SmallVectorImpl<BlockArgument> &appendedEntryArgs,
54 bool addResultAttribute) {
55 auto functionType = func.getFunctionType();
56
57 // Collect information about the results will become appended arguments.
58 SmallVector<Type, 6> erasedResultTypes;
59 BitVector erasedResultIndices(functionType.getNumResults());
60 for (const auto &resultType : llvm::enumerate(First: functionType.getResults())) {
61 if (auto memrefType = dyn_cast<MemRefType>(Val: resultType.value())) {
62 if (!hasStaticIdentityLayout(type: memrefType) &&
63 !hasFullyDynamicLayoutMap(type: memrefType)) {
64 // Only buffers with static identity layout can be allocated. These can
65 // be casted to memrefs with fully dynamic layout map. Other layout maps
66 // are not supported.
67 return func->emitError()
68 << "cannot create out param for result with unsupported layout";
69 }
70 erasedResultIndices.set(resultType.index());
71 erasedResultTypes.push_back(Elt: memrefType);
72 }
73 }
74
75 // Add the new arguments to the function type.
76 auto newArgTypes = llvm::to_vector<6>(
77 Range: llvm::concat<const Type>(Ranges: functionType.getInputs(), Ranges&: erasedResultTypes));
78 auto newFunctionType = FunctionType::get(context: func.getContext(), inputs: newArgTypes,
79 results: functionType.getResults());
80 func.setType(newFunctionType);
81
82 // Transfer the result attributes to arg attributes.
83 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
84 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
85 func.setArgAttrs(index: functionType.getNumInputs() + i,
86 attributes: func.getResultAttrs(index: *erasedIndicesIt));
87 if (addResultAttribute)
88 func.setArgAttr(index: functionType.getNumInputs() + i,
89 name: StringAttr::get(context: func.getContext(), bytes: "bufferize.result"),
90 value: UnitAttr::get(context: func.getContext()));
91 }
92
93 // Erase the results.
94 if (failed(Result: func.eraseResults(resultIndices: erasedResultIndices)))
95 return failure();
96
97 // Add the new arguments to the entry block if the function is not external.
98 if (func.isExternal())
99 return success();
100 Location loc = func.getLoc();
101 for (Type type : erasedResultTypes)
102 appendedEntryArgs.push_back(Elt: func.front().addArgument(type, loc));
103
104 return success();
105}
106
107// Updates all ReturnOps in the scope of the given func::FuncOp by either
108// keeping them as return values or copying the associated buffer contents into
109// the given out-params.
110static LogicalResult
111updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
112 const bufferization::BufferResultsToOutParamsOpts &options) {
113 auto res = func.walk(callback: [&](func::ReturnOp op) {
114 SmallVector<Value, 6> copyIntoOutParams;
115 SmallVector<Value, 6> keepAsReturnOperands;
116 for (Value operand : op.getOperands()) {
117 if (isa<MemRefType>(Val: operand.getType()))
118 copyIntoOutParams.push_back(Elt: operand);
119 else
120 keepAsReturnOperands.push_back(Elt: operand);
121 }
122 OpBuilder builder(op);
123 for (auto [orig, arg] : llvm::zip(t&: copyIntoOutParams, u&: appendedEntryArgs)) {
124 if (options.hoistStaticAllocs &&
125 isa_and_nonnull<bufferization::AllocationOpInterface>(
126 Val: orig.getDefiningOp()) &&
127 mlir::cast<MemRefType>(Val: orig.getType()).hasStaticShape()) {
128 orig.replaceAllUsesWith(newValue: arg);
129 orig.getDefiningOp()->erase();
130 } else {
131 if (failed(Result: options.memCpyFn(builder, op.getLoc(), orig, arg)))
132 return WalkResult::interrupt();
133 }
134 }
135 builder.create<func::ReturnOp>(location: op.getLoc(), args&: keepAsReturnOperands);
136 op.erase();
137 return WalkResult::advance();
138 });
139 return failure(IsFailure: res.wasInterrupted());
140}
141
142// Updates all CallOps in the scope of the given ModuleOp by allocating
143// temporary buffers for newly introduced out params.
144static LogicalResult
145updateCalls(ModuleOp module,
146 const bufferization::BufferResultsToOutParamsOpts &options) {
147 bool didFail = false;
148 SymbolTable symtab(module);
149 module.walk(callback: [&](func::CallOp op) {
150 auto callee = symtab.lookup<func::FuncOp>(name: op.getCallee());
151 if (!callee) {
152 op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
153 << "symbol table";
154 didFail = true;
155 return;
156 }
157 if (!options.filterFn(&callee))
158 return;
159 SmallVector<Value, 6> replaceWithNewCallResults;
160 SmallVector<Value, 6> replaceWithOutParams;
161 for (OpResult result : op.getResults()) {
162 if (isa<MemRefType>(Val: result.getType()))
163 replaceWithOutParams.push_back(Elt: result);
164 else
165 replaceWithNewCallResults.push_back(Elt: result);
166 }
167 SmallVector<Value, 6> outParams;
168 OpBuilder builder(op);
169 for (Value memref : replaceWithOutParams) {
170 if (!cast<MemRefType>(Val: memref.getType()).hasStaticShape()) {
171 op.emitError()
172 << "cannot create out param for dynamically shaped result";
173 didFail = true;
174 return;
175 }
176 auto memrefType = cast<MemRefType>(Val: memref.getType());
177 auto allocType =
178 MemRefType::get(shape: memrefType.getShape(), elementType: memrefType.getElementType(),
179 map: AffineMap(), memorySpace: memrefType.getMemorySpace());
180 auto maybeOutParam =
181 options.allocationFn(builder, op.getLoc(), allocType);
182 if (failed(Result: maybeOutParam)) {
183 op.emitError() << "failed to create allocation op";
184 didFail = true;
185 return;
186 }
187 Value outParam = maybeOutParam.value();
188 if (!hasStaticIdentityLayout(type: memrefType)) {
189 // Layout maps are already checked in `updateFuncOp`.
190 assert(hasFullyDynamicLayoutMap(memrefType) &&
191 "layout map not supported");
192 outParam =
193 builder.create<memref::CastOp>(location: op.getLoc(), args&: memrefType, args&: outParam);
194 }
195 memref.replaceAllUsesWith(newValue: outParam);
196 outParams.push_back(Elt: outParam);
197 }
198
199 auto newOperands = llvm::to_vector<6>(Range: op.getOperands());
200 newOperands.append(in_start: outParams.begin(), in_end: outParams.end());
201 auto newResultTypes = llvm::to_vector<6>(Range: llvm::map_range(
202 C&: replaceWithNewCallResults, F: [](Value v) { return v.getType(); }));
203 auto newCall = builder.create<func::CallOp>(location: op.getLoc(), args: op.getCalleeAttr(),
204 args&: newResultTypes, args&: newOperands);
205 for (auto t : llvm::zip(t&: replaceWithNewCallResults, u: newCall.getResults()))
206 std::get<0>(t&: t).replaceAllUsesWith(newValue: std::get<1>(t&: t));
207 op.erase();
208 });
209
210 return failure(IsFailure: didFail);
211}
212
213LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
214 ModuleOp module,
215 const bufferization::BufferResultsToOutParamsOpts &options) {
216 for (auto func : module.getOps<func::FuncOp>()) {
217 if (!options.filterFn(&func))
218 continue;
219 SmallVector<BlockArgument, 6> appendedEntryArgs;
220 if (failed(
221 Result: updateFuncOp(func, appendedEntryArgs, addResultAttribute: options.addResultAttribute)))
222 return failure();
223 if (func.isExternal())
224 continue;
225 if (failed(Result: updateReturnOps(func, appendedEntryArgs, options))) {
226 return failure();
227 }
228 }
229 if (failed(Result: updateCalls(module, options)))
230 return failure();
231 return success();
232}
233
234namespace {
235struct BufferResultsToOutParamsPass
236 : bufferization::impl::BufferResultsToOutParamsPassBase<
237 BufferResultsToOutParamsPass> {
238 using Base::Base;
239
240 void runOnOperation() override {
241 // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
242 if (addResultAttribute)
243 options.addResultAttribute = true;
244 if (hoistStaticAllocs)
245 options.hoistStaticAllocs = true;
246
247 if (failed(Result: bufferization::promoteBufferResultsToOutParams(module: getOperation(),
248 options)))
249 return signalPassFailure();
250 }
251
252private:
253 bufferization::BufferResultsToOutParamsOpts options;
254};
255} // namespace
256

source code of mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp