1 | //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
13 | #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
15 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" |
16 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
19 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
20 | #include "mlir/Interfaces/FunctionInterfaces.h" |
21 | |
22 | using namespace mlir; |
23 | using namespace mlir::bufferization; |
24 | using namespace mlir::transform; |
25 | |
26 | //===----------------------------------------------------------------------===// |
27 | // BufferLoopHoistingOp |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne( |
31 | TransformRewriter &rewriter, Operation *target, |
32 | ApplyToEachResultList &results, TransformState &state) { |
33 | bufferization::hoistBuffersFromLoops(target); |
34 | return DiagnosedSilenceableFailure::success(); |
35 | } |
36 | |
37 | void transform::BufferLoopHoistingOp::getEffects( |
38 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
39 | onlyReadsHandle(getTargetMutable(), effects); |
40 | modifiesPayload(effects); |
41 | } |
42 | |
43 | //===----------------------------------------------------------------------===// |
44 | // OneShotBufferizeOp |
45 | //===----------------------------------------------------------------------===// |
46 | |
47 | LogicalResult transform::OneShotBufferizeOp::verify() { |
48 | if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy" ) |
49 | return emitOpError() << "unsupported memcpy op" ; |
50 | if (getPrintConflicts() && !getTestAnalysisOnly()) |
51 | return emitOpError() << "'print_conflicts' requires 'test_analysis_only'" ; |
52 | if (getDumpAliasSets() && !getTestAnalysisOnly()) |
53 | return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'" ; |
54 | return success(); |
55 | } |
56 | |
57 | DiagnosedSilenceableFailure |
58 | transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, |
59 | TransformResults &transformResults, |
60 | TransformState &state) { |
61 | OneShotBufferizationOptions options; |
62 | options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops(); |
63 | options.allowUnknownOps = getAllowUnknownOps(); |
64 | options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); |
65 | options.dumpAliasSets = getDumpAliasSets(); |
66 | options.testAnalysisOnly = getTestAnalysisOnly(); |
67 | options.printConflicts = getPrintConflicts(); |
68 | if (getFunctionBoundaryTypeConversion().has_value()) |
69 | options.setFunctionBoundaryTypeConversion( |
70 | *getFunctionBoundaryTypeConversion()); |
71 | if (getMemcpyOp() == "memref.copy" ) { |
72 | options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { |
73 | b.create<memref::CopyOp>(loc, from, to); |
74 | return success(); |
75 | }; |
76 | } else if (getMemcpyOp() == "linalg.copy" ) { |
77 | options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { |
78 | b.create<linalg::CopyOp>(loc, from, to); |
79 | return success(); |
80 | }; |
81 | } else { |
82 | llvm_unreachable("invalid copy op" ); |
83 | } |
84 | |
85 | auto payloadOps = state.getPayloadOps(getTarget()); |
86 | BufferizationState bufferizationState; |
87 | |
88 | for (Operation *target : payloadOps) { |
89 | if (!isa<ModuleOp, FunctionOpInterface>(target)) |
90 | return emitSilenceableError() << "expected module or function target" ; |
91 | auto moduleOp = dyn_cast<ModuleOp>(target); |
92 | if (options.bufferizeFunctionBoundaries) { |
93 | if (!moduleOp) |
94 | return emitSilenceableError() << "expected module target" ; |
95 | if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options, |
96 | bufferizationState))) |
97 | return emitSilenceableError() << "bufferization failed" ; |
98 | } else { |
99 | if (failed(bufferization::runOneShotBufferize(target, options, |
100 | bufferizationState))) |
101 | return emitSilenceableError() << "bufferization failed" ; |
102 | } |
103 | } |
104 | |
105 | // This transform op is currently restricted to ModuleOps and function ops. |
106 | // Such ops are modified in-place. |
107 | transformResults.set(cast<OpResult>(getTransformed()), payloadOps); |
108 | return DiagnosedSilenceableFailure::success(); |
109 | } |
110 | |
111 | //===----------------------------------------------------------------------===// |
112 | // EliminateEmptyTensorsOp |
113 | //===----------------------------------------------------------------------===// |
114 | |
115 | void transform::EliminateEmptyTensorsOp::getEffects( |
116 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
117 | onlyReadsHandle(getTargetMutable(), effects); |
118 | modifiesPayload(effects); |
119 | } |
120 | |
121 | DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( |
122 | transform::TransformRewriter &rewriter, TransformResults &transformResults, |
123 | TransformState &state) { |
124 | for (Operation *target : state.getPayloadOps(getTarget())) { |
125 | if (failed(bufferization::eliminateEmptyTensors(rewriter, target))) |
126 | return mlir::emitSilenceableFailure(target->getLoc()) |
127 | << "empty tensor elimination failed" ; |
128 | } |
129 | return DiagnosedSilenceableFailure::success(); |
130 | } |
131 | |
132 | //===----------------------------------------------------------------------===// |
133 | // EmptyTensorToAllocTensorOp |
134 | //===----------------------------------------------------------------------===// |
135 | |
136 | DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne( |
137 | transform::TransformRewriter &rewriter, tensor::EmptyOp target, |
138 | ApplyToEachResultList &results, transform::TransformState &state) { |
139 | rewriter.setInsertionPoint(target); |
140 | auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>( |
141 | target, target.getType(), target.getDynamicSizes()); |
142 | results.push_back(alloc); |
143 | return DiagnosedSilenceableFailure::success(); |
144 | } |
145 | |
146 | //===----------------------------------------------------------------------===// |
147 | // Transform op registration |
148 | //===----------------------------------------------------------------------===// |
149 | |
150 | namespace { |
151 | /// Registers new ops and declares PDL as dependent dialect since the additional |
152 | /// ops are using PDL types for operands and results. |
153 | class BufferizationTransformDialectExtension |
154 | : public transform::TransformDialectExtension< |
155 | BufferizationTransformDialectExtension> { |
156 | public: |
157 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
158 | BufferizationTransformDialectExtension) |
159 | |
160 | using Base::Base; |
161 | |
162 | void init() { |
163 | declareGeneratedDialect<bufferization::BufferizationDialect>(); |
164 | declareGeneratedDialect<memref::MemRefDialect>(); |
165 | |
166 | registerTransformOps< |
167 | #define GET_OP_LIST |
168 | #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" |
169 | |
170 | >(); |
171 | } |
172 | }; |
173 | } // namespace |
174 | |
175 | #define GET_OP_CLASSES |
176 | #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" |
177 | |
178 | #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc" |
179 | |
180 | void mlir::bufferization::registerTransformDialectExtension( |
181 | DialectRegistry ®istry) { |
182 | registry.addExtensions<BufferizationTransformDialectExtension>(); |
183 | } |
184 | |