1 | //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" |
10 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
11 | |
12 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
13 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
14 | #include "mlir/IR/Operation.h" |
15 | #include "mlir/Pass/Pass.h" |
16 | |
17 | namespace mlir { |
18 | namespace bufferization { |
19 | #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS |
20 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
21 | } // namespace bufferization |
22 | } // namespace mlir |
23 | |
24 | using namespace mlir; |
25 | using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn; |
26 | using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; |
27 | |
28 | /// Return `true` if the given MemRef type has a fully dynamic layout. |
29 | static bool hasFullyDynamicLayoutMap(MemRefType type) { |
30 | int64_t offset; |
31 | SmallVector<int64_t, 4> strides; |
32 | if (failed(type.getStridesAndOffset(strides, offset))) |
33 | return false; |
34 | if (!llvm::all_of(strides, ShapedType::isDynamic)) |
35 | return false; |
36 | if (!ShapedType::isDynamic(offset)) |
37 | return false; |
38 | return true; |
39 | } |
40 | |
41 | /// Return `true` if the given MemRef type has a static identity layout (i.e., |
42 | /// no layout). |
43 | static bool hasStaticIdentityLayout(MemRefType type) { |
44 | return type.getLayout().isIdentity(); |
45 | } |
46 | |
47 | // Updates the func op and entry block. |
48 | // |
49 | // Any args appended to the entry block are added to `appendedEntryArgs`. |
50 | // If `addResultAttribute` is true, adds the unit attribute `bufferize.result` |
51 | // to each newly created function argument. |
52 | static LogicalResult |
53 | updateFuncOp(func::FuncOp func, |
54 | SmallVectorImpl<BlockArgument> &appendedEntryArgs, |
55 | bool addResultAttribute) { |
56 | auto functionType = func.getFunctionType(); |
57 | |
58 | // Collect information about the results will become appended arguments. |
59 | SmallVector<Type, 6> erasedResultTypes; |
60 | BitVector erasedResultIndices(functionType.getNumResults()); |
61 | for (const auto &resultType : llvm::enumerate(functionType.getResults())) { |
62 | if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) { |
63 | if (!hasStaticIdentityLayout(memrefType) && |
64 | !hasFullyDynamicLayoutMap(memrefType)) { |
65 | // Only buffers with static identity layout can be allocated. These can |
66 | // be casted to memrefs with fully dynamic layout map. Other layout maps |
67 | // are not supported. |
68 | return func->emitError() |
69 | << "cannot create out param for result with unsupported layout" ; |
70 | } |
71 | erasedResultIndices.set(resultType.index()); |
72 | erasedResultTypes.push_back(memrefType); |
73 | } |
74 | } |
75 | |
76 | // Add the new arguments to the function type. |
77 | auto newArgTypes = llvm::to_vector<6>( |
78 | llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); |
79 | auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, |
80 | functionType.getResults()); |
81 | func.setType(newFunctionType); |
82 | |
83 | // Transfer the result attributes to arg attributes. |
84 | auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); |
85 | for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { |
86 | func.setArgAttrs(functionType.getNumInputs() + i, |
87 | func.getResultAttrs(*erasedIndicesIt)); |
88 | if (addResultAttribute) |
89 | func.setArgAttr(functionType.getNumInputs() + i, |
90 | StringAttr::get(func.getContext(), "bufferize.result" ), |
91 | UnitAttr::get(func.getContext())); |
92 | } |
93 | |
94 | // Erase the results. |
95 | if (failed(func.eraseResults(erasedResultIndices))) |
96 | return failure(); |
97 | |
98 | // Add the new arguments to the entry block if the function is not external. |
99 | if (func.isExternal()) |
100 | return success(); |
101 | Location loc = func.getLoc(); |
102 | for (Type type : erasedResultTypes) |
103 | appendedEntryArgs.push_back(Elt: func.front().addArgument(type, loc)); |
104 | |
105 | return success(); |
106 | } |
107 | |
108 | // Updates all ReturnOps in the scope of the given func::FuncOp by either |
109 | // keeping them as return values or copying the associated buffer contents into |
110 | // the given out-params. |
111 | static LogicalResult |
112 | updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs, |
113 | const bufferization::BufferResultsToOutParamsOpts &options) { |
114 | auto res = func.walk([&](func::ReturnOp op) { |
115 | SmallVector<Value, 6> copyIntoOutParams; |
116 | SmallVector<Value, 6> keepAsReturnOperands; |
117 | for (Value operand : op.getOperands()) { |
118 | if (isa<MemRefType>(operand.getType())) |
119 | copyIntoOutParams.push_back(operand); |
120 | else |
121 | keepAsReturnOperands.push_back(operand); |
122 | } |
123 | OpBuilder builder(op); |
124 | for (auto [orig, arg] : llvm::zip(t&: copyIntoOutParams, u&: appendedEntryArgs)) { |
125 | if (options.hoistStaticAllocs && |
126 | isa_and_nonnull<bufferization::AllocationOpInterface>( |
127 | orig.getDefiningOp()) && |
128 | mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) { |
129 | orig.replaceAllUsesWith(newValue: arg); |
130 | orig.getDefiningOp()->erase(); |
131 | } else { |
132 | if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg))) |
133 | return WalkResult::interrupt(); |
134 | } |
135 | } |
136 | builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands); |
137 | op.erase(); |
138 | return WalkResult::advance(); |
139 | }); |
140 | return failure(res.wasInterrupted()); |
141 | } |
142 | |
143 | // Updates all CallOps in the scope of the given ModuleOp by allocating |
144 | // temporary buffers for newly introduced out params. |
145 | static LogicalResult |
146 | updateCalls(ModuleOp module, |
147 | const bufferization::BufferResultsToOutParamsOpts &options) { |
148 | bool didFail = false; |
149 | SymbolTable symtab(module); |
150 | module.walk([&](func::CallOp op) { |
151 | auto callee = symtab.lookup<func::FuncOp>(op.getCallee()); |
152 | if (!callee) { |
153 | op.emitError() << "cannot find callee '" << op.getCallee() << "' in " |
154 | << "symbol table" ; |
155 | didFail = true; |
156 | return; |
157 | } |
158 | if (!options.filterFn(&callee)) |
159 | return; |
160 | SmallVector<Value, 6> replaceWithNewCallResults; |
161 | SmallVector<Value, 6> replaceWithOutParams; |
162 | for (OpResult result : op.getResults()) { |
163 | if (isa<MemRefType>(result.getType())) |
164 | replaceWithOutParams.push_back(result); |
165 | else |
166 | replaceWithNewCallResults.push_back(result); |
167 | } |
168 | SmallVector<Value, 6> outParams; |
169 | OpBuilder builder(op); |
170 | for (Value memref : replaceWithOutParams) { |
171 | if (!cast<MemRefType>(memref.getType()).hasStaticShape()) { |
172 | op.emitError() |
173 | << "cannot create out param for dynamically shaped result" ; |
174 | didFail = true; |
175 | return; |
176 | } |
177 | auto memrefType = cast<MemRefType>(memref.getType()); |
178 | auto allocType = |
179 | MemRefType::get(memrefType.getShape(), memrefType.getElementType(), |
180 | AffineMap(), memrefType.getMemorySpace()); |
181 | auto maybeOutParam = |
182 | options.allocationFn(builder, op.getLoc(), allocType); |
183 | if (failed(maybeOutParam)) { |
184 | op.emitError() << "failed to create allocation op" ; |
185 | didFail = true; |
186 | return; |
187 | } |
188 | Value outParam = maybeOutParam.value(); |
189 | if (!hasStaticIdentityLayout(memrefType)) { |
190 | // Layout maps are already checked in `updateFuncOp`. |
191 | assert(hasFullyDynamicLayoutMap(memrefType) && |
192 | "layout map not supported" ); |
193 | outParam = |
194 | builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam); |
195 | } |
196 | memref.replaceAllUsesWith(newValue: outParam); |
197 | outParams.push_back(Elt: outParam); |
198 | } |
199 | |
200 | auto newOperands = llvm::to_vector<6>(op.getOperands()); |
201 | newOperands.append(outParams.begin(), outParams.end()); |
202 | auto newResultTypes = llvm::to_vector<6>(Range: llvm::map_range( |
203 | C&: replaceWithNewCallResults, F: [](Value v) { return v.getType(); })); |
204 | auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(), |
205 | newResultTypes, newOperands); |
206 | for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) |
207 | std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); |
208 | op.erase(); |
209 | }); |
210 | |
211 | return failure(IsFailure: didFail); |
212 | } |
213 | |
214 | LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( |
215 | ModuleOp module, |
216 | const bufferization::BufferResultsToOutParamsOpts &options) { |
217 | for (auto func : module.getOps<func::FuncOp>()) { |
218 | if (!options.filterFn(&func)) |
219 | continue; |
220 | SmallVector<BlockArgument, 6> appendedEntryArgs; |
221 | if (failed( |
222 | updateFuncOp(func, appendedEntryArgs, options.addResultAttribute))) |
223 | return failure(); |
224 | if (func.isExternal()) |
225 | continue; |
226 | if (failed(updateReturnOps(func, appendedEntryArgs, options))) { |
227 | return failure(); |
228 | } |
229 | } |
230 | if (failed(updateCalls(module, options))) |
231 | return failure(); |
232 | return success(); |
233 | } |
234 | |
235 | namespace { |
236 | struct BufferResultsToOutParamsPass |
237 | : bufferization::impl::BufferResultsToOutParamsPassBase< |
238 | BufferResultsToOutParamsPass> { |
239 | using Base::Base; |
240 | |
241 | void runOnOperation() override { |
242 | // Convert from pass options in tablegen to BufferResultsToOutParamsOpts. |
243 | if (addResultAttribute) |
244 | options.addResultAttribute = true; |
245 | if (hoistStaticAllocs) |
246 | options.hoistStaticAllocs = true; |
247 | |
248 | if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), |
249 | options))) |
250 | return signalPassFailure(); |
251 | } |
252 | |
253 | private: |
254 | bufferization::BufferResultsToOutParamsOpts options; |
255 | }; |
256 | } // namespace |
257 | |