1//===- TestTransformDialectExtension.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines an extension of the MLIR Transform dialect for testing
10// purposes.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TestTransformDialectExtension.h"
15#include "TestTransformStateExtension.h"
16#include "mlir/Dialect/PDL/IR/PDL.h"
17#include "mlir/Dialect/Transform/IR/TransformDialect.h"
18#include "mlir/Dialect/Transform/IR/TransformOps.h"
19#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
21#include "mlir/IR/OpImplementation.h"
22#include "mlir/IR/PatternMatch.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/Compiler.h"
26#include "llvm/Support/raw_ostream.h"
27
28using namespace mlir;
29
30namespace {
31/// Simple transform op defined outside of the dialect. Just emits a remark when
32/// applied. This op is defined in C++ to test that C++ definitions also work
33/// for op injection into the Transform dialect.
34class TestTransformOp
35 : public Op<TestTransformOp, transform::TransformOpInterface::Trait,
36 MemoryEffectOpInterface::Trait> {
37public:
38 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
39
40 using Op::Op;
41
42 static ArrayRef<StringRef> getAttributeNames() { return {}; }
43
44 static constexpr llvm::StringLiteral getOperationName() {
45 return llvm::StringLiteral("transform.test_transform_op");
46 }
47
48 DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
49 transform::TransformResults &results,
50 transform::TransformState &state) {
51 InFlightDiagnostic remark = emitRemark() << "applying transformation";
52 if (Attribute message = getMessage())
53 remark << " " << message;
54
55 return DiagnosedSilenceableFailure::success();
56 }
57
58 Attribute getMessage() {
59 return getOperation()->getDiscardableAttr("message");
60 }
61
62 static ParseResult parse(OpAsmParser &parser, OperationState &state) {
63 StringAttr message;
64 OptionalParseResult result = parser.parseOptionalAttribute(message);
65 if (!result.has_value())
66 return success();
67
68 if (result.value().succeeded())
69 state.addAttribute("message", message);
70 return result.value();
71 }
72
73 void print(OpAsmPrinter &printer) {
74 if (getMessage())
75 printer << " " << getMessage();
76 }
77
78 // No side effects.
79 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
80};
81
82/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
83/// in cases where it is attached to ops that do not comply with the trait
84/// requirements. This op cannot be defined in ODS because ODS generates strict
85/// verifiers that overalp with those in the trait and run earlier.
86class TestTransformUnrestrictedOpNoInterface
87 : public Op<TestTransformUnrestrictedOpNoInterface,
88 transform::PossibleTopLevelTransformOpTrait,
89 transform::TransformOpInterface::Trait,
90 MemoryEffectOpInterface::Trait> {
91public:
92 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
93 TestTransformUnrestrictedOpNoInterface)
94
95 using Op::Op;
96
97 static ArrayRef<StringRef> getAttributeNames() { return {}; }
98
99 static constexpr llvm::StringLiteral getOperationName() {
100 return llvm::StringLiteral(
101 "transform.test_transform_unrestricted_op_no_interface");
102 }
103
104 DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
105 transform::TransformResults &results,
106 transform::TransformState &state) {
107 return DiagnosedSilenceableFailure::success();
108 }
109
110 // No side effects.
111 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
112};
113} // namespace
114
115DiagnosedSilenceableFailure
116mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
117 transform::TransformRewriter &rewriter,
118 transform::TransformResults &results, transform::TransformState &state) {
119 if (getOperation()->getNumOperands() != 0) {
120 results.set(cast<OpResult>(getResult()),
121 {getOperation()->getOperand(0).getDefiningOp()});
122 } else {
123 results.set(cast<OpResult>(getResult()), {getOperation()});
124 }
125 return DiagnosedSilenceableFailure::success();
126}
127
128void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects(
129 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
130 if (getOperand())
131 transform::onlyReadsHandle(getOperandMutable(), effects);
132 transform::producesHandle(getOperation()->getOpResults(), effects);
133}
134
135DiagnosedSilenceableFailure
136mlir::test::TestProduceValueHandleToSelfOperand::apply(
137 transform::TransformRewriter &rewriter,
138 transform::TransformResults &results, transform::TransformState &state) {
139 results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
140 return DiagnosedSilenceableFailure::success();
141}
142
143void mlir::test::TestProduceValueHandleToSelfOperand::getEffects(
144 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
145 transform::onlyReadsHandle(getInMutable(), effects);
146 transform::producesHandle(getOperation()->getOpResults(), effects);
147 transform::onlyReadsPayload(effects);
148}
149
150DiagnosedSilenceableFailure
151mlir::test::TestProduceValueHandleToResult::applyToOne(
152 transform::TransformRewriter &rewriter, Operation *target,
153 transform::ApplyToEachResultList &results,
154 transform::TransformState &state) {
155 if (target->getNumResults() <= getNumber())
156 return emitSilenceableError() << "payload has no result #" << getNumber();
157 results.push_back(target->getResult(getNumber()));
158 return DiagnosedSilenceableFailure::success();
159}
160
161void mlir::test::TestProduceValueHandleToResult::getEffects(
162 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
163 transform::onlyReadsHandle(getInMutable(), effects);
164 transform::producesHandle(getOperation()->getOpResults(), effects);
165 transform::onlyReadsPayload(effects);
166}
167
168DiagnosedSilenceableFailure
169mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne(
170 transform::TransformRewriter &rewriter, Operation *target,
171 transform::ApplyToEachResultList &results,
172 transform::TransformState &state) {
173 if (!target->getBlock())
174 return emitSilenceableError() << "payload has no parent block";
175 if (target->getBlock()->getNumArguments() <= getNumber())
176 return emitSilenceableError()
177 << "parent of the payload has no argument #" << getNumber();
178 results.push_back(target->getBlock()->getArgument(getNumber()));
179 return DiagnosedSilenceableFailure::success();
180}
181
182void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects(
183 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
184 transform::onlyReadsHandle(getInMutable(), effects);
185 transform::producesHandle(getOperation()->getOpResults(), effects);
186 transform::onlyReadsPayload(effects);
187}
188
189bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() {
190 return getAllowRepeatedHandles();
191}
192
193DiagnosedSilenceableFailure
194mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter,
195 transform::TransformResults &results,
196 transform::TransformState &state) {
197 return DiagnosedSilenceableFailure::success();
198}
199
200void mlir::test::TestConsumeOperand::getEffects(
201 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
202 transform::consumesHandle(getOperation()->getOpOperands(), effects);
203 if (getSecondOperand())
204 transform::consumesHandle(getSecondOperandMutable(), effects);
205 transform::modifiesPayload(effects);
206}
207
208DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply(
209 transform::TransformRewriter &rewriter,
210 transform::TransformResults &results, transform::TransformState &state) {
211 auto payload = state.getPayloadOps(getOperand());
212 assert(llvm::hasSingleElement(payload) && "expected a single target op");
213 if ((*payload.begin())->getName().getStringRef() != getOpKind()) {
214 return emitSilenceableError()
215 << "op expected the operand to be associated a payload op of kind "
216 << getOpKind() << " got "
217 << (*payload.begin())->getName().getStringRef();
218 }
219
220 emitRemark() << "succeeded";
221 return DiagnosedSilenceableFailure::success();
222}
223
224void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects(
225 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
226 transform::consumesHandle(getOperation()->getOpOperands(), effects);
227 transform::modifiesPayload(effects);
228}
229
230DiagnosedSilenceableFailure
231mlir::test::TestSucceedIfOperandOfOpKind::matchOperation(
232 Operation *op, transform::TransformResults &results,
233 transform::TransformState &state) {
234 if (op->getName().getStringRef() != getOpKind()) {
235 return emitSilenceableError()
236 << "op expected the operand to be associated with a payload op of "
237 "kind "
238 << getOpKind() << " got " << op->getName().getStringRef();
239 }
240 return DiagnosedSilenceableFailure::success();
241}
242
243void mlir::test::TestSucceedIfOperandOfOpKind::getEffects(
244 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
245 transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
246 transform::onlyReadsPayload(effects);
247}
248
249DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply(
250 transform::TransformRewriter &rewriter,
251 transform::TransformResults &results, transform::TransformState &state) {
252 state.addExtension<TestTransformStateExtension>(getMessageAttr());
253 return DiagnosedSilenceableFailure::success();
254}
255
256DiagnosedSilenceableFailure
257mlir::test::TestCheckIfTestExtensionPresentOp::apply(
258 transform::TransformRewriter &rewriter,
259 transform::TransformResults &results, transform::TransformState &state) {
260 auto *extension = state.getExtension<TestTransformStateExtension>();
261 if (!extension) {
262 emitRemark() << "extension absent";
263 return DiagnosedSilenceableFailure::success();
264 }
265
266 InFlightDiagnostic diag = emitRemark()
267 << "extension present, " << extension->getMessage();
268 for (Operation *payload : state.getPayloadOps(getOperand())) {
269 diag.attachNote(payload->getLoc()) << "associated payload op";
270#ifndef NDEBUG
271 SmallVector<Value> handles;
272 assert(succeeded(state.getHandlesForPayloadOp(payload, handles)));
273 assert(llvm::is_contained(handles, getOperand()) &&
274 "inconsistent mapping between transform IR handles and payload IR "
275 "operations");
276#endif // NDEBUG
277 }
278
279 return DiagnosedSilenceableFailure::success();
280}
281
282void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects(
283 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
284 transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
285 transform::onlyReadsPayload(effects);
286}
287
288DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
289 transform::TransformRewriter &rewriter,
290 transform::TransformResults &results, transform::TransformState &state) {
291 auto *extension = state.getExtension<TestTransformStateExtension>();
292 if (!extension)
293 return emitDefiniteFailure("TestTransformStateExtension missing");
294
295 if (failed(extension->updateMapping(
296 *state.getPayloadOps(getOperand()).begin(), getOperation())))
297 return DiagnosedSilenceableFailure::definiteFailure();
298 if (getNumResults() > 0)
299 results.set(cast<OpResult>(getResult(0)), {getOperation()});
300 return DiagnosedSilenceableFailure::success();
301}
302
303void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects(
304 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
305 transform::onlyReadsHandle(getOperation()->getOpOperands(), effects);
306 transform::producesHandle(getOperation()->getOpResults(), effects);
307 transform::onlyReadsPayload(effects);
308}
309
310DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
311 transform::TransformRewriter &rewriter,
312 transform::TransformResults &results, transform::TransformState &state) {
313 state.removeExtension<TestTransformStateExtension>();
314 return DiagnosedSilenceableFailure::success();
315}
316
317DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(
318 transform::TransformRewriter &rewriter,
319 transform::TransformResults &results, transform::TransformState &state) {
320 auto payloadOps = state.getPayloadOps(getTarget());
321 auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
322 results.set(llvm::cast<OpResult>(getResult()), reversedOps);
323 return DiagnosedSilenceableFailure::success();
324}
325
326DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply(
327 transform::TransformRewriter &rewriter,
328 transform::TransformResults &results, transform::TransformState &state) {
329 return DiagnosedSilenceableFailure::success();
330}
331
332void mlir::test::TestTransformOpWithRegions::getEffects(
333 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
334
335DiagnosedSilenceableFailure
336mlir::test::TestBranchingTransformOpTerminator::apply(
337 transform::TransformRewriter &rewriter,
338 transform::TransformResults &results, transform::TransformState &state) {
339 return DiagnosedSilenceableFailure::success();
340}
341
342void mlir::test::TestBranchingTransformOpTerminator::getEffects(
343 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
344
345DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
346 transform::TransformRewriter &rewriter,
347 transform::TransformResults &results, transform::TransformState &state) {
348 emitRemark() << getRemark();
349 for (Operation *op : state.getPayloadOps(getTarget())) {
350 if (!op->getUses().empty())
351 return emitSilenceableError() << "cannot erase an op that has uses";
352 rewriter.eraseOp(op);
353 }
354
355 if (getFailAfterErase())
356 return emitSilenceableError() << "silenceable error";
357 return DiagnosedSilenceableFailure::success();
358}
359
360void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects(
361 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
362 transform::consumesHandle(getTargetMutable(), effects);
363 transform::modifiesPayload(effects);
364}
365
366DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
367 transform::TransformRewriter &rewriter, Operation *target,
368 transform::ApplyToEachResultList &results,
369 transform::TransformState &state) {
370 OperationState opState(target->getLoc(), "foo");
371 results.push_back(OpBuilder(target).create(opState));
372 return DiagnosedSilenceableFailure::success();
373}
374
375DiagnosedSilenceableFailure
376mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
377 transform::TransformRewriter &rewriter, Operation *target,
378 transform::ApplyToEachResultList &results,
379 transform::TransformState &state) {
380 static int count = 0;
381 if (count++ == 0) {
382 OperationState opState(target->getLoc(), "foo");
383 results.push_back(OpBuilder(target).create(opState));
384 }
385 return DiagnosedSilenceableFailure::success();
386}
387
388DiagnosedSilenceableFailure
389mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
390 transform::TransformRewriter &rewriter, Operation *target,
391 transform::ApplyToEachResultList &results,
392 transform::TransformState &state) {
393 OperationState opState(target->getLoc(), "foo");
394 results.push_back(OpBuilder(target).create(opState));
395 results.push_back(OpBuilder(target).create(opState));
396 return DiagnosedSilenceableFailure::success();
397}
398
399DiagnosedSilenceableFailure
400mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
401 transform::TransformRewriter &rewriter, Operation *target,
402 transform::ApplyToEachResultList &results,
403 transform::TransformState &state) {
404 OperationState opState(target->getLoc(), "foo");
405 results.push_back(nullptr);
406 results.push_back(OpBuilder(target).create(opState));
407 return DiagnosedSilenceableFailure::success();
408}
409
410DiagnosedSilenceableFailure
411mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
412 transform::TransformRewriter &rewriter, Operation *target,
413 transform::ApplyToEachResultList &results,
414 transform::TransformState &state) {
415 if (target->hasAttr("target_me"))
416 return DiagnosedSilenceableFailure::success();
417 return emitDefaultSilenceableFailure(target);
418}
419
420DiagnosedSilenceableFailure
421mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
422 transform::TransformResults &results,
423 transform::TransformState &state) {
424 results.set(llvm::cast<OpResult>(getCopy()),
425 state.getPayloadOps(getHandle()));
426 return DiagnosedSilenceableFailure::success();
427}
428
429void mlir::test::TestCopyPayloadOp::getEffects(
430 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
431 transform::onlyReadsHandle(getHandleMutable(), effects);
432 transform::producesHandle(getOperation()->getOpResults(), effects);
433 transform::onlyReadsPayload(effects);
434}
435
436DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
437 Location loc, ArrayRef<Operation *> payload) const {
438 if (payload.empty())
439 return DiagnosedSilenceableFailure::success();
440
441 for (Operation *op : payload) {
442 if (op->getName().getDialectNamespace() != "test") {
443 return emitSilenceableError(loc) << "expected the payload operation to "
444 "belong to the 'test' dialect";
445 }
446 }
447
448 return DiagnosedSilenceableFailure::success();
449}
450
451DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
452 Location loc, ArrayRef<Attribute> payload) const {
453 for (Attribute attr : payload) {
454 auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
455 if (integerAttr && integerAttr.getType().isSignlessInteger(32))
456 continue;
457 return emitSilenceableError(loc)
458 << "expected the parameter to be a i32 integer attribute";
459 }
460
461 return DiagnosedSilenceableFailure::success();
462}
463
464void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
465 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
466 transform::onlyReadsHandle(getTargetMutable(), effects);
467}
468
469DiagnosedSilenceableFailure
470mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
471 transform::TransformRewriter &rewriter,
472 transform::TransformResults &results, transform::TransformState &state) {
473 int64_t count = 0;
474 for (Operation *op : state.getPayloadOps(getTarget())) {
475 op->walk([&](Operation *nested) {
476 SmallVector<Value> handles;
477 (void)state.getHandlesForPayloadOp(nested, handles);
478 count += handles.size();
479 });
480 }
481 emitRemark() << count << " handles nested under";
482 return DiagnosedSilenceableFailure::success();
483}
484
485DiagnosedSilenceableFailure
486mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter,
487 transform::TransformResults &results,
488 transform::TransformState &state) {
489 SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0);
490 if (Value param = getParam()) {
491 values = llvm::to_vector(
492 llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
493 return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue(
494 UINT32_MAX);
495 }));
496 }
497
498 Builder builder(getContext());
499 SmallVector<Attribute> result = llvm::to_vector(
500 llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
501 return builder.getI32IntegerAttr(value + getAddendum());
502 }));
503 results.setParams(llvm::cast<OpResult>(getResult()), result);
504 return DiagnosedSilenceableFailure::success();
505}
506
507DiagnosedSilenceableFailure
508mlir::test::TestProduceParamWithNumberOfTestOps::apply(
509 transform::TransformRewriter &rewriter,
510 transform::TransformResults &results, transform::TransformState &state) {
511 Builder builder(getContext());
512 SmallVector<Attribute> result = llvm::to_vector(
513 llvm::map_range(state.getPayloadOps(getHandle()),
514 [&builder](Operation *payload) -> Attribute {
515 int32_t count = 0;
516 payload->walk([&count](Operation *op) {
517 if (op->getName().getDialectNamespace() == "test")
518 ++count;
519 });
520 return builder.getI32IntegerAttr(count);
521 }));
522 results.setParams(llvm::cast<OpResult>(getResult()), result);
523 return DiagnosedSilenceableFailure::success();
524}
525
526DiagnosedSilenceableFailure
527mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter,
528 transform::TransformResults &results,
529 transform::TransformState &state) {
530 results.setParams(llvm::cast<OpResult>(getResult()), getAttr());
531 return DiagnosedSilenceableFailure::success();
532}
533
534void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects(
535 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
536 transform::onlyReadsHandle(getInMutable(), effects);
537 transform::producesHandle(getOperation()->getOpResults(), effects);
538}
539
540DiagnosedSilenceableFailure
541mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
542 transform::TransformRewriter &rewriter, Operation *target,
543 ::transform::ApplyToEachResultList &results,
544 ::transform::TransformState &state) {
545 Builder builder(getContext());
546 if (getFirstResultIsParam()) {
547 results.push_back(builder.getI64IntegerAttr(0));
548 } else if (getFirstResultIsNull()) {
549 results.push_back(nullptr);
550 } else {
551 results.push_back(*state.getPayloadOps(getIn()).begin());
552 }
553
554 if (getSecondResultIsHandle()) {
555 results.push_back(*state.getPayloadOps(getIn()).begin());
556 } else {
557 results.push_back(builder.getI64IntegerAttr(42));
558 }
559
560 return DiagnosedSilenceableFailure::success();
561}
562
563void mlir::test::TestProduceNullPayloadOp::getEffects(
564 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
565 transform::producesHandle(getOperation()->getOpResults(), effects);
566}
567
568DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
569 transform::TransformRewriter &rewriter,
570 transform::TransformResults &results, transform::TransformState &state) {
571 SmallVector<Operation *, 1> null({nullptr});
572 results.set(llvm::cast<OpResult>(getOut()), null);
573 return DiagnosedSilenceableFailure::success();
574}
575
576DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply(
577 transform::TransformRewriter &rewriter,
578 transform::TransformResults &results, transform::TransformState &state) {
579 results.set(cast<OpResult>(getOut()), {});
580 return DiagnosedSilenceableFailure::success();
581}
582
583void mlir::test::TestProduceNullParamOp::getEffects(
584 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
585 transform::producesHandle(getOperation()->getOpResults(), effects);
586}
587
588DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(
589 transform::TransformRewriter &rewriter,
590 transform::TransformResults &results, transform::TransformState &state) {
591 results.setParams(llvm::cast<OpResult>(getOut()), Attribute());
592 return DiagnosedSilenceableFailure::success();
593}
594
595void mlir::test::TestProduceNullValueOp::getEffects(
596 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
597 transform::producesHandle(getOperation()->getOpResults(), effects);
598}
599
600DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
601 transform::TransformRewriter &rewriter,
602 transform::TransformResults &results, transform::TransformState &state) {
603 results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
604 return DiagnosedSilenceableFailure::success();
605}
606
607void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
608 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
609 if (getHasOperandEffect())
610 transform::consumesHandle(getInMutable(), effects);
611
612 if (getHasResultEffect()) {
613 transform::producesHandle(getOperation()->getOpResults(), effects);
614 } else {
615 effects.emplace_back(MemoryEffects::Read::get(),
616 llvm::cast<OpResult>(getOut()),
617 transform::TransformMappingResource::get());
618 }
619
620 if (getModifiesPayload())
621 transform::modifiesPayload(effects);
622}
623
624DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
625 transform::TransformRewriter &rewriter,
626 transform::TransformResults &results, transform::TransformState &state) {
627 results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn()));
628 return DiagnosedSilenceableFailure::success();
629}
630
631void mlir::test::TestTrackedRewriteOp::getEffects(
632 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
633 transform::onlyReadsHandle(getInMutable(), effects);
634 transform::modifiesPayload(effects);
635}
636
637void mlir::test::TestDummyPayloadOp::getEffects(
638 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
639 transform::producesHandle(getOperation()->getOpResults(), effects);
640}
641
642LogicalResult mlir::test::TestDummyPayloadOp::verify() {
643 if (getFailToVerify())
644 return emitOpError() << "fail_to_verify is set";
645 return success();
646}
647
648DiagnosedSilenceableFailure
649mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter,
650 transform::TransformResults &results,
651 transform::TransformState &state) {
652 int64_t numIterations = 0;
653
654 // `getPayloadOps` returns an iterator that skips ops that are erased in the
655 // loop body. Replacement ops are not enumerated.
656 for (Operation *op : state.getPayloadOps(getIn())) {
657 ++numIterations;
658 (void)op;
659
660 // Erase all payload ops. The outer loop should have only one iteration.
661 for (Operation *op : state.getPayloadOps(getIn())) {
662 rewriter.setInsertionPoint(op);
663 if (op->hasAttr("erase_me")) {
664 rewriter.eraseOp(op);
665 continue;
666 }
667 if (!op->hasAttr("replace_me")) {
668 continue;
669 }
670
671 SmallVector<NamedAttribute> attributes;
672 attributes.emplace_back(rewriter.getStringAttr("new_op"),
673 rewriter.getUnitAttr());
674 OperationState opState(op->getLoc(), op->getName().getIdentifier(),
675 /*operands=*/ValueRange(),
676 /*types=*/op->getResultTypes(), attributes);
677 Operation *newOp = rewriter.create(opState);
678 rewriter.replaceOp(op, newOp->getResults());
679 }
680 }
681
682 emitRemark() << numIterations << " iterations";
683 return DiagnosedSilenceableFailure::success();
684}
685
686namespace {
687// Test pattern to replace an operation with a new op.
688class ReplaceWithNewOp : public RewritePattern {
689public:
690 ReplaceWithNewOp(MLIRContext *context)
691 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
692
693 LogicalResult matchAndRewrite(Operation *op,
694 PatternRewriter &rewriter) const override {
695 auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op");
696 if (!newName)
697 return failure();
698 Operation *newOp = rewriter.create(
699 op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(),
700 op->getOperands(), op->getResultTypes());
701 rewriter.replaceOp(op, newValues: newOp->getResults());
702 return success();
703 }
704};
705
706// Test pattern to erase an operation.
707class EraseOp : public RewritePattern {
708public:
709 EraseOp(MLIRContext *context)
710 : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
711 LogicalResult matchAndRewrite(Operation *op,
712 PatternRewriter &rewriter) const override {
713 rewriter.eraseOp(op);
714 return success();
715 }
716};
717} // namespace
718
719void mlir::test::ApplyTestPatternsOp::populatePatterns(
720 RewritePatternSet &patterns) {
721 patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
722}
723
724void mlir::test::TestReEnterRegionOp::getEffects(
725 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
726 transform::consumesHandle(getOperation()->getOpOperands(), effects);
727 transform::modifiesPayload(effects);
728}
729
730DiagnosedSilenceableFailure
731mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter,
732 transform::TransformResults &results,
733 transform::TransformState &state) {
734
735 SmallVector<SmallVector<transform::MappedValue>> mappings;
736 for (BlockArgument arg : getBody().front().getArguments()) {
737 mappings.emplace_back(llvm::to_vector(llvm::map_range(
738 state.getPayloadOps(getOperand(arg.getArgNumber())),
739 [](Operation *op) -> transform::MappedValue { return op; })));
740 }
741
742 for (int i = 0; i < 4; ++i) {
743 auto scope = state.make_region_scope(getBody());
744 for (BlockArgument arg : getBody().front().getArguments()) {
745 if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()])))
746 return DiagnosedSilenceableFailure::definiteFailure();
747 }
748 for (Operation &op : getBody().front().without_terminator()) {
749 DiagnosedSilenceableFailure diag =
750 state.applyTransform(cast<transform::TransformOpInterface>(op));
751 if (!diag.succeeded())
752 return diag;
753 }
754 }
755 return DiagnosedSilenceableFailure::success();
756}
757
758LogicalResult mlir::test::TestReEnterRegionOp::verify() {
759 if (getNumOperands() != getBody().front().getNumArguments()) {
760 return emitOpError() << "expects as many operands as block arguments";
761 }
762 return success();
763}
764
765DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply(
766 transform::TransformRewriter &rewriter,
767 transform::TransformResults &results, transform::TransformState &state) {
768 auto originalOps = state.getPayloadOps(getOriginal());
769 auto replacementOps = state.getPayloadOps(getReplacement());
770 if (llvm::range_size(originalOps) != llvm::range_size(replacementOps))
771 return emitSilenceableError() << "expected same number of original and "
772 "replacement payload operations";
773 for (const auto &[original, replacement] :
774 llvm::zip(originalOps, replacementOps)) {
775 if (failed(
776 rewriter.notifyPayloadOperationReplaced(original, replacement))) {
777 auto diag = emitSilenceableError()
778 << "unable to replace payload op in transform mapping";
779 diag.attachNote(original->getLoc()) << "original payload op";
780 diag.attachNote(replacement->getLoc()) << "replacement payload op";
781 return diag;
782 }
783 }
784 return DiagnosedSilenceableFailure::success();
785}
786
787void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects(
788 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
789 transform::onlyReadsHandle(getOriginalMutable(), effects);
790 transform::onlyReadsHandle(getReplacementMutable(), effects);
791}
792
793DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne(
794 transform::TransformRewriter &rewriter, Operation *target,
795 transform::ApplyToEachResultList &results,
796 transform::TransformState &state) {
797 // Provide some IR that does not verify.
798 rewriter.setInsertionPointToStart(&target->getRegion(0).front());
799 rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(),
800 ValueRange(), /*failToVerify=*/true);
801 return DiagnosedSilenceableFailure::success();
802}
803
804void mlir::test::TestProduceInvalidIR::getEffects(
805 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
806 transform::onlyReadsHandle(getTargetMutable(), effects);
807 transform::modifiesPayload(effects);
808}
809
810DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply(
811 transform::TransformRewriter &rewriter,
812 transform::TransformResults &results, transform::TransformState &state) {
813 std::string opName =
814 this->getOperationName().str() + "_" + getTypeAttr().str();
815 TransformStateInitializerExtension *initExt =
816 state.getExtension<TransformStateInitializerExtension>();
817 if (!initExt) {
818 emitRemark() << "\nSpecified extension not found, adding a new one!\n";
819 SmallVector<std::string> opCollection = {opName};
820 state.addExtension<TransformStateInitializerExtension>(1, opCollection);
821 } else {
822 initExt->setNumOp(initExt->getNumOp() + 1);
823 initExt->pushRegisteredOps(opName);
824 InFlightDiagnostic diag = emitRemark()
825 << "Number of currently registered op: "
826 << initExt->getNumOp() << "\n"
827 << initExt->printMessage() << "\n";
828 }
829 return DiagnosedSilenceableFailure::success();
830}
831
832namespace {
833/// Test conversion pattern that replaces ops with the "replace_with_new_op"
834/// attribute with "test.new_op".
835class ReplaceWithNewOpConversion : public ConversionPattern {
836public:
837 ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context)
838 : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(),
839 /*benefit=*/1, context) {}
840
841 LogicalResult
842 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
843 ConversionPatternRewriter &rewriter) const override {
844 if (!op->hasAttr(name: "replace_with_new_op"))
845 return failure();
846 SmallVector<Type> newResultTypes;
847 if (failed(Result: getTypeConverter()->convertTypes(types: op->getResultTypes(),
848 results&: newResultTypes)))
849 return failure();
850 Operation *newOp = rewriter.create(
851 op->getLoc(),
852 OperationName("test.new_op", op->getContext()).getIdentifier(),
853 operands, newResultTypes);
854 rewriter.replaceOp(op, newValues: newOp->getResults());
855 return success();
856 }
857};
858} // namespace
859
860void mlir::test::ApplyTestConversionPatternsOp::populatePatterns(
861 TypeConverter &typeConverter, RewritePatternSet &patterns) {
862 patterns.insert<ReplaceWithNewOpConversion>(typeConverter,
863 patterns.getContext());
864}
865
866namespace {
867/// Test type converter that converts tensor types to memref types.
868class TestTypeConverter : public TypeConverter {
869public:
870 TestTypeConverter() {
871 addConversion(callback: [](Type t) { return t; });
872 addConversion(callback: [](RankedTensorType type) -> Type {
873 return MemRefType::get(type.getShape(), type.getElementType());
874 });
875 auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType,
876 ValueRange inputs,
877 Location loc) -> Value {
878 if (inputs.size() != 1)
879 return Value();
880 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
881 .getResult(0);
882 };
883 addSourceMaterialization(callback&: unrealizedCastConverter);
884 addTargetMaterialization(callback&: unrealizedCastConverter);
885 }
886};
887} // namespace
888
889std::unique_ptr<::mlir::TypeConverter>
890mlir::test::TestTypeConverterOp::getTypeConverter() {
891 return std::make_unique<TestTypeConverter>();
892}
893
894namespace {
895/// Test extension of the Transform dialect. Registers additional ops and
896/// declares PDL as dependent dialect since the additional ops are using PDL
897/// types for operands and results.
898class TestTransformDialectExtension
899 : public transform::TransformDialectExtension<
900 TestTransformDialectExtension> {
901public:
902 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension)
903
904 using Base::Base;
905
906 void init() {
907 declareDependentDialect<pdl::PDLDialect>();
908 registerTransformOps<TestTransformOp,
909 TestTransformUnrestrictedOpNoInterface,
910#define GET_OP_LIST
911#include "TestTransformDialectExtension.cpp.inc"
912 >();
913 registerTypes<
914#define GET_TYPEDEF_LIST
915#include "TestTransformDialectExtensionTypes.cpp.inc"
916 >();
917
918 auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
919 ArrayRef<PDLValue> pdlValues) {
920 for (const PDLValue &pdlValue : pdlValues) {
921 if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
922 op->emitWarning() << "from PDL constraint";
923 }
924 }
925 return success();
926 };
927
928 addDialectDataInitializer<transform::PDLMatchHooks>(
929 [&](transform::PDLMatchHooks &hooks) {
930 llvm::StringMap<PDLConstraintFunction> constraints;
931 constraints.try_emplace("verbose_constraint", verboseConstraint);
932 hooks.mergeInPDLMatchHooks(constraintFns: std::move(constraints));
933 });
934 }
935};
936} // namespace
937
938// These are automatically generated by ODS but are not used as the Transform
939// dialect uses a different dispatch mechanism to support dialect extensions.
940LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
941generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
942LLVM_ATTRIBUTE_UNUSED static LogicalResult
943generatedTypePrinter(Type def, AsmPrinter &printer);
944
945#define GET_TYPEDEF_CLASSES
946#include "TestTransformDialectExtensionTypes.cpp.inc"
947
948#define GET_OP_CLASSES
949#include "TestTransformDialectExtension.cpp.inc"
950
951void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
952 registry.addExtensions<TestTransformDialectExtension>();
953}
954

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp