1//===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the types for the LLVM dialect in MLIR. These MLIR types
10// correspond to the LLVM IR type system.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TypeDetail.h"
15
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/DialectImplementation.h"
20#include "mlir/IR/TypeSupport.h"
21
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/TypeSize.h"
24#include <optional>
25
26using namespace mlir;
27using namespace mlir::LLVM;
28
29constexpr const static uint64_t kBitsInByte = 8;
30
31//===----------------------------------------------------------------------===//
32// custom<FunctionTypes>
33//===----------------------------------------------------------------------===//
34
35static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> &params,
36 bool &isVarArg) {
37 isVarArg = false;
38 // `(` `)`
39 if (succeeded(Result: p.parseOptionalRParen()))
40 return success();
41
42 // `(` `...` `)`
43 if (succeeded(Result: p.parseOptionalEllipsis())) {
44 isVarArg = true;
45 return p.parseRParen();
46 }
47
48 // type (`,` type)* (`,` `...`)?
49 Type type;
50 if (parsePrettyLLVMType(p, type))
51 return failure();
52 params.push_back(Elt: type);
53 while (succeeded(Result: p.parseOptionalComma())) {
54 if (succeeded(Result: p.parseOptionalEllipsis())) {
55 isVarArg = true;
56 return p.parseRParen();
57 }
58 if (parsePrettyLLVMType(p, type))
59 return failure();
60 params.push_back(Elt: type);
61 }
62 return p.parseRParen();
63}
64
65static void printFunctionTypes(AsmPrinter &p, ArrayRef<Type> params,
66 bool isVarArg) {
67 llvm::interleaveComma(c: params, os&: p,
68 each_fn: [&](Type type) { printPrettyLLVMType(p, type); });
69 if (isVarArg) {
70 if (!params.empty())
71 p << ", ";
72 p << "...";
73 }
74 p << ')';
75}
76
77//===----------------------------------------------------------------------===//
78// custom<ExtTypeParams>
79//===----------------------------------------------------------------------===//
80
81/// Parses the parameter list for a target extension type. The parameter list
82/// contains an optional list of type parameters, followed by an optional list
83/// of integer parameters. Type and integer parameters cannot be interleaved in
84/// the list.
85/// extTypeParams ::= typeList? | intList? | (typeList "," intList)
86/// typeList ::= type ("," type)*
87/// intList ::= integer ("," integer)*
88static ParseResult
89parseExtTypeParams(AsmParser &p, SmallVectorImpl<Type> &typeParams,
90 SmallVectorImpl<unsigned int> &intParams) {
91 bool parseType = true;
92 auto typeOrIntParser = [&]() -> ParseResult {
93 unsigned int i;
94 auto intResult = p.parseOptionalInteger(result&: i);
95 if (intResult.has_value() && !failed(Result: *intResult)) {
96 // Successfully parsed an integer.
97 intParams.push_back(Elt: i);
98 // After the first integer was successfully parsed, no
99 // more types can be parsed.
100 parseType = false;
101 return success();
102 }
103 if (parseType) {
104 Type t;
105 if (!parsePrettyLLVMType(p, type&: t)) {
106 // Successfully parsed a type.
107 typeParams.push_back(Elt: t);
108 return success();
109 }
110 }
111 return failure();
112 };
113 if (p.parseCommaSeparatedList(parseElementFn: typeOrIntParser)) {
114 p.emitError(loc: p.getCurrentLocation(),
115 message: "failed to parse parameter list for target extension type");
116 return failure();
117 }
118 return success();
119}
120
121static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams,
122 ArrayRef<unsigned int> intParams) {
123 p << typeParams;
124 if (!typeParams.empty() && !intParams.empty())
125 p << ", ";
126
127 p << intParams;
128}
129
130//===----------------------------------------------------------------------===//
131// ODS-Generated Definitions
132//===----------------------------------------------------------------------===//
133
134/// These are unused for now.
135/// TODO: Move over to these once more types have been migrated to TypeDef.
136LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
137generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
138LLVM_ATTRIBUTE_UNUSED static LogicalResult
139generatedTypePrinter(Type def, AsmPrinter &printer);
140
141#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
142
143#define GET_TYPEDEF_CLASSES
144#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
145
146//===----------------------------------------------------------------------===//
147// LLVMArrayType
148//===----------------------------------------------------------------------===//
149
150bool LLVMArrayType::isValidElementType(Type type) {
151 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
152 LLVMFunctionType, LLVMTokenType>(Val: type);
153}
154
155LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
156 assert(elementType && "expected non-null subtype");
157 return Base::get(ctx: elementType.getContext(), args&: elementType, args&: numElements);
158}
159
160LLVMArrayType
161LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
162 Type elementType, uint64_t numElements) {
163 assert(elementType && "expected non-null subtype");
164 return Base::getChecked(emitErrorFn: emitError, ctx: elementType.getContext(), args: elementType,
165 args: numElements);
166}
167
168LogicalResult
169LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
170 Type elementType, uint64_t numElements) {
171 if (!isValidElementType(type: elementType))
172 return emitError() << "invalid array element type: " << elementType;
173 return success();
174}
175
176//===----------------------------------------------------------------------===//
177// DataLayoutTypeInterface
178//===----------------------------------------------------------------------===//
179
180llvm::TypeSize
181LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
182 DataLayoutEntryListRef params) const {
183 return llvm::TypeSize::getFixed(ExactSize: kBitsInByte *
184 getTypeSize(dataLayout, params));
185}
186
187llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
188 DataLayoutEntryListRef params) const {
189 return llvm::alignTo(Size: dataLayout.getTypeSize(t: getElementType()),
190 Align: dataLayout.getTypeABIAlignment(t: getElementType())) *
191 getNumElements();
192}
193
194uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
195 DataLayoutEntryListRef params) const {
196 return dataLayout.getTypeABIAlignment(t: getElementType());
197}
198
199uint64_t
200LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
201 DataLayoutEntryListRef params) const {
202 return dataLayout.getTypePreferredAlignment(t: getElementType());
203}
204
205//===----------------------------------------------------------------------===//
206// Function type.
207//===----------------------------------------------------------------------===//
208
209bool LLVMFunctionType::isValidArgumentType(Type type) {
210 return !llvm::isa<LLVMVoidType, LLVMFunctionType>(Val: type);
211}
212
213bool LLVMFunctionType::isValidResultType(Type type) {
214 return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(Val: type);
215}
216
217LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
218 bool isVarArg) {
219 assert(result && "expected non-null result");
220 return Base::get(ctx: result.getContext(), args&: result, args&: arguments, args&: isVarArg);
221}
222
223LLVMFunctionType
224LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
225 Type result, ArrayRef<Type> arguments,
226 bool isVarArg) {
227 assert(result && "expected non-null result");
228 return Base::getChecked(emitErrorFn: emitError, ctx: result.getContext(), args: result, args: arguments,
229 args: isVarArg);
230}
231
232LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
233 TypeRange results) const {
234 if (results.size() != 1 || !isValidResultType(type: results[0]))
235 return {};
236 if (!llvm::all_of(Range&: inputs, P: isValidArgumentType))
237 return {};
238 return get(result: results[0], arguments: llvm::to_vector(Range&: inputs), isVarArg: isVarArg());
239}
240
241ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
242 return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
243}
244
245LogicalResult
246LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
247 Type result, ArrayRef<Type> arguments, bool) {
248 if (!isValidResultType(type: result))
249 return emitError() << "invalid function result type: " << result;
250
251 for (Type arg : arguments)
252 if (!isValidArgumentType(type: arg))
253 return emitError() << "invalid function argument type: " << arg;
254
255 return success();
256}
257
258//===----------------------------------------------------------------------===//
259// DataLayoutTypeInterface
260//===----------------------------------------------------------------------===//
261
262constexpr const static uint64_t kDefaultPointerSizeBits = 64;
263constexpr const static uint64_t kDefaultPointerAlignment = 8;
264
265std::optional<uint64_t> mlir::LLVM::extractPointerSpecValue(Attribute attr,
266 PtrDLEntryPos pos) {
267 auto spec = cast<DenseIntElementsAttr>(Val&: attr);
268 auto idx = static_cast<int64_t>(pos);
269 if (idx >= spec.size())
270 return std::nullopt;
271 return spec.getValues<uint64_t>()[idx];
272}
273
274/// Returns the part of the data layout entry that corresponds to `pos` for the
275/// given `type` by interpreting the list of entries `params`. For the pointer
276/// type in the default address space, returns the default value if the entries
277/// do not provide a custom one, for other address spaces returns std::nullopt.
278static std::optional<uint64_t>
279getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type,
280 PtrDLEntryPos pos) {
281 // First, look for the entry for the pointer in the current address space.
282 Attribute currentEntry;
283 for (DataLayoutEntryInterface entry : params) {
284 if (!entry.isTypeEntry())
285 continue;
286 if (cast<LLVMPointerType>(Val: cast<Type>(Val: entry.getKey())).getAddressSpace() ==
287 type.getAddressSpace()) {
288 currentEntry = entry.getValue();
289 break;
290 }
291 }
292 if (currentEntry) {
293 std::optional<uint64_t> value = extractPointerSpecValue(attr: currentEntry, pos);
294 // If the optional `PtrDLEntryPos::Index` entry is not available, use the
295 // pointer size as the index bitwidth.
296 if (!value && pos == PtrDLEntryPos::Index)
297 value = extractPointerSpecValue(attr: currentEntry, pos: PtrDLEntryPos::Size);
298 bool isSizeOrIndex =
299 pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
300 return *value / (isSizeOrIndex ? 1 : kBitsInByte);
301 }
302
303 // If not found, and this is the pointer to the default memory space, assume
304 // 64-bit pointers.
305 if (type.getAddressSpace() == 0) {
306 bool isSizeOrIndex =
307 pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
308 return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
309 }
310
311 return std::nullopt;
312}
313
314llvm::TypeSize
315LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
316 DataLayoutEntryListRef params) const {
317 if (std::optional<uint64_t> size =
318 getPointerDataLayoutEntry(params, type: *this, pos: PtrDLEntryPos::Size))
319 return llvm::TypeSize::getFixed(ExactSize: *size);
320
321 // For other memory spaces, use the size of the pointer to the default memory
322 // space.
323 return dataLayout.getTypeSizeInBits(t: get(context: getContext()));
324}
325
326uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
327 DataLayoutEntryListRef params) const {
328 if (std::optional<uint64_t> alignment =
329 getPointerDataLayoutEntry(params, type: *this, pos: PtrDLEntryPos::Abi))
330 return *alignment;
331
332 return dataLayout.getTypeABIAlignment(t: get(context: getContext()));
333}
334
335uint64_t
336LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
337 DataLayoutEntryListRef params) const {
338 if (std::optional<uint64_t> alignment =
339 getPointerDataLayoutEntry(params, type: *this, pos: PtrDLEntryPos::Preferred))
340 return *alignment;
341
342 return dataLayout.getTypePreferredAlignment(t: get(context: getContext()));
343}
344
345std::optional<uint64_t>
346LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout,
347 DataLayoutEntryListRef params) const {
348 if (std::optional<uint64_t> indexBitwidth =
349 getPointerDataLayoutEntry(params, type: *this, pos: PtrDLEntryPos::Index))
350 return *indexBitwidth;
351
352 return dataLayout.getTypeIndexBitwidth(t: get(context: getContext()));
353}
354
355bool LLVMPointerType::areCompatible(
356 DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
357 DataLayoutSpecInterface newSpec,
358 const DataLayoutIdentifiedEntryMap &map) const {
359 for (DataLayoutEntryInterface newEntry : newLayout) {
360 if (!newEntry.isTypeEntry())
361 continue;
362 uint64_t size = kDefaultPointerSizeBits;
363 uint64_t abi = kDefaultPointerAlignment;
364 auto newType =
365 llvm::cast<LLVMPointerType>(Val: llvm::cast<Type>(Val: newEntry.getKey()));
366 const auto *it =
367 llvm::find_if(Range&: oldLayout, P: [&](DataLayoutEntryInterface entry) {
368 if (auto type = llvm::dyn_cast_if_present<Type>(Val: entry.getKey())) {
369 return llvm::cast<LLVMPointerType>(Val&: type).getAddressSpace() ==
370 newType.getAddressSpace();
371 }
372 return false;
373 });
374 if (it == oldLayout.end()) {
375 llvm::find_if(Range&: oldLayout, P: [&](DataLayoutEntryInterface entry) {
376 if (auto type = llvm::dyn_cast_if_present<Type>(Val: entry.getKey())) {
377 return llvm::cast<LLVMPointerType>(Val&: type).getAddressSpace() == 0;
378 }
379 return false;
380 });
381 }
382 if (it != oldLayout.end()) {
383 size = *extractPointerSpecValue(attr: *it, pos: PtrDLEntryPos::Size);
384 abi = *extractPointerSpecValue(attr: *it, pos: PtrDLEntryPos::Abi);
385 }
386
387 Attribute newSpec = llvm::cast<DenseIntElementsAttr>(Val: newEntry.getValue());
388 uint64_t newSize = *extractPointerSpecValue(attr: newSpec, pos: PtrDLEntryPos::Size);
389 uint64_t newAbi = *extractPointerSpecValue(attr: newSpec, pos: PtrDLEntryPos::Abi);
390 if (size != newSize || abi < newAbi || abi % newAbi != 0)
391 return false;
392 }
393 return true;
394}
395
396LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
397 Location loc) const {
398 for (DataLayoutEntryInterface entry : entries) {
399 if (!entry.isTypeEntry())
400 continue;
401 auto key = llvm::cast<Type>(Val: entry.getKey());
402 auto values = llvm::dyn_cast<DenseIntElementsAttr>(Val: entry.getValue());
403 if (!values || (values.size() != 3 && values.size() != 4)) {
404 return emitError(loc)
405 << "expected layout attribute for " << key
406 << " to be a dense integer elements attribute with 3 or 4 "
407 "elements";
408 }
409 if (!values.getElementType().isInteger(width: 64))
410 return emitError(loc) << "expected i64 parameters for " << key;
411
412 if (extractPointerSpecValue(attr: values, pos: PtrDLEntryPos::Abi) >
413 extractPointerSpecValue(attr: values, pos: PtrDLEntryPos::Preferred)) {
414 return emitError(loc) << "preferred alignment is expected to be at least "
415 "as large as ABI alignment";
416 }
417 }
418 return success();
419}
420
421//===----------------------------------------------------------------------===//
422// Struct type.
423//===----------------------------------------------------------------------===//
424
425bool LLVMStructType::isValidElementType(Type type) {
426 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
427 LLVMFunctionType, LLVMTokenType>(Val: type);
428}
429
430LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
431 StringRef name) {
432 return Base::get(ctx: context, args&: name, /*opaque=*/args: false);
433}
434
435LLVMStructType LLVMStructType::getIdentifiedChecked(
436 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
437 StringRef name) {
438 return Base::getChecked(emitErrorFn: emitError, ctx: context, args: name, /*opaque=*/args: false);
439}
440
441LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
442 StringRef name,
443 ArrayRef<Type> elements,
444 bool isPacked) {
445 std::string stringName = name.str();
446 unsigned counter = 0;
447 do {
448 auto type = LLVMStructType::getIdentified(context, name: stringName);
449 if (type.isInitialized() || failed(Result: type.setBody(types: elements, isPacked))) {
450 counter += 1;
451 stringName = (Twine(name) + "." + std::to_string(val: counter)).str();
452 continue;
453 }
454 return type;
455 } while (true);
456}
457
458LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
459 ArrayRef<Type> types, bool isPacked) {
460 return Base::get(ctx: context, args&: types, args&: isPacked);
461}
462
463LLVMStructType
464LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
465 MLIRContext *context, ArrayRef<Type> types,
466 bool isPacked) {
467 return Base::getChecked(emitErrorFn: emitError, ctx: context, args: types, args: isPacked);
468}
469
470LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
471 return Base::get(ctx: context, args&: name, /*opaque=*/args: true);
472}
473
474LLVMStructType
475LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
476 MLIRContext *context, StringRef name) {
477 return Base::getChecked(emitErrorFn: emitError, ctx: context, args: name, /*opaque=*/args: true);
478}
479
480LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
481 assert(isIdentified() && "can only set bodies of identified structs");
482 assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
483 "expected valid body types");
484 return Base::mutate(args&: types, args&: isPacked);
485}
486
487bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
488bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
489bool LLVMStructType::isOpaque() const {
490 return getImpl()->isIdentified() &&
491 (getImpl()->isOpaque() || !getImpl()->isInitialized());
492}
493bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
494StringRef LLVMStructType::getName() const { return getImpl()->getIdentifier(); }
495ArrayRef<Type> LLVMStructType::getBody() const {
496 return isIdentified() ? getImpl()->getIdentifiedStructBody()
497 : getImpl()->getTypeList();
498}
499
500LogicalResult
501LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
502 bool) {
503 return success();
504}
505
506LogicalResult
507LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
508 ArrayRef<Type> types, bool) {
509 for (Type t : types)
510 if (!isValidElementType(type: t))
511 return emitError() << "invalid LLVM structure element type: " << t;
512
513 return success();
514}
515
516llvm::TypeSize
517LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
518 DataLayoutEntryListRef params) const {
519 auto structSize = llvm::TypeSize::getFixed(ExactSize: 0);
520 uint64_t structAlignment = 1;
521 for (Type element : getBody()) {
522 uint64_t elementAlignment =
523 isPacked() ? 1 : dataLayout.getTypeABIAlignment(t: element);
524 // Add padding to the struct size to align it to the abi alignment of the
525 // element type before than adding the size of the element.
526 structSize = llvm::alignTo(Size: structSize, Align: elementAlignment);
527 structSize += dataLayout.getTypeSize(t: element);
528
529 // The alignment requirement of a struct is equal to the strictest alignment
530 // requirement of its elements.
531 structAlignment = std::max(a: elementAlignment, b: structAlignment);
532 }
533 // At the end, add padding to the struct to satisfy its own alignment
534 // requirement. Otherwise structs inside of arrays would be misaligned.
535 structSize = llvm::alignTo(Size: structSize, Align: structAlignment);
536 return structSize * kBitsInByte;
537}
538
539namespace {
540enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
541} // namespace
542
543static std::optional<uint64_t>
544getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type,
545 StructDLEntryPos pos) {
546 const auto *currentEntry =
547 llvm::find_if(Range&: params, P: [](DataLayoutEntryInterface entry) {
548 return entry.isTypeEntry();
549 });
550 if (currentEntry == params.end())
551 return std::nullopt;
552
553 auto attr = llvm::cast<DenseIntElementsAttr>(Val: currentEntry->getValue());
554 if (pos == StructDLEntryPos::Preferred &&
555 attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred))
556 // If no preferred was specified, fall back to abi alignment
557 pos = StructDLEntryPos::Abi;
558
559 return attr.getValues<uint64_t>()[static_cast<size_t>(pos)];
560}
561
562static uint64_t calculateStructAlignment(const DataLayout &dataLayout,
563 DataLayoutEntryListRef params,
564 LLVMStructType type,
565 StructDLEntryPos pos) {
566 // Packed structs always have an abi alignment of 1
567 if (pos == StructDLEntryPos::Abi && type.isPacked()) {
568 return 1;
569 }
570
571 // The alignment requirement of a struct is equal to the strictest alignment
572 // requirement of its elements.
573 uint64_t structAlignment = 1;
574 for (Type iter : type.getBody()) {
575 structAlignment =
576 std::max(a: dataLayout.getTypeABIAlignment(t: iter), b: structAlignment);
577 }
578
579 // Entries are only allowed to be stricter than the required alignment
580 if (std::optional<uint64_t> entryResult =
581 getStructDataLayoutEntry(params, type, pos))
582 return std::max(a: *entryResult / kBitsInByte, b: structAlignment);
583
584 return structAlignment;
585}
586
587uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
588 DataLayoutEntryListRef params) const {
589 return calculateStructAlignment(dataLayout, params, type: *this,
590 pos: StructDLEntryPos::Abi);
591}
592
593uint64_t
594LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
595 DataLayoutEntryListRef params) const {
596 return calculateStructAlignment(dataLayout, params, type: *this,
597 pos: StructDLEntryPos::Preferred);
598}
599
600static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
601 return llvm::cast<DenseIntElementsAttr>(Val&: attr)
602 .getValues<uint64_t>()[static_cast<size_t>(pos)];
603}
604
605bool LLVMStructType::areCompatible(
606 DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
607 DataLayoutSpecInterface newSpec,
608 const DataLayoutIdentifiedEntryMap &map) const {
609 for (DataLayoutEntryInterface newEntry : newLayout) {
610 if (!newEntry.isTypeEntry())
611 continue;
612
613 const auto *previousEntry =
614 llvm::find_if(Range&: oldLayout, P: [](DataLayoutEntryInterface entry) {
615 return entry.isTypeEntry();
616 });
617 if (previousEntry == oldLayout.end())
618 continue;
619
620 uint64_t abi = extractStructSpecValue(attr: previousEntry->getValue(),
621 pos: StructDLEntryPos::Abi);
622 uint64_t newAbi =
623 extractStructSpecValue(attr: newEntry.getValue(), pos: StructDLEntryPos::Abi);
624 if (abi < newAbi || abi % newAbi != 0)
625 return false;
626 }
627 return true;
628}
629
630LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
631 Location loc) const {
632 for (DataLayoutEntryInterface entry : entries) {
633 if (!entry.isTypeEntry())
634 continue;
635
636 auto key = llvm::cast<LLVMStructType>(Val: llvm::cast<Type>(Val: entry.getKey()));
637 auto values = llvm::dyn_cast<DenseIntElementsAttr>(Val: entry.getValue());
638 if (!values || (values.size() != 2 && values.size() != 1)) {
639 return emitError(loc)
640 << "expected layout attribute for "
641 << llvm::cast<Type>(Val: entry.getKey())
642 << " to be a dense integer elements attribute of 1 or 2 elements";
643 }
644 if (!values.getElementType().isInteger(width: 64))
645 return emitError(loc) << "expected i64 entries for " << key;
646
647 if (key.isIdentified() || !key.getBody().empty()) {
648 return emitError(loc) << "unexpected layout attribute for struct " << key;
649 }
650
651 if (values.size() == 1)
652 continue;
653
654 if (extractStructSpecValue(attr: values, pos: StructDLEntryPos::Abi) >
655 extractStructSpecValue(attr: values, pos: StructDLEntryPos::Preferred)) {
656 return emitError(loc) << "preferred alignment is expected to be at least "
657 "as large as ABI alignment";
658 }
659 }
660 return mlir::success();
661}
662
663//===----------------------------------------------------------------------===//
664// LLVMTargetExtType.
665//===----------------------------------------------------------------------===//
666
667static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
668static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
669
670bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
671 // See llvm/lib/IR/Type.cpp for reference.
672 uint64_t properties = 0;
673
674 if (getExtTypeName().starts_with(Prefix: kSpirvPrefix))
675 properties |=
676 (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
677
678 return (properties & prop) == prop;
679}
680
681bool LLVM::LLVMTargetExtType::supportsMemOps() const {
682 // See llvm/lib/IR/Type.cpp for reference.
683 if (getExtTypeName().starts_with(Prefix: kSpirvPrefix))
684 return true;
685
686 if (getExtTypeName() == kArmSVCount)
687 return true;
688
689 return false;
690}
691
692//===----------------------------------------------------------------------===//
693// LLVMPPCFP128Type
694//===----------------------------------------------------------------------===//
695
696const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
697 return APFloat::PPCDoubleDouble();
698}
699
700//===----------------------------------------------------------------------===//
701// Utility functions.
702//===----------------------------------------------------------------------===//
703
704bool mlir::LLVM::isCompatibleOuterType(Type type) {
705 // clang-format off
706 if (llvm::isa<
707 BFloat16Type,
708 Float16Type,
709 Float32Type,
710 Float64Type,
711 Float80Type,
712 Float128Type,
713 LLVMArrayType,
714 LLVMFunctionType,
715 LLVMLabelType,
716 LLVMMetadataType,
717 LLVMPPCFP128Type,
718 LLVMPointerType,
719 LLVMStructType,
720 LLVMTokenType,
721 LLVMTargetExtType,
722 LLVMVoidType,
723 LLVMX86AMXType
724 >(Val: type)) {
725 // clang-format on
726 return true;
727 }
728
729 // Only signless integers are compatible.
730 if (auto intType = llvm::dyn_cast<IntegerType>(Val&: type))
731 return intType.isSignless();
732
733 // 1D vector types are compatible.
734 if (auto vecType = llvm::dyn_cast<VectorType>(Val&: type))
735 return vecType.getRank() == 1;
736
737 return false;
738}
739
740static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
741 if (!compatibleTypes.insert(V: type).second)
742 return true;
743
744 auto isCompatible = [&](Type type) {
745 return isCompatibleImpl(type, compatibleTypes);
746 };
747
748 bool result =
749 llvm::TypeSwitch<Type, bool>(type)
750 .Case<LLVMStructType>(caseFn: [&](auto structType) {
751 return llvm::all_of(structType.getBody(), isCompatible);
752 })
753 .Case<LLVMFunctionType>(caseFn: [&](auto funcType) {
754 return isCompatible(funcType.getReturnType()) &&
755 llvm::all_of(funcType.getParams(), isCompatible);
756 })
757 .Case<IntegerType>(caseFn: [](auto intType) { return intType.isSignless(); })
758 .Case<VectorType>(caseFn: [&](auto vecType) {
759 return vecType.getRank() == 1 &&
760 isCompatible(vecType.getElementType());
761 })
762 .Case<LLVMPointerType>(caseFn: [&](auto pointerType) { return true; })
763 .Case<LLVMTargetExtType>(caseFn: [&](auto extType) {
764 return llvm::all_of(extType.getTypeParams(), isCompatible);
765 })
766 // clang-format off
767 .Case<
768 LLVMArrayType
769 >(caseFn: [&](auto containerType) {
770 return isCompatible(containerType.getElementType());
771 })
772 .Case<
773 BFloat16Type,
774 Float16Type,
775 Float32Type,
776 Float64Type,
777 Float80Type,
778 Float128Type,
779 LLVMLabelType,
780 LLVMMetadataType,
781 LLVMPPCFP128Type,
782 LLVMTokenType,
783 LLVMVoidType,
784 LLVMX86AMXType
785 >(caseFn: [](Type) { return true; })
786 // clang-format on
787 .Default(defaultFn: [](Type) { return false; });
788
789 if (!result)
790 compatibleTypes.erase(V: type);
791
792 return result;
793}
794
795bool LLVMDialect::isCompatibleType(Type type) {
796 if (auto *llvmDialect =
797 type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
798 return isCompatibleImpl(type, compatibleTypes&: llvmDialect->compatibleTypes.get());
799
800 DenseSet<Type> localCompatibleTypes;
801 return isCompatibleImpl(type, compatibleTypes&: localCompatibleTypes);
802}
803
804bool mlir::LLVM::isCompatibleType(Type type) {
805 return LLVMDialect::isCompatibleType(type);
806}
807
808bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
809 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
810 Float80Type, Float128Type, LLVMPPCFP128Type>(Val: type);
811}
812
813bool mlir::LLVM::isCompatibleVectorType(Type type) {
814 if (auto vecType = llvm::dyn_cast<VectorType>(Val&: type)) {
815 if (vecType.getRank() != 1)
816 return false;
817 Type elementType = vecType.getElementType();
818 if (auto intType = llvm::dyn_cast<IntegerType>(Val&: elementType))
819 return intType.isSignless();
820 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
821 Float80Type, Float128Type, LLVMPointerType>(Val: elementType);
822 }
823 return false;
824}
825
826llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
827 auto vecTy = dyn_cast<VectorType>(Val&: type);
828 assert(vecTy && "incompatible with LLVM vector type");
829 if (vecTy.isScalable())
830 return llvm::ElementCount::getScalable(MinVal: vecTy.getNumElements());
831 return llvm::ElementCount::getFixed(MinVal: vecTy.getNumElements());
832}
833
834bool mlir::LLVM::isScalableVectorType(Type vectorType) {
835 assert(llvm::isa<VectorType>(vectorType) &&
836 "expected LLVM-compatible vector type");
837 return llvm::cast<VectorType>(Val&: vectorType).isScalable();
838}
839
840Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
841 bool isScalable) {
842 assert(VectorType::isValidElementType(elementType) &&
843 "incompatible element type");
844 return VectorType::get(shape: numElements, elementType, scalableDims: {isScalable});
845}
846
847Type mlir::LLVM::getVectorType(Type elementType,
848 const llvm::ElementCount &numElements) {
849 if (numElements.isScalable())
850 return getVectorType(elementType, numElements: numElements.getKnownMinValue(),
851 /*isScalable=*/true);
852 return getVectorType(elementType, numElements: numElements.getFixedValue(),
853 /*isScalable=*/false);
854}
855
856llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
857 assert(isCompatibleType(type) &&
858 "expected a type compatible with the LLVM dialect");
859
860 return llvm::TypeSwitch<Type, llvm::TypeSize>(type)
861 .Case<BFloat16Type, Float16Type>(
862 caseFn: [](Type) { return llvm::TypeSize::getFixed(ExactSize: 16); })
863 .Case<Float32Type>(caseFn: [](Type) { return llvm::TypeSize::getFixed(ExactSize: 32); })
864 .Case<Float64Type>(caseFn: [](Type) { return llvm::TypeSize::getFixed(ExactSize: 64); })
865 .Case<Float80Type>(caseFn: [](Type) { return llvm::TypeSize::getFixed(ExactSize: 80); })
866 .Case<Float128Type>(caseFn: [](Type) { return llvm::TypeSize::getFixed(ExactSize: 128); })
867 .Case<IntegerType>(caseFn: [](IntegerType intTy) {
868 return llvm::TypeSize::getFixed(ExactSize: intTy.getWidth());
869 })
870 .Case<LLVMPPCFP128Type>(
871 caseFn: [](Type) { return llvm::TypeSize::getFixed(ExactSize: 128); })
872 .Case<VectorType>(caseFn: [](VectorType t) {
873 assert(isCompatibleVectorType(t) &&
874 "unexpected incompatible with LLVM vector type");
875 llvm::TypeSize elementSize =
876 getPrimitiveTypeSizeInBits(type: t.getElementType());
877 return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
878 elementSize.isScalable());
879 })
880 .Default(defaultFn: [](Type ty) {
881 assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
882 LLVMTokenType, LLVMStructType, LLVMArrayType,
883 LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>(
884 ty)) &&
885 "unexpected missing support for primitive type");
886 return llvm::TypeSize::getFixed(ExactSize: 0);
887 });
888}
889
890//===----------------------------------------------------------------------===//
891// LLVMDialect
892//===----------------------------------------------------------------------===//
893
894void LLVMDialect::registerTypes() {
895 addTypes<
896#define GET_TYPEDEF_LIST
897#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
898 >();
899}
900
901Type LLVMDialect::parseType(DialectAsmParser &parser) const {
902 return detail::parseType(parser);
903}
904
905void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
906 return detail::printType(type, printer&: os);
907}
908

source code of mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp