1 | //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines Transform dialect extension operations used in the |
10 | // Chapter 2 of the Transform dialect tutorial. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "MyExtension.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
17 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
18 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
19 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
20 | #include "mlir/IR/DialectRegistry.h" |
21 | #include "mlir/IR/Operation.h" |
22 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
23 | #include "mlir/Support/LLVM.h" |
24 | #include "llvm/ADT/SmallVector.h" |
25 | #include "llvm/ADT/StringRef.h" |
26 | |
27 | // Define a new transform dialect extension. This uses the CRTP idiom to |
28 | // identify extensions. |
29 | class MyExtension |
30 | : public ::mlir::transform::TransformDialectExtension<MyExtension> { |
31 | public: |
32 | // The TypeID of this extension. |
33 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) |
34 | |
35 | // The extension must derive the base constructor. |
36 | using Base::Base; |
37 | |
38 | // This function initializes the extension, similarly to `initialize` in |
39 | // dialect definitions. List individual operations and dependent dialects |
40 | // here. |
41 | void init(); |
42 | }; |
43 | |
44 | void MyExtension::init() { |
45 | // Similarly to dialects, an extension can declare a dependent dialect. This |
46 | // dialect will be loaded along with the extension and, therefore, along with |
47 | // the Transform dialect. Only declare as dependent the dialects that contain |
48 | // the attributes or types used by transform operations. Do NOT declare as |
49 | // dependent the dialects produced during the transformation. |
50 | // declareDependentDialect<MyDialect>(); |
51 | |
52 | // When transformations are applied, they may produce new operations from |
53 | // previously unloaded dialects. Typically, a pass would need to declare |
54 | // itself dependent on the dialects containing such new operations. To avoid |
55 | // confusion with the dialects the extension itself depends on, the Transform |
56 | // dialects differentiates between: |
57 | // - dependent dialects, which are used by the transform operations, and |
58 | // - generated dialects, which contain the entities (attributes, operations, |
59 | // types) that may be produced by applying the transformation even when |
60 | // not present in the original payload IR. |
61 | // In the following chapter, we will be add operations that generate function |
62 | // calls and structured control flow operations, so let's declare the |
63 | // corresponding dialects as generated. |
64 | declareGeneratedDialect<::mlir::scf::SCFDialect>(); |
65 | declareGeneratedDialect<::mlir::func::FuncDialect>(); |
66 | |
67 | // Finally, we register the additional transform operations with the dialect. |
68 | // List all operations generated from ODS. This call will perform additional |
69 | // checks that the operations implement the transform and memory effect |
70 | // interfaces required by the dialect interpreter and assert if they do not. |
71 | registerTransformOps< |
72 | #define GET_OP_LIST |
73 | #include "MyExtension.cpp.inc" |
74 | >(); |
75 | } |
76 | |
77 | #define GET_OP_CLASSES |
78 | #include "MyExtension.cpp.inc" |
79 | |
80 | static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { |
81 | call.setCallee(newTarget); |
82 | } |
83 | |
84 | // Implementation of our transform dialect operation. |
85 | // This operation returns a tri-state result that can be one of: |
86 | // - success when the transformation succeeded; |
87 | // - definite failure when the transformation failed in such a way that |
88 | // following transformations are impossible or undesirable, typically it could |
89 | // have left payload IR in an invalid state; it is expected that a diagnostic |
90 | // is emitted immediately before returning the definite error; |
91 | // - silenceable failure when the transformation failed but following |
92 | // transformations are still applicable, typically this means a precondition |
93 | // for the transformation is not satisfied and the payload IR has not been |
94 | // modified. The silenceable failure additionally carries a Diagnostic that |
95 | // can be emitted to the user. |
96 | ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( |
97 | // The rewriter that should be used when modifying IR. |
98 | ::mlir::transform::TransformRewriter &rewriter, |
99 | // The list of payload IR entities that will be associated with the |
100 | // transform IR values defined by this transform operation. In this case, it |
101 | // can remain empty as there are no results. |
102 | ::mlir::transform::TransformResults &results, |
103 | // The transform application state. This object can be used to query the |
104 | // current associations between transform IR values and payload IR entities. |
105 | // It can also carry additional user-defined state. |
106 | ::mlir::transform::TransformState &state) { |
107 | |
108 | // First, we need to obtain the list of payload operations that are associated |
109 | // with the operand handle. |
110 | auto payload = state.getPayloadOps(getCall()); |
111 | |
112 | // Then, we iterate over the list of operands and call the actual IR-mutating |
113 | // function. We also check the preconditions here. |
114 | for (Operation *payloadOp : payload) { |
115 | auto call = dyn_cast<::mlir::func::CallOp>(payloadOp); |
116 | if (!call) { |
117 | DiagnosedSilenceableFailure diag = |
118 | emitSilenceableError() << "only applies to func.call payloads" ; |
119 | diag.attachNote(payloadOp->getLoc()) << "offending payload" ; |
120 | return diag; |
121 | } |
122 | |
123 | updateCallee(call, getNewTarget()); |
124 | } |
125 | |
126 | // If everything went well, return success. |
127 | return DiagnosedSilenceableFailure::success(); |
128 | } |
129 | |
130 | void mlir::transform::ChangeCallTargetOp::getEffects( |
131 | ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { |
132 | // Indicate that the `call` handle is only read by this operation because the |
133 | // associated operation is not erased but rather modified in-place, so the |
134 | // reference to it remains valid. |
135 | onlyReadsHandle(getCallMutable(), effects); |
136 | |
137 | // Indicate that the payload is modified by this operation. |
138 | modifiesPayload(effects); |
139 | } |
140 | |
141 | void registerMyExtension(::mlir::DialectRegistry ®istry) { |
142 | registry.addExtensions<MyExtension>(); |
143 | } |
144 | |