1//===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_DIALECTINTERFACE_H
10#define MLIR_IR_DIALECTINTERFACE_H
11
12#include "mlir/Support/TypeID.h"
13#include "llvm/ADT/DenseSet.h"
14#include "llvm/ADT/STLExtras.h"
15
16namespace mlir {
17class Dialect;
18class MLIRContext;
19class Operation;
20
21//===----------------------------------------------------------------------===//
22// DialectInterface
23//===----------------------------------------------------------------------===//
24namespace detail {
25/// The base class used for all derived interface types. This class provides
26/// utilities necessary for registration.
27template <typename ConcreteType, typename BaseT>
28class DialectInterfaceBase : public BaseT {
29public:
30 using Base = DialectInterfaceBase<ConcreteType, BaseT>;
31
32 /// Get a unique id for the derived interface type.
33 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
34
35protected:
36 DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
37};
38} // namespace detail
39
40/// This class represents an interface overridden for a single dialect.
41class DialectInterface {
42public:
43 virtual ~DialectInterface();
44
45 /// The base class used for all derived interface types. This class provides
46 /// utilities necessary for registration.
47 template <typename ConcreteType>
48 using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>;
49
50 /// Return the dialect that this interface represents.
51 Dialect *getDialect() const { return dialect; }
52
53 /// Return the context that holds the parent dialect of this interface.
54 MLIRContext *getContext() const;
55
56 /// Return the derived interface id.
57 TypeID getID() const { return interfaceID; }
58
59protected:
60 DialectInterface(Dialect *dialect, TypeID id)
61 : dialect(dialect), interfaceID(id) {}
62
63private:
64 /// The dialect that represents this interface.
65 Dialect *dialect;
66
67 /// The unique identifier for the derived interface type.
68 TypeID interfaceID;
69};
70
71//===----------------------------------------------------------------------===//
72// DialectInterfaceCollection
73//===----------------------------------------------------------------------===//
74
75namespace detail {
76/// This class is the base class for a collection of instances for a specific
77/// interface kind.
78class DialectInterfaceCollectionBase {
79 /// DenseMap info for dialect interfaces that allows lookup by the dialect.
80 struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> {
81 using DenseMapInfo<const DialectInterface *>::isEqual;
82
83 static unsigned getHashValue(Dialect *key) { return llvm::hash_value(ptr: key); }
84 static unsigned getHashValue(const DialectInterface *key) {
85 return getHashValue(key: key->getDialect());
86 }
87
88 static bool isEqual(Dialect *lhs, const DialectInterface *rhs) {
89 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
90 return false;
91 return lhs == rhs->getDialect();
92 }
93 };
94
95 /// A set of registered dialect interface instances.
96 using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>;
97 using InterfaceVectorT = std::vector<const DialectInterface *>;
98
99public:
100 DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind,
101 StringRef interfaceName);
102 virtual ~DialectInterfaceCollectionBase();
103
104protected:
105 /// Get the interface for the dialect of given operation, or null if one
106 /// is not registered.
107 const DialectInterface *getInterfaceFor(Operation *op) const;
108
109 /// Get the interface for the given dialect.
110 const DialectInterface *getInterfaceFor(Dialect *dialect) const {
111 auto it = interfaces.find_as(Val: dialect);
112 return it == interfaces.end() ? nullptr : *it;
113 }
114
115 /// An iterator class that iterates the held interface objects of the given
116 /// derived interface type.
117 template <typename InterfaceT>
118 struct iterator
119 : public llvm::mapped_iterator_base<iterator<InterfaceT>,
120 InterfaceVectorT::const_iterator,
121 const InterfaceT &> {
122 using llvm::mapped_iterator_base<iterator<InterfaceT>,
123 InterfaceVectorT::const_iterator,
124 const InterfaceT &>::mapped_iterator_base;
125
126 /// Map the element to the iterator result type.
127 const InterfaceT &mapElement(const DialectInterface *interface) const {
128 return *static_cast<const InterfaceT *>(interface);
129 }
130 };
131
132 /// Iterator access to the held interfaces.
133 template <typename InterfaceT>
134 iterator<InterfaceT> interface_begin() const {
135 return iterator<InterfaceT>(orderedInterfaces.begin());
136 }
137 template <typename InterfaceT>
138 iterator<InterfaceT> interface_end() const {
139 return iterator<InterfaceT>(orderedInterfaces.end());
140 }
141
142private:
143 /// A set of registered dialect interface instances.
144 InterfaceSetT interfaces;
145 /// An ordered list of the registered interface instances, necessary for
146 /// deterministic iteration.
147 // NOTE: SetVector does not provide find access, so it can't be used here.
148 InterfaceVectorT orderedInterfaces;
149};
150} // namespace detail
151
152/// A collection of dialect interfaces within a context, for a given concrete
153/// interface type.
154template <typename InterfaceType>
155class DialectInterfaceCollection
156 : public detail::DialectInterfaceCollectionBase {
157public:
158 using Base = DialectInterfaceCollection<InterfaceType>;
159
160 /// Collect the registered dialect interfaces within the provided context.
161 DialectInterfaceCollection(MLIRContext *ctx)
162 : detail::DialectInterfaceCollectionBase(
163 ctx, InterfaceType::getInterfaceID(),
164 llvm::getTypeName<InterfaceType>()) {}
165
166 /// Get the interface for a given object, or null if one is not registered.
167 /// The object may be a dialect or an operation instance.
168 template <typename Object>
169 const InterfaceType *getInterfaceFor(Object *obj) const {
170 return static_cast<const InterfaceType *>(
171 detail::DialectInterfaceCollectionBase::getInterfaceFor(obj));
172 }
173
174 /// Iterator access to the held interfaces.
175 using iterator =
176 detail::DialectInterfaceCollectionBase::iterator<InterfaceType>;
177 iterator begin() const { return interface_begin<InterfaceType>(); }
178 iterator end() const { return interface_end<InterfaceType>(); }
179
180private:
181 using detail::DialectInterfaceCollectionBase::interface_begin;
182 using detail::DialectInterfaceCollectionBase::interface_end;
183};
184
185} // namespace mlir
186
187#endif
188

source code of mlir/include/mlir/IR/DialectInterface.h