1//===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines several support classes for defining interfaces.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_SUPPORT_INTERFACESUPPORT_H
14#define MLIR_SUPPORT_INTERFACESUPPORT_H
15
16#include "mlir/Support/TypeID.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/DenseMap.h"
19#include "llvm/Support/TypeName.h"
20
21namespace mlir {
22namespace detail {
23//===----------------------------------------------------------------------===//
24// Interface
25//===----------------------------------------------------------------------===//
26
27/// This class represents an abstract interface. An interface is a simplified
28/// mechanism for attaching concept based polymorphism to a class hierarchy. An
29/// interface is comprised of two components:
30/// * The derived interface class: This is what users interact with, and invoke
31/// methods on.
32/// * An interface `Trait` class: This is the class that is attached to the
33/// object implementing the interface. It is the mechanism with which models
34/// are specialized.
35///
36/// Derived interfaces types must provide the following template types:
37/// * ConcreteType: The CRTP derived type.
38/// * ValueT: The opaque type the derived interface operates on. For example
39/// `Operation*` for operation interfaces, or `Attribute` for
40/// attribute interfaces.
41/// * Traits: A class that contains definitions for a 'Concept' and a 'Model'
42/// class. The 'Concept' class defines an abstract virtual interface,
43/// where as the 'Model' class implements this interface for a
44/// specific derived T type. Both of these classes *must* not contain
45/// non-static data. A simple example is shown below:
46///
47/// ```c++
48/// struct ExampleInterfaceTraits {
49/// struct Concept {
50/// virtual unsigned getNumInputs(T t) const = 0;
51/// };
52/// template <typename DerivedT> class Model {
53/// unsigned getNumInputs(T t) const final {
54/// return cast<DerivedT>(t).getNumInputs();
55/// }
56/// };
57/// };
58/// ```
59///
60/// * BaseType: A desired base type for the interface. This is a class
61/// that provides specific functionality for the `ValueT`
62/// value. For instance the specific `Op` that will wrap the
63/// `Operation*` for an `OpInterface`.
64/// * BaseTrait: The base type for the interface trait. This is the base class
65/// to use for the interface trait that will be attached to each
66/// instance of `ValueT` that implements this interface.
67///
68template <typename ConcreteType, typename ValueT, typename Traits,
69 typename BaseType,
70 template <typename, template <typename> class> class BaseTrait>
71class Interface : public BaseType {
72public:
73 using Concept = typename Traits::Concept;
74 template <typename T>
75 using Model = typename Traits::template Model<T>;
76 template <typename T>
77 using FallbackModel = typename Traits::template FallbackModel<T>;
78 using InterfaceBase =
79 Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
80 template <typename T, typename U>
81 using ExternalModel = typename Traits::template ExternalModel<T, U>;
82 using ValueType = ValueT;
83
84 /// This is a special trait that registers a given interface with an object.
85 template <typename ConcreteT>
86 struct Trait : public BaseTrait<ConcreteT, Trait> {
87 using ModelT = Model<ConcreteT>;
88
89 /// Define an accessor for the ID of this interface.
90 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
91 };
92
93 /// Construct an interface from an instance of the value type.
94 explicit Interface(ValueT t = ValueT())
95 : BaseType(t),
96 conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
97 assert((!t || conceptImpl) &&
98 "expected value to provide interface instance");
99 }
100 Interface(std::nullptr_t) : BaseType(ValueT()), conceptImpl(nullptr) {}
101
102 /// Construct an interface instance from a type that implements this
103 /// interface's trait.
104 template <typename T,
105 std::enable_if_t<std::is_base_of<Trait<T>, T>::value> * = nullptr>
106 Interface(T t)
107 : BaseType(t),
108 conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
109 assert((!t || conceptImpl) &&
110 "expected value to provide interface instance");
111 }
112
113 /// Constructor for a known concept.
114 Interface(ValueT t, const Concept *conceptImpl)
115 : BaseType(t), conceptImpl(const_cast<Concept *>(conceptImpl)) {
116 assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl);
117 }
118
119 /// Constructor for DenseMapInfo's empty key and tombstone key.
120 Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {}
121
122 /// Support 'classof' by checking if the given object defines the concrete
123 /// interface.
124 static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); }
125
126 /// Define an accessor for the ID of this interface.
127 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
128
129protected:
130 /// Get the raw concept in the correct derived concept type.
131 const Concept *getImpl() const { return conceptImpl; }
132 Concept *getImpl() { return conceptImpl; }
133
134private:
135 /// A pointer to the impl concept object.
136 Concept *conceptImpl;
137};
138
139//===----------------------------------------------------------------------===//
140// InterfaceMap
141//===----------------------------------------------------------------------===//
142
143/// Template utility that computes the number of elements within `T` that
144/// satisfy the given predicate.
145template <template <class> class Pred, size_t N, typename... Ts>
146struct count_if_t_impl : public std::integral_constant<size_t, N> {};
147template <template <class> class Pred, size_t N, typename T, typename... Us>
148struct count_if_t_impl<Pred, N, T, Us...>
149 : public std::integral_constant<
150 size_t,
151 count_if_t_impl<Pred, N + (Pred<T>::value ? 1 : 0), Us...>::value> {};
152template <template <class> class Pred, typename... Ts>
153using count_if_t = count_if_t_impl<Pred, 0, Ts...>;
154
155/// This class provides an efficient mapping between a given `Interface` type,
156/// and a particular implementation of its concept.
157class InterfaceMap {
158 /// Trait to check if T provides a static 'getInterfaceID' method.
159 template <typename T, typename... Args>
160 using has_get_interface_id = decltype(T::getInterfaceID());
161 template <typename T>
162 using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
163 template <typename... Types>
164 using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>;
165
166 /// Trait to check if T provides a 'initializeInterfaceConcept' method.
167 template <typename T, typename... Args>
168 using has_initialize_method =
169 decltype(std::declval<T>().initializeInterfaceConcept(
170 std::declval<InterfaceMap &>()));
171 template <typename T>
172 using detect_initialize_method = llvm::is_detected<has_initialize_method, T>;
173
174public:
175 InterfaceMap() = default;
176 InterfaceMap(InterfaceMap &&) = default;
177 InterfaceMap &operator=(InterfaceMap &&rhs) {
178 for (auto &it : interfaces)
179 free(ptr: it.second);
180 interfaces = std::move(rhs.interfaces);
181 return *this;
182 }
183 ~InterfaceMap() {
184 for (auto &it : interfaces)
185 free(ptr: it.second);
186 }
187
188 /// Construct an InterfaceMap with the given set of template types. For
189 /// convenience given that object trait lists may contain other non-interface
190 /// types, not all of the types need to be interfaces. The provided types that
191 /// do not represent interfaces are not added to the interface map.
192 template <typename... Types>
193 static InterfaceMap get() {
194 constexpr size_t numInterfaces = num_interface_types_t<Types...>::value;
195 if constexpr (numInterfaces == 0)
196 return InterfaceMap();
197
198 InterfaceMap map;
199 (map.insertPotentialInterface<Types>(), ...);
200 return map;
201 }
202
203 /// Returns an instance of the concept object for the given interface if it
204 /// was registered to this map, null otherwise.
205 template <typename T>
206 typename T::Concept *lookup() const {
207 return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID()));
208 }
209
210 /// Returns true if the interface map contains an interface for the given id.
211 bool contains(TypeID interfaceID) const { return lookup(id: interfaceID); }
212
213 /// Insert the given interface models.
214 template <typename... IfaceModels>
215 void insertModels() {
216 (insertModel<IfaceModels>(), ...);
217 }
218
219private:
220 /// Insert the given interface type into the map, ignoring it if it doesn't
221 /// actually represent an interface.
222 template <typename T>
223 inline void insertPotentialInterface() {
224 if constexpr (detect_get_interface_id<T>::value)
225 insertModel<typename T::ModelT>();
226 }
227
228 /// Insert the given interface model into the map.
229 template <typename InterfaceModel>
230 void insertModel() {
231 // FIXME(#59975): Uncomment this when SPIRV no longer awkwardly reimplements
232 // interfaces in a way that isn't clean/compatible.
233 // static_assert(std::is_trivially_destructible_v<InterfaceModel>,
234 // "interface models must be trivially destructible");
235
236 // Build the interface model, optionally initializing if necessary.
237 InterfaceModel *model =
238 new (malloc(size: sizeof(InterfaceModel))) InterfaceModel();
239 if constexpr (detect_initialize_method<InterfaceModel>::value)
240 model->initializeInterfaceConcept(*this);
241
242 insert(interfaceId: InterfaceModel::Interface::getInterfaceID(), conceptImpl: model);
243 }
244 /// Insert the given set of interface id and concept implementation into the
245 /// interface map.
246 void insert(TypeID interfaceId, void *conceptImpl);
247
248 /// Compare two TypeID instances by comparing the underlying pointer.
249 static bool compare(TypeID lhs, TypeID rhs) {
250 return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
251 }
252
253 /// Returns an instance of the concept object for the given interface id if it
254 /// was registered to this map, null otherwise.
255 void *lookup(TypeID id) const {
256 const auto *it =
257 llvm::lower_bound(Range: interfaces, Value&: id, C: [](const auto &it, TypeID id) {
258 return compare(lhs: it.first, rhs: id);
259 });
260 return (it != interfaces.end() && it->first == id) ? it->second : nullptr;
261 }
262
263 /// A list of interface instances, sorted by TypeID.
264 SmallVector<std::pair<TypeID, void *>> interfaces;
265};
266
267template <typename ConcreteType, typename ValueT, typename Traits,
268 typename BaseType,
269 template <typename, template <typename> class> class BaseTrait>
270void isInterfaceImpl(
271 Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait> &);
272
273template <typename T>
274using is_interface_t = decltype(isInterfaceImpl(std::declval<T &>()));
275
276template <typename T>
277using IsInterface = llvm::is_detected<is_interface_t, T>;
278
279} // namespace detail
280} // namespace mlir
281
282namespace llvm {
283
284template <typename T>
285struct DenseMapInfo<T, std::enable_if_t<mlir::detail::IsInterface<T>::value>> {
286 using ValueTypeInfo = llvm::DenseMapInfo<typename T::ValueType>;
287
288 static T getEmptyKey() { return T(ValueTypeInfo::getEmptyKey(), nullptr); }
289
290 static T getTombstoneKey() {
291 return T(ValueTypeInfo::getTombstoneKey(), nullptr);
292 }
293
294 static unsigned getHashValue(T val) {
295 return ValueTypeInfo::getHashValue(val);
296 }
297
298 static bool isEqual(T lhs, T rhs) { return ValueTypeInfo::isEqual(lhs, rhs); }
299};
300
301} // namespace llvm
302
303#endif
304

source code of mlir/include/mlir/Support/InterfaceSupport.h