1//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines Transform dialect extension operations used in the
10// Chapter 3 of the Transform dialect tutorial.
11//
12//===----------------------------------------------------------------------===//
13
14#include "MyExtension.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/Transform/IR/TransformDialect.h"
18#include "mlir/Dialect/Transform/IR/TransformTypes.h"
19#include "mlir/IR/DialectImplementation.h"
20#include "mlir/Interfaces/CallInterfaces.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23#define GET_TYPEDEF_CLASSES
24#include "MyExtensionTypes.cpp.inc"
25
26#define GET_OP_CLASSES
27#include "MyExtension.cpp.inc"
28
29//===---------------------------------------------------------------------===//
30// MyExtension
31//===---------------------------------------------------------------------===//
32
33// Define a new transform dialect extension. This uses the CRTP idiom to
34// identify extensions.
35class MyExtension
36 : public ::mlir::transform::TransformDialectExtension<MyExtension> {
37public:
38 // The TypeID of this extension.
39 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
40
41 // The extension must derive the base constructor.
42 using Base::Base;
43
44 // This function initializes the extension, similarly to `initialize` in
45 // dialect definitions. List individual operations and dependent dialects
46 // here.
47 void init();
48};
49
50void MyExtension::init() {
51 // Similarly to dialects, an extension can declare a dependent dialect. This
52 // dialect will be loaded along with the extension and, therefore, along with
53 // the Transform dialect. Only declare as dependent the dialects that contain
54 // the attributes or types used by transform operations. Do NOT declare as
55 // dependent the dialects produced during the transformation.
56 // declareDependentDialect<MyDialect>();
57
58 // When transformations are applied, they may produce new operations from
59 // previously unloaded dialects. Typically, a pass would need to declare
60 // itself dependent on the dialects containing such new operations. To avoid
61 // confusion with the dialects the extension itself depends on, the Transform
62 // dialects differentiates between:
63 // - dependent dialects, which are used by the transform operations, and
64 // - generated dialects, which contain the entities (attributes, operations,
65 // types) that may be produced by applying the transformation even when
66 // not present in the original payload IR.
67 // In the following chapter, we will be add operations that generate function
68 // calls and structured control flow operations, so let's declare the
69 // corresponding dialects as generated.
70 declareGeneratedDialect<::mlir::scf::SCFDialect>();
71 declareGeneratedDialect<::mlir::func::FuncDialect>();
72
73 // Register the additional transform dialect types with the dialect. List all
74 // types generated from ODS.
75 registerTypes<
76#define GET_TYPEDEF_LIST
77#include "MyExtensionTypes.cpp.inc"
78 >();
79
80 // ODS generates these helpers for type printing and parsing, but the
81 // Transform dialect provides its own support for types supplied by the
82 // extension. Reference these functions to avoid a compiler warning.
83 (void)&generatedTypeParser;
84 (void)&generatedTypePrinter;
85
86 // Finally, we register the additional transform operations with the dialect.
87 // List all operations generated from ODS. This call will perform additional
88 // checks that the operations implement the transform and memory effect
89 // interfaces required by the dialect interpreter and assert if they do not.
90 registerTransformOps<
91#define GET_OP_LIST
92#include "MyExtension.cpp.inc"
93 >();
94}
95
96//===---------------------------------------------------------------------===//
97// ChangeCallTargetOp
98//===---------------------------------------------------------------------===//
99
100static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) {
101 call.setCallee(newTarget);
102}
103
104// Implementation of our transform dialect operation.
105// This operation returns a tri-state result that can be one of:
106// - success when the transformation succeeded;
107// - definite failure when the transformation failed in such a way that
108// following
109// transformations are impossible or undesirable, typically it could have left
110// payload IR in an invalid state; it is expected that a diagnostic is emitted
111// immediately before returning the definite error;
112// - silenceable failure when the transformation failed but following
113// transformations
114// are still applicable, typically this means a precondition for the
115// transformation is not satisfied and the payload IR has not been modified.
116// The silenceable failure additionally carries a Diagnostic that can be emitted
117// to the user.
118::mlir::DiagnosedSilenceableFailure
119mlir::transform::ChangeCallTargetOp::applyToOne(
120 // The rewriter that should be used when modifying IR.
121 ::mlir::transform::TransformRewriter &rewriter,
122 // The single payload operation to which the transformation is applied.
123 ::mlir::func::CallOp call,
124 // The payload IR entities that will be appended to lists associated with
125 // the results of this transform operation. This list contains one entry per
126 // result.
127 ::mlir::transform::ApplyToEachResultList &results,
128 // The transform application state. This object can be used to query the
129 // current associations between transform IR values and payload IR entities.
130 // It can also carry additional user-defined state.
131 ::mlir::transform::TransformState &state) {
132
133 // Dispatch to the actual transformation.
134 updateCallee(call, getNewTarget());
135
136 // If everything went well, return success.
137 return DiagnosedSilenceableFailure::success();
138}
139
140void mlir::transform::ChangeCallTargetOp::getEffects(
141 ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
142 // Indicate that the `call` handle is only read by this operation because the
143 // associated operation is not erased but rather modified in-place, so the
144 // reference to it remains valid.
145 onlyReadsHandle(getCallMutable(), effects);
146
147 // Indicate that the payload is modified by this operation.
148 modifiesPayload(effects);
149}
150
151//===---------------------------------------------------------------------===//
152// CallToOp
153//===---------------------------------------------------------------------===//
154
155static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter,
156 mlir::CallOpInterface call) {
157 // Construct an operation from an unregistered dialect. This is discouraged
158 // and is only used here for brevity of the overall example.
159 mlir::OperationState state(call.getLoc(), "my.mm4");
160 state.types.assign(call->result_type_begin(), call->result_type_end());
161 state.operands.assign(call->operand_begin(), call->operand_end());
162
163 mlir::Operation *replacement = rewriter.create(state);
164 rewriter.replaceOp(call, replacement->getResults());
165 return replacement;
166}
167
168// See above for the signature description.
169mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne(
170 mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call,
171 mlir::transform::ApplyToEachResultList &results,
172 mlir::transform::TransformState &state) {
173
174 // Dispatch to the actual transformation.
175 Operation *replacement = replaceCallWithOp(rewriter, call);
176
177 // Associate the payload operation produced by the rewrite with the result
178 // handle of this transform operation.
179 results.push_back(replacement);
180
181 // If everything went well, return success.
182 return DiagnosedSilenceableFailure::success();
183}
184
185//===---------------------------------------------------------------------===//
186// CallOpInterfaceHandleType
187//===---------------------------------------------------------------------===//
188
189// The interface declares this method to verify constraints this type has on
190// payload operations. It returns the now familiar tri-state result.
191mlir::DiagnosedSilenceableFailure
192mlir::transform::CallOpInterfaceHandleType::checkPayload(
193 // Location at which diagnostics should be emitted.
194 mlir::Location loc,
195 // List of payload operations that are about to be associated with the
196 // handle that has this type.
197 llvm::ArrayRef<mlir::Operation *> payload) const {
198
199 // All payload operations are expected to implement CallOpInterface, check
200 // this.
201 for (Operation *op : payload) {
202 if (llvm::isa<mlir::CallOpInterface>(op))
203 continue;
204
205 // By convention, these verifiers always emit a silenceable failure since
206 // they are checking a precondition.
207 DiagnosedSilenceableFailure diag =
208 emitSilenceableError(loc)
209 << "expected the payload operation to implement CallOpInterface";
210 diag.attachNote(op->getLoc()) << "offending operation";
211 return diag;
212 }
213
214 // If everything is okay, return success.
215 return DiagnosedSilenceableFailure::success();
216}
217
218//===---------------------------------------------------------------------===//
219// Extension registration
220//===---------------------------------------------------------------------===//
221
222void registerMyExtension(::mlir::DialectRegistry &registry) {
223 registry.addExtensions<MyExtension>();
224}
225

source code of mlir/examples/transform/Ch3/lib/MyExtension.cpp