1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" |
13 | #include "mlir/Dialect/MLProgram/IR/MLProgram.h" |
14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::bufferization; |
18 | using namespace mlir::ml_program; |
19 | |
20 | namespace mlir { |
21 | namespace ml_program { |
22 | namespace { |
23 | |
24 | template <typename Interface, typename Op> |
25 | struct ExternalModelBase |
26 | : public BufferizableOpInterface::ExternalModel<Interface, Op> { |
27 | |
28 | AliasingValueList getAliasingValues(Operation *, OpOperand &, |
29 | const AnalysisState &) const { |
30 | return {}; |
31 | } |
32 | |
33 | BufferRelation bufferRelation(Operation *, OpResult, |
34 | const AnalysisState &) const { |
35 | return BufferRelation::Unknown; |
36 | } |
37 | }; |
38 | |
39 | /// Bufferization of ml_program.global into a memref.global |
40 | struct GlobalOpInterface |
41 | : public ExternalModelBase<GlobalOpInterface, GlobalOp> { |
42 | |
43 | bool bufferizesToMemoryRead(Operation *, OpOperand &, |
44 | const AnalysisState &) const { |
45 | return false; |
46 | } |
47 | |
48 | bool bufferizesToMemoryWrite(Operation *, OpOperand &, |
49 | const AnalysisState &) const { |
50 | return false; |
51 | } |
52 | |
53 | bool hasTensorSemantics(Operation *) const { return true; } |
54 | |
55 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
56 | const BufferizationOptions &, |
57 | BufferizationState &state) const { |
58 | auto globalOp = cast<GlobalOp>(op); |
59 | if (!globalOp.getValue().has_value()) |
60 | return globalOp.emitError("global op must have a value" ); |
61 | |
62 | bufferization::removeSymbol(op: globalOp, state); |
63 | |
64 | auto tensorType = cast<TensorType>(globalOp.getType()); |
65 | auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); |
66 | |
67 | auto replacement = replaceOpWithNewBufferizedOp<memref::GlobalOp>( |
68 | rewriter, globalOp, globalOp.getSymName(), |
69 | /*sym_visibility=*/globalOp.getSymVisibilityAttr(), |
70 | /*type=*/cast<MemRefType>(memrefType), |
71 | /*initial_value=*/globalOp.getValue().value(), |
72 | /*constant=*/!globalOp.getIsMutable(), |
73 | /*alignment=*/nullptr); |
74 | |
75 | bufferization::insertSymbol(op: replacement, state); |
76 | return success(); |
77 | } |
78 | }; |
79 | |
80 | /// Bufferization of ml_program.global_load into a memref.get_global |
81 | struct GlobalLoadOpInterface |
82 | : public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> { |
83 | |
84 | bool bufferizesToMemoryRead(Operation *, OpOperand &, |
85 | const AnalysisState &) const { |
86 | return false; |
87 | } |
88 | |
89 | bool bufferizesToMemoryWrite(Operation *, OpOperand &, |
90 | const AnalysisState &) const { |
91 | return false; |
92 | } |
93 | |
94 | bool isWritable(Operation *, Value, const AnalysisState &) const { |
95 | return false; |
96 | } |
97 | |
98 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
99 | const BufferizationOptions &, |
100 | BufferizationState &state) const { |
101 | auto globalLoadOp = cast<GlobalLoadOp>(op); |
102 | |
103 | auto tensorType = cast<TensorType>(globalLoadOp.getType()); |
104 | auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); |
105 | |
106 | replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( |
107 | rewriter, globalLoadOp, memrefType, |
108 | globalLoadOp.getGlobalAttr().getLeafReference()); |
109 | |
110 | return success(); |
111 | } |
112 | }; |
113 | |
114 | /// Bufferization of ml_program.global_store into a memref.get_global and |
115 | /// memcpy |
116 | struct GlobalStoreOpInterface |
117 | : public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> { |
118 | |
119 | bool bufferizesToMemoryRead(Operation *, OpOperand &, |
120 | const AnalysisState &) const { |
121 | return false; |
122 | } |
123 | |
124 | bool bufferizesToMemoryWrite(Operation *, OpOperand &, |
125 | const AnalysisState &) const { |
126 | return true; |
127 | } |
128 | |
129 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
130 | const BufferizationOptions &options, |
131 | BufferizationState &state) const { |
132 | auto globalStoreOp = cast<GlobalStoreOp>(op); |
133 | |
134 | auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType()); |
135 | auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); |
136 | |
137 | auto loc = globalStoreOp.getLoc(); |
138 | auto targetMemref = rewriter.create<memref::GetGlobalOp>( |
139 | loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference()); |
140 | |
141 | auto sourceMemref = |
142 | getBuffer(rewriter, globalStoreOp.getValue(), options, state); |
143 | if (failed(sourceMemref)) { |
144 | return failure(); |
145 | } |
146 | |
147 | auto memcpy = |
148 | options.createMemCpy(b&: rewriter, loc: loc, from: sourceMemref.value(), to: targetMemref); |
149 | if (failed(memcpy)) { |
150 | return failure(); |
151 | } |
152 | rewriter.eraseOp(op: globalStoreOp); |
153 | |
154 | return success(); |
155 | } |
156 | }; |
157 | } // namespace |
158 | |
159 | void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { |
160 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, MLProgramDialect *) { |
161 | GlobalOp::attachInterface<GlobalOpInterface>(*ctx); |
162 | GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx); |
163 | GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx); |
164 | }); |
165 | } |
166 | } // namespace ml_program |
167 | } // namespace mlir |
168 | |