1//===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Tensor/IR/Tensor.h"
13#include "mlir/IR/Verifier.h"
14#include "mlir/Interfaces/FunctionImplementation.h"
15#include "mlir/Interfaces/MemorySlotInterfaces.h"
16
17using namespace mlir;
18using namespace test;
19
20//===----------------------------------------------------------------------===//
21// TestBranchOp
22//===----------------------------------------------------------------------===//
23
24SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
25 assert(index == 0 && "invalid successor index");
26 return SuccessorOperands(getTargetOperandsMutable());
27}
28
29//===----------------------------------------------------------------------===//
30// TestProducingBranchOp
31//===----------------------------------------------------------------------===//
32
33SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
34 assert(index <= 1 && "invalid successor index");
35 if (index == 1)
36 return SuccessorOperands(getFirstOperandsMutable());
37 return SuccessorOperands(getSecondOperandsMutable());
38}
39
40//===----------------------------------------------------------------------===//
41// TestInternalBranchOp
42//===----------------------------------------------------------------------===//
43
44SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
45 assert(index <= 1 && "invalid successor index");
46 if (index == 0)
47 return SuccessorOperands(0, getSuccessOperandsMutable());
48 return SuccessorOperands(1, getErrorOperandsMutable());
49}
50
51//===----------------------------------------------------------------------===//
52// TestCallOp
53//===----------------------------------------------------------------------===//
54
55LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
56 // Check that the callee attribute was specified.
57 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
58 if (!fnAttr)
59 return emitOpError("requires a 'callee' symbol reference attribute");
60 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
61 return emitOpError() << "'" << fnAttr.getValue()
62 << "' does not reference a valid function";
63 return success();
64}
65
66//===----------------------------------------------------------------------===//
67// FoldToCallOp
68//===----------------------------------------------------------------------===//
69
70namespace {
71struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
72 using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
73
74 LogicalResult matchAndRewrite(FoldToCallOp op,
75 PatternRewriter &rewriter) const override {
76 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
77 op.getCalleeAttr(), ValueRange());
78 return success();
79 }
80};
81} // namespace
82
83void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
84 MLIRContext *context) {
85 results.add<FoldToCallOpPattern>(context);
86}
87
88//===----------------------------------------------------------------------===//
89// IsolatedRegionOp - test parsing passthrough operands
90//===----------------------------------------------------------------------===//
91
92ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
93 OperationState &result) {
94 // Parse the input operand.
95 OpAsmParser::Argument argInfo;
96 argInfo.type = parser.getBuilder().getIndexType();
97 if (parser.parseOperand(argInfo.ssaName) ||
98 parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
99 return failure();
100
101 // Parse the body region, and reuse the operand info as the argument info.
102 Region *body = result.addRegion();
103 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
104}
105
106void IsolatedRegionOp::print(OpAsmPrinter &p) {
107 p << ' ';
108 p.printOperand(getOperand());
109 p.shadowRegionArgs(getRegion(), getOperand());
110 p << ' ';
111 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
112}
113
114//===----------------------------------------------------------------------===//
115// SSACFGRegionOp
116//===----------------------------------------------------------------------===//
117
118RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
119 return RegionKind::SSACFG;
120}
121
122//===----------------------------------------------------------------------===//
123// GraphRegionOp
124//===----------------------------------------------------------------------===//
125
126RegionKind GraphRegionOp::getRegionKind(unsigned index) {
127 return RegionKind::Graph;
128}
129
130//===----------------------------------------------------------------------===//
131// IsolatedGraphRegionOp
132//===----------------------------------------------------------------------===//
133
134RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) {
135 return RegionKind::Graph;
136}
137
138//===----------------------------------------------------------------------===//
139// AffineScopeOp
140//===----------------------------------------------------------------------===//
141
142ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
143 // Parse the body region, and reuse the operand info as the argument info.
144 Region *body = result.addRegion();
145 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
146}
147
148void AffineScopeOp::print(OpAsmPrinter &p) {
149 p << " ";
150 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
151}
152
153//===----------------------------------------------------------------------===//
154// TestRemoveOpWithInnerOps
155//===----------------------------------------------------------------------===//
156
157namespace {
158struct TestRemoveOpWithInnerOps
159 : public OpRewritePattern<TestOpWithRegionPattern> {
160 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
161
162 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
163
164 LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
165 PatternRewriter &rewriter) const override {
166 rewriter.eraseOp(op: op);
167 return success();
168 }
169};
170} // namespace
171
172//===----------------------------------------------------------------------===//
173// TestOpWithRegionPattern
174//===----------------------------------------------------------------------===//
175
176void TestOpWithRegionPattern::getCanonicalizationPatterns(
177 RewritePatternSet &results, MLIRContext *context) {
178 results.add<TestRemoveOpWithInnerOps>(context);
179}
180
181//===----------------------------------------------------------------------===//
182// TestOpWithRegionFold
183//===----------------------------------------------------------------------===//
184
185OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
186 return getOperand();
187}
188
189//===----------------------------------------------------------------------===//
190// TestOpConstant
191//===----------------------------------------------------------------------===//
192
193OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
194
195//===----------------------------------------------------------------------===//
196// TestOpWithVariadicResultsAndFolder
197//===----------------------------------------------------------------------===//
198
199LogicalResult TestOpWithVariadicResultsAndFolder::fold(
200 FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
201 for (Value input : this->getOperands()) {
202 results.push_back(input);
203 }
204 return success();
205}
206
207//===----------------------------------------------------------------------===//
208// TestOpInPlaceFold
209//===----------------------------------------------------------------------===//
210
211OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
212 // Exercise the fact that an operation created with createOrFold should be
213 // allowed to access its parent block.
214 assert(getOperation()->getBlock() &&
215 "expected that operation is not unlinked");
216
217 if (adaptor.getOp() && !getProperties().attr) {
218 // The folder adds "attr" if not present.
219 getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
220 return getResult();
221 }
222 return {};
223}
224
225//===----------------------------------------------------------------------===//
226// OpWithInferTypeInterfaceOp
227//===----------------------------------------------------------------------===//
228
229LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
230 MLIRContext *, std::optional<Location> location, ValueRange operands,
231 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
232 SmallVectorImpl<Type> &inferredReturnTypes) {
233 if (operands[0].getType() != operands[1].getType()) {
234 return emitOptionalError(location, "operand type mismatch ",
235 operands[0].getType(), " vs ",
236 operands[1].getType());
237 }
238 inferredReturnTypes.assign({operands[0].getType()});
239 return success();
240}
241
242//===----------------------------------------------------------------------===//
243// OpWithShapedTypeInferTypeInterfaceOp
244//===----------------------------------------------------------------------===//
245
246LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
247 MLIRContext *context, std::optional<Location> location,
248 ValueShapeRange operands, DictionaryAttr attributes,
249 OpaqueProperties properties, RegionRange regions,
250 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
251 // Create return type consisting of the last element of the first operand.
252 auto operandType = operands.front().getType();
253 auto sval = dyn_cast<ShapedType>(operandType);
254 if (!sval)
255 return emitOptionalError(location, "only shaped type operands allowed");
256 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
257 auto type = IntegerType::get(context, 17);
258
259 Attribute encoding;
260 if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
261 encoding = rankedTy.getEncoding();
262 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
263 return success();
264}
265
266LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
267 OpBuilder &builder, ValueRange operands,
268 llvm::SmallVectorImpl<Value> &shapes) {
269 shapes = SmallVector<Value, 1>{
270 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
271 return success();
272}
273
274//===----------------------------------------------------------------------===//
275// OpWithResultShapeInterfaceOp
276//===----------------------------------------------------------------------===//
277
278LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
279 OpBuilder &builder, ValueRange operands,
280 llvm::SmallVectorImpl<Value> &shapes) {
281 Location loc = getLoc();
282 shapes.reserve(operands.size());
283 for (Value operand : llvm::reverse(operands)) {
284 auto rank = cast<RankedTensorType>(operand.getType()).getRank();
285 auto currShape = llvm::to_vector<4>(
286 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
287 return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
288 }));
289 shapes.push_back(builder.create<tensor::FromElementsOp>(
290 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
291 currShape));
292 }
293 return success();
294}
295
296//===----------------------------------------------------------------------===//
297// OpWithResultShapePerDimInterfaceOp
298//===----------------------------------------------------------------------===//
299
300LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
301 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
302 Location loc = getLoc();
303 shapes.reserve(getNumOperands());
304 for (Value operand : llvm::reverse(getOperands())) {
305 auto tensorType = cast<RankedTensorType>(operand.getType());
306 auto currShape = llvm::to_vector<4>(llvm::map_range(
307 llvm::seq<int64_t>(0, tensorType.getRank()),
308 [&](int64_t dim) -> OpFoldResult {
309 return tensorType.isDynamicDim(dim)
310 ? static_cast<OpFoldResult>(
311 builder.createOrFold<tensor::DimOp>(loc, operand,
312 dim))
313 : static_cast<OpFoldResult>(
314 builder.getIndexAttr(tensorType.getDimSize(dim)));
315 }));
316 shapes.emplace_back(std::move(currShape));
317 }
318 return success();
319}
320
321//===----------------------------------------------------------------------===//
322// SideEffectOp
323//===----------------------------------------------------------------------===//
324
325namespace {
326/// A test resource for side effects.
327struct TestResource : public SideEffects::Resource::Base<TestResource> {
328 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
329
330 StringRef getName() final { return "<Test>"; }
331};
332} // namespace
333
334void SideEffectOp::getEffects(
335 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
336 // Check for an effects attribute on the op instance.
337 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
338 if (!effectsAttr)
339 return;
340
341 for (Attribute element : effectsAttr) {
342 DictionaryAttr effectElement = cast<DictionaryAttr>(Val&: element);
343
344 // Get the specific memory effect.
345 MemoryEffects::Effect *effect =
346 StringSwitch<MemoryEffects::Effect *>(
347 cast<StringAttr>(Val: effectElement.get(name: "effect")).getValue())
348 .Case(S: "allocate", Value: MemoryEffects::Allocate::get())
349 .Case(S: "free", Value: MemoryEffects::Free::get())
350 .Case(S: "read", Value: MemoryEffects::Read::get())
351 .Case(S: "write", Value: MemoryEffects::Write::get());
352
353 // Check for a non-default resource to use.
354 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
355 if (effectElement.get("test_resource"))
356 resource = TestResource::get();
357
358 // Check for a result to affect.
359 if (effectElement.get(name: "on_result"))
360 effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
361 else if (Attribute ref = effectElement.get(name: "on_reference"))
362 effects.emplace_back(Args&: effect, Args: cast<SymbolRefAttr>(Val&: ref), Args&: resource);
363 else
364 effects.emplace_back(Args&: effect, Args&: resource);
365 }
366}
367
368void SideEffectOp::getEffects(
369 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
370 testSideEffectOpGetEffect(getOperation(), effects);
371}
372
373void SideEffectWithRegionOp::getEffects(
374 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
375 // Check for an effects attribute on the op instance.
376 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
377 if (!effectsAttr)
378 return;
379
380 for (Attribute element : effectsAttr) {
381 DictionaryAttr effectElement = cast<DictionaryAttr>(element);
382
383 // Get the specific memory effect.
384 MemoryEffects::Effect *effect =
385 StringSwitch<MemoryEffects::Effect *>(
386 cast<StringAttr>(effectElement.get("effect")).getValue())
387 .Case("allocate", MemoryEffects::Allocate::get())
388 .Case("free", MemoryEffects::Free::get())
389 .Case("read", MemoryEffects::Read::get())
390 .Case("write", MemoryEffects::Write::get());
391
392 // Check for a non-default resource to use.
393 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
394 if (effectElement.get("test_resource"))
395 resource = TestResource::get();
396
397 // Check for a result to affect.
398 if (effectElement.get("on_result"))
399 effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
400 else if (effectElement.get("on_operand"))
401 effects.emplace_back(effect, &getOperation()->getOpOperands()[0],
402 resource);
403 else if (effectElement.get("on_argument"))
404 effects.emplace_back(effect, getOperation()->getRegion(0).getArgument(0),
405 resource);
406 else if (Attribute ref = effectElement.get("on_reference"))
407 effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
408 else
409 effects.emplace_back(effect, resource);
410 }
411}
412
413void SideEffectWithRegionOp::getEffects(
414 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
415 testSideEffectOpGetEffect(getOperation(), effects);
416}
417
418//===----------------------------------------------------------------------===//
419// StringAttrPrettyNameOp
420//===----------------------------------------------------------------------===//
421
422// This op has fancy handling of its SSA result name.
423ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
424 OperationState &result) {
425 // Add the result types.
426 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
427 result.addTypes(parser.getBuilder().getIntegerType(32));
428
429 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
430 return failure();
431
432 // If the attribute dictionary contains no 'names' attribute, infer it from
433 // the SSA name (if specified).
434 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
435 return attr.getName() == "names";
436 });
437
438 // If there was no name specified, check to see if there was a useful name
439 // specified in the asm file.
440 if (hadNames || parser.getNumResults() == 0)
441 return success();
442
443 SmallVector<StringRef, 4> names;
444 auto *context = result.getContext();
445
446 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
447 auto resultName = parser.getResultName(i);
448 StringRef nameStr;
449 if (!resultName.first.empty() && !isdigit(resultName.first[0]))
450 nameStr = resultName.first;
451
452 names.push_back(nameStr);
453 }
454
455 auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
456 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
457 return success();
458}
459
460void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
461 // Note that we only need to print the "name" attribute if the asmprinter
462 // result name disagrees with it. This can happen in strange cases, e.g.
463 // when there are conflicts.
464 bool namesDisagree = getNames().size() != getNumResults();
465
466 SmallString<32> resultNameStr;
467 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
468 resultNameStr.clear();
469 llvm::raw_svector_ostream tmpStream(resultNameStr);
470 p.printOperand(getResult(i), tmpStream);
471
472 auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
473 if (!expectedName ||
474 tmpStream.str().drop_front() != expectedName.getValue()) {
475 namesDisagree = true;
476 }
477 }
478
479 if (namesDisagree)
480 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
481 else
482 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
483}
484
485// We set the SSA name in the asm syntax to the contents of the name
486// attribute.
487void StringAttrPrettyNameOp::getAsmResultNames(
488 function_ref<void(Value, StringRef)> setNameFn) {
489
490 auto value = getNames();
491 for (size_t i = 0, e = value.size(); i != e; ++i)
492 if (auto str = dyn_cast<StringAttr>(value[i]))
493 if (!str.getValue().empty())
494 setNameFn(getResult(i), str.getValue());
495}
496
497//===----------------------------------------------------------------------===//
498// CustomResultsNameOp
499//===----------------------------------------------------------------------===//
500
501void CustomResultsNameOp::getAsmResultNames(
502 function_ref<void(Value, StringRef)> setNameFn) {
503 ArrayAttr value = getNames();
504 for (size_t i = 0, e = value.size(); i != e; ++i)
505 if (auto str = dyn_cast<StringAttr>(value[i]))
506 if (!str.empty())
507 setNameFn(getResult(i), str.getValue());
508}
509
510//===----------------------------------------------------------------------===//
511// ResultNameFromTypeOp
512//===----------------------------------------------------------------------===//
513
514void ResultNameFromTypeOp::getAsmResultNames(
515 function_ref<void(Value, StringRef)> setNameFn) {
516 auto result = getResult();
517 auto setResultNameFn = [&](::llvm::StringRef name) {
518 setNameFn(result, name);
519 };
520 auto opAsmTypeInterface =
521 ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
522 opAsmTypeInterface.getAsmName(setResultNameFn);
523}
524
525//===----------------------------------------------------------------------===//
526// BlockArgumentNameFromTypeOp
527//===----------------------------------------------------------------------===//
528
529void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames(
530 ::mlir::Region &region, ::mlir::OpAsmSetValueNameFn setNameFn) {
531 for (auto &block : region) {
532 for (auto arg : block.getArguments()) {
533 if (auto opAsmTypeInterface =
534 ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
535 auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
536 opAsmTypeInterface.getAsmName(setArgNameFn);
537 }
538 }
539 }
540}
541
542//===----------------------------------------------------------------------===//
543// ResultTypeWithTraitOp
544//===----------------------------------------------------------------------===//
545
546LogicalResult ResultTypeWithTraitOp::verify() {
547 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
548 return success();
549 return emitError("result type should have trait 'TestTypeTrait'");
550}
551
552//===----------------------------------------------------------------------===//
553// AttrWithTraitOp
554//===----------------------------------------------------------------------===//
555
556LogicalResult AttrWithTraitOp::verify() {
557 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
558 return success();
559 return emitError("'attr' attribute should have trait 'TestAttrTrait'");
560}
561
562//===----------------------------------------------------------------------===//
563// RegionIfOp
564//===----------------------------------------------------------------------===//
565
566void RegionIfOp::print(OpAsmPrinter &p) {
567 p << " ";
568 p.printOperands(getOperands());
569 p << ": " << getOperandTypes();
570 p.printArrowTypeList(getResultTypes());
571 p << " then ";
572 p.printRegion(getThenRegion(),
573 /*printEntryBlockArgs=*/true,
574 /*printBlockTerminators=*/true);
575 p << " else ";
576 p.printRegion(getElseRegion(),
577 /*printEntryBlockArgs=*/true,
578 /*printBlockTerminators=*/true);
579 p << " join ";
580 p.printRegion(getJoinRegion(),
581 /*printEntryBlockArgs=*/true,
582 /*printBlockTerminators=*/true);
583}
584
585ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
586 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
587 SmallVector<Type, 2> operandTypes;
588
589 result.regions.reserve(3);
590 Region *thenRegion = result.addRegion();
591 Region *elseRegion = result.addRegion();
592 Region *joinRegion = result.addRegion();
593
594 // Parse operand, type and arrow type lists.
595 if (parser.parseOperandList(operandInfos) ||
596 parser.parseColonTypeList(operandTypes) ||
597 parser.parseArrowTypeList(result.types))
598 return failure();
599
600 // Parse all attached regions.
601 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
602 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
603 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
604 return failure();
605
606 return parser.resolveOperands(operandInfos, operandTypes,
607 parser.getCurrentLocation(), result.operands);
608}
609
610OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
611 assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
612 "invalid region index");
613 return getOperands();
614}
615
616void RegionIfOp::getSuccessorRegions(
617 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
618 // We always branch to the join region.
619 if (!point.isParent()) {
620 if (point != getJoinRegion())
621 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
622 else
623 regions.push_back(RegionSuccessor(getResults()));
624 return;
625 }
626
627 // The then and else regions are the entry regions of this op.
628 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
629 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
630}
631
632void RegionIfOp::getRegionInvocationBounds(
633 ArrayRef<Attribute> operands,
634 SmallVectorImpl<InvocationBounds> &invocationBounds) {
635 // Each region is invoked at most once.
636 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
637}
638
639//===----------------------------------------------------------------------===//
640// AnyCondOp
641//===----------------------------------------------------------------------===//
642
643void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
644 SmallVectorImpl<RegionSuccessor> &regions) {
645 // The parent op branches into the only region, and the region branches back
646 // to the parent op.
647 if (point.isParent())
648 regions.emplace_back(&getRegion());
649 else
650 regions.emplace_back(getResults());
651}
652
653void AnyCondOp::getRegionInvocationBounds(
654 ArrayRef<Attribute> operands,
655 SmallVectorImpl<InvocationBounds> &invocationBounds) {
656 invocationBounds.emplace_back(1, 1);
657}
658
659//===----------------------------------------------------------------------===//
660// SingleBlockImplicitTerminatorOp
661//===----------------------------------------------------------------------===//
662
663/// Testing the correctness of some traits.
664static_assert(
665 llvm::is_detected<OpTrait::has_implicit_terminator_t,
666 SingleBlockImplicitTerminatorOp>::value,
667 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
668static_assert(OpTrait::hasSingleBlockImplicitTerminator<
669 SingleBlockImplicitTerminatorOp>::value,
670 "hasSingleBlockImplicitTerminator does not match "
671 "SingleBlockImplicitTerminatorOp");
672
673//===----------------------------------------------------------------------===//
674// SingleNoTerminatorCustomAsmOp
675//===----------------------------------------------------------------------===//
676
677ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
678 OperationState &state) {
679 Region *body = state.addRegion();
680 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
681 return failure();
682 return success();
683}
684
685void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
686 printer.printRegion(
687 getRegion(), /*printEntryBlockArgs=*/false,
688 // This op has a single block without terminators. But explicitly mark
689 // as not printing block terminators for testing.
690 /*printBlockTerminators=*/false);
691}
692
693//===----------------------------------------------------------------------===//
694// TestVerifiersOp
695//===----------------------------------------------------------------------===//
696
697LogicalResult TestVerifiersOp::verify() {
698 if (!getRegion().hasOneBlock())
699 return emitOpError("`hasOneBlock` trait hasn't been verified");
700
701 Operation *definingOp = getInput().getDefiningOp();
702 if (definingOp && failed(mlir::verify(definingOp)))
703 return emitOpError("operand hasn't been verified");
704
705 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
706 // loop.
707 mlir::emitRemark(getLoc(), "success run of verifier");
708
709 return success();
710}
711
712LogicalResult TestVerifiersOp::verifyRegions() {
713 if (!getRegion().hasOneBlock())
714 return emitOpError("`hasOneBlock` trait hasn't been verified");
715
716 for (Block &block : getRegion())
717 for (Operation &op : block)
718 if (failed(mlir::verify(&op)))
719 return emitOpError("nested op hasn't been verified");
720
721 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
722 // loop.
723 mlir::emitRemark(getLoc(), "success run of region verifier");
724
725 return success();
726}
727
728//===----------------------------------------------------------------------===//
729// Test InferIntRangeInterface
730//===----------------------------------------------------------------------===//
731
732//===----------------------------------------------------------------------===//
733// TestWithBoundsOp
734//===----------------------------------------------------------------------===//
735
736void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
737 SetIntRangeFn setResultRanges) {
738 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
739}
740
741//===----------------------------------------------------------------------===//
742// TestWithBoundsRegionOp
743//===----------------------------------------------------------------------===//
744
745ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
746 OperationState &result) {
747 if (parser.parseOptionalAttrDict(result.attributes))
748 return failure();
749
750 // Parse the input argument
751 OpAsmParser::Argument argInfo;
752 if (failed(parser.parseArgument(argInfo, true)))
753 return failure();
754
755 // Parse the body region, and reuse the operand info as the argument info.
756 Region *body = result.addRegion();
757 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
758}
759
760void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
761 p.printOptionalAttrDict((*this)->getAttrs());
762 p << ' ';
763 p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
764 /*omitType=*/false);
765 p << ' ';
766 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
767}
768
769void TestWithBoundsRegionOp::inferResultRanges(
770 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
771 Value arg = getRegion().getArgument(0);
772 setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
773}
774
775//===----------------------------------------------------------------------===//
776// TestIncrementOp
777//===----------------------------------------------------------------------===//
778
779void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
780 SetIntRangeFn setResultRanges) {
781 const ConstantIntRanges &range = argRanges[0];
782 APInt one(range.umin().getBitWidth(), 1);
783 setResultRanges(getResult(),
784 {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
785 range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
786}
787
788//===----------------------------------------------------------------------===//
789// TestReflectBoundsOp
790//===----------------------------------------------------------------------===//
791
792void TestReflectBoundsOp::inferResultRanges(
793 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
794 const ConstantIntRanges &range = argRanges[0];
795 MLIRContext *ctx = getContext();
796 Builder b(ctx);
797 Type sIntTy, uIntTy;
798 // For plain `IntegerType`s, we can derive the appropriate signed and unsigned
799 // Types for the Attributes.
800 Type type = getElementTypeOrSelf(getType());
801 if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
802 unsigned bitwidth = intTy.getWidth();
803 sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
804 uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
805 } else {
806 sIntTy = uIntTy = type;
807 }
808
809 setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
810 setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
811 setSminAttr(b.getIntegerAttr(sIntTy, range.smin()));
812 setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax()));
813 setResultRanges(getResult(), range);
814}
815
816//===----------------------------------------------------------------------===//
817// ConversionFuncOp
818//===----------------------------------------------------------------------===//
819
820ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
821 OperationState &result) {
822 auto buildFuncType =
823 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
824 function_interface_impl::VariadicFlag,
825 std::string &) { return builder.getFunctionType(argTypes, results); };
826
827 return function_interface_impl::parseFunctionOp(
828 parser, result, /*allowVariadic=*/false,
829 getFunctionTypeAttrName(result.name), buildFuncType,
830 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
831}
832
833void ConversionFuncOp::print(OpAsmPrinter &p) {
834 function_interface_impl::printFunctionOp(
835 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
836 getArgAttrsAttrName(), getResAttrsAttrName());
837}
838
839//===----------------------------------------------------------------------===//
840// TestValueWithBoundsOp
841//===----------------------------------------------------------------------===//
842
843void TestValueWithBoundsOp::populateBoundsForIndexValue(
844 Value v, ValueBoundsConstraintSet &cstr) {
845 cstr.bound(v) >= getMin().getSExtValue();
846 cstr.bound(v) <= getMax().getSExtValue();
847}
848
849//===----------------------------------------------------------------------===//
850// ReifyBoundOp
851//===----------------------------------------------------------------------===//
852
853mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
854 if (getType() == "EQ")
855 return mlir::presburger::BoundType::EQ;
856 if (getType() == "LB")
857 return mlir::presburger::BoundType::LB;
858 if (getType() == "UB")
859 return mlir::presburger::BoundType::UB;
860 llvm_unreachable("invalid bound type");
861}
862
863LogicalResult ReifyBoundOp::verify() {
864 if (isa<ShapedType>(getVar().getType())) {
865 if (!getDim().has_value())
866 return emitOpError("expected 'dim' attribute for shaped type variable");
867 } else if (getVar().getType().isIndex()) {
868 if (getDim().has_value())
869 return emitOpError("unexpected 'dim' attribute for index variable");
870 } else {
871 return emitOpError("expected index-typed variable or shape type variable");
872 }
873 if (getConstant() && getScalable())
874 return emitOpError("'scalable' and 'constant' are mutually exlusive");
875 if (getScalable() != getVscaleMin().has_value())
876 return emitOpError("expected 'vscale_min' if and only if 'scalable'");
877 if (getScalable() != getVscaleMax().has_value())
878 return emitOpError("expected 'vscale_min' if and only if 'scalable'");
879 return success();
880}
881
882ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
883 if (getDim().has_value())
884 return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
885 return ValueBoundsConstraintSet::Variable(getVar());
886}
887
888//===----------------------------------------------------------------------===//
889// CompareOp
890//===----------------------------------------------------------------------===//
891
892ValueBoundsConstraintSet::ComparisonOperator
893CompareOp::getComparisonOperator() {
894 if (getCmp() == "EQ")
895 return ValueBoundsConstraintSet::ComparisonOperator::EQ;
896 if (getCmp() == "LT")
897 return ValueBoundsConstraintSet::ComparisonOperator::LT;
898 if (getCmp() == "LE")
899 return ValueBoundsConstraintSet::ComparisonOperator::LE;
900 if (getCmp() == "GT")
901 return ValueBoundsConstraintSet::ComparisonOperator::GT;
902 if (getCmp() == "GE")
903 return ValueBoundsConstraintSet::ComparisonOperator::GE;
904 llvm_unreachable("invalid comparison operator");
905}
906
907mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
908 if (!getLhsMap())
909 return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
910 SmallVector<Value> mapOperands(
911 getVarOperands().slice(0, getLhsMap()->getNumInputs()));
912 return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
913}
914
915mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
916 int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
917 if (!getRhsMap())
918 return ValueBoundsConstraintSet::Variable(
919 getVarOperands()[rhsOperandsBegin]);
920 SmallVector<Value> mapOperands(
921 getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
922 return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
923}
924
925LogicalResult CompareOp::verify() {
926 if (getCompose() && (getLhsMap() || getRhsMap()))
927 return emitOpError(
928 "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
929 int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
930 expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
931 if (getVarOperands().size() != size_t(expectedNumOperands))
932 return emitOpError("expected ")
933 << expectedNumOperands << " operands, but got "
934 << getVarOperands().size();
935 return success();
936}
937
938//===----------------------------------------------------------------------===//
939// TestOpInPlaceSelfFold
940//===----------------------------------------------------------------------===//
941
942OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
943 if (!getFolded()) {
944 // The folder adds the "folded" if not present.
945 setFolded(true);
946 return getResult();
947 }
948 return {};
949}
950
951//===----------------------------------------------------------------------===//
952// TestOpFoldWithFoldAdaptor
953//===----------------------------------------------------------------------===//
954
955OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
956 int64_t sum = 0;
957 if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
958 sum += value.getValue().getSExtValue();
959
960 for (Attribute attr : adaptor.getVariadic())
961 if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
962 sum += 2 * value.getValue().getSExtValue();
963
964 for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
965 for (Attribute attr : attrs)
966 if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
967 sum += 3 * value.getValue().getSExtValue();
968
969 sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
970
971 return IntegerAttr::get(getType(), sum);
972}
973
974//===----------------------------------------------------------------------===//
975// OpWithInferTypeAdaptorInterfaceOp
976//===----------------------------------------------------------------------===//
977
978LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
979 MLIRContext *, std::optional<Location> location,
980 OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
981 SmallVectorImpl<Type> &inferredReturnTypes) {
982 if (adaptor.getX().getType() != adaptor.getY().getType()) {
983 return emitOptionalError(location, "operand type mismatch ",
984 adaptor.getX().getType(), " vs ",
985 adaptor.getY().getType());
986 }
987 inferredReturnTypes.assign({adaptor.getX().getType()});
988 return success();
989}
990
991//===----------------------------------------------------------------------===//
992// OpWithRefineTypeInterfaceOp
993//===----------------------------------------------------------------------===//
994
995// TODO: We should be able to only define either inferReturnType or
996// refineReturnType, currently only refineReturnType can be omitted.
997LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
998 MLIRContext *context, std::optional<Location> location, ValueRange operands,
999 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1000 SmallVectorImpl<Type> &returnTypes) {
1001 returnTypes.clear();
1002 return OpWithRefineTypeInterfaceOp::refineReturnTypes(
1003 context, location, operands, attributes, properties, regions,
1004 returnTypes);
1005}
1006
1007LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
1008 MLIRContext *, std::optional<Location> location, ValueRange operands,
1009 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1010 SmallVectorImpl<Type> &returnTypes) {
1011 if (operands[0].getType() != operands[1].getType()) {
1012 return emitOptionalError(location, "operand type mismatch ",
1013 operands[0].getType(), " vs ",
1014 operands[1].getType());
1015 }
1016 // TODO: Add helper to make this more concise to write.
1017 if (returnTypes.empty())
1018 returnTypes.resize(1, nullptr);
1019 if (returnTypes[0] && returnTypes[0] != operands[0].getType())
1020 return emitOptionalError(location,
1021 "required first operand and result to match");
1022 returnTypes[0] = operands[0].getType();
1023 return success();
1024}
1025
1026//===----------------------------------------------------------------------===//
1027// OpWithShapedTypeInferTypeAdaptorInterfaceOp
1028//===----------------------------------------------------------------------===//
1029
1030LogicalResult
1031OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
1032 MLIRContext *context, std::optional<Location> location,
1033 OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
1034 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1035 // Create return type consisting of the last element of the first operand.
1036 auto operandType = adaptor.getOperand1().getType();
1037 auto sval = dyn_cast<ShapedType>(operandType);
1038 if (!sval)
1039 return emitOptionalError(location, "only shaped type operands allowed");
1040 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
1041 auto type = IntegerType::get(context, 17);
1042
1043 Attribute encoding;
1044 if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
1045 encoding = rankedTy.getEncoding();
1046 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
1047 return success();
1048}
1049
1050LogicalResult
1051OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
1052 OpBuilder &builder, ValueRange operands,
1053 llvm::SmallVectorImpl<Value> &shapes) {
1054 shapes = SmallVector<Value, 1>{
1055 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1056 return success();
1057}
1058
1059//===----------------------------------------------------------------------===//
1060// TestOpWithPropertiesAndInferredType
1061//===----------------------------------------------------------------------===//
1062
1063LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
1064 MLIRContext *context, std::optional<Location>, ValueRange operands,
1065 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1066 SmallVectorImpl<Type> &inferredReturnTypes) {
1067
1068 Adaptor adaptor(operands, attributes, properties, regions);
1069 inferredReturnTypes.push_back(IntegerType::get(
1070 context, adaptor.getLhs() + adaptor.getProperties().rhs));
1071 return success();
1072}
1073
1074//===----------------------------------------------------------------------===//
1075// LoopBlockOp
1076//===----------------------------------------------------------------------===//
1077
1078void LoopBlockOp::getSuccessorRegions(
1079 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1080 regions.emplace_back(&getBody(), getBody().getArguments());
1081 if (point.isParent())
1082 return;
1083
1084 regions.emplace_back((*this)->getResults());
1085}
1086
1087OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1088 assert(point == getBody());
1089 return MutableOperandRange(getInitMutable());
1090}
1091
1092//===----------------------------------------------------------------------===//
1093// LoopBlockTerminatorOp
1094//===----------------------------------------------------------------------===//
1095
1096MutableOperandRange
1097LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
1098 if (point.isParent())
1099 return getExitArgMutable();
1100 return getNextIterArgMutable();
1101}
1102
1103//===----------------------------------------------------------------------===//
1104// SwitchWithNoBreakOp
1105//===----------------------------------------------------------------------===//
1106
1107void TestNoTerminatorOp::getSuccessorRegions(
1108 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}
1109
1110//===----------------------------------------------------------------------===//
1111// Test InferIntRangeInterface
1112//===----------------------------------------------------------------------===//
1113
1114OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
1115 // Just a simple fold for testing purposes that reads an operands constant
1116 // value and returns it.
1117 if (!attributes.empty())
1118 return attributes.front();
1119 return nullptr;
1120}
1121
1122//===----------------------------------------------------------------------===//
1123// Tensor/Buffer Ops
1124//===----------------------------------------------------------------------===//
1125
1126void ReadBufferOp::getEffects(
1127 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1128 &effects) {
1129 // The buffer operand is read.
1130 effects.emplace_back(MemoryEffects::Read::get(), &getBufferMutable(),
1131 SideEffects::DefaultResource::get());
1132 // The buffer contents are dumped.
1133 effects.emplace_back(MemoryEffects::Write::get(),
1134 SideEffects::DefaultResource::get());
1135}
1136
1137//===----------------------------------------------------------------------===//
1138// Test Dataflow
1139//===----------------------------------------------------------------------===//
1140
1141//===----------------------------------------------------------------------===//
1142// TestCallAndStoreOp
1143//===----------------------------------------------------------------------===//
1144
1145CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
1146 return getCallee();
1147}
1148
1149void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1150 setCalleeAttr(cast<SymbolRefAttr>(callee));
1151}
1152
1153Operation::operand_range TestCallAndStoreOp::getArgOperands() {
1154 return getCalleeOperands();
1155}
1156
1157MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
1158 return getCalleeOperandsMutable();
1159}
1160
1161//===----------------------------------------------------------------------===//
1162// TestCallOnDeviceOp
1163//===----------------------------------------------------------------------===//
1164
1165CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
1166 return getCallee();
1167}
1168
1169void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1170 setCalleeAttr(cast<SymbolRefAttr>(callee));
1171}
1172
1173Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
1174 return getForwardedOperands();
1175}
1176
1177MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
1178 return getForwardedOperandsMutable();
1179}
1180
1181//===----------------------------------------------------------------------===//
1182// TestStoreWithARegion
1183//===----------------------------------------------------------------------===//
1184
1185void TestStoreWithARegion::getSuccessorRegions(
1186 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1187 if (point.isParent())
1188 regions.emplace_back(&getBody(), getBody().front().getArguments());
1189 else
1190 regions.emplace_back();
1191}
1192
1193//===----------------------------------------------------------------------===//
1194// TestStoreWithALoopRegion
1195//===----------------------------------------------------------------------===//
1196
1197void TestStoreWithALoopRegion::getSuccessorRegions(
1198 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1199 // Both the operation itself and the region may be branching into the body or
1200 // back into the operation itself. It is possible for the operation not to
1201 // enter the body.
1202 regions.emplace_back(
1203 RegionSuccessor(&getBody(), getBody().front().getArguments()));
1204 regions.emplace_back();
1205}
1206
1207//===----------------------------------------------------------------------===//
1208// TestVersionedOpA
1209//===----------------------------------------------------------------------===//
1210
1211LogicalResult
1212TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader,
1213 mlir::OperationState &state) {
1214 auto &prop = state.getOrAddProperties<Properties>();
1215 if (mlir::failed(reader.readAttribute(prop.dims)))
1216 return mlir::failure();
1217
1218 // Check if we have a version. If not, assume we are parsing the current
1219 // version.
1220 auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1221 if (succeeded(maybeVersion)) {
1222 // If version is less than 2.0, there is no additional attribute to parse.
1223 // We can materialize missing properties post parsing before verification.
1224 const auto *version =
1225 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1226 if ((version->major_ < 2)) {
1227 return success();
1228 }
1229 }
1230
1231 if (mlir::failed(reader.readAttribute(prop.modifier)))
1232 return mlir::failure();
1233 return mlir::success();
1234}
1235
1236void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
1237 auto &prop = getProperties();
1238 writer.writeAttribute(prop.dims);
1239
1240 auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
1241 if (succeeded(maybeVersion)) {
1242 // If version is less than 2.0, there is no additional attribute to write.
1243 const auto *version =
1244 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1245 if ((version->major_ < 2)) {
1246 llvm::outs() << "downgrading op properties...\n";
1247 return;
1248 }
1249 }
1250 writer.writeAttribute(prop.modifier);
1251}
1252
1253//===----------------------------------------------------------------------===//
1254// TestOpWithVersionedProperties
1255//===----------------------------------------------------------------------===//
1256
1257llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
1258 mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
1259 uint64_t value1, value2 = 0;
1260 if (failed(reader.readVarInt(value1)))
1261 return failure();
1262
1263 // Check if we have a version. If not, assume we are parsing the current
1264 // version.
1265 auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1266 bool needToParseAnotherInt = true;
1267 if (succeeded(maybeVersion)) {
1268 // If version is less than 2.0, there is no additional attribute to parse.
1269 // We can materialize missing properties post parsing before verification.
1270 const auto *version =
1271 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1272 if ((version->major_ < 2))
1273 needToParseAnotherInt = false;
1274 }
1275 if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
1276 return failure();
1277
1278 prop.value1 = value1;
1279 prop.value2 = value2;
1280 return success();
1281}
1282
1283void TestOpWithVersionedProperties::writeToMlirBytecode(
1284 mlir::DialectBytecodeWriter &writer,
1285 const test::VersionedProperties &prop) {
1286 writer.writeVarInt(prop.value1);
1287 writer.writeVarInt(prop.value2);
1288}
1289
1290//===----------------------------------------------------------------------===//
1291// TestMultiSlotAlloca
1292//===----------------------------------------------------------------------===//
1293
1294llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
1295 SmallVector<MemorySlot> slots;
1296 for (Value result : getResults()) {
1297 slots.push_back(MemorySlot{
1298 result, cast<MemRefType>(result.getType()).getElementType()});
1299 }
1300 return slots;
1301}
1302
1303Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
1304 OpBuilder &builder) {
1305 return builder.create<TestOpConstant>(getLoc(), slot.elemType,
1306 builder.getI32IntegerAttr(42));
1307}
1308
1309void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
1310 BlockArgument argument,
1311 OpBuilder &builder) {
1312 // Not relevant for testing.
1313}
1314
1315/// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
1316static std::optional<TestMultiSlotAlloca>
1317createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
1318 TestMultiSlotAlloca oldOp) {
1319
1320 if (oldOp.getNumResults() == 1) {
1321 oldOp.erase();
1322 return std::nullopt;
1323 }
1324
1325 SmallVector<Type> newTypes;
1326 SmallVector<Value> remainingValues;
1327
1328 for (Value oldResult : oldOp.getResults()) {
1329 if (oldResult == slot.ptr)
1330 continue;
1331 remainingValues.push_back(oldResult);
1332 newTypes.push_back(oldResult.getType());
1333 }
1334
1335 OpBuilder::InsertionGuard guard(builder);
1336 builder.setInsertionPoint(oldOp);
1337 auto replacement =
1338 builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
1339 for (auto [oldResult, newResult] :
1340 llvm::zip_equal(remainingValues, replacement.getResults()))
1341 oldResult.replaceAllUsesWith(newResult);
1342
1343 oldOp.erase();
1344 return replacement;
1345}
1346
1347std::optional<PromotableAllocationOpInterface>
1348TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
1349 Value defaultValue,
1350 OpBuilder &builder) {
1351 if (defaultValue && defaultValue.use_empty())
1352 defaultValue.getDefiningOp()->erase();
1353 return createNewMultiAllocaWithoutSlot(slot, builder, *this);
1354}
1355
1356SmallVector<DestructurableMemorySlot>
1357TestMultiSlotAlloca::getDestructurableSlots() {
1358 SmallVector<DestructurableMemorySlot> slots;
1359 for (Value result : getResults()) {
1360 auto memrefType = cast<MemRefType>(result.getType());
1361 auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType);
1362 if (!destructurable)
1363 continue;
1364
1365 std::optional<DenseMap<Attribute, Type>> destructuredType =
1366 destructurable.getSubelementIndexMap();
1367 if (!destructuredType)
1368 continue;
1369 slots.emplace_back(
1370 DestructurableMemorySlot{{result, memrefType}, *destructuredType});
1371 }
1372 return slots;
1373}
1374
1375DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
1376 const DestructurableMemorySlot &slot,
1377 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
1378 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
1379 OpBuilder::InsertionGuard guard(builder);
1380 builder.setInsertionPointAfter(*this);
1381
1382 DenseMap<Attribute, MemorySlot> slotMap;
1383
1384 for (Attribute usedIndex : usedIndices) {
1385 Type elemType = slot.subelementTypes.lookup(usedIndex);
1386 MemRefType elemPtr = MemRefType::get({}, elemType);
1387 auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
1388 newAllocators.push_back(subAlloca);
1389 slotMap.try_emplace<MemorySlot>(usedIndex,
1390 {subAlloca.getResult(0), elemType});
1391 }
1392
1393 return slotMap;
1394}
1395
1396std::optional<DestructurableAllocationOpInterface>
1397TestMultiSlotAlloca::handleDestructuringComplete(
1398 const DestructurableMemorySlot &slot, OpBuilder &builder) {
1399 return createNewMultiAllocaWithoutSlot(slot, builder, *this);
1400}
1401
1402::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
1403 ::mlir::RewriterBase &rewriter,
1404 const ::mlir::bufferization::BufferizationOptions &options,
1405 ::mlir::bufferization::BufferizationState &state) {
1406 auto buffer =
1407 mlir::bufferization::getBuffer(rewriter, getInput(), options, state);
1408 if (mlir::failed(buffer))
1409 return failure();
1410
1411 const auto outType = getOutput().getType();
1412 const auto bufferizedOutType = test::TestMemrefType::get(
1413 getContext(), outType.getShape(), outType.getElementType(), nullptr);
1414 // replace op with memref analogy
1415 auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>(
1416 getLoc(), bufferizedOutType, *buffer);
1417
1418 mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(),
1419 dummyMemrefOp.getResult());
1420
1421 return mlir::success();
1422}
1423
1424::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
1425 ::mlir::RewriterBase &rewriter,
1426 const ::mlir::bufferization::BufferizationOptions &options,
1427 ::mlir::bufferization::BufferizationState &state) {
1428 // Note: mlir::bufferization::getBufferType() would internally call
1429 // TestCreateTensorOp::getBufferType()
1430 const auto bufferizedOutType =
1431 mlir::bufferization::getBufferType(getOutput(), options, state);
1432 if (mlir::failed(bufferizedOutType))
1433 return failure();
1434
1435 // replace op with memref analogy
1436 auto createMemrefOp =
1437 rewriter.create<test::TestCreateMemrefOp>(getLoc(), *bufferizedOutType);
1438
1439 mlir::bufferization::replaceOpWithBufferizedValues(
1440 rewriter, getOperation(), createMemrefOp.getResult());
1441
1442 return mlir::success();
1443}
1444
1445mlir::FailureOr<mlir::bufferization::BufferLikeType>
1446test::TestCreateTensorOp::getBufferType(
1447 mlir::Value value, const mlir::bufferization::BufferizationOptions &,
1448 const mlir::bufferization::BufferizationState &,
1449 llvm::SmallVector<::mlir::Value> &) {
1450 const auto type = dyn_cast<test::TestTensorType>(value.getType());
1451 if (type == nullptr)
1452 return failure();
1453
1454 return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
1455 getContext(), type.getShape(), type.getElementType(), nullptr));
1456}
1457

source code of mlir/test/lib/Dialect/Test/TestOpDefs.cpp