1//===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
10#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
11
12#include "mlir/IR/Dialect.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Support/LLVM.h"
15#include "mlir/Support/TypeID.h"
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/StringMap.h"
18#include <optional>
19
20namespace mlir {
21namespace transform {
22
23namespace detail {
24/// Concrete base class for CRTP TransformDialectDataBase. Must not be used
25/// directly.
26class TransformDialectDataBase {
27public:
28 virtual ~TransformDialectDataBase() = default;
29
30 /// Returns the dynamic type ID of the subclass.
31 TypeID getTypeID() const { return typeID; }
32
33protected:
34 /// Must be called by the subclass with the appropriate type ID.
35 explicit TransformDialectDataBase(TypeID typeID, MLIRContext *ctx)
36 : typeID(typeID), ctx(ctx) {}
37
38 /// Return the MLIR context.
39 MLIRContext *getContext() const { return ctx; }
40
41private:
42 /// The type ID of the subclass.
43 const TypeID typeID;
44
45 /// The MLIR context.
46 MLIRContext *ctx;
47};
48} // namespace detail
49
50/// Base class for additional data owned by the Transform dialect. Extensions
51/// may communicate with each other using this data. The data object is
52/// identified by the TypeID of the specific data subclass, querying the data of
53/// the same subclass returns a reference to the same object. When a Transform
54/// dialect extension is initialized, it can populate the data in the specific
55/// subclass. When a Transform op is applied, it can read (but not mutate) the
56/// data in the specific subclass, including the data provided by other
57/// extensions.
58///
59/// This follows CRTP: derived classes must list themselves as template
60/// argument.
61template <typename DerivedTy>
62class TransformDialectData : public detail::TransformDialectDataBase {
63protected:
64 /// Forward the TypeID of the derived class to the base.
65 TransformDialectData(MLIRContext *ctx)
66 : TransformDialectDataBase(TypeID::get<DerivedTy>(), ctx) {}
67};
68
69#ifndef NDEBUG
70namespace detail {
71/// Asserts that the operations provided as template arguments implement the
72/// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic
73/// assertion since interface implementations may be registered at runtime.
74void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context);
75
76/// Asserts that the type provided as template argument implements the
77/// TransformHandleTypeInterface. This must be a dynamic assertion since
78/// interface implementations may be registered at runtime.
79void checkImplementsTransformHandleTypeInterface(TypeID typeID,
80 MLIRContext *context);
81} // namespace detail
82#endif // NDEBUG
83} // namespace transform
84} // namespace mlir
85
86#include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
87
88namespace mlir {
89namespace transform {
90
91/// Base class for extensions of the Transform dialect that supports injecting
92/// operations into the Transform dialect at load time. Concrete extensions are
93/// expected to derive this class and register operations in the constructor.
94/// They can be registered with the DialectRegistry and automatically applied
95/// to the Transform dialect when it is loaded.
96///
97/// Derived classes are expected to define a `void init()` function in which
98/// they can call various protected methods of the base class to register
99/// extension operations and declare their dependencies.
100///
101/// By default, the extension is configured both for construction of the
102/// Transform IR and for its application to some payload. If only the
103/// construction is desired, the extension can be switched to "build-only" mode
104/// that avoids loading the dialects that are only necessary for transforming
105/// the payload. To perform the switch, the extension must be wrapped into the
106/// `BuildOnly` class template (see below) when it is registered, as in:
107///
108/// dialectRegistry.addExtension<BuildOnly<MyTransformDialectExt>>();
109///
110/// instead of:
111///
112/// dialectRegistry.addExtension<MyTransformDialectExt>();
113///
114/// Derived classes must reexport the constructor of this class or otherwise
115/// forward its boolean argument to support this behavior.
116template <typename DerivedTy, typename... ExtraDialects>
117class TransformDialectExtension
118 : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> {
119 using Initializer = std::function<void(TransformDialect *)>;
120 using DialectLoader = std::function<void(MLIRContext *)>;
121
122public:
123 /// Extension application hook. Actually loads the dependent dialects and
124 /// registers the additional operations. Not expected to be called directly.
125 void apply(MLIRContext *context, TransformDialect *transformDialect,
126 ExtraDialects *...) const final {
127 for (const DialectLoader &loader : dialectLoaders)
128 loader(context);
129
130 // Only load generated dialects if the user intends to apply
131 // transformations specified by the extension.
132 if (!buildOnly)
133 for (const DialectLoader &loader : generatedDialectLoaders)
134 loader(context);
135
136 for (const Initializer &init : initializers)
137 init(transformDialect);
138 }
139
140protected:
141 using Base = TransformDialectExtension<DerivedTy, ExtraDialects...>;
142
143 /// Extension constructor. The argument indicates whether to skip generated
144 /// dialects when applying the extension.
145 explicit TransformDialectExtension(bool buildOnly = false)
146 : buildOnly(buildOnly) {
147 static_cast<DerivedTy *>(this)->init();
148 }
149
150 /// Registers a custom initialization step to be performed when the extension
151 /// is applied to the dialect while loading. This is discouraged in favor of
152 /// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer`
153 /// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It
154 /// will be called during the extension initialization and given the current
155 /// MLIR context. This may be used to attach additional interfaces that cannot
156 /// be attached elsewhere.
157 template <typename Func>
158 void addCustomInitializationStep(Func &&func) {
159 std::function<void(MLIRContext *)> initializer = func;
160 dialectLoaders.push_back(
161 [init = std::move(initializer)](MLIRContext *ctx) { init(ctx); });
162 }
163
164 /// Registers the given function as one of the initializers for the
165 /// dialect-owned data of the kind specified as template argument. The
166 /// function must be convertible to the `void (DataTy &)` form. It will be
167 /// called during the extension initialization and will be given a mutable
168 /// reference to `DataTy`. The callback is expected to append data to the
169 /// given storage, and is not allowed to remove or destructively mutate the
170 /// existing data. The order in which callbacks from different extensions are
171 /// executed is unspecified so the callbacks may not rely on data being
172 /// already present. `DataTy` must be a class deriving `TransformDialectData`.
173 template <typename DataTy, typename Func>
174 void addDialectDataInitializer(Func &&func) {
175 static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>,
176 "only classes deriving TransformDialectData are accepted");
177
178 std::function<void(DataTy &)> initializer = func;
179 initializers.push_back(
180 [init = std::move(initializer)](TransformDialect *transformDialect) {
181 init(transformDialect->getOrCreateExtraData<DataTy>());
182 });
183 }
184
185 /// Hook for derived classes to inject constructor behavior.
186 void init() {}
187
188 /// Injects the operations into the Transform dialect. The operations must
189 /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the
190 /// implementations must be already available when the operation is injected.
191 template <typename... OpTys>
192 void registerTransformOps() {
193 initializers.push_back([](TransformDialect *transformDialect) {
194 transformDialect->addOperationsChecked<OpTys...>();
195 });
196 }
197
198 /// Injects the types into the Transform dialect. The types must implement
199 /// the TransformHandleTypeInterface and the implementation must be already
200 /// available when the type is injected. Furthermore, the types must provide
201 /// a `getMnemonic` static method returning an object convertible to
202 /// `StringRef` that is unique across all injected types.
203 template <typename... TypeTys>
204 void registerTypes() {
205 initializers.push_back([](TransformDialect *transformDialect) {
206 transformDialect->addTypesChecked<TypeTys...>();
207 });
208 }
209
210 /// Declares that this Transform dialect extension depends on the dialect
211 /// provided as template parameter. When the Transform dialect is loaded,
212 /// dependent dialects will be loaded as well. This is intended for dialects
213 /// that contain attributes and types used in creation and canonicalization of
214 /// the injected operations, similarly to how the dialect definition may list
215 /// dependent dialects. This is *not* intended for dialects entities from
216 /// which may be produced when applying the transformations specified by ops
217 /// registered by this extension.
218 template <typename DialectTy>
219 void declareDependentDialect() {
220 dialectLoaders.push_back(
221 [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
222 }
223
224 /// Declares that the transformations associated with the operations
225 /// registered by this dialect extension may produce operations from the
226 /// dialect provided as template parameter while processing payload IR that
227 /// does not contain the operations from said dialect. This is similar to
228 /// dependent dialects of a pass. These dialects will be loaded along with the
229 /// transform dialect unless the extension is in the build-only mode.
230 template <typename DialectTy>
231 void declareGeneratedDialect() {
232 generatedDialectLoaders.push_back(
233 [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
234 }
235
236private:
237 /// Callbacks performing extension initialization, e.g., registering ops,
238 /// types and defining the additional data.
239 SmallVector<Initializer> initializers;
240
241 /// Callbacks loading the dependent dialects, i.e. the dialect needed for the
242 /// extension ops.
243 SmallVector<DialectLoader> dialectLoaders;
244
245 /// Callbacks loading the generated dialects, i.e. the dialects produced when
246 /// applying the transformations.
247 SmallVector<DialectLoader> generatedDialectLoaders;
248
249 /// Indicates that the extension is in build-only mode.
250 bool buildOnly;
251};
252
253template <typename OpTy>
254void TransformDialect::addOperationIfNotRegistered() {
255 std::optional<RegisteredOperationName> opName =
256 RegisteredOperationName::lookup(TypeID::get<OpTy>(), getContext());
257 if (!opName) {
258 addOperations<OpTy>();
259#ifndef NDEBUG
260 StringRef name = OpTy::getOperationName();
261 detail::checkImplementsTransformOpInterface(name, getContext());
262#endif // NDEBUG
263 return;
264 }
265
266 if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>()))
267 return;
268
269 reportDuplicateOpRegistration(OpTy::getOperationName());
270}
271
272template <typename Type>
273void TransformDialect::addTypeIfNotRegistered() {
274 // Use the address of the parse method as a proxy for identifying whether we
275 // are registering the same type class for the same mnemonic.
276 StringRef mnemonic = Type::getMnemonic();
277 auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse);
278 if (!inserted) {
279 const ExtensionTypeParsingHook &parsingHook = it->getValue();
280 if (parsingHook != &Type::parse)
281 reportDuplicateTypeRegistration(mnemonic);
282 else
283 return;
284 }
285 typePrintingHooks.try_emplace(
286 TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) {
287 printer << Type::getMnemonic();
288 cast<Type>(type).print(printer);
289 });
290 addTypes<Type>();
291
292#ifndef NDEBUG
293 detail::checkImplementsTransformHandleTypeInterface(TypeID::get<Type>(),
294 getContext());
295#endif // NDEBUG
296}
297
298template <typename DataTy>
299DataTy &TransformDialect::getOrCreateExtraData() {
300 TypeID typeID = TypeID::get<DataTy>();
301 auto it = extraData.find(typeID);
302 if (it != extraData.end())
303 return static_cast<DataTy &>(*it->getSecond());
304
305 auto emplaced =
306 extraData.try_emplace(typeID, std::make_unique<DataTy>(getContext()));
307 return static_cast<DataTy &>(*emplaced.first->getSecond());
308}
309
310/// A wrapper for transform dialect extensions that forces them to be
311/// constructed in the build-only mode.
312template <typename DerivedTy>
313class BuildOnly : public DerivedTy {
314public:
315 BuildOnly() : DerivedTy(/*buildOnly=*/true) {}
316};
317
318} // namespace transform
319} // namespace mlir
320
321#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
322

source code of mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h