| 1 | //===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H |
| 10 | #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H |
| 11 | |
| 12 | #include "mlir/IR/Dialect.h" |
| 13 | #include "mlir/IR/PatternMatch.h" |
| 14 | #include "mlir/Support/LLVM.h" |
| 15 | #include "mlir/Support/TypeID.h" |
| 16 | #include "llvm/ADT/DenseMap.h" |
| 17 | #include "llvm/ADT/StringMap.h" |
| 18 | #include <optional> |
| 19 | |
| 20 | namespace mlir { |
| 21 | namespace transform { |
| 22 | |
| 23 | namespace detail { |
| 24 | /// Concrete base class for CRTP TransformDialectDataBase. Must not be used |
| 25 | /// directly. |
| 26 | class TransformDialectDataBase { |
| 27 | public: |
| 28 | virtual ~TransformDialectDataBase() = default; |
| 29 | |
| 30 | /// Returns the dynamic type ID of the subclass. |
| 31 | TypeID getTypeID() const { return typeID; } |
| 32 | |
| 33 | protected: |
| 34 | /// Must be called by the subclass with the appropriate type ID. |
| 35 | explicit TransformDialectDataBase(TypeID typeID, MLIRContext *ctx) |
| 36 | : typeID(typeID), ctx(ctx) {} |
| 37 | |
| 38 | /// Return the MLIR context. |
| 39 | MLIRContext *getContext() const { return ctx; } |
| 40 | |
| 41 | private: |
| 42 | /// The type ID of the subclass. |
| 43 | const TypeID typeID; |
| 44 | |
| 45 | /// The MLIR context. |
| 46 | MLIRContext *ctx; |
| 47 | }; |
| 48 | } // namespace detail |
| 49 | |
| 50 | /// Base class for additional data owned by the Transform dialect. Extensions |
| 51 | /// may communicate with each other using this data. The data object is |
| 52 | /// identified by the TypeID of the specific data subclass, querying the data of |
| 53 | /// the same subclass returns a reference to the same object. When a Transform |
| 54 | /// dialect extension is initialized, it can populate the data in the specific |
| 55 | /// subclass. When a Transform op is applied, it can read (but not mutate) the |
| 56 | /// data in the specific subclass, including the data provided by other |
| 57 | /// extensions. |
| 58 | /// |
| 59 | /// This follows CRTP: derived classes must list themselves as template |
| 60 | /// argument. |
| 61 | template <typename DerivedTy> |
| 62 | class TransformDialectData : public detail::TransformDialectDataBase { |
| 63 | protected: |
| 64 | /// Forward the TypeID of the derived class to the base. |
| 65 | TransformDialectData(MLIRContext *ctx) |
| 66 | : TransformDialectDataBase(TypeID::get<DerivedTy>(), ctx) {} |
| 67 | }; |
| 68 | |
| 69 | #ifndef NDEBUG |
| 70 | namespace detail { |
| 71 | /// Asserts that the operations provided as template arguments implement the |
| 72 | /// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic |
| 73 | /// assertion since interface implementations may be registered at runtime. |
| 74 | void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context); |
| 75 | |
| 76 | /// Asserts that the type provided as template argument implements the |
| 77 | /// TransformHandleTypeInterface. This must be a dynamic assertion since |
| 78 | /// interface implementations may be registered at runtime. |
| 79 | void checkImplementsTransformHandleTypeInterface(TypeID typeID, |
| 80 | MLIRContext *context); |
| 81 | } // namespace detail |
| 82 | #endif // NDEBUG |
| 83 | } // namespace transform |
| 84 | } // namespace mlir |
| 85 | |
| 86 | #include "mlir/Dialect/Transform/IR/TransformDialect.h.inc" |
| 87 | |
| 88 | namespace mlir { |
| 89 | namespace transform { |
| 90 | |
| 91 | /// Base class for extensions of the Transform dialect that supports injecting |
| 92 | /// operations into the Transform dialect at load time. Concrete extensions are |
| 93 | /// expected to derive this class and register operations in the constructor. |
| 94 | /// They can be registered with the DialectRegistry and automatically applied |
| 95 | /// to the Transform dialect when it is loaded. |
| 96 | /// |
| 97 | /// Derived classes are expected to define a `void init()` function in which |
| 98 | /// they can call various protected methods of the base class to register |
| 99 | /// extension operations and declare their dependencies. |
| 100 | /// |
| 101 | /// By default, the extension is configured both for construction of the |
| 102 | /// Transform IR and for its application to some payload. If only the |
| 103 | /// construction is desired, the extension can be switched to "build-only" mode |
| 104 | /// that avoids loading the dialects that are only necessary for transforming |
| 105 | /// the payload. To perform the switch, the extension must be wrapped into the |
| 106 | /// `BuildOnly` class template (see below) when it is registered, as in: |
| 107 | /// |
| 108 | /// dialectRegistry.addExtension<BuildOnly<MyTransformDialectExt>>(); |
| 109 | /// |
| 110 | /// instead of: |
| 111 | /// |
| 112 | /// dialectRegistry.addExtension<MyTransformDialectExt>(); |
| 113 | /// |
| 114 | /// Derived classes must reexport the constructor of this class or otherwise |
| 115 | /// forward its boolean argument to support this behavior. |
| 116 | template <typename DerivedTy, typename... ExtraDialects> |
| 117 | class TransformDialectExtension |
| 118 | : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> { |
| 119 | using Initializer = std::function<void(TransformDialect *)>; |
| 120 | using DialectLoader = std::function<void(MLIRContext *)>; |
| 121 | |
| 122 | public: |
| 123 | /// Extension application hook. Actually loads the dependent dialects and |
| 124 | /// registers the additional operations. Not expected to be called directly. |
| 125 | void apply(MLIRContext *context, TransformDialect *transformDialect, |
| 126 | ExtraDialects *...) const final { |
| 127 | for (const DialectLoader &loader : dialectLoaders) |
| 128 | loader(context); |
| 129 | |
| 130 | // Only load generated dialects if the user intends to apply |
| 131 | // transformations specified by the extension. |
| 132 | if (!buildOnly) |
| 133 | for (const DialectLoader &loader : generatedDialectLoaders) |
| 134 | loader(context); |
| 135 | |
| 136 | for (const Initializer &init : initializers) |
| 137 | init(transformDialect); |
| 138 | } |
| 139 | |
| 140 | protected: |
| 141 | using Base = TransformDialectExtension<DerivedTy, ExtraDialects...>; |
| 142 | |
| 143 | /// Extension constructor. The argument indicates whether to skip generated |
| 144 | /// dialects when applying the extension. |
| 145 | explicit TransformDialectExtension(bool buildOnly = false) |
| 146 | : buildOnly(buildOnly) { |
| 147 | static_cast<DerivedTy *>(this)->init(); |
| 148 | } |
| 149 | |
| 150 | /// Registers a custom initialization step to be performed when the extension |
| 151 | /// is applied to the dialect while loading. This is discouraged in favor of |
| 152 | /// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer` |
| 153 | /// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It |
| 154 | /// will be called during the extension initialization and given the current |
| 155 | /// MLIR context. This may be used to attach additional interfaces that cannot |
| 156 | /// be attached elsewhere. |
| 157 | template <typename Func> |
| 158 | void addCustomInitializationStep(Func &&func) { |
| 159 | std::function<void(MLIRContext *)> initializer = func; |
| 160 | dialectLoaders.push_back( |
| 161 | [init = std::move(initializer)](MLIRContext *ctx) { init(ctx); }); |
| 162 | } |
| 163 | |
| 164 | /// Registers the given function as one of the initializers for the |
| 165 | /// dialect-owned data of the kind specified as template argument. The |
| 166 | /// function must be convertible to the `void (DataTy &)` form. It will be |
| 167 | /// called during the extension initialization and will be given a mutable |
| 168 | /// reference to `DataTy`. The callback is expected to append data to the |
| 169 | /// given storage, and is not allowed to remove or destructively mutate the |
| 170 | /// existing data. The order in which callbacks from different extensions are |
| 171 | /// executed is unspecified so the callbacks may not rely on data being |
| 172 | /// already present. `DataTy` must be a class deriving `TransformDialectData`. |
| 173 | template <typename DataTy, typename Func> |
| 174 | void addDialectDataInitializer(Func &&func) { |
| 175 | static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>, |
| 176 | "only classes deriving TransformDialectData are accepted" ); |
| 177 | |
| 178 | std::function<void(DataTy &)> initializer = func; |
| 179 | initializers.push_back( |
| 180 | [init = std::move(initializer)](TransformDialect *transformDialect) { |
| 181 | init(transformDialect->getOrCreateExtraData<DataTy>()); |
| 182 | }); |
| 183 | } |
| 184 | |
| 185 | /// Hook for derived classes to inject constructor behavior. |
| 186 | void init() {} |
| 187 | |
| 188 | /// Injects the operations into the Transform dialect. The operations must |
| 189 | /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the |
| 190 | /// implementations must be already available when the operation is injected. |
| 191 | template <typename... OpTys> |
| 192 | void registerTransformOps() { |
| 193 | initializers.push_back([](TransformDialect *transformDialect) { |
| 194 | transformDialect->addOperationsChecked<OpTys...>(); |
| 195 | }); |
| 196 | } |
| 197 | |
| 198 | /// Injects the types into the Transform dialect. The types must implement |
| 199 | /// the TransformHandleTypeInterface and the implementation must be already |
| 200 | /// available when the type is injected. Furthermore, the types must provide |
| 201 | /// a `getMnemonic` static method returning an object convertible to |
| 202 | /// `StringRef` that is unique across all injected types. |
| 203 | template <typename... TypeTys> |
| 204 | void registerTypes() { |
| 205 | initializers.push_back([](TransformDialect *transformDialect) { |
| 206 | transformDialect->addTypesChecked<TypeTys...>(); |
| 207 | }); |
| 208 | } |
| 209 | |
| 210 | /// Declares that this Transform dialect extension depends on the dialect |
| 211 | /// provided as template parameter. When the Transform dialect is loaded, |
| 212 | /// dependent dialects will be loaded as well. This is intended for dialects |
| 213 | /// that contain attributes and types used in creation and canonicalization of |
| 214 | /// the injected operations, similarly to how the dialect definition may list |
| 215 | /// dependent dialects. This is *not* intended for dialects entities from |
| 216 | /// which may be produced when applying the transformations specified by ops |
| 217 | /// registered by this extension. |
| 218 | template <typename DialectTy> |
| 219 | void declareDependentDialect() { |
| 220 | dialectLoaders.push_back( |
| 221 | [](MLIRContext *context) { context->loadDialect<DialectTy>(); }); |
| 222 | } |
| 223 | |
| 224 | /// Declares that the transformations associated with the operations |
| 225 | /// registered by this dialect extension may produce operations from the |
| 226 | /// dialect provided as template parameter while processing payload IR that |
| 227 | /// does not contain the operations from said dialect. This is similar to |
| 228 | /// dependent dialects of a pass. These dialects will be loaded along with the |
| 229 | /// transform dialect unless the extension is in the build-only mode. |
| 230 | template <typename DialectTy> |
| 231 | void declareGeneratedDialect() { |
| 232 | generatedDialectLoaders.push_back( |
| 233 | [](MLIRContext *context) { context->loadDialect<DialectTy>(); }); |
| 234 | } |
| 235 | |
| 236 | private: |
| 237 | /// Callbacks performing extension initialization, e.g., registering ops, |
| 238 | /// types and defining the additional data. |
| 239 | SmallVector<Initializer> initializers; |
| 240 | |
| 241 | /// Callbacks loading the dependent dialects, i.e. the dialect needed for the |
| 242 | /// extension ops. |
| 243 | SmallVector<DialectLoader> dialectLoaders; |
| 244 | |
| 245 | /// Callbacks loading the generated dialects, i.e. the dialects produced when |
| 246 | /// applying the transformations. |
| 247 | SmallVector<DialectLoader> generatedDialectLoaders; |
| 248 | |
| 249 | /// Indicates that the extension is in build-only mode. |
| 250 | bool buildOnly; |
| 251 | }; |
| 252 | |
| 253 | template <typename OpTy> |
| 254 | void TransformDialect::addOperationIfNotRegistered() { |
| 255 | std::optional<RegisteredOperationName> opName = |
| 256 | RegisteredOperationName::lookup(TypeID::get<OpTy>(), getContext()); |
| 257 | if (!opName) { |
| 258 | addOperations<OpTy>(); |
| 259 | #ifndef NDEBUG |
| 260 | StringRef name = OpTy::getOperationName(); |
| 261 | detail::checkImplementsTransformOpInterface(name, getContext()); |
| 262 | #endif // NDEBUG |
| 263 | return; |
| 264 | } |
| 265 | |
| 266 | if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>())) |
| 267 | return; |
| 268 | |
| 269 | reportDuplicateOpRegistration(OpTy::getOperationName()); |
| 270 | } |
| 271 | |
| 272 | template <typename Type> |
| 273 | void TransformDialect::addTypeIfNotRegistered() { |
| 274 | // Use the address of the parse method as a proxy for identifying whether we |
| 275 | // are registering the same type class for the same mnemonic. |
| 276 | StringRef mnemonic = Type::getMnemonic(); |
| 277 | auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse); |
| 278 | if (!inserted) { |
| 279 | const ExtensionTypeParsingHook &parsingHook = it->getValue(); |
| 280 | if (parsingHook != &Type::parse) |
| 281 | reportDuplicateTypeRegistration(mnemonic); |
| 282 | else |
| 283 | return; |
| 284 | } |
| 285 | typePrintingHooks.try_emplace( |
| 286 | TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) { |
| 287 | printer << Type::getMnemonic(); |
| 288 | cast<Type>(type).print(printer); |
| 289 | }); |
| 290 | addTypes<Type>(); |
| 291 | |
| 292 | #ifndef NDEBUG |
| 293 | detail::checkImplementsTransformHandleTypeInterface(TypeID::get<Type>(), |
| 294 | getContext()); |
| 295 | #endif // NDEBUG |
| 296 | } |
| 297 | |
| 298 | template <typename DataTy> |
| 299 | DataTy &TransformDialect::getOrCreateExtraData() { |
| 300 | TypeID typeID = TypeID::get<DataTy>(); |
| 301 | auto [it, inserted] = extraData.try_emplace(typeID); |
| 302 | if (inserted) |
| 303 | it->getSecond() = std::make_unique<DataTy>(getContext()); |
| 304 | return static_cast<DataTy &>(*it->getSecond()); |
| 305 | } |
| 306 | |
| 307 | /// A wrapper for transform dialect extensions that forces them to be |
| 308 | /// constructed in the build-only mode. |
| 309 | template <typename DerivedTy> |
| 310 | class BuildOnly : public DerivedTy { |
| 311 | public: |
| 312 | BuildOnly() : DerivedTy(/*buildOnly=*/true) {} |
| 313 | }; |
| 314 | |
| 315 | } // namespace transform |
| 316 | } // namespace mlir |
| 317 | |
| 318 | #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H |
| 319 | |