1 | //===- TestTransformDialectInterpreter.cpp --------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines a test pass that interprets Transform dialect operations in |
10 | // the module. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "TestTransformDialectExtension.h" |
15 | #include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.h" |
16 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
17 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
18 | #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/BuiltinOps.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | |
23 | using namespace mlir; |
24 | |
25 | namespace { |
26 | /// Simple pass that applies transform dialect ops directly contained in a |
27 | /// module. |
28 | |
29 | template <typename Derived> |
30 | class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {}; |
31 | |
32 | class TestTransformDialectInterpreterPass |
33 | : public transform::TransformInterpreterPassBase< |
34 | TestTransformDialectInterpreterPass, OpPassWrapper> { |
35 | public: |
36 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
37 | TestTransformDialectInterpreterPass) |
38 | |
39 | TestTransformDialectInterpreterPass() = default; |
40 | TestTransformDialectInterpreterPass( |
41 | const TestTransformDialectInterpreterPass &pass) |
42 | : TransformInterpreterPassBase(pass) {} |
43 | |
44 | StringRef getArgument() const override { |
45 | return "test-transform-dialect-interpreter" ; |
46 | } |
47 | |
48 | StringRef getDescription() const override { |
49 | return "apply transform dialect operations one by one" ; |
50 | } |
51 | |
52 | void getDependentDialects(DialectRegistry ®istry) const override { |
53 | registry.insert<transform::TransformDialect>(); |
54 | } |
55 | |
56 | void findOperationsByName(Operation *root, StringRef name, |
57 | SmallVectorImpl<Operation *> &operations) { |
58 | root->walk(callback: [&](Operation *op) { |
59 | if (op->getName().getStringRef() == name) { |
60 | operations.push_back(Elt: op); |
61 | } |
62 | }); |
63 | } |
64 | |
65 | void createParameterMapping(MLIRContext &context, ArrayRef<int> values, |
66 | RaggedArray<transform::MappedValue> &result) { |
67 | SmallVector<transform::MappedValue> storage = |
68 | llvm::to_vector(llvm::map_range(values, [&](int v) { |
69 | Builder b(&context); |
70 | return transform::MappedValue(b.getI64IntegerAttr(v)); |
71 | })); |
72 | result.push_back(elements: std::move(storage)); |
73 | } |
74 | |
75 | void |
76 | createOpResultMapping(Operation *root, StringRef name, |
77 | RaggedArray<transform::MappedValue> &) { |
78 | SmallVector<Operation *> operations; |
79 | findOperationsByName(root, name, operations); |
80 | SmallVector<Value> results; |
81 | for (Operation *op : operations) |
82 | llvm::append_range(C&: results, R: op->getResults()); |
83 | extraMapping.push_back(elements&: results); |
84 | } |
85 | |
86 | unsigned numberOfSetOptions(const Option<std::string> &ops, |
87 | const ListOption<int> ¶ms, |
88 | const Option<std::string> &values) { |
89 | unsigned numSetValues = 0; |
90 | numSetValues += !ops.empty(); |
91 | numSetValues += !params.empty(); |
92 | numSetValues += !values.empty(); |
93 | return numSetValues; |
94 | } |
95 | |
96 | std::optional<LogicalResult> constructTransformModule(OpBuilder &builder, |
97 | Location loc) { |
98 | if (!testModuleGeneration) |
99 | return std::nullopt; |
100 | |
101 | builder.create<transform::SequenceOp>( |
102 | loc, TypeRange(), transform::FailurePropagationMode::Propagate, |
103 | builder.getType<transform::AnyOpType>(), |
104 | [](OpBuilder &b, Location nested, Value rootH) { |
105 | b.create<transform::DebugEmitRemarkAtOp>(nested, rootH, |
106 | "remark from generated" ); |
107 | b.create<transform::YieldOp>(nested, ValueRange()); |
108 | }); |
109 | return success(); |
110 | } |
111 | |
112 | void runOnOperation() override { |
113 | unsigned firstSetOptions = |
114 | numberOfSetOptions(ops: bindFirstExtraToOps, params: bindFirstExtraToParams, |
115 | values: bindFirstExtraToResultsOfOps); |
116 | unsigned secondSetOptions = |
117 | numberOfSetOptions(ops: bindSecondExtraToOps, params: bindSecondExtraToParams, |
118 | values: bindSecondExtraToResultsOfOps); |
119 | auto loc = UnknownLoc::get(&getContext()); |
120 | if (firstSetOptions > 1) { |
121 | emitError(loc) << "cannot bind the first extra top-level argument to " |
122 | "multiple entities" ; |
123 | return signalPassFailure(); |
124 | } |
125 | if (secondSetOptions > 1) { |
126 | emitError(loc) << "cannot bind the second extra top-level argument to " |
127 | "multiple entities" ; |
128 | return signalPassFailure(); |
129 | } |
130 | if (firstSetOptions == 0 && secondSetOptions != 0) { |
131 | emitError(loc) << "cannot bind the second extra top-level argument " |
132 | "without bindings the first" ; |
133 | } |
134 | |
135 | RaggedArray<transform::MappedValue> ; |
136 | if (!bindFirstExtraToOps.empty()) { |
137 | SmallVector<Operation *> operations; |
138 | findOperationsByName(root: getOperation(), name: bindFirstExtraToOps.getValue(), |
139 | operations); |
140 | extraMapping.push_back(elements&: operations); |
141 | } else if (!bindFirstExtraToParams.empty()) { |
142 | createParameterMapping(context&: getContext(), values: bindFirstExtraToParams, |
143 | result&: extraMapping); |
144 | } else if (!bindFirstExtraToResultsOfOps.empty()) { |
145 | createOpResultMapping(root: getOperation(), name: bindFirstExtraToResultsOfOps, |
146 | extraMapping); |
147 | } |
148 | |
149 | if (!bindSecondExtraToOps.empty()) { |
150 | SmallVector<Operation *> operations; |
151 | findOperationsByName(root: getOperation(), name: bindSecondExtraToOps, operations); |
152 | extraMapping.push_back(elements&: operations); |
153 | } else if (!bindSecondExtraToParams.empty()) { |
154 | createParameterMapping(context&: getContext(), values: bindSecondExtraToParams, |
155 | result&: extraMapping); |
156 | } else if (!bindSecondExtraToResultsOfOps.empty()) { |
157 | createOpResultMapping(root: getOperation(), name: bindSecondExtraToResultsOfOps, |
158 | extraMapping); |
159 | } |
160 | |
161 | options = options.enableExpensiveChecks(enable: enableExpensiveChecks); |
162 | options = options.enableEnforceSingleToplevelTransformOp( |
163 | enable: enforceSingleToplevelTransformOp); |
164 | if (failed(result: transform::detail::interpreterBaseRunOnOperationImpl( |
165 | target: getOperation(), passName: getArgument(), sharedTransformModule: getSharedTransformModule(), |
166 | libraryModule: getTransformLibraryModule(), extraMappings: extraMapping, options, |
167 | transformFileName, transformLibraryPaths, debugPayloadRootTag, |
168 | debugTransformRootTag, binaryName: getBinaryName()))) |
169 | return signalPassFailure(); |
170 | } |
171 | |
172 | Option<bool> enableExpensiveChecks{ |
173 | *this, "enable-expensive-checks" , llvm::cl::init(Val: false), |
174 | llvm::cl::desc("perform expensive checks to better report errors in the " |
175 | "transform IR" )}; |
176 | Option<bool> enforceSingleToplevelTransformOp{ |
177 | *this, "enforce-single-top-level-transform-op" , llvm::cl::init(Val: true), |
178 | llvm::cl::desc("Ensure that only a single top-level transform op is " |
179 | "present in the IR." )}; |
180 | |
181 | Option<std::string> { |
182 | *this, "bind-first-extra-to-ops" , |
183 | llvm::cl::desc("bind the first extra argument of the top-level op to " |
184 | "payload operations of the given kind" )}; |
185 | ListOption<int> { |
186 | *this, "bind-first-extra-to-params" , |
187 | llvm::cl::desc("bind the first extra argument of the top-level op to " |
188 | "the given integer parameters" )}; |
189 | Option<std::string> { |
190 | *this, "bind-first-extra-to-results-of-ops" , |
191 | llvm::cl::desc("bind the first extra argument of the top-level op to " |
192 | "results of payload operations of the given kind" )}; |
193 | |
194 | Option<std::string> { |
195 | *this, "bind-second-extra-to-ops" , |
196 | llvm::cl::desc("bind the second extra argument of the top-level op to " |
197 | "payload operations of the given kind" )}; |
198 | ListOption<int> { |
199 | *this, "bind-second-extra-to-params" , |
200 | llvm::cl::desc("bind the second extra argument of the top-level op to " |
201 | "the given integer parameters" )}; |
202 | Option<std::string> { |
203 | *this, "bind-second-extra-to-results-of-ops" , |
204 | llvm::cl::desc("bind the second extra argument of the top-level op to " |
205 | "results of payload operations of the given kind" )}; |
206 | |
207 | Option<std::string> transformFileName{ |
208 | *this, "transform-file-name" , llvm::cl::init(Val: "" ), |
209 | llvm::cl::desc( |
210 | "Optional filename containing a transform dialect specification to " |
211 | "apply. If left empty, the IR is assumed to contain one top-level " |
212 | "transform dialect operation somewhere in the module." )}; |
213 | Option<std::string> debugPayloadRootTag{ |
214 | *this, "debug-payload-root-tag" , llvm::cl::init(Val: "" ), |
215 | llvm::cl::desc( |
216 | "Select the operation with 'transform.target_tag' attribute having " |
217 | "the given value as payload IR root. If empty select the pass anchor " |
218 | "operation as the payload IR root." )}; |
219 | Option<std::string> debugTransformRootTag{ |
220 | *this, "debug-transform-root-tag" , llvm::cl::init(Val: "" ), |
221 | llvm::cl::desc( |
222 | "Select the operation with 'transform.target_tag' attribute having " |
223 | "the given value as container IR for top-level transform ops. This " |
224 | "allows user control on what transformation to apply. If empty, " |
225 | "select the container of the top-level transform op." )}; |
226 | ListOption<std::string> transformLibraryPaths{ |
227 | *this, "transform-library-paths" , llvm::cl::ZeroOrMore, |
228 | llvm::cl::desc("Optional paths to files with modules that should be " |
229 | "merged into the transform module to provide the " |
230 | "definitions of external named sequences." )}; |
231 | |
232 | Option<bool> testModuleGeneration{ |
233 | *this, "test-module-generation" , llvm::cl::init(Val: false), |
234 | llvm::cl::desc("test the generation of the transform module during pass " |
235 | "initialization, overridden by parsing" )}; |
236 | }; |
237 | |
238 | struct TestTransformDialectEraseSchedulePass |
239 | : public PassWrapper<TestTransformDialectEraseSchedulePass, |
240 | OperationPass<ModuleOp>> { |
241 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
242 | TestTransformDialectEraseSchedulePass) |
243 | |
244 | StringRef getArgument() const final { |
245 | return "test-transform-dialect-erase-schedule" ; |
246 | } |
247 | |
248 | StringRef getDescription() const final { |
249 | return "erase transform dialect schedule from the IR" ; |
250 | } |
251 | |
252 | void runOnOperation() override { |
253 | getOperation()->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) { |
254 | if (isa<transform::TransformOpInterface>(nestedOp)) { |
255 | nestedOp->erase(); |
256 | return WalkResult::skip(); |
257 | } |
258 | return WalkResult::advance(); |
259 | }); |
260 | } |
261 | }; |
262 | } // namespace |
263 | |
264 | namespace mlir { |
265 | namespace test { |
266 | /// Registers the test pass for erasing transform dialect ops. |
267 | void registerTestTransformDialectEraseSchedulePass() { |
268 | PassRegistration<TestTransformDialectEraseSchedulePass> reg; |
269 | } |
270 | /// Registers the test pass for applying transform dialect ops. |
271 | void registerTestTransformDialectInterpreterPass() { |
272 | PassRegistration<TestTransformDialectInterpreterPass> reg; |
273 | } |
274 | } // namespace test |
275 | } // namespace mlir |
276 | |