| 1 | //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines Transform dialect extension operations used in the |
| 10 | // Chapter 3 of the Transform dialect tutorial. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "MyExtension.h" |
| 15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 17 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| 18 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
| 19 | #include "mlir/IR/DialectImplementation.h" |
| 20 | #include "mlir/Interfaces/CallInterfaces.h" |
| 21 | #include "llvm/ADT/TypeSwitch.h" |
| 22 | |
| 23 | #define GET_TYPEDEF_CLASSES |
| 24 | #include "MyExtensionTypes.cpp.inc" |
| 25 | |
| 26 | #define GET_OP_CLASSES |
| 27 | #include "MyExtension.cpp.inc" |
| 28 | |
| 29 | //===---------------------------------------------------------------------===// |
| 30 | // MyExtension |
| 31 | //===---------------------------------------------------------------------===// |
| 32 | |
| 33 | // Define a new transform dialect extension. This uses the CRTP idiom to |
| 34 | // identify extensions. |
| 35 | class MyExtension |
| 36 | : public ::mlir::transform::TransformDialectExtension<MyExtension> { |
| 37 | public: |
| 38 | // The TypeID of this extension. |
| 39 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) |
| 40 | |
| 41 | // The extension must derive the base constructor. |
| 42 | using Base::Base; |
| 43 | |
| 44 | // This function initializes the extension, similarly to `initialize` in |
| 45 | // dialect definitions. List individual operations and dependent dialects |
| 46 | // here. |
| 47 | void init(); |
| 48 | }; |
| 49 | |
| 50 | void MyExtension::init() { |
| 51 | // Similarly to dialects, an extension can declare a dependent dialect. This |
| 52 | // dialect will be loaded along with the extension and, therefore, along with |
| 53 | // the Transform dialect. Only declare as dependent the dialects that contain |
| 54 | // the attributes or types used by transform operations. Do NOT declare as |
| 55 | // dependent the dialects produced during the transformation. |
| 56 | // declareDependentDialect<MyDialect>(); |
| 57 | |
| 58 | // When transformations are applied, they may produce new operations from |
| 59 | // previously unloaded dialects. Typically, a pass would need to declare |
| 60 | // itself dependent on the dialects containing such new operations. To avoid |
| 61 | // confusion with the dialects the extension itself depends on, the Transform |
| 62 | // dialects differentiates between: |
| 63 | // - dependent dialects, which are used by the transform operations, and |
| 64 | // - generated dialects, which contain the entities (attributes, operations, |
| 65 | // types) that may be produced by applying the transformation even when |
| 66 | // not present in the original payload IR. |
| 67 | // In the following chapter, we will be add operations that generate function |
| 68 | // calls and structured control flow operations, so let's declare the |
| 69 | // corresponding dialects as generated. |
| 70 | declareGeneratedDialect<::mlir::scf::SCFDialect>(); |
| 71 | declareGeneratedDialect<::mlir::func::FuncDialect>(); |
| 72 | |
| 73 | // Register the additional transform dialect types with the dialect. List all |
| 74 | // types generated from ODS. |
| 75 | registerTypes< |
| 76 | #define GET_TYPEDEF_LIST |
| 77 | #include "MyExtensionTypes.cpp.inc" |
| 78 | >(); |
| 79 | |
| 80 | // ODS generates these helpers for type printing and parsing, but the |
| 81 | // Transform dialect provides its own support for types supplied by the |
| 82 | // extension. Reference these functions to avoid a compiler warning. |
| 83 | (void)&generatedTypeParser; |
| 84 | (void)&generatedTypePrinter; |
| 85 | |
| 86 | // Finally, we register the additional transform operations with the dialect. |
| 87 | // List all operations generated from ODS. This call will perform additional |
| 88 | // checks that the operations implement the transform and memory effect |
| 89 | // interfaces required by the dialect interpreter and assert if they do not. |
| 90 | registerTransformOps< |
| 91 | #define GET_OP_LIST |
| 92 | #include "MyExtension.cpp.inc" |
| 93 | >(); |
| 94 | } |
| 95 | |
| 96 | //===---------------------------------------------------------------------===// |
| 97 | // ChangeCallTargetOp |
| 98 | //===---------------------------------------------------------------------===// |
| 99 | |
| 100 | static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { |
| 101 | call.setCallee(newTarget); |
| 102 | } |
| 103 | |
| 104 | // Implementation of our transform dialect operation. |
| 105 | // This operation returns a tri-state result that can be one of: |
| 106 | // - success when the transformation succeeded; |
| 107 | // - definite failure when the transformation failed in such a way that |
| 108 | // following |
| 109 | // transformations are impossible or undesirable, typically it could have left |
| 110 | // payload IR in an invalid state; it is expected that a diagnostic is emitted |
| 111 | // immediately before returning the definite error; |
| 112 | // - silenceable failure when the transformation failed but following |
| 113 | // transformations |
| 114 | // are still applicable, typically this means a precondition for the |
| 115 | // transformation is not satisfied and the payload IR has not been modified. |
| 116 | // The silenceable failure additionally carries a Diagnostic that can be emitted |
| 117 | // to the user. |
| 118 | ::mlir::DiagnosedSilenceableFailure |
| 119 | mlir::transform::ChangeCallTargetOp::applyToOne( |
| 120 | // The rewriter that should be used when modifying IR. |
| 121 | ::mlir::transform::TransformRewriter &rewriter, |
| 122 | // The single payload operation to which the transformation is applied. |
| 123 | ::mlir::func::CallOp call, |
| 124 | // The payload IR entities that will be appended to lists associated with |
| 125 | // the results of this transform operation. This list contains one entry per |
| 126 | // result. |
| 127 | ::mlir::transform::ApplyToEachResultList &results, |
| 128 | // The transform application state. This object can be used to query the |
| 129 | // current associations between transform IR values and payload IR entities. |
| 130 | // It can also carry additional user-defined state. |
| 131 | ::mlir::transform::TransformState &state) { |
| 132 | |
| 133 | // Dispatch to the actual transformation. |
| 134 | updateCallee(call, getNewTarget()); |
| 135 | |
| 136 | // If everything went well, return success. |
| 137 | return DiagnosedSilenceableFailure::success(); |
| 138 | } |
| 139 | |
| 140 | void mlir::transform::ChangeCallTargetOp::getEffects( |
| 141 | ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { |
| 142 | // Indicate that the `call` handle is only read by this operation because the |
| 143 | // associated operation is not erased but rather modified in-place, so the |
| 144 | // reference to it remains valid. |
| 145 | onlyReadsHandle(getCallMutable(), effects); |
| 146 | |
| 147 | // Indicate that the payload is modified by this operation. |
| 148 | modifiesPayload(effects); |
| 149 | } |
| 150 | |
| 151 | //===---------------------------------------------------------------------===// |
| 152 | // CallToOp |
| 153 | //===---------------------------------------------------------------------===// |
| 154 | |
| 155 | static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter, |
| 156 | mlir::CallOpInterface call) { |
| 157 | // Construct an operation from an unregistered dialect. This is discouraged |
| 158 | // and is only used here for brevity of the overall example. |
| 159 | mlir::OperationState state(call.getLoc(), "my.mm4" ); |
| 160 | state.types.assign(call->result_type_begin(), call->result_type_end()); |
| 161 | state.operands.assign(call->operand_begin(), call->operand_end()); |
| 162 | |
| 163 | mlir::Operation *replacement = rewriter.create(state); |
| 164 | rewriter.replaceOp(call, replacement->getResults()); |
| 165 | return replacement; |
| 166 | } |
| 167 | |
| 168 | // See above for the signature description. |
| 169 | mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne( |
| 170 | mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call, |
| 171 | mlir::transform::ApplyToEachResultList &results, |
| 172 | mlir::transform::TransformState &state) { |
| 173 | |
| 174 | // Dispatch to the actual transformation. |
| 175 | Operation *replacement = replaceCallWithOp(rewriter, call); |
| 176 | |
| 177 | // Associate the payload operation produced by the rewrite with the result |
| 178 | // handle of this transform operation. |
| 179 | results.push_back(replacement); |
| 180 | |
| 181 | // If everything went well, return success. |
| 182 | return DiagnosedSilenceableFailure::success(); |
| 183 | } |
| 184 | |
| 185 | //===---------------------------------------------------------------------===// |
| 186 | // CallOpInterfaceHandleType |
| 187 | //===---------------------------------------------------------------------===// |
| 188 | |
| 189 | // The interface declares this method to verify constraints this type has on |
| 190 | // payload operations. It returns the now familiar tri-state result. |
| 191 | mlir::DiagnosedSilenceableFailure |
| 192 | mlir::transform::CallOpInterfaceHandleType::checkPayload( |
| 193 | // Location at which diagnostics should be emitted. |
| 194 | mlir::Location loc, |
| 195 | // List of payload operations that are about to be associated with the |
| 196 | // handle that has this type. |
| 197 | llvm::ArrayRef<mlir::Operation *> payload) const { |
| 198 | |
| 199 | // All payload operations are expected to implement CallOpInterface, check |
| 200 | // this. |
| 201 | for (Operation *op : payload) { |
| 202 | if (llvm::isa<mlir::CallOpInterface>(op)) |
| 203 | continue; |
| 204 | |
| 205 | // By convention, these verifiers always emit a silenceable failure since |
| 206 | // they are checking a precondition. |
| 207 | DiagnosedSilenceableFailure diag = |
| 208 | emitSilenceableError(loc) |
| 209 | << "expected the payload operation to implement CallOpInterface" ; |
| 210 | diag.attachNote(op->getLoc()) << "offending operation" ; |
| 211 | return diag; |
| 212 | } |
| 213 | |
| 214 | // If everything is okay, return success. |
| 215 | return DiagnosedSilenceableFailure::success(); |
| 216 | } |
| 217 | |
| 218 | //===---------------------------------------------------------------------===// |
| 219 | // Extension registration |
| 220 | //===---------------------------------------------------------------------===// |
| 221 | |
| 222 | void registerMyExtension(::mlir::DialectRegistry ®istry) { |
| 223 | registry.addExtensions<MyExtension>(); |
| 224 | } |
| 225 | |