1 | //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
13 | #include "mlir/IR/Operation.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | |
16 | namespace mlir { |
17 | namespace bufferization { |
18 | #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS |
19 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
20 | } // namespace bufferization |
21 | } // namespace mlir |
22 | |
23 | using namespace mlir; |
24 | using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; |
25 | |
26 | /// Return `true` if the given MemRef type has a fully dynamic layout. |
27 | static bool hasFullyDynamicLayoutMap(MemRefType type) { |
28 | int64_t offset; |
29 | SmallVector<int64_t, 4> strides; |
30 | if (failed(getStridesAndOffset(type, strides, offset))) |
31 | return false; |
32 | if (!llvm::all_of(strides, ShapedType::isDynamic)) |
33 | return false; |
34 | if (!ShapedType::isDynamic(offset)) |
35 | return false; |
36 | return true; |
37 | } |
38 | |
39 | /// Return `true` if the given MemRef type has a static identity layout (i.e., |
40 | /// no layout). |
41 | static bool hasStaticIdentityLayout(MemRefType type) { |
42 | return type.getLayout().isIdentity(); |
43 | } |
44 | |
45 | // Updates the func op and entry block. |
46 | // |
47 | // Any args appended to the entry block are added to `appendedEntryArgs`. |
48 | // If `addResultAttribute` is true, adds the unit attribute `bufferize.result` |
49 | // to each newly created function argument. |
50 | static LogicalResult |
51 | updateFuncOp(func::FuncOp func, |
52 | SmallVectorImpl<BlockArgument> &appendedEntryArgs, |
53 | bool addResultAttribute) { |
54 | auto functionType = func.getFunctionType(); |
55 | |
56 | // Collect information about the results will become appended arguments. |
57 | SmallVector<Type, 6> erasedResultTypes; |
58 | BitVector erasedResultIndices(functionType.getNumResults()); |
59 | for (const auto &resultType : llvm::enumerate(functionType.getResults())) { |
60 | if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) { |
61 | if (!hasStaticIdentityLayout(memrefType) && |
62 | !hasFullyDynamicLayoutMap(memrefType)) { |
63 | // Only buffers with static identity layout can be allocated. These can |
64 | // be casted to memrefs with fully dynamic layout map. Other layout maps |
65 | // are not supported. |
66 | return func->emitError() |
67 | << "cannot create out param for result with unsupported layout" ; |
68 | } |
69 | erasedResultIndices.set(resultType.index()); |
70 | erasedResultTypes.push_back(memrefType); |
71 | } |
72 | } |
73 | |
74 | // Add the new arguments to the function type. |
75 | auto newArgTypes = llvm::to_vector<6>( |
76 | llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); |
77 | auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, |
78 | functionType.getResults()); |
79 | func.setType(newFunctionType); |
80 | |
81 | // Transfer the result attributes to arg attributes. |
82 | auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); |
83 | for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { |
84 | func.setArgAttrs(functionType.getNumInputs() + i, |
85 | func.getResultAttrs(*erasedIndicesIt)); |
86 | if (addResultAttribute) |
87 | func.setArgAttr(functionType.getNumInputs() + i, |
88 | StringAttr::get(func.getContext(), "bufferize.result" ), |
89 | UnitAttr::get(func.getContext())); |
90 | } |
91 | |
92 | // Erase the results. |
93 | func.eraseResults(erasedResultIndices); |
94 | |
95 | // Add the new arguments to the entry block if the function is not external. |
96 | if (func.isExternal()) |
97 | return success(); |
98 | Location loc = func.getLoc(); |
99 | for (Type type : erasedResultTypes) |
100 | appendedEntryArgs.push_back(Elt: func.front().addArgument(type, loc)); |
101 | |
102 | return success(); |
103 | } |
104 | |
105 | // Updates all ReturnOps in the scope of the given func::FuncOp by either |
106 | // keeping them as return values or copying the associated buffer contents into |
107 | // the given out-params. |
108 | static LogicalResult updateReturnOps(func::FuncOp func, |
109 | ArrayRef<BlockArgument> appendedEntryArgs, |
110 | MemCpyFn memCpyFn) { |
111 | auto res = func.walk([&](func::ReturnOp op) { |
112 | SmallVector<Value, 6> copyIntoOutParams; |
113 | SmallVector<Value, 6> keepAsReturnOperands; |
114 | for (Value operand : op.getOperands()) { |
115 | if (isa<MemRefType>(operand.getType())) |
116 | copyIntoOutParams.push_back(operand); |
117 | else |
118 | keepAsReturnOperands.push_back(operand); |
119 | } |
120 | OpBuilder builder(op); |
121 | for (auto t : llvm::zip(t&: copyIntoOutParams, u&: appendedEntryArgs)) { |
122 | if (failed( |
123 | memCpyFn(builder, op.getLoc(), std::get<0>(t&: t), std::get<1>(t&: t)))) |
124 | return WalkResult::interrupt(); |
125 | } |
126 | builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands); |
127 | op.erase(); |
128 | return WalkResult::advance(); |
129 | }); |
130 | return failure(res.wasInterrupted()); |
131 | } |
132 | |
133 | // Updates all CallOps in the scope of the given ModuleOp by allocating |
134 | // temporary buffers for newly introduced out params. |
135 | static LogicalResult |
136 | updateCalls(ModuleOp module, |
137 | const bufferization::BufferResultsToOutParamsOpts &options) { |
138 | bool didFail = false; |
139 | SymbolTable symtab(module); |
140 | module.walk([&](func::CallOp op) { |
141 | auto callee = symtab.lookup<func::FuncOp>(op.getCallee()); |
142 | if (!callee) { |
143 | op.emitError() << "cannot find callee '" << op.getCallee() << "' in " |
144 | << "symbol table" ; |
145 | didFail = true; |
146 | return; |
147 | } |
148 | if (!options.filterFn(&callee)) |
149 | return; |
150 | SmallVector<Value, 6> replaceWithNewCallResults; |
151 | SmallVector<Value, 6> replaceWithOutParams; |
152 | for (OpResult result : op.getResults()) { |
153 | if (isa<MemRefType>(result.getType())) |
154 | replaceWithOutParams.push_back(result); |
155 | else |
156 | replaceWithNewCallResults.push_back(result); |
157 | } |
158 | SmallVector<Value, 6> outParams; |
159 | OpBuilder builder(op); |
160 | for (Value memref : replaceWithOutParams) { |
161 | if (!cast<MemRefType>(memref.getType()).hasStaticShape()) { |
162 | op.emitError() |
163 | << "cannot create out param for dynamically shaped result" ; |
164 | didFail = true; |
165 | return; |
166 | } |
167 | auto memrefType = cast<MemRefType>(memref.getType()); |
168 | auto allocType = |
169 | MemRefType::get(memrefType.getShape(), memrefType.getElementType(), |
170 | AffineMap(), memrefType.getMemorySpace()); |
171 | Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType); |
172 | if (!hasStaticIdentityLayout(memrefType)) { |
173 | // Layout maps are already checked in `updateFuncOp`. |
174 | assert(hasFullyDynamicLayoutMap(memrefType) && |
175 | "layout map not supported" ); |
176 | outParam = |
177 | builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam); |
178 | } |
179 | memref.replaceAllUsesWith(newValue: outParam); |
180 | outParams.push_back(Elt: outParam); |
181 | } |
182 | |
183 | auto newOperands = llvm::to_vector<6>(op.getOperands()); |
184 | newOperands.append(outParams.begin(), outParams.end()); |
185 | auto newResultTypes = llvm::to_vector<6>(Range: llvm::map_range( |
186 | C&: replaceWithNewCallResults, F: [](Value v) { return v.getType(); })); |
187 | auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(), |
188 | newResultTypes, newOperands); |
189 | for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) |
190 | std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); |
191 | op.erase(); |
192 | }); |
193 | |
194 | return failure(isFailure: didFail); |
195 | } |
196 | |
197 | LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( |
198 | ModuleOp module, |
199 | const bufferization::BufferResultsToOutParamsOpts &options) { |
200 | for (auto func : module.getOps<func::FuncOp>()) { |
201 | if (!options.filterFn(&func)) |
202 | continue; |
203 | SmallVector<BlockArgument, 6> appendedEntryArgs; |
204 | if (failed( |
205 | updateFuncOp(func, appendedEntryArgs, options.addResultAttribute))) |
206 | return failure(); |
207 | if (func.isExternal()) |
208 | continue; |
209 | auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from, |
210 | Value to) { |
211 | builder.create<memref::CopyOp>(loc, from, to); |
212 | return success(); |
213 | }; |
214 | if (failed(updateReturnOps(func, appendedEntryArgs, |
215 | options.memCpyFn.value_or(defaultMemCpyFn)))) { |
216 | return failure(); |
217 | } |
218 | } |
219 | if (failed(updateCalls(module, options))) |
220 | return failure(); |
221 | return success(); |
222 | } |
223 | |
224 | namespace { |
225 | struct BufferResultsToOutParamsPass |
226 | : bufferization::impl::BufferResultsToOutParamsBase< |
227 | BufferResultsToOutParamsPass> { |
228 | explicit BufferResultsToOutParamsPass( |
229 | const bufferization::BufferResultsToOutParamsOpts &options) |
230 | : options(options) {} |
231 | |
232 | void runOnOperation() override { |
233 | // Convert from pass options in tablegen to BufferResultsToOutParamsOpts. |
234 | if (addResultAttribute) |
235 | options.addResultAttribute = true; |
236 | |
237 | if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), |
238 | options))) |
239 | return signalPassFailure(); |
240 | } |
241 | |
242 | private: |
243 | bufferization::BufferResultsToOutParamsOpts options; |
244 | }; |
245 | } // namespace |
246 | |
247 | std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass( |
248 | const bufferization::BufferResultsToOutParamsOpts &options) { |
249 | return std::make_unique<BufferResultsToOutParamsPass>(args: options); |
250 | } |
251 | |