1//===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Dialect/Tensor/IR/Tensor.h"
12#include "mlir/IR/Verifier.h"
13#include "mlir/Interfaces/FunctionImplementation.h"
14#include "mlir/Interfaces/MemorySlotInterfaces.h"
15
16using namespace mlir;
17using namespace test;
18
19//===----------------------------------------------------------------------===//
20// TestBranchOp
21//===----------------------------------------------------------------------===//
22
23SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
24 assert(index == 0 && "invalid successor index");
25 return SuccessorOperands(getTargetOperandsMutable());
26}
27
28//===----------------------------------------------------------------------===//
29// TestProducingBranchOp
30//===----------------------------------------------------------------------===//
31
32SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
33 assert(index <= 1 && "invalid successor index");
34 if (index == 1)
35 return SuccessorOperands(getFirstOperandsMutable());
36 return SuccessorOperands(getSecondOperandsMutable());
37}
38
39//===----------------------------------------------------------------------===//
40// TestInternalBranchOp
41//===----------------------------------------------------------------------===//
42
43SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
44 assert(index <= 1 && "invalid successor index");
45 if (index == 0)
46 return SuccessorOperands(0, getSuccessOperandsMutable());
47 return SuccessorOperands(1, getErrorOperandsMutable());
48}
49
50//===----------------------------------------------------------------------===//
51// TestCallOp
52//===----------------------------------------------------------------------===//
53
54LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
55 // Check that the callee attribute was specified.
56 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
57 if (!fnAttr)
58 return emitOpError("requires a 'callee' symbol reference attribute");
59 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
60 return emitOpError() << "'" << fnAttr.getValue()
61 << "' does not reference a valid function";
62 return success();
63}
64
65//===----------------------------------------------------------------------===//
66// FoldToCallOp
67//===----------------------------------------------------------------------===//
68
69namespace {
70struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
71 using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
72
73 LogicalResult matchAndRewrite(FoldToCallOp op,
74 PatternRewriter &rewriter) const override {
75 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
76 op.getCalleeAttr(), ValueRange());
77 return success();
78 }
79};
80} // namespace
81
82void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
83 MLIRContext *context) {
84 results.add<FoldToCallOpPattern>(context);
85}
86
87//===----------------------------------------------------------------------===//
88// IsolatedRegionOp - test parsing passthrough operands
89//===----------------------------------------------------------------------===//
90
91ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
92 OperationState &result) {
93 // Parse the input operand.
94 OpAsmParser::Argument argInfo;
95 argInfo.type = parser.getBuilder().getIndexType();
96 if (parser.parseOperand(argInfo.ssaName) ||
97 parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
98 return failure();
99
100 // Parse the body region, and reuse the operand info as the argument info.
101 Region *body = result.addRegion();
102 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
103}
104
105void IsolatedRegionOp::print(OpAsmPrinter &p) {
106 p << ' ';
107 p.printOperand(getOperand());
108 p.shadowRegionArgs(getRegion(), getOperand());
109 p << ' ';
110 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
111}
112
113//===----------------------------------------------------------------------===//
114// SSACFGRegionOp
115//===----------------------------------------------------------------------===//
116
117RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
118 return RegionKind::SSACFG;
119}
120
121//===----------------------------------------------------------------------===//
122// GraphRegionOp
123//===----------------------------------------------------------------------===//
124
125RegionKind GraphRegionOp::getRegionKind(unsigned index) {
126 return RegionKind::Graph;
127}
128
129//===----------------------------------------------------------------------===//
130// IsolatedGraphRegionOp
131//===----------------------------------------------------------------------===//
132
133RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) {
134 return RegionKind::Graph;
135}
136
137//===----------------------------------------------------------------------===//
138// AffineScopeOp
139//===----------------------------------------------------------------------===//
140
141ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
142 // Parse the body region, and reuse the operand info as the argument info.
143 Region *body = result.addRegion();
144 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
145}
146
147void AffineScopeOp::print(OpAsmPrinter &p) {
148 p << " ";
149 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
150}
151
152//===----------------------------------------------------------------------===//
153// TestRemoveOpWithInnerOps
154//===----------------------------------------------------------------------===//
155
156namespace {
157struct TestRemoveOpWithInnerOps
158 : public OpRewritePattern<TestOpWithRegionPattern> {
159 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
160
161 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
162
163 LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
164 PatternRewriter &rewriter) const override {
165 rewriter.eraseOp(op: op);
166 return success();
167 }
168};
169} // namespace
170
171//===----------------------------------------------------------------------===//
172// TestOpWithRegionPattern
173//===----------------------------------------------------------------------===//
174
175void TestOpWithRegionPattern::getCanonicalizationPatterns(
176 RewritePatternSet &results, MLIRContext *context) {
177 results.add<TestRemoveOpWithInnerOps>(context);
178}
179
180//===----------------------------------------------------------------------===//
181// TestOpWithRegionFold
182//===----------------------------------------------------------------------===//
183
184OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
185 return getOperand();
186}
187
188//===----------------------------------------------------------------------===//
189// TestOpConstant
190//===----------------------------------------------------------------------===//
191
192OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
193
194//===----------------------------------------------------------------------===//
195// TestOpWithVariadicResultsAndFolder
196//===----------------------------------------------------------------------===//
197
198LogicalResult TestOpWithVariadicResultsAndFolder::fold(
199 FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
200 for (Value input : this->getOperands()) {
201 results.push_back(input);
202 }
203 return success();
204}
205
206//===----------------------------------------------------------------------===//
207// TestOpInPlaceFold
208//===----------------------------------------------------------------------===//
209
210OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
211 // Exercise the fact that an operation created with createOrFold should be
212 // allowed to access its parent block.
213 assert(getOperation()->getBlock() &&
214 "expected that operation is not unlinked");
215
216 if (adaptor.getOp() && !getProperties().attr) {
217 // The folder adds "attr" if not present.
218 getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
219 return getResult();
220 }
221 return {};
222}
223
224//===----------------------------------------------------------------------===//
225// OpWithInferTypeInterfaceOp
226//===----------------------------------------------------------------------===//
227
228LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
229 MLIRContext *, std::optional<Location> location, ValueRange operands,
230 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
231 SmallVectorImpl<Type> &inferredReturnTypes) {
232 if (operands[0].getType() != operands[1].getType()) {
233 return emitOptionalError(location, "operand type mismatch ",
234 operands[0].getType(), " vs ",
235 operands[1].getType());
236 }
237 inferredReturnTypes.assign({operands[0].getType()});
238 return success();
239}
240
241//===----------------------------------------------------------------------===//
242// OpWithShapedTypeInferTypeInterfaceOp
243//===----------------------------------------------------------------------===//
244
245LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
246 MLIRContext *context, std::optional<Location> location,
247 ValueShapeRange operands, DictionaryAttr attributes,
248 OpaqueProperties properties, RegionRange regions,
249 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
250 // Create return type consisting of the last element of the first operand.
251 auto operandType = operands.front().getType();
252 auto sval = dyn_cast<ShapedType>(operandType);
253 if (!sval)
254 return emitOptionalError(location, "only shaped type operands allowed");
255 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
256 auto type = IntegerType::get(context, 17);
257
258 Attribute encoding;
259 if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
260 encoding = rankedTy.getEncoding();
261 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
262 return success();
263}
264
265LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
266 OpBuilder &builder, ValueRange operands,
267 llvm::SmallVectorImpl<Value> &shapes) {
268 shapes = SmallVector<Value, 1>{
269 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
270 return success();
271}
272
273//===----------------------------------------------------------------------===//
274// OpWithResultShapeInterfaceOp
275//===----------------------------------------------------------------------===//
276
277LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
278 OpBuilder &builder, ValueRange operands,
279 llvm::SmallVectorImpl<Value> &shapes) {
280 Location loc = getLoc();
281 shapes.reserve(operands.size());
282 for (Value operand : llvm::reverse(operands)) {
283 auto rank = cast<RankedTensorType>(operand.getType()).getRank();
284 auto currShape = llvm::to_vector<4>(
285 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
286 return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
287 }));
288 shapes.push_back(builder.create<tensor::FromElementsOp>(
289 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
290 currShape));
291 }
292 return success();
293}
294
295//===----------------------------------------------------------------------===//
296// OpWithResultShapePerDimInterfaceOp
297//===----------------------------------------------------------------------===//
298
299LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
300 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
301 Location loc = getLoc();
302 shapes.reserve(getNumOperands());
303 for (Value operand : llvm::reverse(getOperands())) {
304 auto tensorType = cast<RankedTensorType>(operand.getType());
305 auto currShape = llvm::to_vector<4>(llvm::map_range(
306 llvm::seq<int64_t>(0, tensorType.getRank()),
307 [&](int64_t dim) -> OpFoldResult {
308 return tensorType.isDynamicDim(dim)
309 ? static_cast<OpFoldResult>(
310 builder.createOrFold<tensor::DimOp>(loc, operand,
311 dim))
312 : static_cast<OpFoldResult>(
313 builder.getIndexAttr(tensorType.getDimSize(dim)));
314 }));
315 shapes.emplace_back(std::move(currShape));
316 }
317 return success();
318}
319
320//===----------------------------------------------------------------------===//
321// SideEffectOp
322//===----------------------------------------------------------------------===//
323
324namespace {
325/// A test resource for side effects.
326struct TestResource : public SideEffects::Resource::Base<TestResource> {
327 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
328
329 StringRef getName() final { return "<Test>"; }
330};
331} // namespace
332
333void SideEffectOp::getEffects(
334 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
335 // Check for an effects attribute on the op instance.
336 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
337 if (!effectsAttr)
338 return;
339
340 for (Attribute element : effectsAttr) {
341 DictionaryAttr effectElement = cast<DictionaryAttr>(element);
342
343 // Get the specific memory effect.
344 MemoryEffects::Effect *effect =
345 StringSwitch<MemoryEffects::Effect *>(
346 cast<StringAttr>(effectElement.get("effect")).getValue())
347 .Case("allocate", MemoryEffects::Allocate::get())
348 .Case("free", MemoryEffects::Free::get())
349 .Case("read", MemoryEffects::Read::get())
350 .Case("write", MemoryEffects::Write::get());
351
352 // Check for a non-default resource to use.
353 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
354 if (effectElement.get("test_resource"))
355 resource = TestResource::get();
356
357 // Check for a result to affect.
358 if (effectElement.get("on_result"))
359 effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
360 else if (Attribute ref = effectElement.get("on_reference"))
361 effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
362 else
363 effects.emplace_back(effect, resource);
364 }
365}
366
367void SideEffectOp::getEffects(
368 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
369 testSideEffectOpGetEffect(getOperation(), effects);
370}
371
372void SideEffectWithRegionOp::getEffects(
373 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
374 // Check for an effects attribute on the op instance.
375 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
376 if (!effectsAttr)
377 return;
378
379 for (Attribute element : effectsAttr) {
380 DictionaryAttr effectElement = cast<DictionaryAttr>(element);
381
382 // Get the specific memory effect.
383 MemoryEffects::Effect *effect =
384 StringSwitch<MemoryEffects::Effect *>(
385 cast<StringAttr>(effectElement.get("effect")).getValue())
386 .Case("allocate", MemoryEffects::Allocate::get())
387 .Case("free", MemoryEffects::Free::get())
388 .Case("read", MemoryEffects::Read::get())
389 .Case("write", MemoryEffects::Write::get());
390
391 // Check for a non-default resource to use.
392 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
393 if (effectElement.get("test_resource"))
394 resource = TestResource::get();
395
396 // Check for a result to affect.
397 if (effectElement.get("on_result"))
398 effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
399 else if (effectElement.get("on_operand"))
400 effects.emplace_back(effect, &getOperation()->getOpOperands()[0],
401 resource);
402 else if (effectElement.get("on_argument"))
403 effects.emplace_back(effect, getOperation()->getRegion(0).getArgument(0),
404 resource);
405 else if (Attribute ref = effectElement.get("on_reference"))
406 effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
407 else
408 effects.emplace_back(effect, resource);
409 }
410}
411
412void SideEffectWithRegionOp::getEffects(
413 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
414 testSideEffectOpGetEffect(getOperation(), effects);
415}
416
417//===----------------------------------------------------------------------===//
418// StringAttrPrettyNameOp
419//===----------------------------------------------------------------------===//
420
421// This op has fancy handling of its SSA result name.
422ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
423 OperationState &result) {
424 // Add the result types.
425 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
426 result.addTypes(parser.getBuilder().getIntegerType(32));
427
428 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
429 return failure();
430
431 // If the attribute dictionary contains no 'names' attribute, infer it from
432 // the SSA name (if specified).
433 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
434 return attr.getName() == "names";
435 });
436
437 // If there was no name specified, check to see if there was a useful name
438 // specified in the asm file.
439 if (hadNames || parser.getNumResults() == 0)
440 return success();
441
442 SmallVector<StringRef, 4> names;
443 auto *context = result.getContext();
444
445 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
446 auto resultName = parser.getResultName(i);
447 StringRef nameStr;
448 if (!resultName.first.empty() && !isdigit(resultName.first[0]))
449 nameStr = resultName.first;
450
451 names.push_back(nameStr);
452 }
453
454 auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
455 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
456 return success();
457}
458
459void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
460 // Note that we only need to print the "name" attribute if the asmprinter
461 // result name disagrees with it. This can happen in strange cases, e.g.
462 // when there are conflicts.
463 bool namesDisagree = getNames().size() != getNumResults();
464
465 SmallString<32> resultNameStr;
466 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
467 resultNameStr.clear();
468 llvm::raw_svector_ostream tmpStream(resultNameStr);
469 p.printOperand(getResult(i), tmpStream);
470
471 auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
472 if (!expectedName ||
473 tmpStream.str().drop_front() != expectedName.getValue()) {
474 namesDisagree = true;
475 }
476 }
477
478 if (namesDisagree)
479 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
480 else
481 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
482}
483
484// We set the SSA name in the asm syntax to the contents of the name
485// attribute.
486void StringAttrPrettyNameOp::getAsmResultNames(
487 function_ref<void(Value, StringRef)> setNameFn) {
488
489 auto value = getNames();
490 for (size_t i = 0, e = value.size(); i != e; ++i)
491 if (auto str = dyn_cast<StringAttr>(value[i]))
492 if (!str.getValue().empty())
493 setNameFn(getResult(i), str.getValue());
494}
495
496//===----------------------------------------------------------------------===//
497// CustomResultsNameOp
498//===----------------------------------------------------------------------===//
499
500void CustomResultsNameOp::getAsmResultNames(
501 function_ref<void(Value, StringRef)> setNameFn) {
502 ArrayAttr value = getNames();
503 for (size_t i = 0, e = value.size(); i != e; ++i)
504 if (auto str = dyn_cast<StringAttr>(value[i]))
505 if (!str.empty())
506 setNameFn(getResult(i), str.getValue());
507}
508
509//===----------------------------------------------------------------------===//
510// ResultNameFromTypeOp
511//===----------------------------------------------------------------------===//
512
513void ResultNameFromTypeOp::getAsmResultNames(
514 function_ref<void(Value, StringRef)> setNameFn) {
515 auto result = getResult();
516 auto setResultNameFn = [&](::llvm::StringRef name) {
517 setNameFn(result, name);
518 };
519 auto opAsmTypeInterface =
520 ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
521 opAsmTypeInterface.getAsmName(setResultNameFn);
522}
523
524//===----------------------------------------------------------------------===//
525// BlockArgumentNameFromTypeOp
526//===----------------------------------------------------------------------===//
527
528void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames(
529 ::mlir::Region &region, ::mlir::OpAsmSetValueNameFn setNameFn) {
530 for (auto &block : region) {
531 for (auto arg : block.getArguments()) {
532 if (auto opAsmTypeInterface =
533 ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
534 auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
535 opAsmTypeInterface.getAsmName(setArgNameFn);
536 }
537 }
538 }
539}
540
541//===----------------------------------------------------------------------===//
542// ResultTypeWithTraitOp
543//===----------------------------------------------------------------------===//
544
545LogicalResult ResultTypeWithTraitOp::verify() {
546 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
547 return success();
548 return emitError("result type should have trait 'TestTypeTrait'");
549}
550
551//===----------------------------------------------------------------------===//
552// AttrWithTraitOp
553//===----------------------------------------------------------------------===//
554
555LogicalResult AttrWithTraitOp::verify() {
556 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
557 return success();
558 return emitError("'attr' attribute should have trait 'TestAttrTrait'");
559}
560
561//===----------------------------------------------------------------------===//
562// RegionIfOp
563//===----------------------------------------------------------------------===//
564
565void RegionIfOp::print(OpAsmPrinter &p) {
566 p << " ";
567 p.printOperands(getOperands());
568 p << ": " << getOperandTypes();
569 p.printArrowTypeList(getResultTypes());
570 p << " then ";
571 p.printRegion(getThenRegion(),
572 /*printEntryBlockArgs=*/true,
573 /*printBlockTerminators=*/true);
574 p << " else ";
575 p.printRegion(getElseRegion(),
576 /*printEntryBlockArgs=*/true,
577 /*printBlockTerminators=*/true);
578 p << " join ";
579 p.printRegion(getJoinRegion(),
580 /*printEntryBlockArgs=*/true,
581 /*printBlockTerminators=*/true);
582}
583
584ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
585 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
586 SmallVector<Type, 2> operandTypes;
587
588 result.regions.reserve(3);
589 Region *thenRegion = result.addRegion();
590 Region *elseRegion = result.addRegion();
591 Region *joinRegion = result.addRegion();
592
593 // Parse operand, type and arrow type lists.
594 if (parser.parseOperandList(operandInfos) ||
595 parser.parseColonTypeList(operandTypes) ||
596 parser.parseArrowTypeList(result.types))
597 return failure();
598
599 // Parse all attached regions.
600 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
601 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
602 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
603 return failure();
604
605 return parser.resolveOperands(operandInfos, operandTypes,
606 parser.getCurrentLocation(), result.operands);
607}
608
609OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
610 assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
611 "invalid region index");
612 return getOperands();
613}
614
615void RegionIfOp::getSuccessorRegions(
616 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
617 // We always branch to the join region.
618 if (!point.isParent()) {
619 if (point != getJoinRegion())
620 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
621 else
622 regions.push_back(RegionSuccessor(getResults()));
623 return;
624 }
625
626 // The then and else regions are the entry regions of this op.
627 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
628 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
629}
630
631void RegionIfOp::getRegionInvocationBounds(
632 ArrayRef<Attribute> operands,
633 SmallVectorImpl<InvocationBounds> &invocationBounds) {
634 // Each region is invoked at most once.
635 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
636}
637
638//===----------------------------------------------------------------------===//
639// AnyCondOp
640//===----------------------------------------------------------------------===//
641
642void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
643 SmallVectorImpl<RegionSuccessor> &regions) {
644 // The parent op branches into the only region, and the region branches back
645 // to the parent op.
646 if (point.isParent())
647 regions.emplace_back(&getRegion());
648 else
649 regions.emplace_back(getResults());
650}
651
652void AnyCondOp::getRegionInvocationBounds(
653 ArrayRef<Attribute> operands,
654 SmallVectorImpl<InvocationBounds> &invocationBounds) {
655 invocationBounds.emplace_back(1, 1);
656}
657
658//===----------------------------------------------------------------------===//
659// SingleBlockImplicitTerminatorOp
660//===----------------------------------------------------------------------===//
661
662/// Testing the correctness of some traits.
663static_assert(
664 llvm::is_detected<OpTrait::has_implicit_terminator_t,
665 SingleBlockImplicitTerminatorOp>::value,
666 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
667static_assert(OpTrait::hasSingleBlockImplicitTerminator<
668 SingleBlockImplicitTerminatorOp>::value,
669 "hasSingleBlockImplicitTerminator does not match "
670 "SingleBlockImplicitTerminatorOp");
671
672//===----------------------------------------------------------------------===//
673// SingleNoTerminatorCustomAsmOp
674//===----------------------------------------------------------------------===//
675
676ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
677 OperationState &state) {
678 Region *body = state.addRegion();
679 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
680 return failure();
681 return success();
682}
683
684void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
685 printer.printRegion(
686 getRegion(), /*printEntryBlockArgs=*/false,
687 // This op has a single block without terminators. But explicitly mark
688 // as not printing block terminators for testing.
689 /*printBlockTerminators=*/false);
690}
691
692//===----------------------------------------------------------------------===//
693// TestVerifiersOp
694//===----------------------------------------------------------------------===//
695
696LogicalResult TestVerifiersOp::verify() {
697 if (!getRegion().hasOneBlock())
698 return emitOpError("`hasOneBlock` trait hasn't been verified");
699
700 Operation *definingOp = getInput().getDefiningOp();
701 if (definingOp && failed(mlir::verify(definingOp)))
702 return emitOpError("operand hasn't been verified");
703
704 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
705 // loop.
706 mlir::emitRemark(getLoc(), "success run of verifier");
707
708 return success();
709}
710
711LogicalResult TestVerifiersOp::verifyRegions() {
712 if (!getRegion().hasOneBlock())
713 return emitOpError("`hasOneBlock` trait hasn't been verified");
714
715 for (Block &block : getRegion())
716 for (Operation &op : block)
717 if (failed(mlir::verify(&op)))
718 return emitOpError("nested op hasn't been verified");
719
720 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
721 // loop.
722 mlir::emitRemark(getLoc(), "success run of region verifier");
723
724 return success();
725}
726
727//===----------------------------------------------------------------------===//
728// Test InferIntRangeInterface
729//===----------------------------------------------------------------------===//
730
731//===----------------------------------------------------------------------===//
732// TestWithBoundsOp
733//===----------------------------------------------------------------------===//
734
735void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
736 SetIntRangeFn setResultRanges) {
737 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
738}
739
740//===----------------------------------------------------------------------===//
741// TestWithBoundsRegionOp
742//===----------------------------------------------------------------------===//
743
744ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
745 OperationState &result) {
746 if (parser.parseOptionalAttrDict(result.attributes))
747 return failure();
748
749 // Parse the input argument
750 OpAsmParser::Argument argInfo;
751 if (failed(parser.parseArgument(argInfo, true)))
752 return failure();
753
754 // Parse the body region, and reuse the operand info as the argument info.
755 Region *body = result.addRegion();
756 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
757}
758
759void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
760 p.printOptionalAttrDict((*this)->getAttrs());
761 p << ' ';
762 p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
763 /*omitType=*/false);
764 p << ' ';
765 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
766}
767
768void TestWithBoundsRegionOp::inferResultRanges(
769 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
770 Value arg = getRegion().getArgument(0);
771 setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
772}
773
774//===----------------------------------------------------------------------===//
775// TestIncrementOp
776//===----------------------------------------------------------------------===//
777
778void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
779 SetIntRangeFn setResultRanges) {
780 const ConstantIntRanges &range = argRanges[0];
781 APInt one(range.umin().getBitWidth(), 1);
782 setResultRanges(getResult(),
783 {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
784 range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
785}
786
787//===----------------------------------------------------------------------===//
788// TestReflectBoundsOp
789//===----------------------------------------------------------------------===//
790
791void TestReflectBoundsOp::inferResultRanges(
792 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
793 const ConstantIntRanges &range = argRanges[0];
794 MLIRContext *ctx = getContext();
795 Builder b(ctx);
796 Type sIntTy, uIntTy;
797 // For plain `IntegerType`s, we can derive the appropriate signed and unsigned
798 // Types for the Attributes.
799 Type type = getElementTypeOrSelf(getType());
800 if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
801 unsigned bitwidth = intTy.getWidth();
802 sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
803 uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
804 } else {
805 sIntTy = uIntTy = type;
806 }
807
808 setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
809 setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
810 setSminAttr(b.getIntegerAttr(sIntTy, range.smin()));
811 setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax()));
812 setResultRanges(getResult(), range);
813}
814
815//===----------------------------------------------------------------------===//
816// ConversionFuncOp
817//===----------------------------------------------------------------------===//
818
819ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
820 OperationState &result) {
821 auto buildFuncType =
822 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
823 function_interface_impl::VariadicFlag,
824 std::string &) { return builder.getFunctionType(argTypes, results); };
825
826 return function_interface_impl::parseFunctionOp(
827 parser, result, /*allowVariadic=*/false,
828 getFunctionTypeAttrName(result.name), buildFuncType,
829 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
830}
831
832void ConversionFuncOp::print(OpAsmPrinter &p) {
833 function_interface_impl::printFunctionOp(
834 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
835 getArgAttrsAttrName(), getResAttrsAttrName());
836}
837
838//===----------------------------------------------------------------------===//
839// ReifyBoundOp
840//===----------------------------------------------------------------------===//
841
842mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
843 if (getType() == "EQ")
844 return mlir::presburger::BoundType::EQ;
845 if (getType() == "LB")
846 return mlir::presburger::BoundType::LB;
847 if (getType() == "UB")
848 return mlir::presburger::BoundType::UB;
849 llvm_unreachable("invalid bound type");
850}
851
852LogicalResult ReifyBoundOp::verify() {
853 if (isa<ShapedType>(getVar().getType())) {
854 if (!getDim().has_value())
855 return emitOpError("expected 'dim' attribute for shaped type variable");
856 } else if (getVar().getType().isIndex()) {
857 if (getDim().has_value())
858 return emitOpError("unexpected 'dim' attribute for index variable");
859 } else {
860 return emitOpError("expected index-typed variable or shape type variable");
861 }
862 if (getConstant() && getScalable())
863 return emitOpError("'scalable' and 'constant' are mutually exlusive");
864 if (getScalable() != getVscaleMin().has_value())
865 return emitOpError("expected 'vscale_min' if and only if 'scalable'");
866 if (getScalable() != getVscaleMax().has_value())
867 return emitOpError("expected 'vscale_min' if and only if 'scalable'");
868 return success();
869}
870
871ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
872 if (getDim().has_value())
873 return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
874 return ValueBoundsConstraintSet::Variable(getVar());
875}
876
877//===----------------------------------------------------------------------===//
878// CompareOp
879//===----------------------------------------------------------------------===//
880
881ValueBoundsConstraintSet::ComparisonOperator
882CompareOp::getComparisonOperator() {
883 if (getCmp() == "EQ")
884 return ValueBoundsConstraintSet::ComparisonOperator::EQ;
885 if (getCmp() == "LT")
886 return ValueBoundsConstraintSet::ComparisonOperator::LT;
887 if (getCmp() == "LE")
888 return ValueBoundsConstraintSet::ComparisonOperator::LE;
889 if (getCmp() == "GT")
890 return ValueBoundsConstraintSet::ComparisonOperator::GT;
891 if (getCmp() == "GE")
892 return ValueBoundsConstraintSet::ComparisonOperator::GE;
893 llvm_unreachable("invalid comparison operator");
894}
895
896mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
897 if (!getLhsMap())
898 return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
899 SmallVector<Value> mapOperands(
900 getVarOperands().slice(0, getLhsMap()->getNumInputs()));
901 return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
902}
903
904mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
905 int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
906 if (!getRhsMap())
907 return ValueBoundsConstraintSet::Variable(
908 getVarOperands()[rhsOperandsBegin]);
909 SmallVector<Value> mapOperands(
910 getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
911 return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
912}
913
914LogicalResult CompareOp::verify() {
915 if (getCompose() && (getLhsMap() || getRhsMap()))
916 return emitOpError(
917 "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
918 int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
919 expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
920 if (getVarOperands().size() != size_t(expectedNumOperands))
921 return emitOpError("expected ")
922 << expectedNumOperands << " operands, but got "
923 << getVarOperands().size();
924 return success();
925}
926
927//===----------------------------------------------------------------------===//
928// TestOpInPlaceSelfFold
929//===----------------------------------------------------------------------===//
930
931OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
932 if (!getFolded()) {
933 // The folder adds the "folded" if not present.
934 setFolded(true);
935 return getResult();
936 }
937 return {};
938}
939
940//===----------------------------------------------------------------------===//
941// TestOpFoldWithFoldAdaptor
942//===----------------------------------------------------------------------===//
943
944OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
945 int64_t sum = 0;
946 if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
947 sum += value.getValue().getSExtValue();
948
949 for (Attribute attr : adaptor.getVariadic())
950 if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
951 sum += 2 * value.getValue().getSExtValue();
952
953 for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
954 for (Attribute attr : attrs)
955 if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
956 sum += 3 * value.getValue().getSExtValue();
957
958 sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
959
960 return IntegerAttr::get(getType(), sum);
961}
962
963//===----------------------------------------------------------------------===//
964// OpWithInferTypeAdaptorInterfaceOp
965//===----------------------------------------------------------------------===//
966
967LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
968 MLIRContext *, std::optional<Location> location,
969 OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
970 SmallVectorImpl<Type> &inferredReturnTypes) {
971 if (adaptor.getX().getType() != adaptor.getY().getType()) {
972 return emitOptionalError(location, "operand type mismatch ",
973 adaptor.getX().getType(), " vs ",
974 adaptor.getY().getType());
975 }
976 inferredReturnTypes.assign({adaptor.getX().getType()});
977 return success();
978}
979
980//===----------------------------------------------------------------------===//
981// OpWithRefineTypeInterfaceOp
982//===----------------------------------------------------------------------===//
983
984// TODO: We should be able to only define either inferReturnType or
985// refineReturnType, currently only refineReturnType can be omitted.
986LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
987 MLIRContext *context, std::optional<Location> location, ValueRange operands,
988 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
989 SmallVectorImpl<Type> &returnTypes) {
990 returnTypes.clear();
991 return OpWithRefineTypeInterfaceOp::refineReturnTypes(
992 context, location, operands, attributes, properties, regions,
993 returnTypes);
994}
995
996LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
997 MLIRContext *, std::optional<Location> location, ValueRange operands,
998 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
999 SmallVectorImpl<Type> &returnTypes) {
1000 if (operands[0].getType() != operands[1].getType()) {
1001 return emitOptionalError(location, "operand type mismatch ",
1002 operands[0].getType(), " vs ",
1003 operands[1].getType());
1004 }
1005 // TODO: Add helper to make this more concise to write.
1006 if (returnTypes.empty())
1007 returnTypes.resize(1, nullptr);
1008 if (returnTypes[0] && returnTypes[0] != operands[0].getType())
1009 return emitOptionalError(location,
1010 "required first operand and result to match");
1011 returnTypes[0] = operands[0].getType();
1012 return success();
1013}
1014
1015//===----------------------------------------------------------------------===//
1016// OpWithShapedTypeInferTypeAdaptorInterfaceOp
1017//===----------------------------------------------------------------------===//
1018
1019LogicalResult
1020OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
1021 MLIRContext *context, std::optional<Location> location,
1022 OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
1023 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1024 // Create return type consisting of the last element of the first operand.
1025 auto operandType = adaptor.getOperand1().getType();
1026 auto sval = dyn_cast<ShapedType>(operandType);
1027 if (!sval)
1028 return emitOptionalError(location, "only shaped type operands allowed");
1029 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
1030 auto type = IntegerType::get(context, 17);
1031
1032 Attribute encoding;
1033 if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
1034 encoding = rankedTy.getEncoding();
1035 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
1036 return success();
1037}
1038
1039LogicalResult
1040OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
1041 OpBuilder &builder, ValueRange operands,
1042 llvm::SmallVectorImpl<Value> &shapes) {
1043 shapes = SmallVector<Value, 1>{
1044 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1045 return success();
1046}
1047
1048//===----------------------------------------------------------------------===//
1049// TestOpWithPropertiesAndInferredType
1050//===----------------------------------------------------------------------===//
1051
1052LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
1053 MLIRContext *context, std::optional<Location>, ValueRange operands,
1054 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1055 SmallVectorImpl<Type> &inferredReturnTypes) {
1056
1057 Adaptor adaptor(operands, attributes, properties, regions);
1058 inferredReturnTypes.push_back(IntegerType::get(
1059 context, adaptor.getLhs() + adaptor.getProperties().rhs));
1060 return success();
1061}
1062
1063//===----------------------------------------------------------------------===//
1064// LoopBlockOp
1065//===----------------------------------------------------------------------===//
1066
1067void LoopBlockOp::getSuccessorRegions(
1068 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1069 regions.emplace_back(&getBody(), getBody().getArguments());
1070 if (point.isParent())
1071 return;
1072
1073 regions.emplace_back((*this)->getResults());
1074}
1075
1076OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1077 assert(point == getBody());
1078 return MutableOperandRange(getInitMutable());
1079}
1080
1081//===----------------------------------------------------------------------===//
1082// LoopBlockTerminatorOp
1083//===----------------------------------------------------------------------===//
1084
1085MutableOperandRange
1086LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
1087 if (point.isParent())
1088 return getExitArgMutable();
1089 return getNextIterArgMutable();
1090}
1091
1092//===----------------------------------------------------------------------===//
1093// SwitchWithNoBreakOp
1094//===----------------------------------------------------------------------===//
1095
1096void TestNoTerminatorOp::getSuccessorRegions(
1097 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}
1098
1099//===----------------------------------------------------------------------===//
1100// Test InferIntRangeInterface
1101//===----------------------------------------------------------------------===//
1102
1103OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
1104 // Just a simple fold for testing purposes that reads an operands constant
1105 // value and returns it.
1106 if (!attributes.empty())
1107 return attributes.front();
1108 return nullptr;
1109}
1110
1111//===----------------------------------------------------------------------===//
1112// Tensor/Buffer Ops
1113//===----------------------------------------------------------------------===//
1114
1115void ReadBufferOp::getEffects(
1116 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1117 &effects) {
1118 // The buffer operand is read.
1119 effects.emplace_back(MemoryEffects::Read::get(), &getBufferMutable(),
1120 SideEffects::DefaultResource::get());
1121 // The buffer contents are dumped.
1122 effects.emplace_back(MemoryEffects::Write::get(),
1123 SideEffects::DefaultResource::get());
1124}
1125
1126//===----------------------------------------------------------------------===//
1127// Test Dataflow
1128//===----------------------------------------------------------------------===//
1129
1130//===----------------------------------------------------------------------===//
1131// TestCallAndStoreOp
1132//===----------------------------------------------------------------------===//
1133
1134CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
1135 return getCallee();
1136}
1137
1138void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1139 setCalleeAttr(cast<SymbolRefAttr>(callee));
1140}
1141
1142Operation::operand_range TestCallAndStoreOp::getArgOperands() {
1143 return getCalleeOperands();
1144}
1145
1146MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
1147 return getCalleeOperandsMutable();
1148}
1149
1150//===----------------------------------------------------------------------===//
1151// TestCallOnDeviceOp
1152//===----------------------------------------------------------------------===//
1153
1154CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
1155 return getCallee();
1156}
1157
1158void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1159 setCalleeAttr(cast<SymbolRefAttr>(callee));
1160}
1161
1162Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
1163 return getForwardedOperands();
1164}
1165
1166MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
1167 return getForwardedOperandsMutable();
1168}
1169
1170//===----------------------------------------------------------------------===//
1171// TestStoreWithARegion
1172//===----------------------------------------------------------------------===//
1173
1174void TestStoreWithARegion::getSuccessorRegions(
1175 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1176 if (point.isParent())
1177 regions.emplace_back(&getBody(), getBody().front().getArguments());
1178 else
1179 regions.emplace_back();
1180}
1181
1182//===----------------------------------------------------------------------===//
1183// TestStoreWithALoopRegion
1184//===----------------------------------------------------------------------===//
1185
1186void TestStoreWithALoopRegion::getSuccessorRegions(
1187 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1188 // Both the operation itself and the region may be branching into the body or
1189 // back into the operation itself. It is possible for the operation not to
1190 // enter the body.
1191 regions.emplace_back(
1192 RegionSuccessor(&getBody(), getBody().front().getArguments()));
1193 regions.emplace_back();
1194}
1195
1196//===----------------------------------------------------------------------===//
1197// TestVersionedOpA
1198//===----------------------------------------------------------------------===//
1199
1200LogicalResult
1201TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader,
1202 mlir::OperationState &state) {
1203 auto &prop = state.getOrAddProperties<Properties>();
1204 if (mlir::failed(reader.readAttribute(prop.dims)))
1205 return mlir::failure();
1206
1207 // Check if we have a version. If not, assume we are parsing the current
1208 // version.
1209 auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1210 if (succeeded(maybeVersion)) {
1211 // If version is less than 2.0, there is no additional attribute to parse.
1212 // We can materialize missing properties post parsing before verification.
1213 const auto *version =
1214 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1215 if ((version->major_ < 2)) {
1216 return success();
1217 }
1218 }
1219
1220 if (mlir::failed(reader.readAttribute(prop.modifier)))
1221 return mlir::failure();
1222 return mlir::success();
1223}
1224
1225void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
1226 auto &prop = getProperties();
1227 writer.writeAttribute(prop.dims);
1228
1229 auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
1230 if (succeeded(maybeVersion)) {
1231 // If version is less than 2.0, there is no additional attribute to write.
1232 const auto *version =
1233 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1234 if ((version->major_ < 2)) {
1235 llvm::outs() << "downgrading op properties...\n";
1236 return;
1237 }
1238 }
1239 writer.writeAttribute(prop.modifier);
1240}
1241
1242//===----------------------------------------------------------------------===//
1243// TestOpWithVersionedProperties
1244//===----------------------------------------------------------------------===//
1245
1246llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
1247 mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
1248 uint64_t value1, value2 = 0;
1249 if (failed(reader.readVarInt(value1)))
1250 return failure();
1251
1252 // Check if we have a version. If not, assume we are parsing the current
1253 // version.
1254 auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1255 bool needToParseAnotherInt = true;
1256 if (succeeded(maybeVersion)) {
1257 // If version is less than 2.0, there is no additional attribute to parse.
1258 // We can materialize missing properties post parsing before verification.
1259 const auto *version =
1260 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1261 if ((version->major_ < 2))
1262 needToParseAnotherInt = false;
1263 }
1264 if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
1265 return failure();
1266
1267 prop.value1 = value1;
1268 prop.value2 = value2;
1269 return success();
1270}
1271
1272void TestOpWithVersionedProperties::writeToMlirBytecode(
1273 mlir::DialectBytecodeWriter &writer,
1274 const test::VersionedProperties &prop) {
1275 writer.writeVarInt(prop.value1);
1276 writer.writeVarInt(prop.value2);
1277}
1278
1279//===----------------------------------------------------------------------===//
1280// TestMultiSlotAlloca
1281//===----------------------------------------------------------------------===//
1282
1283llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
1284 SmallVector<MemorySlot> slots;
1285 for (Value result : getResults()) {
1286 slots.push_back(MemorySlot{
1287 result, cast<MemRefType>(result.getType()).getElementType()});
1288 }
1289 return slots;
1290}
1291
1292Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
1293 OpBuilder &builder) {
1294 return builder.create<TestOpConstant>(getLoc(), slot.elemType,
1295 builder.getI32IntegerAttr(42));
1296}
1297
1298void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
1299 BlockArgument argument,
1300 OpBuilder &builder) {
1301 // Not relevant for testing.
1302}
1303
1304/// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
1305static std::optional<TestMultiSlotAlloca>
1306createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
1307 TestMultiSlotAlloca oldOp) {
1308
1309 if (oldOp.getNumResults() == 1) {
1310 oldOp.erase();
1311 return std::nullopt;
1312 }
1313
1314 SmallVector<Type> newTypes;
1315 SmallVector<Value> remainingValues;
1316
1317 for (Value oldResult : oldOp.getResults()) {
1318 if (oldResult == slot.ptr)
1319 continue;
1320 remainingValues.push_back(oldResult);
1321 newTypes.push_back(oldResult.getType());
1322 }
1323
1324 OpBuilder::InsertionGuard guard(builder);
1325 builder.setInsertionPoint(oldOp);
1326 auto replacement =
1327 builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
1328 for (auto [oldResult, newResult] :
1329 llvm::zip_equal(remainingValues, replacement.getResults()))
1330 oldResult.replaceAllUsesWith(newResult);
1331
1332 oldOp.erase();
1333 return replacement;
1334}
1335
1336std::optional<PromotableAllocationOpInterface>
1337TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
1338 Value defaultValue,
1339 OpBuilder &builder) {
1340 if (defaultValue && defaultValue.use_empty())
1341 defaultValue.getDefiningOp()->erase();
1342 return createNewMultiAllocaWithoutSlot(slot, builder, *this);
1343}
1344
1345SmallVector<DestructurableMemorySlot>
1346TestMultiSlotAlloca::getDestructurableSlots() {
1347 SmallVector<DestructurableMemorySlot> slots;
1348 for (Value result : getResults()) {
1349 auto memrefType = cast<MemRefType>(result.getType());
1350 auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType);
1351 if (!destructurable)
1352 continue;
1353
1354 std::optional<DenseMap<Attribute, Type>> destructuredType =
1355 destructurable.getSubelementIndexMap();
1356 if (!destructuredType)
1357 continue;
1358 slots.emplace_back(
1359 DestructurableMemorySlot{{result, memrefType}, *destructuredType});
1360 }
1361 return slots;
1362}
1363
1364DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
1365 const DestructurableMemorySlot &slot,
1366 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
1367 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
1368 OpBuilder::InsertionGuard guard(builder);
1369 builder.setInsertionPointAfter(*this);
1370
1371 DenseMap<Attribute, MemorySlot> slotMap;
1372
1373 for (Attribute usedIndex : usedIndices) {
1374 Type elemType = slot.subelementTypes.lookup(usedIndex);
1375 MemRefType elemPtr = MemRefType::get({}, elemType);
1376 auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
1377 newAllocators.push_back(subAlloca);
1378 slotMap.try_emplace<MemorySlot>(usedIndex,
1379 {subAlloca.getResult(0), elemType});
1380 }
1381
1382 return slotMap;
1383}
1384
1385std::optional<DestructurableAllocationOpInterface>
1386TestMultiSlotAlloca::handleDestructuringComplete(
1387 const DestructurableMemorySlot &slot, OpBuilder &builder) {
1388 return createNewMultiAllocaWithoutSlot(slot, builder, *this);
1389}
1390

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/test/lib/Dialect/Test/TestOpDefs.cpp