1//===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
15#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
16#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Transform/IR/TransformDialect.h"
20#include "mlir/Interfaces/FunctionInterfaces.h"
21
22using namespace mlir;
23using namespace mlir::bufferization;
24using namespace mlir::transform;
25
26//===----------------------------------------------------------------------===//
27// BufferLoopHoistingOp
28//===----------------------------------------------------------------------===//
29
30DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne(
31 TransformRewriter &rewriter, Operation *target,
32 ApplyToEachResultList &results, TransformState &state) {
33 bufferization::hoistBuffersFromLoops(target);
34 return DiagnosedSilenceableFailure::success();
35}
36
37void transform::BufferLoopHoistingOp::getEffects(
38 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
39 onlyReadsHandle(getTarget(), effects);
40 modifiesPayload(effects);
41}
42
43//===----------------------------------------------------------------------===//
44// OneShotBufferizeOp
45//===----------------------------------------------------------------------===//
46
47LogicalResult transform::OneShotBufferizeOp::verify() {
48 if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
49 return emitOpError() << "unsupported memcpy op";
50 if (getPrintConflicts() && !getTestAnalysisOnly())
51 return emitOpError() << "'print_conflicts' requires 'test_analysis_only'";
52 if (getDumpAliasSets() && !getTestAnalysisOnly())
53 return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'";
54 return success();
55}
56
57DiagnosedSilenceableFailure
58transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
59 TransformResults &transformResults,
60 TransformState &state) {
61 OneShotBufferizationOptions options;
62 options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops();
63 options.allowUnknownOps = getAllowUnknownOps();
64 options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
65 options.dumpAliasSets = getDumpAliasSets();
66 options.testAnalysisOnly = getTestAnalysisOnly();
67 options.printConflicts = getPrintConflicts();
68 if (getFunctionBoundaryTypeConversion().has_value())
69 options.setFunctionBoundaryTypeConversion(
70 *getFunctionBoundaryTypeConversion());
71 if (getMemcpyOp() == "memref.copy") {
72 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
73 b.create<memref::CopyOp>(loc, from, to);
74 return success();
75 };
76 } else if (getMemcpyOp() == "linalg.copy") {
77 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
78 b.create<linalg::CopyOp>(loc, from, to);
79 return success();
80 };
81 } else {
82 llvm_unreachable("invalid copy op");
83 }
84
85 auto payloadOps = state.getPayloadOps(getTarget());
86 for (Operation *target : payloadOps) {
87 if (!isa<ModuleOp, FunctionOpInterface>(target))
88 return emitSilenceableError() << "expected module or function target";
89 auto moduleOp = dyn_cast<ModuleOp>(target);
90 if (options.bufferizeFunctionBoundaries) {
91 if (!moduleOp)
92 return emitSilenceableError() << "expected module target";
93 if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
94 return emitSilenceableError() << "bufferization failed";
95 } else {
96 if (failed(bufferization::runOneShotBufferize(target, options)))
97 return emitSilenceableError() << "bufferization failed";
98 }
99 }
100
101 // This transform op is currently restricted to ModuleOps and function ops.
102 // Such ops are modified in-place.
103 transformResults.set(cast<OpResult>(getTransformed()), payloadOps);
104 return DiagnosedSilenceableFailure::success();
105}
106
107//===----------------------------------------------------------------------===//
108// EliminateEmptyTensorsOp
109//===----------------------------------------------------------------------===//
110
111void transform::EliminateEmptyTensorsOp::getEffects(
112 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
113 onlyReadsHandle(getTarget(), effects);
114 modifiesPayload(effects);
115}
116
117DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
118 transform::TransformRewriter &rewriter, TransformResults &transformResults,
119 TransformState &state) {
120 for (Operation *target : state.getPayloadOps(getTarget())) {
121 if (failed(bufferization::eliminateEmptyTensors(rewriter, target)))
122 return mlir::emitSilenceableFailure(target->getLoc())
123 << "empty tensor elimination failed";
124 }
125 return DiagnosedSilenceableFailure::success();
126}
127
128//===----------------------------------------------------------------------===//
129// EmptyTensorToAllocTensorOp
130//===----------------------------------------------------------------------===//
131
132DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne(
133 transform::TransformRewriter &rewriter, tensor::EmptyOp target,
134 ApplyToEachResultList &results, transform::TransformState &state) {
135 rewriter.setInsertionPoint(target);
136 auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
137 target, target.getType(), target.getDynamicSizes());
138 results.push_back(alloc);
139 return DiagnosedSilenceableFailure::success();
140}
141
142//===----------------------------------------------------------------------===//
143// Transform op registration
144//===----------------------------------------------------------------------===//
145
146namespace {
147/// Registers new ops and declares PDL as dependent dialect since the additional
148/// ops are using PDL types for operands and results.
149class BufferizationTransformDialectExtension
150 : public transform::TransformDialectExtension<
151 BufferizationTransformDialectExtension> {
152public:
153 using Base::Base;
154
155 void init() {
156 declareGeneratedDialect<bufferization::BufferizationDialect>();
157 declareGeneratedDialect<memref::MemRefDialect>();
158
159 registerTransformOps<
160#define GET_OP_LIST
161#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
162 >();
163 }
164};
165} // namespace
166
167#define GET_OP_CLASSES
168#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
169
170#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
171
172void mlir::bufferization::registerTransformDialectExtension(
173 DialectRegistry &registry) {
174 registry.addExtensions<BufferizationTransformDialectExtension>();
175}
176

source code of mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp