1 | //===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_SUPPORT_STORAGEUNIQUER_H |
10 | #define MLIR_SUPPORT_STORAGEUNIQUER_H |
11 | |
12 | #include "mlir/Support/LLVM.h" |
13 | #include "mlir/Support/LogicalResult.h" |
14 | #include "mlir/Support/TypeID.h" |
15 | #include "llvm/ADT/ArrayRef.h" |
16 | #include "llvm/ADT/DenseSet.h" |
17 | #include "llvm/ADT/StringRef.h" |
18 | #include "llvm/Support/Allocator.h" |
19 | #include <utility> |
20 | |
21 | namespace mlir { |
22 | namespace detail { |
23 | struct StorageUniquerImpl; |
24 | |
25 | /// Trait to check if ImplTy provides a 'getKey' method with types 'Args'. |
26 | template <typename ImplTy, typename... Args> |
27 | using has_impltype_getkey_t = decltype(ImplTy::getKey(std::declval<Args>()...)); |
28 | |
29 | /// Trait to check if ImplTy provides a 'hashKey' method for 'T'. |
30 | template <typename ImplTy, typename T> |
31 | using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>())); |
32 | } // namespace detail |
33 | |
34 | /// A utility class to get or create instances of "storage classes". These |
35 | /// storage classes must derive from 'StorageUniquer::BaseStorage'. |
36 | /// |
37 | /// For non-parametric storage classes, i.e. singleton classes, nothing else is |
38 | /// needed. Instances of these classes can be created by calling `get` without |
39 | /// trailing arguments. |
40 | /// |
41 | /// Otherwise, the parametric storage classes may be created with `get`, |
42 | /// and must respect the following: |
43 | /// - Define a type alias, KeyTy, to a type that uniquely identifies the |
44 | /// instance of the storage class. |
45 | /// * The key type must be constructible from the values passed into the |
46 | /// getComplex call. |
47 | /// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the |
48 | /// storage class must define a hashing method: |
49 | /// 'static unsigned hashKey(const KeyTy &)' |
50 | /// |
51 | /// - Provide a method, 'bool operator==(const KeyTy &) const', to |
52 | /// compare the storage instance against an instance of the key type. |
53 | /// |
54 | /// - Provide a static construction method: |
55 | /// 'DerivedStorage *construct(StorageAllocator &, const KeyTy &key)' |
56 | /// that builds a unique instance of the derived storage. The arguments to |
57 | /// this function are an allocator to store any uniqued data and the key |
58 | /// type for this storage. |
59 | /// |
60 | /// - Provide a cleanup method: |
61 | /// 'void cleanup()' |
62 | /// that is called when erasing a storage instance. This should cleanup any |
63 | /// fields of the storage as necessary and not attempt to free the memory |
64 | /// of the storage itself. |
65 | /// |
66 | /// Storage classes may have an optional mutable component, which must not take |
67 | /// part in the unique immutable key. In this case, storage classes may be |
68 | /// mutated with `mutate` and must additionally respect the following: |
69 | /// - Provide a mutation method: |
70 | /// 'LogicalResult mutate(StorageAllocator &, <...>)' |
71 | /// that is called when mutating a storage instance. The first argument is |
72 | /// an allocator to store any mutable data, and the remaining arguments are |
73 | /// forwarded from the call site. The storage can be mutated at any time |
74 | /// after creation. Care must be taken to avoid excessive mutation since |
75 | /// the allocated storage can keep containing previous states. The return |
76 | /// value of the function is used to indicate whether the mutation was |
77 | /// successful, e.g., to limit the number of mutations or enable deferred |
78 | /// one-time assignment of the mutable component. |
79 | /// |
80 | /// All storage classes must be registered with the uniquer via |
81 | /// `registerParametricStorageType` or `registerSingletonStorageType` |
82 | /// using an appropriate unique `TypeID` for the storage class. |
83 | class StorageUniquer { |
84 | public: |
85 | /// This class acts as the base storage that all storage classes must derived |
86 | /// from. |
87 | class alignas(8) BaseStorage { |
88 | protected: |
89 | BaseStorage() = default; |
90 | }; |
91 | |
92 | /// This is a utility allocator used to allocate memory for instances of |
93 | /// derived types. |
94 | class StorageAllocator { |
95 | public: |
96 | /// Copy the specified array of elements into memory managed by our bump |
97 | /// pointer allocator. This assumes the elements are all PODs. |
98 | template <typename T> |
99 | ArrayRef<T> copyInto(ArrayRef<T> elements) { |
100 | if (elements.empty()) |
101 | return std::nullopt; |
102 | auto result = allocator.Allocate<T>(elements.size()); |
103 | std::uninitialized_copy(elements.begin(), elements.end(), result); |
104 | return ArrayRef<T>(result, elements.size()); |
105 | } |
106 | |
107 | /// Copy the provided string into memory managed by our bump pointer |
108 | /// allocator. |
109 | StringRef copyInto(StringRef str) { |
110 | if (str.empty()) |
111 | return StringRef(); |
112 | |
113 | char *result = allocator.Allocate<char>(Num: str.size() + 1); |
114 | std::uninitialized_copy(first: str.begin(), last: str.end(), result: result); |
115 | result[str.size()] = 0; |
116 | return StringRef(result, str.size()); |
117 | } |
118 | |
119 | /// Allocate an instance of the provided type. |
120 | template <typename T> |
121 | T *allocate() { |
122 | return allocator.Allocate<T>(); |
123 | } |
124 | |
125 | /// Allocate 'size' bytes of 'alignment' aligned memory. |
126 | void *allocate(size_t size, size_t alignment) { |
127 | return allocator.Allocate(Size: size, Alignment: alignment); |
128 | } |
129 | |
130 | /// Returns true if this allocator allocated the provided object pointer. |
131 | bool allocated(const void *ptr) { |
132 | return allocator.identifyObject(Ptr: ptr).has_value(); |
133 | } |
134 | |
135 | private: |
136 | /// The raw allocator for type storage objects. |
137 | llvm::BumpPtrAllocator allocator; |
138 | }; |
139 | |
140 | StorageUniquer(); |
141 | ~StorageUniquer(); |
142 | |
143 | /// Set the flag specifying if multi-threading is disabled within the uniquer. |
144 | void disableMultithreading(bool disable = true); |
145 | |
146 | /// Register a new parametric storage class, this is necessary to create |
147 | /// instances of this class type. `id` is the type identifier that will be |
148 | /// used to identify this type when creating instances of it via 'get'. |
149 | template <typename Storage> |
150 | void registerParametricStorageType(TypeID id) { |
151 | // If the storage is trivially destructible, we don't need a destructor |
152 | // function. |
153 | if constexpr (std::is_trivially_destructible_v<Storage>) |
154 | return registerParametricStorageTypeImpl(id, destructorFn: nullptr); |
155 | registerParametricStorageTypeImpl(id, destructorFn: [](BaseStorage *storage) { |
156 | static_cast<Storage *>(storage)->~Storage(); |
157 | }); |
158 | } |
159 | /// Utility override when the storage type represents the type id. |
160 | template <typename Storage> |
161 | void registerParametricStorageType() { |
162 | registerParametricStorageType<Storage>(TypeID::get<Storage>()); |
163 | } |
164 | /// Register a new singleton storage class, this is necessary to get the |
165 | /// singletone instance. `id` is the type identifier that will be used to |
166 | /// access the singleton instance via 'get'. An optional initialization |
167 | /// function may also be provided to initialize the newly created storage |
168 | /// instance, and used when the singleton instance is created. |
169 | template <typename Storage> |
170 | void registerSingletonStorageType(TypeID id, |
171 | function_ref<void(Storage *)> initFn) { |
172 | auto ctorFn = [&](StorageAllocator &allocator) { |
173 | auto *storage = new (allocator.allocate<Storage>()) Storage(); |
174 | if (initFn) |
175 | initFn(storage); |
176 | return storage; |
177 | }; |
178 | registerSingletonImpl(id, ctorFn); |
179 | } |
180 | template <typename Storage> |
181 | void registerSingletonStorageType(TypeID id) { |
182 | registerSingletonStorageType<Storage>(id, std::nullopt); |
183 | } |
184 | /// Utility override when the storage type represents the type id. |
185 | template <typename Storage> |
186 | void registerSingletonStorageType(function_ref<void(Storage *)> initFn = {}) { |
187 | registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn); |
188 | } |
189 | |
190 | /// Gets a uniqued instance of 'Storage'. 'id' is the type id used when |
191 | /// registering the storage instance. 'initFn' is an optional parameter that |
192 | /// can be used to initialize a newly inserted storage instance. This function |
193 | /// is used for derived types that have complex storage or uniquing |
194 | /// constraints. |
195 | template <typename Storage, typename... Args> |
196 | Storage *get(function_ref<void(Storage *)> initFn, TypeID id, |
197 | Args &&...args) { |
198 | // Construct a value of the derived key type. |
199 | auto derivedKey = getKey<Storage>(std::forward<Args>(args)...); |
200 | |
201 | // Create a hash of the derived key. |
202 | unsigned hashValue = getHash<Storage>(derivedKey); |
203 | |
204 | // Generate an equality function for the derived storage. |
205 | auto isEqual = [&derivedKey](const BaseStorage *existing) { |
206 | return static_cast<const Storage &>(*existing) == derivedKey; |
207 | }; |
208 | |
209 | // Generate a constructor function for the derived storage. |
210 | auto ctorFn = [&](StorageAllocator &allocator) { |
211 | auto *storage = Storage::construct(allocator, std::move(derivedKey)); |
212 | if (initFn) |
213 | initFn(storage); |
214 | return storage; |
215 | }; |
216 | |
217 | // Get an instance for the derived storage. |
218 | return static_cast<Storage *>( |
219 | getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn)); |
220 | } |
221 | /// Utility override when the storage type represents the type id. |
222 | template <typename Storage, typename... Args> |
223 | Storage *get(function_ref<void(Storage *)> initFn, Args &&...args) { |
224 | return get<Storage>(initFn, TypeID::get<Storage>(), |
225 | std::forward<Args>(args)...); |
226 | } |
227 | |
228 | /// Gets a uniqued instance of 'Storage' which is a singleton storage type. |
229 | /// 'id' is the type id used when registering the storage instance. |
230 | template <typename Storage> |
231 | Storage *get(TypeID id) { |
232 | return static_cast<Storage *>(getSingletonImpl(id)); |
233 | } |
234 | /// Utility override when the storage type represents the type id. |
235 | template <typename Storage> |
236 | Storage *get() { |
237 | return get<Storage>(TypeID::get<Storage>()); |
238 | } |
239 | |
240 | /// Test if there is a singleton storage uniquer initialized for the provided |
241 | /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer |
242 | /// is initialized when a dialect is loaded. |
243 | bool isSingletonStorageInitialized(TypeID id); |
244 | |
245 | /// Test if there is a parametric storage uniquer initialized for the provided |
246 | /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer |
247 | /// is initialized when a dialect is loaded. |
248 | bool isParametricStorageInitialized(TypeID id); |
249 | |
250 | /// Changes the mutable component of 'storage' by forwarding the trailing |
251 | /// arguments to the 'mutate' function of the derived class. |
252 | template <typename Storage, typename... Args> |
253 | LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) { |
254 | auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult { |
255 | return static_cast<Storage &>(*storage).mutate( |
256 | allocator, std::forward<Args>(args)...); |
257 | }; |
258 | return mutateImpl(id, storage, mutationFn); |
259 | } |
260 | |
261 | private: |
262 | /// Implementation for getting/creating an instance of a derived type with |
263 | /// parametric storage. |
264 | BaseStorage *getParametricStorageTypeImpl( |
265 | TypeID id, unsigned hashValue, |
266 | function_ref<bool(const BaseStorage *)> isEqual, |
267 | function_ref<BaseStorage *(StorageAllocator &)> ctorFn); |
268 | |
269 | /// Implementation for registering an instance of a derived type with |
270 | /// parametric storage. This method takes an optional destructor function that |
271 | /// destructs storage instances when necessary. |
272 | void registerParametricStorageTypeImpl( |
273 | TypeID id, function_ref<void(BaseStorage *)> destructorFn); |
274 | |
275 | /// Implementation for getting an instance of a derived type with default |
276 | /// storage. |
277 | BaseStorage *getSingletonImpl(TypeID id); |
278 | |
279 | /// Implementation for registering an instance of a derived type with default |
280 | /// storage. |
281 | void |
282 | registerSingletonImpl(TypeID id, |
283 | function_ref<BaseStorage *(StorageAllocator &)> ctorFn); |
284 | |
285 | /// Implementation for mutating an instance of a derived storage. |
286 | LogicalResult |
287 | mutateImpl(TypeID id, BaseStorage *storage, |
288 | function_ref<LogicalResult(StorageAllocator &)> mutationFn); |
289 | |
290 | /// The internal implementation class. |
291 | std::unique_ptr<detail::StorageUniquerImpl> impl; |
292 | |
293 | //===--------------------------------------------------------------------===// |
294 | // Key Construction |
295 | //===--------------------------------------------------------------------===// |
296 | |
297 | /// Used to construct an instance of 'ImplTy::KeyTy' if there is an |
298 | /// 'ImplTy::getKey' function for the provided arguments. Otherwise, then we |
299 | /// try to directly construct the 'ImplTy::KeyTy' with the provided arguments. |
300 | template <typename ImplTy, typename... Args> |
301 | static typename ImplTy::KeyTy getKey(Args &&...args) { |
302 | if constexpr (llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, |
303 | Args...>::value) |
304 | return ImplTy::getKey(std::forward<Args>(args)...); |
305 | else |
306 | return typename ImplTy::KeyTy(std::forward<Args>(args)...); |
307 | } |
308 | |
309 | //===--------------------------------------------------------------------===// |
310 | // Key Hashing |
311 | //===--------------------------------------------------------------------===// |
312 | |
313 | /// Used to generate a hash for the `ImplTy` of a storage instance if |
314 | /// there is a `ImplTy::hashKey. Otherwise, if there is no `ImplTy::hashKey` |
315 | /// then default to using the 'llvm::DenseMapInfo' definition for |
316 | /// 'DerivedKey' for generating a hash. |
317 | template <typename ImplTy, typename DerivedKey> |
318 | static ::llvm::hash_code getHash(const DerivedKey &derivedKey) { |
319 | if constexpr (llvm::is_detected<detail::has_impltype_hash_t, ImplTy, |
320 | DerivedKey>::value) |
321 | return ImplTy::hashKey(derivedKey); |
322 | else |
323 | return DenseMapInfo<DerivedKey>::getHashValue(derivedKey); |
324 | } |
325 | }; |
326 | } // namespace mlir |
327 | |
328 | #endif |
329 | |