1//===- TestTransformDialectInterpreter.cpp --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines a test pass that interprets Transform dialect operations in
10// the module.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TestTransformDialectExtension.h"
15#include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.h"
16#include "mlir/Dialect/Transform/IR/TransformOps.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinOps.h"
21#include "mlir/Pass/Pass.h"
22
23using namespace mlir;
24
25namespace {
26/// Simple pass that applies transform dialect ops directly contained in a
27/// module.
28
29template <typename Derived>
30class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};
31
32class TestTransformDialectInterpreterPass
33 : public transform::TransformInterpreterPassBase<
34 TestTransformDialectInterpreterPass, OpPassWrapper> {
35public:
36 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
37 TestTransformDialectInterpreterPass)
38
39 TestTransformDialectInterpreterPass() = default;
40 TestTransformDialectInterpreterPass(
41 const TestTransformDialectInterpreterPass &pass)
42 : TransformInterpreterPassBase(pass) {}
43
44 StringRef getArgument() const override {
45 return "test-transform-dialect-interpreter";
46 }
47
48 StringRef getDescription() const override {
49 return "apply transform dialect operations one by one";
50 }
51
52 void getDependentDialects(DialectRegistry &registry) const override {
53 registry.insert<transform::TransformDialect>();
54 }
55
56 void findOperationsByName(Operation *root, StringRef name,
57 SmallVectorImpl<Operation *> &operations) {
58 root->walk(callback: [&](Operation *op) {
59 if (op->getName().getStringRef() == name) {
60 operations.push_back(Elt: op);
61 }
62 });
63 }
64
65 void createParameterMapping(MLIRContext &context, ArrayRef<int> values,
66 RaggedArray<transform::MappedValue> &result) {
67 SmallVector<transform::MappedValue> storage =
68 llvm::to_vector(llvm::map_range(values, [&](int v) {
69 Builder b(&context);
70 return transform::MappedValue(b.getI64IntegerAttr(v));
71 }));
72 result.push_back(elements: std::move(storage));
73 }
74
75 void
76 createOpResultMapping(Operation *root, StringRef name,
77 RaggedArray<transform::MappedValue> &extraMapping) {
78 SmallVector<Operation *> operations;
79 findOperationsByName(root, name, operations);
80 SmallVector<Value> results;
81 for (Operation *op : operations)
82 llvm::append_range(C&: results, R: op->getResults());
83 extraMapping.push_back(elements&: results);
84 }
85
86 unsigned numberOfSetOptions(const Option<std::string> &ops,
87 const ListOption<int> &params,
88 const Option<std::string> &values) {
89 unsigned numSetValues = 0;
90 numSetValues += !ops.empty();
91 numSetValues += !params.empty();
92 numSetValues += !values.empty();
93 return numSetValues;
94 }
95
96 std::optional<LogicalResult> constructTransformModule(OpBuilder &builder,
97 Location loc) {
98 if (!testModuleGeneration)
99 return std::nullopt;
100
101 builder.create<transform::SequenceOp>(
102 loc, TypeRange(), transform::FailurePropagationMode::Propagate,
103 builder.getType<transform::AnyOpType>(),
104 [](OpBuilder &b, Location nested, Value rootH) {
105 b.create<transform::DebugEmitRemarkAtOp>(nested, rootH,
106 "remark from generated");
107 b.create<transform::YieldOp>(nested, ValueRange());
108 });
109 return success();
110 }
111
112 void runOnOperation() override {
113 unsigned firstSetOptions =
114 numberOfSetOptions(ops: bindFirstExtraToOps, params: bindFirstExtraToParams,
115 values: bindFirstExtraToResultsOfOps);
116 unsigned secondSetOptions =
117 numberOfSetOptions(ops: bindSecondExtraToOps, params: bindSecondExtraToParams,
118 values: bindSecondExtraToResultsOfOps);
119 auto loc = UnknownLoc::get(&getContext());
120 if (firstSetOptions > 1) {
121 emitError(loc) << "cannot bind the first extra top-level argument to "
122 "multiple entities";
123 return signalPassFailure();
124 }
125 if (secondSetOptions > 1) {
126 emitError(loc) << "cannot bind the second extra top-level argument to "
127 "multiple entities";
128 return signalPassFailure();
129 }
130 if (firstSetOptions == 0 && secondSetOptions != 0) {
131 emitError(loc) << "cannot bind the second extra top-level argument "
132 "without bindings the first";
133 }
134
135 RaggedArray<transform::MappedValue> extraMapping;
136 if (!bindFirstExtraToOps.empty()) {
137 SmallVector<Operation *> operations;
138 findOperationsByName(root: getOperation(), name: bindFirstExtraToOps.getValue(),
139 operations);
140 extraMapping.push_back(elements&: operations);
141 } else if (!bindFirstExtraToParams.empty()) {
142 createParameterMapping(context&: getContext(), values: bindFirstExtraToParams,
143 result&: extraMapping);
144 } else if (!bindFirstExtraToResultsOfOps.empty()) {
145 createOpResultMapping(root: getOperation(), name: bindFirstExtraToResultsOfOps,
146 extraMapping);
147 }
148
149 if (!bindSecondExtraToOps.empty()) {
150 SmallVector<Operation *> operations;
151 findOperationsByName(root: getOperation(), name: bindSecondExtraToOps, operations);
152 extraMapping.push_back(elements&: operations);
153 } else if (!bindSecondExtraToParams.empty()) {
154 createParameterMapping(context&: getContext(), values: bindSecondExtraToParams,
155 result&: extraMapping);
156 } else if (!bindSecondExtraToResultsOfOps.empty()) {
157 createOpResultMapping(root: getOperation(), name: bindSecondExtraToResultsOfOps,
158 extraMapping);
159 }
160
161 options = options.enableExpensiveChecks(enable: enableExpensiveChecks);
162 options = options.enableEnforceSingleToplevelTransformOp(
163 enable: enforceSingleToplevelTransformOp);
164 if (failed(result: transform::detail::interpreterBaseRunOnOperationImpl(
165 target: getOperation(), passName: getArgument(), sharedTransformModule: getSharedTransformModule(),
166 libraryModule: getTransformLibraryModule(), extraMappings: extraMapping, options,
167 transformFileName, transformLibraryPaths, debugPayloadRootTag,
168 debugTransformRootTag, binaryName: getBinaryName())))
169 return signalPassFailure();
170 }
171
172 Option<bool> enableExpensiveChecks{
173 *this, "enable-expensive-checks", llvm::cl::init(Val: false),
174 llvm::cl::desc("perform expensive checks to better report errors in the "
175 "transform IR")};
176 Option<bool> enforceSingleToplevelTransformOp{
177 *this, "enforce-single-top-level-transform-op", llvm::cl::init(Val: true),
178 llvm::cl::desc("Ensure that only a single top-level transform op is "
179 "present in the IR.")};
180
181 Option<std::string> bindFirstExtraToOps{
182 *this, "bind-first-extra-to-ops",
183 llvm::cl::desc("bind the first extra argument of the top-level op to "
184 "payload operations of the given kind")};
185 ListOption<int> bindFirstExtraToParams{
186 *this, "bind-first-extra-to-params",
187 llvm::cl::desc("bind the first extra argument of the top-level op to "
188 "the given integer parameters")};
189 Option<std::string> bindFirstExtraToResultsOfOps{
190 *this, "bind-first-extra-to-results-of-ops",
191 llvm::cl::desc("bind the first extra argument of the top-level op to "
192 "results of payload operations of the given kind")};
193
194 Option<std::string> bindSecondExtraToOps{
195 *this, "bind-second-extra-to-ops",
196 llvm::cl::desc("bind the second extra argument of the top-level op to "
197 "payload operations of the given kind")};
198 ListOption<int> bindSecondExtraToParams{
199 *this, "bind-second-extra-to-params",
200 llvm::cl::desc("bind the second extra argument of the top-level op to "
201 "the given integer parameters")};
202 Option<std::string> bindSecondExtraToResultsOfOps{
203 *this, "bind-second-extra-to-results-of-ops",
204 llvm::cl::desc("bind the second extra argument of the top-level op to "
205 "results of payload operations of the given kind")};
206
207 Option<std::string> transformFileName{
208 *this, "transform-file-name", llvm::cl::init(Val: ""),
209 llvm::cl::desc(
210 "Optional filename containing a transform dialect specification to "
211 "apply. If left empty, the IR is assumed to contain one top-level "
212 "transform dialect operation somewhere in the module.")};
213 Option<std::string> debugPayloadRootTag{
214 *this, "debug-payload-root-tag", llvm::cl::init(Val: ""),
215 llvm::cl::desc(
216 "Select the operation with 'transform.target_tag' attribute having "
217 "the given value as payload IR root. If empty select the pass anchor "
218 "operation as the payload IR root.")};
219 Option<std::string> debugTransformRootTag{
220 *this, "debug-transform-root-tag", llvm::cl::init(Val: ""),
221 llvm::cl::desc(
222 "Select the operation with 'transform.target_tag' attribute having "
223 "the given value as container IR for top-level transform ops. This "
224 "allows user control on what transformation to apply. If empty, "
225 "select the container of the top-level transform op.")};
226 ListOption<std::string> transformLibraryPaths{
227 *this, "transform-library-paths", llvm::cl::ZeroOrMore,
228 llvm::cl::desc("Optional paths to files with modules that should be "
229 "merged into the transform module to provide the "
230 "definitions of external named sequences.")};
231
232 Option<bool> testModuleGeneration{
233 *this, "test-module-generation", llvm::cl::init(Val: false),
234 llvm::cl::desc("test the generation of the transform module during pass "
235 "initialization, overridden by parsing")};
236};
237
238struct TestTransformDialectEraseSchedulePass
239 : public PassWrapper<TestTransformDialectEraseSchedulePass,
240 OperationPass<ModuleOp>> {
241 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
242 TestTransformDialectEraseSchedulePass)
243
244 StringRef getArgument() const final {
245 return "test-transform-dialect-erase-schedule";
246 }
247
248 StringRef getDescription() const final {
249 return "erase transform dialect schedule from the IR";
250 }
251
252 void runOnOperation() override {
253 getOperation()->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) {
254 if (isa<transform::TransformOpInterface>(nestedOp)) {
255 nestedOp->erase();
256 return WalkResult::skip();
257 }
258 return WalkResult::advance();
259 });
260 }
261};
262} // namespace
263
264namespace mlir {
265namespace test {
266/// Registers the test pass for erasing transform dialect ops.
267void registerTestTransformDialectEraseSchedulePass() {
268 PassRegistration<TestTransformDialectEraseSchedulePass> reg;
269}
270/// Registers the test pass for applying transform dialect ops.
271void registerTestTransformDialectInterpreterPass() {
272 PassRegistration<TestTransformDialectInterpreterPass> reg;
273}
274} // namespace test
275} // namespace mlir
276

source code of mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp