1//===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9// This file defines utility classes for interfacing with StorageUniquer.
16#include "mlir/IR/AttrTypeSubElements.h"
17#include "mlir/IR/DialectRegistry.h"
18#include "mlir/Support/InterfaceSupport.h"
19#include "mlir/Support/LogicalResult.h"
20#include "mlir/Support/StorageUniquer.h"
21#include "mlir/Support/TypeID.h"
22#include "llvm/ADT/FunctionExtras.h"
24namespace mlir {
25class InFlightDiagnostic;
26class Location;
27class MLIRContext;
29namespace detail {
30/// Utility method to generate a callback that can be used to generate a
31/// diagnostic when checking the construction invariants of a storage object.
32/// This is defined out-of-line to avoid the need to include Location.h.
34getDefaultDiagnosticEmitFn(MLIRContext *ctx);
36getDefaultDiagnosticEmitFn(const Location &loc);
39// StorageUserTraitBase
42/// Helper class for implementing traits for storage classes. Clients are not
43/// expected to interact with this directly, so its members are all protected.
44template <typename ConcreteType, template <typename> class TraitType>
45class StorageUserTraitBase {
47 /// Return the derived instance.
48 ConcreteType getInstance() const {
49 // We have to cast up to the trait type, then to the concrete type because
50 // the concrete type will multiply derive from the (content free) TraitBase
51 // class, and we need to be able to disambiguate the path for the C++
52 // compiler.
53 auto *trait = static_cast<const TraitType<ConcreteType> *>(this);
54 return *static_cast<const ConcreteType *>(trait);
55 }
58namespace StorageUserTrait {
59/// This trait is used to determine if a storage user, like Type, is mutable
60/// or not. A storage user is mutable if ImplType of the derived class defines
61/// a `mutate` function with a proper signature. Note that this trait is not
62/// supposed to be used publicly. Users should use alias names like
63/// `TypeTrait::IsMutable` instead.
64template <typename ConcreteType>
65struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {};
66} // namespace StorageUserTrait
69// StorageUserBase
72namespace storage_user_base_impl {
73/// Returns true if this given Trait ID matches the IDs of any of the provided
74/// trait types `Traits`.
75template <template <typename T> class... Traits>
76bool hasTrait(TypeID traitID) {
77 TypeID traitIDs[] = {TypeID::get<Traits>()...};
78 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
79 if (traitIDs[i] == traitID)
80 return true;
81 return false;
84// We specialize for the empty case to not define an empty array.
85template <>
86inline bool hasTrait(TypeID traitID) {
87 return false;
89} // namespace storage_user_base_impl
91/// Utility class for implementing users of storage classes uniqued by a
92/// StorageUniquer. Clients are not expected to interact with this class
93/// directly.
94template <typename ConcreteT, typename BaseT, typename StorageT,
95 typename UniquerT, template <typename T> class... Traits>
96class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
98 using BaseT::BaseT;
100 /// Utility declarations for the concrete attribute class.
101 using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>;
102 using ImplType = StorageT;
103 using HasTraitFn = bool (*)(TypeID);
105 /// Return a unique identifier for the concrete type.
106 static TypeID getTypeID() { return TypeID::get<ConcreteT>(); }
108 /// Provide an implementation of 'classof' that compares the type id of the
109 /// provided value with that of the concrete type.
110 template <typename T>
111 static bool classof(T val) {
112 static_assert(std::is_convertible<ConcreteT, T>::value,
113 "casting from a non-convertible type");
114 return val.getTypeID() == getTypeID();
115 }
117 /// Returns an interface map for the interfaces registered to this storage
118 /// user. This should not be used directly.
119 static detail::InterfaceMap getInterfaceMap() {
120 return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
121 }
123 /// Returns the function that returns true if the given Trait ID matches the
124 /// IDs of any of the traits defined by the storage user.
125 static HasTraitFn getHasTraitFn() {
126 return [](TypeID id) {
127 return storage_user_base_impl::hasTrait<Traits...>(id);
128 };
129 }
131 /// Returns a function that walks immediate sub elements of a given instance
132 /// of the storage user.
133 static auto getWalkImmediateSubElementsFn() {
134 return [](auto instance, function_ref<void(Attribute)> walkAttrsFn,
135 function_ref<void(Type)> walkTypesFn) {
136 ::mlir::detail::walkImmediateSubElementsImpl(
137 llvm::cast<ConcreteT>(instance), walkAttrsFn, walkTypesFn);
138 };
139 }
141 /// Returns a function that replaces immediate sub elements of a given
142 /// instance of the storage user.
143 static auto getReplaceImmediateSubElementsFn() {
144 return [](auto instance, ArrayRef<Attribute> replAttrs,
145 ArrayRef<Type> replTypes) {
146 return ::mlir::detail::replaceImmediateSubElementsImpl(
147 llvm::cast<ConcreteT>(instance), replAttrs, replTypes);
148 };
149 }
151 /// Attach the given models as implementations of the corresponding interfaces
152 /// for the concrete storage user class. The type must be registered with the
153 /// context, i.e. the dialect to which the type belongs must be loaded. The
154 /// call will abort otherwise.
155 template <typename... IfaceModels>
156 static void attachInterface(MLIRContext &context) {
157 typename ConcreteT::AbstractTy *abstract =
158 ConcreteT::AbstractTy::lookupMutable(TypeID::get<ConcreteT>(),
159 &context);
160 if (!abstract)
161 llvm::report_fatal_error(reason: "Registering an interface for an attribute/type "
162 "that is not itself registered.");
164 // Handle the case where the models resolve a promised interface.
165 (dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
166 dialect&: abstract->getDialect(), interfaceRequestorID: abstract->getTypeID(),
167 interfaceID: IfaceModels::Interface::getInterfaceID()),
168 ...);
170 (checkInterfaceTarget<IfaceModels>(), ...);
171 abstract->interfaceMap.template insertModels<IfaceModels...>();
172 }
174 /// Get or create a new ConcreteT instance within the ctx. This
175 /// function is guaranteed to return a non null object and will assert if
176 /// the arguments provided are invalid.
177 template <typename... Args>
178 static ConcreteT get(MLIRContext *ctx, Args &&...args) {
179 // Ensure that the invariants are correct for construction.
180 assert(
181 succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
182 return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
183 }
185 /// Get or create a new ConcreteT instance within the ctx, defined at
186 /// the given, potentially unknown, location. If the arguments provided are
187 /// invalid, errors are emitted using the provided location and a null object
188 /// is returned.
189 template <typename... Args>
190 static ConcreteT getChecked(const Location &loc, Args &&...args) {
191 return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc),
192 std::forward<Args>(args)...);
193 }
195 /// Get or create a new ConcreteT instance within the ctx. If the arguments
196 /// provided are invalid, errors are emitted using the provided `emitError`
197 /// and a null object is returned.
198 template <typename... Args>
199 static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
200 MLIRContext *ctx, Args... args) {
201 // If the construction invariants fail then we return a null attribute.
202 if (failed(ConcreteT::verify(emitErrorFn, args...)))
203 return ConcreteT();
204 return UniquerT::template get<ConcreteT>(ctx, args...);
205 }
207 /// Get an instance of the concrete type from a void pointer.
208 static ConcreteT getFromOpaquePointer(const void *ptr) {
209 return ConcreteT((const typename BaseT::ImplType *)ptr);
210 }
212 /// Utility for easy access to the storage instance.
213 ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); }
216 /// Mutate the current storage instance. This will not change the unique key.
217 /// The arguments are forwarded to 'ConcreteT::mutate'.
218 template <typename... Args>
219 LogicalResult mutate(Args &&...args) {
220 static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>,
221 ConcreteT>::value,
222 "The `mutate` function expects mutable trait "
223 "(e.g. TypeTrait::IsMutable) to be attached on parent.");
224 return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
225 std::forward<Args>(args)...);
226 }
228 /// Default implementation that just returns success.
229 template <typename... Args>
230 static LogicalResult verify(Args... args) {
231 return success();
232 }
235 /// Trait to check if T provides a 'ConcreteEntity' type alias.
236 template <typename T>
237 using has_concrete_entity_t = typename T::ConcreteEntity;
239 /// A struct-wrapped type alias to T::ConcreteEntity if provided and to
240 /// ConcreteT otherwise. This is akin to std::conditional but doesn't fail on
241 /// the missing typedef. Useful for checking if the interface is targeting the
242 /// right class.
243 template <typename T,
244 bool = llvm::is_detected<has_concrete_entity_t, T>::value>
245 struct IfaceTargetOrConcreteT {
246 using type = typename T::ConcreteEntity;
247 };
248 template <typename T>
249 struct IfaceTargetOrConcreteT<T, false> {
250 using type = ConcreteT;
251 };
253 /// A hook for static assertion that the external interface model T is
254 /// targeting a base class of the concrete attribute/type. The model can also
255 /// be a fallback model that works for every attribute/type.
256 template <typename T>
257 static void checkInterfaceTarget() {
258 static_assert(std::is_base_of<typename IfaceTargetOrConcreteT<T>::type,
259 ConcreteT>::value,
260 "attaching an interface to the wrong attribute/type kind");
261 }
263} // namespace detail
264} // namespace mlir

source code of mlir/include/mlir/IR/StorageUniquerSupport.h