1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
13#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15
16using namespace mlir;
17using namespace mlir::bufferization;
18using namespace mlir::ml_program;
19
20namespace mlir {
21namespace ml_program {
22namespace {
23
24template <typename Interface, typename Op>
25struct ExternalModelBase
26 : public BufferizableOpInterface::ExternalModel<Interface, Op> {
27
28 AliasingValueList getAliasingValues(Operation *, OpOperand &,
29 const AnalysisState &) const {
30 return {};
31 }
32
33 BufferRelation bufferRelation(Operation *, OpResult,
34 const AnalysisState &) const {
35 return BufferRelation::Unknown;
36 }
37};
38
39/// Bufferization of ml_program.global into a memref.global
40struct GlobalOpInterface
41 : public ExternalModelBase<GlobalOpInterface, GlobalOp> {
42
43 bool bufferizesToMemoryRead(Operation *, OpOperand &,
44 const AnalysisState &) const {
45 return false;
46 }
47
48 bool bufferizesToMemoryWrite(Operation *, OpOperand &,
49 const AnalysisState &) const {
50 return false;
51 }
52
53 bool hasTensorSemantics(Operation *) const { return true; }
54
55 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
56 const BufferizationOptions &,
57 BufferizationState &state) const {
58 auto globalOp = cast<GlobalOp>(op);
59 if (!globalOp.getValue().has_value())
60 return globalOp.emitError("global op must have a value");
61
62 bufferization::removeSymbol(op: globalOp, state);
63
64 auto tensorType = cast<TensorType>(globalOp.getType());
65 auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
66
67 auto replacement = replaceOpWithNewBufferizedOp<memref::GlobalOp>(
68 rewriter, globalOp, globalOp.getSymName(),
69 /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
70 /*type=*/cast<MemRefType>(memrefType),
71 /*initial_value=*/globalOp.getValue().value(),
72 /*constant=*/!globalOp.getIsMutable(),
73 /*alignment=*/nullptr);
74
75 bufferization::insertSymbol(op: replacement, state);
76 return success();
77 }
78};
79
80/// Bufferization of ml_program.global_load into a memref.get_global
81struct GlobalLoadOpInterface
82 : public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
83
84 bool bufferizesToMemoryRead(Operation *, OpOperand &,
85 const AnalysisState &) const {
86 return false;
87 }
88
89 bool bufferizesToMemoryWrite(Operation *, OpOperand &,
90 const AnalysisState &) const {
91 return false;
92 }
93
94 bool isWritable(Operation *, Value, const AnalysisState &) const {
95 return false;
96 }
97
98 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
99 const BufferizationOptions &,
100 BufferizationState &state) const {
101 auto globalLoadOp = cast<GlobalLoadOp>(op);
102
103 auto tensorType = cast<TensorType>(globalLoadOp.getType());
104 auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
105
106 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
107 rewriter, globalLoadOp, memrefType,
108 globalLoadOp.getGlobalAttr().getLeafReference());
109
110 return success();
111 }
112};
113
114/// Bufferization of ml_program.global_store into a memref.get_global and
115/// memcpy
116struct GlobalStoreOpInterface
117 : public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
118
119 bool bufferizesToMemoryRead(Operation *, OpOperand &,
120 const AnalysisState &) const {
121 return false;
122 }
123
124 bool bufferizesToMemoryWrite(Operation *, OpOperand &,
125 const AnalysisState &) const {
126 return true;
127 }
128
129 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
130 const BufferizationOptions &options,
131 BufferizationState &state) const {
132 auto globalStoreOp = cast<GlobalStoreOp>(op);
133
134 auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
135 auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
136
137 auto loc = globalStoreOp.getLoc();
138 auto targetMemref = rewriter.create<memref::GetGlobalOp>(
139 loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
140
141 auto sourceMemref =
142 getBuffer(rewriter, globalStoreOp.getValue(), options, state);
143 if (failed(sourceMemref)) {
144 return failure();
145 }
146
147 auto memcpy =
148 options.createMemCpy(b&: rewriter, loc: loc, from: sourceMemref.value(), to: targetMemref);
149 if (failed(memcpy)) {
150 return failure();
151 }
152 rewriter.eraseOp(op: globalStoreOp);
153
154 return success();
155 }
156};
157} // namespace
158
159void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
160 registry.addExtension(extensionFn: +[](MLIRContext *ctx, MLProgramDialect *) {
161 GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
162 GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
163 GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
164 });
165}
166} // namespace ml_program
167} // namespace mlir
168

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp