1 | //===- TestTransformDialectExtension.cpp ----------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines an extension of the MLIR Transform dialect for testing |
10 | // purposes. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "TestTransformDialectExtension.h" |
15 | #include "TestTransformStateExtension.h" |
16 | #include "mlir/Dialect/PDL/IR/PDL.h" |
17 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
18 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
19 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
20 | #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" |
21 | #include "mlir/IR/OpImplementation.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "llvm/ADT/STLExtras.h" |
24 | #include "llvm/ADT/TypeSwitch.h" |
25 | #include "llvm/Support/Compiler.h" |
26 | #include "llvm/Support/raw_ostream.h" |
27 | |
28 | using namespace mlir; |
29 | |
30 | namespace { |
31 | /// Simple transform op defined outside of the dialect. Just emits a remark when |
32 | /// applied. This op is defined in C++ to test that C++ definitions also work |
33 | /// for op injection into the Transform dialect. |
34 | class TestTransformOp |
35 | : public Op<TestTransformOp, transform::TransformOpInterface::Trait, |
36 | MemoryEffectOpInterface::Trait> { |
37 | public: |
38 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) |
39 | |
40 | using Op::Op; |
41 | |
42 | static ArrayRef<StringRef> getAttributeNames() { return {}; } |
43 | |
44 | static constexpr llvm::StringLiteral getOperationName() { |
45 | return llvm::StringLiteral("transform.test_transform_op" ); |
46 | } |
47 | |
48 | DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, |
49 | transform::TransformResults &results, |
50 | transform::TransformState &state) { |
51 | InFlightDiagnostic = emitRemark() << "applying transformation" ; |
52 | if (Attribute message = getMessage()) |
53 | remark << " " << message; |
54 | |
55 | return DiagnosedSilenceableFailure::success(); |
56 | } |
57 | |
58 | Attribute getMessage() { |
59 | return getOperation()->getDiscardableAttr("message" ); |
60 | } |
61 | |
62 | static ParseResult parse(OpAsmParser &parser, OperationState &state) { |
63 | StringAttr message; |
64 | OptionalParseResult result = parser.parseOptionalAttribute(message); |
65 | if (!result.has_value()) |
66 | return success(); |
67 | |
68 | if (result.value().succeeded()) |
69 | state.addAttribute("message" , message); |
70 | return result.value(); |
71 | } |
72 | |
73 | void print(OpAsmPrinter &printer) { |
74 | if (getMessage()) |
75 | printer << " " << getMessage(); |
76 | } |
77 | |
78 | // No side effects. |
79 | void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} |
80 | }; |
81 | |
82 | /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait |
83 | /// in cases where it is attached to ops that do not comply with the trait |
84 | /// requirements. This op cannot be defined in ODS because ODS generates strict |
85 | /// verifiers that overalp with those in the trait and run earlier. |
86 | class TestTransformUnrestrictedOpNoInterface |
87 | : public Op<TestTransformUnrestrictedOpNoInterface, |
88 | transform::PossibleTopLevelTransformOpTrait, |
89 | transform::TransformOpInterface::Trait, |
90 | MemoryEffectOpInterface::Trait> { |
91 | public: |
92 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
93 | TestTransformUnrestrictedOpNoInterface) |
94 | |
95 | using Op::Op; |
96 | |
97 | static ArrayRef<StringRef> getAttributeNames() { return {}; } |
98 | |
99 | static constexpr llvm::StringLiteral getOperationName() { |
100 | return llvm::StringLiteral( |
101 | "transform.test_transform_unrestricted_op_no_interface" ); |
102 | } |
103 | |
104 | DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, |
105 | transform::TransformResults &results, |
106 | transform::TransformState &state) { |
107 | return DiagnosedSilenceableFailure::success(); |
108 | } |
109 | |
110 | // No side effects. |
111 | void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} |
112 | }; |
113 | } // namespace |
114 | |
115 | DiagnosedSilenceableFailure |
116 | mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( |
117 | transform::TransformRewriter &rewriter, |
118 | transform::TransformResults &results, transform::TransformState &state) { |
119 | if (getOperation()->getNumOperands() != 0) { |
120 | results.set(cast<OpResult>(getResult()), |
121 | {getOperation()->getOperand(0).getDefiningOp()}); |
122 | } else { |
123 | results.set(cast<OpResult>(getResult()), {getOperation()}); |
124 | } |
125 | return DiagnosedSilenceableFailure::success(); |
126 | } |
127 | |
128 | void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects( |
129 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
130 | if (getOperand()) |
131 | transform::onlyReadsHandle(getOperand(), effects); |
132 | transform::producesHandle(getRes(), effects); |
133 | } |
134 | |
135 | DiagnosedSilenceableFailure |
136 | mlir::test::TestProduceValueHandleToSelfOperand::apply( |
137 | transform::TransformRewriter &rewriter, |
138 | transform::TransformResults &results, transform::TransformState &state) { |
139 | results.setValues(llvm::cast<OpResult>(getOut()), {getIn()}); |
140 | return DiagnosedSilenceableFailure::success(); |
141 | } |
142 | |
143 | void mlir::test::TestProduceValueHandleToSelfOperand::getEffects( |
144 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
145 | transform::onlyReadsHandle(getIn(), effects); |
146 | transform::producesHandle(getOut(), effects); |
147 | transform::onlyReadsPayload(effects); |
148 | } |
149 | |
150 | DiagnosedSilenceableFailure |
151 | mlir::test::TestProduceValueHandleToResult::applyToOne( |
152 | transform::TransformRewriter &rewriter, Operation *target, |
153 | transform::ApplyToEachResultList &results, |
154 | transform::TransformState &state) { |
155 | if (target->getNumResults() <= getNumber()) |
156 | return emitSilenceableError() << "payload has no result #" << getNumber(); |
157 | results.push_back(target->getResult(getNumber())); |
158 | return DiagnosedSilenceableFailure::success(); |
159 | } |
160 | |
161 | void mlir::test::TestProduceValueHandleToResult::getEffects( |
162 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
163 | transform::onlyReadsHandle(getIn(), effects); |
164 | transform::producesHandle(getOut(), effects); |
165 | transform::onlyReadsPayload(effects); |
166 | } |
167 | |
168 | DiagnosedSilenceableFailure |
169 | mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( |
170 | transform::TransformRewriter &rewriter, Operation *target, |
171 | transform::ApplyToEachResultList &results, |
172 | transform::TransformState &state) { |
173 | if (!target->getBlock()) |
174 | return emitSilenceableError() << "payload has no parent block" ; |
175 | if (target->getBlock()->getNumArguments() <= getNumber()) |
176 | return emitSilenceableError() |
177 | << "parent of the payload has no argument #" << getNumber(); |
178 | results.push_back(target->getBlock()->getArgument(getNumber())); |
179 | return DiagnosedSilenceableFailure::success(); |
180 | } |
181 | |
182 | void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects( |
183 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
184 | transform::onlyReadsHandle(getIn(), effects); |
185 | transform::producesHandle(getOut(), effects); |
186 | transform::onlyReadsPayload(effects); |
187 | } |
188 | |
189 | bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() { |
190 | return getAllowRepeatedHandles(); |
191 | } |
192 | |
193 | DiagnosedSilenceableFailure |
194 | mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter, |
195 | transform::TransformResults &results, |
196 | transform::TransformState &state) { |
197 | return DiagnosedSilenceableFailure::success(); |
198 | } |
199 | |
200 | void mlir::test::TestConsumeOperand::getEffects( |
201 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
202 | transform::consumesHandle(getOperand(), effects); |
203 | if (getSecondOperand()) |
204 | transform::consumesHandle(getSecondOperand(), effects); |
205 | transform::modifiesPayload(effects); |
206 | } |
207 | |
208 | DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( |
209 | transform::TransformRewriter &rewriter, |
210 | transform::TransformResults &results, transform::TransformState &state) { |
211 | auto payload = state.getPayloadOps(getOperand()); |
212 | assert(llvm::hasSingleElement(payload) && "expected a single target op" ); |
213 | if ((*payload.begin())->getName().getStringRef() != getOpKind()) { |
214 | return emitSilenceableError() |
215 | << "op expected the operand to be associated a payload op of kind " |
216 | << getOpKind() << " got " |
217 | << (*payload.begin())->getName().getStringRef(); |
218 | } |
219 | |
220 | emitRemark() << "succeeded" ; |
221 | return DiagnosedSilenceableFailure::success(); |
222 | } |
223 | |
224 | void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects( |
225 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
226 | transform::consumesHandle(getOperand(), effects); |
227 | transform::modifiesPayload(effects); |
228 | } |
229 | |
230 | DiagnosedSilenceableFailure |
231 | mlir::test::TestSucceedIfOperandOfOpKind::matchOperation( |
232 | Operation *op, transform::TransformResults &results, |
233 | transform::TransformState &state) { |
234 | if (op->getName().getStringRef() != getOpKind()) { |
235 | return emitSilenceableError() |
236 | << "op expected the operand to be associated with a payload op of " |
237 | "kind " |
238 | << getOpKind() << " got " << op->getName().getStringRef(); |
239 | } |
240 | return DiagnosedSilenceableFailure::success(); |
241 | } |
242 | |
243 | void mlir::test::TestSucceedIfOperandOfOpKind::getEffects( |
244 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
245 | transform::onlyReadsHandle(getOperand(), effects); |
246 | transform::onlyReadsPayload(effects); |
247 | } |
248 | |
249 | DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply( |
250 | transform::TransformRewriter &rewriter, |
251 | transform::TransformResults &results, transform::TransformState &state) { |
252 | state.addExtension<TestTransformStateExtension>(getMessageAttr()); |
253 | return DiagnosedSilenceableFailure::success(); |
254 | } |
255 | |
256 | DiagnosedSilenceableFailure |
257 | mlir::test::TestCheckIfTestExtensionPresentOp::apply( |
258 | transform::TransformRewriter &rewriter, |
259 | transform::TransformResults &results, transform::TransformState &state) { |
260 | auto *extension = state.getExtension<TestTransformStateExtension>(); |
261 | if (!extension) { |
262 | emitRemark() << "extension absent" ; |
263 | return DiagnosedSilenceableFailure::success(); |
264 | } |
265 | |
266 | InFlightDiagnostic diag = emitRemark() |
267 | << "extension present, " << extension->getMessage(); |
268 | for (Operation *payload : state.getPayloadOps(getOperand())) { |
269 | diag.attachNote(payload->getLoc()) << "associated payload op" ; |
270 | #ifndef NDEBUG |
271 | SmallVector<Value> handles; |
272 | assert(succeeded(state.getHandlesForPayloadOp(payload, handles))); |
273 | assert(llvm::is_contained(handles, getOperand()) && |
274 | "inconsistent mapping between transform IR handles and payload IR " |
275 | "operations" ); |
276 | #endif // NDEBUG |
277 | } |
278 | |
279 | return DiagnosedSilenceableFailure::success(); |
280 | } |
281 | |
282 | void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects( |
283 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
284 | transform::onlyReadsHandle(getOperand(), effects); |
285 | transform::onlyReadsPayload(effects); |
286 | } |
287 | |
288 | DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( |
289 | transform::TransformRewriter &rewriter, |
290 | transform::TransformResults &results, transform::TransformState &state) { |
291 | auto *extension = state.getExtension<TestTransformStateExtension>(); |
292 | if (!extension) |
293 | return emitDefiniteFailure("TestTransformStateExtension missing" ); |
294 | |
295 | if (failed(extension->updateMapping( |
296 | *state.getPayloadOps(getOperand()).begin(), getOperation()))) |
297 | return DiagnosedSilenceableFailure::definiteFailure(); |
298 | if (getNumResults() > 0) |
299 | results.set(cast<OpResult>(getResult(0)), {getOperation()}); |
300 | return DiagnosedSilenceableFailure::success(); |
301 | } |
302 | |
303 | void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects( |
304 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
305 | transform::onlyReadsHandle(getOperand(), effects); |
306 | transform::producesHandle(getOut(), effects); |
307 | transform::onlyReadsPayload(effects); |
308 | } |
309 | |
310 | DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( |
311 | transform::TransformRewriter &rewriter, |
312 | transform::TransformResults &results, transform::TransformState &state) { |
313 | state.removeExtension<TestTransformStateExtension>(); |
314 | return DiagnosedSilenceableFailure::success(); |
315 | } |
316 | |
317 | DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply( |
318 | transform::TransformRewriter &rewriter, |
319 | transform::TransformResults &results, transform::TransformState &state) { |
320 | auto payloadOps = state.getPayloadOps(getTarget()); |
321 | auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); |
322 | results.set(llvm::cast<OpResult>(getResult()), reversedOps); |
323 | return DiagnosedSilenceableFailure::success(); |
324 | } |
325 | |
326 | DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( |
327 | transform::TransformRewriter &rewriter, |
328 | transform::TransformResults &results, transform::TransformState &state) { |
329 | return DiagnosedSilenceableFailure::success(); |
330 | } |
331 | |
332 | void mlir::test::TestTransformOpWithRegions::getEffects( |
333 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} |
334 | |
335 | DiagnosedSilenceableFailure |
336 | mlir::test::TestBranchingTransformOpTerminator::apply( |
337 | transform::TransformRewriter &rewriter, |
338 | transform::TransformResults &results, transform::TransformState &state) { |
339 | return DiagnosedSilenceableFailure::success(); |
340 | } |
341 | |
342 | void mlir::test::TestBranchingTransformOpTerminator::getEffects( |
343 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} |
344 | |
345 | DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( |
346 | transform::TransformRewriter &rewriter, |
347 | transform::TransformResults &results, transform::TransformState &state) { |
348 | emitRemark() << getRemark(); |
349 | for (Operation *op : state.getPayloadOps(getTarget())) |
350 | rewriter.eraseOp(op); |
351 | |
352 | if (getFailAfterErase()) |
353 | return emitSilenceableError() << "silenceable error" ; |
354 | return DiagnosedSilenceableFailure::success(); |
355 | } |
356 | |
357 | void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects( |
358 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
359 | transform::consumesHandle(getTarget(), effects); |
360 | transform::modifiesPayload(effects); |
361 | } |
362 | |
363 | DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( |
364 | transform::TransformRewriter &rewriter, Operation *target, |
365 | transform::ApplyToEachResultList &results, |
366 | transform::TransformState &state) { |
367 | OperationState opState(target->getLoc(), "foo" ); |
368 | results.push_back(OpBuilder(target).create(opState)); |
369 | return DiagnosedSilenceableFailure::success(); |
370 | } |
371 | |
372 | DiagnosedSilenceableFailure |
373 | mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( |
374 | transform::TransformRewriter &rewriter, Operation *target, |
375 | transform::ApplyToEachResultList &results, |
376 | transform::TransformState &state) { |
377 | static int count = 0; |
378 | if (count++ == 0) { |
379 | OperationState opState(target->getLoc(), "foo" ); |
380 | results.push_back(OpBuilder(target).create(opState)); |
381 | } |
382 | return DiagnosedSilenceableFailure::success(); |
383 | } |
384 | |
385 | DiagnosedSilenceableFailure |
386 | mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( |
387 | transform::TransformRewriter &rewriter, Operation *target, |
388 | transform::ApplyToEachResultList &results, |
389 | transform::TransformState &state) { |
390 | OperationState opState(target->getLoc(), "foo" ); |
391 | results.push_back(OpBuilder(target).create(opState)); |
392 | results.push_back(OpBuilder(target).create(opState)); |
393 | return DiagnosedSilenceableFailure::success(); |
394 | } |
395 | |
396 | DiagnosedSilenceableFailure |
397 | mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( |
398 | transform::TransformRewriter &rewriter, Operation *target, |
399 | transform::ApplyToEachResultList &results, |
400 | transform::TransformState &state) { |
401 | OperationState opState(target->getLoc(), "foo" ); |
402 | results.push_back(nullptr); |
403 | results.push_back(OpBuilder(target).create(opState)); |
404 | return DiagnosedSilenceableFailure::success(); |
405 | } |
406 | |
407 | DiagnosedSilenceableFailure |
408 | mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( |
409 | transform::TransformRewriter &rewriter, Operation *target, |
410 | transform::ApplyToEachResultList &results, |
411 | transform::TransformState &state) { |
412 | if (target->hasAttr("target_me" )) |
413 | return DiagnosedSilenceableFailure::success(); |
414 | return emitDefaultSilenceableFailure(target); |
415 | } |
416 | |
417 | DiagnosedSilenceableFailure |
418 | mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter, |
419 | transform::TransformResults &results, |
420 | transform::TransformState &state) { |
421 | results.set(llvm::cast<OpResult>(getCopy()), |
422 | state.getPayloadOps(getHandle())); |
423 | return DiagnosedSilenceableFailure::success(); |
424 | } |
425 | |
426 | void mlir::test::TestCopyPayloadOp::getEffects( |
427 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
428 | transform::onlyReadsHandle(getHandle(), effects); |
429 | transform::producesHandle(getCopy(), effects); |
430 | transform::onlyReadsPayload(effects); |
431 | } |
432 | |
433 | DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload( |
434 | Location loc, ArrayRef<Operation *> payload) const { |
435 | if (payload.empty()) |
436 | return DiagnosedSilenceableFailure::success(); |
437 | |
438 | for (Operation *op : payload) { |
439 | if (op->getName().getDialectNamespace() != "test" ) { |
440 | return emitSilenceableError(loc) << "expected the payload operation to " |
441 | "belong to the 'test' dialect" ; |
442 | } |
443 | } |
444 | |
445 | return DiagnosedSilenceableFailure::success(); |
446 | } |
447 | |
448 | DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload( |
449 | Location loc, ArrayRef<Attribute> payload) const { |
450 | for (Attribute attr : payload) { |
451 | auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr); |
452 | if (integerAttr && integerAttr.getType().isSignlessInteger(32)) |
453 | continue; |
454 | return emitSilenceableError(loc) |
455 | << "expected the parameter to be a i32 integer attribute" ; |
456 | } |
457 | |
458 | return DiagnosedSilenceableFailure::success(); |
459 | } |
460 | |
461 | void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects( |
462 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
463 | transform::onlyReadsHandle(getTarget(), effects); |
464 | } |
465 | |
466 | DiagnosedSilenceableFailure |
467 | mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply( |
468 | transform::TransformRewriter &rewriter, |
469 | transform::TransformResults &results, transform::TransformState &state) { |
470 | int64_t count = 0; |
471 | for (Operation *op : state.getPayloadOps(getTarget())) { |
472 | op->walk([&](Operation *nested) { |
473 | SmallVector<Value> handles; |
474 | (void)state.getHandlesForPayloadOp(nested, handles); |
475 | count += handles.size(); |
476 | }); |
477 | } |
478 | emitRemark() << count << " handles nested under" ; |
479 | return DiagnosedSilenceableFailure::success(); |
480 | } |
481 | |
482 | DiagnosedSilenceableFailure |
483 | mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter, |
484 | transform::TransformResults &results, |
485 | transform::TransformState &state) { |
486 | SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0); |
487 | if (Value param = getParam()) { |
488 | values = llvm::to_vector( |
489 | llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t { |
490 | return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue( |
491 | UINT32_MAX); |
492 | })); |
493 | } |
494 | |
495 | Builder builder(getContext()); |
496 | SmallVector<Attribute> result = llvm::to_vector( |
497 | llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute { |
498 | return builder.getI32IntegerAttr(value + getAddendum()); |
499 | })); |
500 | results.setParams(llvm::cast<OpResult>(getResult()), result); |
501 | return DiagnosedSilenceableFailure::success(); |
502 | } |
503 | |
504 | DiagnosedSilenceableFailure |
505 | mlir::test::TestProduceParamWithNumberOfTestOps::apply( |
506 | transform::TransformRewriter &rewriter, |
507 | transform::TransformResults &results, transform::TransformState &state) { |
508 | Builder builder(getContext()); |
509 | SmallVector<Attribute> result = llvm::to_vector( |
510 | llvm::map_range(state.getPayloadOps(getHandle()), |
511 | [&builder](Operation *payload) -> Attribute { |
512 | int32_t count = 0; |
513 | payload->walk([&count](Operation *op) { |
514 | if (op->getName().getDialectNamespace() == "test" ) |
515 | ++count; |
516 | }); |
517 | return builder.getI32IntegerAttr(count); |
518 | })); |
519 | results.setParams(llvm::cast<OpResult>(getResult()), result); |
520 | return DiagnosedSilenceableFailure::success(); |
521 | } |
522 | |
523 | DiagnosedSilenceableFailure |
524 | mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter, |
525 | transform::TransformResults &results, |
526 | transform::TransformState &state) { |
527 | results.setParams(llvm::cast<OpResult>(getResult()), getAttr()); |
528 | return DiagnosedSilenceableFailure::success(); |
529 | } |
530 | |
531 | void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects( |
532 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
533 | transform::onlyReadsHandle(getIn(), effects); |
534 | transform::producesHandle(getOut(), effects); |
535 | transform::producesHandle(getParam(), effects); |
536 | } |
537 | |
538 | DiagnosedSilenceableFailure |
539 | mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( |
540 | transform::TransformRewriter &rewriter, Operation *target, |
541 | ::transform::ApplyToEachResultList &results, |
542 | ::transform::TransformState &state) { |
543 | Builder builder(getContext()); |
544 | if (getFirstResultIsParam()) { |
545 | results.push_back(builder.getI64IntegerAttr(0)); |
546 | } else if (getFirstResultIsNull()) { |
547 | results.push_back(nullptr); |
548 | } else { |
549 | results.push_back(*state.getPayloadOps(getIn()).begin()); |
550 | } |
551 | |
552 | if (getSecondResultIsHandle()) { |
553 | results.push_back(*state.getPayloadOps(getIn()).begin()); |
554 | } else { |
555 | results.push_back(builder.getI64IntegerAttr(42)); |
556 | } |
557 | |
558 | return DiagnosedSilenceableFailure::success(); |
559 | } |
560 | |
561 | void mlir::test::TestProduceNullPayloadOp::getEffects( |
562 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
563 | transform::producesHandle(getOut(), effects); |
564 | } |
565 | |
566 | DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( |
567 | transform::TransformRewriter &rewriter, |
568 | transform::TransformResults &results, transform::TransformState &state) { |
569 | SmallVector<Operation *, 1> null({nullptr}); |
570 | results.set(llvm::cast<OpResult>(getOut()), null); |
571 | return DiagnosedSilenceableFailure::success(); |
572 | } |
573 | |
574 | DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply( |
575 | transform::TransformRewriter &rewriter, |
576 | transform::TransformResults &results, transform::TransformState &state) { |
577 | results.set(cast<OpResult>(getOut()), {}); |
578 | return DiagnosedSilenceableFailure::success(); |
579 | } |
580 | |
581 | void mlir::test::TestProduceNullParamOp::getEffects( |
582 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
583 | transform::producesHandle(getOut(), effects); |
584 | } |
585 | |
586 | DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply( |
587 | transform::TransformRewriter &rewriter, |
588 | transform::TransformResults &results, transform::TransformState &state) { |
589 | results.setParams(llvm::cast<OpResult>(getOut()), Attribute()); |
590 | return DiagnosedSilenceableFailure::success(); |
591 | } |
592 | |
593 | void mlir::test::TestProduceNullValueOp::getEffects( |
594 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
595 | transform::producesHandle(getOut(), effects); |
596 | } |
597 | |
598 | DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply( |
599 | transform::TransformRewriter &rewriter, |
600 | transform::TransformResults &results, transform::TransformState &state) { |
601 | results.setValues(llvm::cast<OpResult>(getOut()), {Value()}); |
602 | return DiagnosedSilenceableFailure::success(); |
603 | } |
604 | |
605 | void mlir::test::TestRequiredMemoryEffectsOp::getEffects( |
606 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
607 | if (getHasOperandEffect()) |
608 | transform::consumesHandle(getIn(), effects); |
609 | |
610 | if (getHasResultEffect()) |
611 | transform::producesHandle(getOut(), effects); |
612 | else |
613 | transform::onlyReadsHandle(getOut(), effects); |
614 | |
615 | if (getModifiesPayload()) |
616 | transform::modifiesPayload(effects); |
617 | } |
618 | |
619 | DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( |
620 | transform::TransformRewriter &rewriter, |
621 | transform::TransformResults &results, transform::TransformState &state) { |
622 | results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn())); |
623 | return DiagnosedSilenceableFailure::success(); |
624 | } |
625 | |
626 | void mlir::test::TestTrackedRewriteOp::getEffects( |
627 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
628 | transform::onlyReadsHandle(getIn(), effects); |
629 | transform::modifiesPayload(effects); |
630 | } |
631 | |
632 | void mlir::test::TestDummyPayloadOp::getEffects( |
633 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
634 | for (OpResult result : getResults()) |
635 | transform::producesHandle(result, effects); |
636 | } |
637 | |
638 | LogicalResult mlir::test::TestDummyPayloadOp::verify() { |
639 | if (getFailToVerify()) |
640 | return emitOpError() << "fail_to_verify is set" ; |
641 | return success(); |
642 | } |
643 | |
644 | DiagnosedSilenceableFailure |
645 | mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter, |
646 | transform::TransformResults &results, |
647 | transform::TransformState &state) { |
648 | int64_t numIterations = 0; |
649 | |
650 | // `getPayloadOps` returns an iterator that skips ops that are erased in the |
651 | // loop body. Replacement ops are not enumerated. |
652 | for (Operation *op : state.getPayloadOps(getIn())) { |
653 | ++numIterations; |
654 | (void)op; |
655 | |
656 | // Erase all payload ops. The outer loop should have only one iteration. |
657 | for (Operation *op : state.getPayloadOps(getIn())) { |
658 | rewriter.setInsertionPoint(op); |
659 | if (op->hasAttr("erase_me" )) { |
660 | rewriter.eraseOp(op); |
661 | continue; |
662 | } |
663 | if (!op->hasAttr("replace_me" )) { |
664 | continue; |
665 | } |
666 | |
667 | SmallVector<NamedAttribute> attributes; |
668 | attributes.emplace_back(rewriter.getStringAttr("new_op" ), |
669 | rewriter.getUnitAttr()); |
670 | OperationState opState(op->getLoc(), op->getName().getIdentifier(), |
671 | /*operands=*/ValueRange(), |
672 | /*types=*/op->getResultTypes(), attributes); |
673 | Operation *newOp = rewriter.create(opState); |
674 | rewriter.replaceOp(op, newOp->getResults()); |
675 | } |
676 | } |
677 | |
678 | emitRemark() << numIterations << " iterations" ; |
679 | return DiagnosedSilenceableFailure::success(); |
680 | } |
681 | |
682 | namespace { |
683 | // Test pattern to replace an operation with a new op. |
684 | class ReplaceWithNewOp : public RewritePattern { |
685 | public: |
686 | ReplaceWithNewOp(MLIRContext *context) |
687 | : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} |
688 | |
689 | LogicalResult matchAndRewrite(Operation *op, |
690 | PatternRewriter &rewriter) const override { |
691 | auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op" ); |
692 | if (!newName) |
693 | return failure(); |
694 | Operation *newOp = rewriter.create( |
695 | op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(), |
696 | op->getOperands(), op->getResultTypes()); |
697 | rewriter.replaceOp(op, newValues: newOp->getResults()); |
698 | return success(); |
699 | } |
700 | }; |
701 | |
702 | // Test pattern to erase an operation. |
703 | class EraseOp : public RewritePattern { |
704 | public: |
705 | EraseOp(MLIRContext *context) |
706 | : RewritePattern("test.erase_op" , /*benefit=*/1, context) {} |
707 | LogicalResult matchAndRewrite(Operation *op, |
708 | PatternRewriter &rewriter) const override { |
709 | rewriter.eraseOp(op); |
710 | return success(); |
711 | } |
712 | }; |
713 | } // namespace |
714 | |
715 | void mlir::test::ApplyTestPatternsOp::populatePatterns( |
716 | RewritePatternSet &patterns) { |
717 | patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext()); |
718 | } |
719 | |
720 | void mlir::test::TestReEnterRegionOp::getEffects( |
721 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
722 | transform::consumesHandle(getOperands(), effects); |
723 | transform::modifiesPayload(effects); |
724 | } |
725 | |
726 | DiagnosedSilenceableFailure |
727 | mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter, |
728 | transform::TransformResults &results, |
729 | transform::TransformState &state) { |
730 | |
731 | SmallVector<SmallVector<transform::MappedValue>> mappings; |
732 | for (BlockArgument arg : getBody().front().getArguments()) { |
733 | mappings.emplace_back(llvm::to_vector(llvm::map_range( |
734 | state.getPayloadOps(getOperand(arg.getArgNumber())), |
735 | [](Operation *op) -> transform::MappedValue { return op; }))); |
736 | } |
737 | |
738 | for (int i = 0; i < 4; ++i) { |
739 | auto scope = state.make_region_scope(getBody()); |
740 | for (BlockArgument arg : getBody().front().getArguments()) { |
741 | if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()]))) |
742 | return DiagnosedSilenceableFailure::definiteFailure(); |
743 | } |
744 | for (Operation &op : getBody().front().without_terminator()) { |
745 | DiagnosedSilenceableFailure diag = |
746 | state.applyTransform(cast<transform::TransformOpInterface>(op)); |
747 | if (!diag.succeeded()) |
748 | return diag; |
749 | } |
750 | } |
751 | return DiagnosedSilenceableFailure::success(); |
752 | } |
753 | |
754 | LogicalResult mlir::test::TestReEnterRegionOp::verify() { |
755 | if (getNumOperands() != getBody().front().getNumArguments()) { |
756 | return emitOpError() << "expects as many operands as block arguments" ; |
757 | } |
758 | return success(); |
759 | } |
760 | |
761 | DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply( |
762 | transform::TransformRewriter &rewriter, |
763 | transform::TransformResults &results, transform::TransformState &state) { |
764 | auto originalOps = state.getPayloadOps(getOriginal()); |
765 | auto replacementOps = state.getPayloadOps(getReplacement()); |
766 | if (llvm::range_size(originalOps) != llvm::range_size(replacementOps)) |
767 | return emitSilenceableError() << "expected same number of original and " |
768 | "replacement payload operations" ; |
769 | for (const auto &[original, replacement] : |
770 | llvm::zip(originalOps, replacementOps)) { |
771 | if (failed( |
772 | rewriter.notifyPayloadOperationReplaced(original, replacement))) { |
773 | auto diag = emitSilenceableError() |
774 | << "unable to replace payload op in transform mapping" ; |
775 | diag.attachNote(original->getLoc()) << "original payload op" ; |
776 | diag.attachNote(replacement->getLoc()) << "replacement payload op" ; |
777 | return diag; |
778 | } |
779 | } |
780 | return DiagnosedSilenceableFailure::success(); |
781 | } |
782 | |
783 | void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects( |
784 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
785 | transform::onlyReadsHandle(getOriginal(), effects); |
786 | transform::onlyReadsHandle(getReplacement(), effects); |
787 | } |
788 | |
789 | DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne( |
790 | transform::TransformRewriter &rewriter, Operation *target, |
791 | transform::ApplyToEachResultList &results, |
792 | transform::TransformState &state) { |
793 | // Provide some IR that does not verify. |
794 | rewriter.setInsertionPointToStart(&target->getRegion(0).front()); |
795 | rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(), |
796 | ValueRange(), /*failToVerify=*/true); |
797 | return DiagnosedSilenceableFailure::success(); |
798 | } |
799 | |
800 | void mlir::test::TestProduceInvalidIR::getEffects( |
801 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
802 | transform::onlyReadsHandle(getTarget(), effects); |
803 | transform::modifiesPayload(effects); |
804 | } |
805 | |
806 | namespace { |
807 | /// Test conversion pattern that replaces ops with the "replace_with_new_op" |
808 | /// attribute with "test.new_op". |
809 | class ReplaceWithNewOpConversion : public ConversionPattern { |
810 | public: |
811 | ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context) |
812 | : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(), |
813 | /*benefit=*/1, context) {} |
814 | |
815 | LogicalResult |
816 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
817 | ConversionPatternRewriter &rewriter) const override { |
818 | if (!op->hasAttr(name: "replace_with_new_op" )) |
819 | return failure(); |
820 | SmallVector<Type> newResultTypes; |
821 | if (failed(result: getTypeConverter()->convertTypes(types: op->getResultTypes(), |
822 | results&: newResultTypes))) |
823 | return failure(); |
824 | Operation *newOp = rewriter.create( |
825 | op->getLoc(), |
826 | OperationName("test.new_op" , op->getContext()).getIdentifier(), |
827 | operands, newResultTypes); |
828 | rewriter.replaceOp(op, newValues: newOp->getResults()); |
829 | return success(); |
830 | } |
831 | }; |
832 | } // namespace |
833 | |
834 | void mlir::test::ApplyTestConversionPatternsOp::populatePatterns( |
835 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
836 | patterns.insert<ReplaceWithNewOpConversion>(typeConverter, |
837 | patterns.getContext()); |
838 | } |
839 | |
840 | namespace { |
841 | /// Test type converter that converts tensor types to memref types. |
842 | class TestTypeConverter : public TypeConverter { |
843 | public: |
844 | TestTypeConverter() { |
845 | addConversion(callback: [](Type t) { return t; }); |
846 | addConversion(callback: [](RankedTensorType type) -> Type { |
847 | return MemRefType::get(type.getShape(), type.getElementType()); |
848 | }); |
849 | auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType, |
850 | ValueRange inputs, |
851 | Location loc) -> std::optional<Value> { |
852 | if (inputs.size() != 1) |
853 | return std::nullopt; |
854 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
855 | .getResult(0); |
856 | }; |
857 | addSourceMaterialization(callback&: unrealizedCastConverter); |
858 | addTargetMaterialization(callback&: unrealizedCastConverter); |
859 | } |
860 | }; |
861 | } // namespace |
862 | |
863 | std::unique_ptr<::mlir::TypeConverter> |
864 | mlir::test::TestTypeConverterOp::getTypeConverter() { |
865 | return std::make_unique<TestTypeConverter>(); |
866 | } |
867 | |
868 | namespace { |
869 | /// Test extension of the Transform dialect. Registers additional ops and |
870 | /// declares PDL as dependent dialect since the additional ops are using PDL |
871 | /// types for operands and results. |
872 | class TestTransformDialectExtension |
873 | : public transform::TransformDialectExtension< |
874 | TestTransformDialectExtension> { |
875 | public: |
876 | using Base::Base; |
877 | |
878 | void init() { |
879 | declareDependentDialect<pdl::PDLDialect>(); |
880 | registerTransformOps<TestTransformOp, |
881 | TestTransformUnrestrictedOpNoInterface, |
882 | #define GET_OP_LIST |
883 | #include "TestTransformDialectExtension.cpp.inc" |
884 | >(); |
885 | registerTypes< |
886 | #define GET_TYPEDEF_LIST |
887 | #include "TestTransformDialectExtensionTypes.cpp.inc" |
888 | >(); |
889 | |
890 | auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &, |
891 | ArrayRef<PDLValue> pdlValues) { |
892 | for (const PDLValue &pdlValue : pdlValues) { |
893 | if (Operation *op = pdlValue.dyn_cast<Operation *>()) { |
894 | op->emitWarning() << "from PDL constraint" ; |
895 | } |
896 | } |
897 | return success(); |
898 | }; |
899 | |
900 | addDialectDataInitializer<transform::PDLMatchHooks>( |
901 | [&](transform::PDLMatchHooks &hooks) { |
902 | llvm::StringMap<PDLConstraintFunction> constraints; |
903 | constraints.try_emplace("verbose_constraint" , verboseConstraint); |
904 | hooks.mergeInPDLMatchHooks(constraintFns: std::move(constraints)); |
905 | }); |
906 | } |
907 | }; |
908 | } // namespace |
909 | |
910 | // These are automatically generated by ODS but are not used as the Transform |
911 | // dialect uses a different dispatch mechanism to support dialect extensions. |
912 | LLVM_ATTRIBUTE_UNUSED static OptionalParseResult |
913 | generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); |
914 | LLVM_ATTRIBUTE_UNUSED static LogicalResult |
915 | generatedTypePrinter(Type def, AsmPrinter &printer); |
916 | |
917 | #define GET_TYPEDEF_CLASSES |
918 | #include "TestTransformDialectExtensionTypes.cpp.inc" |
919 | |
920 | #define GET_OP_CLASSES |
921 | #include "TestTransformDialectExtension.cpp.inc" |
922 | |
923 | void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { |
924 | registry.addExtensions<TestTransformDialectExtension>(); |
925 | } |
926 | |