1//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Module Bufferization is an extension of One-Shot Bufferize that
10// bufferizes function boundaries. It provides `BufferizableOpInterface`
11// implementations for FuncOp, CallOp and ReturnOp.
12//
13// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
14// This function analyzes the given module and determines the order of analysis
15// and bufferization: Functions that are called are processed before their
16// respective callers.
17//
18// After analyzing a FuncOp, additional information about its bbArgs is
19// gathered and stored in `FuncAnalysisState`.
20//
21// * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
22// for
23// each tensor return value (if any).
24// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
25// read/written.
26//
27// Module Bufferization implements the following calling convention.
28//
29// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
30// be written to in-place.
31// * If a tensor operand of a CallOp is read after the CallOp, the operand of
32// the CallOp must bufferize out-of-place.
33//
34// Example: The tensor.insert op bufferizes in-place because it is allowed to
35// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
36// out-of-place because `%t0` is modified by the callee but read by the
37// tensor.extract op. The analysis of CallOps decides whether an OpOperand must
38// bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
39// ```
40// func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
41// %f = ... : f32
42// %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
43// return %0 : tensor<?xf32>
44// }
45//
46// func @caller() -> () {
47// %t0 = ... : tensor<?xf32>
48// %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
49// %2 = tensor.extract %1[...] : tensor<?xf32>
50// }
51// ```
52//
53// Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
54// analyze the function body. In such a case, the CallOp analysis conservatively
55// assumes that each tensor OpOperand is both read and written.
56//
57// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
58// as "not reading" and/or "not writing".
59
60#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
61
62#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
63#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
64#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
65#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
66#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
67#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
68#include "mlir/Dialect/Func/IR/FuncOps.h"
69#include "mlir/Dialect/MemRef/IR/MemRef.h"
70#include "mlir/IR/BuiltinTypes.h"
71#include "mlir/IR/Operation.h"
72
73using namespace mlir;
74using namespace mlir::bufferization;
75using namespace mlir::bufferization::func_ext;
76
77/// A mapping of FuncOps to their callers.
78using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
79
80/// Get or create FuncAnalysisState.
81static FuncAnalysisState &
82getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
83 auto *result = state.getExtension<FuncAnalysisState>();
84 if (result)
85 return *result;
86 return state.addExtension<FuncAnalysisState>();
87}
88
89namespace {
90
91/// Annotate IR with the results of the analysis. For testing purposes only.
92static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
93 BlockArgument bbArg) {
94 const char *kEquivalentArgsAttr = "__equivalent_func_args__";
95 Operation *op = returnVal.getOwner();
96
97 SmallVector<int64_t> equivBbArgs;
98 if (op->hasAttr(name: kEquivalentArgsAttr)) {
99 auto attr = cast<ArrayAttr>(op->getAttr(name: kEquivalentArgsAttr));
100 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
101 return cast<IntegerAttr>(a).getValue().getSExtValue();
102 }));
103 } else {
104 equivBbArgs.append(NumInputs: op->getNumOperands(), Elt: -1);
105 }
106 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
107
108 OpBuilder b(op->getContext());
109 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
110}
111
112/// Store function BlockArguments that are equivalent to/aliasing a returned
113/// value in FuncAnalysisState.
114static LogicalResult
115aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
116 FuncAnalysisState &funcState) {
117 if (funcOp.getBody().empty()) {
118 // No function body available. Conservatively assume that every tensor
119 // return value may alias with any tensor bbArg.
120 FunctionType type = funcOp.getFunctionType();
121 for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
122 if (!isa<TensorType>(inputIt.value()))
123 continue;
124 for (const auto &resultIt : llvm::enumerate(type.getResults())) {
125 if (!isa<TensorType>(resultIt.value()))
126 continue;
127 int64_t returnIdx = resultIt.index();
128 int64_t bbArgIdx = inputIt.index();
129 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
130 }
131 }
132 return success();
133 }
134
135 // Find all func.return ops.
136 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
137 assert(!returnOps.empty() && "expected at least one ReturnOp");
138
139 // Build alias sets. Merge all aliases from all func.return ops.
140 for (BlockArgument bbArg : funcOp.getArguments()) {
141 if (isa<RankedTensorType>(bbArg.getType())) {
142 int64_t bbArgIdx = bbArg.getArgNumber();
143 // Store aliases in a set, so that we don't add the same alias twice.
144 SetVector<int64_t> aliases;
145 for (func::ReturnOp returnOp : returnOps) {
146 for (OpOperand &returnVal : returnOp->getOpOperands()) {
147 if (isa<RankedTensorType>(returnVal.get().getType())) {
148 int64_t returnIdx = returnVal.getOperandNumber();
149 if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
150 aliases.insert(returnIdx);
151 }
152 }
153 }
154 for (int64_t alias : aliases)
155 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
156 }
157 }
158
159 // Build equivalence sets.
160 // Helper function that finds an equivalent block argument index for the
161 // given OpOperand. Return std::nullopt if no equivalent block argument could
162 // be found.
163 auto findEquivalentBlockArgIdx =
164 [&](OpOperand &opOperand) -> std::optional<int64_t> {
165 Value v = opOperand.get();
166 if (!isa<TensorType>(Val: v.getType()))
167 return std::nullopt;
168 for (BlockArgument bbArg : funcOp.getArguments()) {
169 if (isa<RankedTensorType>(bbArg.getType())) {
170 if (state.areEquivalentBufferizedValues(v, bbArg)) {
171 if (state.getOptions().testAnalysisOnly)
172 annotateEquivalentReturnBbArg(opOperand, bbArg);
173 return bbArg.getArgNumber();
174 }
175 }
176 }
177 return std::nullopt;
178 };
179
180 int64_t numResults = returnOps.front()->getNumOperands();
181 for (int64_t i = 0; i < numResults; ++i) {
182 // Find the equivalent block argument index for the i-th operand of the
183 // first func.return op.
184 std::optional<int64_t> maybeEquiv =
185 findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
186 if (!maybeEquiv.has_value())
187 continue;
188 int64_t bbArgIdx = *maybeEquiv;
189 bool allEquiv = true;
190
191 // Check if all other func.return ops have the same equivalent block
192 // argument for the i-th operand. In contrast to aliasing information,
193 // which is just "merged", equivalence information must match across all
194 // func.return ops.
195 for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
196 std::optional<int64_t> maybeEquiv =
197 findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
198 if (maybeEquiv != bbArgIdx) {
199 allEquiv = false;
200 break;
201 }
202 }
203
204 // All func.return ops have the same equivalent block argument for the i-th
205 // operand.
206 if (allEquiv)
207 funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
208 }
209
210 return success();
211}
212
213static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
214 bool isWritten) {
215 OpBuilder b(funcOp.getContext());
216 Attribute accessType;
217 if (isRead && isWritten) {
218 accessType = b.getStringAttr("read-write");
219 } else if (isRead) {
220 accessType = b.getStringAttr("read");
221 } else if (isWritten) {
222 accessType = b.getStringAttr("write");
223 } else {
224 accessType = b.getStringAttr("none");
225 }
226 funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
227 accessType);
228}
229
230/// Determine which FuncOp bbArgs are read and which are written. When run on a
231/// function with unknown ops, we conservatively assume that such ops bufferize
232/// to a read + write.
233static LogicalResult
234funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
235 FuncAnalysisState &funcState) {
236 for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
237 ++idx) {
238 // Skip non-tensor arguments.
239 if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
240 continue;
241 bool isRead;
242 bool isWritten;
243 if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
244 idx, BufferizationDialect::kBufferAccessAttrName)) {
245 // Buffer access behavior is specified on the function. Skip the analysis.
246 StringRef str = accessAttr.getValue();
247 isRead = str == "read" || str == "read-write";
248 isWritten = str == "write" || str == "read-write";
249 } else if (funcOp.getBody().empty()) {
250 // If the function has no body, conservatively assume that all args are
251 // read + written.
252 isRead = true;
253 isWritten = true;
254 } else {
255 // Analyze the body of the function.
256 BlockArgument bbArg = funcOp.getArgument(idx);
257 isRead = state.isValueRead(bbArg);
258 isWritten = state.isValueWritten(value: bbArg);
259 }
260
261 if (state.getOptions().testAnalysisOnly)
262 annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
263 if (isRead)
264 funcState.readBbArgs[funcOp].insert(idx);
265 if (isWritten)
266 funcState.writtenBbArgs[funcOp].insert(idx);
267 }
268
269 return success();
270}
271} // namespace
272
273/// Remove bufferization attributes on FuncOp arguments.
274static void removeBufferizationAttributes(BlockArgument bbArg) {
275 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
276 funcOp.removeArgAttr(bbArg.getArgNumber(),
277 BufferizationDialect::kBufferLayoutAttrName);
278 funcOp.removeArgAttr(bbArg.getArgNumber(),
279 BufferizationDialect::kWritableAttrName);
280}
281
282/// Return the func::FuncOp called by `callOp`.
283static func::FuncOp
284getCalledFunction(func::CallOp callOp,
285 mlir::SymbolTableCollection &symbolTable) {
286 SymbolRefAttr sym =
287 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
288 if (!sym)
289 return nullptr;
290 return dyn_cast_or_null<func::FuncOp>(
291 symbolTable.lookupNearestSymbolFrom(callOp, sym));
292}
293
294/// Return "true" if the given function signature has tensor semantics.
295static bool hasTensorSignature(func::FuncOp funcOp) {
296 return llvm::any_of(funcOp.getFunctionType().getInputs(),
297 llvm::IsaPred<TensorType>) ||
298 llvm::any_of(funcOp.getFunctionType().getResults(),
299 llvm::IsaPred<TensorType>);
300}
301
302/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
303/// callee-caller order (i.e., callees without callers first). Store all
304/// remaining functions (i.e., the ones that call each other recursively) in
305/// `remainingFuncOps`. Does not traverse nested symbol tables.
306///
307/// Store the map of FuncOp to all its callers in `callerMap`.
308///
309/// Return `failure()` if we are unable to retrieve the called FuncOp from
310/// any func::CallOp.
311static LogicalResult getFuncOpsOrderedByCalls(
312 ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313 SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
314 SymbolTableCollection &symbolTables) {
315 // For each FuncOp, the set of functions called by it (i.e. the union of
316 // symbols of all nested func::CallOp).
317 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
318 // For each FuncOp, the number of func::CallOp it contains.
319 DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
320
321 for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
322 // Collect function calls and populate the caller map.
323 numberCallOpsContainedInFuncOp[funcOp] = 0;
324 WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
325 func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
326 assert(calledFunction && "could not retrieved called func::FuncOp");
327 // If the called function does not have any tensors in its signature, then
328 // it is not necessary to bufferize the callee before the caller.
329 if (!hasTensorSignature(calledFunction))
330 return WalkResult::skip();
331
332 callerMap[calledFunction].insert(callOp);
333 if (calledBy[calledFunction].insert(funcOp).second) {
334 numberCallOpsContainedInFuncOp[funcOp]++;
335 }
336 return WalkResult::advance();
337 });
338 if (res.wasInterrupted())
339 return failure();
340 }
341
342 // Iteratively remove function operations that do not call any of the
343 // functions remaining in the callCounter map and add them to ordered list.
344 SmallVector<func::FuncOp> worklist;
345
346 for (const auto &entry : numberCallOpsContainedInFuncOp) {
347 if (entry.second == 0)
348 worklist.push_back(entry.first);
349 }
350
351 while (!worklist.empty()) {
352 func::FuncOp func = worklist.pop_back_val();
353 orderedFuncOps.push_back(func);
354
355 for (func::FuncOp caller : calledBy[func]) {
356 auto &count = numberCallOpsContainedInFuncOp[caller];
357
358 if (--count == 0)
359 worklist.push_back(caller);
360 }
361
362 numberCallOpsContainedInFuncOp.erase(func);
363 }
364
365 // Put all other functions in the list of remaining functions. These are
366 // functions that call each other circularly.
367 for (auto it : numberCallOpsContainedInFuncOp)
368 remainingFuncOps.push_back(it.first);
369
370 return success();
371}
372
373/// Helper function that extracts the source from a memref.cast. If the given
374/// value is not a memref.cast result, simply returns the given value.
375static Value unpackCast(Value v) {
376 auto castOp = v.getDefiningOp<memref::CastOp>();
377 if (!castOp)
378 return v;
379 return castOp.getSource();
380}
381
382/// Helper function that returns the return types (skipping casts) of the given
383/// func.return ops. This function returns as many types as the return ops have
384/// operands. If the i-th operand is not the same for all func.return ops, then
385/// the i-th returned type is an "empty" type.
386static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
387 assert(!returnOps.empty() && "expected at least one ReturnOp");
388 int numOperands = returnOps.front()->getNumOperands();
389
390 // Helper function that unpacks memref.cast ops and returns the type.
391 auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
392
393 SmallVector<Type> result;
394 for (int i = 0; i < numOperands; ++i) {
395 // Get the type of the i-th operand of the first func.return ops.
396 Type t = getSourceType(returnOps.front()->getOperand(i));
397
398 // Check if all other func.return ops have a matching operand type.
399 for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
400 if (getSourceType(returnOps[j]->getOperand(i)) != t)
401 t = Type();
402
403 result.push_back(Elt: t);
404 }
405
406 return result;
407}
408
409/// Fold return values that are memref casts and update function return types.
410///
411/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
412/// is not known yet. Therefore, the bufferization uses memref types with the
413/// most generic layout map as function return types. After bufferizing the
414/// entire function body, a more concise memref type can potentially be used for
415/// the return type of the function.
416static void foldMemRefCasts(func::FuncOp funcOp) {
417 // There is nothing to do for bodiless ops.
418 if (funcOp.getBody().empty())
419 return;
420
421 // Compute the common result types of all return ops.
422 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
423 SmallVector<Type> resultTypes = getReturnTypes(returnOps);
424
425 // Remove direct casts.
426 for (func::ReturnOp returnOp : returnOps) {
427 for (OpOperand &operand : returnOp->getOpOperands()) {
428 // Bail if no common result type was found.
429 if (resultTypes[operand.getOperandNumber()]) {
430 operand.set(unpackCast(operand.get()));
431 }
432 }
433 }
434
435 // Fill in the missing result types that were not the same among all
436 // func.return ops.
437 for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
438 if (resultTypes[i])
439 continue;
440 resultTypes[i] = funcOp.getFunctionType().getResult(i);
441 }
442
443 // Update the function type.
444 auto newFuncType = FunctionType::get(
445 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
446 funcOp.setType(newFuncType);
447}
448
449LogicalResult
450mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
451 OneShotAnalysisState &state,
452 BufferizationStatistics *statistics) {
453 assert(state.getOptions().bufferizeFunctionBoundaries &&
454 "expected that function boundary bufferization is activated");
455 FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
456
457 // A list of non-circular functions in the order in which they are analyzed
458 // and bufferized.
459 SmallVector<func::FuncOp> orderedFuncOps;
460 // A list of all other functions. I.e., functions that call each other
461 // recursively. For these, we analyze the function body but not the function
462 // boundary.
463 SmallVector<func::FuncOp> remainingFuncOps;
464
465 // A mapping of FuncOps to their callers.
466 FuncCallerMap callerMap;
467
468 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
469 remainingFuncOps, callerMap,
470 funcState.symbolTables)))
471 return failure();
472
473 // Analyze functions in order. Starting with functions that are not calling
474 // any other functions.
475 for (func::FuncOp funcOp : orderedFuncOps) {
476 if (!state.getOptions().isOpAllowed(funcOp))
477 continue;
478
479 // Now analyzing function.
480 funcState.startFunctionAnalysis(funcOp);
481
482 // Analyze funcOp.
483 if (failed(analyzeOp(funcOp, state, statistics)))
484 return failure();
485
486 // Run some extra function analyses.
487 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
488 failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
489 return failure();
490
491 // Mark op as fully analyzed.
492 funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
493 }
494
495 // Analyze all other functions. All function boundary analyses are skipped.
496 for (func::FuncOp funcOp : remainingFuncOps) {
497 if (!state.getOptions().isOpAllowed(funcOp))
498 continue;
499
500 // Analyze funcOp.
501 if (failed(analyzeOp(funcOp, state, statistics)))
502 return failure();
503
504 // TODO: We currently skip all function argument analyses for functions
505 // that call each other circularly. These analyses do not support recursive
506 // calls yet. The `BufferizableOpInterface` implementations of `func`
507 // dialect ops return conservative results in the absence of analysis
508 // information.
509 }
510
511 return success();
512}
513
514void mlir::bufferization::removeBufferizationAttributesInModule(
515 ModuleOp moduleOp) {
516 for (auto op : moduleOp.getOps<func::FuncOp>()) {
517 for (BlockArgument bbArg : op.getArguments())
518 removeBufferizationAttributes(bbArg);
519 }
520}
521
522LogicalResult mlir::bufferization::bufferizeModuleOp(
523 ModuleOp moduleOp, const OneShotBufferizationOptions &options,
524 BufferizationState &state, BufferizationStatistics *statistics) {
525 assert(options.bufferizeFunctionBoundaries &&
526 "expected that function boundary bufferization is activated");
527 IRRewriter rewriter(moduleOp.getContext());
528
529 // A list of non-circular functions in the order in which they are analyzed
530 // and bufferized.
531 SmallVector<func::FuncOp> orderedFuncOps;
532 // A list of all other functions. I.e., functions that call each other
533 // recursively. For these, we analyze the function body but not the function
534 // boundary.
535 SmallVector<func::FuncOp> remainingFuncOps;
536
537 // A mapping of FuncOps to their callers.
538 FuncCallerMap callerMap;
539
540 // Try to bufferize functions in calling order. I.e., first bufferize
541 // functions that do not call other functions. This allows us to infer
542 // accurate buffer types for function return values. Functions that call
543 // each other recursively are bufferized in an unspecified order at the end.
544 // We may use unnecessarily "complex" (in terms of layout map) buffer types.
545 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
546 remainingFuncOps, callerMap,
547 state.getSymbolTables())))
548 return failure();
549 llvm::append_range(orderedFuncOps, remainingFuncOps);
550
551 // Bufferize functions.
552 for (func::FuncOp funcOp : orderedFuncOps) {
553 // Note: It would be good to apply cleanups here but we cannot as aliasInfo
554 // would be invalidated.
555
556 if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
557 // This function was not analyzed and RaW conflicts were not resolved.
558 // Buffer copies must be inserted before every write.
559 OneShotBufferizationOptions updatedOptions = options;
560 updatedOptions.copyBeforeWrite = true;
561 if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics)))
562 return failure();
563 } else {
564 if (failed(bufferizeOp(funcOp, options, state, statistics)))
565 return failure();
566 }
567
568 // Change buffer return types to more precise layout maps.
569 if (options.inferFunctionResultLayout)
570 foldMemRefCasts(funcOp);
571 }
572
573 // Bufferize all other ops.
574 for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
575 // Functions were already bufferized.
576 if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
577 continue;
578 if (failed(bufferizeOp(&op, options, state, statistics)))
579 return failure();
580 }
581
582 // Post-pass cleanup of function argument attributes.
583 removeBufferizationAttributesInModule(moduleOp);
584
585 return success();
586}
587
588LogicalResult mlir::bufferization::runOneShotModuleBufferize(
589 ModuleOp moduleOp, const OneShotBufferizationOptions &options,
590 BufferizationState &state, BufferizationStatistics *statistics) {
591 assert(options.bufferizeFunctionBoundaries &&
592 "expected that function boundary bufferization is activated");
593 assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
594 "invalid combination of bufferization flags");
595 if (!options.copyBeforeWrite) {
596 if (options.noAnalysisFuncFilter.empty()) {
597 if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
598 return failure();
599 } else {
600 // FuncOps whose names are specified in options.noAnalysisFuncFilter will
601 // not be analyzed. Ops in these FuncOps will not be analyzed as well.
602 OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
603 auto func = dyn_cast<func::FuncOp>(op);
604 if (!func)
605 func = op->getParentOfType<func::FuncOp>();
606 if (func)
607 return llvm::is_contained(options.noAnalysisFuncFilter,
608 func.getSymName());
609 return false;
610 };
611 OneShotBufferizationOptions updatedOptions(options);
612 updatedOptions.opFilter.denyOperation(analysisFilterFn);
613 if (failed(
614 insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
615 return failure();
616 }
617 }
618 if (options.testAnalysisOnly)
619 return success();
620 if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
621 return failure();
622 return success();
623}
624

source code of mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp