1//===- TypeSupport.h --------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines support types for registering dialect extended types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_TYPESUPPORT_H
14#define MLIR_IR_TYPESUPPORT_H
15
16#include "mlir/IR/MLIRContext.h"
17#include "mlir/IR/StorageUniquerSupport.h"
18#include "llvm/ADT/Twine.h"
19
20namespace mlir {
21class Dialect;
22class MLIRContext;
23
24//===----------------------------------------------------------------------===//
25// AbstractType
26//===----------------------------------------------------------------------===//
27
28/// This class contains all of the static information common to all instances of
29/// a registered Type.
30class AbstractType {
31public:
32 using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
33 using WalkImmediateSubElementsFn = function_ref<void(
34 Type, function_ref<void(Attribute)>, function_ref<void(Type)>)>;
35 using ReplaceImmediateSubElementsFn =
36 function_ref<Type(Type, ArrayRef<Attribute>, ArrayRef<Type>)>;
37
38 /// Look up the specified abstract type in the MLIRContext and return a
39 /// reference to it.
40 static const AbstractType &lookup(TypeID typeID, MLIRContext *context);
41
42 /// Look up the specified abstract type in the MLIRContext and return a
43 /// reference to it if it exists.
44 static std::optional<std::reference_wrapper<const AbstractType>>
45 lookup(StringRef name, MLIRContext *context);
46
47 /// This method is used by Dialect objects when they register the list of
48 /// types they contain.
49 template <typename T>
50 static AbstractType get(Dialect &dialect) {
51 return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
52 T::getWalkImmediateSubElementsFn(),
53 T::getReplaceImmediateSubElementsFn(), T::getTypeID(),
54 T::name);
55 }
56
57 /// This method is used by Dialect objects to register types with
58 /// custom TypeIDs.
59 /// The use of this method is in general discouraged in favor of
60 /// 'get<CustomType>(dialect)';
61 static AbstractType
62 get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
63 HasTraitFn &&hasTrait,
64 WalkImmediateSubElementsFn walkImmediateSubElementsFn,
65 ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
66 TypeID typeID, StringRef name) {
67 return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait),
68 walkImmediateSubElementsFn,
69 replaceImmediateSubElementsFn, typeID, name);
70 }
71
72 /// Return the dialect this type was registered to.
73 Dialect &getDialect() const { return const_cast<Dialect &>(dialect); }
74
75 /// Returns an instance of the concept object for the given interface if it
76 /// was registered to this type, null otherwise. This should not be used
77 /// directly.
78 template <typename T>
79 typename T::Concept *getInterface() const {
80 return interfaceMap.lookup<T>();
81 }
82
83 /// Returns true if the type has the interface with the given ID.
84 bool hasInterface(TypeID interfaceID) const {
85 return interfaceMap.contains(interfaceID);
86 }
87
88 /// Returns true if the type has a particular trait.
89 template <template <typename T> class Trait>
90 bool hasTrait() const {
91 return hasTraitFn(TypeID::get<Trait>());
92 }
93
94 /// Returns true if the type has a particular trait.
95 bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
96
97 /// Walk the immediate sub-elements of the given type.
98 void walkImmediateSubElements(Type type,
99 function_ref<void(Attribute)> walkAttrsFn,
100 function_ref<void(Type)> walkTypesFn) const;
101
102 /// Replace the immediate sub-elements of the given type.
103 Type replaceImmediateSubElements(Type type, ArrayRef<Attribute> replAttrs,
104 ArrayRef<Type> replTypes) const;
105
106 /// Return the unique identifier representing the concrete type class.
107 TypeID getTypeID() const { return typeID; }
108
109 /// Return the unique name representing the type.
110 StringRef getName() const { return name; }
111
112private:
113 AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
114 HasTraitFn &&hasTrait,
115 WalkImmediateSubElementsFn walkImmediateSubElementsFn,
116 ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
117 TypeID typeID, StringRef name)
118 : dialect(dialect), interfaceMap(std::move(interfaceMap)),
119 hasTraitFn(std::move(hasTrait)),
120 walkImmediateSubElementsFn(walkImmediateSubElementsFn),
121 replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
122 typeID(typeID), name(name) {}
123
124 /// Give StorageUserBase access to the mutable lookup.
125 template <typename ConcreteT, typename BaseT, typename StorageT,
126 typename UniquerT, template <typename T> class... Traits>
127 friend class detail::StorageUserBase;
128
129 /// Look up the specified abstract type in the MLIRContext and return a
130 /// (mutable) pointer to it. Return a null pointer if the type could not
131 /// be found in the context.
132 static AbstractType *lookupMutable(TypeID typeID, MLIRContext *context);
133
134 /// This is the dialect that this type was registered to.
135 const Dialect &dialect;
136
137 /// This is a collection of the interfaces registered to this type.
138 detail::InterfaceMap interfaceMap;
139
140 /// Function to check if the type has a particular trait.
141 HasTraitFn hasTraitFn;
142
143 /// Function to walk the immediate sub-elements of this type.
144 WalkImmediateSubElementsFn walkImmediateSubElementsFn;
145
146 /// Function to replace the immediate sub-elements of this type.
147 ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn;
148
149 /// The unique identifier of the derived Type class.
150 const TypeID typeID;
151
152 /// The unique name of this type. The string is not owned by the context, so
153 /// The lifetime of this string should outlive the MLIR context.
154 const StringRef name;
155};
156
157//===----------------------------------------------------------------------===//
158// TypeStorage
159//===----------------------------------------------------------------------===//
160
161namespace detail {
162struct TypeUniquer;
163} // namespace detail
164
165/// Base storage class appearing in a Type.
166class TypeStorage : public StorageUniquer::BaseStorage {
167 friend detail::TypeUniquer;
168 friend StorageUniquer;
169
170public:
171 /// Return the abstract type descriptor for this type.
172 const AbstractType &getAbstractType() {
173 assert(abstractType && "Malformed type storage object.");
174 return *abstractType;
175 }
176
177protected:
178 /// This constructor is used by derived classes as part of the TypeUniquer.
179 TypeStorage() {}
180
181private:
182 /// Set the abstract type for this storage instance. This is used by the
183 /// TypeUniquer when initializing a newly constructed type storage object.
184 void initialize(const AbstractType &abstractTy) {
185 abstractType = const_cast<AbstractType *>(&abstractTy);
186 }
187
188 /// The abstract description for this type.
189 AbstractType *abstractType{nullptr};
190};
191
192/// Default storage type for types that require no additional initialization or
193/// storage.
194using DefaultTypeStorage = TypeStorage;
195
196//===----------------------------------------------------------------------===//
197// TypeStorageAllocator
198//===----------------------------------------------------------------------===//
199
200/// This is a utility allocator used to allocate memory for instances of derived
201/// Types.
202using TypeStorageAllocator = StorageUniquer::StorageAllocator;
203
204//===----------------------------------------------------------------------===//
205// TypeUniquer
206//===----------------------------------------------------------------------===//
207namespace detail {
208/// A utility class to get, or create, unique instances of types within an
209/// MLIRContext. This class manages all creation and uniquing of types.
210struct TypeUniquer {
211 /// Get an uniqued instance of a type T.
212 template <typename T, typename... Args>
213 static T get(MLIRContext *ctx, Args &&...args) {
214 return getWithTypeID<T, Args...>(ctx, T::getTypeID(),
215 std::forward<Args>(args)...);
216 }
217
218 /// Get an uniqued instance of a parametric type T.
219 /// The use of this method is in general discouraged in favor of
220 /// 'get<T, Args>(ctx, args)'.
221 template <typename T, typename... Args>
222 static std::enable_if_t<
223 !std::is_same<typename T::ImplType, TypeStorage>::value, T>
224 getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) {
225#ifndef NDEBUG
226 if (!ctx->getTypeUniquer().isParametricStorageInitialized(id: typeID))
227 llvm::report_fatal_error(
228 llvm::Twine("can't create type '") + llvm::getTypeName<T>() +
229 "' because storage uniquer isn't initialized: the dialect was likely "
230 "not loaded, or the type wasn't added with addTypes<...>() "
231 "in the Dialect::initialize() method.");
232#endif
233 return ctx->getTypeUniquer().get<typename T::ImplType>(
234 [&, typeID](TypeStorage *storage) {
235 storage->initialize(abstractTy: AbstractType::lookup(typeID, context: ctx));
236 },
237 typeID, std::forward<Args>(args)...);
238 }
239 /// Get an uniqued instance of a singleton type T.
240 /// The use of this method is in general discouraged in favor of
241 /// 'get<T, Args>(ctx, args)'.
242 template <typename T>
243 static std::enable_if_t<
244 std::is_same<typename T::ImplType, TypeStorage>::value, T>
245 getWithTypeID(MLIRContext *ctx, TypeID typeID) {
246#ifndef NDEBUG
247 if (!ctx->getTypeUniquer().isSingletonStorageInitialized(id: typeID))
248 llvm::report_fatal_error(
249 llvm::Twine("can't create type '") + llvm::getTypeName<T>() +
250 "' because storage uniquer isn't initialized: the dialect was likely "
251 "not loaded, or the type wasn't added with addTypes<...>() "
252 "in the Dialect::initialize() method.");
253#endif
254 return ctx->getTypeUniquer().get<typename T::ImplType>(typeID);
255 }
256
257 /// Change the mutable component of the given type instance in the provided
258 /// context.
259 template <typename T, typename... Args>
260 static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl,
261 Args &&...args) {
262 assert(impl && "cannot mutate null type");
263 return ctx->getTypeUniquer().mutate(T::getTypeID(), impl,
264 std::forward<Args>(args)...);
265 }
266
267 /// Register a type instance T with the uniquer.
268 template <typename T>
269 static void registerType(MLIRContext *ctx) {
270 registerType<T>(ctx, T::getTypeID());
271 }
272
273 /// Register a parametric type instance T with the uniquer.
274 /// The use of this method is in general discouraged in favor of
275 /// 'registerType<T>(ctx)'.
276 template <typename T>
277 static std::enable_if_t<
278 !std::is_same<typename T::ImplType, TypeStorage>::value>
279 registerType(MLIRContext *ctx, TypeID typeID) {
280 ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>(
281 typeID);
282 }
283 /// Register a singleton type instance T with the uniquer.
284 /// The use of this method is in general discouraged in favor of
285 /// 'registerType<T>(ctx)'.
286 template <typename T>
287 static std::enable_if_t<
288 std::is_same<typename T::ImplType, TypeStorage>::value>
289 registerType(MLIRContext *ctx, TypeID typeID) {
290 ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>(
291 typeID, [&ctx, typeID](TypeStorage *storage) {
292 storage->initialize(abstractTy: AbstractType::lookup(typeID, context: ctx));
293 });
294 }
295};
296} // namespace detail
297
298} // namespace mlir
299
300#endif
301

source code of mlir/include/mlir/IR/TypeSupport.h