1//===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the types for the LLVM dialect in MLIR. These MLIR types
10// correspond to the LLVM IR type system.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TypeDetail.h"
15
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/DialectImplementation.h"
20#include "mlir/IR/TypeSupport.h"
21
22#include "llvm/ADT/ScopeExit.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/TypeSize.h"
25#include <optional>
26
27using namespace mlir;
28using namespace mlir::LLVM;
29
30constexpr const static uint64_t kBitsInByte = 8;
31
32//===----------------------------------------------------------------------===//
33// custom<FunctionTypes>
34//===----------------------------------------------------------------------===//
35
36static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> &params,
37 bool &isVarArg) {
38 isVarArg = false;
39 // `(` `)`
40 if (succeeded(result: p.parseOptionalRParen()))
41 return success();
42
43 // `(` `...` `)`
44 if (succeeded(result: p.parseOptionalEllipsis())) {
45 isVarArg = true;
46 return p.parseRParen();
47 }
48
49 // type (`,` type)* (`,` `...`)?
50 Type type;
51 if (parsePrettyLLVMType(p, type))
52 return failure();
53 params.push_back(Elt: type);
54 while (succeeded(result: p.parseOptionalComma())) {
55 if (succeeded(result: p.parseOptionalEllipsis())) {
56 isVarArg = true;
57 return p.parseRParen();
58 }
59 if (parsePrettyLLVMType(p, type))
60 return failure();
61 params.push_back(Elt: type);
62 }
63 return p.parseRParen();
64}
65
66static void printFunctionTypes(AsmPrinter &p, ArrayRef<Type> params,
67 bool isVarArg) {
68 llvm::interleaveComma(c: params, os&: p,
69 each_fn: [&](Type type) { printPrettyLLVMType(p, type); });
70 if (isVarArg) {
71 if (!params.empty())
72 p << ", ";
73 p << "...";
74 }
75 p << ')';
76}
77
78//===----------------------------------------------------------------------===//
79// custom<ExtTypeParams>
80//===----------------------------------------------------------------------===//
81
82/// Parses the parameter list for a target extension type. The parameter list
83/// contains an optional list of type parameters, followed by an optional list
84/// of integer parameters. Type and integer parameters cannot be interleaved in
85/// the list.
86/// extTypeParams ::= typeList? | intList? | (typeList "," intList)
87/// typeList ::= type ("," type)*
88/// intList ::= integer ("," integer)*
89static ParseResult
90parseExtTypeParams(AsmParser &p, SmallVectorImpl<Type> &typeParams,
91 SmallVectorImpl<unsigned int> &intParams) {
92 bool parseType = true;
93 auto typeOrIntParser = [&]() -> ParseResult {
94 unsigned int i;
95 auto intResult = p.parseOptionalInteger(result&: i);
96 if (intResult.has_value() && !failed(result: *intResult)) {
97 // Successfully parsed an integer.
98 intParams.push_back(Elt: i);
99 // After the first integer was successfully parsed, no
100 // more types can be parsed.
101 parseType = false;
102 return success();
103 }
104 if (parseType) {
105 Type t;
106 if (!parsePrettyLLVMType(p, type&: t)) {
107 // Successfully parsed a type.
108 typeParams.push_back(Elt: t);
109 return success();
110 }
111 }
112 return failure();
113 };
114 if (p.parseCommaSeparatedList(parseElementFn: typeOrIntParser)) {
115 p.emitError(loc: p.getCurrentLocation(),
116 message: "failed to parse parameter list for target extension type");
117 return failure();
118 }
119 return success();
120}
121
122static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams,
123 ArrayRef<unsigned int> intParams) {
124 p << typeParams;
125 if (!typeParams.empty() && !intParams.empty())
126 p << ", ";
127
128 p << intParams;
129}
130
131//===----------------------------------------------------------------------===//
132// ODS-Generated Definitions
133//===----------------------------------------------------------------------===//
134
135/// These are unused for now.
136/// TODO: Move over to these once more types have been migrated to TypeDef.
137LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
138generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
139LLVM_ATTRIBUTE_UNUSED static LogicalResult
140generatedTypePrinter(Type def, AsmPrinter &printer);
141
142#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
143
144#define GET_TYPEDEF_CLASSES
145#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
146
147//===----------------------------------------------------------------------===//
148// LLVMArrayType
149//===----------------------------------------------------------------------===//
150
151bool LLVMArrayType::isValidElementType(Type type) {
152 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
153 LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
154 type);
155}
156
157LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
158 assert(elementType && "expected non-null subtype");
159 return Base::get(elementType.getContext(), elementType, numElements);
160}
161
162LLVMArrayType
163LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
164 Type elementType, unsigned numElements) {
165 assert(elementType && "expected non-null subtype");
166 return Base::getChecked(emitError, elementType.getContext(), elementType,
167 numElements);
168}
169
170LogicalResult
171LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
172 Type elementType, unsigned numElements) {
173 if (!isValidElementType(elementType))
174 return emitError() << "invalid array element type: " << elementType;
175 return success();
176}
177
178//===----------------------------------------------------------------------===//
179// DataLayoutTypeInterface
180
181llvm::TypeSize
182LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
183 DataLayoutEntryListRef params) const {
184 return llvm::TypeSize::getFixed(kBitsInByte *
185 getTypeSize(dataLayout, params));
186}
187
188llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
189 DataLayoutEntryListRef params) const {
190 return llvm::alignTo(dataLayout.getTypeSize(getElementType()),
191 dataLayout.getTypeABIAlignment(getElementType())) *
192 getNumElements();
193}
194
195uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
196 DataLayoutEntryListRef params) const {
197 return dataLayout.getTypeABIAlignment(getElementType());
198}
199
200uint64_t
201LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
202 DataLayoutEntryListRef params) const {
203 return dataLayout.getTypePreferredAlignment(getElementType());
204}
205
206//===----------------------------------------------------------------------===//
207// Function type.
208//===----------------------------------------------------------------------===//
209
210bool LLVMFunctionType::isValidArgumentType(Type type) {
211 return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type);
212}
213
214bool LLVMFunctionType::isValidResultType(Type type) {
215 return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type);
216}
217
218LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
219 bool isVarArg) {
220 assert(result && "expected non-null result");
221 return Base::get(result.getContext(), result, arguments, isVarArg);
222}
223
224LLVMFunctionType
225LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
226 Type result, ArrayRef<Type> arguments,
227 bool isVarArg) {
228 assert(result && "expected non-null result");
229 return Base::getChecked(emitError, result.getContext(), result, arguments,
230 isVarArg);
231}
232
233LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
234 TypeRange results) const {
235 assert(results.size() == 1 && "expected a single result type");
236 return get(results[0], llvm::to_vector(inputs), isVarArg());
237}
238
239ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
240 return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
241}
242
243LogicalResult
244LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
245 Type result, ArrayRef<Type> arguments, bool) {
246 if (!isValidResultType(result))
247 return emitError() << "invalid function result type: " << result;
248
249 for (Type arg : arguments)
250 if (!isValidArgumentType(arg))
251 return emitError() << "invalid function argument type: " << arg;
252
253 return success();
254}
255
256//===----------------------------------------------------------------------===//
257// DataLayoutTypeInterface
258
259constexpr const static uint64_t kDefaultPointerSizeBits = 64;
260constexpr const static uint64_t kDefaultPointerAlignment = 8;
261
262std::optional<uint64_t> mlir::LLVM::extractPointerSpecValue(Attribute attr,
263 PtrDLEntryPos pos) {
264 auto spec = cast<DenseIntElementsAttr>(Val&: attr);
265 auto idx = static_cast<int64_t>(pos);
266 if (idx >= spec.size())
267 return std::nullopt;
268 return spec.getValues<uint64_t>()[idx];
269}
270
271/// Returns the part of the data layout entry that corresponds to `pos` for the
272/// given `type` by interpreting the list of entries `params`. For the pointer
273/// type in the default address space, returns the default value if the entries
274/// do not provide a custom one, for other address spaces returns std::nullopt.
275static std::optional<uint64_t>
276getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type,
277 PtrDLEntryPos pos) {
278 // First, look for the entry for the pointer in the current address space.
279 Attribute currentEntry;
280 for (DataLayoutEntryInterface entry : params) {
281 if (!entry.isTypeEntry())
282 continue;
283 if (cast<LLVMPointerType>(entry.getKey().get<Type>()).getAddressSpace() ==
284 type.getAddressSpace()) {
285 currentEntry = entry.getValue();
286 break;
287 }
288 }
289 if (currentEntry) {
290 std::optional<uint64_t> value = extractPointerSpecValue(attr: currentEntry, pos);
291 // If the optional `PtrDLEntryPos::Index` entry is not available, use the
292 // pointer size as the index bitwidth.
293 if (!value && pos == PtrDLEntryPos::Index)
294 value = extractPointerSpecValue(attr: currentEntry, pos: PtrDLEntryPos::Size);
295 bool isSizeOrIndex =
296 pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
297 return *value / (isSizeOrIndex ? 1 : kBitsInByte);
298 }
299
300 // If not found, and this is the pointer to the default memory space, assume
301 // 64-bit pointers.
302 if (type.getAddressSpace() == 0) {
303 bool isSizeOrIndex =
304 pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
305 return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
306 }
307
308 return std::nullopt;
309}
310
311llvm::TypeSize
312LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
313 DataLayoutEntryListRef params) const {
314 if (std::optional<uint64_t> size =
315 getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Size))
316 return llvm::TypeSize::getFixed(*size);
317
318 // For other memory spaces, use the size of the pointer to the default memory
319 // space.
320 return dataLayout.getTypeSizeInBits(get(getContext()));
321}
322
323uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
324 DataLayoutEntryListRef params) const {
325 if (std::optional<uint64_t> alignment =
326 getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Abi))
327 return *alignment;
328
329 return dataLayout.getTypeABIAlignment(get(getContext()));
330}
331
332uint64_t
333LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
334 DataLayoutEntryListRef params) const {
335 if (std::optional<uint64_t> alignment =
336 getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Preferred))
337 return *alignment;
338
339 return dataLayout.getTypePreferredAlignment(get(getContext()));
340}
341
342std::optional<uint64_t>
343LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout,
344 DataLayoutEntryListRef params) const {
345 if (std::optional<uint64_t> indexBitwidth =
346 getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Index))
347 return *indexBitwidth;
348
349 return dataLayout.getTypeIndexBitwidth(get(getContext()));
350}
351
352bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
353 DataLayoutEntryListRef newLayout) const {
354 for (DataLayoutEntryInterface newEntry : newLayout) {
355 if (!newEntry.isTypeEntry())
356 continue;
357 uint64_t size = kDefaultPointerSizeBits;
358 uint64_t abi = kDefaultPointerAlignment;
359 auto newType = llvm::cast<LLVMPointerType>(newEntry.getKey().get<Type>());
360 const auto *it =
361 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
362 if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
363 return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
364 newType.getAddressSpace();
365 }
366 return false;
367 });
368 if (it == oldLayout.end()) {
369 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
370 if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
371 return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
372 }
373 return false;
374 });
375 }
376 if (it != oldLayout.end()) {
377 size = *extractPointerSpecValue(*it, PtrDLEntryPos::Size);
378 abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi);
379 }
380
381 Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
382 uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
383 uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
384 if (size != newSize || abi < newAbi || abi % newAbi != 0)
385 return false;
386 }
387 return true;
388}
389
390LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
391 Location loc) const {
392 for (DataLayoutEntryInterface entry : entries) {
393 if (!entry.isTypeEntry())
394 continue;
395 auto key = entry.getKey().get<Type>();
396 auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
397 if (!values || (values.size() != 3 && values.size() != 4)) {
398 return emitError(loc)
399 << "expected layout attribute for " << key
400 << " to be a dense integer elements attribute with 3 or 4 "
401 "elements";
402 }
403 if (!values.getElementType().isInteger(64))
404 return emitError(loc) << "expected i64 parameters for " << key;
405
406 if (extractPointerSpecValue(values, PtrDLEntryPos::Abi) >
407 extractPointerSpecValue(values, PtrDLEntryPos::Preferred)) {
408 return emitError(loc) << "preferred alignment is expected to be at least "
409 "as large as ABI alignment";
410 }
411 }
412 return success();
413}
414
415//===----------------------------------------------------------------------===//
416// Struct type.
417//===----------------------------------------------------------------------===//
418
419bool LLVMStructType::isValidElementType(Type type) {
420 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
421 LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
422 type);
423}
424
425LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
426 StringRef name) {
427 return Base::get(context, name, /*opaque=*/false);
428}
429
430LLVMStructType LLVMStructType::getIdentifiedChecked(
431 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
432 StringRef name) {
433 return Base::getChecked(emitError, context, name, /*opaque=*/false);
434}
435
436LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
437 StringRef name,
438 ArrayRef<Type> elements,
439 bool isPacked) {
440 std::string stringName = name.str();
441 unsigned counter = 0;
442 do {
443 auto type = LLVMStructType::getIdentified(context, name: stringName);
444 if (type.isInitialized() || failed(result: type.setBody(types: elements, isPacked))) {
445 counter += 1;
446 stringName = (Twine(name) + "." + std::to_string(val: counter)).str();
447 continue;
448 }
449 return type;
450 } while (true);
451}
452
453LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
454 ArrayRef<Type> types, bool isPacked) {
455 return Base::get(context, types, isPacked);
456}
457
458LLVMStructType
459LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
460 MLIRContext *context, ArrayRef<Type> types,
461 bool isPacked) {
462 return Base::getChecked(emitError, context, types, isPacked);
463}
464
465LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
466 return Base::get(context, name, /*opaque=*/true);
467}
468
469LLVMStructType
470LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
471 MLIRContext *context, StringRef name) {
472 return Base::getChecked(emitError, context, name, /*opaque=*/true);
473}
474
475LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
476 assert(isIdentified() && "can only set bodies of identified structs");
477 assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
478 "expected valid body types");
479 return Base::mutate(types, isPacked);
480}
481
482bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
483bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
484bool LLVMStructType::isOpaque() {
485 return getImpl()->isIdentified() &&
486 (getImpl()->isOpaque() || !getImpl()->isInitialized());
487}
488bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
489StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); }
490ArrayRef<Type> LLVMStructType::getBody() const {
491 return isIdentified() ? getImpl()->getIdentifiedStructBody()
492 : getImpl()->getTypeList();
493}
494
495LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
496 StringRef, bool) {
497 return success();
498}
499
500LogicalResult
501LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
502 ArrayRef<Type> types, bool) {
503 for (Type t : types)
504 if (!isValidElementType(type: t))
505 return emitError() << "invalid LLVM structure element type: " << t;
506
507 return success();
508}
509
510llvm::TypeSize
511LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
512 DataLayoutEntryListRef params) const {
513 auto structSize = llvm::TypeSize::getFixed(ExactSize: 0);
514 uint64_t structAlignment = 1;
515 for (Type element : getBody()) {
516 uint64_t elementAlignment =
517 isPacked() ? 1 : dataLayout.getTypeABIAlignment(t: element);
518 // Add padding to the struct size to align it to the abi alignment of the
519 // element type before than adding the size of the element.
520 structSize = llvm::alignTo(Size: structSize, Align: elementAlignment);
521 structSize += dataLayout.getTypeSize(t: element);
522
523 // The alignment requirement of a struct is equal to the strictest alignment
524 // requirement of its elements.
525 structAlignment = std::max(a: elementAlignment, b: structAlignment);
526 }
527 // At the end, add padding to the struct to satisfy its own alignment
528 // requirement. Otherwise structs inside of arrays would be misaligned.
529 structSize = llvm::alignTo(Size: structSize, Align: structAlignment);
530 return structSize * kBitsInByte;
531}
532
533namespace {
534enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
535} // namespace
536
537static std::optional<uint64_t>
538getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type,
539 StructDLEntryPos pos) {
540 const auto *currentEntry =
541 llvm::find_if(Range&: params, P: [](DataLayoutEntryInterface entry) {
542 return entry.isTypeEntry();
543 });
544 if (currentEntry == params.end())
545 return std::nullopt;
546
547 auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue());
548 if (pos == StructDLEntryPos::Preferred &&
549 attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred))
550 // If no preferred was specified, fall back to abi alignment
551 pos = StructDLEntryPos::Abi;
552
553 return attr.getValues<uint64_t>()[static_cast<size_t>(pos)];
554}
555
556static uint64_t calculateStructAlignment(const DataLayout &dataLayout,
557 DataLayoutEntryListRef params,
558 LLVMStructType type,
559 StructDLEntryPos pos) {
560 // Packed structs always have an abi alignment of 1
561 if (pos == StructDLEntryPos::Abi && type.isPacked()) {
562 return 1;
563 }
564
565 // The alignment requirement of a struct is equal to the strictest alignment
566 // requirement of its elements.
567 uint64_t structAlignment = 1;
568 for (Type iter : type.getBody()) {
569 structAlignment =
570 std::max(a: dataLayout.getTypeABIAlignment(t: iter), b: structAlignment);
571 }
572
573 // Entries are only allowed to be stricter than the required alignment
574 if (std::optional<uint64_t> entryResult =
575 getStructDataLayoutEntry(params, type, pos))
576 return std::max(a: *entryResult / kBitsInByte, b: structAlignment);
577
578 return structAlignment;
579}
580
581uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
582 DataLayoutEntryListRef params) const {
583 return calculateStructAlignment(dataLayout, params, type: *this,
584 pos: StructDLEntryPos::Abi);
585}
586
587uint64_t
588LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
589 DataLayoutEntryListRef params) const {
590 return calculateStructAlignment(dataLayout, params, type: *this,
591 pos: StructDLEntryPos::Preferred);
592}
593
594static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
595 return llvm::cast<DenseIntElementsAttr>(attr)
596 .getValues<uint64_t>()[static_cast<size_t>(pos)];
597}
598
599bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout,
600 DataLayoutEntryListRef newLayout) const {
601 for (DataLayoutEntryInterface newEntry : newLayout) {
602 if (!newEntry.isTypeEntry())
603 continue;
604
605 const auto *previousEntry =
606 llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
607 return entry.isTypeEntry();
608 });
609 if (previousEntry == oldLayout.end())
610 continue;
611
612 uint64_t abi = extractStructSpecValue(previousEntry->getValue(),
613 StructDLEntryPos::Abi);
614 uint64_t newAbi =
615 extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi);
616 if (abi < newAbi || abi % newAbi != 0)
617 return false;
618 }
619 return true;
620}
621
622LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
623 Location loc) const {
624 for (DataLayoutEntryInterface entry : entries) {
625 if (!entry.isTypeEntry())
626 continue;
627
628 auto key = llvm::cast<LLVMStructType>(entry.getKey().get<Type>());
629 auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
630 if (!values || (values.size() != 2 && values.size() != 1)) {
631 return emitError(loc)
632 << "expected layout attribute for " << entry.getKey().get<Type>()
633 << " to be a dense integer elements attribute of 1 or 2 elements";
634 }
635 if (!values.getElementType().isInteger(64))
636 return emitError(loc) << "expected i64 entries for " << key;
637
638 if (key.isIdentified() || !key.getBody().empty()) {
639 return emitError(loc) << "unexpected layout attribute for struct " << key;
640 }
641
642 if (values.size() == 1)
643 continue;
644
645 if (extractStructSpecValue(values, StructDLEntryPos::Abi) >
646 extractStructSpecValue(values, StructDLEntryPos::Preferred)) {
647 return emitError(loc) << "preferred alignment is expected to be at least "
648 "as large as ABI alignment";
649 }
650 }
651 return mlir::success();
652}
653
654//===----------------------------------------------------------------------===//
655// Vector types.
656//===----------------------------------------------------------------------===//
657
658/// Verifies that the type about to be constructed is well-formed.
659template <typename VecTy>
660static LogicalResult
661verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError,
662 Type elementType, unsigned numElements) {
663 if (numElements == 0)
664 return emitError() << "the number of vector elements must be positive";
665
666 if (!VecTy::isValidElementType(elementType))
667 return emitError() << "invalid vector element type";
668
669 return success();
670}
671
672LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType,
673 unsigned numElements) {
674 assert(elementType && "expected non-null subtype");
675 return Base::get(elementType.getContext(), elementType, numElements);
676}
677
678LLVMFixedVectorType
679LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
680 Type elementType, unsigned numElements) {
681 assert(elementType && "expected non-null subtype");
682 return Base::getChecked(emitError, elementType.getContext(), elementType,
683 numElements);
684}
685
686bool LLVMFixedVectorType::isValidElementType(Type type) {
687 return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type);
688}
689
690LogicalResult
691LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
692 Type elementType, unsigned numElements) {
693 return verifyVectorConstructionInvariants<LLVMFixedVectorType>(
694 emitError, elementType, numElements);
695}
696
697//===----------------------------------------------------------------------===//
698// LLVMScalableVectorType.
699//===----------------------------------------------------------------------===//
700
701LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType,
702 unsigned minNumElements) {
703 assert(elementType && "expected non-null subtype");
704 return Base::get(elementType.getContext(), elementType, minNumElements);
705}
706
707LLVMScalableVectorType
708LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
709 Type elementType, unsigned minNumElements) {
710 assert(elementType && "expected non-null subtype");
711 return Base::getChecked(emitError, elementType.getContext(), elementType,
712 minNumElements);
713}
714
715bool LLVMScalableVectorType::isValidElementType(Type type) {
716 if (auto intType = llvm::dyn_cast<IntegerType>(type))
717 return intType.isSignless();
718
719 return isCompatibleFloatingPointType(type) ||
720 llvm::isa<LLVMPointerType>(type);
721}
722
723LogicalResult
724LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
725 Type elementType, unsigned numElements) {
726 return verifyVectorConstructionInvariants<LLVMScalableVectorType>(
727 emitError, elementType, numElements);
728}
729
730//===----------------------------------------------------------------------===//
731// LLVMTargetExtType.
732//===----------------------------------------------------------------------===//
733
734static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
735static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
736
737bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
738 // See llvm/lib/IR/Type.cpp for reference.
739 uint64_t properties = 0;
740
741 if (getExtTypeName().starts_with(kSpirvPrefix))
742 properties |=
743 (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
744
745 return (properties & prop) == prop;
746}
747
748bool LLVM::LLVMTargetExtType::supportsMemOps() const {
749 // See llvm/lib/IR/Type.cpp for reference.
750 if (getExtTypeName().starts_with(kSpirvPrefix))
751 return true;
752
753 if (getExtTypeName() == kArmSVCount)
754 return true;
755
756 return false;
757}
758
759//===----------------------------------------------------------------------===//
760// Utility functions.
761//===----------------------------------------------------------------------===//
762
763bool mlir::LLVM::isCompatibleOuterType(Type type) {
764 // clang-format off
765 if (llvm::isa<
766 BFloat16Type,
767 Float16Type,
768 Float32Type,
769 Float64Type,
770 Float80Type,
771 Float128Type,
772 LLVMArrayType,
773 LLVMFunctionType,
774 LLVMLabelType,
775 LLVMMetadataType,
776 LLVMPPCFP128Type,
777 LLVMPointerType,
778 LLVMStructType,
779 LLVMTokenType,
780 LLVMFixedVectorType,
781 LLVMScalableVectorType,
782 LLVMTargetExtType,
783 LLVMVoidType,
784 LLVMX86MMXType
785 >(type)) {
786 // clang-format on
787 return true;
788 }
789
790 // Only signless integers are compatible.
791 if (auto intType = llvm::dyn_cast<IntegerType>(type))
792 return intType.isSignless();
793
794 // 1D vector types are compatible.
795 if (auto vecType = llvm::dyn_cast<VectorType>(type))
796 return vecType.getRank() == 1;
797
798 return false;
799}
800
801static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
802 if (!compatibleTypes.insert(V: type).second)
803 return true;
804
805 auto isCompatible = [&](Type type) {
806 return isCompatibleImpl(type, compatibleTypes);
807 };
808
809 bool result =
810 llvm::TypeSwitch<Type, bool>(type)
811 .Case<LLVMStructType>([&](auto structType) {
812 return llvm::all_of(structType.getBody(), isCompatible);
813 })
814 .Case<LLVMFunctionType>([&](auto funcType) {
815 return isCompatible(funcType.getReturnType()) &&
816 llvm::all_of(funcType.getParams(), isCompatible);
817 })
818 .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
819 .Case<VectorType>([&](auto vecType) {
820 return vecType.getRank() == 1 &&
821 isCompatible(vecType.getElementType());
822 })
823 .Case<LLVMPointerType>([&](auto pointerType) { return true; })
824 .Case<LLVMTargetExtType>([&](auto extType) {
825 return llvm::all_of(extType.getTypeParams(), isCompatible);
826 })
827 // clang-format off
828 .Case<
829 LLVMFixedVectorType,
830 LLVMScalableVectorType,
831 LLVMArrayType
832 >([&](auto containerType) {
833 return isCompatible(containerType.getElementType());
834 })
835 .Case<
836 BFloat16Type,
837 Float16Type,
838 Float32Type,
839 Float64Type,
840 Float80Type,
841 Float128Type,
842 LLVMLabelType,
843 LLVMMetadataType,
844 LLVMPPCFP128Type,
845 LLVMTokenType,
846 LLVMVoidType,
847 LLVMX86MMXType
848 >([](Type) { return true; })
849 // clang-format on
850 .Default([](Type) { return false; });
851
852 if (!result)
853 compatibleTypes.erase(V: type);
854
855 return result;
856}
857
858bool LLVMDialect::isCompatibleType(Type type) {
859 if (auto *llvmDialect =
860 type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
861 return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
862
863 DenseSet<Type> localCompatibleTypes;
864 return isCompatibleImpl(type, localCompatibleTypes);
865}
866
867bool mlir::LLVM::isCompatibleType(Type type) {
868 return LLVMDialect::isCompatibleType(type);
869}
870
871bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
872 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
873 Float80Type, Float128Type, LLVMPPCFP128Type>(type);
874}
875
876bool mlir::LLVM::isCompatibleVectorType(Type type) {
877 if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType>(type))
878 return true;
879
880 if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
881 if (vecType.getRank() != 1)
882 return false;
883 Type elementType = vecType.getElementType();
884 if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
885 return intType.isSignless();
886 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
887 Float80Type, Float128Type>(elementType);
888 }
889 return false;
890}
891
892Type mlir::LLVM::getVectorElementType(Type type) {
893 return llvm::TypeSwitch<Type, Type>(type)
894 .Case<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>(
895 [](auto ty) { return ty.getElementType(); })
896 .Default([](Type) -> Type {
897 llvm_unreachable("incompatible with LLVM vector type");
898 });
899}
900
901llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
902 return llvm::TypeSwitch<Type, llvm::ElementCount>(type)
903 .Case(caseFn: [](VectorType ty) {
904 if (ty.isScalable())
905 return llvm::ElementCount::getScalable(ty.getNumElements());
906 return llvm::ElementCount::getFixed(ty.getNumElements());
907 })
908 .Case(caseFn: [](LLVMFixedVectorType ty) {
909 return llvm::ElementCount::getFixed(ty.getNumElements());
910 })
911 .Case(caseFn: [](LLVMScalableVectorType ty) {
912 return llvm::ElementCount::getScalable(ty.getMinNumElements());
913 })
914 .Default(defaultFn: [](Type) -> llvm::ElementCount {
915 llvm_unreachable("incompatible with LLVM vector type");
916 });
917}
918
919bool mlir::LLVM::isScalableVectorType(Type vectorType) {
920 assert((llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>(
921 vectorType)) &&
922 "expected LLVM-compatible vector type");
923 return !llvm::isa<LLVMFixedVectorType>(vectorType) &&
924 (llvm::isa<LLVMScalableVectorType>(vectorType) ||
925 llvm::cast<VectorType>(vectorType).isScalable());
926}
927
928Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
929 bool isScalable) {
930 bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
931 bool useBuiltIn = VectorType::isValidElementType(elementType);
932 (void)useBuiltIn;
933 assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
934 "to be either builtin or LLVM dialect type");
935 if (useLLVM) {
936 if (isScalable)
937 return LLVMScalableVectorType::get(elementType, numElements);
938 return LLVMFixedVectorType::get(elementType, numElements);
939 }
940
941 // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
942 // scalable/non-scalable.
943 return VectorType::get(numElements, elementType, {isScalable});
944}
945
946Type mlir::LLVM::getVectorType(Type elementType,
947 const llvm::ElementCount &numElements) {
948 if (numElements.isScalable())
949 return getVectorType(elementType, numElements: numElements.getKnownMinValue(),
950 /*isScalable=*/true);
951 return getVectorType(elementType, numElements: numElements.getFixedValue(),
952 /*isScalable=*/false);
953}
954
955Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
956 bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
957 bool useBuiltIn = VectorType::isValidElementType(elementType);
958 (void)useBuiltIn;
959 assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
960 "to be either builtin or LLVM dialect type");
961 if (useLLVM)
962 return LLVMFixedVectorType::get(elementType, numElements);
963 return VectorType::get(numElements, elementType);
964}
965
966Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
967 bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType);
968 bool useBuiltIn = VectorType::isValidElementType(elementType);
969 (void)useBuiltIn;
970 assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector "
971 "type to be either builtin or LLVM dialect "
972 "type");
973 if (useLLVM)
974 return LLVMScalableVectorType::get(elementType, numElements);
975
976 // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
977 // scalable/non-scalable.
978 return VectorType::get(numElements, elementType, /*scalableDims=*/true);
979}
980
981llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
982 assert(isCompatibleType(type) &&
983 "expected a type compatible with the LLVM dialect");
984
985 return llvm::TypeSwitch<Type, llvm::TypeSize>(type)
986 .Case<BFloat16Type, Float16Type>(
987 [](Type) { return llvm::TypeSize::getFixed(16); })
988 .Case<Float32Type>([](Type) { return llvm::TypeSize::getFixed(32); })
989 .Case<Float64Type, LLVMX86MMXType>(
990 [](Type) { return llvm::TypeSize::getFixed(64); })
991 .Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); })
992 .Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); })
993 .Case<IntegerType>([](IntegerType intTy) {
994 return llvm::TypeSize::getFixed(intTy.getWidth());
995 })
996 .Case<LLVMPPCFP128Type>(
997 [](Type) { return llvm::TypeSize::getFixed(128); })
998 .Case<LLVMFixedVectorType>([](LLVMFixedVectorType t) {
999 llvm::TypeSize elementSize =
1000 getPrimitiveTypeSizeInBits(t.getElementType());
1001 return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
1002 elementSize.isScalable());
1003 })
1004 .Case<VectorType>([](VectorType t) {
1005 assert(isCompatibleVectorType(t) &&
1006 "unexpected incompatible with LLVM vector type");
1007 llvm::TypeSize elementSize =
1008 getPrimitiveTypeSizeInBits(t.getElementType());
1009 return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
1010 elementSize.isScalable());
1011 })
1012 .Default([](Type ty) {
1013 assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
1014 LLVMTokenType, LLVMStructType, LLVMArrayType,
1015 LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>(
1016 ty)) &&
1017 "unexpected missing support for primitive type");
1018 return llvm::TypeSize::getFixed(0);
1019 });
1020}
1021
1022//===----------------------------------------------------------------------===//
1023// LLVMDialect
1024//===----------------------------------------------------------------------===//
1025
1026void LLVMDialect::registerTypes() {
1027 addTypes<
1028#define GET_TYPEDEF_LIST
1029#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
1030 >();
1031}
1032
1033Type LLVMDialect::parseType(DialectAsmParser &parser) const {
1034 return detail::parseType(parser);
1035}
1036
1037void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
1038 return detail::printType(type, os);
1039}
1040

source code of mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp