| 1 | //===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_IR_TYPES_H |
| 10 | #define MLIR_IR_TYPES_H |
| 11 | |
| 12 | #include "mlir/IR/TypeSupport.h" |
| 13 | #include "llvm/ADT/ArrayRef.h" |
| 14 | #include "llvm/ADT/DenseMapInfo.h" |
| 15 | #include "llvm/Support/PointerLikeTypeTraits.h" |
| 16 | |
| 17 | namespace mlir { |
| 18 | class AsmState; |
| 19 | |
| 20 | /// Instances of the Type class are uniqued, have an immutable identifier and an |
| 21 | /// optional mutable component. They wrap a pointer to the storage object owned |
| 22 | /// by MLIRContext. Therefore, instances of Type are passed around by value. |
| 23 | /// |
| 24 | /// Some types are "primitives" meaning they do not have any parameters, for |
| 25 | /// example the Index type. Parametric types have additional information that |
| 26 | /// differentiates the types of the same class, for example the Integer type has |
| 27 | /// bitwidth, making i8 and i16 belong to the same kind by be different |
| 28 | /// instances of the IntegerType. Type parameters are part of the unique |
| 29 | /// immutable key. The mutable component of the type can be modified after the |
| 30 | /// type is created, but cannot affect the identity of the type. |
| 31 | /// |
| 32 | /// Types are constructed and uniqued via the 'detail::TypeUniquer' class. |
| 33 | /// |
| 34 | /// Derived type classes are expected to implement several required |
| 35 | /// implementation hooks: |
| 36 | /// * Optional: |
| 37 | /// - static LogicalResult verifyInvariants( |
| 38 | /// function_ref<InFlightDiagnostic()> emitError, |
| 39 | /// Args... args) |
| 40 | /// * This method is invoked when calling the 'TypeBase::get/getChecked' |
| 41 | /// methods to ensure that the arguments passed in are valid to construct |
| 42 | /// a type instance with. |
| 43 | /// * This method is expected to return failure if a type cannot be |
| 44 | /// constructed with 'args', success otherwise. |
| 45 | /// * 'args' must correspond with the arguments passed into the |
| 46 | /// 'TypeBase::get' call. |
| 47 | /// |
| 48 | /// |
| 49 | /// Type storage objects inherit from TypeStorage and contain the following: |
| 50 | /// - The dialect that defined the type. |
| 51 | /// - Any parameters of the type. |
| 52 | /// - An optional mutable component. |
| 53 | /// For non-parametric types, a convenience DefaultTypeStorage is provided. |
| 54 | /// Parametric storage types must derive TypeStorage and respect the following: |
| 55 | /// - Define a type alias, KeyTy, to a type that uniquely identifies the |
| 56 | /// instance of the type. |
| 57 | /// * The key type must be constructible from the values passed into the |
| 58 | /// detail::TypeUniquer::get call. |
| 59 | /// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the |
| 60 | /// storage class must define a hashing method: |
| 61 | /// 'static unsigned hashKey(const KeyTy &)' |
| 62 | /// |
| 63 | /// - Provide a method, 'bool operator==(const KeyTy &) const', to |
| 64 | /// compare the storage instance against an instance of the key type. |
| 65 | /// |
| 66 | /// - Provide a static construction method: |
| 67 | /// 'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)' |
| 68 | /// that builds a unique instance of the derived storage. The arguments to |
| 69 | /// this function are an allocator to store any uniqued data within the |
| 70 | /// context and the key type for this storage. |
| 71 | /// |
| 72 | /// - If they have a mutable component, this component must not be a part of |
| 73 | /// the key. |
| 74 | class Type { |
| 75 | public: |
| 76 | /// Utility class for implementing types. |
| 77 | template <typename ConcreteType, typename BaseType, typename StorageType, |
| 78 | template <typename T> class... Traits> |
| 79 | using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType, |
| 80 | detail::TypeUniquer, Traits...>; |
| 81 | |
| 82 | using ImplType = TypeStorage; |
| 83 | |
| 84 | using AbstractTy = AbstractType; |
| 85 | |
| 86 | constexpr Type() = default; |
| 87 | /* implicit */ Type(const ImplType *impl) |
| 88 | : impl(const_cast<ImplType *>(impl)) {} |
| 89 | |
| 90 | Type(const Type &other) = default; |
| 91 | Type &operator=(const Type &other) = default; |
| 92 | |
| 93 | bool operator==(Type other) const { return impl == other.impl; } |
| 94 | bool operator!=(Type other) const { return !(*this == other); } |
| 95 | explicit operator bool() const { return impl; } |
| 96 | |
| 97 | bool operator!() const { return impl == nullptr; } |
| 98 | |
| 99 | /// Return a unique identifier for the concrete type. This is used to support |
| 100 | /// dynamic type casting. |
| 101 | TypeID getTypeID() { return impl->getAbstractType().getTypeID(); } |
| 102 | |
| 103 | /// Return the MLIRContext in which this type was uniqued. |
| 104 | MLIRContext *getContext() const; |
| 105 | |
| 106 | /// Get the dialect this type is registered to. |
| 107 | Dialect &getDialect() const { return impl->getAbstractType().getDialect(); } |
| 108 | |
| 109 | // Convenience predicates. This is only for floating point types, |
| 110 | // derived types should use isa/dyn_cast. |
| 111 | bool isIndex() const; |
| 112 | bool isBF16() const; |
| 113 | bool isF16() const; |
| 114 | bool isTF32() const; |
| 115 | bool isF32() const; |
| 116 | bool isF64() const; |
| 117 | bool isF80() const; |
| 118 | bool isF128() const; |
| 119 | /// Return true if this is an float type (with the specified width). |
| 120 | bool isFloat() const; |
| 121 | bool isFloat(unsigned width) const; |
| 122 | |
| 123 | /// Return true if this is an integer type (with the specified width). |
| 124 | bool isInteger() const; |
| 125 | bool isInteger(unsigned width) const; |
| 126 | /// Return true if this is a signless integer type (with the specified width). |
| 127 | bool isSignlessInteger() const; |
| 128 | bool isSignlessInteger(unsigned width) const; |
| 129 | /// Return true if this is a signed integer type (with the specified width). |
| 130 | bool isSignedInteger() const; |
| 131 | bool isSignedInteger(unsigned width) const; |
| 132 | /// Return true if this is an unsigned integer type (with the specified |
| 133 | /// width). |
| 134 | bool isUnsignedInteger() const; |
| 135 | bool isUnsignedInteger(unsigned width) const; |
| 136 | |
| 137 | /// Return the bit width of an integer or a float type, assert failure on |
| 138 | /// other types. |
| 139 | unsigned getIntOrFloatBitWidth() const; |
| 140 | |
| 141 | /// Return true if this is a signless integer or index type. |
| 142 | bool isSignlessIntOrIndex() const; |
| 143 | /// Return true if this is a signless integer, index, or float type. |
| 144 | bool isSignlessIntOrIndexOrFloat() const; |
| 145 | /// Return true of this is a signless integer or a float type. |
| 146 | bool isSignlessIntOrFloat() const; |
| 147 | |
| 148 | /// Return true if this is an integer (of any signedness) or an index type. |
| 149 | bool isIntOrIndex() const; |
| 150 | /// Return true if this is an integer (of any signedness) or a float type. |
| 151 | bool isIntOrFloat() const; |
| 152 | /// Return true if this is an integer (of any signedness), index, or float |
| 153 | /// type. |
| 154 | bool isIntOrIndexOrFloat() const; |
| 155 | |
| 156 | /// Print the current type. |
| 157 | void print(raw_ostream &os) const; |
| 158 | void print(raw_ostream &os, AsmState &state) const; |
| 159 | void dump() const; |
| 160 | |
| 161 | friend ::llvm::hash_code hash_value(Type arg); |
| 162 | |
| 163 | /// Methods for supporting PointerLikeTypeTraits. |
| 164 | const void *getAsOpaquePointer() const { |
| 165 | return static_cast<const void *>(impl); |
| 166 | } |
| 167 | static Type getFromOpaquePointer(const void *pointer) { |
| 168 | return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer))); |
| 169 | } |
| 170 | |
| 171 | /// Returns true if `InterfaceT` has been promised by the dialect or |
| 172 | /// implemented. |
| 173 | template <typename InterfaceT> |
| 174 | bool hasPromiseOrImplementsInterface() { |
| 175 | return dialect_extension_detail::hasPromisedInterface( |
| 176 | getDialect(), getTypeID(), InterfaceT::getInterfaceID()) || |
| 177 | mlir::isa<InterfaceT>(*this); |
| 178 | } |
| 179 | |
| 180 | /// Returns true if the type was registered with a particular trait. |
| 181 | template <template <typename T> class Trait> |
| 182 | bool hasTrait() { |
| 183 | return getAbstractType().hasTrait<Trait>(); |
| 184 | } |
| 185 | |
| 186 | /// Return the abstract type descriptor for this type. |
| 187 | const AbstractTy &getAbstractType() const { return impl->getAbstractType(); } |
| 188 | |
| 189 | /// Return the Type implementation. |
| 190 | ImplType *getImpl() const { return impl; } |
| 191 | |
| 192 | /// Walk all of the immediately nested sub-attributes and sub-types. This |
| 193 | /// method does not recurse into sub elements. |
| 194 | void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn, |
| 195 | function_ref<void(Type)> walkTypesFn) const { |
| 196 | getAbstractType().walkImmediateSubElements(type: *this, walkAttrsFn, walkTypesFn); |
| 197 | } |
| 198 | |
| 199 | /// Replace the immediately nested sub-attributes and sub-types with those |
| 200 | /// provided. The order of the provided elements is derived from the order of |
| 201 | /// the elements returned by the callbacks of `walkImmediateSubElements`. The |
| 202 | /// element at index 0 would replace the very first attribute given by |
| 203 | /// `walkImmediateSubElements`. On success, the new instance with the values |
| 204 | /// replaced is returned. If replacement fails, nullptr is returned. |
| 205 | auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, |
| 206 | ArrayRef<Type> replTypes) const { |
| 207 | return getAbstractType().replaceImmediateSubElements(type: *this, replAttrs, |
| 208 | replTypes); |
| 209 | } |
| 210 | |
| 211 | /// Walk this type and all attibutes/types nested within using the |
| 212 | /// provided walk functions. See `AttrTypeWalker` for information on the |
| 213 | /// supported walk function types. |
| 214 | template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns> |
| 215 | auto walk(WalkFns &&...walkFns) { |
| 216 | AttrTypeWalker walker; |
| 217 | (walker.addWalk(std::forward<WalkFns>(walkFns)), ...); |
| 218 | return walker.walk<Order>(*this); |
| 219 | } |
| 220 | |
| 221 | /// Recursively replace all of the nested sub-attributes and sub-types using |
| 222 | /// the provided map functions. Returns nullptr in the case of failure. See |
| 223 | /// `AttrTypeReplacer` for information on the support replacement function |
| 224 | /// types. |
| 225 | template <typename... ReplacementFns> |
| 226 | auto replace(ReplacementFns &&...replacementFns) { |
| 227 | AttrTypeReplacer replacer; |
| 228 | (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), |
| 229 | ...); |
| 230 | return replacer.replace(type: *this); |
| 231 | } |
| 232 | |
| 233 | protected: |
| 234 | ImplType *impl{nullptr}; |
| 235 | }; |
| 236 | |
| 237 | inline raw_ostream &operator<<(raw_ostream &os, Type type) { |
| 238 | type.print(os); |
| 239 | return os; |
| 240 | } |
| 241 | |
| 242 | //===----------------------------------------------------------------------===// |
| 243 | // TypeTraitBase |
| 244 | //===----------------------------------------------------------------------===// |
| 245 | |
| 246 | namespace TypeTrait { |
| 247 | /// This class represents the base of a type trait. |
| 248 | template <typename ConcreteType, template <typename> class TraitType> |
| 249 | using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>; |
| 250 | } // namespace TypeTrait |
| 251 | |
| 252 | //===----------------------------------------------------------------------===// |
| 253 | // TypeInterface |
| 254 | //===----------------------------------------------------------------------===// |
| 255 | |
| 256 | /// This class represents the base of a type interface. See the definition of |
| 257 | /// `detail::Interface` for requirements on the `Traits` type. |
| 258 | template <typename ConcreteType, typename Traits> |
| 259 | class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type, |
| 260 | TypeTrait::TraitBase> { |
| 261 | public: |
| 262 | using Base = TypeInterface<ConcreteType, Traits>; |
| 263 | using InterfaceBase = |
| 264 | detail::Interface<ConcreteType, Type, Traits, Type, TypeTrait::TraitBase>; |
| 265 | using InterfaceBase::InterfaceBase; |
| 266 | |
| 267 | protected: |
| 268 | /// Returns the impl interface instance for the given type. |
| 269 | static typename InterfaceBase::Concept *getInterfaceFor(Type type) { |
| 270 | #ifndef NDEBUG |
| 271 | // Check that the current interface isn't an unresolved promise for the |
| 272 | // given type. |
| 273 | dialect_extension_detail::handleUseOfUndefinedPromisedInterface( |
| 274 | dialect&: type.getDialect(), interfaceRequestorID: type.getTypeID(), interfaceID: ConcreteType::getInterfaceID(), |
| 275 | interfaceName: llvm::getTypeName<ConcreteType>()); |
| 276 | #endif |
| 277 | |
| 278 | return type.getAbstractType().getInterface<ConcreteType>(); |
| 279 | } |
| 280 | |
| 281 | /// Allow access to 'getInterfaceFor'. |
| 282 | friend InterfaceBase; |
| 283 | }; |
| 284 | |
| 285 | //===----------------------------------------------------------------------===// |
| 286 | // Core TypeTrait |
| 287 | //===----------------------------------------------------------------------===// |
| 288 | |
| 289 | /// This trait is used to determine if a type is mutable or not. It is attached |
| 290 | /// on a type if the corresponding ImplType defines a `mutate` function with |
| 291 | /// a proper signature. |
| 292 | namespace TypeTrait { |
| 293 | template <typename ConcreteType> |
| 294 | using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>; |
| 295 | } // namespace TypeTrait |
| 296 | |
| 297 | //===----------------------------------------------------------------------===// |
| 298 | // Type Utils |
| 299 | //===----------------------------------------------------------------------===// |
| 300 | |
| 301 | // Make Type hashable. |
| 302 | inline ::llvm::hash_code hash_value(Type arg) { |
| 303 | return DenseMapInfo<const Type::ImplType *>::getHashValue(PtrVal: arg.impl); |
| 304 | } |
| 305 | |
| 306 | } // namespace mlir |
| 307 | |
| 308 | namespace llvm { |
| 309 | |
| 310 | // Type hash just like pointers. |
| 311 | template <> |
| 312 | struct DenseMapInfo<mlir::Type> { |
| 313 | static mlir::Type getEmptyKey() { |
| 314 | auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
| 315 | return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer)); |
| 316 | } |
| 317 | static mlir::Type getTombstoneKey() { |
| 318 | auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
| 319 | return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer)); |
| 320 | } |
| 321 | static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(arg: val); } |
| 322 | static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; } |
| 323 | }; |
| 324 | template <typename T> |
| 325 | struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value && |
| 326 | !mlir::detail::IsInterface<T>::value>> |
| 327 | : public DenseMapInfo<mlir::Type> { |
| 328 | static T getEmptyKey() { |
| 329 | const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey(); |
| 330 | return T::getFromOpaquePointer(pointer); |
| 331 | } |
| 332 | static T getTombstoneKey() { |
| 333 | const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey(); |
| 334 | return T::getFromOpaquePointer(pointer); |
| 335 | } |
| 336 | }; |
| 337 | |
| 338 | /// We align TypeStorage by 8, so allow LLVM to steal the low bits. |
| 339 | template <> |
| 340 | struct PointerLikeTypeTraits<mlir::Type> { |
| 341 | public: |
| 342 | static inline void *getAsVoidPointer(mlir::Type I) { |
| 343 | return const_cast<void *>(I.getAsOpaquePointer()); |
| 344 | } |
| 345 | static inline mlir::Type getFromVoidPointer(void *P) { |
| 346 | return mlir::Type::getFromOpaquePointer(pointer: P); |
| 347 | } |
| 348 | static constexpr int NumLowBitsAvailable = 3; |
| 349 | }; |
| 350 | |
| 351 | /// Add support for llvm style casts. |
| 352 | /// We provide a cast between To and From if From is mlir::Type or derives from |
| 353 | /// it |
| 354 | template <typename To, typename From> |
| 355 | struct CastInfo< |
| 356 | To, From, |
| 357 | std::enable_if_t<std::is_same_v<mlir::Type, std::remove_const_t<From>> || |
| 358 | std::is_base_of_v<mlir::Type, From>>> |
| 359 | : NullableValueCastFailed<To>, |
| 360 | DefaultDoCastIfPossible<To, From, CastInfo<To, From>> { |
| 361 | /// Arguments are taken as mlir::Type here and not as `From`, because when |
| 362 | /// casting from an intermediate type of the hierarchy to one of its children, |
| 363 | /// the val.getTypeID() inside T::classof will use the static getTypeID of the |
| 364 | /// parent instead of the non-static Type::getTypeID that returns the dynamic |
| 365 | /// ID. This means that T::classof would end up comparing the static TypeID of |
| 366 | /// the children to the static TypeID of its parent, making it impossible to |
| 367 | /// downcast from the parent to the child. |
| 368 | static inline bool isPossible(mlir::Type ty) { |
| 369 | /// Return a constant true instead of a dynamic true when casting to self or |
| 370 | /// up the hierarchy. |
| 371 | if constexpr (std::is_base_of_v<To, From>) { |
| 372 | return true; |
| 373 | } else { |
| 374 | return To::classof(ty); |
| 375 | }; |
| 376 | } |
| 377 | static inline To doCast(mlir::Type ty) { return To(ty.getImpl()); } |
| 378 | }; |
| 379 | |
| 380 | } // namespace llvm |
| 381 | |
| 382 | #endif // MLIR_IR_TYPES_H |
| 383 | |