1//===- TestTransformDialectExtension.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines an extension of the MLIR Transform dialect for testing
10// purposes.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TestTransformDialectExtension.h"
15#include "TestTransformStateExtension.h"
16#include "mlir/Dialect/PDL/IR/PDL.h"
17#include "mlir/Dialect/Transform/IR/TransformDialect.h"
18#include "mlir/Dialect/Transform/IR/TransformOps.h"
19#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
21#include "mlir/IR/OpImplementation.h"
22#include "mlir/IR/PatternMatch.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/Compiler.h"
26#include "llvm/Support/raw_ostream.h"
27
28using namespace mlir;
29
30namespace {
31/// Simple transform op defined outside of the dialect. Just emits a remark when
32/// applied. This op is defined in C++ to test that C++ definitions also work
33/// for op injection into the Transform dialect.
34class TestTransformOp
35 : public Op<TestTransformOp, transform::TransformOpInterface::Trait,
36 MemoryEffectOpInterface::Trait> {
37public:
38 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
39
40 using Op::Op;
41
42 static ArrayRef<StringRef> getAttributeNames() { return {}; }
43
44 static constexpr llvm::StringLiteral getOperationName() {
45 return llvm::StringLiteral("transform.test_transform_op");
46 }
47
48 DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
49 transform::TransformResults &results,
50 transform::TransformState &state) {
51 InFlightDiagnostic remark = emitRemark() << "applying transformation";
52 if (Attribute message = getMessage())
53 remark << " " << message;
54
55 return DiagnosedSilenceableFailure::success();
56 }
57
58 Attribute getMessage() {
59 return getOperation()->getDiscardableAttr("message");
60 }
61
62 static ParseResult parse(OpAsmParser &parser, OperationState &state) {
63 StringAttr message;
64 OptionalParseResult result = parser.parseOptionalAttribute(message);
65 if (!result.has_value())
66 return success();
67
68 if (result.value().succeeded())
69 state.addAttribute("message", message);
70 return result.value();
71 }
72
73 void print(OpAsmPrinter &printer) {
74 if (getMessage())
75 printer << " " << getMessage();
76 }
77
78 // No side effects.
79 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
80};
81
82/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
83/// in cases where it is attached to ops that do not comply with the trait
84/// requirements. This op cannot be defined in ODS because ODS generates strict
85/// verifiers that overalp with those in the trait and run earlier.
86class TestTransformUnrestrictedOpNoInterface
87 : public Op<TestTransformUnrestrictedOpNoInterface,
88 transform::PossibleTopLevelTransformOpTrait,
89 transform::TransformOpInterface::Trait,
90 MemoryEffectOpInterface::Trait> {
91public:
92 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
93 TestTransformUnrestrictedOpNoInterface)
94
95 using Op::Op;
96
97 static ArrayRef<StringRef> getAttributeNames() { return {}; }
98
99 static constexpr llvm::StringLiteral getOperationName() {
100 return llvm::StringLiteral(
101 "transform.test_transform_unrestricted_op_no_interface");
102 }
103
104 DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
105 transform::TransformResults &results,
106 transform::TransformState &state) {
107 return DiagnosedSilenceableFailure::success();
108 }
109
110 // No side effects.
111 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
112};
113} // namespace
114
115DiagnosedSilenceableFailure
116mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
117 transform::TransformRewriter &rewriter,
118 transform::TransformResults &results, transform::TransformState &state) {
119 if (getOperation()->getNumOperands() != 0) {
120 results.set(cast<OpResult>(getResult()),
121 {getOperation()->getOperand(0).getDefiningOp()});
122 } else {
123 results.set(cast<OpResult>(getResult()), {getOperation()});
124 }
125 return DiagnosedSilenceableFailure::success();
126}
127
128void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects(
129 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
130 if (getOperand())
131 transform::onlyReadsHandle(getOperand(), effects);
132 transform::producesHandle(getRes(), effects);
133}
134
135DiagnosedSilenceableFailure
136mlir::test::TestProduceValueHandleToSelfOperand::apply(
137 transform::TransformRewriter &rewriter,
138 transform::TransformResults &results, transform::TransformState &state) {
139 results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
140 return DiagnosedSilenceableFailure::success();
141}
142
143void mlir::test::TestProduceValueHandleToSelfOperand::getEffects(
144 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
145 transform::onlyReadsHandle(getIn(), effects);
146 transform::producesHandle(getOut(), effects);
147 transform::onlyReadsPayload(effects);
148}
149
150DiagnosedSilenceableFailure
151mlir::test::TestProduceValueHandleToResult::applyToOne(
152 transform::TransformRewriter &rewriter, Operation *target,
153 transform::ApplyToEachResultList &results,
154 transform::TransformState &state) {
155 if (target->getNumResults() <= getNumber())
156 return emitSilenceableError() << "payload has no result #" << getNumber();
157 results.push_back(target->getResult(getNumber()));
158 return DiagnosedSilenceableFailure::success();
159}
160
161void mlir::test::TestProduceValueHandleToResult::getEffects(
162 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
163 transform::onlyReadsHandle(getIn(), effects);
164 transform::producesHandle(getOut(), effects);
165 transform::onlyReadsPayload(effects);
166}
167
168DiagnosedSilenceableFailure
169mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne(
170 transform::TransformRewriter &rewriter, Operation *target,
171 transform::ApplyToEachResultList &results,
172 transform::TransformState &state) {
173 if (!target->getBlock())
174 return emitSilenceableError() << "payload has no parent block";
175 if (target->getBlock()->getNumArguments() <= getNumber())
176 return emitSilenceableError()
177 << "parent of the payload has no argument #" << getNumber();
178 results.push_back(target->getBlock()->getArgument(getNumber()));
179 return DiagnosedSilenceableFailure::success();
180}
181
182void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects(
183 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
184 transform::onlyReadsHandle(getIn(), effects);
185 transform::producesHandle(getOut(), effects);
186 transform::onlyReadsPayload(effects);
187}
188
189bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() {
190 return getAllowRepeatedHandles();
191}
192
193DiagnosedSilenceableFailure
194mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter,
195 transform::TransformResults &results,
196 transform::TransformState &state) {
197 return DiagnosedSilenceableFailure::success();
198}
199
200void mlir::test::TestConsumeOperand::getEffects(
201 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
202 transform::consumesHandle(getOperand(), effects);
203 if (getSecondOperand())
204 transform::consumesHandle(getSecondOperand(), effects);
205 transform::modifiesPayload(effects);
206}
207
208DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply(
209 transform::TransformRewriter &rewriter,
210 transform::TransformResults &results, transform::TransformState &state) {
211 auto payload = state.getPayloadOps(getOperand());
212 assert(llvm::hasSingleElement(payload) && "expected a single target op");
213 if ((*payload.begin())->getName().getStringRef() != getOpKind()) {
214 return emitSilenceableError()
215 << "op expected the operand to be associated a payload op of kind "
216 << getOpKind() << " got "
217 << (*payload.begin())->getName().getStringRef();
218 }
219
220 emitRemark() << "succeeded";
221 return DiagnosedSilenceableFailure::success();
222}
223
224void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects(
225 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
226 transform::consumesHandle(getOperand(), effects);
227 transform::modifiesPayload(effects);
228}
229
230DiagnosedSilenceableFailure
231mlir::test::TestSucceedIfOperandOfOpKind::matchOperation(
232 Operation *op, transform::TransformResults &results,
233 transform::TransformState &state) {
234 if (op->getName().getStringRef() != getOpKind()) {
235 return emitSilenceableError()
236 << "op expected the operand to be associated with a payload op of "
237 "kind "
238 << getOpKind() << " got " << op->getName().getStringRef();
239 }
240 return DiagnosedSilenceableFailure::success();
241}
242
243void mlir::test::TestSucceedIfOperandOfOpKind::getEffects(
244 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
245 transform::onlyReadsHandle(getOperand(), effects);
246 transform::onlyReadsPayload(effects);
247}
248
249DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply(
250 transform::TransformRewriter &rewriter,
251 transform::TransformResults &results, transform::TransformState &state) {
252 state.addExtension<TestTransformStateExtension>(getMessageAttr());
253 return DiagnosedSilenceableFailure::success();
254}
255
256DiagnosedSilenceableFailure
257mlir::test::TestCheckIfTestExtensionPresentOp::apply(
258 transform::TransformRewriter &rewriter,
259 transform::TransformResults &results, transform::TransformState &state) {
260 auto *extension = state.getExtension<TestTransformStateExtension>();
261 if (!extension) {
262 emitRemark() << "extension absent";
263 return DiagnosedSilenceableFailure::success();
264 }
265
266 InFlightDiagnostic diag = emitRemark()
267 << "extension present, " << extension->getMessage();
268 for (Operation *payload : state.getPayloadOps(getOperand())) {
269 diag.attachNote(payload->getLoc()) << "associated payload op";
270#ifndef NDEBUG
271 SmallVector<Value> handles;
272 assert(succeeded(state.getHandlesForPayloadOp(payload, handles)));
273 assert(llvm::is_contained(handles, getOperand()) &&
274 "inconsistent mapping between transform IR handles and payload IR "
275 "operations");
276#endif // NDEBUG
277 }
278
279 return DiagnosedSilenceableFailure::success();
280}
281
282void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects(
283 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
284 transform::onlyReadsHandle(getOperand(), effects);
285 transform::onlyReadsPayload(effects);
286}
287
288DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
289 transform::TransformRewriter &rewriter,
290 transform::TransformResults &results, transform::TransformState &state) {
291 auto *extension = state.getExtension<TestTransformStateExtension>();
292 if (!extension)
293 return emitDefiniteFailure("TestTransformStateExtension missing");
294
295 if (failed(extension->updateMapping(
296 *state.getPayloadOps(getOperand()).begin(), getOperation())))
297 return DiagnosedSilenceableFailure::definiteFailure();
298 if (getNumResults() > 0)
299 results.set(cast<OpResult>(getResult(0)), {getOperation()});
300 return DiagnosedSilenceableFailure::success();
301}
302
303void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects(
304 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
305 transform::onlyReadsHandle(getOperand(), effects);
306 transform::producesHandle(getOut(), effects);
307 transform::onlyReadsPayload(effects);
308}
309
310DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
311 transform::TransformRewriter &rewriter,
312 transform::TransformResults &results, transform::TransformState &state) {
313 state.removeExtension<TestTransformStateExtension>();
314 return DiagnosedSilenceableFailure::success();
315}
316
317DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(
318 transform::TransformRewriter &rewriter,
319 transform::TransformResults &results, transform::TransformState &state) {
320 auto payloadOps = state.getPayloadOps(getTarget());
321 auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
322 results.set(llvm::cast<OpResult>(getResult()), reversedOps);
323 return DiagnosedSilenceableFailure::success();
324}
325
326DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply(
327 transform::TransformRewriter &rewriter,
328 transform::TransformResults &results, transform::TransformState &state) {
329 return DiagnosedSilenceableFailure::success();
330}
331
332void mlir::test::TestTransformOpWithRegions::getEffects(
333 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
334
335DiagnosedSilenceableFailure
336mlir::test::TestBranchingTransformOpTerminator::apply(
337 transform::TransformRewriter &rewriter,
338 transform::TransformResults &results, transform::TransformState &state) {
339 return DiagnosedSilenceableFailure::success();
340}
341
342void mlir::test::TestBranchingTransformOpTerminator::getEffects(
343 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
344
345DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
346 transform::TransformRewriter &rewriter,
347 transform::TransformResults &results, transform::TransformState &state) {
348 emitRemark() << getRemark();
349 for (Operation *op : state.getPayloadOps(getTarget()))
350 rewriter.eraseOp(op);
351
352 if (getFailAfterErase())
353 return emitSilenceableError() << "silenceable error";
354 return DiagnosedSilenceableFailure::success();
355}
356
357void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects(
358 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
359 transform::consumesHandle(getTarget(), effects);
360 transform::modifiesPayload(effects);
361}
362
363DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
364 transform::TransformRewriter &rewriter, Operation *target,
365 transform::ApplyToEachResultList &results,
366 transform::TransformState &state) {
367 OperationState opState(target->getLoc(), "foo");
368 results.push_back(OpBuilder(target).create(opState));
369 return DiagnosedSilenceableFailure::success();
370}
371
372DiagnosedSilenceableFailure
373mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
374 transform::TransformRewriter &rewriter, Operation *target,
375 transform::ApplyToEachResultList &results,
376 transform::TransformState &state) {
377 static int count = 0;
378 if (count++ == 0) {
379 OperationState opState(target->getLoc(), "foo");
380 results.push_back(OpBuilder(target).create(opState));
381 }
382 return DiagnosedSilenceableFailure::success();
383}
384
385DiagnosedSilenceableFailure
386mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
387 transform::TransformRewriter &rewriter, Operation *target,
388 transform::ApplyToEachResultList &results,
389 transform::TransformState &state) {
390 OperationState opState(target->getLoc(), "foo");
391 results.push_back(OpBuilder(target).create(opState));
392 results.push_back(OpBuilder(target).create(opState));
393 return DiagnosedSilenceableFailure::success();
394}
395
396DiagnosedSilenceableFailure
397mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
398 transform::TransformRewriter &rewriter, Operation *target,
399 transform::ApplyToEachResultList &results,
400 transform::TransformState &state) {
401 OperationState opState(target->getLoc(), "foo");
402 results.push_back(nullptr);
403 results.push_back(OpBuilder(target).create(opState));
404 return DiagnosedSilenceableFailure::success();
405}
406
407DiagnosedSilenceableFailure
408mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
409 transform::TransformRewriter &rewriter, Operation *target,
410 transform::ApplyToEachResultList &results,
411 transform::TransformState &state) {
412 if (target->hasAttr("target_me"))
413 return DiagnosedSilenceableFailure::success();
414 return emitDefaultSilenceableFailure(target);
415}
416
417DiagnosedSilenceableFailure
418mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
419 transform::TransformResults &results,
420 transform::TransformState &state) {
421 results.set(llvm::cast<OpResult>(getCopy()),
422 state.getPayloadOps(getHandle()));
423 return DiagnosedSilenceableFailure::success();
424}
425
426void mlir::test::TestCopyPayloadOp::getEffects(
427 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
428 transform::onlyReadsHandle(getHandle(), effects);
429 transform::producesHandle(getCopy(), effects);
430 transform::onlyReadsPayload(effects);
431}
432
433DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
434 Location loc, ArrayRef<Operation *> payload) const {
435 if (payload.empty())
436 return DiagnosedSilenceableFailure::success();
437
438 for (Operation *op : payload) {
439 if (op->getName().getDialectNamespace() != "test") {
440 return emitSilenceableError(loc) << "expected the payload operation to "
441 "belong to the 'test' dialect";
442 }
443 }
444
445 return DiagnosedSilenceableFailure::success();
446}
447
448DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
449 Location loc, ArrayRef<Attribute> payload) const {
450 for (Attribute attr : payload) {
451 auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
452 if (integerAttr && integerAttr.getType().isSignlessInteger(32))
453 continue;
454 return emitSilenceableError(loc)
455 << "expected the parameter to be a i32 integer attribute";
456 }
457
458 return DiagnosedSilenceableFailure::success();
459}
460
461void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
462 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
463 transform::onlyReadsHandle(getTarget(), effects);
464}
465
466DiagnosedSilenceableFailure
467mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
468 transform::TransformRewriter &rewriter,
469 transform::TransformResults &results, transform::TransformState &state) {
470 int64_t count = 0;
471 for (Operation *op : state.getPayloadOps(getTarget())) {
472 op->walk([&](Operation *nested) {
473 SmallVector<Value> handles;
474 (void)state.getHandlesForPayloadOp(nested, handles);
475 count += handles.size();
476 });
477 }
478 emitRemark() << count << " handles nested under";
479 return DiagnosedSilenceableFailure::success();
480}
481
482DiagnosedSilenceableFailure
483mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter,
484 transform::TransformResults &results,
485 transform::TransformState &state) {
486 SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0);
487 if (Value param = getParam()) {
488 values = llvm::to_vector(
489 llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
490 return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue(
491 UINT32_MAX);
492 }));
493 }
494
495 Builder builder(getContext());
496 SmallVector<Attribute> result = llvm::to_vector(
497 llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
498 return builder.getI32IntegerAttr(value + getAddendum());
499 }));
500 results.setParams(llvm::cast<OpResult>(getResult()), result);
501 return DiagnosedSilenceableFailure::success();
502}
503
504DiagnosedSilenceableFailure
505mlir::test::TestProduceParamWithNumberOfTestOps::apply(
506 transform::TransformRewriter &rewriter,
507 transform::TransformResults &results, transform::TransformState &state) {
508 Builder builder(getContext());
509 SmallVector<Attribute> result = llvm::to_vector(
510 llvm::map_range(state.getPayloadOps(getHandle()),
511 [&builder](Operation *payload) -> Attribute {
512 int32_t count = 0;
513 payload->walk([&count](Operation *op) {
514 if (op->getName().getDialectNamespace() == "test")
515 ++count;
516 });
517 return builder.getI32IntegerAttr(count);
518 }));
519 results.setParams(llvm::cast<OpResult>(getResult()), result);
520 return DiagnosedSilenceableFailure::success();
521}
522
523DiagnosedSilenceableFailure
524mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter,
525 transform::TransformResults &results,
526 transform::TransformState &state) {
527 results.setParams(llvm::cast<OpResult>(getResult()), getAttr());
528 return DiagnosedSilenceableFailure::success();
529}
530
531void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects(
532 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
533 transform::onlyReadsHandle(getIn(), effects);
534 transform::producesHandle(getOut(), effects);
535 transform::producesHandle(getParam(), effects);
536}
537
538DiagnosedSilenceableFailure
539mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
540 transform::TransformRewriter &rewriter, Operation *target,
541 ::transform::ApplyToEachResultList &results,
542 ::transform::TransformState &state) {
543 Builder builder(getContext());
544 if (getFirstResultIsParam()) {
545 results.push_back(builder.getI64IntegerAttr(0));
546 } else if (getFirstResultIsNull()) {
547 results.push_back(nullptr);
548 } else {
549 results.push_back(*state.getPayloadOps(getIn()).begin());
550 }
551
552 if (getSecondResultIsHandle()) {
553 results.push_back(*state.getPayloadOps(getIn()).begin());
554 } else {
555 results.push_back(builder.getI64IntegerAttr(42));
556 }
557
558 return DiagnosedSilenceableFailure::success();
559}
560
561void mlir::test::TestProduceNullPayloadOp::getEffects(
562 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
563 transform::producesHandle(getOut(), effects);
564}
565
566DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
567 transform::TransformRewriter &rewriter,
568 transform::TransformResults &results, transform::TransformState &state) {
569 SmallVector<Operation *, 1> null({nullptr});
570 results.set(llvm::cast<OpResult>(getOut()), null);
571 return DiagnosedSilenceableFailure::success();
572}
573
574DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply(
575 transform::TransformRewriter &rewriter,
576 transform::TransformResults &results, transform::TransformState &state) {
577 results.set(cast<OpResult>(getOut()), {});
578 return DiagnosedSilenceableFailure::success();
579}
580
581void mlir::test::TestProduceNullParamOp::getEffects(
582 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
583 transform::producesHandle(getOut(), effects);
584}
585
586DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(
587 transform::TransformRewriter &rewriter,
588 transform::TransformResults &results, transform::TransformState &state) {
589 results.setParams(llvm::cast<OpResult>(getOut()), Attribute());
590 return DiagnosedSilenceableFailure::success();
591}
592
593void mlir::test::TestProduceNullValueOp::getEffects(
594 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
595 transform::producesHandle(getOut(), effects);
596}
597
598DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
599 transform::TransformRewriter &rewriter,
600 transform::TransformResults &results, transform::TransformState &state) {
601 results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
602 return DiagnosedSilenceableFailure::success();
603}
604
605void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
606 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
607 if (getHasOperandEffect())
608 transform::consumesHandle(getIn(), effects);
609
610 if (getHasResultEffect())
611 transform::producesHandle(getOut(), effects);
612 else
613 transform::onlyReadsHandle(getOut(), effects);
614
615 if (getModifiesPayload())
616 transform::modifiesPayload(effects);
617}
618
619DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
620 transform::TransformRewriter &rewriter,
621 transform::TransformResults &results, transform::TransformState &state) {
622 results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn()));
623 return DiagnosedSilenceableFailure::success();
624}
625
626void mlir::test::TestTrackedRewriteOp::getEffects(
627 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
628 transform::onlyReadsHandle(getIn(), effects);
629 transform::modifiesPayload(effects);
630}
631
632void mlir::test::TestDummyPayloadOp::getEffects(
633 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
634 for (OpResult result : getResults())
635 transform::producesHandle(result, effects);
636}
637
638LogicalResult mlir::test::TestDummyPayloadOp::verify() {
639 if (getFailToVerify())
640 return emitOpError() << "fail_to_verify is set";
641 return success();
642}
643
644DiagnosedSilenceableFailure
645mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter,
646 transform::TransformResults &results,
647 transform::TransformState &state) {
648 int64_t numIterations = 0;
649
650 // `getPayloadOps` returns an iterator that skips ops that are erased in the
651 // loop body. Replacement ops are not enumerated.
652 for (Operation *op : state.getPayloadOps(getIn())) {
653 ++numIterations;
654 (void)op;
655
656 // Erase all payload ops. The outer loop should have only one iteration.
657 for (Operation *op : state.getPayloadOps(getIn())) {
658 rewriter.setInsertionPoint(op);
659 if (op->hasAttr("erase_me")) {
660 rewriter.eraseOp(op);
661 continue;
662 }
663 if (!op->hasAttr("replace_me")) {
664 continue;
665 }
666
667 SmallVector<NamedAttribute> attributes;
668 attributes.emplace_back(rewriter.getStringAttr("new_op"),
669 rewriter.getUnitAttr());
670 OperationState opState(op->getLoc(), op->getName().getIdentifier(),
671 /*operands=*/ValueRange(),
672 /*types=*/op->getResultTypes(), attributes);
673 Operation *newOp = rewriter.create(opState);
674 rewriter.replaceOp(op, newOp->getResults());
675 }
676 }
677
678 emitRemark() << numIterations << " iterations";
679 return DiagnosedSilenceableFailure::success();
680}
681
682namespace {
683// Test pattern to replace an operation with a new op.
684class ReplaceWithNewOp : public RewritePattern {
685public:
686 ReplaceWithNewOp(MLIRContext *context)
687 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
688
689 LogicalResult matchAndRewrite(Operation *op,
690 PatternRewriter &rewriter) const override {
691 auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op");
692 if (!newName)
693 return failure();
694 Operation *newOp = rewriter.create(
695 op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(),
696 op->getOperands(), op->getResultTypes());
697 rewriter.replaceOp(op, newValues: newOp->getResults());
698 return success();
699 }
700};
701
702// Test pattern to erase an operation.
703class EraseOp : public RewritePattern {
704public:
705 EraseOp(MLIRContext *context)
706 : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
707 LogicalResult matchAndRewrite(Operation *op,
708 PatternRewriter &rewriter) const override {
709 rewriter.eraseOp(op);
710 return success();
711 }
712};
713} // namespace
714
715void mlir::test::ApplyTestPatternsOp::populatePatterns(
716 RewritePatternSet &patterns) {
717 patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
718}
719
720void mlir::test::TestReEnterRegionOp::getEffects(
721 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
722 transform::consumesHandle(getOperands(), effects);
723 transform::modifiesPayload(effects);
724}
725
726DiagnosedSilenceableFailure
727mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter,
728 transform::TransformResults &results,
729 transform::TransformState &state) {
730
731 SmallVector<SmallVector<transform::MappedValue>> mappings;
732 for (BlockArgument arg : getBody().front().getArguments()) {
733 mappings.emplace_back(llvm::to_vector(llvm::map_range(
734 state.getPayloadOps(getOperand(arg.getArgNumber())),
735 [](Operation *op) -> transform::MappedValue { return op; })));
736 }
737
738 for (int i = 0; i < 4; ++i) {
739 auto scope = state.make_region_scope(getBody());
740 for (BlockArgument arg : getBody().front().getArguments()) {
741 if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()])))
742 return DiagnosedSilenceableFailure::definiteFailure();
743 }
744 for (Operation &op : getBody().front().without_terminator()) {
745 DiagnosedSilenceableFailure diag =
746 state.applyTransform(cast<transform::TransformOpInterface>(op));
747 if (!diag.succeeded())
748 return diag;
749 }
750 }
751 return DiagnosedSilenceableFailure::success();
752}
753
754LogicalResult mlir::test::TestReEnterRegionOp::verify() {
755 if (getNumOperands() != getBody().front().getNumArguments()) {
756 return emitOpError() << "expects as many operands as block arguments";
757 }
758 return success();
759}
760
761DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply(
762 transform::TransformRewriter &rewriter,
763 transform::TransformResults &results, transform::TransformState &state) {
764 auto originalOps = state.getPayloadOps(getOriginal());
765 auto replacementOps = state.getPayloadOps(getReplacement());
766 if (llvm::range_size(originalOps) != llvm::range_size(replacementOps))
767 return emitSilenceableError() << "expected same number of original and "
768 "replacement payload operations";
769 for (const auto &[original, replacement] :
770 llvm::zip(originalOps, replacementOps)) {
771 if (failed(
772 rewriter.notifyPayloadOperationReplaced(original, replacement))) {
773 auto diag = emitSilenceableError()
774 << "unable to replace payload op in transform mapping";
775 diag.attachNote(original->getLoc()) << "original payload op";
776 diag.attachNote(replacement->getLoc()) << "replacement payload op";
777 return diag;
778 }
779 }
780 return DiagnosedSilenceableFailure::success();
781}
782
783void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects(
784 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
785 transform::onlyReadsHandle(getOriginal(), effects);
786 transform::onlyReadsHandle(getReplacement(), effects);
787}
788
789DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne(
790 transform::TransformRewriter &rewriter, Operation *target,
791 transform::ApplyToEachResultList &results,
792 transform::TransformState &state) {
793 // Provide some IR that does not verify.
794 rewriter.setInsertionPointToStart(&target->getRegion(0).front());
795 rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(),
796 ValueRange(), /*failToVerify=*/true);
797 return DiagnosedSilenceableFailure::success();
798}
799
800void mlir::test::TestProduceInvalidIR::getEffects(
801 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
802 transform::onlyReadsHandle(getTarget(), effects);
803 transform::modifiesPayload(effects);
804}
805
806namespace {
807/// Test conversion pattern that replaces ops with the "replace_with_new_op"
808/// attribute with "test.new_op".
809class ReplaceWithNewOpConversion : public ConversionPattern {
810public:
811 ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context)
812 : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(),
813 /*benefit=*/1, context) {}
814
815 LogicalResult
816 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
817 ConversionPatternRewriter &rewriter) const override {
818 if (!op->hasAttr(name: "replace_with_new_op"))
819 return failure();
820 SmallVector<Type> newResultTypes;
821 if (failed(result: getTypeConverter()->convertTypes(types: op->getResultTypes(),
822 results&: newResultTypes)))
823 return failure();
824 Operation *newOp = rewriter.create(
825 op->getLoc(),
826 OperationName("test.new_op", op->getContext()).getIdentifier(),
827 operands, newResultTypes);
828 rewriter.replaceOp(op, newValues: newOp->getResults());
829 return success();
830 }
831};
832} // namespace
833
834void mlir::test::ApplyTestConversionPatternsOp::populatePatterns(
835 TypeConverter &typeConverter, RewritePatternSet &patterns) {
836 patterns.insert<ReplaceWithNewOpConversion>(typeConverter,
837 patterns.getContext());
838}
839
840namespace {
841/// Test type converter that converts tensor types to memref types.
842class TestTypeConverter : public TypeConverter {
843public:
844 TestTypeConverter() {
845 addConversion(callback: [](Type t) { return t; });
846 addConversion(callback: [](RankedTensorType type) -> Type {
847 return MemRefType::get(type.getShape(), type.getElementType());
848 });
849 auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType,
850 ValueRange inputs,
851 Location loc) -> std::optional<Value> {
852 if (inputs.size() != 1)
853 return std::nullopt;
854 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
855 .getResult(0);
856 };
857 addSourceMaterialization(callback&: unrealizedCastConverter);
858 addTargetMaterialization(callback&: unrealizedCastConverter);
859 }
860};
861} // namespace
862
863std::unique_ptr<::mlir::TypeConverter>
864mlir::test::TestTypeConverterOp::getTypeConverter() {
865 return std::make_unique<TestTypeConverter>();
866}
867
868namespace {
869/// Test extension of the Transform dialect. Registers additional ops and
870/// declares PDL as dependent dialect since the additional ops are using PDL
871/// types for operands and results.
872class TestTransformDialectExtension
873 : public transform::TransformDialectExtension<
874 TestTransformDialectExtension> {
875public:
876 using Base::Base;
877
878 void init() {
879 declareDependentDialect<pdl::PDLDialect>();
880 registerTransformOps<TestTransformOp,
881 TestTransformUnrestrictedOpNoInterface,
882#define GET_OP_LIST
883#include "TestTransformDialectExtension.cpp.inc"
884 >();
885 registerTypes<
886#define GET_TYPEDEF_LIST
887#include "TestTransformDialectExtensionTypes.cpp.inc"
888 >();
889
890 auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
891 ArrayRef<PDLValue> pdlValues) {
892 for (const PDLValue &pdlValue : pdlValues) {
893 if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
894 op->emitWarning() << "from PDL constraint";
895 }
896 }
897 return success();
898 };
899
900 addDialectDataInitializer<transform::PDLMatchHooks>(
901 [&](transform::PDLMatchHooks &hooks) {
902 llvm::StringMap<PDLConstraintFunction> constraints;
903 constraints.try_emplace("verbose_constraint", verboseConstraint);
904 hooks.mergeInPDLMatchHooks(constraintFns: std::move(constraints));
905 });
906 }
907};
908} // namespace
909
910// These are automatically generated by ODS but are not used as the Transform
911// dialect uses a different dispatch mechanism to support dialect extensions.
912LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
913generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
914LLVM_ATTRIBUTE_UNUSED static LogicalResult
915generatedTypePrinter(Type def, AsmPrinter &printer);
916
917#define GET_TYPEDEF_CLASSES
918#include "TestTransformDialectExtensionTypes.cpp.inc"
919
920#define GET_OP_CLASSES
921#include "TestTransformDialectExtension.cpp.inc"
922
923void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
924 registry.addExtensions<TestTransformDialectExtension>();
925}
926

source code of mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp