1 | //===- TestTransformDialectExtension.cpp ----------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines an extension of the MLIR Transform dialect for testing |
10 | // purposes. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "TestTransformDialectExtension.h" |
15 | #include "TestTransformStateExtension.h" |
16 | #include "mlir/Dialect/PDL/IR/PDL.h" |
17 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
18 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
19 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
20 | #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" |
21 | #include "mlir/IR/OpImplementation.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "llvm/ADT/STLExtras.h" |
24 | #include "llvm/ADT/TypeSwitch.h" |
25 | #include "llvm/Support/Compiler.h" |
26 | #include "llvm/Support/raw_ostream.h" |
27 | |
28 | using namespace mlir; |
29 | |
30 | namespace { |
31 | /// Simple transform op defined outside of the dialect. Just emits a remark when |
32 | /// applied. This op is defined in C++ to test that C++ definitions also work |
33 | /// for op injection into the Transform dialect. |
34 | class TestTransformOp |
35 | : public Op<TestTransformOp, transform::TransformOpInterface::Trait, |
36 | MemoryEffectOpInterface::Trait> { |
37 | public: |
38 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) |
39 | |
40 | using Op::Op; |
41 | |
42 | static ArrayRef<StringRef> getAttributeNames() { return {}; } |
43 | |
44 | static constexpr llvm::StringLiteral getOperationName() { |
45 | return llvm::StringLiteral("transform.test_transform_op"); |
46 | } |
47 | |
48 | DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, |
49 | transform::TransformResults &results, |
50 | transform::TransformState &state) { |
51 | InFlightDiagnostic remark = emitRemark() << "applying transformation"; |
52 | if (Attribute message = getMessage()) |
53 | remark << " "<< message; |
54 | |
55 | return DiagnosedSilenceableFailure::success(); |
56 | } |
57 | |
58 | Attribute getMessage() { |
59 | return getOperation()->getDiscardableAttr("message"); |
60 | } |
61 | |
62 | static ParseResult parse(OpAsmParser &parser, OperationState &state) { |
63 | StringAttr message; |
64 | OptionalParseResult result = parser.parseOptionalAttribute(message); |
65 | if (!result.has_value()) |
66 | return success(); |
67 | |
68 | if (result.value().succeeded()) |
69 | state.addAttribute("message", message); |
70 | return result.value(); |
71 | } |
72 | |
73 | void print(OpAsmPrinter &printer) { |
74 | if (getMessage()) |
75 | printer << " "<< getMessage(); |
76 | } |
77 | |
78 | // No side effects. |
79 | void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} |
80 | }; |
81 | |
82 | /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait |
83 | /// in cases where it is attached to ops that do not comply with the trait |
84 | /// requirements. This op cannot be defined in ODS because ODS generates strict |
85 | /// verifiers that overalp with those in the trait and run earlier. |
86 | class TestTransformUnrestrictedOpNoInterface |
87 | : public Op<TestTransformUnrestrictedOpNoInterface, |
88 | transform::PossibleTopLevelTransformOpTrait, |
89 | transform::TransformOpInterface::Trait, |
90 | MemoryEffectOpInterface::Trait> { |
91 | public: |
92 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
93 | TestTransformUnrestrictedOpNoInterface) |
94 | |
95 | using Op::Op; |
96 | |
97 | static ArrayRef<StringRef> getAttributeNames() { return {}; } |
98 | |
99 | static constexpr llvm::StringLiteral getOperationName() { |
100 | return llvm::StringLiteral( |
101 | "transform.test_transform_unrestricted_op_no_interface"); |
102 | } |
103 | |
104 | DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, |
105 | transform::TransformResults &results, |
106 | transform::TransformState &state) { |
107 | return DiagnosedSilenceableFailure::success(); |
108 | } |
109 | |
110 | // No side effects. |
111 | void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} |
112 | }; |
113 | } // namespace |
114 | |
115 | DiagnosedSilenceableFailure |
116 | mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( |
117 | transform::TransformRewriter &rewriter, |
118 | transform::TransformResults &results, transform::TransformState &state) { |
119 | if (getOperation()->getNumOperands() != 0) { |
120 | results.set(cast<OpResult>(getResult()), |
121 | {getOperation()->getOperand(0).getDefiningOp()}); |
122 | } else { |
123 | results.set(cast<OpResult>(getResult()), {getOperation()}); |
124 | } |
125 | return DiagnosedSilenceableFailure::success(); |
126 | } |
127 | |
128 | void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects( |
129 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
130 | if (getOperand()) |
131 | transform::onlyReadsHandle(getOperandMutable(), effects); |
132 | transform::producesHandle(getOperation()->getOpResults(), effects); |
133 | } |
134 | |
135 | DiagnosedSilenceableFailure |
136 | mlir::test::TestProduceValueHandleToSelfOperand::apply( |
137 | transform::TransformRewriter &rewriter, |
138 | transform::TransformResults &results, transform::TransformState &state) { |
139 | results.setValues(llvm::cast<OpResult>(getOut()), {getIn()}); |
140 | return DiagnosedSilenceableFailure::success(); |
141 | } |
142 | |
143 | void mlir::test::TestProduceValueHandleToSelfOperand::getEffects( |
144 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
145 | transform::onlyReadsHandle(getInMutable(), effects); |
146 | transform::producesHandle(getOperation()->getOpResults(), effects); |
147 | transform::onlyReadsPayload(effects); |
148 | } |
149 | |
150 | DiagnosedSilenceableFailure |
151 | mlir::test::TestProduceValueHandleToResult::applyToOne( |
152 | transform::TransformRewriter &rewriter, Operation *target, |
153 | transform::ApplyToEachResultList &results, |
154 | transform::TransformState &state) { |
155 | if (target->getNumResults() <= getNumber()) |
156 | return emitSilenceableError() << "payload has no result #"<< getNumber(); |
157 | results.push_back(target->getResult(getNumber())); |
158 | return DiagnosedSilenceableFailure::success(); |
159 | } |
160 | |
161 | void mlir::test::TestProduceValueHandleToResult::getEffects( |
162 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
163 | transform::onlyReadsHandle(getInMutable(), effects); |
164 | transform::producesHandle(getOperation()->getOpResults(), effects); |
165 | transform::onlyReadsPayload(effects); |
166 | } |
167 | |
168 | DiagnosedSilenceableFailure |
169 | mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( |
170 | transform::TransformRewriter &rewriter, Operation *target, |
171 | transform::ApplyToEachResultList &results, |
172 | transform::TransformState &state) { |
173 | if (!target->getBlock()) |
174 | return emitSilenceableError() << "payload has no parent block"; |
175 | if (target->getBlock()->getNumArguments() <= getNumber()) |
176 | return emitSilenceableError() |
177 | << "parent of the payload has no argument #"<< getNumber(); |
178 | results.push_back(target->getBlock()->getArgument(getNumber())); |
179 | return DiagnosedSilenceableFailure::success(); |
180 | } |
181 | |
182 | void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects( |
183 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
184 | transform::onlyReadsHandle(getInMutable(), effects); |
185 | transform::producesHandle(getOperation()->getOpResults(), effects); |
186 | transform::onlyReadsPayload(effects); |
187 | } |
188 | |
189 | bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() { |
190 | return getAllowRepeatedHandles(); |
191 | } |
192 | |
193 | DiagnosedSilenceableFailure |
194 | mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter, |
195 | transform::TransformResults &results, |
196 | transform::TransformState &state) { |
197 | return DiagnosedSilenceableFailure::success(); |
198 | } |
199 | |
200 | void mlir::test::TestConsumeOperand::getEffects( |
201 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
202 | transform::consumesHandle(getOperation()->getOpOperands(), effects); |
203 | if (getSecondOperand()) |
204 | transform::consumesHandle(getSecondOperandMutable(), effects); |
205 | transform::modifiesPayload(effects); |
206 | } |
207 | |
208 | DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( |
209 | transform::TransformRewriter &rewriter, |
210 | transform::TransformResults &results, transform::TransformState &state) { |
211 | auto payload = state.getPayloadOps(getOperand()); |
212 | assert(llvm::hasSingleElement(payload) && "expected a single target op"); |
213 | if ((*payload.begin())->getName().getStringRef() != getOpKind()) { |
214 | return emitSilenceableError() |
215 | << "op expected the operand to be associated a payload op of kind " |
216 | << getOpKind() << " got " |
217 | << (*payload.begin())->getName().getStringRef(); |
218 | } |
219 | |
220 | emitRemark() << "succeeded"; |
221 | return DiagnosedSilenceableFailure::success(); |
222 | } |
223 | |
224 | void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects( |
225 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
226 | transform::consumesHandle(getOperation()->getOpOperands(), effects); |
227 | transform::modifiesPayload(effects); |
228 | } |
229 | |
230 | DiagnosedSilenceableFailure |
231 | mlir::test::TestSucceedIfOperandOfOpKind::matchOperation( |
232 | Operation *op, transform::TransformResults &results, |
233 | transform::TransformState &state) { |
234 | if (op->getName().getStringRef() != getOpKind()) { |
235 | return emitSilenceableError() |
236 | << "op expected the operand to be associated with a payload op of " |
237 | "kind " |
238 | << getOpKind() << " got "<< op->getName().getStringRef(); |
239 | } |
240 | return DiagnosedSilenceableFailure::success(); |
241 | } |
242 | |
243 | void mlir::test::TestSucceedIfOperandOfOpKind::getEffects( |
244 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
245 | transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); |
246 | transform::onlyReadsPayload(effects); |
247 | } |
248 | |
249 | DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply( |
250 | transform::TransformRewriter &rewriter, |
251 | transform::TransformResults &results, transform::TransformState &state) { |
252 | state.addExtension<TestTransformStateExtension>(getMessageAttr()); |
253 | return DiagnosedSilenceableFailure::success(); |
254 | } |
255 | |
256 | DiagnosedSilenceableFailure |
257 | mlir::test::TestCheckIfTestExtensionPresentOp::apply( |
258 | transform::TransformRewriter &rewriter, |
259 | transform::TransformResults &results, transform::TransformState &state) { |
260 | auto *extension = state.getExtension<TestTransformStateExtension>(); |
261 | if (!extension) { |
262 | emitRemark() << "extension absent"; |
263 | return DiagnosedSilenceableFailure::success(); |
264 | } |
265 | |
266 | InFlightDiagnostic diag = emitRemark() |
267 | << "extension present, "<< extension->getMessage(); |
268 | for (Operation *payload : state.getPayloadOps(getOperand())) { |
269 | diag.attachNote(payload->getLoc()) << "associated payload op"; |
270 | #ifndef NDEBUG |
271 | SmallVector<Value> handles; |
272 | assert(succeeded(state.getHandlesForPayloadOp(payload, handles))); |
273 | assert(llvm::is_contained(handles, getOperand()) && |
274 | "inconsistent mapping between transform IR handles and payload IR " |
275 | "operations"); |
276 | #endif // NDEBUG |
277 | } |
278 | |
279 | return DiagnosedSilenceableFailure::success(); |
280 | } |
281 | |
282 | void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects( |
283 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
284 | transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); |
285 | transform::onlyReadsPayload(effects); |
286 | } |
287 | |
288 | DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( |
289 | transform::TransformRewriter &rewriter, |
290 | transform::TransformResults &results, transform::TransformState &state) { |
291 | auto *extension = state.getExtension<TestTransformStateExtension>(); |
292 | if (!extension) |
293 | return emitDefiniteFailure("TestTransformStateExtension missing"); |
294 | |
295 | if (failed(extension->updateMapping( |
296 | *state.getPayloadOps(getOperand()).begin(), getOperation()))) |
297 | return DiagnosedSilenceableFailure::definiteFailure(); |
298 | if (getNumResults() > 0) |
299 | results.set(cast<OpResult>(getResult(0)), {getOperation()}); |
300 | return DiagnosedSilenceableFailure::success(); |
301 | } |
302 | |
303 | void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects( |
304 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
305 | transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); |
306 | transform::producesHandle(getOperation()->getOpResults(), effects); |
307 | transform::onlyReadsPayload(effects); |
308 | } |
309 | |
310 | DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( |
311 | transform::TransformRewriter &rewriter, |
312 | transform::TransformResults &results, transform::TransformState &state) { |
313 | state.removeExtension<TestTransformStateExtension>(); |
314 | return DiagnosedSilenceableFailure::success(); |
315 | } |
316 | |
317 | DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply( |
318 | transform::TransformRewriter &rewriter, |
319 | transform::TransformResults &results, transform::TransformState &state) { |
320 | auto payloadOps = state.getPayloadOps(getTarget()); |
321 | auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); |
322 | results.set(llvm::cast<OpResult>(getResult()), reversedOps); |
323 | return DiagnosedSilenceableFailure::success(); |
324 | } |
325 | |
326 | DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( |
327 | transform::TransformRewriter &rewriter, |
328 | transform::TransformResults &results, transform::TransformState &state) { |
329 | return DiagnosedSilenceableFailure::success(); |
330 | } |
331 | |
332 | void mlir::test::TestTransformOpWithRegions::getEffects( |
333 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} |
334 | |
335 | DiagnosedSilenceableFailure |
336 | mlir::test::TestBranchingTransformOpTerminator::apply( |
337 | transform::TransformRewriter &rewriter, |
338 | transform::TransformResults &results, transform::TransformState &state) { |
339 | return DiagnosedSilenceableFailure::success(); |
340 | } |
341 | |
342 | void mlir::test::TestBranchingTransformOpTerminator::getEffects( |
343 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} |
344 | |
345 | DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( |
346 | transform::TransformRewriter &rewriter, |
347 | transform::TransformResults &results, transform::TransformState &state) { |
348 | emitRemark() << getRemark(); |
349 | for (Operation *op : state.getPayloadOps(getTarget())) { |
350 | if (!op->getUses().empty()) |
351 | return emitSilenceableError() << "cannot erase an op that has uses"; |
352 | rewriter.eraseOp(op); |
353 | } |
354 | |
355 | if (getFailAfterErase()) |
356 | return emitSilenceableError() << "silenceable error"; |
357 | return DiagnosedSilenceableFailure::success(); |
358 | } |
359 | |
360 | void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects( |
361 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
362 | transform::consumesHandle(getTargetMutable(), effects); |
363 | transform::modifiesPayload(effects); |
364 | } |
365 | |
366 | DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( |
367 | transform::TransformRewriter &rewriter, Operation *target, |
368 | transform::ApplyToEachResultList &results, |
369 | transform::TransformState &state) { |
370 | OperationState opState(target->getLoc(), "foo"); |
371 | results.push_back(OpBuilder(target).create(opState)); |
372 | return DiagnosedSilenceableFailure::success(); |
373 | } |
374 | |
375 | DiagnosedSilenceableFailure |
376 | mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( |
377 | transform::TransformRewriter &rewriter, Operation *target, |
378 | transform::ApplyToEachResultList &results, |
379 | transform::TransformState &state) { |
380 | static int count = 0; |
381 | if (count++ == 0) { |
382 | OperationState opState(target->getLoc(), "foo"); |
383 | results.push_back(OpBuilder(target).create(opState)); |
384 | } |
385 | return DiagnosedSilenceableFailure::success(); |
386 | } |
387 | |
388 | DiagnosedSilenceableFailure |
389 | mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( |
390 | transform::TransformRewriter &rewriter, Operation *target, |
391 | transform::ApplyToEachResultList &results, |
392 | transform::TransformState &state) { |
393 | OperationState opState(target->getLoc(), "foo"); |
394 | results.push_back(OpBuilder(target).create(opState)); |
395 | results.push_back(OpBuilder(target).create(opState)); |
396 | return DiagnosedSilenceableFailure::success(); |
397 | } |
398 | |
399 | DiagnosedSilenceableFailure |
400 | mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( |
401 | transform::TransformRewriter &rewriter, Operation *target, |
402 | transform::ApplyToEachResultList &results, |
403 | transform::TransformState &state) { |
404 | OperationState opState(target->getLoc(), "foo"); |
405 | results.push_back(nullptr); |
406 | results.push_back(OpBuilder(target).create(opState)); |
407 | return DiagnosedSilenceableFailure::success(); |
408 | } |
409 | |
410 | DiagnosedSilenceableFailure |
411 | mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( |
412 | transform::TransformRewriter &rewriter, Operation *target, |
413 | transform::ApplyToEachResultList &results, |
414 | transform::TransformState &state) { |
415 | if (target->hasAttr("target_me")) |
416 | return DiagnosedSilenceableFailure::success(); |
417 | return emitDefaultSilenceableFailure(target); |
418 | } |
419 | |
420 | DiagnosedSilenceableFailure |
421 | mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter, |
422 | transform::TransformResults &results, |
423 | transform::TransformState &state) { |
424 | results.set(llvm::cast<OpResult>(getCopy()), |
425 | state.getPayloadOps(getHandle())); |
426 | return DiagnosedSilenceableFailure::success(); |
427 | } |
428 | |
429 | void mlir::test::TestCopyPayloadOp::getEffects( |
430 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
431 | transform::onlyReadsHandle(getHandleMutable(), effects); |
432 | transform::producesHandle(getOperation()->getOpResults(), effects); |
433 | transform::onlyReadsPayload(effects); |
434 | } |
435 | |
436 | DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload( |
437 | Location loc, ArrayRef<Operation *> payload) const { |
438 | if (payload.empty()) |
439 | return DiagnosedSilenceableFailure::success(); |
440 | |
441 | for (Operation *op : payload) { |
442 | if (op->getName().getDialectNamespace() != "test") { |
443 | return emitSilenceableError(loc) << "expected the payload operation to " |
444 | "belong to the 'test' dialect"; |
445 | } |
446 | } |
447 | |
448 | return DiagnosedSilenceableFailure::success(); |
449 | } |
450 | |
451 | DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload( |
452 | Location loc, ArrayRef<Attribute> payload) const { |
453 | for (Attribute attr : payload) { |
454 | auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr); |
455 | if (integerAttr && integerAttr.getType().isSignlessInteger(32)) |
456 | continue; |
457 | return emitSilenceableError(loc) |
458 | << "expected the parameter to be a i32 integer attribute"; |
459 | } |
460 | |
461 | return DiagnosedSilenceableFailure::success(); |
462 | } |
463 | |
464 | void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects( |
465 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
466 | transform::onlyReadsHandle(getTargetMutable(), effects); |
467 | } |
468 | |
469 | DiagnosedSilenceableFailure |
470 | mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply( |
471 | transform::TransformRewriter &rewriter, |
472 | transform::TransformResults &results, transform::TransformState &state) { |
473 | int64_t count = 0; |
474 | for (Operation *op : state.getPayloadOps(getTarget())) { |
475 | op->walk([&](Operation *nested) { |
476 | SmallVector<Value> handles; |
477 | (void)state.getHandlesForPayloadOp(nested, handles); |
478 | count += handles.size(); |
479 | }); |
480 | } |
481 | emitRemark() << count << " handles nested under"; |
482 | return DiagnosedSilenceableFailure::success(); |
483 | } |
484 | |
485 | DiagnosedSilenceableFailure |
486 | mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter, |
487 | transform::TransformResults &results, |
488 | transform::TransformState &state) { |
489 | SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0); |
490 | if (Value param = getParam()) { |
491 | values = llvm::to_vector( |
492 | llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t { |
493 | return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue( |
494 | UINT32_MAX); |
495 | })); |
496 | } |
497 | |
498 | Builder builder(getContext()); |
499 | SmallVector<Attribute> result = llvm::to_vector( |
500 | llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute { |
501 | return builder.getI32IntegerAttr(value + getAddendum()); |
502 | })); |
503 | results.setParams(llvm::cast<OpResult>(getResult()), result); |
504 | return DiagnosedSilenceableFailure::success(); |
505 | } |
506 | |
507 | DiagnosedSilenceableFailure |
508 | mlir::test::TestProduceParamWithNumberOfTestOps::apply( |
509 | transform::TransformRewriter &rewriter, |
510 | transform::TransformResults &results, transform::TransformState &state) { |
511 | Builder builder(getContext()); |
512 | SmallVector<Attribute> result = llvm::to_vector( |
513 | llvm::map_range(state.getPayloadOps(getHandle()), |
514 | [&builder](Operation *payload) -> Attribute { |
515 | int32_t count = 0; |
516 | payload->walk([&count](Operation *op) { |
517 | if (op->getName().getDialectNamespace() == "test") |
518 | ++count; |
519 | }); |
520 | return builder.getI32IntegerAttr(count); |
521 | })); |
522 | results.setParams(llvm::cast<OpResult>(getResult()), result); |
523 | return DiagnosedSilenceableFailure::success(); |
524 | } |
525 | |
526 | DiagnosedSilenceableFailure |
527 | mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter, |
528 | transform::TransformResults &results, |
529 | transform::TransformState &state) { |
530 | results.setParams(llvm::cast<OpResult>(getResult()), getAttr()); |
531 | return DiagnosedSilenceableFailure::success(); |
532 | } |
533 | |
534 | void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects( |
535 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
536 | transform::onlyReadsHandle(getInMutable(), effects); |
537 | transform::producesHandle(getOperation()->getOpResults(), effects); |
538 | } |
539 | |
540 | DiagnosedSilenceableFailure |
541 | mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( |
542 | transform::TransformRewriter &rewriter, Operation *target, |
543 | ::transform::ApplyToEachResultList &results, |
544 | ::transform::TransformState &state) { |
545 | Builder builder(getContext()); |
546 | if (getFirstResultIsParam()) { |
547 | results.push_back(builder.getI64IntegerAttr(0)); |
548 | } else if (getFirstResultIsNull()) { |
549 | results.push_back(nullptr); |
550 | } else { |
551 | results.push_back(*state.getPayloadOps(getIn()).begin()); |
552 | } |
553 | |
554 | if (getSecondResultIsHandle()) { |
555 | results.push_back(*state.getPayloadOps(getIn()).begin()); |
556 | } else { |
557 | results.push_back(builder.getI64IntegerAttr(42)); |
558 | } |
559 | |
560 | return DiagnosedSilenceableFailure::success(); |
561 | } |
562 | |
563 | void mlir::test::TestProduceNullPayloadOp::getEffects( |
564 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
565 | transform::producesHandle(getOperation()->getOpResults(), effects); |
566 | } |
567 | |
568 | DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( |
569 | transform::TransformRewriter &rewriter, |
570 | transform::TransformResults &results, transform::TransformState &state) { |
571 | SmallVector<Operation *, 1> null({nullptr}); |
572 | results.set(llvm::cast<OpResult>(getOut()), null); |
573 | return DiagnosedSilenceableFailure::success(); |
574 | } |
575 | |
576 | DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply( |
577 | transform::TransformRewriter &rewriter, |
578 | transform::TransformResults &results, transform::TransformState &state) { |
579 | results.set(cast<OpResult>(getOut()), {}); |
580 | return DiagnosedSilenceableFailure::success(); |
581 | } |
582 | |
583 | void mlir::test::TestProduceNullParamOp::getEffects( |
584 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
585 | transform::producesHandle(getOperation()->getOpResults(), effects); |
586 | } |
587 | |
588 | DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply( |
589 | transform::TransformRewriter &rewriter, |
590 | transform::TransformResults &results, transform::TransformState &state) { |
591 | results.setParams(llvm::cast<OpResult>(getOut()), Attribute()); |
592 | return DiagnosedSilenceableFailure::success(); |
593 | } |
594 | |
595 | void mlir::test::TestProduceNullValueOp::getEffects( |
596 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
597 | transform::producesHandle(getOperation()->getOpResults(), effects); |
598 | } |
599 | |
600 | DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply( |
601 | transform::TransformRewriter &rewriter, |
602 | transform::TransformResults &results, transform::TransformState &state) { |
603 | results.setValues(llvm::cast<OpResult>(getOut()), {Value()}); |
604 | return DiagnosedSilenceableFailure::success(); |
605 | } |
606 | |
607 | void mlir::test::TestRequiredMemoryEffectsOp::getEffects( |
608 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
609 | if (getHasOperandEffect()) |
610 | transform::consumesHandle(getInMutable(), effects); |
611 | |
612 | if (getHasResultEffect()) { |
613 | transform::producesHandle(getOperation()->getOpResults(), effects); |
614 | } else { |
615 | effects.emplace_back(MemoryEffects::Read::get(), |
616 | llvm::cast<OpResult>(getOut()), |
617 | transform::TransformMappingResource::get()); |
618 | } |
619 | |
620 | if (getModifiesPayload()) |
621 | transform::modifiesPayload(effects); |
622 | } |
623 | |
624 | DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( |
625 | transform::TransformRewriter &rewriter, |
626 | transform::TransformResults &results, transform::TransformState &state) { |
627 | results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn())); |
628 | return DiagnosedSilenceableFailure::success(); |
629 | } |
630 | |
631 | void mlir::test::TestTrackedRewriteOp::getEffects( |
632 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
633 | transform::onlyReadsHandle(getInMutable(), effects); |
634 | transform::modifiesPayload(effects); |
635 | } |
636 | |
637 | void mlir::test::TestDummyPayloadOp::getEffects( |
638 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
639 | transform::producesHandle(getOperation()->getOpResults(), effects); |
640 | } |
641 | |
642 | LogicalResult mlir::test::TestDummyPayloadOp::verify() { |
643 | if (getFailToVerify()) |
644 | return emitOpError() << "fail_to_verify is set"; |
645 | return success(); |
646 | } |
647 | |
648 | DiagnosedSilenceableFailure |
649 | mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter, |
650 | transform::TransformResults &results, |
651 | transform::TransformState &state) { |
652 | int64_t numIterations = 0; |
653 | |
654 | // `getPayloadOps` returns an iterator that skips ops that are erased in the |
655 | // loop body. Replacement ops are not enumerated. |
656 | for (Operation *op : state.getPayloadOps(getIn())) { |
657 | ++numIterations; |
658 | (void)op; |
659 | |
660 | // Erase all payload ops. The outer loop should have only one iteration. |
661 | for (Operation *op : state.getPayloadOps(getIn())) { |
662 | rewriter.setInsertionPoint(op); |
663 | if (op->hasAttr("erase_me")) { |
664 | rewriter.eraseOp(op); |
665 | continue; |
666 | } |
667 | if (!op->hasAttr("replace_me")) { |
668 | continue; |
669 | } |
670 | |
671 | SmallVector<NamedAttribute> attributes; |
672 | attributes.emplace_back(rewriter.getStringAttr("new_op"), |
673 | rewriter.getUnitAttr()); |
674 | OperationState opState(op->getLoc(), op->getName().getIdentifier(), |
675 | /*operands=*/ValueRange(), |
676 | /*types=*/op->getResultTypes(), attributes); |
677 | Operation *newOp = rewriter.create(opState); |
678 | rewriter.replaceOp(op, newOp->getResults()); |
679 | } |
680 | } |
681 | |
682 | emitRemark() << numIterations << " iterations"; |
683 | return DiagnosedSilenceableFailure::success(); |
684 | } |
685 | |
686 | namespace { |
687 | // Test pattern to replace an operation with a new op. |
688 | class ReplaceWithNewOp : public RewritePattern { |
689 | public: |
690 | ReplaceWithNewOp(MLIRContext *context) |
691 | : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} |
692 | |
693 | LogicalResult matchAndRewrite(Operation *op, |
694 | PatternRewriter &rewriter) const override { |
695 | auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op"); |
696 | if (!newName) |
697 | return failure(); |
698 | Operation *newOp = rewriter.create( |
699 | op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(), |
700 | op->getOperands(), op->getResultTypes()); |
701 | rewriter.replaceOp(op, newValues: newOp->getResults()); |
702 | return success(); |
703 | } |
704 | }; |
705 | |
706 | // Test pattern to erase an operation. |
707 | class EraseOp : public RewritePattern { |
708 | public: |
709 | EraseOp(MLIRContext *context) |
710 | : RewritePattern("test.erase_op", /*benefit=*/1, context) {} |
711 | LogicalResult matchAndRewrite(Operation *op, |
712 | PatternRewriter &rewriter) const override { |
713 | rewriter.eraseOp(op); |
714 | return success(); |
715 | } |
716 | }; |
717 | } // namespace |
718 | |
719 | void mlir::test::ApplyTestPatternsOp::populatePatterns( |
720 | RewritePatternSet &patterns) { |
721 | patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext()); |
722 | } |
723 | |
724 | void mlir::test::TestReEnterRegionOp::getEffects( |
725 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
726 | transform::consumesHandle(getOperation()->getOpOperands(), effects); |
727 | transform::modifiesPayload(effects); |
728 | } |
729 | |
730 | DiagnosedSilenceableFailure |
731 | mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter, |
732 | transform::TransformResults &results, |
733 | transform::TransformState &state) { |
734 | |
735 | SmallVector<SmallVector<transform::MappedValue>> mappings; |
736 | for (BlockArgument arg : getBody().front().getArguments()) { |
737 | mappings.emplace_back(llvm::to_vector(llvm::map_range( |
738 | state.getPayloadOps(getOperand(arg.getArgNumber())), |
739 | [](Operation *op) -> transform::MappedValue { return op; }))); |
740 | } |
741 | |
742 | for (int i = 0; i < 4; ++i) { |
743 | auto scope = state.make_region_scope(getBody()); |
744 | for (BlockArgument arg : getBody().front().getArguments()) { |
745 | if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()]))) |
746 | return DiagnosedSilenceableFailure::definiteFailure(); |
747 | } |
748 | for (Operation &op : getBody().front().without_terminator()) { |
749 | DiagnosedSilenceableFailure diag = |
750 | state.applyTransform(cast<transform::TransformOpInterface>(op)); |
751 | if (!diag.succeeded()) |
752 | return diag; |
753 | } |
754 | } |
755 | return DiagnosedSilenceableFailure::success(); |
756 | } |
757 | |
758 | LogicalResult mlir::test::TestReEnterRegionOp::verify() { |
759 | if (getNumOperands() != getBody().front().getNumArguments()) { |
760 | return emitOpError() << "expects as many operands as block arguments"; |
761 | } |
762 | return success(); |
763 | } |
764 | |
765 | DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply( |
766 | transform::TransformRewriter &rewriter, |
767 | transform::TransformResults &results, transform::TransformState &state) { |
768 | auto originalOps = state.getPayloadOps(getOriginal()); |
769 | auto replacementOps = state.getPayloadOps(getReplacement()); |
770 | if (llvm::range_size(originalOps) != llvm::range_size(replacementOps)) |
771 | return emitSilenceableError() << "expected same number of original and " |
772 | "replacement payload operations"; |
773 | for (const auto &[original, replacement] : |
774 | llvm::zip(originalOps, replacementOps)) { |
775 | if (failed( |
776 | rewriter.notifyPayloadOperationReplaced(original, replacement))) { |
777 | auto diag = emitSilenceableError() |
778 | << "unable to replace payload op in transform mapping"; |
779 | diag.attachNote(original->getLoc()) << "original payload op"; |
780 | diag.attachNote(replacement->getLoc()) << "replacement payload op"; |
781 | return diag; |
782 | } |
783 | } |
784 | return DiagnosedSilenceableFailure::success(); |
785 | } |
786 | |
787 | void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects( |
788 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
789 | transform::onlyReadsHandle(getOriginalMutable(), effects); |
790 | transform::onlyReadsHandle(getReplacementMutable(), effects); |
791 | } |
792 | |
793 | DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne( |
794 | transform::TransformRewriter &rewriter, Operation *target, |
795 | transform::ApplyToEachResultList &results, |
796 | transform::TransformState &state) { |
797 | // Provide some IR that does not verify. |
798 | rewriter.setInsertionPointToStart(&target->getRegion(0).front()); |
799 | rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(), |
800 | ValueRange(), /*failToVerify=*/true); |
801 | return DiagnosedSilenceableFailure::success(); |
802 | } |
803 | |
804 | void mlir::test::TestProduceInvalidIR::getEffects( |
805 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
806 | transform::onlyReadsHandle(getTargetMutable(), effects); |
807 | transform::modifiesPayload(effects); |
808 | } |
809 | |
810 | DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply( |
811 | transform::TransformRewriter &rewriter, |
812 | transform::TransformResults &results, transform::TransformState &state) { |
813 | std::string opName = |
814 | this->getOperationName().str() + "_"+ getTypeAttr().str(); |
815 | TransformStateInitializerExtension *initExt = |
816 | state.getExtension<TransformStateInitializerExtension>(); |
817 | if (!initExt) { |
818 | emitRemark() << "\nSpecified extension not found, adding a new one!\n"; |
819 | SmallVector<std::string> opCollection = {opName}; |
820 | state.addExtension<TransformStateInitializerExtension>(1, opCollection); |
821 | } else { |
822 | initExt->setNumOp(initExt->getNumOp() + 1); |
823 | initExt->pushRegisteredOps(opName); |
824 | InFlightDiagnostic diag = emitRemark() |
825 | << "Number of currently registered op: " |
826 | << initExt->getNumOp() << "\n" |
827 | << initExt->printMessage() << "\n"; |
828 | } |
829 | return DiagnosedSilenceableFailure::success(); |
830 | } |
831 | |
832 | namespace { |
833 | /// Test conversion pattern that replaces ops with the "replace_with_new_op" |
834 | /// attribute with "test.new_op". |
835 | class ReplaceWithNewOpConversion : public ConversionPattern { |
836 | public: |
837 | ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context) |
838 | : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(), |
839 | /*benefit=*/1, context) {} |
840 | |
841 | LogicalResult |
842 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
843 | ConversionPatternRewriter &rewriter) const override { |
844 | if (!op->hasAttr(name: "replace_with_new_op")) |
845 | return failure(); |
846 | SmallVector<Type> newResultTypes; |
847 | if (failed(Result: getTypeConverter()->convertTypes(types: op->getResultTypes(), |
848 | results&: newResultTypes))) |
849 | return failure(); |
850 | Operation *newOp = rewriter.create( |
851 | op->getLoc(), |
852 | OperationName("test.new_op", op->getContext()).getIdentifier(), |
853 | operands, newResultTypes); |
854 | rewriter.replaceOp(op, newValues: newOp->getResults()); |
855 | return success(); |
856 | } |
857 | }; |
858 | } // namespace |
859 | |
860 | void mlir::test::ApplyTestConversionPatternsOp::populatePatterns( |
861 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
862 | patterns.insert<ReplaceWithNewOpConversion>(typeConverter, |
863 | patterns.getContext()); |
864 | } |
865 | |
866 | namespace { |
867 | /// Test type converter that converts tensor types to memref types. |
868 | class TestTypeConverter : public TypeConverter { |
869 | public: |
870 | TestTypeConverter() { |
871 | addConversion(callback: [](Type t) { return t; }); |
872 | addConversion(callback: [](RankedTensorType type) -> Type { |
873 | return MemRefType::get(type.getShape(), type.getElementType()); |
874 | }); |
875 | auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType, |
876 | ValueRange inputs, |
877 | Location loc) -> Value { |
878 | if (inputs.size() != 1) |
879 | return Value(); |
880 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
881 | .getResult(0); |
882 | }; |
883 | addSourceMaterialization(callback&: unrealizedCastConverter); |
884 | addTargetMaterialization(callback&: unrealizedCastConverter); |
885 | } |
886 | }; |
887 | } // namespace |
888 | |
889 | std::unique_ptr<::mlir::TypeConverter> |
890 | mlir::test::TestTypeConverterOp::getTypeConverter() { |
891 | return std::make_unique<TestTypeConverter>(); |
892 | } |
893 | |
894 | namespace { |
895 | /// Test extension of the Transform dialect. Registers additional ops and |
896 | /// declares PDL as dependent dialect since the additional ops are using PDL |
897 | /// types for operands and results. |
898 | class TestTransformDialectExtension |
899 | : public transform::TransformDialectExtension< |
900 | TestTransformDialectExtension> { |
901 | public: |
902 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension) |
903 | |
904 | using Base::Base; |
905 | |
906 | void init() { |
907 | declareDependentDialect<pdl::PDLDialect>(); |
908 | registerTransformOps<TestTransformOp, |
909 | TestTransformUnrestrictedOpNoInterface, |
910 | #define GET_OP_LIST |
911 | #include "TestTransformDialectExtension.cpp.inc" |
912 | >(); |
913 | registerTypes< |
914 | #define GET_TYPEDEF_LIST |
915 | #include "TestTransformDialectExtensionTypes.cpp.inc" |
916 | >(); |
917 | |
918 | auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &, |
919 | ArrayRef<PDLValue> pdlValues) { |
920 | for (const PDLValue &pdlValue : pdlValues) { |
921 | if (Operation *op = pdlValue.dyn_cast<Operation *>()) { |
922 | op->emitWarning() << "from PDL constraint"; |
923 | } |
924 | } |
925 | return success(); |
926 | }; |
927 | |
928 | addDialectDataInitializer<transform::PDLMatchHooks>( |
929 | [&](transform::PDLMatchHooks &hooks) { |
930 | llvm::StringMap<PDLConstraintFunction> constraints; |
931 | constraints.try_emplace("verbose_constraint", verboseConstraint); |
932 | hooks.mergeInPDLMatchHooks(constraintFns: std::move(constraints)); |
933 | }); |
934 | } |
935 | }; |
936 | } // namespace |
937 | |
938 | // These are automatically generated by ODS but are not used as the Transform |
939 | // dialect uses a different dispatch mechanism to support dialect extensions. |
940 | LLVM_ATTRIBUTE_UNUSED static OptionalParseResult |
941 | generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); |
942 | LLVM_ATTRIBUTE_UNUSED static LogicalResult |
943 | generatedTypePrinter(Type def, AsmPrinter &printer); |
944 | |
945 | #define GET_TYPEDEF_CLASSES |
946 | #include "TestTransformDialectExtensionTypes.cpp.inc" |
947 | |
948 | #define GET_OP_CLASSES |
949 | #include "TestTransformDialectExtension.cpp.inc" |
950 | |
951 | void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { |
952 | registry.addExtensions<TestTransformDialectExtension>(); |
953 | } |
954 |
Definitions
- TestTransformOp
- getAttributeNames
- getOperationName
- apply
- getMessage
- parse
- getEffects
- TestTransformUnrestrictedOpNoInterface
- getAttributeNames
- getOperationName
- apply
- getEffects
- ReplaceWithNewOp
- ReplaceWithNewOp
- matchAndRewrite
- EraseOp
- EraseOp
- matchAndRewrite
- ReplaceWithNewOpConversion
- ReplaceWithNewOpConversion
- matchAndRewrite
- TestTypeConverter
- TestTypeConverter
- TestTransformDialectExtension
- init
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more