1//===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Dialect/Tensor/IR/Tensor.h"
12#include "mlir/IR/Verifier.h"
13#include "mlir/Interfaces/FunctionImplementation.h"
14
15using namespace mlir;
16using namespace test;
17
18//===----------------------------------------------------------------------===//
19// TestBranchOp
20//===----------------------------------------------------------------------===//
21
22SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
23 assert(index == 0 && "invalid successor index");
24 return SuccessorOperands(getTargetOperandsMutable());
25}
26
27//===----------------------------------------------------------------------===//
28// TestProducingBranchOp
29//===----------------------------------------------------------------------===//
30
31SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
32 assert(index <= 1 && "invalid successor index");
33 if (index == 1)
34 return SuccessorOperands(getFirstOperandsMutable());
35 return SuccessorOperands(getSecondOperandsMutable());
36}
37
38//===----------------------------------------------------------------------===//
39// TestInternalBranchOp
40//===----------------------------------------------------------------------===//
41
42SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
43 assert(index <= 1 && "invalid successor index");
44 if (index == 0)
45 return SuccessorOperands(0, getSuccessOperandsMutable());
46 return SuccessorOperands(1, getErrorOperandsMutable());
47}
48
49//===----------------------------------------------------------------------===//
50// TestCallOp
51//===----------------------------------------------------------------------===//
52
53LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
54 // Check that the callee attribute was specified.
55 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
56 if (!fnAttr)
57 return emitOpError("requires a 'callee' symbol reference attribute");
58 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
59 return emitOpError() << "'" << fnAttr.getValue()
60 << "' does not reference a valid function";
61 return success();
62}
63
64//===----------------------------------------------------------------------===//
65// FoldToCallOp
66//===----------------------------------------------------------------------===//
67
68namespace {
69struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
70 using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
71
72 LogicalResult matchAndRewrite(FoldToCallOp op,
73 PatternRewriter &rewriter) const override {
74 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
75 op.getCalleeAttr(), ValueRange());
76 return success();
77 }
78};
79} // namespace
80
81void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
82 MLIRContext *context) {
83 results.add<FoldToCallOpPattern>(context);
84}
85
86//===----------------------------------------------------------------------===//
87// IsolatedRegionOp - test parsing passthrough operands
88//===----------------------------------------------------------------------===//
89
90ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
91 OperationState &result) {
92 // Parse the input operand.
93 OpAsmParser::Argument argInfo;
94 argInfo.type = parser.getBuilder().getIndexType();
95 if (parser.parseOperand(argInfo.ssaName) ||
96 parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
97 return failure();
98
99 // Parse the body region, and reuse the operand info as the argument info.
100 Region *body = result.addRegion();
101 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
102}
103
104void IsolatedRegionOp::print(OpAsmPrinter &p) {
105 p << ' ';
106 p.printOperand(getOperand());
107 p.shadowRegionArgs(getRegion(), getOperand());
108 p << ' ';
109 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
110}
111
112//===----------------------------------------------------------------------===//
113// SSACFGRegionOp
114//===----------------------------------------------------------------------===//
115
116RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
117 return RegionKind::SSACFG;
118}
119
120//===----------------------------------------------------------------------===//
121// GraphRegionOp
122//===----------------------------------------------------------------------===//
123
124RegionKind GraphRegionOp::getRegionKind(unsigned index) {
125 return RegionKind::Graph;
126}
127
128//===----------------------------------------------------------------------===//
129// AffineScopeOp
130//===----------------------------------------------------------------------===//
131
132ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
133 // Parse the body region, and reuse the operand info as the argument info.
134 Region *body = result.addRegion();
135 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
136}
137
138void AffineScopeOp::print(OpAsmPrinter &p) {
139 p << " ";
140 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
141}
142
143//===----------------------------------------------------------------------===//
144// TestRemoveOpWithInnerOps
145//===----------------------------------------------------------------------===//
146
147namespace {
148struct TestRemoveOpWithInnerOps
149 : public OpRewritePattern<TestOpWithRegionPattern> {
150 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
151
152 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
153
154 LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
155 PatternRewriter &rewriter) const override {
156 rewriter.eraseOp(op: op);
157 return success();
158 }
159};
160} // namespace
161
162//===----------------------------------------------------------------------===//
163// TestOpWithRegionPattern
164//===----------------------------------------------------------------------===//
165
166void TestOpWithRegionPattern::getCanonicalizationPatterns(
167 RewritePatternSet &results, MLIRContext *context) {
168 results.add<TestRemoveOpWithInnerOps>(context);
169}
170
171//===----------------------------------------------------------------------===//
172// TestOpWithRegionFold
173//===----------------------------------------------------------------------===//
174
175OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
176 return getOperand();
177}
178
179//===----------------------------------------------------------------------===//
180// TestOpConstant
181//===----------------------------------------------------------------------===//
182
183OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
184
185//===----------------------------------------------------------------------===//
186// TestOpWithVariadicResultsAndFolder
187//===----------------------------------------------------------------------===//
188
189LogicalResult TestOpWithVariadicResultsAndFolder::fold(
190 FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
191 for (Value input : this->getOperands()) {
192 results.push_back(input);
193 }
194 return success();
195}
196
197//===----------------------------------------------------------------------===//
198// TestOpInPlaceFold
199//===----------------------------------------------------------------------===//
200
201OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
202 // Exercise the fact that an operation created with createOrFold should be
203 // allowed to access its parent block.
204 assert(getOperation()->getBlock() &&
205 "expected that operation is not unlinked");
206
207 if (adaptor.getOp() && !getProperties().attr) {
208 // The folder adds "attr" if not present.
209 getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
210 return getResult();
211 }
212 return {};
213}
214
215//===----------------------------------------------------------------------===//
216// OpWithInferTypeInterfaceOp
217//===----------------------------------------------------------------------===//
218
219LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
220 MLIRContext *, std::optional<Location> location, ValueRange operands,
221 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
222 SmallVectorImpl<Type> &inferredReturnTypes) {
223 if (operands[0].getType() != operands[1].getType()) {
224 return emitOptionalError(location, "operand type mismatch ",
225 operands[0].getType(), " vs ",
226 operands[1].getType());
227 }
228 inferredReturnTypes.assign({operands[0].getType()});
229 return success();
230}
231
232//===----------------------------------------------------------------------===//
233// OpWithShapedTypeInferTypeInterfaceOp
234//===----------------------------------------------------------------------===//
235
236LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
237 MLIRContext *context, std::optional<Location> location,
238 ValueShapeRange operands, DictionaryAttr attributes,
239 OpaqueProperties properties, RegionRange regions,
240 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
241 // Create return type consisting of the last element of the first operand.
242 auto operandType = operands.front().getType();
243 auto sval = dyn_cast<ShapedType>(operandType);
244 if (!sval)
245 return emitOptionalError(location, "only shaped type operands allowed");
246 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
247 auto type = IntegerType::get(context, 17);
248
249 Attribute encoding;
250 if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
251 encoding = rankedTy.getEncoding();
252 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
253 return success();
254}
255
256LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
257 OpBuilder &builder, ValueRange operands,
258 llvm::SmallVectorImpl<Value> &shapes) {
259 shapes = SmallVector<Value, 1>{
260 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
261 return success();
262}
263
264//===----------------------------------------------------------------------===//
265// OpWithResultShapeInterfaceOp
266//===----------------------------------------------------------------------===//
267
268LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
269 OpBuilder &builder, ValueRange operands,
270 llvm::SmallVectorImpl<Value> &shapes) {
271 Location loc = getLoc();
272 shapes.reserve(operands.size());
273 for (Value operand : llvm::reverse(operands)) {
274 auto rank = cast<RankedTensorType>(operand.getType()).getRank();
275 auto currShape = llvm::to_vector<4>(
276 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
277 return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
278 }));
279 shapes.push_back(builder.create<tensor::FromElementsOp>(
280 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
281 currShape));
282 }
283 return success();
284}
285
286//===----------------------------------------------------------------------===//
287// OpWithResultShapePerDimInterfaceOp
288//===----------------------------------------------------------------------===//
289
290LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
291 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
292 Location loc = getLoc();
293 shapes.reserve(getNumOperands());
294 for (Value operand : llvm::reverse(getOperands())) {
295 auto tensorType = cast<RankedTensorType>(operand.getType());
296 auto currShape = llvm::to_vector<4>(llvm::map_range(
297 llvm::seq<int64_t>(0, tensorType.getRank()),
298 [&](int64_t dim) -> OpFoldResult {
299 return tensorType.isDynamicDim(dim)
300 ? static_cast<OpFoldResult>(
301 builder.createOrFold<tensor::DimOp>(loc, operand,
302 dim))
303 : static_cast<OpFoldResult>(
304 builder.getIndexAttr(tensorType.getDimSize(dim)));
305 }));
306 shapes.emplace_back(std::move(currShape));
307 }
308 return success();
309}
310
311//===----------------------------------------------------------------------===//
312// SideEffectOp
313//===----------------------------------------------------------------------===//
314
315namespace {
316/// A test resource for side effects.
317struct TestResource : public SideEffects::Resource::Base<TestResource> {
318 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
319
320 StringRef getName() final { return "<Test>"; }
321};
322} // namespace
323
324void SideEffectOp::getEffects(
325 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
326 // Check for an effects attribute on the op instance.
327 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
328 if (!effectsAttr)
329 return;
330
331 // If there is one, it is an array of dictionary attributes that hold
332 // information on the effects of this operation.
333 for (Attribute element : effectsAttr) {
334 DictionaryAttr effectElement = cast<DictionaryAttr>(element);
335
336 // Get the specific memory effect.
337 MemoryEffects::Effect *effect =
338 StringSwitch<MemoryEffects::Effect *>(
339 cast<StringAttr>(effectElement.get("effect")).getValue())
340 .Case("allocate", MemoryEffects::Allocate::get())
341 .Case("free", MemoryEffects::Free::get())
342 .Case("read", MemoryEffects::Read::get())
343 .Case("write", MemoryEffects::Write::get());
344
345 // Check for a non-default resource to use.
346 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
347 if (effectElement.get("test_resource"))
348 resource = TestResource::get();
349
350 // Check for a result to affect.
351 if (effectElement.get("on_result"))
352 effects.emplace_back(effect, getResult(), resource);
353 else if (Attribute ref = effectElement.get("on_reference"))
354 effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
355 else
356 effects.emplace_back(effect, resource);
357 }
358}
359
360void SideEffectOp::getEffects(
361 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
362 testSideEffectOpGetEffect(getOperation(), effects);
363}
364
365//===----------------------------------------------------------------------===//
366// StringAttrPrettyNameOp
367//===----------------------------------------------------------------------===//
368
369// This op has fancy handling of its SSA result name.
370ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
371 OperationState &result) {
372 // Add the result types.
373 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
374 result.addTypes(parser.getBuilder().getIntegerType(32));
375
376 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
377 return failure();
378
379 // If the attribute dictionary contains no 'names' attribute, infer it from
380 // the SSA name (if specified).
381 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
382 return attr.getName() == "names";
383 });
384
385 // If there was no name specified, check to see if there was a useful name
386 // specified in the asm file.
387 if (hadNames || parser.getNumResults() == 0)
388 return success();
389
390 SmallVector<StringRef, 4> names;
391 auto *context = result.getContext();
392
393 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
394 auto resultName = parser.getResultName(i);
395 StringRef nameStr;
396 if (!resultName.first.empty() && !isdigit(resultName.first[0]))
397 nameStr = resultName.first;
398
399 names.push_back(nameStr);
400 }
401
402 auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
403 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
404 return success();
405}
406
407void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
408 // Note that we only need to print the "name" attribute if the asmprinter
409 // result name disagrees with it. This can happen in strange cases, e.g.
410 // when there are conflicts.
411 bool namesDisagree = getNames().size() != getNumResults();
412
413 SmallString<32> resultNameStr;
414 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
415 resultNameStr.clear();
416 llvm::raw_svector_ostream tmpStream(resultNameStr);
417 p.printOperand(getResult(i), tmpStream);
418
419 auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
420 if (!expectedName ||
421 tmpStream.str().drop_front() != expectedName.getValue()) {
422 namesDisagree = true;
423 }
424 }
425
426 if (namesDisagree)
427 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
428 else
429 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
430}
431
432// We set the SSA name in the asm syntax to the contents of the name
433// attribute.
434void StringAttrPrettyNameOp::getAsmResultNames(
435 function_ref<void(Value, StringRef)> setNameFn) {
436
437 auto value = getNames();
438 for (size_t i = 0, e = value.size(); i != e; ++i)
439 if (auto str = dyn_cast<StringAttr>(value[i]))
440 if (!str.getValue().empty())
441 setNameFn(getResult(i), str.getValue());
442}
443
444//===----------------------------------------------------------------------===//
445// CustomResultsNameOp
446//===----------------------------------------------------------------------===//
447
448void CustomResultsNameOp::getAsmResultNames(
449 function_ref<void(Value, StringRef)> setNameFn) {
450 ArrayAttr value = getNames();
451 for (size_t i = 0, e = value.size(); i != e; ++i)
452 if (auto str = dyn_cast<StringAttr>(value[i]))
453 if (!str.empty())
454 setNameFn(getResult(i), str.getValue());
455}
456
457//===----------------------------------------------------------------------===//
458// ResultTypeWithTraitOp
459//===----------------------------------------------------------------------===//
460
461LogicalResult ResultTypeWithTraitOp::verify() {
462 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
463 return success();
464 return emitError("result type should have trait 'TestTypeTrait'");
465}
466
467//===----------------------------------------------------------------------===//
468// AttrWithTraitOp
469//===----------------------------------------------------------------------===//
470
471LogicalResult AttrWithTraitOp::verify() {
472 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
473 return success();
474 return emitError("'attr' attribute should have trait 'TestAttrTrait'");
475}
476
477//===----------------------------------------------------------------------===//
478// RegionIfOp
479//===----------------------------------------------------------------------===//
480
481void RegionIfOp::print(OpAsmPrinter &p) {
482 p << " ";
483 p.printOperands(getOperands());
484 p << ": " << getOperandTypes();
485 p.printArrowTypeList(getResultTypes());
486 p << " then ";
487 p.printRegion(getThenRegion(),
488 /*printEntryBlockArgs=*/true,
489 /*printBlockTerminators=*/true);
490 p << " else ";
491 p.printRegion(getElseRegion(),
492 /*printEntryBlockArgs=*/true,
493 /*printBlockTerminators=*/true);
494 p << " join ";
495 p.printRegion(getJoinRegion(),
496 /*printEntryBlockArgs=*/true,
497 /*printBlockTerminators=*/true);
498}
499
500ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
501 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
502 SmallVector<Type, 2> operandTypes;
503
504 result.regions.reserve(3);
505 Region *thenRegion = result.addRegion();
506 Region *elseRegion = result.addRegion();
507 Region *joinRegion = result.addRegion();
508
509 // Parse operand, type and arrow type lists.
510 if (parser.parseOperandList(operandInfos) ||
511 parser.parseColonTypeList(operandTypes) ||
512 parser.parseArrowTypeList(result.types))
513 return failure();
514
515 // Parse all attached regions.
516 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
517 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
518 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
519 return failure();
520
521 return parser.resolveOperands(operandInfos, operandTypes,
522 parser.getCurrentLocation(), result.operands);
523}
524
525OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
526 assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
527 "invalid region index");
528 return getOperands();
529}
530
531void RegionIfOp::getSuccessorRegions(
532 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
533 // We always branch to the join region.
534 if (!point.isParent()) {
535 if (point != getJoinRegion())
536 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
537 else
538 regions.push_back(RegionSuccessor(getResults()));
539 return;
540 }
541
542 // The then and else regions are the entry regions of this op.
543 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
544 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
545}
546
547void RegionIfOp::getRegionInvocationBounds(
548 ArrayRef<Attribute> operands,
549 SmallVectorImpl<InvocationBounds> &invocationBounds) {
550 // Each region is invoked at most once.
551 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
552}
553
554//===----------------------------------------------------------------------===//
555// AnyCondOp
556//===----------------------------------------------------------------------===//
557
558void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
559 SmallVectorImpl<RegionSuccessor> &regions) {
560 // The parent op branches into the only region, and the region branches back
561 // to the parent op.
562 if (point.isParent())
563 regions.emplace_back(&getRegion());
564 else
565 regions.emplace_back(getResults());
566}
567
568void AnyCondOp::getRegionInvocationBounds(
569 ArrayRef<Attribute> operands,
570 SmallVectorImpl<InvocationBounds> &invocationBounds) {
571 invocationBounds.emplace_back(1, 1);
572}
573
574//===----------------------------------------------------------------------===//
575// SingleBlockImplicitTerminatorOp
576//===----------------------------------------------------------------------===//
577
578/// Testing the correctness of some traits.
579static_assert(
580 llvm::is_detected<OpTrait::has_implicit_terminator_t,
581 SingleBlockImplicitTerminatorOp>::value,
582 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
583static_assert(OpTrait::hasSingleBlockImplicitTerminator<
584 SingleBlockImplicitTerminatorOp>::value,
585 "hasSingleBlockImplicitTerminator does not match "
586 "SingleBlockImplicitTerminatorOp");
587
588//===----------------------------------------------------------------------===//
589// SingleNoTerminatorCustomAsmOp
590//===----------------------------------------------------------------------===//
591
592ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
593 OperationState &state) {
594 Region *body = state.addRegion();
595 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
596 return failure();
597 return success();
598}
599
600void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
601 printer.printRegion(
602 getRegion(), /*printEntryBlockArgs=*/false,
603 // This op has a single block without terminators. But explicitly mark
604 // as not printing block terminators for testing.
605 /*printBlockTerminators=*/false);
606}
607
608//===----------------------------------------------------------------------===//
609// TestVerifiersOp
610//===----------------------------------------------------------------------===//
611
612LogicalResult TestVerifiersOp::verify() {
613 if (!getRegion().hasOneBlock())
614 return emitOpError("`hasOneBlock` trait hasn't been verified");
615
616 Operation *definingOp = getInput().getDefiningOp();
617 if (definingOp && failed(mlir::verify(definingOp)))
618 return emitOpError("operand hasn't been verified");
619
620 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
621 // loop.
622 mlir::emitRemark(getLoc(), "success run of verifier");
623
624 return success();
625}
626
627LogicalResult TestVerifiersOp::verifyRegions() {
628 if (!getRegion().hasOneBlock())
629 return emitOpError("`hasOneBlock` trait hasn't been verified");
630
631 for (Block &block : getRegion())
632 for (Operation &op : block)
633 if (failed(mlir::verify(&op)))
634 return emitOpError("nested op hasn't been verified");
635
636 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
637 // loop.
638 mlir::emitRemark(getLoc(), "success run of region verifier");
639
640 return success();
641}
642
643//===----------------------------------------------------------------------===//
644// Test InferIntRangeInterface
645//===----------------------------------------------------------------------===//
646
647//===----------------------------------------------------------------------===//
648// TestWithBoundsOp
649
650void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
651 SetIntRangeFn setResultRanges) {
652 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
653}
654
655//===----------------------------------------------------------------------===//
656// TestWithBoundsRegionOp
657
658ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
659 OperationState &result) {
660 if (parser.parseOptionalAttrDict(result.attributes))
661 return failure();
662
663 // Parse the input argument
664 OpAsmParser::Argument argInfo;
665 argInfo.type = parser.getBuilder().getIndexType();
666 if (failed(parser.parseArgument(argInfo)))
667 return failure();
668
669 // Parse the body region, and reuse the operand info as the argument info.
670 Region *body = result.addRegion();
671 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
672}
673
674void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
675 p.printOptionalAttrDict((*this)->getAttrs());
676 p << ' ';
677 p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
678 /*omitType=*/true);
679 p << ' ';
680 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
681}
682
683void TestWithBoundsRegionOp::inferResultRanges(
684 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
685 Value arg = getRegion().getArgument(0);
686 setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
687}
688
689//===----------------------------------------------------------------------===//
690// TestIncrementOp
691
692void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
693 SetIntRangeFn setResultRanges) {
694 const ConstantIntRanges &range = argRanges[0];
695 APInt one(range.umin().getBitWidth(), 1);
696 setResultRanges(getResult(),
697 {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
698 range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
699}
700
701//===----------------------------------------------------------------------===//
702// TestReflectBoundsOp
703
704void TestReflectBoundsOp::inferResultRanges(
705 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
706 const ConstantIntRanges &range = argRanges[0];
707 MLIRContext *ctx = getContext();
708 Builder b(ctx);
709 setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
710 setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
711 setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
712 setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
713 setResultRanges(getResult(), range);
714}
715
716//===----------------------------------------------------------------------===//
717// ConversionFuncOp
718//===----------------------------------------------------------------------===//
719
720ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
721 OperationState &result) {
722 auto buildFuncType =
723 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
724 function_interface_impl::VariadicFlag,
725 std::string &) { return builder.getFunctionType(argTypes, results); };
726
727 return function_interface_impl::parseFunctionOp(
728 parser, result, /*allowVariadic=*/false,
729 getFunctionTypeAttrName(result.name), buildFuncType,
730 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
731}
732
733void ConversionFuncOp::print(OpAsmPrinter &p) {
734 function_interface_impl::printFunctionOp(
735 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
736 getArgAttrsAttrName(), getResAttrsAttrName());
737}
738
739//===----------------------------------------------------------------------===//
740// ReifyBoundOp
741//===----------------------------------------------------------------------===//
742
743mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
744 if (getType() == "EQ")
745 return mlir::presburger::BoundType::EQ;
746 if (getType() == "LB")
747 return mlir::presburger::BoundType::LB;
748 if (getType() == "UB")
749 return mlir::presburger::BoundType::UB;
750 llvm_unreachable("invalid bound type");
751}
752
753LogicalResult ReifyBoundOp::verify() {
754 if (isa<ShapedType>(getVar().getType())) {
755 if (!getDim().has_value())
756 return emitOpError("expected 'dim' attribute for shaped type variable");
757 } else if (getVar().getType().isIndex()) {
758 if (getDim().has_value())
759 return emitOpError("unexpected 'dim' attribute for index variable");
760 } else {
761 return emitOpError("expected index-typed variable or shape type variable");
762 }
763 if (getConstant() && getScalable())
764 return emitOpError("'scalable' and 'constant' are mutually exlusive");
765 if (getScalable() != getVscaleMin().has_value())
766 return emitOpError("expected 'vscale_min' if and only if 'scalable'");
767 if (getScalable() != getVscaleMax().has_value())
768 return emitOpError("expected 'vscale_min' if and only if 'scalable'");
769 return success();
770}
771
772ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
773 if (getDim().has_value())
774 return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
775 return ValueBoundsConstraintSet::Variable(getVar());
776}
777
778//===----------------------------------------------------------------------===//
779// CompareOp
780//===----------------------------------------------------------------------===//
781
782ValueBoundsConstraintSet::ComparisonOperator
783CompareOp::getComparisonOperator() {
784 if (getCmp() == "EQ")
785 return ValueBoundsConstraintSet::ComparisonOperator::EQ;
786 if (getCmp() == "LT")
787 return ValueBoundsConstraintSet::ComparisonOperator::LT;
788 if (getCmp() == "LE")
789 return ValueBoundsConstraintSet::ComparisonOperator::LE;
790 if (getCmp() == "GT")
791 return ValueBoundsConstraintSet::ComparisonOperator::GT;
792 if (getCmp() == "GE")
793 return ValueBoundsConstraintSet::ComparisonOperator::GE;
794 llvm_unreachable("invalid comparison operator");
795}
796
797mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
798 if (!getLhsMap())
799 return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
800 SmallVector<Value> mapOperands(
801 getVarOperands().slice(0, getLhsMap()->getNumInputs()));
802 return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
803}
804
805mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
806 int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
807 if (!getRhsMap())
808 return ValueBoundsConstraintSet::Variable(
809 getVarOperands()[rhsOperandsBegin]);
810 SmallVector<Value> mapOperands(
811 getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
812 return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
813}
814
815LogicalResult CompareOp::verify() {
816 if (getCompose() && (getLhsMap() || getRhsMap()))
817 return emitOpError(
818 "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
819 int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
820 expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
821 if (getVarOperands().size() != size_t(expectedNumOperands))
822 return emitOpError("expected ")
823 << expectedNumOperands << " operands, but got "
824 << getVarOperands().size();
825 return success();
826}
827
828//===----------------------------------------------------------------------===//
829// TestOpInPlaceSelfFold
830//===----------------------------------------------------------------------===//
831
832OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
833 if (!getFolded()) {
834 // The folder adds the "folded" if not present.
835 setFolded(true);
836 return getResult();
837 }
838 return {};
839}
840
841//===----------------------------------------------------------------------===//
842// TestOpFoldWithFoldAdaptor
843//===----------------------------------------------------------------------===//
844
845OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
846 int64_t sum = 0;
847 if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
848 sum += value.getValue().getSExtValue();
849
850 for (Attribute attr : adaptor.getVariadic())
851 if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
852 sum += 2 * value.getValue().getSExtValue();
853
854 for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
855 for (Attribute attr : attrs)
856 if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
857 sum += 3 * value.getValue().getSExtValue();
858
859 sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
860
861 return IntegerAttr::get(getType(), sum);
862}
863
864//===----------------------------------------------------------------------===//
865// OpWithInferTypeAdaptorInterfaceOp
866//===----------------------------------------------------------------------===//
867
868LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
869 MLIRContext *, std::optional<Location> location,
870 OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
871 SmallVectorImpl<Type> &inferredReturnTypes) {
872 if (adaptor.getX().getType() != adaptor.getY().getType()) {
873 return emitOptionalError(location, "operand type mismatch ",
874 adaptor.getX().getType(), " vs ",
875 adaptor.getY().getType());
876 }
877 inferredReturnTypes.assign({adaptor.getX().getType()});
878 return success();
879}
880
881//===----------------------------------------------------------------------===//
882// OpWithRefineTypeInterfaceOp
883//===----------------------------------------------------------------------===//
884
885// TODO: We should be able to only define either inferReturnType or
886// refineReturnType, currently only refineReturnType can be omitted.
887LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
888 MLIRContext *context, std::optional<Location> location, ValueRange operands,
889 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
890 SmallVectorImpl<Type> &returnTypes) {
891 returnTypes.clear();
892 return OpWithRefineTypeInterfaceOp::refineReturnTypes(
893 context, location, operands, attributes, properties, regions,
894 returnTypes);
895}
896
897LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
898 MLIRContext *, std::optional<Location> location, ValueRange operands,
899 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
900 SmallVectorImpl<Type> &returnTypes) {
901 if (operands[0].getType() != operands[1].getType()) {
902 return emitOptionalError(location, "operand type mismatch ",
903 operands[0].getType(), " vs ",
904 operands[1].getType());
905 }
906 // TODO: Add helper to make this more concise to write.
907 if (returnTypes.empty())
908 returnTypes.resize(1, nullptr);
909 if (returnTypes[0] && returnTypes[0] != operands[0].getType())
910 return emitOptionalError(location,
911 "required first operand and result to match");
912 returnTypes[0] = operands[0].getType();
913 return success();
914}
915
916//===----------------------------------------------------------------------===//
917// OpWithShapedTypeInferTypeAdaptorInterfaceOp
918//===----------------------------------------------------------------------===//
919
920LogicalResult
921OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
922 MLIRContext *context, std::optional<Location> location,
923 OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
924 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
925 // Create return type consisting of the last element of the first operand.
926 auto operandType = adaptor.getOperand1().getType();
927 auto sval = dyn_cast<ShapedType>(operandType);
928 if (!sval)
929 return emitOptionalError(location, "only shaped type operands allowed");
930 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
931 auto type = IntegerType::get(context, 17);
932
933 Attribute encoding;
934 if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
935 encoding = rankedTy.getEncoding();
936 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
937 return success();
938}
939
940LogicalResult
941OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
942 OpBuilder &builder, ValueRange operands,
943 llvm::SmallVectorImpl<Value> &shapes) {
944 shapes = SmallVector<Value, 1>{
945 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
946 return success();
947}
948
949//===----------------------------------------------------------------------===//
950// TestOpWithPropertiesAndInferredType
951//===----------------------------------------------------------------------===//
952
953LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
954 MLIRContext *context, std::optional<Location>, ValueRange operands,
955 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
956 SmallVectorImpl<Type> &inferredReturnTypes) {
957
958 Adaptor adaptor(operands, attributes, properties, regions);
959 inferredReturnTypes.push_back(IntegerType::get(
960 context, adaptor.getLhs() + adaptor.getProperties().rhs));
961 return success();
962}
963
964//===----------------------------------------------------------------------===//
965// LoopBlockOp
966//===----------------------------------------------------------------------===//
967
968void LoopBlockOp::getSuccessorRegions(
969 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
970 regions.emplace_back(&getBody(), getBody().getArguments());
971 if (point.isParent())
972 return;
973
974 regions.emplace_back((*this)->getResults());
975}
976
977OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
978 assert(point == getBody());
979 return MutableOperandRange(getInitMutable());
980}
981
982//===----------------------------------------------------------------------===//
983// LoopBlockTerminatorOp
984//===----------------------------------------------------------------------===//
985
986MutableOperandRange
987LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
988 if (point.isParent())
989 return getExitArgMutable();
990 return getNextIterArgMutable();
991}
992
993//===----------------------------------------------------------------------===//
994// SwitchWithNoBreakOp
995//===----------------------------------------------------------------------===//
996
997void TestNoTerminatorOp::getSuccessorRegions(
998 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}
999
1000//===----------------------------------------------------------------------===//
1001// Test InferIntRangeInterface
1002//===----------------------------------------------------------------------===//
1003
1004OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
1005 // Just a simple fold for testing purposes that reads an operands constant
1006 // value and returns it.
1007 if (!attributes.empty())
1008 return attributes.front();
1009 return nullptr;
1010}
1011
1012//===----------------------------------------------------------------------===//
1013// Tensor/Buffer Ops
1014//===----------------------------------------------------------------------===//
1015
1016void ReadBufferOp::getEffects(
1017 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1018 &effects) {
1019 // The buffer operand is read.
1020 effects.emplace_back(MemoryEffects::Read::get(), getBuffer(),
1021 SideEffects::DefaultResource::get());
1022 // The buffer contents are dumped.
1023 effects.emplace_back(MemoryEffects::Write::get(),
1024 SideEffects::DefaultResource::get());
1025}
1026
1027//===----------------------------------------------------------------------===//
1028// Test Dataflow
1029//===----------------------------------------------------------------------===//
1030
1031//===----------------------------------------------------------------------===//
1032// TestCallAndStoreOp
1033
1034CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
1035 return getCallee();
1036}
1037
1038void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1039 setCalleeAttr(callee.get<SymbolRefAttr>());
1040}
1041
1042Operation::operand_range TestCallAndStoreOp::getArgOperands() {
1043 return getCalleeOperands();
1044}
1045
1046MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
1047 return getCalleeOperandsMutable();
1048}
1049
1050//===----------------------------------------------------------------------===//
1051// TestCallOnDeviceOp
1052
1053CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
1054 return getCallee();
1055}
1056
1057void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1058 setCalleeAttr(callee.get<SymbolRefAttr>());
1059}
1060
1061Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
1062 return getForwardedOperands();
1063}
1064
1065MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
1066 return getForwardedOperandsMutable();
1067}
1068
1069//===----------------------------------------------------------------------===//
1070// TestStoreWithARegion
1071
1072void TestStoreWithARegion::getSuccessorRegions(
1073 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1074 if (point.isParent())
1075 regions.emplace_back(&getBody(), getBody().front().getArguments());
1076 else
1077 regions.emplace_back();
1078}
1079
1080//===----------------------------------------------------------------------===//
1081// TestStoreWithALoopRegion
1082
1083void TestStoreWithALoopRegion::getSuccessorRegions(
1084 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1085 // Both the operation itself and the region may be branching into the body or
1086 // back into the operation itself. It is possible for the operation not to
1087 // enter the body.
1088 regions.emplace_back(
1089 RegionSuccessor(&getBody(), getBody().front().getArguments()));
1090 regions.emplace_back();
1091}
1092
1093//===----------------------------------------------------------------------===//
1094// TestVersionedOpA
1095//===----------------------------------------------------------------------===//
1096
1097LogicalResult
1098TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader,
1099 mlir::OperationState &state) {
1100 auto &prop = state.getOrAddProperties<Properties>();
1101 if (mlir::failed(reader.readAttribute(prop.dims)))
1102 return mlir::failure();
1103
1104 // Check if we have a version. If not, assume we are parsing the current
1105 // version.
1106 auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1107 if (succeeded(maybeVersion)) {
1108 // If version is less than 2.0, there is no additional attribute to parse.
1109 // We can materialize missing properties post parsing before verification.
1110 const auto *version =
1111 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1112 if ((version->major_ < 2)) {
1113 return success();
1114 }
1115 }
1116
1117 if (mlir::failed(reader.readAttribute(prop.modifier)))
1118 return mlir::failure();
1119 return mlir::success();
1120}
1121
1122void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
1123 auto &prop = getProperties();
1124 writer.writeAttribute(prop.dims);
1125
1126 auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
1127 if (succeeded(maybeVersion)) {
1128 // If version is less than 2.0, there is no additional attribute to write.
1129 const auto *version =
1130 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1131 if ((version->major_ < 2)) {
1132 llvm::outs() << "downgrading op properties...\n";
1133 return;
1134 }
1135 }
1136 writer.writeAttribute(prop.modifier);
1137}
1138
1139//===----------------------------------------------------------------------===//
1140// TestOpWithVersionedProperties
1141//===----------------------------------------------------------------------===//
1142
1143mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
1144 mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
1145 uint64_t value1, value2 = 0;
1146 if (failed(reader.readVarInt(value1)))
1147 return failure();
1148
1149 // Check if we have a version. If not, assume we are parsing the current
1150 // version.
1151 auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1152 bool needToParseAnotherInt = true;
1153 if (succeeded(maybeVersion)) {
1154 // If version is less than 2.0, there is no additional attribute to parse.
1155 // We can materialize missing properties post parsing before verification.
1156 const auto *version =
1157 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1158 if ((version->major_ < 2))
1159 needToParseAnotherInt = false;
1160 }
1161 if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
1162 return failure();
1163
1164 prop.value1 = value1;
1165 prop.value2 = value2;
1166 return success();
1167}
1168
1169void TestOpWithVersionedProperties::writeToMlirBytecode(
1170 mlir::DialectBytecodeWriter &writer,
1171 const test::VersionedProperties &prop) {
1172 writer.writeVarInt(prop.value1);
1173 writer.writeVarInt(prop.value2);
1174}
1175

source code of mlir/test/lib/Dialect/Test/TestOpDefs.cpp