1 | //===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H |
10 | #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H |
11 | |
12 | #include "mlir/IR/Dialect.h" |
13 | #include "mlir/IR/PatternMatch.h" |
14 | #include "mlir/Support/LLVM.h" |
15 | #include "mlir/Support/TypeID.h" |
16 | #include "llvm/ADT/DenseMap.h" |
17 | #include "llvm/ADT/StringMap.h" |
18 | #include <optional> |
19 | |
20 | namespace mlir { |
21 | namespace transform { |
22 | |
23 | namespace detail { |
24 | /// Concrete base class for CRTP TransformDialectDataBase. Must not be used |
25 | /// directly. |
26 | class TransformDialectDataBase { |
27 | public: |
28 | virtual ~TransformDialectDataBase() = default; |
29 | |
30 | /// Returns the dynamic type ID of the subclass. |
31 | TypeID getTypeID() const { return typeID; } |
32 | |
33 | protected: |
34 | /// Must be called by the subclass with the appropriate type ID. |
35 | explicit TransformDialectDataBase(TypeID typeID, MLIRContext *ctx) |
36 | : typeID(typeID), ctx(ctx) {} |
37 | |
38 | /// Return the MLIR context. |
39 | MLIRContext *getContext() const { return ctx; } |
40 | |
41 | private: |
42 | /// The type ID of the subclass. |
43 | const TypeID typeID; |
44 | |
45 | /// The MLIR context. |
46 | MLIRContext *ctx; |
47 | }; |
48 | } // namespace detail |
49 | |
50 | /// Base class for additional data owned by the Transform dialect. Extensions |
51 | /// may communicate with each other using this data. The data object is |
52 | /// identified by the TypeID of the specific data subclass, querying the data of |
53 | /// the same subclass returns a reference to the same object. When a Transform |
54 | /// dialect extension is initialized, it can populate the data in the specific |
55 | /// subclass. When a Transform op is applied, it can read (but not mutate) the |
56 | /// data in the specific subclass, including the data provided by other |
57 | /// extensions. |
58 | /// |
59 | /// This follows CRTP: derived classes must list themselves as template |
60 | /// argument. |
61 | template <typename DerivedTy> |
62 | class TransformDialectData : public detail::TransformDialectDataBase { |
63 | protected: |
64 | /// Forward the TypeID of the derived class to the base. |
65 | TransformDialectData(MLIRContext *ctx) |
66 | : TransformDialectDataBase(TypeID::get<DerivedTy>(), ctx) {} |
67 | }; |
68 | |
69 | #ifndef NDEBUG |
70 | namespace detail { |
71 | /// Asserts that the operations provided as template arguments implement the |
72 | /// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic |
73 | /// assertion since interface implementations may be registered at runtime. |
74 | void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context); |
75 | |
76 | /// Asserts that the type provided as template argument implements the |
77 | /// TransformHandleTypeInterface. This must be a dynamic assertion since |
78 | /// interface implementations may be registered at runtime. |
79 | void checkImplementsTransformHandleTypeInterface(TypeID typeID, |
80 | MLIRContext *context); |
81 | } // namespace detail |
82 | #endif // NDEBUG |
83 | } // namespace transform |
84 | } // namespace mlir |
85 | |
86 | #include "mlir/Dialect/Transform/IR/TransformDialect.h.inc" |
87 | |
88 | namespace mlir { |
89 | namespace transform { |
90 | |
91 | /// Base class for extensions of the Transform dialect that supports injecting |
92 | /// operations into the Transform dialect at load time. Concrete extensions are |
93 | /// expected to derive this class and register operations in the constructor. |
94 | /// They can be registered with the DialectRegistry and automatically applied |
95 | /// to the Transform dialect when it is loaded. |
96 | /// |
97 | /// Derived classes are expected to define a `void init()` function in which |
98 | /// they can call various protected methods of the base class to register |
99 | /// extension operations and declare their dependencies. |
100 | /// |
101 | /// By default, the extension is configured both for construction of the |
102 | /// Transform IR and for its application to some payload. If only the |
103 | /// construction is desired, the extension can be switched to "build-only" mode |
104 | /// that avoids loading the dialects that are only necessary for transforming |
105 | /// the payload. To perform the switch, the extension must be wrapped into the |
106 | /// `BuildOnly` class template (see below) when it is registered, as in: |
107 | /// |
108 | /// dialectRegistry.addExtension<BuildOnly<MyTransformDialectExt>>(); |
109 | /// |
110 | /// instead of: |
111 | /// |
112 | /// dialectRegistry.addExtension<MyTransformDialectExt>(); |
113 | /// |
114 | /// Derived classes must reexport the constructor of this class or otherwise |
115 | /// forward its boolean argument to support this behavior. |
116 | template <typename DerivedTy, typename... ExtraDialects> |
117 | class TransformDialectExtension |
118 | : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> { |
119 | using Initializer = std::function<void(TransformDialect *)>; |
120 | using DialectLoader = std::function<void(MLIRContext *)>; |
121 | |
122 | public: |
123 | /// Extension application hook. Actually loads the dependent dialects and |
124 | /// registers the additional operations. Not expected to be called directly. |
125 | void apply(MLIRContext *context, TransformDialect *transformDialect, |
126 | ExtraDialects *...) const final { |
127 | for (const DialectLoader &loader : dialectLoaders) |
128 | loader(context); |
129 | |
130 | // Only load generated dialects if the user intends to apply |
131 | // transformations specified by the extension. |
132 | if (!buildOnly) |
133 | for (const DialectLoader &loader : generatedDialectLoaders) |
134 | loader(context); |
135 | |
136 | for (const Initializer &init : initializers) |
137 | init(transformDialect); |
138 | } |
139 | |
140 | protected: |
141 | using Base = TransformDialectExtension<DerivedTy, ExtraDialects...>; |
142 | |
143 | /// Extension constructor. The argument indicates whether to skip generated |
144 | /// dialects when applying the extension. |
145 | explicit TransformDialectExtension(bool buildOnly = false) |
146 | : buildOnly(buildOnly) { |
147 | static_cast<DerivedTy *>(this)->init(); |
148 | } |
149 | |
150 | /// Registers a custom initialization step to be performed when the extension |
151 | /// is applied to the dialect while loading. This is discouraged in favor of |
152 | /// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer` |
153 | /// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It |
154 | /// will be called during the extension initialization and given the current |
155 | /// MLIR context. This may be used to attach additional interfaces that cannot |
156 | /// be attached elsewhere. |
157 | template <typename Func> |
158 | void addCustomInitializationStep(Func &&func) { |
159 | std::function<void(MLIRContext *)> initializer = func; |
160 | dialectLoaders.push_back( |
161 | [init = std::move(initializer)](MLIRContext *ctx) { init(ctx); }); |
162 | } |
163 | |
164 | /// Registers the given function as one of the initializers for the |
165 | /// dialect-owned data of the kind specified as template argument. The |
166 | /// function must be convertible to the `void (DataTy &)` form. It will be |
167 | /// called during the extension initialization and will be given a mutable |
168 | /// reference to `DataTy`. The callback is expected to append data to the |
169 | /// given storage, and is not allowed to remove or destructively mutate the |
170 | /// existing data. The order in which callbacks from different extensions are |
171 | /// executed is unspecified so the callbacks may not rely on data being |
172 | /// already present. `DataTy` must be a class deriving `TransformDialectData`. |
173 | template <typename DataTy, typename Func> |
174 | void addDialectDataInitializer(Func &&func) { |
175 | static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>, |
176 | "only classes deriving TransformDialectData are accepted" ); |
177 | |
178 | std::function<void(DataTy &)> initializer = func; |
179 | initializers.push_back( |
180 | [init = std::move(initializer)](TransformDialect *transformDialect) { |
181 | init(transformDialect->getOrCreateExtraData<DataTy>()); |
182 | }); |
183 | } |
184 | |
185 | /// Hook for derived classes to inject constructor behavior. |
186 | void init() {} |
187 | |
188 | /// Injects the operations into the Transform dialect. The operations must |
189 | /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the |
190 | /// implementations must be already available when the operation is injected. |
191 | template <typename... OpTys> |
192 | void registerTransformOps() { |
193 | initializers.push_back([](TransformDialect *transformDialect) { |
194 | transformDialect->addOperationsChecked<OpTys...>(); |
195 | }); |
196 | } |
197 | |
198 | /// Injects the types into the Transform dialect. The types must implement |
199 | /// the TransformHandleTypeInterface and the implementation must be already |
200 | /// available when the type is injected. Furthermore, the types must provide |
201 | /// a `getMnemonic` static method returning an object convertible to |
202 | /// `StringRef` that is unique across all injected types. |
203 | template <typename... TypeTys> |
204 | void registerTypes() { |
205 | initializers.push_back([](TransformDialect *transformDialect) { |
206 | transformDialect->addTypesChecked<TypeTys...>(); |
207 | }); |
208 | } |
209 | |
210 | /// Declares that this Transform dialect extension depends on the dialect |
211 | /// provided as template parameter. When the Transform dialect is loaded, |
212 | /// dependent dialects will be loaded as well. This is intended for dialects |
213 | /// that contain attributes and types used in creation and canonicalization of |
214 | /// the injected operations, similarly to how the dialect definition may list |
215 | /// dependent dialects. This is *not* intended for dialects entities from |
216 | /// which may be produced when applying the transformations specified by ops |
217 | /// registered by this extension. |
218 | template <typename DialectTy> |
219 | void declareDependentDialect() { |
220 | dialectLoaders.push_back( |
221 | [](MLIRContext *context) { context->loadDialect<DialectTy>(); }); |
222 | } |
223 | |
224 | /// Declares that the transformations associated with the operations |
225 | /// registered by this dialect extension may produce operations from the |
226 | /// dialect provided as template parameter while processing payload IR that |
227 | /// does not contain the operations from said dialect. This is similar to |
228 | /// dependent dialects of a pass. These dialects will be loaded along with the |
229 | /// transform dialect unless the extension is in the build-only mode. |
230 | template <typename DialectTy> |
231 | void declareGeneratedDialect() { |
232 | generatedDialectLoaders.push_back( |
233 | [](MLIRContext *context) { context->loadDialect<DialectTy>(); }); |
234 | } |
235 | |
236 | private: |
237 | /// Callbacks performing extension initialization, e.g., registering ops, |
238 | /// types and defining the additional data. |
239 | SmallVector<Initializer> initializers; |
240 | |
241 | /// Callbacks loading the dependent dialects, i.e. the dialect needed for the |
242 | /// extension ops. |
243 | SmallVector<DialectLoader> dialectLoaders; |
244 | |
245 | /// Callbacks loading the generated dialects, i.e. the dialects produced when |
246 | /// applying the transformations. |
247 | SmallVector<DialectLoader> generatedDialectLoaders; |
248 | |
249 | /// Indicates that the extension is in build-only mode. |
250 | bool buildOnly; |
251 | }; |
252 | |
253 | template <typename OpTy> |
254 | void TransformDialect::addOperationIfNotRegistered() { |
255 | std::optional<RegisteredOperationName> opName = |
256 | RegisteredOperationName::lookup(TypeID::get<OpTy>(), getContext()); |
257 | if (!opName) { |
258 | addOperations<OpTy>(); |
259 | #ifndef NDEBUG |
260 | StringRef name = OpTy::getOperationName(); |
261 | detail::checkImplementsTransformOpInterface(name, getContext()); |
262 | #endif // NDEBUG |
263 | return; |
264 | } |
265 | |
266 | if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>())) |
267 | return; |
268 | |
269 | reportDuplicateOpRegistration(OpTy::getOperationName()); |
270 | } |
271 | |
272 | template <typename Type> |
273 | void TransformDialect::addTypeIfNotRegistered() { |
274 | // Use the address of the parse method as a proxy for identifying whether we |
275 | // are registering the same type class for the same mnemonic. |
276 | StringRef mnemonic = Type::getMnemonic(); |
277 | auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse); |
278 | if (!inserted) { |
279 | const ExtensionTypeParsingHook &parsingHook = it->getValue(); |
280 | if (parsingHook != &Type::parse) |
281 | reportDuplicateTypeRegistration(mnemonic); |
282 | else |
283 | return; |
284 | } |
285 | typePrintingHooks.try_emplace( |
286 | TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) { |
287 | printer << Type::getMnemonic(); |
288 | cast<Type>(type).print(printer); |
289 | }); |
290 | addTypes<Type>(); |
291 | |
292 | #ifndef NDEBUG |
293 | detail::checkImplementsTransformHandleTypeInterface(TypeID::get<Type>(), |
294 | getContext()); |
295 | #endif // NDEBUG |
296 | } |
297 | |
298 | template <typename DataTy> |
299 | DataTy &TransformDialect::getOrCreateExtraData() { |
300 | TypeID typeID = TypeID::get<DataTy>(); |
301 | auto it = extraData.find(typeID); |
302 | if (it != extraData.end()) |
303 | return static_cast<DataTy &>(*it->getSecond()); |
304 | |
305 | auto emplaced = |
306 | extraData.try_emplace(typeID, std::make_unique<DataTy>(getContext())); |
307 | return static_cast<DataTy &>(*emplaced.first->getSecond()); |
308 | } |
309 | |
310 | /// A wrapper for transform dialect extensions that forces them to be |
311 | /// constructed in the build-only mode. |
312 | template <typename DerivedTy> |
313 | class BuildOnly : public DerivedTy { |
314 | public: |
315 | BuildOnly() : DerivedTy(/*buildOnly=*/true) {} |
316 | }; |
317 | |
318 | } // namespace transform |
319 | } // namespace mlir |
320 | |
321 | #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H |
322 | |