1 | //===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_IR_TYPES_H |
10 | #define MLIR_IR_TYPES_H |
11 | |
12 | #include "mlir/IR/TypeSupport.h" |
13 | #include "llvm/ADT/ArrayRef.h" |
14 | #include "llvm/ADT/DenseMapInfo.h" |
15 | #include "llvm/Support/PointerLikeTypeTraits.h" |
16 | |
17 | namespace mlir { |
18 | class AsmState; |
19 | |
20 | /// Instances of the Type class are uniqued, have an immutable identifier and an |
21 | /// optional mutable component. They wrap a pointer to the storage object owned |
22 | /// by MLIRContext. Therefore, instances of Type are passed around by value. |
23 | /// |
24 | /// Some types are "primitives" meaning they do not have any parameters, for |
25 | /// example the Index type. Parametric types have additional information that |
26 | /// differentiates the types of the same class, for example the Integer type has |
27 | /// bitwidth, making i8 and i16 belong to the same kind by be different |
28 | /// instances of the IntegerType. Type parameters are part of the unique |
29 | /// immutable key. The mutable component of the type can be modified after the |
30 | /// type is created, but cannot affect the identity of the type. |
31 | /// |
32 | /// Types are constructed and uniqued via the 'detail::TypeUniquer' class. |
33 | /// |
34 | /// Derived type classes are expected to implement several required |
35 | /// implementation hooks: |
36 | /// * Optional: |
37 | /// - static LogicalResult verify( |
38 | /// function_ref<InFlightDiagnostic()> emitError, |
39 | /// Args... args) |
40 | /// * This method is invoked when calling the 'TypeBase::get/getChecked' |
41 | /// methods to ensure that the arguments passed in are valid to construct |
42 | /// a type instance with. |
43 | /// * This method is expected to return failure if a type cannot be |
44 | /// constructed with 'args', success otherwise. |
45 | /// * 'args' must correspond with the arguments passed into the |
46 | /// 'TypeBase::get' call. |
47 | /// |
48 | /// |
49 | /// Type storage objects inherit from TypeStorage and contain the following: |
50 | /// - The dialect that defined the type. |
51 | /// - Any parameters of the type. |
52 | /// - An optional mutable component. |
53 | /// For non-parametric types, a convenience DefaultTypeStorage is provided. |
54 | /// Parametric storage types must derive TypeStorage and respect the following: |
55 | /// - Define a type alias, KeyTy, to a type that uniquely identifies the |
56 | /// instance of the type. |
57 | /// * The key type must be constructible from the values passed into the |
58 | /// detail::TypeUniquer::get call. |
59 | /// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the |
60 | /// storage class must define a hashing method: |
61 | /// 'static unsigned hashKey(const KeyTy &)' |
62 | /// |
63 | /// - Provide a method, 'bool operator==(const KeyTy &) const', to |
64 | /// compare the storage instance against an instance of the key type. |
65 | /// |
66 | /// - Provide a static construction method: |
67 | /// 'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)' |
68 | /// that builds a unique instance of the derived storage. The arguments to |
69 | /// this function are an allocator to store any uniqued data within the |
70 | /// context and the key type for this storage. |
71 | /// |
72 | /// - If they have a mutable component, this component must not be a part of |
73 | /// the key. |
74 | class Type { |
75 | public: |
76 | /// Utility class for implementing types. |
77 | template <typename ConcreteType, typename BaseType, typename StorageType, |
78 | template <typename T> class... Traits> |
79 | using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType, |
80 | detail::TypeUniquer, Traits...>; |
81 | |
82 | using ImplType = TypeStorage; |
83 | |
84 | using AbstractTy = AbstractType; |
85 | |
86 | constexpr Type() = default; |
87 | /* implicit */ Type(const ImplType *impl) |
88 | : impl(const_cast<ImplType *>(impl)) {} |
89 | |
90 | Type(const Type &other) = default; |
91 | Type &operator=(const Type &other) = default; |
92 | |
93 | bool operator==(Type other) const { return impl == other.impl; } |
94 | bool operator!=(Type other) const { return !(*this == other); } |
95 | explicit operator bool() const { return impl; } |
96 | |
97 | bool operator!() const { return impl == nullptr; } |
98 | |
99 | template <typename... Tys> |
100 | bool isa() const; |
101 | template <typename... Tys> |
102 | bool isa_and_nonnull() const; |
103 | template <typename U> |
104 | U dyn_cast() const; |
105 | template <typename U> |
106 | U dyn_cast_or_null() const; |
107 | template <typename U> |
108 | U cast() const; |
109 | |
110 | /// Return a unique identifier for the concrete type. This is used to support |
111 | /// dynamic type casting. |
112 | TypeID getTypeID() { return impl->getAbstractType().getTypeID(); } |
113 | |
114 | /// Return the MLIRContext in which this type was uniqued. |
115 | MLIRContext *getContext() const; |
116 | |
117 | /// Get the dialect this type is registered to. |
118 | Dialect &getDialect() const { return impl->getAbstractType().getDialect(); } |
119 | |
120 | // Convenience predicates. This is only for floating point types, |
121 | // derived types should use isa/dyn_cast. |
122 | bool isIndex() const; |
123 | bool isFloat8E5M2() const; |
124 | bool isFloat8E4M3FN() const; |
125 | bool isFloat8E5M2FNUZ() const; |
126 | bool isFloat8E4M3FNUZ() const; |
127 | bool isFloat8E4M3B11FNUZ() const; |
128 | bool isBF16() const; |
129 | bool isF16() const; |
130 | bool isTF32() const; |
131 | bool isF32() const; |
132 | bool isF64() const; |
133 | bool isF80() const; |
134 | bool isF128() const; |
135 | |
136 | /// Return true if this is an integer type (with the specified width). |
137 | bool isInteger() const; |
138 | bool isInteger(unsigned width) const; |
139 | /// Return true if this is a signless integer type (with the specified width). |
140 | bool isSignlessInteger() const; |
141 | bool isSignlessInteger(unsigned width) const; |
142 | /// Return true if this is a signed integer type (with the specified width). |
143 | bool isSignedInteger() const; |
144 | bool isSignedInteger(unsigned width) const; |
145 | /// Return true if this is an unsigned integer type (with the specified |
146 | /// width). |
147 | bool isUnsignedInteger() const; |
148 | bool isUnsignedInteger(unsigned width) const; |
149 | |
150 | /// Return the bit width of an integer or a float type, assert failure on |
151 | /// other types. |
152 | unsigned getIntOrFloatBitWidth() const; |
153 | |
154 | /// Return true if this is a signless integer or index type. |
155 | bool isSignlessIntOrIndex() const; |
156 | /// Return true if this is a signless integer, index, or float type. |
157 | bool isSignlessIntOrIndexOrFloat() const; |
158 | /// Return true of this is a signless integer or a float type. |
159 | bool isSignlessIntOrFloat() const; |
160 | |
161 | /// Return true if this is an integer (of any signedness) or an index type. |
162 | bool isIntOrIndex() const; |
163 | /// Return true if this is an integer (of any signedness) or a float type. |
164 | bool isIntOrFloat() const; |
165 | /// Return true if this is an integer (of any signedness), index, or float |
166 | /// type. |
167 | bool isIntOrIndexOrFloat() const; |
168 | |
169 | /// Print the current type. |
170 | void print(raw_ostream &os) const; |
171 | void print(raw_ostream &os, AsmState &state) const; |
172 | void dump() const; |
173 | |
174 | friend ::llvm::hash_code hash_value(Type arg); |
175 | |
176 | /// Methods for supporting PointerLikeTypeTraits. |
177 | const void *getAsOpaquePointer() const { |
178 | return static_cast<const void *>(impl); |
179 | } |
180 | static Type getFromOpaquePointer(const void *pointer) { |
181 | return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer))); |
182 | } |
183 | |
184 | /// Returns true if `InterfaceT` has been promised by the dialect or |
185 | /// implemented. |
186 | template <typename InterfaceT> |
187 | bool hasPromiseOrImplementsInterface() { |
188 | return dialect_extension_detail::hasPromisedInterface( |
189 | getDialect(), getTypeID(), InterfaceT::getInterfaceID()) || |
190 | mlir::isa<InterfaceT>(*this); |
191 | } |
192 | |
193 | /// Returns true if the type was registered with a particular trait. |
194 | template <template <typename T> class Trait> |
195 | bool hasTrait() { |
196 | return getAbstractType().hasTrait<Trait>(); |
197 | } |
198 | |
199 | /// Return the abstract type descriptor for this type. |
200 | const AbstractTy &getAbstractType() const { return impl->getAbstractType(); } |
201 | |
202 | /// Return the Type implementation. |
203 | ImplType *getImpl() const { return impl; } |
204 | |
205 | /// Walk all of the immediately nested sub-attributes and sub-types. This |
206 | /// method does not recurse into sub elements. |
207 | void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn, |
208 | function_ref<void(Type)> walkTypesFn) const { |
209 | getAbstractType().walkImmediateSubElements(type: *this, walkAttrsFn, walkTypesFn); |
210 | } |
211 | |
212 | /// Replace the immediately nested sub-attributes and sub-types with those |
213 | /// provided. The order of the provided elements is derived from the order of |
214 | /// the elements returned by the callbacks of `walkImmediateSubElements`. The |
215 | /// element at index 0 would replace the very first attribute given by |
216 | /// `walkImmediateSubElements`. On success, the new instance with the values |
217 | /// replaced is returned. If replacement fails, nullptr is returned. |
218 | auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, |
219 | ArrayRef<Type> replTypes) const { |
220 | return getAbstractType().replaceImmediateSubElements(type: *this, replAttrs, |
221 | replTypes); |
222 | } |
223 | |
224 | /// Walk this type and all attibutes/types nested within using the |
225 | /// provided walk functions. See `AttrTypeWalker` for information on the |
226 | /// supported walk function types. |
227 | template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns> |
228 | auto walk(WalkFns &&...walkFns) { |
229 | AttrTypeWalker walker; |
230 | (walker.addWalk(std::forward<WalkFns>(walkFns)), ...); |
231 | return walker.walk<Order>(*this); |
232 | } |
233 | |
234 | /// Recursively replace all of the nested sub-attributes and sub-types using |
235 | /// the provided map functions. Returns nullptr in the case of failure. See |
236 | /// `AttrTypeReplacer` for information on the support replacement function |
237 | /// types. |
238 | template <typename... ReplacementFns> |
239 | auto replace(ReplacementFns &&...replacementFns) { |
240 | AttrTypeReplacer replacer; |
241 | (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), |
242 | ...); |
243 | return replacer.replace(type: *this); |
244 | } |
245 | |
246 | protected: |
247 | ImplType *impl{nullptr}; |
248 | }; |
249 | |
250 | inline raw_ostream &operator<<(raw_ostream &os, Type type) { |
251 | type.print(os); |
252 | return os; |
253 | } |
254 | |
255 | //===----------------------------------------------------------------------===// |
256 | // TypeTraitBase |
257 | //===----------------------------------------------------------------------===// |
258 | |
259 | namespace TypeTrait { |
260 | /// This class represents the base of a type trait. |
261 | template <typename ConcreteType, template <typename> class TraitType> |
262 | using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>; |
263 | } // namespace TypeTrait |
264 | |
265 | //===----------------------------------------------------------------------===// |
266 | // TypeInterface |
267 | //===----------------------------------------------------------------------===// |
268 | |
269 | /// This class represents the base of a type interface. See the definition of |
270 | /// `detail::Interface` for requirements on the `Traits` type. |
271 | template <typename ConcreteType, typename Traits> |
272 | class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type, |
273 | TypeTrait::TraitBase> { |
274 | public: |
275 | using Base = TypeInterface<ConcreteType, Traits>; |
276 | using InterfaceBase = |
277 | detail::Interface<ConcreteType, Type, Traits, Type, TypeTrait::TraitBase>; |
278 | using InterfaceBase::InterfaceBase; |
279 | |
280 | protected: |
281 | /// Returns the impl interface instance for the given type. |
282 | static typename InterfaceBase::Concept *getInterfaceFor(Type type) { |
283 | #ifndef NDEBUG |
284 | // Check that the current interface isn't an unresolved promise for the |
285 | // given type. |
286 | dialect_extension_detail::handleUseOfUndefinedPromisedInterface( |
287 | dialect&: type.getDialect(), interfaceRequestorID: type.getTypeID(), interfaceID: ConcreteType::getInterfaceID(), |
288 | interfaceName: llvm::getTypeName<ConcreteType>()); |
289 | #endif |
290 | |
291 | return type.getAbstractType().getInterface<ConcreteType>(); |
292 | } |
293 | |
294 | /// Allow access to 'getInterfaceFor'. |
295 | friend InterfaceBase; |
296 | }; |
297 | |
298 | //===----------------------------------------------------------------------===// |
299 | // Core TypeTrait |
300 | //===----------------------------------------------------------------------===// |
301 | |
302 | /// This trait is used to determine if a type is mutable or not. It is attached |
303 | /// on a type if the corresponding ImplType defines a `mutate` function with |
304 | /// a proper signature. |
305 | namespace TypeTrait { |
306 | template <typename ConcreteType> |
307 | using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>; |
308 | } // namespace TypeTrait |
309 | |
310 | //===----------------------------------------------------------------------===// |
311 | // Type Utils |
312 | //===----------------------------------------------------------------------===// |
313 | |
314 | // Make Type hashable. |
315 | inline ::llvm::hash_code hash_value(Type arg) { |
316 | return DenseMapInfo<const Type::ImplType *>::getHashValue(PtrVal: arg.impl); |
317 | } |
318 | |
319 | template <typename... Tys> |
320 | bool Type::isa() const { |
321 | return llvm::isa<Tys...>(*this); |
322 | } |
323 | |
324 | template <typename... Tys> |
325 | bool Type::isa_and_nonnull() const { |
326 | return llvm::isa_and_present<Tys...>(*this); |
327 | } |
328 | |
329 | template <typename U> |
330 | U Type::dyn_cast() const { |
331 | return llvm::dyn_cast<U>(*this); |
332 | } |
333 | |
334 | template <typename U> |
335 | U Type::dyn_cast_or_null() const { |
336 | return llvm::dyn_cast_or_null<U>(*this); |
337 | } |
338 | |
339 | template <typename U> |
340 | U Type::cast() const { |
341 | return llvm::cast<U>(*this); |
342 | } |
343 | |
344 | } // namespace mlir |
345 | |
346 | namespace llvm { |
347 | |
348 | // Type hash just like pointers. |
349 | template <> |
350 | struct DenseMapInfo<mlir::Type> { |
351 | static mlir::Type getEmptyKey() { |
352 | auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
353 | return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer)); |
354 | } |
355 | static mlir::Type getTombstoneKey() { |
356 | auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
357 | return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer)); |
358 | } |
359 | static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(arg: val); } |
360 | static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; } |
361 | }; |
362 | template <typename T> |
363 | struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value && |
364 | !mlir::detail::IsInterface<T>::value>> |
365 | : public DenseMapInfo<mlir::Type> { |
366 | static T getEmptyKey() { |
367 | const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey(); |
368 | return T::getFromOpaquePointer(pointer); |
369 | } |
370 | static T getTombstoneKey() { |
371 | const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey(); |
372 | return T::getFromOpaquePointer(pointer); |
373 | } |
374 | }; |
375 | |
376 | /// We align TypeStorage by 8, so allow LLVM to steal the low bits. |
377 | template <> |
378 | struct PointerLikeTypeTraits<mlir::Type> { |
379 | public: |
380 | static inline void *getAsVoidPointer(mlir::Type I) { |
381 | return const_cast<void *>(I.getAsOpaquePointer()); |
382 | } |
383 | static inline mlir::Type getFromVoidPointer(void *P) { |
384 | return mlir::Type::getFromOpaquePointer(pointer: P); |
385 | } |
386 | static constexpr int NumLowBitsAvailable = 3; |
387 | }; |
388 | |
389 | /// Add support for llvm style casts. |
390 | /// We provide a cast between To and From if From is mlir::Type or derives from |
391 | /// it |
392 | template <typename To, typename From> |
393 | struct CastInfo< |
394 | To, From, |
395 | std::enable_if_t<std::is_same_v<mlir::Type, std::remove_const_t<From>> || |
396 | std::is_base_of_v<mlir::Type, From>>> |
397 | : NullableValueCastFailed<To>, |
398 | DefaultDoCastIfPossible<To, From, CastInfo<To, From>> { |
399 | /// Arguments are taken as mlir::Type here and not as `From`, because when |
400 | /// casting from an intermediate type of the hierarchy to one of its children, |
401 | /// the val.getTypeID() inside T::classof will use the static getTypeID of the |
402 | /// parent instead of the non-static Type::getTypeID that returns the dynamic |
403 | /// ID. This means that T::classof would end up comparing the static TypeID of |
404 | /// the children to the static TypeID of its parent, making it impossible to |
405 | /// downcast from the parent to the child. |
406 | static inline bool isPossible(mlir::Type ty) { |
407 | /// Return a constant true instead of a dynamic true when casting to self or |
408 | /// up the hierarchy. |
409 | if constexpr (std::is_base_of_v<To, From>) { |
410 | return true; |
411 | } else { |
412 | return To::classof(ty); |
413 | }; |
414 | } |
415 | static inline To doCast(mlir::Type ty) { return To(ty.getImpl()); } |
416 | }; |
417 | |
418 | } // namespace llvm |
419 | |
420 | #endif // MLIR_IR_TYPES_H |
421 | |