| 1 | //===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements patterns to convert memref ops into emitc ops. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" |
| 14 | |
| 15 | #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" |
| 16 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
| 17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 18 | #include "mlir/IR/Builders.h" |
| 19 | #include "mlir/IR/PatternMatch.h" |
| 20 | #include "mlir/Transforms/DialectConversion.h" |
| 21 | |
| 22 | using namespace mlir; |
| 23 | |
| 24 | namespace { |
| 25 | /// Implement the interface to convert MemRef to EmitC. |
| 26 | struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { |
| 27 | using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface; |
| 28 | |
| 29 | /// Hook for derived dialect interface to provide conversion patterns |
| 30 | /// and mark dialect legal for the conversion target. |
| 31 | void populateConvertToEmitCConversionPatterns( |
| 32 | ConversionTarget &target, TypeConverter &typeConverter, |
| 33 | RewritePatternSet &patterns) const final { |
| 34 | populateMemRefToEmitCTypeConversion(typeConverter); |
| 35 | populateMemRefToEmitCConversionPatterns(patterns, converter: typeConverter); |
| 36 | } |
| 37 | }; |
| 38 | } // namespace |
| 39 | |
| 40 | void mlir::registerConvertMemRefToEmitCInterface(DialectRegistry ®istry) { |
| 41 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) { |
| 42 | dialect->addInterfaces<MemRefToEmitCDialectInterface>(); |
| 43 | }); |
| 44 | } |
| 45 | |
| 46 | //===----------------------------------------------------------------------===// |
| 47 | // Conversion Patterns |
| 48 | //===----------------------------------------------------------------------===// |
| 49 | |
| 50 | namespace { |
| 51 | struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> { |
| 52 | using OpConversionPattern::OpConversionPattern; |
| 53 | |
| 54 | LogicalResult |
| 55 | matchAndRewrite(memref::AllocaOp op, OpAdaptor operands, |
| 56 | ConversionPatternRewriter &rewriter) const override { |
| 57 | |
| 58 | if (!op.getType().hasStaticShape()) { |
| 59 | return rewriter.notifyMatchFailure( |
| 60 | op.getLoc(), "cannot transform alloca with dynamic shape" ); |
| 61 | } |
| 62 | |
| 63 | if (op.getAlignment().value_or(1) > 1) { |
| 64 | // TODO: Allow alignment if it is not more than the natural alignment |
| 65 | // of the C array. |
| 66 | return rewriter.notifyMatchFailure( |
| 67 | op.getLoc(), "cannot transform alloca with alignment requirement" ); |
| 68 | } |
| 69 | |
| 70 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
| 71 | if (!resultTy) { |
| 72 | return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type" ); |
| 73 | } |
| 74 | auto noInit = emitc::OpaqueAttr::get(getContext(), "" ); |
| 75 | rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit); |
| 76 | return success(); |
| 77 | } |
| 78 | }; |
| 79 | |
| 80 | struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { |
| 81 | using OpConversionPattern::OpConversionPattern; |
| 82 | |
| 83 | LogicalResult |
| 84 | matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, |
| 85 | ConversionPatternRewriter &rewriter) const override { |
| 86 | |
| 87 | if (!op.getType().hasStaticShape()) { |
| 88 | return rewriter.notifyMatchFailure( |
| 89 | op.getLoc(), "cannot transform global with dynamic shape" ); |
| 90 | } |
| 91 | |
| 92 | if (op.getAlignment().value_or(1) > 1) { |
| 93 | // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier. |
| 94 | return rewriter.notifyMatchFailure( |
| 95 | op.getLoc(), "global variable with alignment requirement is " |
| 96 | "currently not supported" ); |
| 97 | } |
| 98 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
| 99 | if (!resultTy) { |
| 100 | return rewriter.notifyMatchFailure(op.getLoc(), |
| 101 | "cannot convert result type" ); |
| 102 | } |
| 103 | |
| 104 | SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(symbol: op); |
| 105 | if (visibility != SymbolTable::Visibility::Public && |
| 106 | visibility != SymbolTable::Visibility::Private) { |
| 107 | return rewriter.notifyMatchFailure( |
| 108 | op.getLoc(), |
| 109 | "only public and private visibility is currently supported" ); |
| 110 | } |
| 111 | // We are explicit in specifing the linkage because the default linkage |
| 112 | // for constants is different in C and C++. |
| 113 | bool staticSpecifier = visibility == SymbolTable::Visibility::Private; |
| 114 | bool externSpecifier = !staticSpecifier; |
| 115 | |
| 116 | Attribute initialValue = operands.getInitialValueAttr(); |
| 117 | if (isa_and_present<UnitAttr>(Val: initialValue)) |
| 118 | initialValue = {}; |
| 119 | |
| 120 | rewriter.replaceOpWithNewOp<emitc::GlobalOp>( |
| 121 | op, operands.getSymName(), resultTy, initialValue, externSpecifier, |
| 122 | staticSpecifier, operands.getConstant()); |
| 123 | return success(); |
| 124 | } |
| 125 | }; |
| 126 | |
| 127 | struct ConvertGetGlobal final |
| 128 | : public OpConversionPattern<memref::GetGlobalOp> { |
| 129 | using OpConversionPattern::OpConversionPattern; |
| 130 | |
| 131 | LogicalResult |
| 132 | matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands, |
| 133 | ConversionPatternRewriter &rewriter) const override { |
| 134 | |
| 135 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
| 136 | if (!resultTy) { |
| 137 | return rewriter.notifyMatchFailure(op.getLoc(), |
| 138 | "cannot convert result type" ); |
| 139 | } |
| 140 | rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy, |
| 141 | operands.getNameAttr()); |
| 142 | return success(); |
| 143 | } |
| 144 | }; |
| 145 | |
| 146 | struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> { |
| 147 | using OpConversionPattern::OpConversionPattern; |
| 148 | |
| 149 | LogicalResult |
| 150 | matchAndRewrite(memref::LoadOp op, OpAdaptor operands, |
| 151 | ConversionPatternRewriter &rewriter) const override { |
| 152 | |
| 153 | auto resultTy = getTypeConverter()->convertType(op.getType()); |
| 154 | if (!resultTy) { |
| 155 | return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type" ); |
| 156 | } |
| 157 | |
| 158 | auto arrayValue = |
| 159 | dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref()); |
| 160 | if (!arrayValue) { |
| 161 | return rewriter.notifyMatchFailure(op.getLoc(), "expected array type" ); |
| 162 | } |
| 163 | |
| 164 | auto subscript = rewriter.create<emitc::SubscriptOp>( |
| 165 | op.getLoc(), arrayValue, operands.getIndices()); |
| 166 | |
| 167 | rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript); |
| 168 | return success(); |
| 169 | } |
| 170 | }; |
| 171 | |
| 172 | struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { |
| 173 | using OpConversionPattern::OpConversionPattern; |
| 174 | |
| 175 | LogicalResult |
| 176 | matchAndRewrite(memref::StoreOp op, OpAdaptor operands, |
| 177 | ConversionPatternRewriter &rewriter) const override { |
| 178 | auto arrayValue = |
| 179 | dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref()); |
| 180 | if (!arrayValue) { |
| 181 | return rewriter.notifyMatchFailure(op.getLoc(), "expected array type" ); |
| 182 | } |
| 183 | |
| 184 | auto subscript = rewriter.create<emitc::SubscriptOp>( |
| 185 | op.getLoc(), arrayValue, operands.getIndices()); |
| 186 | rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript, |
| 187 | operands.getValue()); |
| 188 | return success(); |
| 189 | } |
| 190 | }; |
| 191 | } // namespace |
| 192 | |
| 193 | void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { |
| 194 | typeConverter.addConversion( |
| 195 | callback: [&](MemRefType memRefType) -> std::optional<Type> { |
| 196 | if (!memRefType.hasStaticShape() || |
| 197 | !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || |
| 198 | llvm::is_contained(memRefType.getShape(), 0)) { |
| 199 | return {}; |
| 200 | } |
| 201 | Type convertedElementType = |
| 202 | typeConverter.convertType(memRefType.getElementType()); |
| 203 | if (!convertedElementType) |
| 204 | return {}; |
| 205 | return emitc::ArrayType::get(memRefType.getShape(), |
| 206 | convertedElementType); |
| 207 | }); |
| 208 | |
| 209 | auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType, |
| 210 | ValueRange inputs, |
| 211 | Location loc) -> Value { |
| 212 | if (inputs.size() != 1) |
| 213 | return Value(); |
| 214 | |
| 215 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
| 216 | .getResult(0); |
| 217 | }; |
| 218 | |
| 219 | typeConverter.addSourceMaterialization(callback&: materializeAsUnrealizedCast); |
| 220 | typeConverter.addTargetMaterialization(callback&: materializeAsUnrealizedCast); |
| 221 | } |
| 222 | |
| 223 | void mlir::populateMemRefToEmitCConversionPatterns( |
| 224 | RewritePatternSet &patterns, const TypeConverter &converter) { |
| 225 | patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, |
| 226 | ConvertStore>(arg: converter, args: patterns.getContext()); |
| 227 | } |
| 228 | |