1 | //===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert memref ops into emitc ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" |
14 | |
15 | #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" |
16 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Transforms/DialectConversion.h" |
21 | |
22 | using namespace mlir; |
23 | |
24 | namespace { |
25 | /// Implement the interface to convert MemRef to EmitC. |
26 | struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { |
27 | using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface; |
28 | |
29 | /// Hook for derived dialect interface to provide conversion patterns |
30 | /// and mark dialect legal for the conversion target. |
31 | void populateConvertToEmitCConversionPatterns( |
32 | ConversionTarget &target, TypeConverter &typeConverter, |
33 | RewritePatternSet &patterns) const final { |
34 | populateMemRefToEmitCTypeConversion(typeConverter); |
35 | populateMemRefToEmitCConversionPatterns(patterns, converter: typeConverter); |
36 | } |
37 | }; |
38 | } // namespace |
39 | |
40 | void mlir::registerConvertMemRefToEmitCInterface(DialectRegistry ®istry) { |
41 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) { |
42 | dialect->addInterfaces<MemRefToEmitCDialectInterface>(); |
43 | }); |
44 | } |
45 | |
46 | //===----------------------------------------------------------------------===// |
47 | // Conversion Patterns |
48 | //===----------------------------------------------------------------------===// |
49 | |
50 | namespace { |
51 | struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> { |
52 | using OpConversionPattern::OpConversionPattern; |
53 | |
54 | LogicalResult |
55 | matchAndRewrite(memref::AllocaOp op, OpAdaptor operands, |
56 | ConversionPatternRewriter &rewriter) const override { |
57 | |
58 | if (!op.getType().hasStaticShape()) { |
59 | return rewriter.notifyMatchFailure( |
60 | op.getLoc(), "cannot transform alloca with dynamic shape" ); |
61 | } |
62 | |
63 | if (op.getAlignment().value_or(1) > 1) { |
64 | // TODO: Allow alignment if it is not more than the natural alignment |
65 | // of the C array. |
66 | return rewriter.notifyMatchFailure( |
67 | op.getLoc(), "cannot transform alloca with alignment requirement" ); |
68 | } |
69 | |
70 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
71 | if (!resultTy) { |
72 | return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type" ); |
73 | } |
74 | auto noInit = emitc::OpaqueAttr::get(getContext(), "" ); |
75 | rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit); |
76 | return success(); |
77 | } |
78 | }; |
79 | |
80 | struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { |
81 | using OpConversionPattern::OpConversionPattern; |
82 | |
83 | LogicalResult |
84 | matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, |
85 | ConversionPatternRewriter &rewriter) const override { |
86 | |
87 | if (!op.getType().hasStaticShape()) { |
88 | return rewriter.notifyMatchFailure( |
89 | op.getLoc(), "cannot transform global with dynamic shape" ); |
90 | } |
91 | |
92 | if (op.getAlignment().value_or(1) > 1) { |
93 | // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier. |
94 | return rewriter.notifyMatchFailure( |
95 | op.getLoc(), "global variable with alignment requirement is " |
96 | "currently not supported" ); |
97 | } |
98 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
99 | if (!resultTy) { |
100 | return rewriter.notifyMatchFailure(op.getLoc(), |
101 | "cannot convert result type" ); |
102 | } |
103 | |
104 | SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(symbol: op); |
105 | if (visibility != SymbolTable::Visibility::Public && |
106 | visibility != SymbolTable::Visibility::Private) { |
107 | return rewriter.notifyMatchFailure( |
108 | op.getLoc(), |
109 | "only public and private visibility is currently supported" ); |
110 | } |
111 | // We are explicit in specifing the linkage because the default linkage |
112 | // for constants is different in C and C++. |
113 | bool staticSpecifier = visibility == SymbolTable::Visibility::Private; |
114 | bool externSpecifier = !staticSpecifier; |
115 | |
116 | Attribute initialValue = operands.getInitialValueAttr(); |
117 | if (isa_and_present<UnitAttr>(Val: initialValue)) |
118 | initialValue = {}; |
119 | |
120 | rewriter.replaceOpWithNewOp<emitc::GlobalOp>( |
121 | op, operands.getSymName(), resultTy, initialValue, externSpecifier, |
122 | staticSpecifier, operands.getConstant()); |
123 | return success(); |
124 | } |
125 | }; |
126 | |
127 | struct ConvertGetGlobal final |
128 | : public OpConversionPattern<memref::GetGlobalOp> { |
129 | using OpConversionPattern::OpConversionPattern; |
130 | |
131 | LogicalResult |
132 | matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands, |
133 | ConversionPatternRewriter &rewriter) const override { |
134 | |
135 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
136 | if (!resultTy) { |
137 | return rewriter.notifyMatchFailure(op.getLoc(), |
138 | "cannot convert result type" ); |
139 | } |
140 | rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy, |
141 | operands.getNameAttr()); |
142 | return success(); |
143 | } |
144 | }; |
145 | |
146 | struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> { |
147 | using OpConversionPattern::OpConversionPattern; |
148 | |
149 | LogicalResult |
150 | matchAndRewrite(memref::LoadOp op, OpAdaptor operands, |
151 | ConversionPatternRewriter &rewriter) const override { |
152 | |
153 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
154 | if (!resultTy) { |
155 | return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type" ); |
156 | } |
157 | |
158 | auto arrayValue = |
159 | dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref()); |
160 | if (!arrayValue) { |
161 | return rewriter.notifyMatchFailure(op.getLoc(), "expected array type" ); |
162 | } |
163 | |
164 | auto subscript = rewriter.create<emitc::SubscriptOp>( |
165 | op.getLoc(), arrayValue, operands.getIndices()); |
166 | |
167 | rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript); |
168 | return success(); |
169 | } |
170 | }; |
171 | |
172 | struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { |
173 | using OpConversionPattern::OpConversionPattern; |
174 | |
175 | LogicalResult |
176 | matchAndRewrite(memref::StoreOp op, OpAdaptor operands, |
177 | ConversionPatternRewriter &rewriter) const override { |
178 | auto arrayValue = |
179 | dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref()); |
180 | if (!arrayValue) { |
181 | return rewriter.notifyMatchFailure(op.getLoc(), "expected array type" ); |
182 | } |
183 | |
184 | auto subscript = rewriter.create<emitc::SubscriptOp>( |
185 | op.getLoc(), arrayValue, operands.getIndices()); |
186 | rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript, |
187 | operands.getValue()); |
188 | return success(); |
189 | } |
190 | }; |
191 | } // namespace |
192 | |
193 | void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { |
194 | typeConverter.addConversion( |
195 | callback: [&](MemRefType memRefType) -> std::optional<Type> { |
196 | if (!memRefType.hasStaticShape() || |
197 | !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || |
198 | llvm::is_contained(memRefType.getShape(), 0)) { |
199 | return {}; |
200 | } |
201 | Type convertedElementType = |
202 | typeConverter.convertType(memRefType.getElementType()); |
203 | if (!convertedElementType) |
204 | return {}; |
205 | return emitc::ArrayType::get(memRefType.getShape(), |
206 | convertedElementType); |
207 | }); |
208 | |
209 | auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType, |
210 | ValueRange inputs, |
211 | Location loc) -> Value { |
212 | if (inputs.size() != 1) |
213 | return Value(); |
214 | |
215 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
216 | .getResult(0); |
217 | }; |
218 | |
219 | typeConverter.addSourceMaterialization(callback&: materializeAsUnrealizedCast); |
220 | typeConverter.addTargetMaterialization(callback&: materializeAsUnrealizedCast); |
221 | } |
222 | |
223 | void mlir::populateMemRefToEmitCConversionPatterns( |
224 | RewritePatternSet &patterns, const TypeConverter &converter) { |
225 | patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, |
226 | ConvertStore>(arg: converter, args: patterns.getContext()); |
227 | } |
228 | |