1 | //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_IR_DIALECTINTERFACE_H |
10 | #define MLIR_IR_DIALECTINTERFACE_H |
11 | |
12 | #include "mlir/Support/TypeID.h" |
13 | #include "llvm/ADT/DenseSet.h" |
14 | #include "llvm/ADT/STLExtras.h" |
15 | |
16 | namespace mlir { |
17 | class Dialect; |
18 | class MLIRContext; |
19 | class Operation; |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | // DialectInterface |
23 | //===----------------------------------------------------------------------===// |
24 | namespace detail { |
25 | /// The base class used for all derived interface types. This class provides |
26 | /// utilities necessary for registration. |
27 | template <typename ConcreteType, typename BaseT> |
28 | class DialectInterfaceBase : public BaseT { |
29 | public: |
30 | using Base = DialectInterfaceBase<ConcreteType, BaseT>; |
31 | |
32 | /// Get a unique id for the derived interface type. |
33 | static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } |
34 | |
35 | protected: |
36 | DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {} |
37 | }; |
38 | } // namespace detail |
39 | |
40 | /// This class represents an interface overridden for a single dialect. |
41 | class DialectInterface { |
42 | public: |
43 | virtual ~DialectInterface(); |
44 | |
45 | /// The base class used for all derived interface types. This class provides |
46 | /// utilities necessary for registration. |
47 | template <typename ConcreteType> |
48 | using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>; |
49 | |
50 | /// Return the dialect that this interface represents. |
51 | Dialect *getDialect() const { return dialect; } |
52 | |
53 | /// Return the context that holds the parent dialect of this interface. |
54 | MLIRContext *getContext() const; |
55 | |
56 | /// Return the derived interface id. |
57 | TypeID getID() const { return interfaceID; } |
58 | |
59 | protected: |
60 | DialectInterface(Dialect *dialect, TypeID id) |
61 | : dialect(dialect), interfaceID(id) {} |
62 | |
63 | private: |
64 | /// The dialect that represents this interface. |
65 | Dialect *dialect; |
66 | |
67 | /// The unique identifier for the derived interface type. |
68 | TypeID interfaceID; |
69 | }; |
70 | |
71 | //===----------------------------------------------------------------------===// |
72 | // DialectInterfaceCollection |
73 | //===----------------------------------------------------------------------===// |
74 | |
75 | namespace detail { |
76 | /// This class is the base class for a collection of instances for a specific |
77 | /// interface kind. |
78 | class DialectInterfaceCollectionBase { |
79 | /// DenseMap info for dialect interfaces that allows lookup by the dialect. |
80 | struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> { |
81 | using DenseMapInfo<const DialectInterface *>::isEqual; |
82 | |
83 | static unsigned getHashValue(Dialect *key) { return llvm::hash_value(ptr: key); } |
84 | static unsigned getHashValue(const DialectInterface *key) { |
85 | return getHashValue(key: key->getDialect()); |
86 | } |
87 | |
88 | static bool isEqual(Dialect *lhs, const DialectInterface *rhs) { |
89 | if (rhs == getEmptyKey() || rhs == getTombstoneKey()) |
90 | return false; |
91 | return lhs == rhs->getDialect(); |
92 | } |
93 | }; |
94 | |
95 | /// A set of registered dialect interface instances. |
96 | using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>; |
97 | using InterfaceVectorT = std::vector<const DialectInterface *>; |
98 | |
99 | public: |
100 | DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind, |
101 | StringRef interfaceName); |
102 | virtual ~DialectInterfaceCollectionBase(); |
103 | |
104 | protected: |
105 | /// Get the interface for the dialect of given operation, or null if one |
106 | /// is not registered. |
107 | const DialectInterface *getInterfaceFor(Operation *op) const; |
108 | |
109 | /// Get the interface for the given dialect. |
110 | const DialectInterface *getInterfaceFor(Dialect *dialect) const { |
111 | auto it = interfaces.find_as(Val: dialect); |
112 | return it == interfaces.end() ? nullptr : *it; |
113 | } |
114 | |
115 | /// An iterator class that iterates the held interface objects of the given |
116 | /// derived interface type. |
117 | template <typename InterfaceT> |
118 | struct iterator |
119 | : public llvm::mapped_iterator_base<iterator<InterfaceT>, |
120 | InterfaceVectorT::const_iterator, |
121 | const InterfaceT &> { |
122 | using llvm::mapped_iterator_base<iterator<InterfaceT>, |
123 | InterfaceVectorT::const_iterator, |
124 | const InterfaceT &>::mapped_iterator_base; |
125 | |
126 | /// Map the element to the iterator result type. |
127 | const InterfaceT &mapElement(const DialectInterface *interface) const { |
128 | return *static_cast<const InterfaceT *>(interface); |
129 | } |
130 | }; |
131 | |
132 | /// Iterator access to the held interfaces. |
133 | template <typename InterfaceT> |
134 | iterator<InterfaceT> interface_begin() const { |
135 | return iterator<InterfaceT>(orderedInterfaces.begin()); |
136 | } |
137 | template <typename InterfaceT> |
138 | iterator<InterfaceT> interface_end() const { |
139 | return iterator<InterfaceT>(orderedInterfaces.end()); |
140 | } |
141 | |
142 | private: |
143 | /// A set of registered dialect interface instances. |
144 | InterfaceSetT interfaces; |
145 | /// An ordered list of the registered interface instances, necessary for |
146 | /// deterministic iteration. |
147 | // NOTE: SetVector does not provide find access, so it can't be used here. |
148 | InterfaceVectorT orderedInterfaces; |
149 | }; |
150 | } // namespace detail |
151 | |
152 | /// A collection of dialect interfaces within a context, for a given concrete |
153 | /// interface type. |
154 | template <typename InterfaceType> |
155 | class DialectInterfaceCollection |
156 | : public detail::DialectInterfaceCollectionBase { |
157 | public: |
158 | using Base = DialectInterfaceCollection<InterfaceType>; |
159 | |
160 | /// Collect the registered dialect interfaces within the provided context. |
161 | DialectInterfaceCollection(MLIRContext *ctx) |
162 | : detail::DialectInterfaceCollectionBase( |
163 | ctx, InterfaceType::getInterfaceID(), |
164 | llvm::getTypeName<InterfaceType>()) {} |
165 | |
166 | /// Get the interface for a given object, or null if one is not registered. |
167 | /// The object may be a dialect or an operation instance. |
168 | template <typename Object> |
169 | const InterfaceType *getInterfaceFor(Object *obj) const { |
170 | return static_cast<const InterfaceType *>( |
171 | detail::DialectInterfaceCollectionBase::getInterfaceFor(obj)); |
172 | } |
173 | |
174 | /// Iterator access to the held interfaces. |
175 | using iterator = |
176 | detail::DialectInterfaceCollectionBase::iterator<InterfaceType>; |
177 | iterator begin() const { return interface_begin<InterfaceType>(); } |
178 | iterator end() const { return interface_end<InterfaceType>(); } |
179 | |
180 | private: |
181 | using detail::DialectInterfaceCollectionBase::interface_begin; |
182 | using detail::DialectInterfaceCollectionBase::interface_end; |
183 | }; |
184 | |
185 | } // namespace mlir |
186 | |
187 | #endif |
188 | |