1//===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
16#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18
19namespace mlir {
20namespace bufferization {
21#define GEN_PASS_DEF_TENSORCOPYINSERTION
22#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
23} // namespace bufferization
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::bufferization;
28
29LogicalResult mlir::bufferization::insertTensorCopies(
30 Operation *op, const OneShotBufferizationOptions &options,
31 BufferizationStatistics *statistics) {
32 OneShotAnalysisState state(op, options);
33 // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
34 // analysis depending on whether function boundary bufferization is enabled or
35 // not.
36 if (options.bufferizeFunctionBoundaries) {
37 if (failed(analyzeModuleOp(cast<ModuleOp>(op), state, statistics)))
38 return failure();
39 } else {
40 if (failed(result: analyzeOp(op, state, statistics)))
41 return failure();
42 }
43
44 if (options.testAnalysisOnly)
45 return success();
46
47 return insertTensorCopies(op, state);
48}
49
50LogicalResult
51mlir::bufferization::insertTensorCopies(Operation *op,
52 const AnalysisState &state) {
53 IRRewriter rewriter(op->getContext());
54
55 WalkResult result = op->walk(callback: [&](Operation *op) {
56 auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
57 if (!bufferizableOp)
58 return WalkResult::skip();
59
60 // Find inplacability conflicts and resolve them. (Typically with explicit
61 // tensor copies in the form of AllocTensorOps.)
62 rewriter.setInsertionPoint(op);
63 if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
64 return WalkResult::interrupt();
65
66 return WalkResult::advance();
67 });
68
69 return failure(isFailure: result.wasInterrupted());
70}
71

source code of mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp