1 | //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
13 | #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
15 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" |
16 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
19 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
20 | #include "mlir/Interfaces/FunctionInterfaces.h" |
21 | |
22 | using namespace mlir; |
23 | using namespace mlir::bufferization; |
24 | using namespace mlir::transform; |
25 | |
26 | //===----------------------------------------------------------------------===// |
27 | // BufferLoopHoistingOp |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne( |
31 | TransformRewriter &rewriter, Operation *target, |
32 | ApplyToEachResultList &results, TransformState &state) { |
33 | bufferization::hoistBuffersFromLoops(target); |
34 | return DiagnosedSilenceableFailure::success(); |
35 | } |
36 | |
37 | void transform::BufferLoopHoistingOp::getEffects( |
38 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
39 | onlyReadsHandle(getTarget(), effects); |
40 | modifiesPayload(effects); |
41 | } |
42 | |
43 | //===----------------------------------------------------------------------===// |
44 | // OneShotBufferizeOp |
45 | //===----------------------------------------------------------------------===// |
46 | |
47 | LogicalResult transform::OneShotBufferizeOp::verify() { |
48 | if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy" ) |
49 | return emitOpError() << "unsupported memcpy op" ; |
50 | if (getPrintConflicts() && !getTestAnalysisOnly()) |
51 | return emitOpError() << "'print_conflicts' requires 'test_analysis_only'" ; |
52 | if (getDumpAliasSets() && !getTestAnalysisOnly()) |
53 | return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'" ; |
54 | return success(); |
55 | } |
56 | |
57 | DiagnosedSilenceableFailure |
58 | transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, |
59 | TransformResults &transformResults, |
60 | TransformState &state) { |
61 | OneShotBufferizationOptions options; |
62 | options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops(); |
63 | options.allowUnknownOps = getAllowUnknownOps(); |
64 | options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); |
65 | options.dumpAliasSets = getDumpAliasSets(); |
66 | options.testAnalysisOnly = getTestAnalysisOnly(); |
67 | options.printConflicts = getPrintConflicts(); |
68 | if (getFunctionBoundaryTypeConversion().has_value()) |
69 | options.setFunctionBoundaryTypeConversion( |
70 | *getFunctionBoundaryTypeConversion()); |
71 | if (getMemcpyOp() == "memref.copy" ) { |
72 | options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { |
73 | b.create<memref::CopyOp>(loc, from, to); |
74 | return success(); |
75 | }; |
76 | } else if (getMemcpyOp() == "linalg.copy" ) { |
77 | options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { |
78 | b.create<linalg::CopyOp>(loc, from, to); |
79 | return success(); |
80 | }; |
81 | } else { |
82 | llvm_unreachable("invalid copy op" ); |
83 | } |
84 | |
85 | auto payloadOps = state.getPayloadOps(getTarget()); |
86 | for (Operation *target : payloadOps) { |
87 | if (!isa<ModuleOp, FunctionOpInterface>(target)) |
88 | return emitSilenceableError() << "expected module or function target" ; |
89 | auto moduleOp = dyn_cast<ModuleOp>(target); |
90 | if (options.bufferizeFunctionBoundaries) { |
91 | if (!moduleOp) |
92 | return emitSilenceableError() << "expected module target" ; |
93 | if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) |
94 | return emitSilenceableError() << "bufferization failed" ; |
95 | } else { |
96 | if (failed(bufferization::runOneShotBufferize(target, options))) |
97 | return emitSilenceableError() << "bufferization failed" ; |
98 | } |
99 | } |
100 | |
101 | // This transform op is currently restricted to ModuleOps and function ops. |
102 | // Such ops are modified in-place. |
103 | transformResults.set(cast<OpResult>(getTransformed()), payloadOps); |
104 | return DiagnosedSilenceableFailure::success(); |
105 | } |
106 | |
107 | //===----------------------------------------------------------------------===// |
108 | // EliminateEmptyTensorsOp |
109 | //===----------------------------------------------------------------------===// |
110 | |
111 | void transform::EliminateEmptyTensorsOp::getEffects( |
112 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
113 | onlyReadsHandle(getTarget(), effects); |
114 | modifiesPayload(effects); |
115 | } |
116 | |
117 | DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( |
118 | transform::TransformRewriter &rewriter, TransformResults &transformResults, |
119 | TransformState &state) { |
120 | for (Operation *target : state.getPayloadOps(getTarget())) { |
121 | if (failed(bufferization::eliminateEmptyTensors(rewriter, target))) |
122 | return mlir::emitSilenceableFailure(target->getLoc()) |
123 | << "empty tensor elimination failed" ; |
124 | } |
125 | return DiagnosedSilenceableFailure::success(); |
126 | } |
127 | |
128 | //===----------------------------------------------------------------------===// |
129 | // EmptyTensorToAllocTensorOp |
130 | //===----------------------------------------------------------------------===// |
131 | |
132 | DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne( |
133 | transform::TransformRewriter &rewriter, tensor::EmptyOp target, |
134 | ApplyToEachResultList &results, transform::TransformState &state) { |
135 | rewriter.setInsertionPoint(target); |
136 | auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>( |
137 | target, target.getType(), target.getDynamicSizes()); |
138 | results.push_back(alloc); |
139 | return DiagnosedSilenceableFailure::success(); |
140 | } |
141 | |
142 | //===----------------------------------------------------------------------===// |
143 | // Transform op registration |
144 | //===----------------------------------------------------------------------===// |
145 | |
146 | namespace { |
147 | /// Registers new ops and declares PDL as dependent dialect since the additional |
148 | /// ops are using PDL types for operands and results. |
149 | class BufferizationTransformDialectExtension |
150 | : public transform::TransformDialectExtension< |
151 | BufferizationTransformDialectExtension> { |
152 | public: |
153 | using Base::Base; |
154 | |
155 | void init() { |
156 | declareGeneratedDialect<bufferization::BufferizationDialect>(); |
157 | declareGeneratedDialect<memref::MemRefDialect>(); |
158 | |
159 | registerTransformOps< |
160 | #define GET_OP_LIST |
161 | #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" |
162 | >(); |
163 | } |
164 | }; |
165 | } // namespace |
166 | |
167 | #define GET_OP_CLASSES |
168 | #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" |
169 | |
170 | #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc" |
171 | |
172 | void mlir::bufferization::registerTransformDialectExtension( |
173 | DialectRegistry ®istry) { |
174 | registry.addExtensions<BufferizationTransformDialectExtension>(); |
175 | } |
176 | |