1 | //===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines utility classes for interfacing with StorageUniquer. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H |
14 | #define MLIR_IR_STORAGEUNIQUERSUPPORT_H |
15 | |
16 | #include "mlir/IR/AttrTypeSubElements.h" |
17 | #include "mlir/IR/DialectRegistry.h" |
18 | #include "mlir/Support/InterfaceSupport.h" |
19 | #include "mlir/Support/LogicalResult.h" |
20 | #include "mlir/Support/StorageUniquer.h" |
21 | #include "mlir/Support/TypeID.h" |
22 | #include "llvm/ADT/FunctionExtras.h" |
23 | |
24 | namespace mlir { |
25 | class InFlightDiagnostic; |
26 | class Location; |
27 | class MLIRContext; |
28 | |
29 | namespace detail { |
30 | /// Utility method to generate a callback that can be used to generate a |
31 | /// diagnostic when checking the construction invariants of a storage object. |
32 | /// This is defined out-of-line to avoid the need to include Location.h. |
33 | llvm::unique_function<InFlightDiagnostic()> |
34 | getDefaultDiagnosticEmitFn(MLIRContext *ctx); |
35 | llvm::unique_function<InFlightDiagnostic()> |
36 | getDefaultDiagnosticEmitFn(const Location &loc); |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // StorageUserTraitBase |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | /// Helper class for implementing traits for storage classes. Clients are not |
43 | /// expected to interact with this directly, so its members are all protected. |
44 | template <typename ConcreteType, template <typename> class TraitType> |
45 | class StorageUserTraitBase { |
46 | protected: |
47 | /// Return the derived instance. |
48 | ConcreteType getInstance() const { |
49 | // We have to cast up to the trait type, then to the concrete type because |
50 | // the concrete type will multiply derive from the (content free) TraitBase |
51 | // class, and we need to be able to disambiguate the path for the C++ |
52 | // compiler. |
53 | auto *trait = static_cast<const TraitType<ConcreteType> *>(this); |
54 | return *static_cast<const ConcreteType *>(trait); |
55 | } |
56 | }; |
57 | |
58 | namespace StorageUserTrait { |
59 | /// This trait is used to determine if a storage user, like Type, is mutable |
60 | /// or not. A storage user is mutable if ImplType of the derived class defines |
61 | /// a `mutate` function with a proper signature. Note that this trait is not |
62 | /// supposed to be used publicly. Users should use alias names like |
63 | /// `TypeTrait::IsMutable` instead. |
64 | template <typename ConcreteType> |
65 | struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {}; |
66 | } // namespace StorageUserTrait |
67 | |
68 | //===----------------------------------------------------------------------===// |
69 | // StorageUserBase |
70 | //===----------------------------------------------------------------------===// |
71 | |
72 | namespace storage_user_base_impl { |
73 | /// Returns true if this given Trait ID matches the IDs of any of the provided |
74 | /// trait types `Traits`. |
75 | template <template <typename T> class... Traits> |
76 | bool hasTrait(TypeID traitID) { |
77 | TypeID traitIDs[] = {TypeID::get<Traits>()...}; |
78 | for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i) |
79 | if (traitIDs[i] == traitID) |
80 | return true; |
81 | return false; |
82 | } |
83 | |
84 | // We specialize for the empty case to not define an empty array. |
85 | template <> |
86 | inline bool hasTrait(TypeID traitID) { |
87 | return false; |
88 | } |
89 | } // namespace storage_user_base_impl |
90 | |
91 | /// Utility class for implementing users of storage classes uniqued by a |
92 | /// StorageUniquer. Clients are not expected to interact with this class |
93 | /// directly. |
94 | template <typename ConcreteT, typename BaseT, typename StorageT, |
95 | typename UniquerT, template <typename T> class... Traits> |
96 | class StorageUserBase : public BaseT, public Traits<ConcreteT>... { |
97 | public: |
98 | using BaseT::BaseT; |
99 | |
100 | /// Utility declarations for the concrete attribute class. |
101 | using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>; |
102 | using ImplType = StorageT; |
103 | using HasTraitFn = bool (*)(TypeID); |
104 | |
105 | /// Return a unique identifier for the concrete type. |
106 | static TypeID getTypeID() { return TypeID::get<ConcreteT>(); } |
107 | |
108 | /// Provide an implementation of 'classof' that compares the type id of the |
109 | /// provided value with that of the concrete type. |
110 | template <typename T> |
111 | static bool classof(T val) { |
112 | static_assert(std::is_convertible<ConcreteT, T>::value, |
113 | "casting from a non-convertible type" ); |
114 | return val.getTypeID() == getTypeID(); |
115 | } |
116 | |
117 | /// Returns an interface map for the interfaces registered to this storage |
118 | /// user. This should not be used directly. |
119 | static detail::InterfaceMap getInterfaceMap() { |
120 | return detail::InterfaceMap::template get<Traits<ConcreteT>...>(); |
121 | } |
122 | |
123 | /// Returns the function that returns true if the given Trait ID matches the |
124 | /// IDs of any of the traits defined by the storage user. |
125 | static HasTraitFn getHasTraitFn() { |
126 | return [](TypeID id) { |
127 | return storage_user_base_impl::hasTrait<Traits...>(id); |
128 | }; |
129 | } |
130 | |
131 | /// Returns a function that walks immediate sub elements of a given instance |
132 | /// of the storage user. |
133 | static auto getWalkImmediateSubElementsFn() { |
134 | return [](auto instance, function_ref<void(Attribute)> walkAttrsFn, |
135 | function_ref<void(Type)> walkTypesFn) { |
136 | ::mlir::detail::walkImmediateSubElementsImpl( |
137 | llvm::cast<ConcreteT>(instance), walkAttrsFn, walkTypesFn); |
138 | }; |
139 | } |
140 | |
141 | /// Returns a function that replaces immediate sub elements of a given |
142 | /// instance of the storage user. |
143 | static auto getReplaceImmediateSubElementsFn() { |
144 | return [](auto instance, ArrayRef<Attribute> replAttrs, |
145 | ArrayRef<Type> replTypes) { |
146 | return ::mlir::detail::replaceImmediateSubElementsImpl( |
147 | llvm::cast<ConcreteT>(instance), replAttrs, replTypes); |
148 | }; |
149 | } |
150 | |
151 | /// Attach the given models as implementations of the corresponding interfaces |
152 | /// for the concrete storage user class. The type must be registered with the |
153 | /// context, i.e. the dialect to which the type belongs must be loaded. The |
154 | /// call will abort otherwise. |
155 | template <typename... IfaceModels> |
156 | static void attachInterface(MLIRContext &context) { |
157 | typename ConcreteT::AbstractTy *abstract = |
158 | ConcreteT::AbstractTy::lookupMutable(TypeID::get<ConcreteT>(), |
159 | &context); |
160 | if (!abstract) |
161 | llvm::report_fatal_error(reason: "Registering an interface for an attribute/type " |
162 | "that is not itself registered." ); |
163 | |
164 | // Handle the case where the models resolve a promised interface. |
165 | (dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface( |
166 | dialect&: abstract->getDialect(), interfaceRequestorID: abstract->getTypeID(), |
167 | interfaceID: IfaceModels::Interface::getInterfaceID()), |
168 | ...); |
169 | |
170 | (checkInterfaceTarget<IfaceModels>(), ...); |
171 | abstract->interfaceMap.template insertModels<IfaceModels...>(); |
172 | } |
173 | |
174 | /// Get or create a new ConcreteT instance within the ctx. This |
175 | /// function is guaranteed to return a non null object and will assert if |
176 | /// the arguments provided are invalid. |
177 | template <typename... Args> |
178 | static ConcreteT get(MLIRContext *ctx, Args &&...args) { |
179 | // Ensure that the invariants are correct for construction. |
180 | assert( |
181 | succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...))); |
182 | return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...); |
183 | } |
184 | |
185 | /// Get or create a new ConcreteT instance within the ctx, defined at |
186 | /// the given, potentially unknown, location. If the arguments provided are |
187 | /// invalid, errors are emitted using the provided location and a null object |
188 | /// is returned. |
189 | template <typename... Args> |
190 | static ConcreteT getChecked(const Location &loc, Args &&...args) { |
191 | return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), |
192 | std::forward<Args>(args)...); |
193 | } |
194 | |
195 | /// Get or create a new ConcreteT instance within the ctx. If the arguments |
196 | /// provided are invalid, errors are emitted using the provided `emitError` |
197 | /// and a null object is returned. |
198 | template <typename... Args> |
199 | static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, |
200 | MLIRContext *ctx, Args... args) { |
201 | // If the construction invariants fail then we return a null attribute. |
202 | if (failed(ConcreteT::verify(emitErrorFn, args...))) |
203 | return ConcreteT(); |
204 | return UniquerT::template get<ConcreteT>(ctx, args...); |
205 | } |
206 | |
207 | /// Get an instance of the concrete type from a void pointer. |
208 | static ConcreteT getFromOpaquePointer(const void *ptr) { |
209 | return ConcreteT((const typename BaseT::ImplType *)ptr); |
210 | } |
211 | |
212 | /// Utility for easy access to the storage instance. |
213 | ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); } |
214 | |
215 | protected: |
216 | /// Mutate the current storage instance. This will not change the unique key. |
217 | /// The arguments are forwarded to 'ConcreteT::mutate'. |
218 | template <typename... Args> |
219 | LogicalResult mutate(Args &&...args) { |
220 | static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>, |
221 | ConcreteT>::value, |
222 | "The `mutate` function expects mutable trait " |
223 | "(e.g. TypeTrait::IsMutable) to be attached on parent." ); |
224 | return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(), |
225 | std::forward<Args>(args)...); |
226 | } |
227 | |
228 | /// Default implementation that just returns success. |
229 | template <typename... Args> |
230 | static LogicalResult verify(Args... args) { |
231 | return success(); |
232 | } |
233 | |
234 | private: |
235 | /// Trait to check if T provides a 'ConcreteEntity' type alias. |
236 | template <typename T> |
237 | using has_concrete_entity_t = typename T::ConcreteEntity; |
238 | |
239 | /// A struct-wrapped type alias to T::ConcreteEntity if provided and to |
240 | /// ConcreteT otherwise. This is akin to std::conditional but doesn't fail on |
241 | /// the missing typedef. Useful for checking if the interface is targeting the |
242 | /// right class. |
243 | template <typename T, |
244 | bool = llvm::is_detected<has_concrete_entity_t, T>::value> |
245 | struct IfaceTargetOrConcreteT { |
246 | using type = typename T::ConcreteEntity; |
247 | }; |
248 | template <typename T> |
249 | struct IfaceTargetOrConcreteT<T, false> { |
250 | using type = ConcreteT; |
251 | }; |
252 | |
253 | /// A hook for static assertion that the external interface model T is |
254 | /// targeting a base class of the concrete attribute/type. The model can also |
255 | /// be a fallback model that works for every attribute/type. |
256 | template <typename T> |
257 | static void checkInterfaceTarget() { |
258 | static_assert(std::is_base_of<typename IfaceTargetOrConcreteT<T>::type, |
259 | ConcreteT>::value, |
260 | "attaching an interface to the wrong attribute/type kind" ); |
261 | } |
262 | }; |
263 | } // namespace detail |
264 | } // namespace mlir |
265 | |
266 | #endif |
267 | |