1 | //===- TypeFromLLVM.cpp - type translation from LLVM to MLIR IR -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Target/LLVMIR/TypeFromLLVM.h" |
10 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
11 | #include "mlir/IR/BuiltinTypes.h" |
12 | #include "mlir/IR/MLIRContext.h" |
13 | |
14 | #include "llvm/ADT/TypeSwitch.h" |
15 | #include "llvm/IR/DataLayout.h" |
16 | #include "llvm/IR/DerivedTypes.h" |
17 | #include "llvm/IR/Type.h" |
18 | |
19 | using namespace mlir; |
20 | |
21 | namespace mlir { |
22 | namespace LLVM { |
23 | namespace detail { |
24 | /// Support for translating LLVM IR types to MLIR LLVM dialect types. |
25 | class TypeFromLLVMIRTranslatorImpl { |
26 | public: |
27 | /// Constructs a class creating types in the given MLIR context. |
28 | TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {} |
29 | |
30 | /// Translates the given type. |
31 | Type translateType(llvm::Type *type) { |
32 | if (knownTranslations.count(Val: type)) |
33 | return knownTranslations.lookup(Val: type); |
34 | |
35 | Type translated = |
36 | llvm::TypeSwitch<llvm::Type *, Type>(type) |
37 | .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType, |
38 | llvm::PointerType, llvm::StructType, llvm::FixedVectorType, |
39 | llvm::ScalableVectorType, llvm::TargetExtType>( |
40 | caseFn: [this](auto *type) { return this->translate(type); }) |
41 | .Default(defaultFn: [this](llvm::Type *type) { |
42 | return translatePrimitiveType(type); |
43 | }); |
44 | knownTranslations.try_emplace(Key: type, Args&: translated); |
45 | return translated; |
46 | } |
47 | |
48 | private: |
49 | /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, |
50 | /// type. |
51 | Type translatePrimitiveType(llvm::Type *type) { |
52 | if (type->isVoidTy()) |
53 | return LLVM::LLVMVoidType::get(ctx: &context); |
54 | if (type->isHalfTy()) |
55 | return Float16Type::get(&context); |
56 | if (type->isBFloatTy()) |
57 | return BFloat16Type::get(&context); |
58 | if (type->isFloatTy()) |
59 | return Float32Type::get(&context); |
60 | if (type->isDoubleTy()) |
61 | return Float64Type::get(&context); |
62 | if (type->isFP128Ty()) |
63 | return Float128Type::get(&context); |
64 | if (type->isX86_FP80Ty()) |
65 | return Float80Type::get(&context); |
66 | if (type->isPPC_FP128Ty()) |
67 | return LLVM::LLVMPPCFP128Type::get(ctx: &context); |
68 | if (type->isX86_MMXTy()) |
69 | return LLVM::LLVMX86MMXType::get(ctx: &context); |
70 | if (type->isLabelTy()) |
71 | return LLVM::LLVMLabelType::get(ctx: &context); |
72 | if (type->isMetadataTy()) |
73 | return LLVM::LLVMMetadataType::get(ctx: &context); |
74 | if (type->isTokenTy()) |
75 | return LLVM::LLVMTokenType::get(ctx: &context); |
76 | llvm_unreachable("not a primitive type" ); |
77 | } |
78 | |
79 | /// Translates the given array type. |
80 | Type translate(llvm::ArrayType *type) { |
81 | return LLVM::LLVMArrayType::get(translateType(type->getElementType()), |
82 | type->getNumElements()); |
83 | } |
84 | |
85 | /// Translates the given function type. |
86 | Type translate(llvm::FunctionType *type) { |
87 | SmallVector<Type, 8> paramTypes; |
88 | translateTypes(types: type->params(), result&: paramTypes); |
89 | return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), |
90 | paramTypes, type->isVarArg()); |
91 | } |
92 | |
93 | /// Translates the given integer type. |
94 | Type translate(llvm::IntegerType *type) { |
95 | return IntegerType::get(&context, type->getBitWidth()); |
96 | } |
97 | |
98 | /// Translates the given pointer type. |
99 | Type translate(llvm::PointerType *type) { |
100 | return LLVM::LLVMPointerType::get(&context, type->getAddressSpace()); |
101 | } |
102 | |
103 | /// Translates the given structure type. |
104 | Type translate(llvm::StructType *type) { |
105 | SmallVector<Type, 8> subtypes; |
106 | if (type->isLiteral()) { |
107 | translateTypes(types: type->subtypes(), result&: subtypes); |
108 | return LLVM::LLVMStructType::getLiteral(context: &context, types: subtypes, |
109 | isPacked: type->isPacked()); |
110 | } |
111 | |
112 | if (type->isOpaque()) |
113 | return LLVM::LLVMStructType::getOpaque(name: type->getName(), context: &context); |
114 | |
115 | // With opaque pointers, types in LLVM can't be recursive anymore. Note that |
116 | // using getIdentified is not possible, as type names in LLVM are not |
117 | // guaranteed to be unique. |
118 | translateTypes(types: type->subtypes(), result&: subtypes); |
119 | LLVM::LLVMStructType translated = LLVM::LLVMStructType::getNewIdentified( |
120 | context: &context, name: type->getName(), elements: subtypes, isPacked: type->isPacked()); |
121 | knownTranslations.try_emplace(Key: type, Args&: translated); |
122 | return translated; |
123 | } |
124 | |
125 | /// Translates the given fixed-vector type. |
126 | Type translate(llvm::FixedVectorType *type) { |
127 | return LLVM::getFixedVectorType(elementType: translateType(type: type->getElementType()), |
128 | numElements: type->getNumElements()); |
129 | } |
130 | |
131 | /// Translates the given scalable-vector type. |
132 | Type translate(llvm::ScalableVectorType *type) { |
133 | return LLVM::LLVMScalableVectorType::get( |
134 | translateType(type->getElementType()), type->getMinNumElements()); |
135 | } |
136 | |
137 | /// Translates the given target extension type. |
138 | Type translate(llvm::TargetExtType *type) { |
139 | SmallVector<Type> typeParams; |
140 | translateTypes(types: type->type_params(), result&: typeParams); |
141 | |
142 | return LLVM::LLVMTargetExtType::get(&context, type->getName(), typeParams, |
143 | type->int_params()); |
144 | } |
145 | |
146 | /// Translates a list of types. |
147 | void translateTypes(ArrayRef<llvm::Type *> types, |
148 | SmallVectorImpl<Type> &result) { |
149 | result.reserve(N: result.size() + types.size()); |
150 | for (llvm::Type *type : types) |
151 | result.push_back(Elt: translateType(type)); |
152 | } |
153 | |
154 | /// Map of known translations. Serves as a cache and as recursion stopper for |
155 | /// translating recursive structs. |
156 | llvm::DenseMap<llvm::Type *, Type> knownTranslations; |
157 | |
158 | /// The context in which MLIR types are created. |
159 | MLIRContext &context; |
160 | }; |
161 | |
162 | } // namespace detail |
163 | } // namespace LLVM |
164 | } // namespace mlir |
165 | |
166 | LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context) |
167 | : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {} |
168 | |
169 | LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() = default; |
170 | |
171 | Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) { |
172 | return impl->translateType(type); |
173 | } |
174 | |