1//===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
15#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
16#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Transform/IR/TransformDialect.h"
20#include "mlir/Interfaces/FunctionInterfaces.h"
21
22using namespace mlir;
23using namespace mlir::bufferization;
24using namespace mlir::transform;
25
26//===----------------------------------------------------------------------===//
27// BufferLoopHoistingOp
28//===----------------------------------------------------------------------===//
29
30DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne(
31 TransformRewriter &rewriter, Operation *target,
32 ApplyToEachResultList &results, TransformState &state) {
33 bufferization::hoistBuffersFromLoops(target);
34 return DiagnosedSilenceableFailure::success();
35}
36
37void transform::BufferLoopHoistingOp::getEffects(
38 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
39 onlyReadsHandle(getTargetMutable(), effects);
40 modifiesPayload(effects);
41}
42
43//===----------------------------------------------------------------------===//
44// OneShotBufferizeOp
45//===----------------------------------------------------------------------===//
46
47LogicalResult transform::OneShotBufferizeOp::verify() {
48 if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
49 return emitOpError() << "unsupported memcpy op";
50 if (getPrintConflicts() && !getTestAnalysisOnly())
51 return emitOpError() << "'print_conflicts' requires 'test_analysis_only'";
52 if (getDumpAliasSets() && !getTestAnalysisOnly())
53 return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'";
54 return success();
55}
56
57DiagnosedSilenceableFailure
58transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
59 TransformResults &transformResults,
60 TransformState &state) {
61 OneShotBufferizationOptions options;
62 options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops();
63 options.allowUnknownOps = getAllowUnknownOps();
64 options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
65 options.dumpAliasSets = getDumpAliasSets();
66 options.testAnalysisOnly = getTestAnalysisOnly();
67 options.printConflicts = getPrintConflicts();
68 if (getFunctionBoundaryTypeConversion().has_value())
69 options.setFunctionBoundaryTypeConversion(
70 *getFunctionBoundaryTypeConversion());
71 if (getMemcpyOp() == "memref.copy") {
72 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
73 b.create<memref::CopyOp>(loc, from, to);
74 return success();
75 };
76 } else if (getMemcpyOp() == "linalg.copy") {
77 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
78 b.create<linalg::CopyOp>(loc, from, to);
79 return success();
80 };
81 } else {
82 llvm_unreachable("invalid copy op");
83 }
84
85 auto payloadOps = state.getPayloadOps(getTarget());
86 BufferizationState bufferizationState;
87
88 for (Operation *target : payloadOps) {
89 if (!isa<ModuleOp, FunctionOpInterface>(target))
90 return emitSilenceableError() << "expected module or function target";
91 auto moduleOp = dyn_cast<ModuleOp>(target);
92 if (options.bufferizeFunctionBoundaries) {
93 if (!moduleOp)
94 return emitSilenceableError() << "expected module target";
95 if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
96 bufferizationState)))
97 return emitSilenceableError() << "bufferization failed";
98 } else {
99 if (failed(bufferization::runOneShotBufferize(target, options,
100 bufferizationState)))
101 return emitSilenceableError() << "bufferization failed";
102 }
103 }
104
105 // This transform op is currently restricted to ModuleOps and function ops.
106 // Such ops are modified in-place.
107 transformResults.set(cast<OpResult>(getTransformed()), payloadOps);
108 return DiagnosedSilenceableFailure::success();
109}
110
111//===----------------------------------------------------------------------===//
112// EliminateEmptyTensorsOp
113//===----------------------------------------------------------------------===//
114
115void transform::EliminateEmptyTensorsOp::getEffects(
116 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
117 onlyReadsHandle(getTargetMutable(), effects);
118 modifiesPayload(effects);
119}
120
121DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
122 transform::TransformRewriter &rewriter, TransformResults &transformResults,
123 TransformState &state) {
124 for (Operation *target : state.getPayloadOps(getTarget())) {
125 if (failed(bufferization::eliminateEmptyTensors(rewriter, target)))
126 return mlir::emitSilenceableFailure(target->getLoc())
127 << "empty tensor elimination failed";
128 }
129 return DiagnosedSilenceableFailure::success();
130}
131
132//===----------------------------------------------------------------------===//
133// EmptyTensorToAllocTensorOp
134//===----------------------------------------------------------------------===//
135
136DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne(
137 transform::TransformRewriter &rewriter, tensor::EmptyOp target,
138 ApplyToEachResultList &results, transform::TransformState &state) {
139 rewriter.setInsertionPoint(target);
140 auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
141 target, target.getType(), target.getDynamicSizes());
142 results.push_back(alloc);
143 return DiagnosedSilenceableFailure::success();
144}
145
146//===----------------------------------------------------------------------===//
147// Transform op registration
148//===----------------------------------------------------------------------===//
149
150namespace {
151/// Registers new ops and declares PDL as dependent dialect since the additional
152/// ops are using PDL types for operands and results.
153class BufferizationTransformDialectExtension
154 : public transform::TransformDialectExtension<
155 BufferizationTransformDialectExtension> {
156public:
157 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
158 BufferizationTransformDialectExtension)
159
160 using Base::Base;
161
162 void init() {
163 declareGeneratedDialect<bufferization::BufferizationDialect>();
164 declareGeneratedDialect<memref::MemRefDialect>();
165
166 registerTransformOps<
167#define GET_OP_LIST
168#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
169
170 >();
171 }
172};
173} // namespace
174
175#define GET_OP_CLASSES
176#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
177
178#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
179
180void mlir::bufferization::registerTransformDialectExtension(
181 DialectRegistry &registry) {
182 registry.addExtensions<BufferizationTransformDialectExtension>();
183}
184

source code of mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp