1 | //===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert memref ops into emitc ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" |
14 | |
15 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | #include "mlir/Transforms/DialectConversion.h" |
20 | |
21 | using namespace mlir; |
22 | |
23 | namespace { |
24 | struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> { |
25 | using OpConversionPattern::OpConversionPattern; |
26 | |
27 | LogicalResult |
28 | matchAndRewrite(memref::AllocaOp op, OpAdaptor operands, |
29 | ConversionPatternRewriter &rewriter) const override { |
30 | |
31 | if (!op.getType().hasStaticShape()) { |
32 | return rewriter.notifyMatchFailure( |
33 | op.getLoc(), "cannot transform alloca with dynamic shape" ); |
34 | } |
35 | |
36 | if (op.getAlignment().value_or(1) > 1) { |
37 | // TODO: Allow alignment if it is not more than the natural alignment |
38 | // of the C array. |
39 | return rewriter.notifyMatchFailure( |
40 | op.getLoc(), "cannot transform alloca with alignment requirement" ); |
41 | } |
42 | |
43 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
44 | if (!resultTy) { |
45 | return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type" ); |
46 | } |
47 | auto noInit = emitc::OpaqueAttr::get(getContext(), "" ); |
48 | rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit); |
49 | return success(); |
50 | } |
51 | }; |
52 | |
53 | struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { |
54 | using OpConversionPattern::OpConversionPattern; |
55 | |
56 | LogicalResult |
57 | matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, |
58 | ConversionPatternRewriter &rewriter) const override { |
59 | |
60 | if (!op.getType().hasStaticShape()) { |
61 | return rewriter.notifyMatchFailure( |
62 | op.getLoc(), "cannot transform global with dynamic shape" ); |
63 | } |
64 | |
65 | if (op.getAlignment().value_or(1) > 1) { |
66 | // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier. |
67 | return rewriter.notifyMatchFailure( |
68 | op.getLoc(), "global variable with alignment requirement is " |
69 | "currently not supported" ); |
70 | } |
71 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
72 | if (!resultTy) { |
73 | return rewriter.notifyMatchFailure(op.getLoc(), |
74 | "cannot convert result type" ); |
75 | } |
76 | |
77 | SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(symbol: op); |
78 | if (visibility != SymbolTable::Visibility::Public && |
79 | visibility != SymbolTable::Visibility::Private) { |
80 | return rewriter.notifyMatchFailure( |
81 | op.getLoc(), |
82 | "only public and private visibility is currently supported" ); |
83 | } |
84 | // We are explicit in specifing the linkage because the default linkage |
85 | // for constants is different in C and C++. |
86 | bool staticSpecifier = visibility == SymbolTable::Visibility::Private; |
87 | bool externSpecifier = !staticSpecifier; |
88 | |
89 | Attribute initialValue = operands.getInitialValueAttr(); |
90 | if (isa_and_present<UnitAttr>(Val: initialValue)) |
91 | initialValue = {}; |
92 | |
93 | rewriter.replaceOpWithNewOp<emitc::GlobalOp>( |
94 | op, operands.getSymName(), resultTy, initialValue, externSpecifier, |
95 | staticSpecifier, operands.getConstant()); |
96 | return success(); |
97 | } |
98 | }; |
99 | |
100 | struct ConvertGetGlobal final |
101 | : public OpConversionPattern<memref::GetGlobalOp> { |
102 | using OpConversionPattern::OpConversionPattern; |
103 | |
104 | LogicalResult |
105 | matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands, |
106 | ConversionPatternRewriter &rewriter) const override { |
107 | |
108 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
109 | if (!resultTy) { |
110 | return rewriter.notifyMatchFailure(op.getLoc(), |
111 | "cannot convert result type" ); |
112 | } |
113 | rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy, |
114 | operands.getNameAttr()); |
115 | return success(); |
116 | } |
117 | }; |
118 | |
119 | struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> { |
120 | using OpConversionPattern::OpConversionPattern; |
121 | |
122 | LogicalResult |
123 | matchAndRewrite(memref::LoadOp op, OpAdaptor operands, |
124 | ConversionPatternRewriter &rewriter) const override { |
125 | |
126 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
127 | if (!resultTy) { |
128 | return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type" ); |
129 | } |
130 | |
131 | auto arrayValue = |
132 | dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref()); |
133 | if (!arrayValue) { |
134 | return rewriter.notifyMatchFailure(op.getLoc(), "expected array type" ); |
135 | } |
136 | |
137 | auto subscript = rewriter.create<emitc::SubscriptOp>( |
138 | op.getLoc(), arrayValue, operands.getIndices()); |
139 | |
140 | auto noInit = emitc::OpaqueAttr::get(getContext(), "" ); |
141 | auto var = |
142 | rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit); |
143 | |
144 | rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript); |
145 | rewriter.replaceOp(op, var); |
146 | return success(); |
147 | } |
148 | }; |
149 | |
150 | struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { |
151 | using OpConversionPattern::OpConversionPattern; |
152 | |
153 | LogicalResult |
154 | matchAndRewrite(memref::StoreOp op, OpAdaptor operands, |
155 | ConversionPatternRewriter &rewriter) const override { |
156 | auto arrayValue = |
157 | dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref()); |
158 | if (!arrayValue) { |
159 | return rewriter.notifyMatchFailure(op.getLoc(), "expected array type" ); |
160 | } |
161 | |
162 | auto subscript = rewriter.create<emitc::SubscriptOp>( |
163 | op.getLoc(), arrayValue, operands.getIndices()); |
164 | rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript, |
165 | operands.getValue()); |
166 | return success(); |
167 | } |
168 | }; |
169 | } // namespace |
170 | |
171 | void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { |
172 | typeConverter.addConversion( |
173 | callback: [&](MemRefType memRefType) -> std::optional<Type> { |
174 | if (!memRefType.hasStaticShape() || |
175 | !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) { |
176 | return {}; |
177 | } |
178 | Type convertedElementType = |
179 | typeConverter.convertType(memRefType.getElementType()); |
180 | if (!convertedElementType) |
181 | return {}; |
182 | return emitc::ArrayType::get(memRefType.getShape(), |
183 | convertedElementType); |
184 | }); |
185 | } |
186 | |
187 | void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, |
188 | TypeConverter &converter) { |
189 | patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, |
190 | ConvertStore>(arg&: converter, args: patterns.getContext()); |
191 | } |
192 | |