1 | //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
13 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
15 | #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" |
16 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | |
19 | namespace mlir { |
20 | namespace bufferization { |
21 | #define GEN_PASS_DEF_TENSORCOPYINSERTION |
22 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
23 | } // namespace bufferization |
24 | } // namespace mlir |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::bufferization; |
28 | |
29 | LogicalResult mlir::bufferization::insertTensorCopies( |
30 | Operation *op, const OneShotBufferizationOptions &options, |
31 | BufferizationStatistics *statistics) { |
32 | OneShotAnalysisState state(op, options); |
33 | // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize |
34 | // analysis depending on whether function boundary bufferization is enabled or |
35 | // not. |
36 | if (options.bufferizeFunctionBoundaries) { |
37 | if (failed(analyzeModuleOp(cast<ModuleOp>(op), state, statistics))) |
38 | return failure(); |
39 | } else { |
40 | if (failed(result: analyzeOp(op, state, statistics))) |
41 | return failure(); |
42 | } |
43 | |
44 | if (options.testAnalysisOnly) |
45 | return success(); |
46 | |
47 | return insertTensorCopies(op, state); |
48 | } |
49 | |
50 | LogicalResult |
51 | mlir::bufferization::insertTensorCopies(Operation *op, |
52 | const AnalysisState &state) { |
53 | IRRewriter rewriter(op->getContext()); |
54 | |
55 | WalkResult result = op->walk(callback: [&](Operation *op) { |
56 | auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op); |
57 | if (!bufferizableOp) |
58 | return WalkResult::skip(); |
59 | |
60 | // Find inplacability conflicts and resolve them. (Typically with explicit |
61 | // tensor copies in the form of AllocTensorOps.) |
62 | rewriter.setInsertionPoint(op); |
63 | if (failed(bufferizableOp.resolveConflicts(rewriter, state))) |
64 | return WalkResult::interrupt(); |
65 | |
66 | return WalkResult::advance(); |
67 | }); |
68 | |
69 | return failure(isFailure: result.wasInterrupted()); |
70 | } |
71 | |