| 1 | //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 12 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
| 13 | #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" |
| 14 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
| 15 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" |
| 16 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 19 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| 20 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 21 | |
| 22 | using namespace mlir; |
| 23 | using namespace mlir::bufferization; |
| 24 | using namespace mlir::transform; |
| 25 | |
| 26 | //===----------------------------------------------------------------------===// |
| 27 | // BufferLoopHoistingOp |
| 28 | //===----------------------------------------------------------------------===// |
| 29 | |
| 30 | DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne( |
| 31 | TransformRewriter &rewriter, Operation *target, |
| 32 | ApplyToEachResultList &results, TransformState &state) { |
| 33 | bufferization::hoistBuffersFromLoops(target); |
| 34 | return DiagnosedSilenceableFailure::success(); |
| 35 | } |
| 36 | |
| 37 | void transform::BufferLoopHoistingOp::getEffects( |
| 38 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 39 | onlyReadsHandle(getTargetMutable(), effects); |
| 40 | modifiesPayload(effects); |
| 41 | } |
| 42 | |
| 43 | //===----------------------------------------------------------------------===// |
| 44 | // OneShotBufferizeOp |
| 45 | //===----------------------------------------------------------------------===// |
| 46 | |
| 47 | LogicalResult transform::OneShotBufferizeOp::verify() { |
| 48 | if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy" ) |
| 49 | return emitOpError() << "unsupported memcpy op" ; |
| 50 | if (getPrintConflicts() && !getTestAnalysisOnly()) |
| 51 | return emitOpError() << "'print_conflicts' requires 'test_analysis_only'" ; |
| 52 | if (getDumpAliasSets() && !getTestAnalysisOnly()) |
| 53 | return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'" ; |
| 54 | return success(); |
| 55 | } |
| 56 | |
| 57 | DiagnosedSilenceableFailure |
| 58 | transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, |
| 59 | TransformResults &transformResults, |
| 60 | TransformState &state) { |
| 61 | OneShotBufferizationOptions options; |
| 62 | options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops(); |
| 63 | options.allowUnknownOps = getAllowUnknownOps(); |
| 64 | options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); |
| 65 | options.dumpAliasSets = getDumpAliasSets(); |
| 66 | options.testAnalysisOnly = getTestAnalysisOnly(); |
| 67 | options.printConflicts = getPrintConflicts(); |
| 68 | if (getFunctionBoundaryTypeConversion().has_value()) |
| 69 | options.setFunctionBoundaryTypeConversion( |
| 70 | *getFunctionBoundaryTypeConversion()); |
| 71 | if (getMemcpyOp() == "memref.copy" ) { |
| 72 | options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { |
| 73 | b.create<memref::CopyOp>(loc, from, to); |
| 74 | return success(); |
| 75 | }; |
| 76 | } else if (getMemcpyOp() == "linalg.copy" ) { |
| 77 | options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { |
| 78 | b.create<linalg::CopyOp>(loc, from, to); |
| 79 | return success(); |
| 80 | }; |
| 81 | } else { |
| 82 | llvm_unreachable("invalid copy op" ); |
| 83 | } |
| 84 | |
| 85 | auto payloadOps = state.getPayloadOps(getTarget()); |
| 86 | BufferizationState bufferizationState; |
| 87 | |
| 88 | for (Operation *target : payloadOps) { |
| 89 | if (!isa<ModuleOp, FunctionOpInterface>(target)) |
| 90 | return emitSilenceableError() << "expected module or function target" ; |
| 91 | auto moduleOp = dyn_cast<ModuleOp>(target); |
| 92 | if (options.bufferizeFunctionBoundaries) { |
| 93 | if (!moduleOp) |
| 94 | return emitSilenceableError() << "expected module target" ; |
| 95 | if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options, |
| 96 | bufferizationState))) |
| 97 | return emitSilenceableError() << "bufferization failed" ; |
| 98 | } else { |
| 99 | if (failed(bufferization::runOneShotBufferize(target, options, |
| 100 | bufferizationState))) |
| 101 | return emitSilenceableError() << "bufferization failed" ; |
| 102 | } |
| 103 | } |
| 104 | |
| 105 | // This transform op is currently restricted to ModuleOps and function ops. |
| 106 | // Such ops are modified in-place. |
| 107 | transformResults.set(cast<OpResult>(getTransformed()), payloadOps); |
| 108 | return DiagnosedSilenceableFailure::success(); |
| 109 | } |
| 110 | |
| 111 | //===----------------------------------------------------------------------===// |
| 112 | // EliminateEmptyTensorsOp |
| 113 | //===----------------------------------------------------------------------===// |
| 114 | |
| 115 | void transform::EliminateEmptyTensorsOp::getEffects( |
| 116 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 117 | onlyReadsHandle(getTargetMutable(), effects); |
| 118 | modifiesPayload(effects); |
| 119 | } |
| 120 | |
| 121 | DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( |
| 122 | transform::TransformRewriter &rewriter, TransformResults &transformResults, |
| 123 | TransformState &state) { |
| 124 | for (Operation *target : state.getPayloadOps(getTarget())) { |
| 125 | if (failed(bufferization::eliminateEmptyTensors(rewriter, target))) |
| 126 | return mlir::emitSilenceableFailure(target->getLoc()) |
| 127 | << "empty tensor elimination failed" ; |
| 128 | } |
| 129 | return DiagnosedSilenceableFailure::success(); |
| 130 | } |
| 131 | |
| 132 | //===----------------------------------------------------------------------===// |
| 133 | // EmptyTensorToAllocTensorOp |
| 134 | //===----------------------------------------------------------------------===// |
| 135 | |
| 136 | DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne( |
| 137 | transform::TransformRewriter &rewriter, tensor::EmptyOp target, |
| 138 | ApplyToEachResultList &results, transform::TransformState &state) { |
| 139 | rewriter.setInsertionPoint(target); |
| 140 | auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>( |
| 141 | target, target.getType(), target.getDynamicSizes()); |
| 142 | results.push_back(alloc); |
| 143 | return DiagnosedSilenceableFailure::success(); |
| 144 | } |
| 145 | |
| 146 | //===----------------------------------------------------------------------===// |
| 147 | // Transform op registration |
| 148 | //===----------------------------------------------------------------------===// |
| 149 | |
| 150 | namespace { |
| 151 | /// Registers new ops and declares PDL as dependent dialect since the additional |
| 152 | /// ops are using PDL types for operands and results. |
| 153 | class BufferizationTransformDialectExtension |
| 154 | : public transform::TransformDialectExtension< |
| 155 | BufferizationTransformDialectExtension> { |
| 156 | public: |
| 157 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
| 158 | BufferizationTransformDialectExtension) |
| 159 | |
| 160 | using Base::Base; |
| 161 | |
| 162 | void init() { |
| 163 | declareGeneratedDialect<bufferization::BufferizationDialect>(); |
| 164 | declareGeneratedDialect<memref::MemRefDialect>(); |
| 165 | |
| 166 | registerTransformOps< |
| 167 | #define GET_OP_LIST |
| 168 | #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" |
| 169 | |
| 170 | >(); |
| 171 | } |
| 172 | }; |
| 173 | } // namespace |
| 174 | |
| 175 | #define GET_OP_CLASSES |
| 176 | #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" |
| 177 | |
| 178 | #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc" |
| 179 | |
| 180 | void mlir::bufferization::registerTransformDialectExtension( |
| 181 | DialectRegistry ®istry) { |
| 182 | registry.addExtensions<BufferizationTransformDialectExtension>(); |
| 183 | } |
| 184 | |