1//===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_TYPES_H
10#define MLIR_IR_TYPES_H
11
12#include "mlir/IR/TypeSupport.h"
13#include "llvm/ADT/ArrayRef.h"
14#include "llvm/ADT/DenseMapInfo.h"
15#include "llvm/Support/PointerLikeTypeTraits.h"
16
17namespace mlir {
18class AsmState;
19
20/// Instances of the Type class are uniqued, have an immutable identifier and an
21/// optional mutable component. They wrap a pointer to the storage object owned
22/// by MLIRContext. Therefore, instances of Type are passed around by value.
23///
24/// Some types are "primitives" meaning they do not have any parameters, for
25/// example the Index type. Parametric types have additional information that
26/// differentiates the types of the same class, for example the Integer type has
27/// bitwidth, making i8 and i16 belong to the same kind by be different
28/// instances of the IntegerType. Type parameters are part of the unique
29/// immutable key. The mutable component of the type can be modified after the
30/// type is created, but cannot affect the identity of the type.
31///
32/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
33///
34/// Derived type classes are expected to implement several required
35/// implementation hooks:
36/// * Optional:
37/// - static LogicalResult verifyInvariants(
38/// function_ref<InFlightDiagnostic()> emitError,
39/// Args... args)
40/// * This method is invoked when calling the 'TypeBase::get/getChecked'
41/// methods to ensure that the arguments passed in are valid to construct
42/// a type instance with.
43/// * This method is expected to return failure if a type cannot be
44/// constructed with 'args', success otherwise.
45/// * 'args' must correspond with the arguments passed into the
46/// 'TypeBase::get' call.
47///
48///
49/// Type storage objects inherit from TypeStorage and contain the following:
50/// - The dialect that defined the type.
51/// - Any parameters of the type.
52/// - An optional mutable component.
53/// For non-parametric types, a convenience DefaultTypeStorage is provided.
54/// Parametric storage types must derive TypeStorage and respect the following:
55/// - Define a type alias, KeyTy, to a type that uniquely identifies the
56/// instance of the type.
57/// * The key type must be constructible from the values passed into the
58/// detail::TypeUniquer::get call.
59/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
60/// storage class must define a hashing method:
61/// 'static unsigned hashKey(const KeyTy &)'
62///
63/// - Provide a method, 'bool operator==(const KeyTy &) const', to
64/// compare the storage instance against an instance of the key type.
65///
66/// - Provide a static construction method:
67/// 'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)'
68/// that builds a unique instance of the derived storage. The arguments to
69/// this function are an allocator to store any uniqued data within the
70/// context and the key type for this storage.
71///
72/// - If they have a mutable component, this component must not be a part of
73/// the key.
74class Type {
75public:
76 /// Utility class for implementing types.
77 template <typename ConcreteType, typename BaseType, typename StorageType,
78 template <typename T> class... Traits>
79 using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
80 detail::TypeUniquer, Traits...>;
81
82 using ImplType = TypeStorage;
83
84 using AbstractTy = AbstractType;
85
86 constexpr Type() = default;
87 /* implicit */ Type(const ImplType *impl)
88 : impl(const_cast<ImplType *>(impl)) {}
89
90 Type(const Type &other) = default;
91 Type &operator=(const Type &other) = default;
92
93 bool operator==(Type other) const { return impl == other.impl; }
94 bool operator!=(Type other) const { return !(*this == other); }
95 explicit operator bool() const { return impl; }
96
97 bool operator!() const { return impl == nullptr; }
98
99 /// Return a unique identifier for the concrete type. This is used to support
100 /// dynamic type casting.
101 TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
102
103 /// Return the MLIRContext in which this type was uniqued.
104 MLIRContext *getContext() const;
105
106 /// Get the dialect this type is registered to.
107 Dialect &getDialect() const { return impl->getAbstractType().getDialect(); }
108
109 // Convenience predicates. This is only for floating point types,
110 // derived types should use isa/dyn_cast.
111 bool isIndex() const;
112 bool isBF16() const;
113 bool isF16() const;
114 bool isTF32() const;
115 bool isF32() const;
116 bool isF64() const;
117 bool isF80() const;
118 bool isF128() const;
119 /// Return true if this is an float type (with the specified width).
120 bool isFloat() const;
121 bool isFloat(unsigned width) const;
122
123 /// Return true if this is an integer type (with the specified width).
124 bool isInteger() const;
125 bool isInteger(unsigned width) const;
126 /// Return true if this is a signless integer type (with the specified width).
127 bool isSignlessInteger() const;
128 bool isSignlessInteger(unsigned width) const;
129 /// Return true if this is a signed integer type (with the specified width).
130 bool isSignedInteger() const;
131 bool isSignedInteger(unsigned width) const;
132 /// Return true if this is an unsigned integer type (with the specified
133 /// width).
134 bool isUnsignedInteger() const;
135 bool isUnsignedInteger(unsigned width) const;
136
137 /// Return the bit width of an integer or a float type, assert failure on
138 /// other types.
139 unsigned getIntOrFloatBitWidth() const;
140
141 /// Return true if this is a signless integer or index type.
142 bool isSignlessIntOrIndex() const;
143 /// Return true if this is a signless integer, index, or float type.
144 bool isSignlessIntOrIndexOrFloat() const;
145 /// Return true of this is a signless integer or a float type.
146 bool isSignlessIntOrFloat() const;
147
148 /// Return true if this is an integer (of any signedness) or an index type.
149 bool isIntOrIndex() const;
150 /// Return true if this is an integer (of any signedness) or a float type.
151 bool isIntOrFloat() const;
152 /// Return true if this is an integer (of any signedness), index, or float
153 /// type.
154 bool isIntOrIndexOrFloat() const;
155
156 /// Print the current type.
157 void print(raw_ostream &os) const;
158 void print(raw_ostream &os, AsmState &state) const;
159 void dump() const;
160
161 friend ::llvm::hash_code hash_value(Type arg);
162
163 /// Methods for supporting PointerLikeTypeTraits.
164 const void *getAsOpaquePointer() const {
165 return static_cast<const void *>(impl);
166 }
167 static Type getFromOpaquePointer(const void *pointer) {
168 return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
169 }
170
171 /// Returns true if `InterfaceT` has been promised by the dialect or
172 /// implemented.
173 template <typename InterfaceT>
174 bool hasPromiseOrImplementsInterface() {
175 return dialect_extension_detail::hasPromisedInterface(
176 getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
177 mlir::isa<InterfaceT>(*this);
178 }
179
180 /// Returns true if the type was registered with a particular trait.
181 template <template <typename T> class Trait>
182 bool hasTrait() {
183 return getAbstractType().hasTrait<Trait>();
184 }
185
186 /// Return the abstract type descriptor for this type.
187 const AbstractTy &getAbstractType() const { return impl->getAbstractType(); }
188
189 /// Return the Type implementation.
190 ImplType *getImpl() const { return impl; }
191
192 /// Walk all of the immediately nested sub-attributes and sub-types. This
193 /// method does not recurse into sub elements.
194 void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
195 function_ref<void(Type)> walkTypesFn) const {
196 getAbstractType().walkImmediateSubElements(type: *this, walkAttrsFn, walkTypesFn);
197 }
198
199 /// Replace the immediately nested sub-attributes and sub-types with those
200 /// provided. The order of the provided elements is derived from the order of
201 /// the elements returned by the callbacks of `walkImmediateSubElements`. The
202 /// element at index 0 would replace the very first attribute given by
203 /// `walkImmediateSubElements`. On success, the new instance with the values
204 /// replaced is returned. If replacement fails, nullptr is returned.
205 auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
206 ArrayRef<Type> replTypes) const {
207 return getAbstractType().replaceImmediateSubElements(type: *this, replAttrs,
208 replTypes);
209 }
210
211 /// Walk this type and all attibutes/types nested within using the
212 /// provided walk functions. See `AttrTypeWalker` for information on the
213 /// supported walk function types.
214 template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns>
215 auto walk(WalkFns &&...walkFns) {
216 AttrTypeWalker walker;
217 (walker.addWalk(std::forward<WalkFns>(walkFns)), ...);
218 return walker.walk<Order>(*this);
219 }
220
221 /// Recursively replace all of the nested sub-attributes and sub-types using
222 /// the provided map functions. Returns nullptr in the case of failure. See
223 /// `AttrTypeReplacer` for information on the support replacement function
224 /// types.
225 template <typename... ReplacementFns>
226 auto replace(ReplacementFns &&...replacementFns) {
227 AttrTypeReplacer replacer;
228 (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)),
229 ...);
230 return replacer.replace(type: *this);
231 }
232
233protected:
234 ImplType *impl{nullptr};
235};
236
237inline raw_ostream &operator<<(raw_ostream &os, Type type) {
238 type.print(os);
239 return os;
240}
241
242//===----------------------------------------------------------------------===//
243// TypeTraitBase
244//===----------------------------------------------------------------------===//
245
246namespace TypeTrait {
247/// This class represents the base of a type trait.
248template <typename ConcreteType, template <typename> class TraitType>
249using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
250} // namespace TypeTrait
251
252//===----------------------------------------------------------------------===//
253// TypeInterface
254//===----------------------------------------------------------------------===//
255
256/// This class represents the base of a type interface. See the definition of
257/// `detail::Interface` for requirements on the `Traits` type.
258template <typename ConcreteType, typename Traits>
259class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
260 TypeTrait::TraitBase> {
261public:
262 using Base = TypeInterface<ConcreteType, Traits>;
263 using InterfaceBase =
264 detail::Interface<ConcreteType, Type, Traits, Type, TypeTrait::TraitBase>;
265 using InterfaceBase::InterfaceBase;
266
267protected:
268 /// Returns the impl interface instance for the given type.
269 static typename InterfaceBase::Concept *getInterfaceFor(Type type) {
270#ifndef NDEBUG
271 // Check that the current interface isn't an unresolved promise for the
272 // given type.
273 dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
274 dialect&: type.getDialect(), interfaceRequestorID: type.getTypeID(), interfaceID: ConcreteType::getInterfaceID(),
275 interfaceName: llvm::getTypeName<ConcreteType>());
276#endif
277
278 return type.getAbstractType().getInterface<ConcreteType>();
279 }
280
281 /// Allow access to 'getInterfaceFor'.
282 friend InterfaceBase;
283};
284
285//===----------------------------------------------------------------------===//
286// Core TypeTrait
287//===----------------------------------------------------------------------===//
288
289/// This trait is used to determine if a type is mutable or not. It is attached
290/// on a type if the corresponding ImplType defines a `mutate` function with
291/// a proper signature.
292namespace TypeTrait {
293template <typename ConcreteType>
294using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
295} // namespace TypeTrait
296
297//===----------------------------------------------------------------------===//
298// Type Utils
299//===----------------------------------------------------------------------===//
300
301// Make Type hashable.
302inline ::llvm::hash_code hash_value(Type arg) {
303 return DenseMapInfo<const Type::ImplType *>::getHashValue(PtrVal: arg.impl);
304}
305
306} // namespace mlir
307
308namespace llvm {
309
310// Type hash just like pointers.
311template <>
312struct DenseMapInfo<mlir::Type> {
313 static mlir::Type getEmptyKey() {
314 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
315 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
316 }
317 static mlir::Type getTombstoneKey() {
318 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
319 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
320 }
321 static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(arg: val); }
322 static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
323};
324template <typename T>
325struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value &&
326 !mlir::detail::IsInterface<T>::value>>
327 : public DenseMapInfo<mlir::Type> {
328 static T getEmptyKey() {
329 const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
330 return T::getFromOpaquePointer(pointer);
331 }
332 static T getTombstoneKey() {
333 const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
334 return T::getFromOpaquePointer(pointer);
335 }
336};
337
338/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
339template <>
340struct PointerLikeTypeTraits<mlir::Type> {
341public:
342 static inline void *getAsVoidPointer(mlir::Type I) {
343 return const_cast<void *>(I.getAsOpaquePointer());
344 }
345 static inline mlir::Type getFromVoidPointer(void *P) {
346 return mlir::Type::getFromOpaquePointer(pointer: P);
347 }
348 static constexpr int NumLowBitsAvailable = 3;
349};
350
351/// Add support for llvm style casts.
352/// We provide a cast between To and From if From is mlir::Type or derives from
353/// it
354template <typename To, typename From>
355struct CastInfo<
356 To, From,
357 std::enable_if_t<std::is_same_v<mlir::Type, std::remove_const_t<From>> ||
358 std::is_base_of_v<mlir::Type, From>>>
359 : NullableValueCastFailed<To>,
360 DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
361 /// Arguments are taken as mlir::Type here and not as `From`, because when
362 /// casting from an intermediate type of the hierarchy to one of its children,
363 /// the val.getTypeID() inside T::classof will use the static getTypeID of the
364 /// parent instead of the non-static Type::getTypeID that returns the dynamic
365 /// ID. This means that T::classof would end up comparing the static TypeID of
366 /// the children to the static TypeID of its parent, making it impossible to
367 /// downcast from the parent to the child.
368 static inline bool isPossible(mlir::Type ty) {
369 /// Return a constant true instead of a dynamic true when casting to self or
370 /// up the hierarchy.
371 if constexpr (std::is_base_of_v<To, From>) {
372 return true;
373 } else {
374 return To::classof(ty);
375 };
376 }
377 static inline To doCast(mlir::Type ty) { return To(ty.getImpl()); }
378};
379
380} // namespace llvm
381
382#endif // MLIR_IR_TYPES_H
383

source code of mlir/include/mlir/IR/Types.h