1 | //===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the types for the LLVM dialect in MLIR. These MLIR types |
10 | // correspond to the LLVM IR type system. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "TypeDetail.h" |
15 | |
16 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
17 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
18 | #include "mlir/IR/BuiltinTypes.h" |
19 | #include "mlir/IR/DialectImplementation.h" |
20 | #include "mlir/IR/TypeSupport.h" |
21 | |
22 | #include "llvm/ADT/ScopeExit.h" |
23 | #include "llvm/ADT/TypeSwitch.h" |
24 | #include "llvm/Support/TypeSize.h" |
25 | #include <optional> |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::LLVM; |
29 | |
30 | constexpr const static uint64_t kBitsInByte = 8; |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // custom<FunctionTypes> |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> ¶ms, |
37 | bool &isVarArg) { |
38 | isVarArg = false; |
39 | // `(` `)` |
40 | if (succeeded(result: p.parseOptionalRParen())) |
41 | return success(); |
42 | |
43 | // `(` `...` `)` |
44 | if (succeeded(result: p.parseOptionalEllipsis())) { |
45 | isVarArg = true; |
46 | return p.parseRParen(); |
47 | } |
48 | |
49 | // type (`,` type)* (`,` `...`)? |
50 | Type type; |
51 | if (parsePrettyLLVMType(p, type)) |
52 | return failure(); |
53 | params.push_back(Elt: type); |
54 | while (succeeded(result: p.parseOptionalComma())) { |
55 | if (succeeded(result: p.parseOptionalEllipsis())) { |
56 | isVarArg = true; |
57 | return p.parseRParen(); |
58 | } |
59 | if (parsePrettyLLVMType(p, type)) |
60 | return failure(); |
61 | params.push_back(Elt: type); |
62 | } |
63 | return p.parseRParen(); |
64 | } |
65 | |
66 | static void printFunctionTypes(AsmPrinter &p, ArrayRef<Type> params, |
67 | bool isVarArg) { |
68 | llvm::interleaveComma(c: params, os&: p, |
69 | each_fn: [&](Type type) { printPrettyLLVMType(p, type); }); |
70 | if (isVarArg) { |
71 | if (!params.empty()) |
72 | p << ", " ; |
73 | p << "..." ; |
74 | } |
75 | p << ')'; |
76 | } |
77 | |
78 | //===----------------------------------------------------------------------===// |
79 | // custom<ExtTypeParams> |
80 | //===----------------------------------------------------------------------===// |
81 | |
82 | /// Parses the parameter list for a target extension type. The parameter list |
83 | /// contains an optional list of type parameters, followed by an optional list |
84 | /// of integer parameters. Type and integer parameters cannot be interleaved in |
85 | /// the list. |
86 | /// extTypeParams ::= typeList? | intList? | (typeList "," intList) |
87 | /// typeList ::= type ("," type)* |
88 | /// intList ::= integer ("," integer)* |
89 | static ParseResult |
90 | parseExtTypeParams(AsmParser &p, SmallVectorImpl<Type> &typeParams, |
91 | SmallVectorImpl<unsigned int> &intParams) { |
92 | bool parseType = true; |
93 | auto typeOrIntParser = [&]() -> ParseResult { |
94 | unsigned int i; |
95 | auto intResult = p.parseOptionalInteger(result&: i); |
96 | if (intResult.has_value() && !failed(result: *intResult)) { |
97 | // Successfully parsed an integer. |
98 | intParams.push_back(Elt: i); |
99 | // After the first integer was successfully parsed, no |
100 | // more types can be parsed. |
101 | parseType = false; |
102 | return success(); |
103 | } |
104 | if (parseType) { |
105 | Type t; |
106 | if (!parsePrettyLLVMType(p, type&: t)) { |
107 | // Successfully parsed a type. |
108 | typeParams.push_back(Elt: t); |
109 | return success(); |
110 | } |
111 | } |
112 | return failure(); |
113 | }; |
114 | if (p.parseCommaSeparatedList(parseElementFn: typeOrIntParser)) { |
115 | p.emitError(loc: p.getCurrentLocation(), |
116 | message: "failed to parse parameter list for target extension type" ); |
117 | return failure(); |
118 | } |
119 | return success(); |
120 | } |
121 | |
122 | static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams, |
123 | ArrayRef<unsigned int> intParams) { |
124 | p << typeParams; |
125 | if (!typeParams.empty() && !intParams.empty()) |
126 | p << ", " ; |
127 | |
128 | p << intParams; |
129 | } |
130 | |
131 | //===----------------------------------------------------------------------===// |
132 | // ODS-Generated Definitions |
133 | //===----------------------------------------------------------------------===// |
134 | |
135 | /// These are unused for now. |
136 | /// TODO: Move over to these once more types have been migrated to TypeDef. |
137 | LLVM_ATTRIBUTE_UNUSED static OptionalParseResult |
138 | generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); |
139 | LLVM_ATTRIBUTE_UNUSED static LogicalResult |
140 | generatedTypePrinter(Type def, AsmPrinter &printer); |
141 | |
142 | #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc" |
143 | |
144 | #define GET_TYPEDEF_CLASSES |
145 | #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc" |
146 | |
147 | //===----------------------------------------------------------------------===// |
148 | // LLVMArrayType |
149 | //===----------------------------------------------------------------------===// |
150 | |
151 | bool LLVMArrayType::isValidElementType(Type type) { |
152 | return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType, |
153 | LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>( |
154 | type); |
155 | } |
156 | |
157 | LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) { |
158 | assert(elementType && "expected non-null subtype" ); |
159 | return Base::get(elementType.getContext(), elementType, numElements); |
160 | } |
161 | |
162 | LLVMArrayType |
163 | LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
164 | Type elementType, unsigned numElements) { |
165 | assert(elementType && "expected non-null subtype" ); |
166 | return Base::getChecked(emitError, elementType.getContext(), elementType, |
167 | numElements); |
168 | } |
169 | |
170 | LogicalResult |
171 | LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError, |
172 | Type elementType, unsigned numElements) { |
173 | if (!isValidElementType(elementType)) |
174 | return emitError() << "invalid array element type: " << elementType; |
175 | return success(); |
176 | } |
177 | |
178 | //===----------------------------------------------------------------------===// |
179 | // DataLayoutTypeInterface |
180 | |
181 | llvm::TypeSize |
182 | LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout, |
183 | DataLayoutEntryListRef params) const { |
184 | return llvm::TypeSize::getFixed(kBitsInByte * |
185 | getTypeSize(dataLayout, params)); |
186 | } |
187 | |
188 | llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout, |
189 | DataLayoutEntryListRef params) const { |
190 | return llvm::alignTo(dataLayout.getTypeSize(getElementType()), |
191 | dataLayout.getTypeABIAlignment(getElementType())) * |
192 | getNumElements(); |
193 | } |
194 | |
195 | uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout, |
196 | DataLayoutEntryListRef params) const { |
197 | return dataLayout.getTypeABIAlignment(getElementType()); |
198 | } |
199 | |
200 | uint64_t |
201 | LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout, |
202 | DataLayoutEntryListRef params) const { |
203 | return dataLayout.getTypePreferredAlignment(getElementType()); |
204 | } |
205 | |
206 | //===----------------------------------------------------------------------===// |
207 | // Function type. |
208 | //===----------------------------------------------------------------------===// |
209 | |
210 | bool LLVMFunctionType::isValidArgumentType(Type type) { |
211 | return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type); |
212 | } |
213 | |
214 | bool LLVMFunctionType::isValidResultType(Type type) { |
215 | return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type); |
216 | } |
217 | |
218 | LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments, |
219 | bool isVarArg) { |
220 | assert(result && "expected non-null result" ); |
221 | return Base::get(result.getContext(), result, arguments, isVarArg); |
222 | } |
223 | |
224 | LLVMFunctionType |
225 | LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
226 | Type result, ArrayRef<Type> arguments, |
227 | bool isVarArg) { |
228 | assert(result && "expected non-null result" ); |
229 | return Base::getChecked(emitError, result.getContext(), result, arguments, |
230 | isVarArg); |
231 | } |
232 | |
233 | LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs, |
234 | TypeRange results) const { |
235 | assert(results.size() == 1 && "expected a single result type" ); |
236 | return get(results[0], llvm::to_vector(inputs), isVarArg()); |
237 | } |
238 | |
239 | ArrayRef<Type> LLVMFunctionType::getReturnTypes() const { |
240 | return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType; |
241 | } |
242 | |
243 | LogicalResult |
244 | LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError, |
245 | Type result, ArrayRef<Type> arguments, bool) { |
246 | if (!isValidResultType(result)) |
247 | return emitError() << "invalid function result type: " << result; |
248 | |
249 | for (Type arg : arguments) |
250 | if (!isValidArgumentType(arg)) |
251 | return emitError() << "invalid function argument type: " << arg; |
252 | |
253 | return success(); |
254 | } |
255 | |
256 | //===----------------------------------------------------------------------===// |
257 | // DataLayoutTypeInterface |
258 | |
259 | constexpr const static uint64_t kDefaultPointerSizeBits = 64; |
260 | constexpr const static uint64_t kDefaultPointerAlignment = 8; |
261 | |
262 | std::optional<uint64_t> mlir::LLVM::(Attribute attr, |
263 | PtrDLEntryPos pos) { |
264 | auto spec = cast<DenseIntElementsAttr>(Val&: attr); |
265 | auto idx = static_cast<int64_t>(pos); |
266 | if (idx >= spec.size()) |
267 | return std::nullopt; |
268 | return spec.getValues<uint64_t>()[idx]; |
269 | } |
270 | |
271 | /// Returns the part of the data layout entry that corresponds to `pos` for the |
272 | /// given `type` by interpreting the list of entries `params`. For the pointer |
273 | /// type in the default address space, returns the default value if the entries |
274 | /// do not provide a custom one, for other address spaces returns std::nullopt. |
275 | static std::optional<uint64_t> |
276 | getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type, |
277 | PtrDLEntryPos pos) { |
278 | // First, look for the entry for the pointer in the current address space. |
279 | Attribute currentEntry; |
280 | for (DataLayoutEntryInterface entry : params) { |
281 | if (!entry.isTypeEntry()) |
282 | continue; |
283 | if (cast<LLVMPointerType>(entry.getKey().get<Type>()).getAddressSpace() == |
284 | type.getAddressSpace()) { |
285 | currentEntry = entry.getValue(); |
286 | break; |
287 | } |
288 | } |
289 | if (currentEntry) { |
290 | std::optional<uint64_t> value = extractPointerSpecValue(attr: currentEntry, pos); |
291 | // If the optional `PtrDLEntryPos::Index` entry is not available, use the |
292 | // pointer size as the index bitwidth. |
293 | if (!value && pos == PtrDLEntryPos::Index) |
294 | value = extractPointerSpecValue(attr: currentEntry, pos: PtrDLEntryPos::Size); |
295 | bool isSizeOrIndex = |
296 | pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index; |
297 | return *value / (isSizeOrIndex ? 1 : kBitsInByte); |
298 | } |
299 | |
300 | // If not found, and this is the pointer to the default memory space, assume |
301 | // 64-bit pointers. |
302 | if (type.getAddressSpace() == 0) { |
303 | bool isSizeOrIndex = |
304 | pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index; |
305 | return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment; |
306 | } |
307 | |
308 | return std::nullopt; |
309 | } |
310 | |
311 | llvm::TypeSize |
312 | LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout, |
313 | DataLayoutEntryListRef params) const { |
314 | if (std::optional<uint64_t> size = |
315 | getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Size)) |
316 | return llvm::TypeSize::getFixed(*size); |
317 | |
318 | // For other memory spaces, use the size of the pointer to the default memory |
319 | // space. |
320 | return dataLayout.getTypeSizeInBits(get(getContext())); |
321 | } |
322 | |
323 | uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout, |
324 | DataLayoutEntryListRef params) const { |
325 | if (std::optional<uint64_t> alignment = |
326 | getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Abi)) |
327 | return *alignment; |
328 | |
329 | return dataLayout.getTypeABIAlignment(get(getContext())); |
330 | } |
331 | |
332 | uint64_t |
333 | LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout, |
334 | DataLayoutEntryListRef params) const { |
335 | if (std::optional<uint64_t> alignment = |
336 | getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Preferred)) |
337 | return *alignment; |
338 | |
339 | return dataLayout.getTypePreferredAlignment(get(getContext())); |
340 | } |
341 | |
342 | std::optional<uint64_t> |
343 | LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout, |
344 | DataLayoutEntryListRef params) const { |
345 | if (std::optional<uint64_t> indexBitwidth = |
346 | getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Index)) |
347 | return *indexBitwidth; |
348 | |
349 | return dataLayout.getTypeIndexBitwidth(get(getContext())); |
350 | } |
351 | |
352 | bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout, |
353 | DataLayoutEntryListRef newLayout) const { |
354 | for (DataLayoutEntryInterface newEntry : newLayout) { |
355 | if (!newEntry.isTypeEntry()) |
356 | continue; |
357 | uint64_t size = kDefaultPointerSizeBits; |
358 | uint64_t abi = kDefaultPointerAlignment; |
359 | auto newType = llvm::cast<LLVMPointerType>(newEntry.getKey().get<Type>()); |
360 | const auto *it = |
361 | llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { |
362 | if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) { |
363 | return llvm::cast<LLVMPointerType>(type).getAddressSpace() == |
364 | newType.getAddressSpace(); |
365 | } |
366 | return false; |
367 | }); |
368 | if (it == oldLayout.end()) { |
369 | llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { |
370 | if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) { |
371 | return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0; |
372 | } |
373 | return false; |
374 | }); |
375 | } |
376 | if (it != oldLayout.end()) { |
377 | size = *extractPointerSpecValue(*it, PtrDLEntryPos::Size); |
378 | abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi); |
379 | } |
380 | |
381 | Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue()); |
382 | uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size); |
383 | uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi); |
384 | if (size != newSize || abi < newAbi || abi % newAbi != 0) |
385 | return false; |
386 | } |
387 | return true; |
388 | } |
389 | |
390 | LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries, |
391 | Location loc) const { |
392 | for (DataLayoutEntryInterface entry : entries) { |
393 | if (!entry.isTypeEntry()) |
394 | continue; |
395 | auto key = entry.getKey().get<Type>(); |
396 | auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue()); |
397 | if (!values || (values.size() != 3 && values.size() != 4)) { |
398 | return emitError(loc) |
399 | << "expected layout attribute for " << key |
400 | << " to be a dense integer elements attribute with 3 or 4 " |
401 | "elements" ; |
402 | } |
403 | if (!values.getElementType().isInteger(64)) |
404 | return emitError(loc) << "expected i64 parameters for " << key; |
405 | |
406 | if (extractPointerSpecValue(values, PtrDLEntryPos::Abi) > |
407 | extractPointerSpecValue(values, PtrDLEntryPos::Preferred)) { |
408 | return emitError(loc) << "preferred alignment is expected to be at least " |
409 | "as large as ABI alignment" ; |
410 | } |
411 | } |
412 | return success(); |
413 | } |
414 | |
415 | //===----------------------------------------------------------------------===// |
416 | // Struct type. |
417 | //===----------------------------------------------------------------------===// |
418 | |
419 | bool LLVMStructType::isValidElementType(Type type) { |
420 | return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType, |
421 | LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>( |
422 | type); |
423 | } |
424 | |
425 | LLVMStructType LLVMStructType::getIdentified(MLIRContext *context, |
426 | StringRef name) { |
427 | return Base::get(context, name, /*opaque=*/false); |
428 | } |
429 | |
430 | LLVMStructType LLVMStructType::getIdentifiedChecked( |
431 | function_ref<InFlightDiagnostic()> emitError, MLIRContext *context, |
432 | StringRef name) { |
433 | return Base::getChecked(emitError, context, name, /*opaque=*/false); |
434 | } |
435 | |
436 | LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context, |
437 | StringRef name, |
438 | ArrayRef<Type> elements, |
439 | bool isPacked) { |
440 | std::string stringName = name.str(); |
441 | unsigned counter = 0; |
442 | do { |
443 | auto type = LLVMStructType::getIdentified(context, name: stringName); |
444 | if (type.isInitialized() || failed(result: type.setBody(types: elements, isPacked))) { |
445 | counter += 1; |
446 | stringName = (Twine(name) + "." + std::to_string(val: counter)).str(); |
447 | continue; |
448 | } |
449 | return type; |
450 | } while (true); |
451 | } |
452 | |
453 | LLVMStructType LLVMStructType::getLiteral(MLIRContext *context, |
454 | ArrayRef<Type> types, bool isPacked) { |
455 | return Base::get(context, types, isPacked); |
456 | } |
457 | |
458 | LLVMStructType |
459 | LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError, |
460 | MLIRContext *context, ArrayRef<Type> types, |
461 | bool isPacked) { |
462 | return Base::getChecked(emitError, context, types, isPacked); |
463 | } |
464 | |
465 | LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) { |
466 | return Base::get(context, name, /*opaque=*/true); |
467 | } |
468 | |
469 | LLVMStructType |
470 | LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError, |
471 | MLIRContext *context, StringRef name) { |
472 | return Base::getChecked(emitError, context, name, /*opaque=*/true); |
473 | } |
474 | |
475 | LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) { |
476 | assert(isIdentified() && "can only set bodies of identified structs" ); |
477 | assert(llvm::all_of(types, LLVMStructType::isValidElementType) && |
478 | "expected valid body types" ); |
479 | return Base::mutate(types, isPacked); |
480 | } |
481 | |
482 | bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); } |
483 | bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); } |
484 | bool LLVMStructType::isOpaque() { |
485 | return getImpl()->isIdentified() && |
486 | (getImpl()->isOpaque() || !getImpl()->isInitialized()); |
487 | } |
488 | bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); } |
489 | StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); } |
490 | ArrayRef<Type> LLVMStructType::getBody() const { |
491 | return isIdentified() ? getImpl()->getIdentifiedStructBody() |
492 | : getImpl()->getTypeList(); |
493 | } |
494 | |
495 | LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>, |
496 | StringRef, bool) { |
497 | return success(); |
498 | } |
499 | |
500 | LogicalResult |
501 | LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError, |
502 | ArrayRef<Type> types, bool) { |
503 | for (Type t : types) |
504 | if (!isValidElementType(type: t)) |
505 | return emitError() << "invalid LLVM structure element type: " << t; |
506 | |
507 | return success(); |
508 | } |
509 | |
510 | llvm::TypeSize |
511 | LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout, |
512 | DataLayoutEntryListRef params) const { |
513 | auto structSize = llvm::TypeSize::getFixed(ExactSize: 0); |
514 | uint64_t structAlignment = 1; |
515 | for (Type element : getBody()) { |
516 | uint64_t elementAlignment = |
517 | isPacked() ? 1 : dataLayout.getTypeABIAlignment(t: element); |
518 | // Add padding to the struct size to align it to the abi alignment of the |
519 | // element type before than adding the size of the element. |
520 | structSize = llvm::alignTo(Size: structSize, Align: elementAlignment); |
521 | structSize += dataLayout.getTypeSize(t: element); |
522 | |
523 | // The alignment requirement of a struct is equal to the strictest alignment |
524 | // requirement of its elements. |
525 | structAlignment = std::max(a: elementAlignment, b: structAlignment); |
526 | } |
527 | // At the end, add padding to the struct to satisfy its own alignment |
528 | // requirement. Otherwise structs inside of arrays would be misaligned. |
529 | structSize = llvm::alignTo(Size: structSize, Align: structAlignment); |
530 | return structSize * kBitsInByte; |
531 | } |
532 | |
533 | namespace { |
534 | enum class StructDLEntryPos { Abi = 0, Preferred = 1 }; |
535 | } // namespace |
536 | |
537 | static std::optional<uint64_t> |
538 | getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type, |
539 | StructDLEntryPos pos) { |
540 | const auto *currentEntry = |
541 | llvm::find_if(Range&: params, P: [](DataLayoutEntryInterface entry) { |
542 | return entry.isTypeEntry(); |
543 | }); |
544 | if (currentEntry == params.end()) |
545 | return std::nullopt; |
546 | |
547 | auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue()); |
548 | if (pos == StructDLEntryPos::Preferred && |
549 | attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred)) |
550 | // If no preferred was specified, fall back to abi alignment |
551 | pos = StructDLEntryPos::Abi; |
552 | |
553 | return attr.getValues<uint64_t>()[static_cast<size_t>(pos)]; |
554 | } |
555 | |
556 | static uint64_t calculateStructAlignment(const DataLayout &dataLayout, |
557 | DataLayoutEntryListRef params, |
558 | LLVMStructType type, |
559 | StructDLEntryPos pos) { |
560 | // Packed structs always have an abi alignment of 1 |
561 | if (pos == StructDLEntryPos::Abi && type.isPacked()) { |
562 | return 1; |
563 | } |
564 | |
565 | // The alignment requirement of a struct is equal to the strictest alignment |
566 | // requirement of its elements. |
567 | uint64_t structAlignment = 1; |
568 | for (Type iter : type.getBody()) { |
569 | structAlignment = |
570 | std::max(a: dataLayout.getTypeABIAlignment(t: iter), b: structAlignment); |
571 | } |
572 | |
573 | // Entries are only allowed to be stricter than the required alignment |
574 | if (std::optional<uint64_t> entryResult = |
575 | getStructDataLayoutEntry(params, type, pos)) |
576 | return std::max(a: *entryResult / kBitsInByte, b: structAlignment); |
577 | |
578 | return structAlignment; |
579 | } |
580 | |
581 | uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout, |
582 | DataLayoutEntryListRef params) const { |
583 | return calculateStructAlignment(dataLayout, params, type: *this, |
584 | pos: StructDLEntryPos::Abi); |
585 | } |
586 | |
587 | uint64_t |
588 | LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout, |
589 | DataLayoutEntryListRef params) const { |
590 | return calculateStructAlignment(dataLayout, params, type: *this, |
591 | pos: StructDLEntryPos::Preferred); |
592 | } |
593 | |
594 | static uint64_t (Attribute attr, StructDLEntryPos pos) { |
595 | return llvm::cast<DenseIntElementsAttr>(attr) |
596 | .getValues<uint64_t>()[static_cast<size_t>(pos)]; |
597 | } |
598 | |
599 | bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout, |
600 | DataLayoutEntryListRef newLayout) const { |
601 | for (DataLayoutEntryInterface newEntry : newLayout) { |
602 | if (!newEntry.isTypeEntry()) |
603 | continue; |
604 | |
605 | const auto *previousEntry = |
606 | llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) { |
607 | return entry.isTypeEntry(); |
608 | }); |
609 | if (previousEntry == oldLayout.end()) |
610 | continue; |
611 | |
612 | uint64_t abi = extractStructSpecValue(previousEntry->getValue(), |
613 | StructDLEntryPos::Abi); |
614 | uint64_t newAbi = |
615 | extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi); |
616 | if (abi < newAbi || abi % newAbi != 0) |
617 | return false; |
618 | } |
619 | return true; |
620 | } |
621 | |
622 | LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries, |
623 | Location loc) const { |
624 | for (DataLayoutEntryInterface entry : entries) { |
625 | if (!entry.isTypeEntry()) |
626 | continue; |
627 | |
628 | auto key = llvm::cast<LLVMStructType>(entry.getKey().get<Type>()); |
629 | auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue()); |
630 | if (!values || (values.size() != 2 && values.size() != 1)) { |
631 | return emitError(loc) |
632 | << "expected layout attribute for " << entry.getKey().get<Type>() |
633 | << " to be a dense integer elements attribute of 1 or 2 elements" ; |
634 | } |
635 | if (!values.getElementType().isInteger(64)) |
636 | return emitError(loc) << "expected i64 entries for " << key; |
637 | |
638 | if (key.isIdentified() || !key.getBody().empty()) { |
639 | return emitError(loc) << "unexpected layout attribute for struct " << key; |
640 | } |
641 | |
642 | if (values.size() == 1) |
643 | continue; |
644 | |
645 | if (extractStructSpecValue(values, StructDLEntryPos::Abi) > |
646 | extractStructSpecValue(values, StructDLEntryPos::Preferred)) { |
647 | return emitError(loc) << "preferred alignment is expected to be at least " |
648 | "as large as ABI alignment" ; |
649 | } |
650 | } |
651 | return mlir::success(); |
652 | } |
653 | |
654 | //===----------------------------------------------------------------------===// |
655 | // Vector types. |
656 | //===----------------------------------------------------------------------===// |
657 | |
658 | /// Verifies that the type about to be constructed is well-formed. |
659 | template <typename VecTy> |
660 | static LogicalResult |
661 | verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError, |
662 | Type elementType, unsigned numElements) { |
663 | if (numElements == 0) |
664 | return emitError() << "the number of vector elements must be positive" ; |
665 | |
666 | if (!VecTy::isValidElementType(elementType)) |
667 | return emitError() << "invalid vector element type" ; |
668 | |
669 | return success(); |
670 | } |
671 | |
672 | LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType, |
673 | unsigned numElements) { |
674 | assert(elementType && "expected non-null subtype" ); |
675 | return Base::get(elementType.getContext(), elementType, numElements); |
676 | } |
677 | |
678 | LLVMFixedVectorType |
679 | LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
680 | Type elementType, unsigned numElements) { |
681 | assert(elementType && "expected non-null subtype" ); |
682 | return Base::getChecked(emitError, elementType.getContext(), elementType, |
683 | numElements); |
684 | } |
685 | |
686 | bool LLVMFixedVectorType::isValidElementType(Type type) { |
687 | return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type); |
688 | } |
689 | |
690 | LogicalResult |
691 | LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError, |
692 | Type elementType, unsigned numElements) { |
693 | return verifyVectorConstructionInvariants<LLVMFixedVectorType>( |
694 | emitError, elementType, numElements); |
695 | } |
696 | |
697 | //===----------------------------------------------------------------------===// |
698 | // LLVMScalableVectorType. |
699 | //===----------------------------------------------------------------------===// |
700 | |
701 | LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType, |
702 | unsigned minNumElements) { |
703 | assert(elementType && "expected non-null subtype" ); |
704 | return Base::get(elementType.getContext(), elementType, minNumElements); |
705 | } |
706 | |
707 | LLVMScalableVectorType |
708 | LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
709 | Type elementType, unsigned minNumElements) { |
710 | assert(elementType && "expected non-null subtype" ); |
711 | return Base::getChecked(emitError, elementType.getContext(), elementType, |
712 | minNumElements); |
713 | } |
714 | |
715 | bool LLVMScalableVectorType::isValidElementType(Type type) { |
716 | if (auto intType = llvm::dyn_cast<IntegerType>(type)) |
717 | return intType.isSignless(); |
718 | |
719 | return isCompatibleFloatingPointType(type) || |
720 | llvm::isa<LLVMPointerType>(type); |
721 | } |
722 | |
723 | LogicalResult |
724 | LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError, |
725 | Type elementType, unsigned numElements) { |
726 | return verifyVectorConstructionInvariants<LLVMScalableVectorType>( |
727 | emitError, elementType, numElements); |
728 | } |
729 | |
730 | //===----------------------------------------------------------------------===// |
731 | // LLVMTargetExtType. |
732 | //===----------------------------------------------------------------------===// |
733 | |
734 | static constexpr llvm::StringRef kSpirvPrefix = "spirv." ; |
735 | static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount" ; |
736 | |
737 | bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { |
738 | // See llvm/lib/IR/Type.cpp for reference. |
739 | uint64_t properties = 0; |
740 | |
741 | if (getExtTypeName().starts_with(kSpirvPrefix)) |
742 | properties |= |
743 | (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal); |
744 | |
745 | return (properties & prop) == prop; |
746 | } |
747 | |
748 | bool LLVM::LLVMTargetExtType::supportsMemOps() const { |
749 | // See llvm/lib/IR/Type.cpp for reference. |
750 | if (getExtTypeName().starts_with(kSpirvPrefix)) |
751 | return true; |
752 | |
753 | if (getExtTypeName() == kArmSVCount) |
754 | return true; |
755 | |
756 | return false; |
757 | } |
758 | |
759 | //===----------------------------------------------------------------------===// |
760 | // Utility functions. |
761 | //===----------------------------------------------------------------------===// |
762 | |
763 | bool mlir::LLVM::isCompatibleOuterType(Type type) { |
764 | // clang-format off |
765 | if (llvm::isa< |
766 | BFloat16Type, |
767 | Float16Type, |
768 | Float32Type, |
769 | Float64Type, |
770 | Float80Type, |
771 | Float128Type, |
772 | LLVMArrayType, |
773 | LLVMFunctionType, |
774 | LLVMLabelType, |
775 | LLVMMetadataType, |
776 | LLVMPPCFP128Type, |
777 | LLVMPointerType, |
778 | LLVMStructType, |
779 | LLVMTokenType, |
780 | LLVMFixedVectorType, |
781 | LLVMScalableVectorType, |
782 | LLVMTargetExtType, |
783 | LLVMVoidType, |
784 | LLVMX86MMXType |
785 | >(type)) { |
786 | // clang-format on |
787 | return true; |
788 | } |
789 | |
790 | // Only signless integers are compatible. |
791 | if (auto intType = llvm::dyn_cast<IntegerType>(type)) |
792 | return intType.isSignless(); |
793 | |
794 | // 1D vector types are compatible. |
795 | if (auto vecType = llvm::dyn_cast<VectorType>(type)) |
796 | return vecType.getRank() == 1; |
797 | |
798 | return false; |
799 | } |
800 | |
801 | static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) { |
802 | if (!compatibleTypes.insert(V: type).second) |
803 | return true; |
804 | |
805 | auto isCompatible = [&](Type type) { |
806 | return isCompatibleImpl(type, compatibleTypes); |
807 | }; |
808 | |
809 | bool result = |
810 | llvm::TypeSwitch<Type, bool>(type) |
811 | .Case<LLVMStructType>([&](auto structType) { |
812 | return llvm::all_of(structType.getBody(), isCompatible); |
813 | }) |
814 | .Case<LLVMFunctionType>([&](auto funcType) { |
815 | return isCompatible(funcType.getReturnType()) && |
816 | llvm::all_of(funcType.getParams(), isCompatible); |
817 | }) |
818 | .Case<IntegerType>([](auto intType) { return intType.isSignless(); }) |
819 | .Case<VectorType>([&](auto vecType) { |
820 | return vecType.getRank() == 1 && |
821 | isCompatible(vecType.getElementType()); |
822 | }) |
823 | .Case<LLVMPointerType>([&](auto pointerType) { return true; }) |
824 | .Case<LLVMTargetExtType>([&](auto extType) { |
825 | return llvm::all_of(extType.getTypeParams(), isCompatible); |
826 | }) |
827 | // clang-format off |
828 | .Case< |
829 | LLVMFixedVectorType, |
830 | LLVMScalableVectorType, |
831 | LLVMArrayType |
832 | >([&](auto containerType) { |
833 | return isCompatible(containerType.getElementType()); |
834 | }) |
835 | .Case< |
836 | BFloat16Type, |
837 | Float16Type, |
838 | Float32Type, |
839 | Float64Type, |
840 | Float80Type, |
841 | Float128Type, |
842 | LLVMLabelType, |
843 | LLVMMetadataType, |
844 | LLVMPPCFP128Type, |
845 | LLVMTokenType, |
846 | LLVMVoidType, |
847 | LLVMX86MMXType |
848 | >([](Type) { return true; }) |
849 | // clang-format on |
850 | .Default([](Type) { return false; }); |
851 | |
852 | if (!result) |
853 | compatibleTypes.erase(V: type); |
854 | |
855 | return result; |
856 | } |
857 | |
858 | bool LLVMDialect::isCompatibleType(Type type) { |
859 | if (auto *llvmDialect = |
860 | type.getContext()->getLoadedDialect<LLVM::LLVMDialect>()) |
861 | return isCompatibleImpl(type, llvmDialect->compatibleTypes.get()); |
862 | |
863 | DenseSet<Type> localCompatibleTypes; |
864 | return isCompatibleImpl(type, localCompatibleTypes); |
865 | } |
866 | |
867 | bool mlir::LLVM::isCompatibleType(Type type) { |
868 | return LLVMDialect::isCompatibleType(type); |
869 | } |
870 | |
871 | bool mlir::LLVM::isCompatibleFloatingPointType(Type type) { |
872 | return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type, |
873 | Float80Type, Float128Type, LLVMPPCFP128Type>(type); |
874 | } |
875 | |
876 | bool mlir::LLVM::isCompatibleVectorType(Type type) { |
877 | if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType>(type)) |
878 | return true; |
879 | |
880 | if (auto vecType = llvm::dyn_cast<VectorType>(type)) { |
881 | if (vecType.getRank() != 1) |
882 | return false; |
883 | Type elementType = vecType.getElementType(); |
884 | if (auto intType = llvm::dyn_cast<IntegerType>(elementType)) |
885 | return intType.isSignless(); |
886 | return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type, |
887 | Float80Type, Float128Type>(elementType); |
888 | } |
889 | return false; |
890 | } |
891 | |
892 | Type mlir::LLVM::getVectorElementType(Type type) { |
893 | return llvm::TypeSwitch<Type, Type>(type) |
894 | .Case<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>( |
895 | [](auto ty) { return ty.getElementType(); }) |
896 | .Default([](Type) -> Type { |
897 | llvm_unreachable("incompatible with LLVM vector type" ); |
898 | }); |
899 | } |
900 | |
901 | llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) { |
902 | return llvm::TypeSwitch<Type, llvm::ElementCount>(type) |
903 | .Case(caseFn: [](VectorType ty) { |
904 | if (ty.isScalable()) |
905 | return llvm::ElementCount::getScalable(ty.getNumElements()); |
906 | return llvm::ElementCount::getFixed(ty.getNumElements()); |
907 | }) |
908 | .Case(caseFn: [](LLVMFixedVectorType ty) { |
909 | return llvm::ElementCount::getFixed(ty.getNumElements()); |
910 | }) |
911 | .Case(caseFn: [](LLVMScalableVectorType ty) { |
912 | return llvm::ElementCount::getScalable(ty.getMinNumElements()); |
913 | }) |
914 | .Default(defaultFn: [](Type) -> llvm::ElementCount { |
915 | llvm_unreachable("incompatible with LLVM vector type" ); |
916 | }); |
917 | } |
918 | |
919 | bool mlir::LLVM::isScalableVectorType(Type vectorType) { |
920 | assert((llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>( |
921 | vectorType)) && |
922 | "expected LLVM-compatible vector type" ); |
923 | return !llvm::isa<LLVMFixedVectorType>(vectorType) && |
924 | (llvm::isa<LLVMScalableVectorType>(vectorType) || |
925 | llvm::cast<VectorType>(vectorType).isScalable()); |
926 | } |
927 | |
928 | Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements, |
929 | bool isScalable) { |
930 | bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType); |
931 | bool useBuiltIn = VectorType::isValidElementType(elementType); |
932 | (void)useBuiltIn; |
933 | assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type " |
934 | "to be either builtin or LLVM dialect type" ); |
935 | if (useLLVM) { |
936 | if (isScalable) |
937 | return LLVMScalableVectorType::get(elementType, numElements); |
938 | return LLVMFixedVectorType::get(elementType, numElements); |
939 | } |
940 | |
941 | // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as |
942 | // scalable/non-scalable. |
943 | return VectorType::get(numElements, elementType, {isScalable}); |
944 | } |
945 | |
946 | Type mlir::LLVM::getVectorType(Type elementType, |
947 | const llvm::ElementCount &numElements) { |
948 | if (numElements.isScalable()) |
949 | return getVectorType(elementType, numElements: numElements.getKnownMinValue(), |
950 | /*isScalable=*/true); |
951 | return getVectorType(elementType, numElements: numElements.getFixedValue(), |
952 | /*isScalable=*/false); |
953 | } |
954 | |
955 | Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) { |
956 | bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType); |
957 | bool useBuiltIn = VectorType::isValidElementType(elementType); |
958 | (void)useBuiltIn; |
959 | assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type " |
960 | "to be either builtin or LLVM dialect type" ); |
961 | if (useLLVM) |
962 | return LLVMFixedVectorType::get(elementType, numElements); |
963 | return VectorType::get(numElements, elementType); |
964 | } |
965 | |
966 | Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) { |
967 | bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType); |
968 | bool useBuiltIn = VectorType::isValidElementType(elementType); |
969 | (void)useBuiltIn; |
970 | assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector " |
971 | "type to be either builtin or LLVM dialect " |
972 | "type" ); |
973 | if (useLLVM) |
974 | return LLVMScalableVectorType::get(elementType, numElements); |
975 | |
976 | // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as |
977 | // scalable/non-scalable. |
978 | return VectorType::get(numElements, elementType, /*scalableDims=*/true); |
979 | } |
980 | |
981 | llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { |
982 | assert(isCompatibleType(type) && |
983 | "expected a type compatible with the LLVM dialect" ); |
984 | |
985 | return llvm::TypeSwitch<Type, llvm::TypeSize>(type) |
986 | .Case<BFloat16Type, Float16Type>( |
987 | [](Type) { return llvm::TypeSize::getFixed(16); }) |
988 | .Case<Float32Type>([](Type) { return llvm::TypeSize::getFixed(32); }) |
989 | .Case<Float64Type, LLVMX86MMXType>( |
990 | [](Type) { return llvm::TypeSize::getFixed(64); }) |
991 | .Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); }) |
992 | .Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); }) |
993 | .Case<IntegerType>([](IntegerType intTy) { |
994 | return llvm::TypeSize::getFixed(intTy.getWidth()); |
995 | }) |
996 | .Case<LLVMPPCFP128Type>( |
997 | [](Type) { return llvm::TypeSize::getFixed(128); }) |
998 | .Case<LLVMFixedVectorType>([](LLVMFixedVectorType t) { |
999 | llvm::TypeSize elementSize = |
1000 | getPrimitiveTypeSizeInBits(t.getElementType()); |
1001 | return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(), |
1002 | elementSize.isScalable()); |
1003 | }) |
1004 | .Case<VectorType>([](VectorType t) { |
1005 | assert(isCompatibleVectorType(t) && |
1006 | "unexpected incompatible with LLVM vector type" ); |
1007 | llvm::TypeSize elementSize = |
1008 | getPrimitiveTypeSizeInBits(t.getElementType()); |
1009 | return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(), |
1010 | elementSize.isScalable()); |
1011 | }) |
1012 | .Default([](Type ty) { |
1013 | assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType, |
1014 | LLVMTokenType, LLVMStructType, LLVMArrayType, |
1015 | LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>( |
1016 | ty)) && |
1017 | "unexpected missing support for primitive type" ); |
1018 | return llvm::TypeSize::getFixed(0); |
1019 | }); |
1020 | } |
1021 | |
1022 | //===----------------------------------------------------------------------===// |
1023 | // LLVMDialect |
1024 | //===----------------------------------------------------------------------===// |
1025 | |
1026 | void LLVMDialect::registerTypes() { |
1027 | addTypes< |
1028 | #define GET_TYPEDEF_LIST |
1029 | #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc" |
1030 | >(); |
1031 | } |
1032 | |
1033 | Type LLVMDialect::parseType(DialectAsmParser &parser) const { |
1034 | return detail::parseType(parser); |
1035 | } |
1036 | |
1037 | void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { |
1038 | return detail::printType(type, os); |
1039 | } |
1040 | |