1//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines Transform dialect extension operations used in the
10// Chapter 2 of the Transform dialect tutorial.
11//
12//===----------------------------------------------------------------------===//
13
14#include "MyExtension.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/Transform/IR/TransformDialect.h"
18#include "mlir/Dialect/Transform/IR/TransformTypes.h"
19#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20#include "mlir/IR/DialectRegistry.h"
21#include "mlir/IR/Operation.h"
22#include "mlir/Interfaces/SideEffectInterfaces.h"
23#include "mlir/Support/LLVM.h"
24#include "llvm/ADT/SmallVector.h"
25#include "llvm/ADT/StringRef.h"
26
27// Define a new transform dialect extension. This uses the CRTP idiom to
28// identify extensions.
29class MyExtension
30 : public ::mlir::transform::TransformDialectExtension<MyExtension> {
31public:
32 // The TypeID of this extension.
33 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
34
35 // The extension must derive the base constructor.
36 using Base::Base;
37
38 // This function initializes the extension, similarly to `initialize` in
39 // dialect definitions. List individual operations and dependent dialects
40 // here.
41 void init();
42};
43
44void MyExtension::init() {
45 // Similarly to dialects, an extension can declare a dependent dialect. This
46 // dialect will be loaded along with the extension and, therefore, along with
47 // the Transform dialect. Only declare as dependent the dialects that contain
48 // the attributes or types used by transform operations. Do NOT declare as
49 // dependent the dialects produced during the transformation.
50 // declareDependentDialect<MyDialect>();
51
52 // When transformations are applied, they may produce new operations from
53 // previously unloaded dialects. Typically, a pass would need to declare
54 // itself dependent on the dialects containing such new operations. To avoid
55 // confusion with the dialects the extension itself depends on, the Transform
56 // dialects differentiates between:
57 // - dependent dialects, which are used by the transform operations, and
58 // - generated dialects, which contain the entities (attributes, operations,
59 // types) that may be produced by applying the transformation even when
60 // not present in the original payload IR.
61 // In the following chapter, we will be add operations that generate function
62 // calls and structured control flow operations, so let's declare the
63 // corresponding dialects as generated.
64 declareGeneratedDialect<::mlir::scf::SCFDialect>();
65 declareGeneratedDialect<::mlir::func::FuncDialect>();
66
67 // Finally, we register the additional transform operations with the dialect.
68 // List all operations generated from ODS. This call will perform additional
69 // checks that the operations implement the transform and memory effect
70 // interfaces required by the dialect interpreter and assert if they do not.
71 registerTransformOps<
72#define GET_OP_LIST
73#include "MyExtension.cpp.inc"
74 >();
75}
76
77#define GET_OP_CLASSES
78#include "MyExtension.cpp.inc"
79
80static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) {
81 call.setCallee(newTarget);
82}
83
84// Implementation of our transform dialect operation.
85// This operation returns a tri-state result that can be one of:
86// - success when the transformation succeeded;
87// - definite failure when the transformation failed in such a way that
88// following transformations are impossible or undesirable, typically it could
89// have left payload IR in an invalid state; it is expected that a diagnostic
90// is emitted immediately before returning the definite error;
91// - silenceable failure when the transformation failed but following
92// transformations are still applicable, typically this means a precondition
93// for the transformation is not satisfied and the payload IR has not been
94// modified. The silenceable failure additionally carries a Diagnostic that
95// can be emitted to the user.
96::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply(
97 // The rewriter that should be used when modifying IR.
98 ::mlir::transform::TransformRewriter &rewriter,
99 // The list of payload IR entities that will be associated with the
100 // transform IR values defined by this transform operation. In this case, it
101 // can remain empty as there are no results.
102 ::mlir::transform::TransformResults &results,
103 // The transform application state. This object can be used to query the
104 // current associations between transform IR values and payload IR entities.
105 // It can also carry additional user-defined state.
106 ::mlir::transform::TransformState &state) {
107
108 // First, we need to obtain the list of payload operations that are associated
109 // with the operand handle.
110 auto payload = state.getPayloadOps(getCall());
111
112 // Then, we iterate over the list of operands and call the actual IR-mutating
113 // function. We also check the preconditions here.
114 for (Operation *payloadOp : payload) {
115 auto call = dyn_cast<::mlir::func::CallOp>(payloadOp);
116 if (!call) {
117 DiagnosedSilenceableFailure diag =
118 emitSilenceableError() << "only applies to func.call payloads";
119 diag.attachNote(payloadOp->getLoc()) << "offending payload";
120 return diag;
121 }
122
123 updateCallee(call, getNewTarget());
124 }
125
126 // If everything went well, return success.
127 return DiagnosedSilenceableFailure::success();
128}
129
130void mlir::transform::ChangeCallTargetOp::getEffects(
131 ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
132 // Indicate that the `call` handle is only read by this operation because the
133 // associated operation is not erased but rather modified in-place, so the
134 // reference to it remains valid.
135 onlyReadsHandle(getCallMutable(), effects);
136
137 // Indicate that the payload is modified by this operation.
138 modifiesPayload(effects);
139}
140
141void registerMyExtension(::mlir::DialectRegistry &registry) {
142 registry.addExtensions<MyExtension>();
143}
144

source code of mlir/examples/transform/Ch2/lib/MyExtension.cpp