1//===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/Pass/Pass.h"
15
16namespace mlir {
17namespace bufferization {
18#define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS
19#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
20} // namespace bufferization
21} // namespace mlir
22
23using namespace mlir;
24using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
25
26/// Return `true` if the given MemRef type has a fully dynamic layout.
27static bool hasFullyDynamicLayoutMap(MemRefType type) {
28 int64_t offset;
29 SmallVector<int64_t, 4> strides;
30 if (failed(getStridesAndOffset(type, strides, offset)))
31 return false;
32 if (!llvm::all_of(strides, ShapedType::isDynamic))
33 return false;
34 if (!ShapedType::isDynamic(offset))
35 return false;
36 return true;
37}
38
39/// Return `true` if the given MemRef type has a static identity layout (i.e.,
40/// no layout).
41static bool hasStaticIdentityLayout(MemRefType type) {
42 return type.getLayout().isIdentity();
43}
44
45// Updates the func op and entry block.
46//
47// Any args appended to the entry block are added to `appendedEntryArgs`.
48// If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
49// to each newly created function argument.
50static LogicalResult
51updateFuncOp(func::FuncOp func,
52 SmallVectorImpl<BlockArgument> &appendedEntryArgs,
53 bool addResultAttribute) {
54 auto functionType = func.getFunctionType();
55
56 // Collect information about the results will become appended arguments.
57 SmallVector<Type, 6> erasedResultTypes;
58 BitVector erasedResultIndices(functionType.getNumResults());
59 for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
60 if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
61 if (!hasStaticIdentityLayout(memrefType) &&
62 !hasFullyDynamicLayoutMap(memrefType)) {
63 // Only buffers with static identity layout can be allocated. These can
64 // be casted to memrefs with fully dynamic layout map. Other layout maps
65 // are not supported.
66 return func->emitError()
67 << "cannot create out param for result with unsupported layout";
68 }
69 erasedResultIndices.set(resultType.index());
70 erasedResultTypes.push_back(memrefType);
71 }
72 }
73
74 // Add the new arguments to the function type.
75 auto newArgTypes = llvm::to_vector<6>(
76 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
77 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
78 functionType.getResults());
79 func.setType(newFunctionType);
80
81 // Transfer the result attributes to arg attributes.
82 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
83 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
84 func.setArgAttrs(functionType.getNumInputs() + i,
85 func.getResultAttrs(*erasedIndicesIt));
86 if (addResultAttribute)
87 func.setArgAttr(functionType.getNumInputs() + i,
88 StringAttr::get(func.getContext(), "bufferize.result"),
89 UnitAttr::get(func.getContext()));
90 }
91
92 // Erase the results.
93 func.eraseResults(erasedResultIndices);
94
95 // Add the new arguments to the entry block if the function is not external.
96 if (func.isExternal())
97 return success();
98 Location loc = func.getLoc();
99 for (Type type : erasedResultTypes)
100 appendedEntryArgs.push_back(Elt: func.front().addArgument(type, loc));
101
102 return success();
103}
104
105// Updates all ReturnOps in the scope of the given func::FuncOp by either
106// keeping them as return values or copying the associated buffer contents into
107// the given out-params.
108static LogicalResult updateReturnOps(func::FuncOp func,
109 ArrayRef<BlockArgument> appendedEntryArgs,
110 MemCpyFn memCpyFn) {
111 auto res = func.walk([&](func::ReturnOp op) {
112 SmallVector<Value, 6> copyIntoOutParams;
113 SmallVector<Value, 6> keepAsReturnOperands;
114 for (Value operand : op.getOperands()) {
115 if (isa<MemRefType>(operand.getType()))
116 copyIntoOutParams.push_back(operand);
117 else
118 keepAsReturnOperands.push_back(operand);
119 }
120 OpBuilder builder(op);
121 for (auto t : llvm::zip(t&: copyIntoOutParams, u&: appendedEntryArgs)) {
122 if (failed(
123 memCpyFn(builder, op.getLoc(), std::get<0>(t&: t), std::get<1>(t&: t))))
124 return WalkResult::interrupt();
125 }
126 builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
127 op.erase();
128 return WalkResult::advance();
129 });
130 return failure(res.wasInterrupted());
131}
132
133// Updates all CallOps in the scope of the given ModuleOp by allocating
134// temporary buffers for newly introduced out params.
135static LogicalResult
136updateCalls(ModuleOp module,
137 const bufferization::BufferResultsToOutParamsOpts &options) {
138 bool didFail = false;
139 SymbolTable symtab(module);
140 module.walk([&](func::CallOp op) {
141 auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
142 if (!callee) {
143 op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
144 << "symbol table";
145 didFail = true;
146 return;
147 }
148 if (!options.filterFn(&callee))
149 return;
150 SmallVector<Value, 6> replaceWithNewCallResults;
151 SmallVector<Value, 6> replaceWithOutParams;
152 for (OpResult result : op.getResults()) {
153 if (isa<MemRefType>(result.getType()))
154 replaceWithOutParams.push_back(result);
155 else
156 replaceWithNewCallResults.push_back(result);
157 }
158 SmallVector<Value, 6> outParams;
159 OpBuilder builder(op);
160 for (Value memref : replaceWithOutParams) {
161 if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
162 op.emitError()
163 << "cannot create out param for dynamically shaped result";
164 didFail = true;
165 return;
166 }
167 auto memrefType = cast<MemRefType>(memref.getType());
168 auto allocType =
169 MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
170 AffineMap(), memrefType.getMemorySpace());
171 Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
172 if (!hasStaticIdentityLayout(memrefType)) {
173 // Layout maps are already checked in `updateFuncOp`.
174 assert(hasFullyDynamicLayoutMap(memrefType) &&
175 "layout map not supported");
176 outParam =
177 builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
178 }
179 memref.replaceAllUsesWith(newValue: outParam);
180 outParams.push_back(Elt: outParam);
181 }
182
183 auto newOperands = llvm::to_vector<6>(op.getOperands());
184 newOperands.append(outParams.begin(), outParams.end());
185 auto newResultTypes = llvm::to_vector<6>(Range: llvm::map_range(
186 C&: replaceWithNewCallResults, F: [](Value v) { return v.getType(); }));
187 auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
188 newResultTypes, newOperands);
189 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
190 std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
191 op.erase();
192 });
193
194 return failure(isFailure: didFail);
195}
196
197LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
198 ModuleOp module,
199 const bufferization::BufferResultsToOutParamsOpts &options) {
200 for (auto func : module.getOps<func::FuncOp>()) {
201 if (!options.filterFn(&func))
202 continue;
203 SmallVector<BlockArgument, 6> appendedEntryArgs;
204 if (failed(
205 updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
206 return failure();
207 if (func.isExternal())
208 continue;
209 auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
210 Value to) {
211 builder.create<memref::CopyOp>(loc, from, to);
212 return success();
213 };
214 if (failed(updateReturnOps(func, appendedEntryArgs,
215 options.memCpyFn.value_or(defaultMemCpyFn)))) {
216 return failure();
217 }
218 }
219 if (failed(updateCalls(module, options)))
220 return failure();
221 return success();
222}
223
224namespace {
225struct BufferResultsToOutParamsPass
226 : bufferization::impl::BufferResultsToOutParamsBase<
227 BufferResultsToOutParamsPass> {
228 explicit BufferResultsToOutParamsPass(
229 const bufferization::BufferResultsToOutParamsOpts &options)
230 : options(options) {}
231
232 void runOnOperation() override {
233 // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
234 if (addResultAttribute)
235 options.addResultAttribute = true;
236
237 if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
238 options)))
239 return signalPassFailure();
240 }
241
242private:
243 bufferization::BufferResultsToOutParamsOpts options;
244};
245} // namespace
246
247std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
248 const bufferization::BufferResultsToOutParamsOpts &options) {
249 return std::make_unique<BufferResultsToOutParamsPass>(args: options);
250}
251

source code of mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp