1 | //===- TypeToLLVM.cpp - type translation from MLIR to LLVM IR -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Target/LLVMIR/TypeToLLVM.h" |
10 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
11 | #include "mlir/IR/BuiltinTypes.h" |
12 | #include "mlir/IR/MLIRContext.h" |
13 | |
14 | #include "llvm/ADT/TypeSwitch.h" |
15 | #include "llvm/IR/DataLayout.h" |
16 | #include "llvm/IR/DerivedTypes.h" |
17 | #include "llvm/IR/Type.h" |
18 | |
19 | using namespace mlir; |
20 | |
21 | namespace mlir { |
22 | namespace LLVM { |
23 | namespace detail { |
24 | /// Support for translating MLIR LLVM dialect types to LLVM IR. |
25 | class TypeToLLVMIRTranslatorImpl { |
26 | public: |
27 | /// Constructs a class creating types in the given LLVM context. |
28 | TypeToLLVMIRTranslatorImpl(llvm::LLVMContext &context) : context(context) {} |
29 | |
30 | /// Translates a single type. |
31 | llvm::Type *translateType(Type type) { |
32 | // If the conversion is already known, just return it. |
33 | if (knownTranslations.count(Val: type)) |
34 | return knownTranslations.lookup(Val: type); |
35 | |
36 | // Dispatch to an appropriate function. |
37 | llvm::Type *translated = |
38 | llvm::TypeSwitch<Type, llvm::Type *>(type) |
39 | .Case([this](LLVM::LLVMVoidType) { |
40 | return llvm::Type::getVoidTy(context); |
41 | }) |
42 | .Case( |
43 | [this](Float16Type) { return llvm::Type::getHalfTy(context); }) |
44 | .Case([this](BFloat16Type) { |
45 | return llvm::Type::getBFloatTy(context); |
46 | }) |
47 | .Case( |
48 | [this](Float32Type) { return llvm::Type::getFloatTy(context); }) |
49 | .Case([this](Float64Type) { |
50 | return llvm::Type::getDoubleTy(context); |
51 | }) |
52 | .Case([this](Float80Type) { |
53 | return llvm::Type::getX86_FP80Ty(context); |
54 | }) |
55 | .Case([this](Float128Type) { |
56 | return llvm::Type::getFP128Ty(context); |
57 | }) |
58 | .Case([this](LLVM::LLVMPPCFP128Type) { |
59 | return llvm::Type::getPPC_FP128Ty(context); |
60 | }) |
61 | .Case([this](LLVM::LLVMX86MMXType) { |
62 | return llvm::Type::getX86_MMXTy(context); |
63 | }) |
64 | .Case([this](LLVM::LLVMTokenType) { |
65 | return llvm::Type::getTokenTy(context); |
66 | }) |
67 | .Case([this](LLVM::LLVMLabelType) { |
68 | return llvm::Type::getLabelTy(context); |
69 | }) |
70 | .Case([this](LLVM::LLVMMetadataType) { |
71 | return llvm::Type::getMetadataTy(context); |
72 | }) |
73 | .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType, |
74 | LLVM::LLVMPointerType, LLVM::LLVMStructType, |
75 | LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType, |
76 | VectorType, LLVM::LLVMTargetExtType>( |
77 | [this](auto type) { return this->translate(type); }) |
78 | .Default([](Type t) -> llvm::Type * { |
79 | llvm_unreachable("unknown LLVM dialect type" ); |
80 | }); |
81 | |
82 | // Cache the result of the conversion and return. |
83 | knownTranslations.try_emplace(Key: type, Args&: translated); |
84 | return translated; |
85 | } |
86 | |
87 | private: |
88 | /// Translates the given array type. |
89 | llvm::Type *translate(LLVM::LLVMArrayType type) { |
90 | return llvm::ArrayType::get(ElementType: translateType(type: type.getElementType()), |
91 | NumElements: type.getNumElements()); |
92 | } |
93 | |
94 | /// Translates the given function type. |
95 | llvm::Type *translate(LLVM::LLVMFunctionType type) { |
96 | SmallVector<llvm::Type *, 8> paramTypes; |
97 | translateTypes(types: type.getParams(), result&: paramTypes); |
98 | return llvm::FunctionType::get(translateType(type: type.getReturnType()), |
99 | paramTypes, type.isVarArg()); |
100 | } |
101 | |
102 | /// Translates the given integer type. |
103 | llvm::Type *translate(IntegerType type) { |
104 | return llvm::IntegerType::get(C&: context, NumBits: type.getWidth()); |
105 | } |
106 | |
107 | /// Translates the given pointer type. |
108 | llvm::Type *translate(LLVM::LLVMPointerType type) { |
109 | return llvm::PointerType::get(context, type.getAddressSpace()); |
110 | } |
111 | |
112 | /// Translates the given structure type, supports both identified and literal |
113 | /// structs. This will _create_ a new identified structure every time, use |
114 | /// `convertType` if a structure with the same name must be looked up instead. |
115 | llvm::Type *translate(LLVM::LLVMStructType type) { |
116 | SmallVector<llvm::Type *, 8> subtypes; |
117 | if (!type.isIdentified()) { |
118 | translateTypes(types: type.getBody(), result&: subtypes); |
119 | return llvm::StructType::get(Context&: context, Elements: subtypes, isPacked: type.isPacked()); |
120 | } |
121 | |
122 | llvm::StructType *structType = |
123 | llvm::StructType::create(Context&: context, Name: type.getName()); |
124 | // Mark the type we just created as known so that recursive calls can pick |
125 | // it up and use directly. |
126 | knownTranslations.try_emplace(type, structType); |
127 | if (type.isOpaque()) |
128 | return structType; |
129 | |
130 | translateTypes(types: type.getBody(), result&: subtypes); |
131 | structType->setBody(Elements: subtypes, isPacked: type.isPacked()); |
132 | return structType; |
133 | } |
134 | |
135 | /// Translates the given built-in vector type compatible with LLVM. |
136 | llvm::Type *translate(VectorType type) { |
137 | assert(LLVM::isCompatibleVectorType(type) && |
138 | "expected compatible with LLVM vector type" ); |
139 | if (type.isScalable()) |
140 | return llvm::ScalableVectorType::get(translateType(type: type.getElementType()), |
141 | type.getNumElements()); |
142 | return llvm::FixedVectorType::get(translateType(type: type.getElementType()), |
143 | type.getNumElements()); |
144 | } |
145 | |
146 | /// Translates the given fixed-vector type. |
147 | llvm::Type *translate(LLVM::LLVMFixedVectorType type) { |
148 | return llvm::FixedVectorType::get(translateType(type: type.getElementType()), |
149 | type.getNumElements()); |
150 | } |
151 | |
152 | /// Translates the given scalable-vector type. |
153 | llvm::Type *translate(LLVM::LLVMScalableVectorType type) { |
154 | return llvm::ScalableVectorType::get(translateType(type: type.getElementType()), |
155 | type.getMinNumElements()); |
156 | } |
157 | |
158 | /// Translates the given target extension type. |
159 | llvm::Type *translate(LLVM::LLVMTargetExtType type) { |
160 | SmallVector<llvm::Type *> typeParams; |
161 | translateTypes(types: type.getTypeParams(), result&: typeParams); |
162 | return llvm::TargetExtType::get(Context&: context, Name: type.getExtTypeName(), Types: typeParams, |
163 | Ints: type.getIntParams()); |
164 | } |
165 | |
166 | /// Translates a list of types. |
167 | void translateTypes(ArrayRef<Type> types, |
168 | SmallVectorImpl<llvm::Type *> &result) { |
169 | result.reserve(N: result.size() + types.size()); |
170 | for (auto type : types) |
171 | result.push_back(Elt: translateType(type)); |
172 | } |
173 | |
174 | /// Reference to the context in which the LLVM IR types are created. |
175 | llvm::LLVMContext &context; |
176 | |
177 | /// Map of known translation. This serves a double purpose: caches translation |
178 | /// results to avoid repeated recursive calls and makes sure identified |
179 | /// structs with the same name (that is, equal) are resolved to an existing |
180 | /// type instead of creating a new type. |
181 | llvm::DenseMap<Type, llvm::Type *> knownTranslations; |
182 | }; |
183 | } // namespace detail |
184 | } // namespace LLVM |
185 | } // namespace mlir |
186 | |
187 | LLVM::TypeToLLVMIRTranslator::TypeToLLVMIRTranslator(llvm::LLVMContext &context) |
188 | : impl(new detail::TypeToLLVMIRTranslatorImpl(context)) {} |
189 | |
190 | LLVM::TypeToLLVMIRTranslator::~TypeToLLVMIRTranslator() = default; |
191 | |
192 | llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(Type type) { |
193 | return impl->translateType(type); |
194 | } |
195 | |
196 | unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment( |
197 | Type type, const llvm::DataLayout &layout) { |
198 | return layout.getPrefTypeAlign(Ty: translateType(type)).value(); |
199 | } |
200 | |