1 | //===- TypeSupport.h --------------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines support types for registering dialect extended types. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_TYPESUPPORT_H |
14 | #define MLIR_IR_TYPESUPPORT_H |
15 | |
16 | #include "mlir/IR/MLIRContext.h" |
17 | #include "mlir/IR/StorageUniquerSupport.h" |
18 | #include "llvm/ADT/Twine.h" |
19 | |
20 | namespace mlir { |
21 | class Dialect; |
22 | class MLIRContext; |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | // AbstractType |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | /// This class contains all of the static information common to all instances of |
29 | /// a registered Type. |
30 | class AbstractType { |
31 | public: |
32 | using HasTraitFn = llvm::unique_function<bool(TypeID) const>; |
33 | using WalkImmediateSubElementsFn = function_ref<void( |
34 | Type, function_ref<void(Attribute)>, function_ref<void(Type)>)>; |
35 | using ReplaceImmediateSubElementsFn = |
36 | function_ref<Type(Type, ArrayRef<Attribute>, ArrayRef<Type>)>; |
37 | |
38 | /// Look up the specified abstract type in the MLIRContext and return a |
39 | /// reference to it. |
40 | static const AbstractType &lookup(TypeID typeID, MLIRContext *context); |
41 | |
42 | /// Look up the specified abstract type in the MLIRContext and return a |
43 | /// reference to it if it exists. |
44 | static std::optional<std::reference_wrapper<const AbstractType>> |
45 | lookup(StringRef name, MLIRContext *context); |
46 | |
47 | /// This method is used by Dialect objects when they register the list of |
48 | /// types they contain. |
49 | template <typename T> |
50 | static AbstractType get(Dialect &dialect) { |
51 | return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(), |
52 | T::getWalkImmediateSubElementsFn(), |
53 | T::getReplaceImmediateSubElementsFn(), T::getTypeID(), |
54 | T::name); |
55 | } |
56 | |
57 | /// This method is used by Dialect objects to register types with |
58 | /// custom TypeIDs. |
59 | /// The use of this method is in general discouraged in favor of |
60 | /// 'get<CustomType>(dialect)'; |
61 | static AbstractType |
62 | get(Dialect &dialect, detail::InterfaceMap &&interfaceMap, |
63 | HasTraitFn &&hasTrait, |
64 | WalkImmediateSubElementsFn walkImmediateSubElementsFn, |
65 | ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, |
66 | TypeID typeID, StringRef name) { |
67 | return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait), |
68 | walkImmediateSubElementsFn, |
69 | replaceImmediateSubElementsFn, typeID, name); |
70 | } |
71 | |
72 | /// Return the dialect this type was registered to. |
73 | Dialect &getDialect() const { return const_cast<Dialect &>(dialect); } |
74 | |
75 | /// Returns an instance of the concept object for the given interface if it |
76 | /// was registered to this type, null otherwise. This should not be used |
77 | /// directly. |
78 | template <typename T> |
79 | typename T::Concept *getInterface() const { |
80 | return interfaceMap.lookup<T>(); |
81 | } |
82 | |
83 | /// Returns true if the type has the interface with the given ID. |
84 | bool hasInterface(TypeID interfaceID) const { |
85 | return interfaceMap.contains(interfaceID); |
86 | } |
87 | |
88 | /// Returns true if the type has a particular trait. |
89 | template <template <typename T> class Trait> |
90 | bool hasTrait() const { |
91 | return hasTraitFn(TypeID::get<Trait>()); |
92 | } |
93 | |
94 | /// Returns true if the type has a particular trait. |
95 | bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); } |
96 | |
97 | /// Walk the immediate sub-elements of the given type. |
98 | void walkImmediateSubElements(Type type, |
99 | function_ref<void(Attribute)> walkAttrsFn, |
100 | function_ref<void(Type)> walkTypesFn) const; |
101 | |
102 | /// Replace the immediate sub-elements of the given type. |
103 | Type replaceImmediateSubElements(Type type, ArrayRef<Attribute> replAttrs, |
104 | ArrayRef<Type> replTypes) const; |
105 | |
106 | /// Return the unique identifier representing the concrete type class. |
107 | TypeID getTypeID() const { return typeID; } |
108 | |
109 | /// Return the unique name representing the type. |
110 | StringRef getName() const { return name; } |
111 | |
112 | private: |
113 | AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap, |
114 | HasTraitFn &&hasTrait, |
115 | WalkImmediateSubElementsFn walkImmediateSubElementsFn, |
116 | ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, |
117 | TypeID typeID, StringRef name) |
118 | : dialect(dialect), interfaceMap(std::move(interfaceMap)), |
119 | hasTraitFn(std::move(hasTrait)), |
120 | walkImmediateSubElementsFn(walkImmediateSubElementsFn), |
121 | replaceImmediateSubElementsFn(replaceImmediateSubElementsFn), |
122 | typeID(typeID), name(name) {} |
123 | |
124 | /// Give StorageUserBase access to the mutable lookup. |
125 | template <typename ConcreteT, typename BaseT, typename StorageT, |
126 | typename UniquerT, template <typename T> class... Traits> |
127 | friend class detail::StorageUserBase; |
128 | |
129 | /// Look up the specified abstract type in the MLIRContext and return a |
130 | /// (mutable) pointer to it. Return a null pointer if the type could not |
131 | /// be found in the context. |
132 | static AbstractType *lookupMutable(TypeID typeID, MLIRContext *context); |
133 | |
134 | /// This is the dialect that this type was registered to. |
135 | const Dialect &dialect; |
136 | |
137 | /// This is a collection of the interfaces registered to this type. |
138 | detail::InterfaceMap interfaceMap; |
139 | |
140 | /// Function to check if the type has a particular trait. |
141 | HasTraitFn hasTraitFn; |
142 | |
143 | /// Function to walk the immediate sub-elements of this type. |
144 | WalkImmediateSubElementsFn walkImmediateSubElementsFn; |
145 | |
146 | /// Function to replace the immediate sub-elements of this type. |
147 | ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn; |
148 | |
149 | /// The unique identifier of the derived Type class. |
150 | const TypeID typeID; |
151 | |
152 | /// The unique name of this type. The string is not owned by the context, so |
153 | /// The lifetime of this string should outlive the MLIR context. |
154 | const StringRef name; |
155 | }; |
156 | |
157 | //===----------------------------------------------------------------------===// |
158 | // TypeStorage |
159 | //===----------------------------------------------------------------------===// |
160 | |
161 | namespace detail { |
162 | struct TypeUniquer; |
163 | } // namespace detail |
164 | |
165 | /// Base storage class appearing in a Type. |
166 | class TypeStorage : public StorageUniquer::BaseStorage { |
167 | friend detail::TypeUniquer; |
168 | friend StorageUniquer; |
169 | |
170 | public: |
171 | /// Return the abstract type descriptor for this type. |
172 | const AbstractType &getAbstractType() { |
173 | assert(abstractType && "Malformed type storage object." ); |
174 | return *abstractType; |
175 | } |
176 | |
177 | protected: |
178 | /// This constructor is used by derived classes as part of the TypeUniquer. |
179 | TypeStorage() {} |
180 | |
181 | private: |
182 | /// Set the abstract type for this storage instance. This is used by the |
183 | /// TypeUniquer when initializing a newly constructed type storage object. |
184 | void initialize(const AbstractType &abstractTy) { |
185 | abstractType = const_cast<AbstractType *>(&abstractTy); |
186 | } |
187 | |
188 | /// The abstract description for this type. |
189 | AbstractType *abstractType{nullptr}; |
190 | }; |
191 | |
192 | /// Default storage type for types that require no additional initialization or |
193 | /// storage. |
194 | using DefaultTypeStorage = TypeStorage; |
195 | |
196 | //===----------------------------------------------------------------------===// |
197 | // TypeStorageAllocator |
198 | //===----------------------------------------------------------------------===// |
199 | |
200 | /// This is a utility allocator used to allocate memory for instances of derived |
201 | /// Types. |
202 | using TypeStorageAllocator = StorageUniquer::StorageAllocator; |
203 | |
204 | //===----------------------------------------------------------------------===// |
205 | // TypeUniquer |
206 | //===----------------------------------------------------------------------===// |
207 | namespace detail { |
208 | /// A utility class to get, or create, unique instances of types within an |
209 | /// MLIRContext. This class manages all creation and uniquing of types. |
210 | struct TypeUniquer { |
211 | /// Get an uniqued instance of a type T. |
212 | template <typename T, typename... Args> |
213 | static T get(MLIRContext *ctx, Args &&...args) { |
214 | return getWithTypeID<T, Args...>(ctx, T::getTypeID(), |
215 | std::forward<Args>(args)...); |
216 | } |
217 | |
218 | /// Get an uniqued instance of a parametric type T. |
219 | /// The use of this method is in general discouraged in favor of |
220 | /// 'get<T, Args>(ctx, args)'. |
221 | template <typename T, typename... Args> |
222 | static std::enable_if_t< |
223 | !std::is_same<typename T::ImplType, TypeStorage>::value, T> |
224 | getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) { |
225 | #ifndef NDEBUG |
226 | if (!ctx->getTypeUniquer().isParametricStorageInitialized(id: typeID)) |
227 | llvm::report_fatal_error( |
228 | llvm::Twine("can't create type '" ) + llvm::getTypeName<T>() + |
229 | "' because storage uniquer isn't initialized: the dialect was likely " |
230 | "not loaded, or the type wasn't added with addTypes<...>() " |
231 | "in the Dialect::initialize() method." ); |
232 | #endif |
233 | return ctx->getTypeUniquer().get<typename T::ImplType>( |
234 | [&, typeID](TypeStorage *storage) { |
235 | storage->initialize(abstractTy: AbstractType::lookup(typeID, context: ctx)); |
236 | }, |
237 | typeID, std::forward<Args>(args)...); |
238 | } |
239 | /// Get an uniqued instance of a singleton type T. |
240 | /// The use of this method is in general discouraged in favor of |
241 | /// 'get<T, Args>(ctx, args)'. |
242 | template <typename T> |
243 | static std::enable_if_t< |
244 | std::is_same<typename T::ImplType, TypeStorage>::value, T> |
245 | getWithTypeID(MLIRContext *ctx, TypeID typeID) { |
246 | #ifndef NDEBUG |
247 | if (!ctx->getTypeUniquer().isSingletonStorageInitialized(id: typeID)) |
248 | llvm::report_fatal_error( |
249 | llvm::Twine("can't create type '" ) + llvm::getTypeName<T>() + |
250 | "' because storage uniquer isn't initialized: the dialect was likely " |
251 | "not loaded, or the type wasn't added with addTypes<...>() " |
252 | "in the Dialect::initialize() method." ); |
253 | #endif |
254 | return ctx->getTypeUniquer().get<typename T::ImplType>(typeID); |
255 | } |
256 | |
257 | /// Change the mutable component of the given type instance in the provided |
258 | /// context. |
259 | template <typename T, typename... Args> |
260 | static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl, |
261 | Args &&...args) { |
262 | assert(impl && "cannot mutate null type" ); |
263 | return ctx->getTypeUniquer().mutate(T::getTypeID(), impl, |
264 | std::forward<Args>(args)...); |
265 | } |
266 | |
267 | /// Register a type instance T with the uniquer. |
268 | template <typename T> |
269 | static void registerType(MLIRContext *ctx) { |
270 | registerType<T>(ctx, T::getTypeID()); |
271 | } |
272 | |
273 | /// Register a parametric type instance T with the uniquer. |
274 | /// The use of this method is in general discouraged in favor of |
275 | /// 'registerType<T>(ctx)'. |
276 | template <typename T> |
277 | static std::enable_if_t< |
278 | !std::is_same<typename T::ImplType, TypeStorage>::value> |
279 | registerType(MLIRContext *ctx, TypeID typeID) { |
280 | ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>( |
281 | typeID); |
282 | } |
283 | /// Register a singleton type instance T with the uniquer. |
284 | /// The use of this method is in general discouraged in favor of |
285 | /// 'registerType<T>(ctx)'. |
286 | template <typename T> |
287 | static std::enable_if_t< |
288 | std::is_same<typename T::ImplType, TypeStorage>::value> |
289 | registerType(MLIRContext *ctx, TypeID typeID) { |
290 | ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>( |
291 | typeID, [&ctx, typeID](TypeStorage *storage) { |
292 | storage->initialize(abstractTy: AbstractType::lookup(typeID, context: ctx)); |
293 | }); |
294 | } |
295 | }; |
296 | } // namespace detail |
297 | |
298 | } // namespace mlir |
299 | |
300 | #endif |
301 | |