1//===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_TYPES_H
10#define MLIR_IR_TYPES_H
11
12#include "mlir/IR/TypeSupport.h"
13#include "llvm/ADT/ArrayRef.h"
14#include "llvm/ADT/DenseMapInfo.h"
15#include "llvm/Support/PointerLikeTypeTraits.h"
16
17namespace mlir {
18class AsmState;
19
20/// Instances of the Type class are uniqued, have an immutable identifier and an
21/// optional mutable component. They wrap a pointer to the storage object owned
22/// by MLIRContext. Therefore, instances of Type are passed around by value.
23///
24/// Some types are "primitives" meaning they do not have any parameters, for
25/// example the Index type. Parametric types have additional information that
26/// differentiates the types of the same class, for example the Integer type has
27/// bitwidth, making i8 and i16 belong to the same kind by be different
28/// instances of the IntegerType. Type parameters are part of the unique
29/// immutable key. The mutable component of the type can be modified after the
30/// type is created, but cannot affect the identity of the type.
31///
32/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
33///
34/// Derived type classes are expected to implement several required
35/// implementation hooks:
36/// * Optional:
37/// - static LogicalResult verify(
38/// function_ref<InFlightDiagnostic()> emitError,
39/// Args... args)
40/// * This method is invoked when calling the 'TypeBase::get/getChecked'
41/// methods to ensure that the arguments passed in are valid to construct
42/// a type instance with.
43/// * This method is expected to return failure if a type cannot be
44/// constructed with 'args', success otherwise.
45/// * 'args' must correspond with the arguments passed into the
46/// 'TypeBase::get' call.
47///
48///
49/// Type storage objects inherit from TypeStorage and contain the following:
50/// - The dialect that defined the type.
51/// - Any parameters of the type.
52/// - An optional mutable component.
53/// For non-parametric types, a convenience DefaultTypeStorage is provided.
54/// Parametric storage types must derive TypeStorage and respect the following:
55/// - Define a type alias, KeyTy, to a type that uniquely identifies the
56/// instance of the type.
57/// * The key type must be constructible from the values passed into the
58/// detail::TypeUniquer::get call.
59/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
60/// storage class must define a hashing method:
61/// 'static unsigned hashKey(const KeyTy &)'
62///
63/// - Provide a method, 'bool operator==(const KeyTy &) const', to
64/// compare the storage instance against an instance of the key type.
65///
66/// - Provide a static construction method:
67/// 'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)'
68/// that builds a unique instance of the derived storage. The arguments to
69/// this function are an allocator to store any uniqued data within the
70/// context and the key type for this storage.
71///
72/// - If they have a mutable component, this component must not be a part of
73/// the key.
74class Type {
75public:
76 /// Utility class for implementing types.
77 template <typename ConcreteType, typename BaseType, typename StorageType,
78 template <typename T> class... Traits>
79 using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
80 detail::TypeUniquer, Traits...>;
81
82 using ImplType = TypeStorage;
83
84 using AbstractTy = AbstractType;
85
86 constexpr Type() = default;
87 /* implicit */ Type(const ImplType *impl)
88 : impl(const_cast<ImplType *>(impl)) {}
89
90 Type(const Type &other) = default;
91 Type &operator=(const Type &other) = default;
92
93 bool operator==(Type other) const { return impl == other.impl; }
94 bool operator!=(Type other) const { return !(*this == other); }
95 explicit operator bool() const { return impl; }
96
97 bool operator!() const { return impl == nullptr; }
98
99 template <typename... Tys>
100 bool isa() const;
101 template <typename... Tys>
102 bool isa_and_nonnull() const;
103 template <typename U>
104 U dyn_cast() const;
105 template <typename U>
106 U dyn_cast_or_null() const;
107 template <typename U>
108 U cast() const;
109
110 /// Return a unique identifier for the concrete type. This is used to support
111 /// dynamic type casting.
112 TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
113
114 /// Return the MLIRContext in which this type was uniqued.
115 MLIRContext *getContext() const;
116
117 /// Get the dialect this type is registered to.
118 Dialect &getDialect() const { return impl->getAbstractType().getDialect(); }
119
120 // Convenience predicates. This is only for floating point types,
121 // derived types should use isa/dyn_cast.
122 bool isIndex() const;
123 bool isFloat8E5M2() const;
124 bool isFloat8E4M3FN() const;
125 bool isFloat8E5M2FNUZ() const;
126 bool isFloat8E4M3FNUZ() const;
127 bool isFloat8E4M3B11FNUZ() const;
128 bool isBF16() const;
129 bool isF16() const;
130 bool isTF32() const;
131 bool isF32() const;
132 bool isF64() const;
133 bool isF80() const;
134 bool isF128() const;
135
136 /// Return true if this is an integer type (with the specified width).
137 bool isInteger() const;
138 bool isInteger(unsigned width) const;
139 /// Return true if this is a signless integer type (with the specified width).
140 bool isSignlessInteger() const;
141 bool isSignlessInteger(unsigned width) const;
142 /// Return true if this is a signed integer type (with the specified width).
143 bool isSignedInteger() const;
144 bool isSignedInteger(unsigned width) const;
145 /// Return true if this is an unsigned integer type (with the specified
146 /// width).
147 bool isUnsignedInteger() const;
148 bool isUnsignedInteger(unsigned width) const;
149
150 /// Return the bit width of an integer or a float type, assert failure on
151 /// other types.
152 unsigned getIntOrFloatBitWidth() const;
153
154 /// Return true if this is a signless integer or index type.
155 bool isSignlessIntOrIndex() const;
156 /// Return true if this is a signless integer, index, or float type.
157 bool isSignlessIntOrIndexOrFloat() const;
158 /// Return true of this is a signless integer or a float type.
159 bool isSignlessIntOrFloat() const;
160
161 /// Return true if this is an integer (of any signedness) or an index type.
162 bool isIntOrIndex() const;
163 /// Return true if this is an integer (of any signedness) or a float type.
164 bool isIntOrFloat() const;
165 /// Return true if this is an integer (of any signedness), index, or float
166 /// type.
167 bool isIntOrIndexOrFloat() const;
168
169 /// Print the current type.
170 void print(raw_ostream &os) const;
171 void print(raw_ostream &os, AsmState &state) const;
172 void dump() const;
173
174 friend ::llvm::hash_code hash_value(Type arg);
175
176 /// Methods for supporting PointerLikeTypeTraits.
177 const void *getAsOpaquePointer() const {
178 return static_cast<const void *>(impl);
179 }
180 static Type getFromOpaquePointer(const void *pointer) {
181 return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
182 }
183
184 /// Returns true if `InterfaceT` has been promised by the dialect or
185 /// implemented.
186 template <typename InterfaceT>
187 bool hasPromiseOrImplementsInterface() {
188 return dialect_extension_detail::hasPromisedInterface(
189 getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
190 mlir::isa<InterfaceT>(*this);
191 }
192
193 /// Returns true if the type was registered with a particular trait.
194 template <template <typename T> class Trait>
195 bool hasTrait() {
196 return getAbstractType().hasTrait<Trait>();
197 }
198
199 /// Return the abstract type descriptor for this type.
200 const AbstractTy &getAbstractType() const { return impl->getAbstractType(); }
201
202 /// Return the Type implementation.
203 ImplType *getImpl() const { return impl; }
204
205 /// Walk all of the immediately nested sub-attributes and sub-types. This
206 /// method does not recurse into sub elements.
207 void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
208 function_ref<void(Type)> walkTypesFn) const {
209 getAbstractType().walkImmediateSubElements(type: *this, walkAttrsFn, walkTypesFn);
210 }
211
212 /// Replace the immediately nested sub-attributes and sub-types with those
213 /// provided. The order of the provided elements is derived from the order of
214 /// the elements returned by the callbacks of `walkImmediateSubElements`. The
215 /// element at index 0 would replace the very first attribute given by
216 /// `walkImmediateSubElements`. On success, the new instance with the values
217 /// replaced is returned. If replacement fails, nullptr is returned.
218 auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
219 ArrayRef<Type> replTypes) const {
220 return getAbstractType().replaceImmediateSubElements(type: *this, replAttrs,
221 replTypes);
222 }
223
224 /// Walk this type and all attibutes/types nested within using the
225 /// provided walk functions. See `AttrTypeWalker` for information on the
226 /// supported walk function types.
227 template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns>
228 auto walk(WalkFns &&...walkFns) {
229 AttrTypeWalker walker;
230 (walker.addWalk(std::forward<WalkFns>(walkFns)), ...);
231 return walker.walk<Order>(*this);
232 }
233
234 /// Recursively replace all of the nested sub-attributes and sub-types using
235 /// the provided map functions. Returns nullptr in the case of failure. See
236 /// `AttrTypeReplacer` for information on the support replacement function
237 /// types.
238 template <typename... ReplacementFns>
239 auto replace(ReplacementFns &&...replacementFns) {
240 AttrTypeReplacer replacer;
241 (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)),
242 ...);
243 return replacer.replace(type: *this);
244 }
245
246protected:
247 ImplType *impl{nullptr};
248};
249
250inline raw_ostream &operator<<(raw_ostream &os, Type type) {
251 type.print(os);
252 return os;
253}
254
255//===----------------------------------------------------------------------===//
256// TypeTraitBase
257//===----------------------------------------------------------------------===//
258
259namespace TypeTrait {
260/// This class represents the base of a type trait.
261template <typename ConcreteType, template <typename> class TraitType>
262using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
263} // namespace TypeTrait
264
265//===----------------------------------------------------------------------===//
266// TypeInterface
267//===----------------------------------------------------------------------===//
268
269/// This class represents the base of a type interface. See the definition of
270/// `detail::Interface` for requirements on the `Traits` type.
271template <typename ConcreteType, typename Traits>
272class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
273 TypeTrait::TraitBase> {
274public:
275 using Base = TypeInterface<ConcreteType, Traits>;
276 using InterfaceBase =
277 detail::Interface<ConcreteType, Type, Traits, Type, TypeTrait::TraitBase>;
278 using InterfaceBase::InterfaceBase;
279
280protected:
281 /// Returns the impl interface instance for the given type.
282 static typename InterfaceBase::Concept *getInterfaceFor(Type type) {
283#ifndef NDEBUG
284 // Check that the current interface isn't an unresolved promise for the
285 // given type.
286 dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
287 dialect&: type.getDialect(), interfaceRequestorID: type.getTypeID(), interfaceID: ConcreteType::getInterfaceID(),
288 interfaceName: llvm::getTypeName<ConcreteType>());
289#endif
290
291 return type.getAbstractType().getInterface<ConcreteType>();
292 }
293
294 /// Allow access to 'getInterfaceFor'.
295 friend InterfaceBase;
296};
297
298//===----------------------------------------------------------------------===//
299// Core TypeTrait
300//===----------------------------------------------------------------------===//
301
302/// This trait is used to determine if a type is mutable or not. It is attached
303/// on a type if the corresponding ImplType defines a `mutate` function with
304/// a proper signature.
305namespace TypeTrait {
306template <typename ConcreteType>
307using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
308} // namespace TypeTrait
309
310//===----------------------------------------------------------------------===//
311// Type Utils
312//===----------------------------------------------------------------------===//
313
314// Make Type hashable.
315inline ::llvm::hash_code hash_value(Type arg) {
316 return DenseMapInfo<const Type::ImplType *>::getHashValue(PtrVal: arg.impl);
317}
318
319template <typename... Tys>
320bool Type::isa() const {
321 return llvm::isa<Tys...>(*this);
322}
323
324template <typename... Tys>
325bool Type::isa_and_nonnull() const {
326 return llvm::isa_and_present<Tys...>(*this);
327}
328
329template <typename U>
330U Type::dyn_cast() const {
331 return llvm::dyn_cast<U>(*this);
332}
333
334template <typename U>
335U Type::dyn_cast_or_null() const {
336 return llvm::dyn_cast_or_null<U>(*this);
337}
338
339template <typename U>
340U Type::cast() const {
341 return llvm::cast<U>(*this);
342}
343
344} // namespace mlir
345
346namespace llvm {
347
348// Type hash just like pointers.
349template <>
350struct DenseMapInfo<mlir::Type> {
351 static mlir::Type getEmptyKey() {
352 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
353 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
354 }
355 static mlir::Type getTombstoneKey() {
356 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
357 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
358 }
359 static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(arg: val); }
360 static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
361};
362template <typename T>
363struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value &&
364 !mlir::detail::IsInterface<T>::value>>
365 : public DenseMapInfo<mlir::Type> {
366 static T getEmptyKey() {
367 const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
368 return T::getFromOpaquePointer(pointer);
369 }
370 static T getTombstoneKey() {
371 const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
372 return T::getFromOpaquePointer(pointer);
373 }
374};
375
376/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
377template <>
378struct PointerLikeTypeTraits<mlir::Type> {
379public:
380 static inline void *getAsVoidPointer(mlir::Type I) {
381 return const_cast<void *>(I.getAsOpaquePointer());
382 }
383 static inline mlir::Type getFromVoidPointer(void *P) {
384 return mlir::Type::getFromOpaquePointer(pointer: P);
385 }
386 static constexpr int NumLowBitsAvailable = 3;
387};
388
389/// Add support for llvm style casts.
390/// We provide a cast between To and From if From is mlir::Type or derives from
391/// it
392template <typename To, typename From>
393struct CastInfo<
394 To, From,
395 std::enable_if_t<std::is_same_v<mlir::Type, std::remove_const_t<From>> ||
396 std::is_base_of_v<mlir::Type, From>>>
397 : NullableValueCastFailed<To>,
398 DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
399 /// Arguments are taken as mlir::Type here and not as `From`, because when
400 /// casting from an intermediate type of the hierarchy to one of its children,
401 /// the val.getTypeID() inside T::classof will use the static getTypeID of the
402 /// parent instead of the non-static Type::getTypeID that returns the dynamic
403 /// ID. This means that T::classof would end up comparing the static TypeID of
404 /// the children to the static TypeID of its parent, making it impossible to
405 /// downcast from the parent to the child.
406 static inline bool isPossible(mlir::Type ty) {
407 /// Return a constant true instead of a dynamic true when casting to self or
408 /// up the hierarchy.
409 if constexpr (std::is_base_of_v<To, From>) {
410 return true;
411 } else {
412 return To::classof(ty);
413 };
414 }
415 static inline To doCast(mlir::Type ty) { return To(ty.getImpl()); }
416};
417
418} // namespace llvm
419
420#endif // MLIR_IR_TYPES_H
421

source code of mlir/include/mlir/IR/Types.h