1 | //===- TypeFromLLVM.cpp - type translation from LLVM to MLIR IR -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Target/LLVMIR/TypeFromLLVM.h" |
10 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
11 | #include "mlir/IR/BuiltinTypes.h" |
12 | #include "mlir/IR/MLIRContext.h" |
13 | |
14 | #include "llvm/ADT/TypeSwitch.h" |
15 | #include "llvm/IR/DerivedTypes.h" |
16 | #include "llvm/IR/Type.h" |
17 | |
18 | using namespace mlir; |
19 | |
20 | namespace mlir { |
21 | namespace LLVM { |
22 | namespace detail { |
23 | /// Support for translating LLVM IR types to MLIR LLVM dialect types. |
24 | class TypeFromLLVMIRTranslatorImpl { |
25 | public: |
26 | /// Constructs a class creating types in the given MLIR context. |
27 | TypeFromLLVMIRTranslatorImpl(MLIRContext &context, |
28 | bool importStructsAsLiterals) |
29 | : context(context), importStructsAsLiterals(importStructsAsLiterals) {} |
30 | |
31 | /// Translates the given type. |
32 | Type translateType(llvm::Type *type) { |
33 | if (knownTranslations.count(Val: type)) |
34 | return knownTranslations.lookup(Val: type); |
35 | |
36 | Type translated = |
37 | llvm::TypeSwitch<llvm::Type *, Type>(type) |
38 | .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType, |
39 | llvm::PointerType, llvm::StructType, llvm::FixedVectorType, |
40 | llvm::ScalableVectorType, llvm::TargetExtType>( |
41 | caseFn: [this](auto *type) { return this->translate(type); }) |
42 | .Default(defaultFn: [this](llvm::Type *type) { |
43 | return translatePrimitiveType(type); |
44 | }); |
45 | knownTranslations.try_emplace(Key: type, Args&: translated); |
46 | return translated; |
47 | } |
48 | |
49 | private: |
50 | /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, |
51 | /// type. |
52 | Type translatePrimitiveType(llvm::Type *type) { |
53 | if (type->isVoidTy()) |
54 | return LLVM::LLVMVoidType::get(ctx: &context); |
55 | if (type->isHalfTy()) |
56 | return Float16Type::get(&context); |
57 | if (type->isBFloatTy()) |
58 | return BFloat16Type::get(&context); |
59 | if (type->isFloatTy()) |
60 | return Float32Type::get(&context); |
61 | if (type->isDoubleTy()) |
62 | return Float64Type::get(&context); |
63 | if (type->isFP128Ty()) |
64 | return Float128Type::get(&context); |
65 | if (type->isX86_FP80Ty()) |
66 | return Float80Type::get(&context); |
67 | if (type->isX86_AMXTy()) |
68 | return LLVM::LLVMX86AMXType::get(&context); |
69 | if (type->isPPC_FP128Ty()) |
70 | return LLVM::LLVMPPCFP128Type::get(&context); |
71 | if (type->isLabelTy()) |
72 | return LLVM::LLVMLabelType::get(ctx: &context); |
73 | if (type->isMetadataTy()) |
74 | return LLVM::LLVMMetadataType::get(ctx: &context); |
75 | if (type->isTokenTy()) |
76 | return LLVM::LLVMTokenType::get(ctx: &context); |
77 | llvm_unreachable("not a primitive type" ); |
78 | } |
79 | |
80 | /// Translates the given array type. |
81 | Type translate(llvm::ArrayType *type) { |
82 | return LLVM::LLVMArrayType::get(translateType(type->getElementType()), |
83 | type->getNumElements()); |
84 | } |
85 | |
86 | /// Translates the given function type. |
87 | Type translate(llvm::FunctionType *type) { |
88 | SmallVector<Type, 8> paramTypes; |
89 | translateTypes(types: type->params(), result&: paramTypes); |
90 | return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), |
91 | paramTypes, type->isVarArg()); |
92 | } |
93 | |
94 | /// Translates the given integer type. |
95 | Type translate(llvm::IntegerType *type) { |
96 | return IntegerType::get(&context, type->getBitWidth()); |
97 | } |
98 | |
99 | /// Translates the given pointer type. |
100 | Type translate(llvm::PointerType *type) { |
101 | return LLVM::LLVMPointerType::get(&context, type->getAddressSpace()); |
102 | } |
103 | |
104 | /// Translates the given structure type. |
105 | Type translate(llvm::StructType *type) { |
106 | SmallVector<Type, 8> subtypes; |
107 | if (type->isLiteral() || importStructsAsLiterals) { |
108 | translateTypes(types: type->subtypes(), result&: subtypes); |
109 | return LLVM::LLVMStructType::getLiteral(&context, subtypes, |
110 | type->isPacked()); |
111 | } |
112 | |
113 | if (type->isOpaque()) |
114 | return LLVM::LLVMStructType::getOpaque(type->getName(), &context); |
115 | |
116 | // With opaque pointers, types in LLVM can't be recursive anymore. Note that |
117 | // using getIdentified is not possible, as type names in LLVM are not |
118 | // guaranteed to be unique. |
119 | translateTypes(types: type->subtypes(), result&: subtypes); |
120 | LLVM::LLVMStructType translated = LLVM::LLVMStructType::getNewIdentified( |
121 | &context, type->getName(), subtypes, type->isPacked()); |
122 | knownTranslations.try_emplace(type, translated); |
123 | return translated; |
124 | } |
125 | |
126 | /// Translates the given fixed-vector type. |
127 | Type translate(llvm::FixedVectorType *type) { |
128 | return VectorType::get(type->getNumElements(), |
129 | translateType(type->getElementType())); |
130 | } |
131 | |
132 | /// Translates the given scalable-vector type. |
133 | Type translate(llvm::ScalableVectorType *type) { |
134 | return VectorType::get(type->getMinNumElements(), |
135 | translateType(type->getElementType()), |
136 | /*scalableDims=*/true); |
137 | } |
138 | |
139 | /// Translates the given target extension type. |
140 | Type translate(llvm::TargetExtType *type) { |
141 | SmallVector<Type> typeParams; |
142 | translateTypes(types: type->type_params(), result&: typeParams); |
143 | |
144 | return LLVM::LLVMTargetExtType::get(&context, type->getName(), typeParams, |
145 | type->int_params()); |
146 | } |
147 | |
148 | /// Translates a list of types. |
149 | void translateTypes(ArrayRef<llvm::Type *> types, |
150 | SmallVectorImpl<Type> &result) { |
151 | result.reserve(N: result.size() + types.size()); |
152 | for (llvm::Type *type : types) |
153 | result.push_back(Elt: translateType(type)); |
154 | } |
155 | |
156 | /// Map of known translations. Serves as a cache and as recursion stopper for |
157 | /// translating recursive structs. |
158 | llvm::DenseMap<llvm::Type *, Type> knownTranslations; |
159 | |
160 | /// The context in which MLIR types are created. |
161 | MLIRContext &context; |
162 | |
163 | /// Controls if structs should be imported as literal structs, i.e., nameless |
164 | /// structs. |
165 | bool importStructsAsLiterals; |
166 | }; |
167 | |
168 | } // namespace detail |
169 | } // namespace LLVM |
170 | } // namespace mlir |
171 | |
172 | LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator( |
173 | MLIRContext &context, bool importStructsAsLiterals) |
174 | : impl(std::make_unique<detail::TypeFromLLVMIRTranslatorImpl>( |
175 | args&: context, args&: importStructsAsLiterals)) {} |
176 | |
177 | LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() = default; |
178 | |
179 | Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) { |
180 | return impl->translateType(type); |
181 | } |
182 | |