1 | //===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines several support classes for defining interfaces. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_SUPPORT_INTERFACESUPPORT_H |
14 | #define MLIR_SUPPORT_INTERFACESUPPORT_H |
15 | |
16 | #include "mlir/Support/TypeID.h" |
17 | #include "llvm/ADT/ArrayRef.h" |
18 | #include "llvm/ADT/DenseMap.h" |
19 | #include "llvm/Support/TypeName.h" |
20 | |
21 | namespace mlir { |
22 | namespace detail { |
23 | //===----------------------------------------------------------------------===// |
24 | // Interface |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | /// This class represents an abstract interface. An interface is a simplified |
28 | /// mechanism for attaching concept based polymorphism to a class hierarchy. An |
29 | /// interface is comprised of two components: |
30 | /// * The derived interface class: This is what users interact with, and invoke |
31 | /// methods on. |
32 | /// * An interface `Trait` class: This is the class that is attached to the |
33 | /// object implementing the interface. It is the mechanism with which models |
34 | /// are specialized. |
35 | /// |
36 | /// Derived interfaces types must provide the following template types: |
37 | /// * ConcreteType: The CRTP derived type. |
38 | /// * ValueT: The opaque type the derived interface operates on. For example |
39 | /// `Operation*` for operation interfaces, or `Attribute` for |
40 | /// attribute interfaces. |
41 | /// * Traits: A class that contains definitions for a 'Concept' and a 'Model' |
42 | /// class. The 'Concept' class defines an abstract virtual interface, |
43 | /// where as the 'Model' class implements this interface for a |
44 | /// specific derived T type. Both of these classes *must* not contain |
45 | /// non-static data. A simple example is shown below: |
46 | /// |
47 | /// ```c++ |
48 | /// struct ExampleInterfaceTraits { |
49 | /// struct Concept { |
50 | /// virtual unsigned getNumInputs(T t) const = 0; |
51 | /// }; |
52 | /// template <typename DerivedT> class Model { |
53 | /// unsigned getNumInputs(T t) const final { |
54 | /// return cast<DerivedT>(t).getNumInputs(); |
55 | /// } |
56 | /// }; |
57 | /// }; |
58 | /// ``` |
59 | /// |
60 | /// * BaseType: A desired base type for the interface. This is a class |
61 | /// that provides specific functionality for the `ValueT` |
62 | /// value. For instance the specific `Op` that will wrap the |
63 | /// `Operation*` for an `OpInterface`. |
64 | /// * BaseTrait: The base type for the interface trait. This is the base class |
65 | /// to use for the interface trait that will be attached to each |
66 | /// instance of `ValueT` that implements this interface. |
67 | /// |
68 | template <typename ConcreteType, typename ValueT, typename Traits, |
69 | typename BaseType, |
70 | template <typename, template <typename> class> class BaseTrait> |
71 | class Interface : public BaseType { |
72 | public: |
73 | using Concept = typename Traits::Concept; |
74 | template <typename T> |
75 | using Model = typename Traits::template Model<T>; |
76 | template <typename T> |
77 | using FallbackModel = typename Traits::template FallbackModel<T>; |
78 | using InterfaceBase = |
79 | Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>; |
80 | template <typename T, typename U> |
81 | using ExternalModel = typename Traits::template ExternalModel<T, U>; |
82 | using ValueType = ValueT; |
83 | |
84 | /// This is a special trait that registers a given interface with an object. |
85 | template <typename ConcreteT> |
86 | struct Trait : public BaseTrait<ConcreteT, Trait> { |
87 | using ModelT = Model<ConcreteT>; |
88 | |
89 | /// Define an accessor for the ID of this interface. |
90 | static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } |
91 | }; |
92 | |
93 | /// Construct an interface from an instance of the value type. |
94 | explicit Interface(ValueT t = ValueT()) |
95 | : BaseType(t), |
96 | conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { |
97 | assert((!t || conceptImpl) && |
98 | "expected value to provide interface instance" ); |
99 | } |
100 | Interface(std::nullptr_t) : BaseType(ValueT()), conceptImpl(nullptr) {} |
101 | |
102 | /// Construct an interface instance from a type that implements this |
103 | /// interface's trait. |
104 | template <typename T, |
105 | std::enable_if_t<std::is_base_of<Trait<T>, T>::value> * = nullptr> |
106 | Interface(T t) |
107 | : BaseType(t), |
108 | conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { |
109 | assert((!t || conceptImpl) && |
110 | "expected value to provide interface instance" ); |
111 | } |
112 | |
113 | /// Constructor for a known concept. |
114 | Interface(ValueT t, const Concept *conceptImpl) |
115 | : BaseType(t), conceptImpl(const_cast<Concept *>(conceptImpl)) { |
116 | assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl); |
117 | } |
118 | |
119 | /// Constructor for DenseMapInfo's empty key and tombstone key. |
120 | Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {} |
121 | |
122 | /// Support 'classof' by checking if the given object defines the concrete |
123 | /// interface. |
124 | static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); } |
125 | |
126 | /// Define an accessor for the ID of this interface. |
127 | static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } |
128 | |
129 | protected: |
130 | /// Get the raw concept in the correct derived concept type. |
131 | const Concept *getImpl() const { return conceptImpl; } |
132 | Concept *getImpl() { return conceptImpl; } |
133 | |
134 | private: |
135 | /// A pointer to the impl concept object. |
136 | Concept *conceptImpl; |
137 | }; |
138 | |
139 | //===----------------------------------------------------------------------===// |
140 | // InterfaceMap |
141 | //===----------------------------------------------------------------------===// |
142 | |
143 | /// Template utility that computes the number of elements within `T` that |
144 | /// satisfy the given predicate. |
145 | template <template <class> class Pred, size_t N, typename... Ts> |
146 | struct count_if_t_impl : public std::integral_constant<size_t, N> {}; |
147 | template <template <class> class Pred, size_t N, typename T, typename... Us> |
148 | struct count_if_t_impl<Pred, N, T, Us...> |
149 | : public std::integral_constant< |
150 | size_t, |
151 | count_if_t_impl<Pred, N + (Pred<T>::value ? 1 : 0), Us...>::value> {}; |
152 | template <template <class> class Pred, typename... Ts> |
153 | using count_if_t = count_if_t_impl<Pred, 0, Ts...>; |
154 | |
155 | /// This class provides an efficient mapping between a given `Interface` type, |
156 | /// and a particular implementation of its concept. |
157 | class InterfaceMap { |
158 | /// Trait to check if T provides a static 'getInterfaceID' method. |
159 | template <typename T, typename... Args> |
160 | using has_get_interface_id = decltype(T::getInterfaceID()); |
161 | template <typename T> |
162 | using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>; |
163 | template <typename... Types> |
164 | using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>; |
165 | |
166 | /// Trait to check if T provides a 'initializeInterfaceConcept' method. |
167 | template <typename T, typename... Args> |
168 | using has_initialize_method = |
169 | decltype(std::declval<T>().initializeInterfaceConcept( |
170 | std::declval<InterfaceMap &>())); |
171 | template <typename T> |
172 | using detect_initialize_method = llvm::is_detected<has_initialize_method, T>; |
173 | |
174 | public: |
175 | InterfaceMap() = default; |
176 | InterfaceMap(InterfaceMap &&) = default; |
177 | InterfaceMap &operator=(InterfaceMap &&rhs) { |
178 | for (auto &it : interfaces) |
179 | free(ptr: it.second); |
180 | interfaces = std::move(rhs.interfaces); |
181 | return *this; |
182 | } |
183 | ~InterfaceMap() { |
184 | for (auto &it : interfaces) |
185 | free(ptr: it.second); |
186 | } |
187 | |
188 | /// Construct an InterfaceMap with the given set of template types. For |
189 | /// convenience given that object trait lists may contain other non-interface |
190 | /// types, not all of the types need to be interfaces. The provided types that |
191 | /// do not represent interfaces are not added to the interface map. |
192 | template <typename... Types> |
193 | static InterfaceMap get() { |
194 | constexpr size_t numInterfaces = num_interface_types_t<Types...>::value; |
195 | if constexpr (numInterfaces == 0) |
196 | return InterfaceMap(); |
197 | |
198 | InterfaceMap map; |
199 | (map.insertPotentialInterface<Types>(), ...); |
200 | return map; |
201 | } |
202 | |
203 | /// Returns an instance of the concept object for the given interface if it |
204 | /// was registered to this map, null otherwise. |
205 | template <typename T> |
206 | typename T::Concept *lookup() const { |
207 | return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID())); |
208 | } |
209 | |
210 | /// Returns true if the interface map contains an interface for the given id. |
211 | bool contains(TypeID interfaceID) const { return lookup(id: interfaceID); } |
212 | |
213 | /// Insert the given interface models. |
214 | template <typename... IfaceModels> |
215 | void insertModels() { |
216 | (insertModel<IfaceModels>(), ...); |
217 | } |
218 | |
219 | private: |
220 | /// Insert the given interface type into the map, ignoring it if it doesn't |
221 | /// actually represent an interface. |
222 | template <typename T> |
223 | inline void insertPotentialInterface() { |
224 | if constexpr (detect_get_interface_id<T>::value) |
225 | insertModel<typename T::ModelT>(); |
226 | } |
227 | |
228 | /// Insert the given interface model into the map. |
229 | template <typename InterfaceModel> |
230 | void insertModel() { |
231 | // FIXME(#59975): Uncomment this when SPIRV no longer awkwardly reimplements |
232 | // interfaces in a way that isn't clean/compatible. |
233 | // static_assert(std::is_trivially_destructible_v<InterfaceModel>, |
234 | // "interface models must be trivially destructible"); |
235 | |
236 | // Build the interface model, optionally initializing if necessary. |
237 | InterfaceModel *model = |
238 | new (malloc(size: sizeof(InterfaceModel))) InterfaceModel(); |
239 | if constexpr (detect_initialize_method<InterfaceModel>::value) |
240 | model->initializeInterfaceConcept(*this); |
241 | |
242 | insert(interfaceId: InterfaceModel::Interface::getInterfaceID(), conceptImpl: model); |
243 | } |
244 | /// Insert the given set of interface id and concept implementation into the |
245 | /// interface map. |
246 | void insert(TypeID interfaceId, void *conceptImpl); |
247 | |
248 | /// Compare two TypeID instances by comparing the underlying pointer. |
249 | static bool compare(TypeID lhs, TypeID rhs) { |
250 | return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); |
251 | } |
252 | |
253 | /// Returns an instance of the concept object for the given interface id if it |
254 | /// was registered to this map, null otherwise. |
255 | void *lookup(TypeID id) const { |
256 | const auto *it = |
257 | llvm::lower_bound(Range: interfaces, Value&: id, C: [](const auto &it, TypeID id) { |
258 | return compare(lhs: it.first, rhs: id); |
259 | }); |
260 | return (it != interfaces.end() && it->first == id) ? it->second : nullptr; |
261 | } |
262 | |
263 | /// A list of interface instances, sorted by TypeID. |
264 | SmallVector<std::pair<TypeID, void *>> interfaces; |
265 | }; |
266 | |
267 | template <typename ConcreteType, typename ValueT, typename Traits, |
268 | typename BaseType, |
269 | template <typename, template <typename> class> class BaseTrait> |
270 | void isInterfaceImpl( |
271 | Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait> &); |
272 | |
273 | template <typename T> |
274 | using is_interface_t = decltype(isInterfaceImpl(std::declval<T &>())); |
275 | |
276 | template <typename T> |
277 | using IsInterface = llvm::is_detected<is_interface_t, T>; |
278 | |
279 | } // namespace detail |
280 | } // namespace mlir |
281 | |
282 | namespace llvm { |
283 | |
284 | template <typename T> |
285 | struct DenseMapInfo<T, std::enable_if_t<mlir::detail::IsInterface<T>::value>> { |
286 | using ValueTypeInfo = llvm::DenseMapInfo<typename T::ValueType>; |
287 | |
288 | static T getEmptyKey() { return T(ValueTypeInfo::getEmptyKey(), nullptr); } |
289 | |
290 | static T getTombstoneKey() { |
291 | return T(ValueTypeInfo::getTombstoneKey(), nullptr); |
292 | } |
293 | |
294 | static unsigned getHashValue(T val) { |
295 | return ValueTypeInfo::getHashValue(val); |
296 | } |
297 | |
298 | static bool isEqual(T lhs, T rhs) { return ValueTypeInfo::isEqual(lhs, rhs); } |
299 | }; |
300 | |
301 | } // namespace llvm |
302 | |
303 | #endif |
304 | |