1 | //===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert Bufferization dialect to MemRef |
10 | // dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" |
15 | |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
18 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
20 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
21 | #include "mlir/Dialect/SCF/IR/SCF.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/Pass/Pass.h" |
24 | #include "mlir/Support/LogicalResult.h" |
25 | #include "mlir/Transforms/DialectConversion.h" |
26 | |
27 | namespace mlir { |
28 | #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREF |
29 | #include "mlir/Conversion/Passes.h.inc" |
30 | } // namespace mlir |
31 | |
32 | using namespace mlir; |
33 | |
34 | namespace { |
35 | /// The CloneOpConversion transforms all bufferization clone operations into |
36 | /// memref alloc and memref copy operations. In the dynamic-shape case, it also |
37 | /// emits additional dim and constant operations to determine the shape. This |
38 | /// conversion does not resolve memory leaks if it is used alone. |
39 | struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { |
40 | using OpConversionPattern<bufferization::CloneOp>::OpConversionPattern; |
41 | |
42 | LogicalResult |
43 | matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor, |
44 | ConversionPatternRewriter &rewriter) const override { |
45 | // Check for unranked memref types which are currently not supported. |
46 | Type type = op.getType(); |
47 | if (isa<UnrankedMemRefType>(type)) { |
48 | return rewriter.notifyMatchFailure( |
49 | op, "UnrankedMemRefType is not supported." ); |
50 | } |
51 | MemRefType memrefType = cast<MemRefType>(type); |
52 | MemRefLayoutAttrInterface layout; |
53 | auto allocType = |
54 | MemRefType::get(memrefType.getShape(), memrefType.getElementType(), |
55 | layout, memrefType.getMemorySpace()); |
56 | // Since this implementation always allocates, certain result types of the |
57 | // clone op cannot be lowered. |
58 | if (!memref::CastOp::areCastCompatible({allocType}, {memrefType})) |
59 | return failure(); |
60 | |
61 | // Transform a clone operation into alloc + copy operation and pay |
62 | // attention to the shape dimensions. |
63 | Location loc = op->getLoc(); |
64 | SmallVector<Value, 4> dynamicOperands; |
65 | for (int i = 0; i < memrefType.getRank(); ++i) { |
66 | if (!memrefType.isDynamicDim(i)) |
67 | continue; |
68 | Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i); |
69 | dynamicOperands.push_back(Elt: dim); |
70 | } |
71 | |
72 | // Allocate a memref with identity layout. |
73 | Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType, |
74 | dynamicOperands); |
75 | // Cast the allocation to the specified type if needed. |
76 | if (memrefType != allocType) |
77 | alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc); |
78 | rewriter.replaceOp(op, alloc); |
79 | rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc); |
80 | return success(); |
81 | } |
82 | }; |
83 | |
84 | } // namespace |
85 | |
86 | namespace { |
87 | struct BufferizationToMemRefPass |
88 | : public impl::ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> { |
89 | BufferizationToMemRefPass() = default; |
90 | |
91 | void runOnOperation() override { |
92 | if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) { |
93 | emitError(getOperation()->getLoc(), |
94 | "root operation must be a builtin.module or a function" ); |
95 | signalPassFailure(); |
96 | return; |
97 | } |
98 | |
99 | func::FuncOp helperFuncOp; |
100 | if (auto module = dyn_cast<ModuleOp>(getOperation())) { |
101 | OpBuilder builder = |
102 | OpBuilder::atBlockBegin(block: &module.getBodyRegion().front()); |
103 | SymbolTable symbolTable(module); |
104 | |
105 | // Build dealloc helper function if there are deallocs. |
106 | getOperation()->walk([&](bufferization::DeallocOp deallocOp) { |
107 | if (deallocOp.getMemrefs().size() > 1) { |
108 | helperFuncOp = bufferization::buildDeallocationLibraryFunction( |
109 | builder, loc: getOperation()->getLoc(), symbolTable); |
110 | return WalkResult::interrupt(); |
111 | } |
112 | return WalkResult::advance(); |
113 | }); |
114 | } |
115 | |
116 | RewritePatternSet patterns(&getContext()); |
117 | patterns.add<CloneOpConversion>(arg: patterns.getContext()); |
118 | bufferization::populateBufferizationDeallocLoweringPattern(patterns, |
119 | deallocLibraryFunc: helperFuncOp); |
120 | |
121 | ConversionTarget target(getContext()); |
122 | target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect, |
123 | scf::SCFDialect, func::FuncDialect>(); |
124 | target.addIllegalDialect<bufferization::BufferizationDialect>(); |
125 | |
126 | if (failed(applyPartialConversion(getOperation(), target, |
127 | std::move(patterns)))) |
128 | signalPassFailure(); |
129 | } |
130 | }; |
131 | } // namespace |
132 | |
133 | std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() { |
134 | return std::make_unique<BufferizationToMemRefPass>(); |
135 | } |
136 | |