1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/MLProgram/IR/MLProgram.h" |
13 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
14 | |
15 | using namespace mlir; |
16 | using namespace mlir::bufferization; |
17 | using namespace mlir::ml_program; |
18 | |
19 | namespace mlir { |
20 | namespace ml_program { |
21 | namespace { |
22 | |
23 | template <typename Interface, typename Op> |
24 | struct ExternalModelBase |
25 | : public BufferizableOpInterface::ExternalModel<Interface, Op> { |
26 | |
27 | AliasingValueList getAliasingValues(Operation *, OpOperand &, |
28 | const AnalysisState &) const { |
29 | return {}; |
30 | } |
31 | |
32 | BufferRelation bufferRelation(Operation *, OpResult, |
33 | const AnalysisState &) const { |
34 | return BufferRelation::Unknown; |
35 | } |
36 | }; |
37 | |
38 | /// Bufferization of ml_program.global into a memref.global |
39 | struct GlobalOpInterface |
40 | : public ExternalModelBase<GlobalOpInterface, GlobalOp> { |
41 | |
42 | bool bufferizesToMemoryRead(Operation *, OpOperand &, |
43 | const AnalysisState &) const { |
44 | return false; |
45 | } |
46 | |
47 | bool bufferizesToMemoryWrite(Operation *, OpOperand &, |
48 | const AnalysisState &) const { |
49 | return false; |
50 | } |
51 | |
52 | bool hasTensorSemantics(Operation *) const { return true; } |
53 | |
54 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
55 | const BufferizationOptions &) const { |
56 | auto globalOp = cast<GlobalOp>(op); |
57 | if (!globalOp.getValue().has_value()) |
58 | return globalOp.emitError("global op must have a value" ); |
59 | |
60 | auto tensorType = cast<TensorType>(globalOp.getType()); |
61 | auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); |
62 | |
63 | replaceOpWithNewBufferizedOp<memref::GlobalOp>( |
64 | rewriter, globalOp, globalOp.getSymName(), |
65 | /*sym_visibility=*/globalOp.getSymVisibilityAttr(), |
66 | /*type=*/cast<MemRefType>(memrefType), |
67 | /*initial_value=*/globalOp.getValue().value(), |
68 | /*constant=*/!globalOp.getIsMutable(), |
69 | /*alignment=*/nullptr); |
70 | |
71 | return success(); |
72 | } |
73 | }; |
74 | |
75 | /// Bufferization of ml_program.global_load into a memref.get_global |
76 | struct GlobalLoadOpInterface |
77 | : public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> { |
78 | |
79 | bool bufferizesToMemoryRead(Operation *, OpOperand &, |
80 | const AnalysisState &) const { |
81 | return false; |
82 | } |
83 | |
84 | bool bufferizesToMemoryWrite(Operation *, OpOperand &, |
85 | const AnalysisState &) const { |
86 | return false; |
87 | } |
88 | |
89 | bool isWritable(Operation *, Value, const AnalysisState &) const { |
90 | return false; |
91 | } |
92 | |
93 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
94 | const BufferizationOptions &) const { |
95 | auto globalLoadOp = cast<GlobalLoadOp>(op); |
96 | |
97 | auto tensorType = cast<TensorType>(globalLoadOp.getType()); |
98 | auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); |
99 | |
100 | replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( |
101 | rewriter, globalLoadOp, memrefType, |
102 | globalLoadOp.getGlobalAttr().getLeafReference()); |
103 | |
104 | return success(); |
105 | } |
106 | }; |
107 | |
108 | /// Bufferization of ml_program.global_store into a memref.get_global and |
109 | /// memcpy |
110 | struct GlobalStoreOpInterface |
111 | : public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> { |
112 | |
113 | bool bufferizesToMemoryRead(Operation *, OpOperand &, |
114 | const AnalysisState &) const { |
115 | return false; |
116 | } |
117 | |
118 | bool bufferizesToMemoryWrite(Operation *, OpOperand &, |
119 | const AnalysisState &) const { |
120 | return true; |
121 | } |
122 | |
123 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
124 | const BufferizationOptions &options) const { |
125 | auto globalStoreOp = cast<GlobalStoreOp>(op); |
126 | |
127 | auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType()); |
128 | auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); |
129 | |
130 | auto loc = globalStoreOp.getLoc(); |
131 | auto targetMemref = rewriter.create<memref::GetGlobalOp>( |
132 | loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference()); |
133 | |
134 | auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options); |
135 | if (failed(sourceMemref)) { |
136 | return failure(); |
137 | } |
138 | |
139 | auto memcpy = |
140 | options.createMemCpy(b&: rewriter, loc: loc, from: sourceMemref.value(), to: targetMemref); |
141 | if (failed(memcpy)) { |
142 | return failure(); |
143 | } |
144 | rewriter.eraseOp(op: globalStoreOp); |
145 | |
146 | return success(); |
147 | } |
148 | }; |
149 | } // namespace |
150 | |
151 | void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { |
152 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, MLProgramDialect *) { |
153 | GlobalOp::attachInterface<GlobalOpInterface>(*ctx); |
154 | GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx); |
155 | GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx); |
156 | }); |
157 | } |
158 | } // namespace ml_program |
159 | } // namespace mlir |
160 | |