1//===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
16#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18
19namespace mlir {
20namespace bufferization {
21#define GEN_PASS_DEF_TENSORCOPYINSERTION
22#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
23} // namespace bufferization
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::bufferization;
28
29LogicalResult mlir::bufferization::insertTensorCopies(
30 Operation *op, const OneShotBufferizationOptions &options,
31 const BufferizationState &bufferizationState,
32 BufferizationStatistics *statistics) {
33 OneShotAnalysisState analysisState(op, options);
34 // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
35 // analysis depending on whether function boundary bufferization is enabled or
36 // not.
37 if (options.bufferizeFunctionBoundaries) {
38 if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
39 return failure();
40 } else {
41 if (failed(Result: analyzeOp(op, state&: analysisState, statistics)))
42 return failure();
43 }
44
45 if (options.testAnalysisOnly)
46 return success();
47
48 return insertTensorCopies(op, analysisState, bufferizationState);
49}
50
51LogicalResult mlir::bufferization::insertTensorCopies(
52 Operation *op, const AnalysisState &analysisState,
53 const BufferizationState &bufferizationState) {
54 IRRewriter rewriter(op->getContext());
55
56 // It may be more efficient to walk in pre-order here, but the current
57 // implementation visits regions of ops even if they are not allowed or
58 // bufferizable, and existing tests rely on this behavior.
59 // For now, only exclude nested operations if they are in a different symbol
60 // table scope.
61 WalkResult result = op->walk(callback: [&](Operation *nestedOp) {
62 if (op->hasTrait<OpTrait::SymbolTable>() &&
63 nestedOp->getParentWithTrait<OpTrait::SymbolTable>() != op)
64 return WalkResult::skip();
65
66 auto bufferizableOp =
67 analysisState.getOptions().dynCastBufferizableOp(nestedOp);
68 if (!bufferizableOp)
69 return WalkResult::skip();
70
71 // Find inplacability conflicts and resolve them. (Typically with explicit
72 // tensor copies in the form of AllocTensorOps.)
73 rewriter.setInsertionPoint(nestedOp);
74 if (failed(bufferizableOp.resolveConflicts(rewriter, analysisState,
75 bufferizationState)))
76 return WalkResult::interrupt();
77
78 return WalkResult::advance();
79 });
80
81 return failure(IsFailure: result.wasInterrupted());
82}
83

source code of mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp