1 | //===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines Transform dialect extension operations used in the |
10 | // Chapter 3 of the Transform dialect tutorial. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "MyExtension.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
17 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
18 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
19 | #include "mlir/IR/DialectImplementation.h" |
20 | #include "mlir/Interfaces/CallInterfaces.h" |
21 | #include "llvm/ADT/TypeSwitch.h" |
22 | |
23 | #define GET_TYPEDEF_CLASSES |
24 | #include "MyExtensionTypes.cpp.inc" |
25 | |
26 | #define GET_OP_CLASSES |
27 | #include "MyExtension.cpp.inc" |
28 | |
29 | //===---------------------------------------------------------------------===// |
30 | // MyExtension |
31 | //===---------------------------------------------------------------------===// |
32 | |
33 | // Define a new transform dialect extension. This uses the CRTP idiom to |
34 | // identify extensions. |
35 | class MyExtension |
36 | : public ::mlir::transform::TransformDialectExtension<MyExtension> { |
37 | public: |
38 | // The extension must derive the base constructor. |
39 | using Base::Base; |
40 | |
41 | // This function initializes the extension, similarly to `initialize` in |
42 | // dialect definitions. List individual operations and dependent dialects |
43 | // here. |
44 | void init(); |
45 | }; |
46 | |
47 | void MyExtension::init() { |
48 | // Similarly to dialects, an extension can declare a dependent dialect. This |
49 | // dialect will be loaded along with the extension and, therefore, along with |
50 | // the Transform dialect. Only declare as dependent the dialects that contain |
51 | // the attributes or types used by transform operations. Do NOT declare as |
52 | // dependent the dialects produced during the transformation. |
53 | // declareDependentDialect<MyDialect>(); |
54 | |
55 | // When transformations are applied, they may produce new operations from |
56 | // previously unloaded dialects. Typically, a pass would need to declare |
57 | // itself dependent on the dialects containing such new operations. To avoid |
58 | // confusion with the dialects the extension itself depends on, the Transform |
59 | // dialects differentiates between: |
60 | // - dependent dialects, which are used by the transform operations, and |
61 | // - generated dialects, which contain the entities (attributes, operations, |
62 | // types) that may be produced by applying the transformation even when |
63 | // not present in the original payload IR. |
64 | // In the following chapter, we will be add operations that generate function |
65 | // calls and structured control flow operations, so let's declare the |
66 | // corresponding dialects as generated. |
67 | declareGeneratedDialect<::mlir::scf::SCFDialect>(); |
68 | declareGeneratedDialect<::mlir::func::FuncDialect>(); |
69 | |
70 | // Register the additional transform dialect types with the dialect. List all |
71 | // types generated from ODS. |
72 | registerTypes< |
73 | #define GET_TYPEDEF_LIST |
74 | #include "MyExtensionTypes.cpp.inc" |
75 | >(); |
76 | |
77 | // ODS generates these helpers for type printing and parsing, but the |
78 | // Transform dialect provides its own support for types supplied by the |
79 | // extension. Reference these functions to avoid a compiler warning. |
80 | (void)&generatedTypeParser; |
81 | (void)&generatedTypePrinter; |
82 | |
83 | // Finally, we register the additional transform operations with the dialect. |
84 | // List all operations generated from ODS. This call will perform additional |
85 | // checks that the operations implement the transform and memory effect |
86 | // interfaces required by the dialect interpreter and assert if they do not. |
87 | registerTransformOps< |
88 | #define GET_OP_LIST |
89 | #include "MyExtension.cpp.inc" |
90 | >(); |
91 | } |
92 | |
93 | //===---------------------------------------------------------------------===// |
94 | // ChangeCallTargetOp |
95 | //===---------------------------------------------------------------------===// |
96 | |
97 | static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { |
98 | call.setCallee(newTarget); |
99 | } |
100 | |
101 | // Implementation of our transform dialect operation. |
102 | // This operation returns a tri-state result that can be one of: |
103 | // - success when the transformation succeeded; |
104 | // - definite failure when the transformation failed in such a way that |
105 | // following |
106 | // transformations are impossible or undesirable, typically it could have left |
107 | // payload IR in an invalid state; it is expected that a diagnostic is emitted |
108 | // immediately before returning the definite error; |
109 | // - silenceable failure when the transformation failed but following |
110 | // transformations |
111 | // are still applicable, typically this means a precondition for the |
112 | // transformation is not satisfied and the payload IR has not been modified. |
113 | // The silenceable failure additionally carries a Diagnostic that can be emitted |
114 | // to the user. |
115 | ::mlir::DiagnosedSilenceableFailure |
116 | mlir::transform::ChangeCallTargetOp::applyToOne( |
117 | // The rewriter that should be used when modifying IR. |
118 | ::mlir::transform::TransformRewriter &rewriter, |
119 | // The single payload operation to which the transformation is applied. |
120 | ::mlir::func::CallOp call, |
121 | // The payload IR entities that will be appended to lists associated with |
122 | // the results of this transform operation. This list contains one entry per |
123 | // result. |
124 | ::mlir::transform::ApplyToEachResultList &results, |
125 | // The transform application state. This object can be used to query the |
126 | // current associations between transform IR values and payload IR entities. |
127 | // It can also carry additional user-defined state. |
128 | ::mlir::transform::TransformState &state) { |
129 | |
130 | // Dispatch to the actual transformation. |
131 | updateCallee(call, getNewTarget()); |
132 | |
133 | // If everything went well, return success. |
134 | return DiagnosedSilenceableFailure::success(); |
135 | } |
136 | |
137 | void mlir::transform::ChangeCallTargetOp::getEffects( |
138 | ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { |
139 | // Indicate that the `call` handle is only read by this operation because the |
140 | // associated operation is not erased but rather modified in-place, so the |
141 | // reference to it remains valid. |
142 | onlyReadsHandle(getCall(), effects); |
143 | |
144 | // Indicate that the payload is modified by this operation. |
145 | modifiesPayload(effects); |
146 | } |
147 | |
148 | //===---------------------------------------------------------------------===// |
149 | // CallToOp |
150 | //===---------------------------------------------------------------------===// |
151 | |
152 | static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter, |
153 | mlir::CallOpInterface call) { |
154 | // Construct an operation from an unregistered dialect. This is discouraged |
155 | // and is only used here for brevity of the overall example. |
156 | mlir::OperationState state(call.getLoc(), "my.mm4" ); |
157 | state.types.assign(call->result_type_begin(), call->result_type_end()); |
158 | state.operands.assign(call->operand_begin(), call->operand_end()); |
159 | |
160 | mlir::Operation *replacement = rewriter.create(state); |
161 | rewriter.replaceOp(call, replacement->getResults()); |
162 | return replacement; |
163 | } |
164 | |
165 | // See above for the signature description. |
166 | mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne( |
167 | mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call, |
168 | mlir::transform::ApplyToEachResultList &results, |
169 | mlir::transform::TransformState &state) { |
170 | |
171 | // Dispatch to the actual transformation. |
172 | Operation *replacement = replaceCallWithOp(rewriter, call); |
173 | |
174 | // Associate the payload operation produced by the rewrite with the result |
175 | // handle of this transform operation. |
176 | results.push_back(replacement); |
177 | |
178 | // If everything went well, return success. |
179 | return DiagnosedSilenceableFailure::success(); |
180 | } |
181 | |
182 | //===---------------------------------------------------------------------===// |
183 | // CallOpInterfaceHandleType |
184 | //===---------------------------------------------------------------------===// |
185 | |
186 | // The interface declares this method to verify constraints this type has on |
187 | // payload operations. It returns the now familiar tri-state result. |
188 | mlir::DiagnosedSilenceableFailure |
189 | mlir::transform::CallOpInterfaceHandleType::checkPayload( |
190 | // Location at which diagnostics should be emitted. |
191 | mlir::Location loc, |
192 | // List of payload operations that are about to be associated with the |
193 | // handle that has this type. |
194 | llvm::ArrayRef<mlir::Operation *> payload) const { |
195 | |
196 | // All payload operations are expected to implement CallOpInterface, check |
197 | // this. |
198 | for (Operation *op : payload) { |
199 | if (llvm::isa<mlir::CallOpInterface>(op)) |
200 | continue; |
201 | |
202 | // By convention, these verifiers always emit a silenceable failure since |
203 | // they are checking a precondition. |
204 | DiagnosedSilenceableFailure diag = |
205 | emitSilenceableError(loc) |
206 | << "expected the payload operation to implement CallOpInterface" ; |
207 | diag.attachNote(op->getLoc()) << "offending operation" ; |
208 | return diag; |
209 | } |
210 | |
211 | // If everything is okay, return success. |
212 | return DiagnosedSilenceableFailure::success(); |
213 | } |
214 | |
215 | //===---------------------------------------------------------------------===// |
216 | // Extension registration |
217 | //===---------------------------------------------------------------------===// |
218 | |
219 | void registerMyExtension(::mlir::DialectRegistry ®istry) { |
220 | registry.addExtensions<MyExtension>(); |
221 | } |
222 | |