1//===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Tensor/IR/Tensor.h"
13#include "mlir/IR/Verifier.h"
14#include "mlir/Interfaces/FunctionImplementation.h"
15#include "mlir/Interfaces/MemorySlotInterfaces.h"
16
17using namespace mlir;
18using namespace test;
19
20//===----------------------------------------------------------------------===//
21// TestBranchOp
22//===----------------------------------------------------------------------===//
23
24SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
25 assert(index == 0 && "invalid successor index");
26 return SuccessorOperands(getTargetOperandsMutable());
27}
28
29//===----------------------------------------------------------------------===//
30// TestProducingBranchOp
31//===----------------------------------------------------------------------===//
32
33SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
34 assert(index <= 1 && "invalid successor index");
35 if (index == 1)
36 return SuccessorOperands(getFirstOperandsMutable());
37 return SuccessorOperands(getSecondOperandsMutable());
38}
39
40//===----------------------------------------------------------------------===//
41// TestInternalBranchOp
42//===----------------------------------------------------------------------===//
43
44SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
45 assert(index <= 1 && "invalid successor index");
46 if (index == 0)
47 return SuccessorOperands(0, getSuccessOperandsMutable());
48 return SuccessorOperands(1, getErrorOperandsMutable());
49}
50
51//===----------------------------------------------------------------------===//
52// TestCallOp
53//===----------------------------------------------------------------------===//
54
55LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
56 // Check that the callee attribute was specified.
57 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>(name: "callee");
58 if (!fnAttr)
59 return emitOpError(message: "requires a 'callee' symbol reference attribute");
60 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(from: *this, symbol: fnAttr))
61 return emitOpError() << "'" << fnAttr.getValue()
62 << "' does not reference a valid function";
63 return success();
64}
65
66//===----------------------------------------------------------------------===//
67// FoldToCallOp
68//===----------------------------------------------------------------------===//
69
70namespace {
71struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
72 using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
73
74 LogicalResult matchAndRewrite(FoldToCallOp op,
75 PatternRewriter &rewriter) const override {
76 rewriter.replaceOpWithNewOp<func::CallOp>(op, args: TypeRange(),
77 args: op.getCalleeAttr(), args: ValueRange());
78 return success();
79 }
80};
81} // namespace
82
83void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
84 MLIRContext *context) {
85 results.add<FoldToCallOpPattern>(arg&: context);
86}
87
88//===----------------------------------------------------------------------===//
89// IsolatedRegionOp - test parsing passthrough operands
90//===----------------------------------------------------------------------===//
91
92ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
93 OperationState &result) {
94 // Parse the input operand.
95 OpAsmParser::Argument argInfo;
96 argInfo.type = parser.getBuilder().getIndexType();
97 if (parser.parseOperand(result&: argInfo.ssaName) ||
98 parser.resolveOperand(operand: argInfo.ssaName, type: argInfo.type, result&: result.operands))
99 return failure();
100
101 // Parse the body region, and reuse the operand info as the argument info.
102 Region *body = result.addRegion();
103 return parser.parseRegion(region&: *body, arguments: argInfo, /*enableNameShadowing=*/true);
104}
105
106void IsolatedRegionOp::print(OpAsmPrinter &p) {
107 p << ' ';
108 p.printOperand(value: getOperand());
109 p.shadowRegionArgs(region&: getRegion(), namesToUse: getOperand());
110 p << ' ';
111 p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false);
112}
113
114//===----------------------------------------------------------------------===//
115// SSACFGRegionOp
116//===----------------------------------------------------------------------===//
117
118RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
119 return RegionKind::SSACFG;
120}
121
122//===----------------------------------------------------------------------===//
123// GraphRegionOp
124//===----------------------------------------------------------------------===//
125
126RegionKind GraphRegionOp::getRegionKind(unsigned index) {
127 return RegionKind::Graph;
128}
129
130//===----------------------------------------------------------------------===//
131// IsolatedGraphRegionOp
132//===----------------------------------------------------------------------===//
133
134RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) {
135 return RegionKind::Graph;
136}
137
138//===----------------------------------------------------------------------===//
139// AffineScopeOp
140//===----------------------------------------------------------------------===//
141
142ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
143 // Parse the body region, and reuse the operand info as the argument info.
144 Region *body = result.addRegion();
145 return parser.parseRegion(region&: *body, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {});
146}
147
148void AffineScopeOp::print(OpAsmPrinter &p) {
149 p << " ";
150 p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false);
151}
152
153//===----------------------------------------------------------------------===//
154// TestRemoveOpWithInnerOps
155//===----------------------------------------------------------------------===//
156
157namespace {
158struct TestRemoveOpWithInnerOps
159 : public OpRewritePattern<TestOpWithRegionPattern> {
160 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
161
162 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
163
164 LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
165 PatternRewriter &rewriter) const override {
166 rewriter.eraseOp(op);
167 return success();
168 }
169};
170} // namespace
171
172//===----------------------------------------------------------------------===//
173// TestOpWithRegionPattern
174//===----------------------------------------------------------------------===//
175
176void TestOpWithRegionPattern::getCanonicalizationPatterns(
177 RewritePatternSet &results, MLIRContext *context) {
178 results.add<TestRemoveOpWithInnerOps>(arg&: context);
179}
180
181//===----------------------------------------------------------------------===//
182// TestOpWithRegionFold
183//===----------------------------------------------------------------------===//
184
185OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
186 return getOperand();
187}
188
189//===----------------------------------------------------------------------===//
190// TestOpConstant
191//===----------------------------------------------------------------------===//
192
193OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
194
195//===----------------------------------------------------------------------===//
196// TestOpWithVariadicResultsAndFolder
197//===----------------------------------------------------------------------===//
198
199LogicalResult TestOpWithVariadicResultsAndFolder::fold(
200 FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
201 for (Value input : this->getOperands()) {
202 results.push_back(Elt: input);
203 }
204 return success();
205}
206
207//===----------------------------------------------------------------------===//
208// TestOpInPlaceFold
209//===----------------------------------------------------------------------===//
210
211OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
212 // Exercise the fact that an operation created with createOrFold should be
213 // allowed to access its parent block.
214 assert(getOperation()->getBlock() &&
215 "expected that operation is not unlinked");
216
217 if (adaptor.getOp() && !getProperties().attr) {
218 // The folder adds "attr" if not present.
219 getProperties().attr = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getOp());
220 return getResult();
221 }
222 return {};
223}
224
225//===----------------------------------------------------------------------===//
226// OpWithInferTypeInterfaceOp
227//===----------------------------------------------------------------------===//
228
229LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
230 MLIRContext *, std::optional<Location> location, ValueRange operands,
231 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
232 SmallVectorImpl<Type> &inferredReturnTypes) {
233 if (operands[0].getType() != operands[1].getType()) {
234 return emitOptionalError(loc: location, args: "operand type mismatch ",
235 args: operands[0].getType(), args: " vs ",
236 args: operands[1].getType());
237 }
238 inferredReturnTypes.assign(IL: {operands[0].getType()});
239 return success();
240}
241
242//===----------------------------------------------------------------------===//
243// OpWithShapedTypeInferTypeInterfaceOp
244//===----------------------------------------------------------------------===//
245
246LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
247 MLIRContext *context, std::optional<Location> location,
248 ValueShapeRange operands, DictionaryAttr attributes,
249 OpaqueProperties properties, RegionRange regions,
250 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
251 // Create return type consisting of the last element of the first operand.
252 auto operandType = operands.front().getType();
253 auto sval = dyn_cast<ShapedType>(Val&: operandType);
254 if (!sval)
255 return emitOptionalError(loc: location, args: "only shaped type operands allowed");
256 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
257 auto type = IntegerType::get(context, width: 17);
258
259 Attribute encoding;
260 if (auto rankedTy = dyn_cast<RankedTensorType>(Val&: sval))
261 encoding = rankedTy.getEncoding();
262 inferredReturnShapes.push_back(Elt: ShapedTypeComponents({dim}, type, encoding));
263 return success();
264}
265
266LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
267 OpBuilder &builder, ValueRange operands,
268 llvm::SmallVectorImpl<Value> &shapes) {
269 shapes = SmallVector<Value, 1>{
270 builder.createOrFold<tensor::DimOp>(location: getLoc(), args: operands.front(), args: 0)};
271 return success();
272}
273
274//===----------------------------------------------------------------------===//
275// OpWithResultShapeInterfaceOp
276//===----------------------------------------------------------------------===//
277
278LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
279 OpBuilder &builder, ValueRange operands,
280 llvm::SmallVectorImpl<Value> &shapes) {
281 Location loc = getLoc();
282 shapes.reserve(N: operands.size());
283 for (Value operand : llvm::reverse(C&: operands)) {
284 auto rank = cast<RankedTensorType>(Val: operand.getType()).getRank();
285 auto currShape = llvm::to_vector<4>(
286 Range: llvm::map_range(C: llvm::seq<int64_t>(Begin: 0, End: rank), F: [&](int64_t dim) -> Value {
287 return builder.createOrFold<tensor::DimOp>(location: loc, args&: operand, args&: dim);
288 }));
289 shapes.push_back(Elt: builder.create<tensor::FromElementsOp>(
290 location: getLoc(), args: RankedTensorType::get(shape: {rank}, elementType: builder.getIndexType()),
291 args&: currShape));
292 }
293 return success();
294}
295
296//===----------------------------------------------------------------------===//
297// OpWithResultShapePerDimInterfaceOp
298//===----------------------------------------------------------------------===//
299
300LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
301 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
302 Location loc = getLoc();
303 shapes.reserve(N: getNumOperands());
304 for (Value operand : llvm::reverse(C: getOperands())) {
305 auto tensorType = cast<RankedTensorType>(Val: operand.getType());
306 auto currShape = llvm::to_vector<4>(Range: llvm::map_range(
307 C: llvm::seq<int64_t>(Begin: 0, End: tensorType.getRank()),
308 F: [&](int64_t dim) -> OpFoldResult {
309 return tensorType.isDynamicDim(idx: dim)
310 ? static_cast<OpFoldResult>(
311 builder.createOrFold<tensor::DimOp>(location: loc, args&: operand,
312 args&: dim))
313 : static_cast<OpFoldResult>(
314 builder.getIndexAttr(value: tensorType.getDimSize(idx: dim)));
315 }));
316 shapes.emplace_back(Args: std::move(currShape));
317 }
318 return success();
319}
320
321//===----------------------------------------------------------------------===//
322// SideEffectOp
323//===----------------------------------------------------------------------===//
324
325namespace {
326/// A test resource for side effects.
327struct TestResource : public SideEffects::Resource::Base<TestResource> {
328 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
329
330 StringRef getName() final { return "<Test>"; }
331};
332} // namespace
333
334void SideEffectOp::getEffects(
335 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
336 // Check for an effects attribute on the op instance.
337 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>(name: "effects");
338 if (!effectsAttr)
339 return;
340
341 for (Attribute element : effectsAttr) {
342 DictionaryAttr effectElement = cast<DictionaryAttr>(Val&: element);
343
344 // Get the specific memory effect.
345 MemoryEffects::Effect *effect =
346 StringSwitch<MemoryEffects::Effect *>(
347 cast<StringAttr>(Val: effectElement.get(name: "effect")).getValue())
348 .Case(S: "allocate", Value: MemoryEffects::Allocate::get())
349 .Case(S: "free", Value: MemoryEffects::Free::get())
350 .Case(S: "read", Value: MemoryEffects::Read::get())
351 .Case(S: "write", Value: MemoryEffects::Write::get());
352
353 // Check for a non-default resource to use.
354 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
355 if (effectElement.get(name: "test_resource"))
356 resource = TestResource::get();
357
358 // Check for a result to affect.
359 if (effectElement.get(name: "on_result"))
360 effects.emplace_back(Args&: effect, Args: getOperation()->getOpResults()[0], Args&: resource);
361 else if (Attribute ref = effectElement.get(name: "on_reference"))
362 effects.emplace_back(Args&: effect, Args: cast<SymbolRefAttr>(Val&: ref), Args&: resource);
363 else
364 effects.emplace_back(Args&: effect, Args&: resource);
365 }
366}
367
368void SideEffectOp::getEffects(
369 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
370 testSideEffectOpGetEffect(op: getOperation(), effects);
371}
372
373void SideEffectWithRegionOp::getEffects(
374 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
375 // Check for an effects attribute on the op instance.
376 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>(name: "effects");
377 if (!effectsAttr)
378 return;
379
380 for (Attribute element : effectsAttr) {
381 DictionaryAttr effectElement = cast<DictionaryAttr>(Val&: element);
382
383 // Get the specific memory effect.
384 MemoryEffects::Effect *effect =
385 StringSwitch<MemoryEffects::Effect *>(
386 cast<StringAttr>(Val: effectElement.get(name: "effect")).getValue())
387 .Case(S: "allocate", Value: MemoryEffects::Allocate::get())
388 .Case(S: "free", Value: MemoryEffects::Free::get())
389 .Case(S: "read", Value: MemoryEffects::Read::get())
390 .Case(S: "write", Value: MemoryEffects::Write::get());
391
392 // Check for a non-default resource to use.
393 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
394 if (effectElement.get(name: "test_resource"))
395 resource = TestResource::get();
396
397 // Check for a result to affect.
398 if (effectElement.get(name: "on_result"))
399 effects.emplace_back(Args&: effect, Args: getOperation()->getOpResults()[0], Args&: resource);
400 else if (effectElement.get(name: "on_operand"))
401 effects.emplace_back(Args&: effect, Args: &getOperation()->getOpOperands()[0],
402 Args&: resource);
403 else if (effectElement.get(name: "on_argument"))
404 effects.emplace_back(Args&: effect, Args: getOperation()->getRegion(index: 0).getArgument(i: 0),
405 Args&: resource);
406 else if (Attribute ref = effectElement.get(name: "on_reference"))
407 effects.emplace_back(Args&: effect, Args: cast<SymbolRefAttr>(Val&: ref), Args&: resource);
408 else
409 effects.emplace_back(Args&: effect, Args&: resource);
410 }
411}
412
413void SideEffectWithRegionOp::getEffects(
414 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
415 testSideEffectOpGetEffect(op: getOperation(), effects);
416}
417
418//===----------------------------------------------------------------------===//
419// StringAttrPrettyNameOp
420//===----------------------------------------------------------------------===//
421
422// This op has fancy handling of its SSA result name.
423ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
424 OperationState &result) {
425 // Add the result types.
426 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
427 result.addTypes(newTypes: parser.getBuilder().getIntegerType(width: 32));
428
429 if (parser.parseOptionalAttrDictWithKeyword(result&: result.attributes))
430 return failure();
431
432 // If the attribute dictionary contains no 'names' attribute, infer it from
433 // the SSA name (if specified).
434 bool hadNames = llvm::any_of(Range&: result.attributes, P: [](NamedAttribute attr) {
435 return attr.getName() == "names";
436 });
437
438 // If there was no name specified, check to see if there was a useful name
439 // specified in the asm file.
440 if (hadNames || parser.getNumResults() == 0)
441 return success();
442
443 SmallVector<StringRef, 4> names;
444 auto *context = result.getContext();
445
446 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
447 auto resultName = parser.getResultName(resultNo: i);
448 StringRef nameStr;
449 if (!resultName.first.empty() && !isdigit(resultName.first[0]))
450 nameStr = resultName.first;
451
452 names.push_back(Elt: nameStr);
453 }
454
455 auto namesAttr = parser.getBuilder().getStrArrayAttr(values: names);
456 result.attributes.push_back(newAttribute: {StringAttr::get(context, bytes: "names"), namesAttr});
457 return success();
458}
459
460void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
461 // Note that we only need to print the "name" attribute if the asmprinter
462 // result name disagrees with it. This can happen in strange cases, e.g.
463 // when there are conflicts.
464 bool namesDisagree = getNames().size() != getNumResults();
465
466 SmallString<32> resultNameStr;
467 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
468 resultNameStr.clear();
469 llvm::raw_svector_ostream tmpStream(resultNameStr);
470 p.printOperand(value: getResult(i), os&: tmpStream);
471
472 auto expectedName = dyn_cast<StringAttr>(Val: getNames()[i]);
473 if (!expectedName ||
474 tmpStream.str().drop_front() != expectedName.getValue()) {
475 namesDisagree = true;
476 }
477 }
478
479 if (namesDisagree)
480 p.printOptionalAttrDictWithKeyword(attrs: (*this)->getAttrs());
481 else
482 p.printOptionalAttrDictWithKeyword(attrs: (*this)->getAttrs(), elidedAttrs: {"names"});
483}
484
485// We set the SSA name in the asm syntax to the contents of the name
486// attribute.
487void StringAttrPrettyNameOp::getAsmResultNames(
488 function_ref<void(Value, StringRef)> setNameFn) {
489
490 auto value = getNames();
491 for (size_t i = 0, e = value.size(); i != e; ++i)
492 if (auto str = dyn_cast<StringAttr>(Val: value[i]))
493 if (!str.getValue().empty())
494 setNameFn(getResult(i), str.getValue());
495}
496
497//===----------------------------------------------------------------------===//
498// CustomResultsNameOp
499//===----------------------------------------------------------------------===//
500
501void CustomResultsNameOp::getAsmResultNames(
502 function_ref<void(Value, StringRef)> setNameFn) {
503 ArrayAttr value = getNames();
504 for (size_t i = 0, e = value.size(); i != e; ++i)
505 if (auto str = dyn_cast<StringAttr>(Val: value[i]))
506 if (!str.empty())
507 setNameFn(getResult(i), str.getValue());
508}
509
510//===----------------------------------------------------------------------===//
511// ResultNameFromTypeOp
512//===----------------------------------------------------------------------===//
513
514void ResultNameFromTypeOp::getAsmResultNames(
515 function_ref<void(Value, StringRef)> setNameFn) {
516 auto result = getResult();
517 auto setResultNameFn = [&](::llvm::StringRef name) {
518 setNameFn(result, name);
519 };
520 auto opAsmTypeInterface =
521 ::mlir::cast<::mlir::OpAsmTypeInterface>(Val: result.getType());
522 opAsmTypeInterface.getAsmName(setNameFn: setResultNameFn);
523}
524
525//===----------------------------------------------------------------------===//
526// BlockArgumentNameFromTypeOp
527//===----------------------------------------------------------------------===//
528
529void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames(
530 ::mlir::Region &region, ::mlir::OpAsmSetValueNameFn setNameFn) {
531 for (auto &block : region) {
532 for (auto arg : block.getArguments()) {
533 if (auto opAsmTypeInterface =
534 ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(Val: arg.getType())) {
535 auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
536 opAsmTypeInterface.getAsmName(setNameFn: setArgNameFn);
537 }
538 }
539 }
540}
541
542//===----------------------------------------------------------------------===//
543// ResultTypeWithTraitOp
544//===----------------------------------------------------------------------===//
545
546LogicalResult ResultTypeWithTraitOp::verify() {
547 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
548 return success();
549 return emitError(message: "result type should have trait 'TestTypeTrait'");
550}
551
552//===----------------------------------------------------------------------===//
553// AttrWithTraitOp
554//===----------------------------------------------------------------------===//
555
556LogicalResult AttrWithTraitOp::verify() {
557 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
558 return success();
559 return emitError(message: "'attr' attribute should have trait 'TestAttrTrait'");
560}
561
562//===----------------------------------------------------------------------===//
563// RegionIfOp
564//===----------------------------------------------------------------------===//
565
566void RegionIfOp::print(OpAsmPrinter &p) {
567 p << " ";
568 p.printOperands(container: getOperands());
569 p << ": " << getOperandTypes();
570 p.printArrowTypeList(types: getResultTypes());
571 p << " then ";
572 p.printRegion(blocks&: getThenRegion(),
573 /*printEntryBlockArgs=*/true,
574 /*printBlockTerminators=*/true);
575 p << " else ";
576 p.printRegion(blocks&: getElseRegion(),
577 /*printEntryBlockArgs=*/true,
578 /*printBlockTerminators=*/true);
579 p << " join ";
580 p.printRegion(blocks&: getJoinRegion(),
581 /*printEntryBlockArgs=*/true,
582 /*printBlockTerminators=*/true);
583}
584
585ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
586 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
587 SmallVector<Type, 2> operandTypes;
588
589 result.regions.reserve(N: 3);
590 Region *thenRegion = result.addRegion();
591 Region *elseRegion = result.addRegion();
592 Region *joinRegion = result.addRegion();
593
594 // Parse operand, type and arrow type lists.
595 if (parser.parseOperandList(result&: operandInfos) ||
596 parser.parseColonTypeList(result&: operandTypes) ||
597 parser.parseArrowTypeList(result&: result.types))
598 return failure();
599
600 // Parse all attached regions.
601 if (parser.parseKeyword(keyword: "then") || parser.parseRegion(region&: *thenRegion, arguments: {}, enableNameShadowing: {}) ||
602 parser.parseKeyword(keyword: "else") || parser.parseRegion(region&: *elseRegion, arguments: {}, enableNameShadowing: {}) ||
603 parser.parseKeyword(keyword: "join") || parser.parseRegion(region&: *joinRegion, arguments: {}, enableNameShadowing: {}))
604 return failure();
605
606 return parser.resolveOperands(operands&: operandInfos, types&: operandTypes,
607 loc: parser.getCurrentLocation(), result&: result.operands);
608}
609
610OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
611 assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
612 "invalid region index");
613 return getOperands();
614}
615
616void RegionIfOp::getSuccessorRegions(
617 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
618 // We always branch to the join region.
619 if (!point.isParent()) {
620 if (point != getJoinRegion())
621 regions.push_back(Elt: RegionSuccessor(&getJoinRegion(), getJoinArgs()));
622 else
623 regions.push_back(Elt: RegionSuccessor(getResults()));
624 return;
625 }
626
627 // The then and else regions are the entry regions of this op.
628 regions.push_back(Elt: RegionSuccessor(&getThenRegion(), getThenArgs()));
629 regions.push_back(Elt: RegionSuccessor(&getElseRegion(), getElseArgs()));
630}
631
632void RegionIfOp::getRegionInvocationBounds(
633 ArrayRef<Attribute> operands,
634 SmallVectorImpl<InvocationBounds> &invocationBounds) {
635 // Each region is invoked at most once.
636 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
637}
638
639//===----------------------------------------------------------------------===//
640// AnyCondOp
641//===----------------------------------------------------------------------===//
642
643void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
644 SmallVectorImpl<RegionSuccessor> &regions) {
645 // The parent op branches into the only region, and the region branches back
646 // to the parent op.
647 if (point.isParent())
648 regions.emplace_back(Args: &getRegion());
649 else
650 regions.emplace_back(Args: getResults());
651}
652
653void AnyCondOp::getRegionInvocationBounds(
654 ArrayRef<Attribute> operands,
655 SmallVectorImpl<InvocationBounds> &invocationBounds) {
656 invocationBounds.emplace_back(Args: 1, Args: 1);
657}
658
659//===----------------------------------------------------------------------===//
660// SingleBlockImplicitTerminatorOp
661//===----------------------------------------------------------------------===//
662
663/// Testing the correctness of some traits.
664static_assert(
665 llvm::is_detected<OpTrait::has_implicit_terminator_t,
666 SingleBlockImplicitTerminatorOp>::value,
667 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
668static_assert(OpTrait::hasSingleBlockImplicitTerminator<
669 SingleBlockImplicitTerminatorOp>::value,
670 "hasSingleBlockImplicitTerminator does not match "
671 "SingleBlockImplicitTerminatorOp");
672
673//===----------------------------------------------------------------------===//
674// SingleNoTerminatorCustomAsmOp
675//===----------------------------------------------------------------------===//
676
677ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
678 OperationState &state) {
679 Region *body = state.addRegion();
680 if (parser.parseRegion(region&: *body, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
681 return failure();
682 return success();
683}
684
685void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
686 printer.printRegion(
687 blocks&: getRegion(), /*printEntryBlockArgs=*/false,
688 // This op has a single block without terminators. But explicitly mark
689 // as not printing block terminators for testing.
690 /*printBlockTerminators=*/false);
691}
692
693//===----------------------------------------------------------------------===//
694// TestVerifiersOp
695//===----------------------------------------------------------------------===//
696
697LogicalResult TestVerifiersOp::verify() {
698 if (!getRegion().hasOneBlock())
699 return emitOpError(message: "`hasOneBlock` trait hasn't been verified");
700
701 Operation *definingOp = getInput().getDefiningOp();
702 if (definingOp && failed(Result: mlir::verify(op: definingOp)))
703 return emitOpError(message: "operand hasn't been verified");
704
705 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
706 // loop.
707 mlir::emitRemark(loc: getLoc(), message: "success run of verifier");
708
709 return success();
710}
711
712LogicalResult TestVerifiersOp::verifyRegions() {
713 if (!getRegion().hasOneBlock())
714 return emitOpError(message: "`hasOneBlock` trait hasn't been verified");
715
716 for (Block &block : getRegion())
717 for (Operation &op : block)
718 if (failed(Result: mlir::verify(op: &op)))
719 return emitOpError(message: "nested op hasn't been verified");
720
721 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
722 // loop.
723 mlir::emitRemark(loc: getLoc(), message: "success run of region verifier");
724
725 return success();
726}
727
728//===----------------------------------------------------------------------===//
729// Test InferIntRangeInterface
730//===----------------------------------------------------------------------===//
731
732//===----------------------------------------------------------------------===//
733// TestWithBoundsOp
734//===----------------------------------------------------------------------===//
735
736void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
737 SetIntRangeFn setResultRanges) {
738 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
739}
740
741//===----------------------------------------------------------------------===//
742// TestWithBoundsRegionOp
743//===----------------------------------------------------------------------===//
744
745ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
746 OperationState &result) {
747 if (parser.parseOptionalAttrDict(result&: result.attributes))
748 return failure();
749
750 // Parse the input argument
751 OpAsmParser::Argument argInfo;
752 if (failed(Result: parser.parseArgument(result&: argInfo, allowType: true)))
753 return failure();
754
755 // Parse the body region, and reuse the operand info as the argument info.
756 Region *body = result.addRegion();
757 return parser.parseRegion(region&: *body, arguments: argInfo, /*enableNameShadowing=*/false);
758}
759
760void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
761 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
762 p << ' ';
763 p.printRegionArgument(arg: getRegion().getArgument(i: 0), /*argAttrs=*/{},
764 /*omitType=*/false);
765 p << ' ';
766 p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false);
767}
768
769void TestWithBoundsRegionOp::inferResultRanges(
770 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
771 Value arg = getRegion().getArgument(i: 0);
772 setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
773}
774
775//===----------------------------------------------------------------------===//
776// TestIncrementOp
777//===----------------------------------------------------------------------===//
778
779void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
780 SetIntRangeFn setResultRanges) {
781 const ConstantIntRanges &range = argRanges[0];
782 APInt one(range.umin().getBitWidth(), 1);
783 setResultRanges(getResult(),
784 {range.umin().uadd_sat(RHS: one), range.umax().uadd_sat(RHS: one),
785 range.smin().sadd_sat(RHS: one), range.smax().sadd_sat(RHS: one)});
786}
787
788//===----------------------------------------------------------------------===//
789// TestReflectBoundsOp
790//===----------------------------------------------------------------------===//
791
792void TestReflectBoundsOp::inferResultRanges(
793 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
794 const ConstantIntRanges &range = argRanges[0];
795 MLIRContext *ctx = getContext();
796 Builder b(ctx);
797 Type sIntTy, uIntTy;
798 // For plain `IntegerType`s, we can derive the appropriate signed and unsigned
799 // Types for the Attributes.
800 Type type = getElementTypeOrSelf(type: getType());
801 if (auto intTy = llvm::dyn_cast<IntegerType>(Val&: type)) {
802 unsigned bitwidth = intTy.getWidth();
803 sIntTy = b.getIntegerType(width: bitwidth, /*isSigned=*/true);
804 uIntTy = b.getIntegerType(width: bitwidth, /*isSigned=*/false);
805 } else {
806 sIntTy = uIntTy = type;
807 }
808
809 setUminAttr(b.getIntegerAttr(type: uIntTy, value: range.umin()));
810 setUmaxAttr(b.getIntegerAttr(type: uIntTy, value: range.umax()));
811 setSminAttr(b.getIntegerAttr(type: sIntTy, value: range.smin()));
812 setSmaxAttr(b.getIntegerAttr(type: sIntTy, value: range.smax()));
813 setResultRanges(getResult(), range);
814}
815
816//===----------------------------------------------------------------------===//
817// ConversionFuncOp
818//===----------------------------------------------------------------------===//
819
820ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
821 OperationState &result) {
822 auto buildFuncType =
823 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
824 function_interface_impl::VariadicFlag,
825 std::string &) { return builder.getFunctionType(inputs: argTypes, results); };
826
827 return function_interface_impl::parseFunctionOp(
828 parser, result, /*allowVariadic=*/false,
829 typeAttrName: getFunctionTypeAttrName(name: result.name), funcTypeBuilder: buildFuncType,
830 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
831}
832
833void ConversionFuncOp::print(OpAsmPrinter &p) {
834 function_interface_impl::printFunctionOp(
835 p, op: *this, /*isVariadic=*/false, typeAttrName: getFunctionTypeAttrName(),
836 argAttrsName: getArgAttrsAttrName(), resAttrsName: getResAttrsAttrName());
837}
838
839//===----------------------------------------------------------------------===//
840// TestValueWithBoundsOp
841//===----------------------------------------------------------------------===//
842
843void TestValueWithBoundsOp::populateBoundsForIndexValue(
844 Value v, ValueBoundsConstraintSet &cstr) {
845 cstr.bound(value: v) >= getMin().getSExtValue();
846 cstr.bound(value: v) <= getMax().getSExtValue();
847}
848
849//===----------------------------------------------------------------------===//
850// ReifyBoundOp
851//===----------------------------------------------------------------------===//
852
853mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
854 if (getType() == "EQ")
855 return mlir::presburger::BoundType::EQ;
856 if (getType() == "LB")
857 return mlir::presburger::BoundType::LB;
858 if (getType() == "UB")
859 return mlir::presburger::BoundType::UB;
860 llvm_unreachable("invalid bound type");
861}
862
863LogicalResult ReifyBoundOp::verify() {
864 if (isa<ShapedType>(Val: getVar().getType())) {
865 if (!getDim().has_value())
866 return emitOpError(message: "expected 'dim' attribute for shaped type variable");
867 } else if (getVar().getType().isIndex()) {
868 if (getDim().has_value())
869 return emitOpError(message: "unexpected 'dim' attribute for index variable");
870 } else {
871 return emitOpError(message: "expected index-typed variable or shape type variable");
872 }
873 if (getConstant() && getScalable())
874 return emitOpError(message: "'scalable' and 'constant' are mutually exlusive");
875 if (getScalable() != getVscaleMin().has_value())
876 return emitOpError(message: "expected 'vscale_min' if and only if 'scalable'");
877 if (getScalable() != getVscaleMax().has_value())
878 return emitOpError(message: "expected 'vscale_min' if and only if 'scalable'");
879 return success();
880}
881
882ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
883 if (getDim().has_value())
884 return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
885 return ValueBoundsConstraintSet::Variable(getVar());
886}
887
888//===----------------------------------------------------------------------===//
889// CompareOp
890//===----------------------------------------------------------------------===//
891
892ValueBoundsConstraintSet::ComparisonOperator
893CompareOp::getComparisonOperator() {
894 if (getCmp() == "EQ")
895 return ValueBoundsConstraintSet::ComparisonOperator::EQ;
896 if (getCmp() == "LT")
897 return ValueBoundsConstraintSet::ComparisonOperator::LT;
898 if (getCmp() == "LE")
899 return ValueBoundsConstraintSet::ComparisonOperator::LE;
900 if (getCmp() == "GT")
901 return ValueBoundsConstraintSet::ComparisonOperator::GT;
902 if (getCmp() == "GE")
903 return ValueBoundsConstraintSet::ComparisonOperator::GE;
904 llvm_unreachable("invalid comparison operator");
905}
906
907mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
908 if (!getLhsMap())
909 return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
910 SmallVector<Value> mapOperands(
911 getVarOperands().slice(n: 0, m: getLhsMap()->getNumInputs()));
912 return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
913}
914
915mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
916 int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
917 if (!getRhsMap())
918 return ValueBoundsConstraintSet::Variable(
919 getVarOperands()[rhsOperandsBegin]);
920 SmallVector<Value> mapOperands(
921 getVarOperands().slice(n: rhsOperandsBegin, m: getRhsMap()->getNumInputs()));
922 return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
923}
924
925LogicalResult CompareOp::verify() {
926 if (getCompose() && (getLhsMap() || getRhsMap()))
927 return emitOpError(
928 message: "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
929 int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
930 expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
931 if (getVarOperands().size() != size_t(expectedNumOperands))
932 return emitOpError(message: "expected ")
933 << expectedNumOperands << " operands, but got "
934 << getVarOperands().size();
935 return success();
936}
937
938//===----------------------------------------------------------------------===//
939// TestOpInPlaceSelfFold
940//===----------------------------------------------------------------------===//
941
942OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
943 if (!getFolded()) {
944 // The folder adds the "folded" if not present.
945 setFolded(true);
946 return getResult();
947 }
948 return {};
949}
950
951//===----------------------------------------------------------------------===//
952// TestOpFoldWithFoldAdaptor
953//===----------------------------------------------------------------------===//
954
955OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
956 int64_t sum = 0;
957 if (auto value = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getOp()))
958 sum += value.getValue().getSExtValue();
959
960 for (Attribute attr : adaptor.getVariadic())
961 if (auto value = dyn_cast_or_null<IntegerAttr>(Val&: attr))
962 sum += 2 * value.getValue().getSExtValue();
963
964 for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
965 for (Attribute attr : attrs)
966 if (auto value = dyn_cast_or_null<IntegerAttr>(Val&: attr))
967 sum += 3 * value.getValue().getSExtValue();
968
969 sum += 4 * std::distance(first: adaptor.getBody().begin(), last: adaptor.getBody().end());
970
971 return IntegerAttr::get(type: getType(), value: sum);
972}
973
974//===----------------------------------------------------------------------===//
975// OpWithInferTypeAdaptorInterfaceOp
976//===----------------------------------------------------------------------===//
977
978LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
979 MLIRContext *, std::optional<Location> location,
980 OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
981 SmallVectorImpl<Type> &inferredReturnTypes) {
982 if (adaptor.getX().getType() != adaptor.getY().getType()) {
983 return emitOptionalError(loc: location, args: "operand type mismatch ",
984 args: adaptor.getX().getType(), args: " vs ",
985 args: adaptor.getY().getType());
986 }
987 inferredReturnTypes.assign(IL: {adaptor.getX().getType()});
988 return success();
989}
990
991//===----------------------------------------------------------------------===//
992// OpWithRefineTypeInterfaceOp
993//===----------------------------------------------------------------------===//
994
995// TODO: We should be able to only define either inferReturnType or
996// refineReturnType, currently only refineReturnType can be omitted.
997LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
998 MLIRContext *context, std::optional<Location> location, ValueRange operands,
999 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1000 SmallVectorImpl<Type> &returnTypes) {
1001 returnTypes.clear();
1002 return OpWithRefineTypeInterfaceOp::refineReturnTypes(
1003 context, location, operands, attributes, properties, regions,
1004 returnTypes);
1005}
1006
1007LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
1008 MLIRContext *, std::optional<Location> location, ValueRange operands,
1009 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1010 SmallVectorImpl<Type> &returnTypes) {
1011 if (operands[0].getType() != operands[1].getType()) {
1012 return emitOptionalError(loc: location, args: "operand type mismatch ",
1013 args: operands[0].getType(), args: " vs ",
1014 args: operands[1].getType());
1015 }
1016 // TODO: Add helper to make this more concise to write.
1017 if (returnTypes.empty())
1018 returnTypes.resize(N: 1, NV: nullptr);
1019 if (returnTypes[0] && returnTypes[0] != operands[0].getType())
1020 return emitOptionalError(loc: location,
1021 args: "required first operand and result to match");
1022 returnTypes[0] = operands[0].getType();
1023 return success();
1024}
1025
1026//===----------------------------------------------------------------------===//
1027// OpWithShapedTypeInferTypeAdaptorInterfaceOp
1028//===----------------------------------------------------------------------===//
1029
1030LogicalResult
1031OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
1032 MLIRContext *context, std::optional<Location> location,
1033 OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
1034 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1035 // Create return type consisting of the last element of the first operand.
1036 auto operandType = adaptor.getOperand1().getType();
1037 auto sval = dyn_cast<ShapedType>(Val&: operandType);
1038 if (!sval)
1039 return emitOptionalError(loc: location, args: "only shaped type operands allowed");
1040 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
1041 auto type = IntegerType::get(context, width: 17);
1042
1043 Attribute encoding;
1044 if (auto rankedTy = dyn_cast<RankedTensorType>(Val&: sval))
1045 encoding = rankedTy.getEncoding();
1046 inferredReturnShapes.push_back(Elt: ShapedTypeComponents({dim}, type, encoding));
1047 return success();
1048}
1049
1050LogicalResult
1051OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
1052 OpBuilder &builder, ValueRange operands,
1053 llvm::SmallVectorImpl<Value> &shapes) {
1054 shapes = SmallVector<Value, 1>{
1055 builder.createOrFold<tensor::DimOp>(location: getLoc(), args: operands.front(), args: 0)};
1056 return success();
1057}
1058
1059//===----------------------------------------------------------------------===//
1060// TestOpWithPropertiesAndInferredType
1061//===----------------------------------------------------------------------===//
1062
1063LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
1064 MLIRContext *context, std::optional<Location>, ValueRange operands,
1065 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1066 SmallVectorImpl<Type> &inferredReturnTypes) {
1067
1068 Adaptor adaptor(operands, attributes, properties, regions);
1069 inferredReturnTypes.push_back(Elt: IntegerType::get(
1070 context, width: adaptor.getLhs() + adaptor.getProperties().rhs));
1071 return success();
1072}
1073
1074//===----------------------------------------------------------------------===//
1075// LoopBlockOp
1076//===----------------------------------------------------------------------===//
1077
1078void LoopBlockOp::getSuccessorRegions(
1079 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1080 regions.emplace_back(Args: &getBody(), Args: getBody().getArguments());
1081 if (point.isParent())
1082 return;
1083
1084 regions.emplace_back(Args: (*this)->getResults());
1085}
1086
1087OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1088 assert(point == getBody());
1089 return MutableOperandRange(getInitMutable());
1090}
1091
1092//===----------------------------------------------------------------------===//
1093// LoopBlockTerminatorOp
1094//===----------------------------------------------------------------------===//
1095
1096MutableOperandRange
1097LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
1098 if (point.isParent())
1099 return getExitArgMutable();
1100 return getNextIterArgMutable();
1101}
1102
1103//===----------------------------------------------------------------------===//
1104// SwitchWithNoBreakOp
1105//===----------------------------------------------------------------------===//
1106
1107void TestNoTerminatorOp::getSuccessorRegions(
1108 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}
1109
1110//===----------------------------------------------------------------------===//
1111// Test InferIntRangeInterface
1112//===----------------------------------------------------------------------===//
1113
1114OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
1115 // Just a simple fold for testing purposes that reads an operands constant
1116 // value and returns it.
1117 if (!attributes.empty())
1118 return attributes.front();
1119 return nullptr;
1120}
1121
1122//===----------------------------------------------------------------------===//
1123// Tensor/Buffer Ops
1124//===----------------------------------------------------------------------===//
1125
1126void ReadBufferOp::getEffects(
1127 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1128 &effects) {
1129 // The buffer operand is read.
1130 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &getBufferMutable(),
1131 Args: SideEffects::DefaultResource::get());
1132 // The buffer contents are dumped.
1133 effects.emplace_back(Args: MemoryEffects::Write::get(),
1134 Args: SideEffects::DefaultResource::get());
1135}
1136
1137//===----------------------------------------------------------------------===//
1138// Test Dataflow
1139//===----------------------------------------------------------------------===//
1140
1141//===----------------------------------------------------------------------===//
1142// TestCallAndStoreOp
1143//===----------------------------------------------------------------------===//
1144
1145CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
1146 return getCallee();
1147}
1148
1149void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1150 setCalleeAttr(cast<SymbolRefAttr>(Val&: callee));
1151}
1152
1153Operation::operand_range TestCallAndStoreOp::getArgOperands() {
1154 return getCalleeOperands();
1155}
1156
1157MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
1158 return getCalleeOperandsMutable();
1159}
1160
1161//===----------------------------------------------------------------------===//
1162// TestCallOnDeviceOp
1163//===----------------------------------------------------------------------===//
1164
1165CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
1166 return getCallee();
1167}
1168
1169void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1170 setCalleeAttr(cast<SymbolRefAttr>(Val&: callee));
1171}
1172
1173Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
1174 return getForwardedOperands();
1175}
1176
1177MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
1178 return getForwardedOperandsMutable();
1179}
1180
1181//===----------------------------------------------------------------------===//
1182// TestStoreWithARegion
1183//===----------------------------------------------------------------------===//
1184
1185void TestStoreWithARegion::getSuccessorRegions(
1186 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1187 if (point.isParent())
1188 regions.emplace_back(Args: &getBody(), Args: getBody().front().getArguments());
1189 else
1190 regions.emplace_back();
1191}
1192
1193//===----------------------------------------------------------------------===//
1194// TestStoreWithALoopRegion
1195//===----------------------------------------------------------------------===//
1196
1197void TestStoreWithALoopRegion::getSuccessorRegions(
1198 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1199 // Both the operation itself and the region may be branching into the body or
1200 // back into the operation itself. It is possible for the operation not to
1201 // enter the body.
1202 regions.emplace_back(
1203 Args: RegionSuccessor(&getBody(), getBody().front().getArguments()));
1204 regions.emplace_back();
1205}
1206
1207//===----------------------------------------------------------------------===//
1208// TestVersionedOpA
1209//===----------------------------------------------------------------------===//
1210
1211LogicalResult
1212TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader,
1213 mlir::OperationState &state) {
1214 auto &prop = state.getOrAddProperties<Properties>();
1215 if (mlir::failed(Result: reader.readAttribute(result&: prop.dims)))
1216 return mlir::failure();
1217
1218 // Check if we have a version. If not, assume we are parsing the current
1219 // version.
1220 auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1221 if (succeeded(Result: maybeVersion)) {
1222 // If version is less than 2.0, there is no additional attribute to parse.
1223 // We can materialize missing properties post parsing before verification.
1224 const auto *version =
1225 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1226 if ((version->major_ < 2)) {
1227 return success();
1228 }
1229 }
1230
1231 if (mlir::failed(Result: reader.readAttribute(result&: prop.modifier)))
1232 return mlir::failure();
1233 return mlir::success();
1234}
1235
1236void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
1237 auto &prop = getProperties();
1238 writer.writeAttribute(attr: prop.dims);
1239
1240 auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
1241 if (succeeded(Result: maybeVersion)) {
1242 // If version is less than 2.0, there is no additional attribute to write.
1243 const auto *version =
1244 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1245 if ((version->major_ < 2)) {
1246 llvm::outs() << "downgrading op properties...\n";
1247 return;
1248 }
1249 }
1250 writer.writeAttribute(attr: prop.modifier);
1251}
1252
1253//===----------------------------------------------------------------------===//
1254// TestOpWithVersionedProperties
1255//===----------------------------------------------------------------------===//
1256
1257llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
1258 mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
1259 uint64_t value1, value2 = 0;
1260 if (failed(Result: reader.readVarInt(result&: value1)))
1261 return failure();
1262
1263 // Check if we have a version. If not, assume we are parsing the current
1264 // version.
1265 auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1266 bool needToParseAnotherInt = true;
1267 if (succeeded(Result: maybeVersion)) {
1268 // If version is less than 2.0, there is no additional attribute to parse.
1269 // We can materialize missing properties post parsing before verification.
1270 const auto *version =
1271 reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1272 if ((version->major_ < 2))
1273 needToParseAnotherInt = false;
1274 }
1275 if (needToParseAnotherInt && failed(Result: reader.readVarInt(result&: value2)))
1276 return failure();
1277
1278 prop.value1 = value1;
1279 prop.value2 = value2;
1280 return success();
1281}
1282
1283void TestOpWithVersionedProperties::writeToMlirBytecode(
1284 mlir::DialectBytecodeWriter &writer,
1285 const test::VersionedProperties &prop) {
1286 writer.writeVarInt(value: prop.value1);
1287 writer.writeVarInt(value: prop.value2);
1288}
1289
1290//===----------------------------------------------------------------------===//
1291// TestMultiSlotAlloca
1292//===----------------------------------------------------------------------===//
1293
1294llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
1295 SmallVector<MemorySlot> slots;
1296 for (Value result : getResults()) {
1297 slots.push_back(Elt: MemorySlot{
1298 .ptr: result, .elemType: cast<MemRefType>(Val: result.getType()).getElementType()});
1299 }
1300 return slots;
1301}
1302
1303Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
1304 OpBuilder &builder) {
1305 return builder.create<TestOpConstant>(location: getLoc(), args: slot.elemType,
1306 args: builder.getI32IntegerAttr(value: 42));
1307}
1308
1309void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
1310 BlockArgument argument,
1311 OpBuilder &builder) {
1312 // Not relevant for testing.
1313}
1314
1315/// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
1316static std::optional<TestMultiSlotAlloca>
1317createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
1318 TestMultiSlotAlloca oldOp) {
1319
1320 if (oldOp.getNumResults() == 1) {
1321 oldOp.erase();
1322 return std::nullopt;
1323 }
1324
1325 SmallVector<Type> newTypes;
1326 SmallVector<Value> remainingValues;
1327
1328 for (Value oldResult : oldOp.getResults()) {
1329 if (oldResult == slot.ptr)
1330 continue;
1331 remainingValues.push_back(Elt: oldResult);
1332 newTypes.push_back(Elt: oldResult.getType());
1333 }
1334
1335 OpBuilder::InsertionGuard guard(builder);
1336 builder.setInsertionPoint(oldOp);
1337 auto replacement =
1338 builder.create<TestMultiSlotAlloca>(location: oldOp->getLoc(), args&: newTypes);
1339 for (auto [oldResult, newResult] :
1340 llvm::zip_equal(t&: remainingValues, u: replacement.getResults()))
1341 oldResult.replaceAllUsesWith(newValue: newResult);
1342
1343 oldOp.erase();
1344 return replacement;
1345}
1346
1347std::optional<PromotableAllocationOpInterface>
1348TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
1349 Value defaultValue,
1350 OpBuilder &builder) {
1351 if (defaultValue && defaultValue.use_empty())
1352 defaultValue.getDefiningOp()->erase();
1353 return createNewMultiAllocaWithoutSlot(slot, builder, oldOp: *this);
1354}
1355
1356SmallVector<DestructurableMemorySlot>
1357TestMultiSlotAlloca::getDestructurableSlots() {
1358 SmallVector<DestructurableMemorySlot> slots;
1359 for (Value result : getResults()) {
1360 auto memrefType = cast<MemRefType>(Val: result.getType());
1361 auto destructurable = dyn_cast<DestructurableTypeInterface>(Val&: memrefType);
1362 if (!destructurable)
1363 continue;
1364
1365 std::optional<DenseMap<Attribute, Type>> destructuredType =
1366 destructurable.getSubelementIndexMap();
1367 if (!destructuredType)
1368 continue;
1369 slots.emplace_back(
1370 Args: DestructurableMemorySlot{{.ptr: result, .elemType: memrefType}, .subelementTypes: *destructuredType});
1371 }
1372 return slots;
1373}
1374
1375DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
1376 const DestructurableMemorySlot &slot,
1377 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
1378 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
1379 OpBuilder::InsertionGuard guard(builder);
1380 builder.setInsertionPointAfter(*this);
1381
1382 DenseMap<Attribute, MemorySlot> slotMap;
1383
1384 for (Attribute usedIndex : usedIndices) {
1385 Type elemType = slot.subelementTypes.lookup(Val: usedIndex);
1386 MemRefType elemPtr = MemRefType::get(shape: {}, elementType: elemType);
1387 auto subAlloca = builder.create<TestMultiSlotAlloca>(location: getLoc(), args&: elemPtr);
1388 newAllocators.push_back(Elt: subAlloca);
1389 slotMap.try_emplace<MemorySlot>(Key: usedIndex,
1390 Args: {.ptr: subAlloca.getResult(i: 0), .elemType: elemType});
1391 }
1392
1393 return slotMap;
1394}
1395
1396std::optional<DestructurableAllocationOpInterface>
1397TestMultiSlotAlloca::handleDestructuringComplete(
1398 const DestructurableMemorySlot &slot, OpBuilder &builder) {
1399 return createNewMultiAllocaWithoutSlot(slot, builder, oldOp: *this);
1400}
1401
1402::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
1403 ::mlir::RewriterBase &rewriter,
1404 const ::mlir::bufferization::BufferizationOptions &options,
1405 ::mlir::bufferization::BufferizationState &state) {
1406 auto buffer =
1407 mlir::bufferization::getBuffer(rewriter, value: getInput(), options, state);
1408 if (mlir::failed(Result: buffer))
1409 return failure();
1410
1411 const auto outType = getOutput().getType();
1412 const auto bufferizedOutType = test::TestMemrefType::get(
1413 context: getContext(), shape: outType.getShape(), elementType: outType.getElementType(), memSpace: nullptr);
1414 // replace op with memref analogy
1415 auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>(
1416 location: getLoc(), args: bufferizedOutType, args&: *buffer);
1417
1418 mlir::bufferization::replaceOpWithBufferizedValues(rewriter, op: getOperation(),
1419 values: dummyMemrefOp.getResult());
1420
1421 return mlir::success();
1422}
1423
1424::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
1425 ::mlir::RewriterBase &rewriter,
1426 const ::mlir::bufferization::BufferizationOptions &options,
1427 ::mlir::bufferization::BufferizationState &state) {
1428 // Note: mlir::bufferization::getBufferType() would internally call
1429 // TestCreateTensorOp::getBufferType()
1430 const auto bufferizedOutType =
1431 mlir::bufferization::getBufferType(value: getOutput(), options, state);
1432 if (mlir::failed(Result: bufferizedOutType))
1433 return failure();
1434
1435 // replace op with memref analogy
1436 auto createMemrefOp =
1437 rewriter.create<test::TestCreateMemrefOp>(location: getLoc(), args: *bufferizedOutType);
1438
1439 mlir::bufferization::replaceOpWithBufferizedValues(
1440 rewriter, op: getOperation(), values: createMemrefOp.getResult());
1441
1442 return mlir::success();
1443}
1444
1445mlir::FailureOr<mlir::bufferization::BufferLikeType>
1446test::TestCreateTensorOp::getBufferType(
1447 mlir::Value value, const mlir::bufferization::BufferizationOptions &,
1448 const mlir::bufferization::BufferizationState &,
1449 llvm::SmallVector<::mlir::Value> &) {
1450 const auto type = dyn_cast<test::TestTensorType>(Val: value.getType());
1451 if (type == nullptr)
1452 return failure();
1453
1454 return cast<mlir::bufferization::BufferLikeType>(Val: test::TestMemrefType::get(
1455 context: getContext(), shape: type.getShape(), elementType: type.getElementType(), memSpace: nullptr));
1456}
1457

source code of mlir/test/lib/Dialect/Test/TestOpDefs.cpp