1 | //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Module Bufferization is an extension of One-Shot Bufferize that |
10 | // bufferizes function boundaries. It provides `BufferizableOpInterface` |
11 | // implementations for FuncOp, CallOp and ReturnOp. |
12 | // |
13 | // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. |
14 | // This function analyzes the given module and determines the order of analysis |
15 | // and bufferization: Functions that are called are processed before their |
16 | // respective callers. |
17 | // |
18 | // After analyzing a FuncOp, additional information about its bbArgs is |
19 | // gathered and stored in `FuncAnalysisState`. |
20 | // |
21 | // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs |
22 | // for |
23 | // each tensor return value (if any). |
24 | // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is |
25 | // read/written. |
26 | // |
27 | // Module Bufferization implements the following calling convention. |
28 | // |
29 | // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always |
30 | // be written to in-place. |
31 | // * If a tensor operand of a CallOp is read after the CallOp, the operand of |
32 | // the CallOp must bufferize out-of-place. |
33 | // |
34 | // Example: The tensor.insert op bufferizes in-place because it is allowed to |
35 | // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize |
36 | // out-of-place because `%t0` is modified by the callee but read by the |
37 | // tensor.extract op. The analysis of CallOps decides whether an OpOperand must |
38 | // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. |
39 | // ``` |
40 | // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { |
41 | // %f = ... : f32 |
42 | // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> |
43 | // return %0 : tensor<?xf32> |
44 | // } |
45 | // |
46 | // func @caller() -> () { |
47 | // %t0 = ... : tensor<?xf32> |
48 | // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) |
49 | // %2 = tensor.extract %1[...] : tensor<?xf32> |
50 | // } |
51 | // ``` |
52 | // |
53 | // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot |
54 | // analyze the function body. In such a case, the CallOp analysis conservatively |
55 | // assumes that each tensor OpOperand is both read and written. |
56 | // |
57 | // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked |
58 | // as "not reading" and/or "not writing". |
59 | |
60 | #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" |
61 | |
62 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
63 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
64 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
65 | #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" |
66 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
67 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" |
68 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
69 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
70 | #include "mlir/IR/BuiltinTypes.h" |
71 | #include "mlir/IR/Operation.h" |
72 | |
73 | using namespace mlir; |
74 | using namespace mlir::bufferization; |
75 | using namespace mlir::bufferization::func_ext; |
76 | |
77 | /// A mapping of FuncOps to their callers. |
78 | using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>; |
79 | |
80 | /// Get or create FuncAnalysisState. |
81 | static FuncAnalysisState & |
82 | getOrCreateFuncAnalysisState(OneShotAnalysisState &state) { |
83 | auto *result = state.getExtension<FuncAnalysisState>(); |
84 | if (result) |
85 | return *result; |
86 | return state.addExtension<FuncAnalysisState>(); |
87 | } |
88 | |
89 | namespace { |
90 | |
91 | /// Annotate IR with the results of the analysis. For testing purposes only. |
92 | static void annotateEquivalentReturnBbArg(OpOperand &returnVal, |
93 | BlockArgument bbArg) { |
94 | const char *kEquivalentArgsAttr = "__equivalent_func_args__" ; |
95 | Operation *op = returnVal.getOwner(); |
96 | |
97 | SmallVector<int64_t> equivBbArgs; |
98 | if (op->hasAttr(name: kEquivalentArgsAttr)) { |
99 | auto attr = cast<ArrayAttr>(op->getAttr(name: kEquivalentArgsAttr)); |
100 | equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { |
101 | return cast<IntegerAttr>(a).getValue().getSExtValue(); |
102 | })); |
103 | } else { |
104 | equivBbArgs.append(NumInputs: op->getNumOperands(), Elt: -1); |
105 | } |
106 | equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); |
107 | |
108 | OpBuilder b(op->getContext()); |
109 | op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); |
110 | } |
111 | |
112 | /// Store function BlockArguments that are equivalent to/aliasing a returned |
113 | /// value in FuncAnalysisState. |
114 | static LogicalResult |
115 | aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, |
116 | FuncAnalysisState &funcState) { |
117 | if (funcOp.getBody().empty()) { |
118 | // No function body available. Conservatively assume that every tensor |
119 | // return value may alias with any tensor bbArg. |
120 | FunctionType type = funcOp.getFunctionType(); |
121 | for (const auto &inputIt : llvm::enumerate(type.getInputs())) { |
122 | if (!isa<TensorType>(inputIt.value())) |
123 | continue; |
124 | for (const auto &resultIt : llvm::enumerate(type.getResults())) { |
125 | if (!isa<TensorType>(resultIt.value())) |
126 | continue; |
127 | int64_t returnIdx = resultIt.index(); |
128 | int64_t bbArgIdx = inputIt.index(); |
129 | funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); |
130 | } |
131 | } |
132 | return success(); |
133 | } |
134 | |
135 | // Find all func.return ops. |
136 | SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp); |
137 | assert(!returnOps.empty() && "expected at least one ReturnOp" ); |
138 | |
139 | // Build alias sets. Merge all aliases from all func.return ops. |
140 | for (BlockArgument bbArg : funcOp.getArguments()) { |
141 | if (isa<RankedTensorType>(bbArg.getType())) { |
142 | int64_t bbArgIdx = bbArg.getArgNumber(); |
143 | // Store aliases in a set, so that we don't add the same alias twice. |
144 | SetVector<int64_t> aliases; |
145 | for (func::ReturnOp returnOp : returnOps) { |
146 | for (OpOperand &returnVal : returnOp->getOpOperands()) { |
147 | if (isa<RankedTensorType>(returnVal.get().getType())) { |
148 | int64_t returnIdx = returnVal.getOperandNumber(); |
149 | if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) |
150 | aliases.insert(returnIdx); |
151 | } |
152 | } |
153 | } |
154 | for (int64_t alias : aliases) |
155 | funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias); |
156 | } |
157 | } |
158 | |
159 | // Build equivalence sets. |
160 | // Helper function that finds an equivalent block argument index for the |
161 | // given OpOperand. Return std::nullopt if no equivalent block argument could |
162 | // be found. |
163 | auto findEquivalentBlockArgIdx = |
164 | [&](OpOperand &opOperand) -> std::optional<int64_t> { |
165 | Value v = opOperand.get(); |
166 | if (!isa<TensorType>(Val: v.getType())) |
167 | return std::nullopt; |
168 | for (BlockArgument bbArg : funcOp.getArguments()) { |
169 | if (isa<RankedTensorType>(bbArg.getType())) { |
170 | if (state.areEquivalentBufferizedValues(v, bbArg)) { |
171 | if (state.getOptions().testAnalysisOnly) |
172 | annotateEquivalentReturnBbArg(opOperand, bbArg); |
173 | return bbArg.getArgNumber(); |
174 | } |
175 | } |
176 | } |
177 | return std::nullopt; |
178 | }; |
179 | |
180 | int64_t numResults = returnOps.front()->getNumOperands(); |
181 | for (int64_t i = 0; i < numResults; ++i) { |
182 | // Find the equivalent block argument index for the i-th operand of the |
183 | // first func.return op. |
184 | std::optional<int64_t> maybeEquiv = |
185 | findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i)); |
186 | if (!maybeEquiv.has_value()) |
187 | continue; |
188 | int64_t bbArgIdx = *maybeEquiv; |
189 | bool allEquiv = true; |
190 | |
191 | // Check if all other func.return ops have the same equivalent block |
192 | // argument for the i-th operand. In contrast to aliasing information, |
193 | // which is just "merged", equivalence information must match across all |
194 | // func.return ops. |
195 | for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) { |
196 | std::optional<int64_t> maybeEquiv = |
197 | findEquivalentBlockArgIdx(returnOp->getOpOperand(i)); |
198 | if (maybeEquiv != bbArgIdx) { |
199 | allEquiv = false; |
200 | break; |
201 | } |
202 | } |
203 | |
204 | // All func.return ops have the same equivalent block argument for the i-th |
205 | // operand. |
206 | if (allEquiv) |
207 | funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx; |
208 | } |
209 | |
210 | return success(); |
211 | } |
212 | |
213 | static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead, |
214 | bool isWritten) { |
215 | OpBuilder b(funcOp.getContext()); |
216 | Attribute accessType; |
217 | if (isRead && isWritten) { |
218 | accessType = b.getStringAttr("read-write" ); |
219 | } else if (isRead) { |
220 | accessType = b.getStringAttr("read" ); |
221 | } else if (isWritten) { |
222 | accessType = b.getStringAttr("write" ); |
223 | } else { |
224 | accessType = b.getStringAttr("none" ); |
225 | } |
226 | funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName, |
227 | accessType); |
228 | } |
229 | |
230 | /// Determine which FuncOp bbArgs are read and which are written. When run on a |
231 | /// function with unknown ops, we conservatively assume that such ops bufferize |
232 | /// to a read + write. |
233 | static LogicalResult |
234 | funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, |
235 | FuncAnalysisState &funcState) { |
236 | for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e; |
237 | ++idx) { |
238 | // Skip non-tensor arguments. |
239 | if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx))) |
240 | continue; |
241 | bool isRead; |
242 | bool isWritten; |
243 | if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>( |
244 | idx, BufferizationDialect::kBufferAccessAttrName)) { |
245 | // Buffer access behavior is specified on the function. Skip the analysis. |
246 | StringRef str = accessAttr.getValue(); |
247 | isRead = str == "read" || str == "read-write" ; |
248 | isWritten = str == "write" || str == "read-write" ; |
249 | } else if (funcOp.getBody().empty()) { |
250 | // If the function has no body, conservatively assume that all args are |
251 | // read + written. |
252 | isRead = true; |
253 | isWritten = true; |
254 | } else { |
255 | // Analyze the body of the function. |
256 | BlockArgument bbArg = funcOp.getArgument(idx); |
257 | isRead = state.isValueRead(bbArg); |
258 | isWritten = state.isValueWritten(value: bbArg); |
259 | } |
260 | |
261 | if (state.getOptions().testAnalysisOnly) |
262 | annotateFuncArgAccess(funcOp, idx, isRead, isWritten); |
263 | if (isRead) |
264 | funcState.readBbArgs[funcOp].insert(idx); |
265 | if (isWritten) |
266 | funcState.writtenBbArgs[funcOp].insert(idx); |
267 | } |
268 | |
269 | return success(); |
270 | } |
271 | } // namespace |
272 | |
273 | /// Remove bufferization attributes on FuncOp arguments. |
274 | static void removeBufferizationAttributes(BlockArgument bbArg) { |
275 | auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); |
276 | funcOp.removeArgAttr(bbArg.getArgNumber(), |
277 | BufferizationDialect::kBufferLayoutAttrName); |
278 | funcOp.removeArgAttr(bbArg.getArgNumber(), |
279 | BufferizationDialect::kWritableAttrName); |
280 | } |
281 | |
282 | /// Return the func::FuncOp called by `callOp`. |
283 | static func::FuncOp |
284 | getCalledFunction(func::CallOp callOp, |
285 | mlir::SymbolTableCollection &symbolTable) { |
286 | SymbolRefAttr sym = |
287 | llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); |
288 | if (!sym) |
289 | return nullptr; |
290 | return dyn_cast_or_null<func::FuncOp>( |
291 | symbolTable.lookupNearestSymbolFrom(callOp, sym)); |
292 | } |
293 | |
294 | /// Return "true" if the given function signature has tensor semantics. |
295 | static bool hasTensorSignature(func::FuncOp funcOp) { |
296 | return llvm::any_of(funcOp.getFunctionType().getInputs(), |
297 | llvm::IsaPred<TensorType>) || |
298 | llvm::any_of(funcOp.getFunctionType().getResults(), |
299 | llvm::IsaPred<TensorType>); |
300 | } |
301 | |
302 | /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by |
303 | /// callee-caller order (i.e., callees without callers first). Store all |
304 | /// remaining functions (i.e., the ones that call each other recursively) in |
305 | /// `remainingFuncOps`. Does not traverse nested symbol tables. |
306 | /// |
307 | /// Store the map of FuncOp to all its callers in `callerMap`. |
308 | /// |
309 | /// Return `failure()` if we are unable to retrieve the called FuncOp from |
310 | /// any func::CallOp. |
311 | static LogicalResult getFuncOpsOrderedByCalls( |
312 | ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, |
313 | SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap, |
314 | SymbolTableCollection &symbolTables) { |
315 | // For each FuncOp, the set of functions called by it (i.e. the union of |
316 | // symbols of all nested func::CallOp). |
317 | DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; |
318 | // For each FuncOp, the number of func::CallOp it contains. |
319 | DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; |
320 | |
321 | for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) { |
322 | // Collect function calls and populate the caller map. |
323 | numberCallOpsContainedInFuncOp[funcOp] = 0; |
324 | WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { |
325 | func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); |
326 | assert(calledFunction && "could not retrieved called func::FuncOp" ); |
327 | // If the called function does not have any tensors in its signature, then |
328 | // it is not necessary to bufferize the callee before the caller. |
329 | if (!hasTensorSignature(calledFunction)) |
330 | return WalkResult::skip(); |
331 | |
332 | callerMap[calledFunction].insert(callOp); |
333 | if (calledBy[calledFunction].insert(funcOp).second) { |
334 | numberCallOpsContainedInFuncOp[funcOp]++; |
335 | } |
336 | return WalkResult::advance(); |
337 | }); |
338 | if (res.wasInterrupted()) |
339 | return failure(); |
340 | } |
341 | |
342 | // Iteratively remove function operations that do not call any of the |
343 | // functions remaining in the callCounter map and add them to ordered list. |
344 | SmallVector<func::FuncOp> worklist; |
345 | |
346 | for (const auto &entry : numberCallOpsContainedInFuncOp) { |
347 | if (entry.second == 0) |
348 | worklist.push_back(entry.first); |
349 | } |
350 | |
351 | while (!worklist.empty()) { |
352 | func::FuncOp func = worklist.pop_back_val(); |
353 | orderedFuncOps.push_back(func); |
354 | |
355 | for (func::FuncOp caller : calledBy[func]) { |
356 | auto &count = numberCallOpsContainedInFuncOp[caller]; |
357 | |
358 | if (--count == 0) |
359 | worklist.push_back(caller); |
360 | } |
361 | |
362 | numberCallOpsContainedInFuncOp.erase(func); |
363 | } |
364 | |
365 | // Put all other functions in the list of remaining functions. These are |
366 | // functions that call each other circularly. |
367 | for (auto it : numberCallOpsContainedInFuncOp) |
368 | remainingFuncOps.push_back(it.first); |
369 | |
370 | return success(); |
371 | } |
372 | |
373 | /// Helper function that extracts the source from a memref.cast. If the given |
374 | /// value is not a memref.cast result, simply returns the given value. |
375 | static Value unpackCast(Value v) { |
376 | auto castOp = v.getDefiningOp<memref::CastOp>(); |
377 | if (!castOp) |
378 | return v; |
379 | return castOp.getSource(); |
380 | } |
381 | |
382 | /// Helper function that returns the return types (skipping casts) of the given |
383 | /// func.return ops. This function returns as many types as the return ops have |
384 | /// operands. If the i-th operand is not the same for all func.return ops, then |
385 | /// the i-th returned type is an "empty" type. |
386 | static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) { |
387 | assert(!returnOps.empty() && "expected at least one ReturnOp" ); |
388 | int numOperands = returnOps.front()->getNumOperands(); |
389 | |
390 | // Helper function that unpacks memref.cast ops and returns the type. |
391 | auto getSourceType = [&](Value v) { return unpackCast(v).getType(); }; |
392 | |
393 | SmallVector<Type> result; |
394 | for (int i = 0; i < numOperands; ++i) { |
395 | // Get the type of the i-th operand of the first func.return ops. |
396 | Type t = getSourceType(returnOps.front()->getOperand(i)); |
397 | |
398 | // Check if all other func.return ops have a matching operand type. |
399 | for (int j = 1; j < static_cast<int>(returnOps.size()); ++j) |
400 | if (getSourceType(returnOps[j]->getOperand(i)) != t) |
401 | t = Type(); |
402 | |
403 | result.push_back(Elt: t); |
404 | } |
405 | |
406 | return result; |
407 | } |
408 | |
409 | /// Fold return values that are memref casts and update function return types. |
410 | /// |
411 | /// During FuncOp bufferization, the exact type of the returned memrefs (if any) |
412 | /// is not known yet. Therefore, the bufferization uses memref types with the |
413 | /// most generic layout map as function return types. After bufferizing the |
414 | /// entire function body, a more concise memref type can potentially be used for |
415 | /// the return type of the function. |
416 | static void foldMemRefCasts(func::FuncOp funcOp) { |
417 | // There is nothing to do for bodiless ops. |
418 | if (funcOp.getBody().empty()) |
419 | return; |
420 | |
421 | // Compute the common result types of all return ops. |
422 | SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp); |
423 | SmallVector<Type> resultTypes = getReturnTypes(returnOps); |
424 | |
425 | // Remove direct casts. |
426 | for (func::ReturnOp returnOp : returnOps) { |
427 | for (OpOperand &operand : returnOp->getOpOperands()) { |
428 | // Bail if no common result type was found. |
429 | if (resultTypes[operand.getOperandNumber()]) { |
430 | operand.set(unpackCast(operand.get())); |
431 | } |
432 | } |
433 | } |
434 | |
435 | // Fill in the missing result types that were not the same among all |
436 | // func.return ops. |
437 | for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) { |
438 | if (resultTypes[i]) |
439 | continue; |
440 | resultTypes[i] = funcOp.getFunctionType().getResult(i); |
441 | } |
442 | |
443 | // Update the function type. |
444 | auto newFuncType = FunctionType::get( |
445 | funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); |
446 | funcOp.setType(newFuncType); |
447 | } |
448 | |
449 | LogicalResult |
450 | mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, |
451 | OneShotAnalysisState &state, |
452 | BufferizationStatistics *statistics) { |
453 | assert(state.getOptions().bufferizeFunctionBoundaries && |
454 | "expected that function boundary bufferization is activated" ); |
455 | FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); |
456 | |
457 | // A list of non-circular functions in the order in which they are analyzed |
458 | // and bufferized. |
459 | SmallVector<func::FuncOp> orderedFuncOps; |
460 | // A list of all other functions. I.e., functions that call each other |
461 | // recursively. For these, we analyze the function body but not the function |
462 | // boundary. |
463 | SmallVector<func::FuncOp> remainingFuncOps; |
464 | |
465 | // A mapping of FuncOps to their callers. |
466 | FuncCallerMap callerMap; |
467 | |
468 | if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, |
469 | remainingFuncOps, callerMap, |
470 | funcState.symbolTables))) |
471 | return failure(); |
472 | |
473 | // Analyze functions in order. Starting with functions that are not calling |
474 | // any other functions. |
475 | for (func::FuncOp funcOp : orderedFuncOps) { |
476 | if (!state.getOptions().isOpAllowed(funcOp)) |
477 | continue; |
478 | |
479 | // Now analyzing function. |
480 | funcState.startFunctionAnalysis(funcOp); |
481 | |
482 | // Analyze funcOp. |
483 | if (failed(analyzeOp(funcOp, state, statistics))) |
484 | return failure(); |
485 | |
486 | // Run some extra function analyses. |
487 | if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) || |
488 | failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState))) |
489 | return failure(); |
490 | |
491 | // Mark op as fully analyzed. |
492 | funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; |
493 | } |
494 | |
495 | // Analyze all other functions. All function boundary analyses are skipped. |
496 | for (func::FuncOp funcOp : remainingFuncOps) { |
497 | if (!state.getOptions().isOpAllowed(funcOp)) |
498 | continue; |
499 | |
500 | // Analyze funcOp. |
501 | if (failed(analyzeOp(funcOp, state, statistics))) |
502 | return failure(); |
503 | |
504 | // TODO: We currently skip all function argument analyses for functions |
505 | // that call each other circularly. These analyses do not support recursive |
506 | // calls yet. The `BufferizableOpInterface` implementations of `func` |
507 | // dialect ops return conservative results in the absence of analysis |
508 | // information. |
509 | } |
510 | |
511 | return success(); |
512 | } |
513 | |
514 | void mlir::bufferization::removeBufferizationAttributesInModule( |
515 | ModuleOp moduleOp) { |
516 | for (auto op : moduleOp.getOps<func::FuncOp>()) { |
517 | for (BlockArgument bbArg : op.getArguments()) |
518 | removeBufferizationAttributes(bbArg); |
519 | } |
520 | } |
521 | |
522 | LogicalResult mlir::bufferization::bufferizeModuleOp( |
523 | ModuleOp moduleOp, const OneShotBufferizationOptions &options, |
524 | BufferizationState &state, BufferizationStatistics *statistics) { |
525 | assert(options.bufferizeFunctionBoundaries && |
526 | "expected that function boundary bufferization is activated" ); |
527 | IRRewriter rewriter(moduleOp.getContext()); |
528 | |
529 | // A list of non-circular functions in the order in which they are analyzed |
530 | // and bufferized. |
531 | SmallVector<func::FuncOp> orderedFuncOps; |
532 | // A list of all other functions. I.e., functions that call each other |
533 | // recursively. For these, we analyze the function body but not the function |
534 | // boundary. |
535 | SmallVector<func::FuncOp> remainingFuncOps; |
536 | |
537 | // A mapping of FuncOps to their callers. |
538 | FuncCallerMap callerMap; |
539 | |
540 | // Try to bufferize functions in calling order. I.e., first bufferize |
541 | // functions that do not call other functions. This allows us to infer |
542 | // accurate buffer types for function return values. Functions that call |
543 | // each other recursively are bufferized in an unspecified order at the end. |
544 | // We may use unnecessarily "complex" (in terms of layout map) buffer types. |
545 | if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, |
546 | remainingFuncOps, callerMap, |
547 | state.getSymbolTables()))) |
548 | return failure(); |
549 | llvm::append_range(orderedFuncOps, remainingFuncOps); |
550 | |
551 | // Bufferize functions. |
552 | for (func::FuncOp funcOp : orderedFuncOps) { |
553 | // Note: It would be good to apply cleanups here but we cannot as aliasInfo |
554 | // would be invalidated. |
555 | |
556 | if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) { |
557 | // This function was not analyzed and RaW conflicts were not resolved. |
558 | // Buffer copies must be inserted before every write. |
559 | OneShotBufferizationOptions updatedOptions = options; |
560 | updatedOptions.copyBeforeWrite = true; |
561 | if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics))) |
562 | return failure(); |
563 | } else { |
564 | if (failed(bufferizeOp(funcOp, options, state, statistics))) |
565 | return failure(); |
566 | } |
567 | |
568 | // Change buffer return types to more precise layout maps. |
569 | if (options.inferFunctionResultLayout) |
570 | foldMemRefCasts(funcOp); |
571 | } |
572 | |
573 | // Bufferize all other ops. |
574 | for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { |
575 | // Functions were already bufferized. |
576 | if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) |
577 | continue; |
578 | if (failed(bufferizeOp(&op, options, state, statistics))) |
579 | return failure(); |
580 | } |
581 | |
582 | // Post-pass cleanup of function argument attributes. |
583 | removeBufferizationAttributesInModule(moduleOp); |
584 | |
585 | return success(); |
586 | } |
587 | |
588 | LogicalResult mlir::bufferization::runOneShotModuleBufferize( |
589 | ModuleOp moduleOp, const OneShotBufferizationOptions &options, |
590 | BufferizationState &state, BufferizationStatistics *statistics) { |
591 | assert(options.bufferizeFunctionBoundaries && |
592 | "expected that function boundary bufferization is activated" ); |
593 | assert(!(options.copyBeforeWrite && options.testAnalysisOnly) && |
594 | "invalid combination of bufferization flags" ); |
595 | if (!options.copyBeforeWrite) { |
596 | if (options.noAnalysisFuncFilter.empty()) { |
597 | if (failed(insertTensorCopies(moduleOp, options, state, statistics))) |
598 | return failure(); |
599 | } else { |
600 | // FuncOps whose names are specified in options.noAnalysisFuncFilter will |
601 | // not be analyzed. Ops in these FuncOps will not be analyzed as well. |
602 | OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) { |
603 | auto func = dyn_cast<func::FuncOp>(op); |
604 | if (!func) |
605 | func = op->getParentOfType<func::FuncOp>(); |
606 | if (func) |
607 | return llvm::is_contained(options.noAnalysisFuncFilter, |
608 | func.getSymName()); |
609 | return false; |
610 | }; |
611 | OneShotBufferizationOptions updatedOptions(options); |
612 | updatedOptions.opFilter.denyOperation(analysisFilterFn); |
613 | if (failed( |
614 | insertTensorCopies(moduleOp, updatedOptions, state, statistics))) |
615 | return failure(); |
616 | } |
617 | } |
618 | if (options.testAnalysisOnly) |
619 | return success(); |
620 | if (failed(bufferizeModuleOp(moduleOp, options, state, statistics))) |
621 | return failure(); |
622 | return success(); |
623 | } |
624 | |