1 | //===- TypeToLLVM.cpp - type translation from MLIR to LLVM IR -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Target/LLVMIR/TypeToLLVM.h" |
10 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
11 | #include "mlir/IR/BuiltinTypes.h" |
12 | #include "mlir/IR/MLIRContext.h" |
13 | |
14 | #include "llvm/ADT/TypeSwitch.h" |
15 | #include "llvm/IR/DataLayout.h" |
16 | #include "llvm/IR/DerivedTypes.h" |
17 | #include "llvm/IR/Type.h" |
18 | |
19 | using namespace mlir; |
20 | |
21 | namespace mlir { |
22 | namespace LLVM { |
23 | namespace detail { |
24 | /// Support for translating MLIR LLVM dialect types to LLVM IR. |
25 | class TypeToLLVMIRTranslatorImpl { |
26 | public: |
27 | /// Constructs a class creating types in the given LLVM context. |
28 | TypeToLLVMIRTranslatorImpl(llvm::LLVMContext &context) : context(context) {} |
29 | |
30 | /// Translates a single type. |
31 | llvm::Type *translateType(Type type) { |
32 | // If the conversion is already known, just return it. |
33 | if (knownTranslations.count(Val: type)) |
34 | return knownTranslations.lookup(Val: type); |
35 | |
36 | // Dispatch to an appropriate function. |
37 | llvm::Type *translated = |
38 | llvm::TypeSwitch<Type, llvm::Type *>(type) |
39 | .Case([this](LLVM::LLVMVoidType) { |
40 | return llvm::Type::getVoidTy(context); |
41 | }) |
42 | .Case( |
43 | [this](Float16Type) { return llvm::Type::getHalfTy(context); }) |
44 | .Case([this](BFloat16Type) { |
45 | return llvm::Type::getBFloatTy(context); |
46 | }) |
47 | .Case( |
48 | [this](Float32Type) { return llvm::Type::getFloatTy(context); }) |
49 | .Case([this](Float64Type) { |
50 | return llvm::Type::getDoubleTy(context); |
51 | }) |
52 | .Case([this](Float80Type) { |
53 | return llvm::Type::getX86_FP80Ty(context); |
54 | }) |
55 | .Case([this](Float128Type) { |
56 | return llvm::Type::getFP128Ty(context); |
57 | }) |
58 | .Case([this](LLVM::LLVMPPCFP128Type) { |
59 | return llvm::Type::getPPC_FP128Ty(context); |
60 | }) |
61 | .Case([this](LLVM::LLVMTokenType) { |
62 | return llvm::Type::getTokenTy(context); |
63 | }) |
64 | .Case([this](LLVM::LLVMLabelType) { |
65 | return llvm::Type::getLabelTy(context); |
66 | }) |
67 | .Case([this](LLVM::LLVMMetadataType) { |
68 | return llvm::Type::getMetadataTy(context); |
69 | }) |
70 | .Case([this](LLVM::LLVMX86AMXType) { |
71 | return llvm::Type::getX86_AMXTy(context); |
72 | }) |
73 | .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType, |
74 | LLVM::LLVMPointerType, LLVM::LLVMStructType, VectorType, |
75 | LLVM::LLVMTargetExtType>( |
76 | [this](auto type) { return this->translate(type); }) |
77 | .Default([](Type t) -> llvm::Type * { |
78 | llvm_unreachable("unknown LLVM dialect type" ); |
79 | }); |
80 | |
81 | // Cache the result of the conversion and return. |
82 | knownTranslations.try_emplace(Key: type, Args&: translated); |
83 | return translated; |
84 | } |
85 | |
86 | private: |
87 | /// Translates the given array type. |
88 | llvm::Type *translate(LLVM::LLVMArrayType type) { |
89 | return llvm::ArrayType::get(ElementType: translateType(type: type.getElementType()), |
90 | NumElements: type.getNumElements()); |
91 | } |
92 | |
93 | /// Translates the given function type. |
94 | llvm::Type *translate(LLVM::LLVMFunctionType type) { |
95 | SmallVector<llvm::Type *, 8> paramTypes; |
96 | translateTypes(types: type.getParams(), result&: paramTypes); |
97 | return llvm::FunctionType::get(translateType(type: type.getReturnType()), |
98 | paramTypes, type.isVarArg()); |
99 | } |
100 | |
101 | /// Translates the given integer type. |
102 | llvm::Type *translate(IntegerType type) { |
103 | return llvm::IntegerType::get(C&: context, NumBits: type.getWidth()); |
104 | } |
105 | |
106 | /// Translates the given pointer type. |
107 | llvm::Type *translate(LLVM::LLVMPointerType type) { |
108 | return llvm::PointerType::get(context, type.getAddressSpace()); |
109 | } |
110 | |
111 | /// Translates the given structure type, supports both identified and literal |
112 | /// structs. This will _create_ a new identified structure every time, use |
113 | /// `convertType` if a structure with the same name must be looked up instead. |
114 | llvm::Type *translate(LLVM::LLVMStructType type) { |
115 | SmallVector<llvm::Type *, 8> subtypes; |
116 | if (!type.isIdentified()) { |
117 | translateTypes(types: type.getBody(), result&: subtypes); |
118 | return llvm::StructType::get(context, subtypes, type.isPacked()); |
119 | } |
120 | |
121 | llvm::StructType *structType = |
122 | llvm::StructType::create(context, type.getName()); |
123 | // Mark the type we just created as known so that recursive calls can pick |
124 | // it up and use directly. |
125 | knownTranslations.try_emplace(type, structType); |
126 | if (type.isOpaque()) |
127 | return structType; |
128 | |
129 | translateTypes(types: type.getBody(), result&: subtypes); |
130 | structType->setBody(Elements: subtypes, isPacked: type.isPacked()); |
131 | return structType; |
132 | } |
133 | |
134 | /// Translates the given built-in vector type compatible with LLVM. |
135 | llvm::Type *translate(VectorType type) { |
136 | assert(LLVM::isCompatibleVectorType(type) && |
137 | "expected compatible with LLVM vector type" ); |
138 | if (type.isScalable()) |
139 | return llvm::ScalableVectorType::get(translateType(type: type.getElementType()), |
140 | type.getNumElements()); |
141 | return llvm::FixedVectorType::get(translateType(type: type.getElementType()), |
142 | type.getNumElements()); |
143 | } |
144 | |
145 | /// Translates the given target extension type. |
146 | llvm::Type *translate(LLVM::LLVMTargetExtType type) { |
147 | SmallVector<llvm::Type *> typeParams; |
148 | translateTypes(types: type.getTypeParams(), result&: typeParams); |
149 | return llvm::TargetExtType::get(Context&: context, Name: type.getExtTypeName(), Types: typeParams, |
150 | Ints: type.getIntParams()); |
151 | } |
152 | |
153 | /// Translates a list of types. |
154 | void translateTypes(ArrayRef<Type> types, |
155 | SmallVectorImpl<llvm::Type *> &result) { |
156 | result.reserve(N: result.size() + types.size()); |
157 | for (auto type : types) |
158 | result.push_back(Elt: translateType(type)); |
159 | } |
160 | |
161 | /// Reference to the context in which the LLVM IR types are created. |
162 | llvm::LLVMContext &context; |
163 | |
164 | /// Map of known translation. This serves a double purpose: caches translation |
165 | /// results to avoid repeated recursive calls and makes sure identified |
166 | /// structs with the same name (that is, equal) are resolved to an existing |
167 | /// type instead of creating a new type. |
168 | llvm::DenseMap<Type, llvm::Type *> knownTranslations; |
169 | }; |
170 | } // namespace detail |
171 | } // namespace LLVM |
172 | } // namespace mlir |
173 | |
174 | LLVM::TypeToLLVMIRTranslator::TypeToLLVMIRTranslator(llvm::LLVMContext &context) |
175 | : impl(new detail::TypeToLLVMIRTranslatorImpl(context)) {} |
176 | |
177 | LLVM::TypeToLLVMIRTranslator::~TypeToLLVMIRTranslator() = default; |
178 | |
179 | llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(Type type) { |
180 | return impl->translateType(type); |
181 | } |
182 | |
183 | unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment( |
184 | Type type, const llvm::DataLayout &layout) { |
185 | return layout.getPrefTypeAlign(Ty: translateType(type)).value(); |
186 | } |
187 | |