1 | //===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "TestOps.h" |
11 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
12 | #include "mlir/IR/Verifier.h" |
13 | #include "mlir/Interfaces/FunctionImplementation.h" |
14 | |
15 | using namespace mlir; |
16 | using namespace test; |
17 | |
18 | //===----------------------------------------------------------------------===// |
19 | // TestBranchOp |
20 | //===----------------------------------------------------------------------===// |
21 | |
22 | SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { |
23 | assert(index == 0 && "invalid successor index" ); |
24 | return SuccessorOperands(getTargetOperandsMutable()); |
25 | } |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // TestProducingBranchOp |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { |
32 | assert(index <= 1 && "invalid successor index" ); |
33 | if (index == 1) |
34 | return SuccessorOperands(getFirstOperandsMutable()); |
35 | return SuccessorOperands(getSecondOperandsMutable()); |
36 | } |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // TestInternalBranchOp |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { |
43 | assert(index <= 1 && "invalid successor index" ); |
44 | if (index == 0) |
45 | return SuccessorOperands(0, getSuccessOperandsMutable()); |
46 | return SuccessorOperands(1, getErrorOperandsMutable()); |
47 | } |
48 | |
49 | //===----------------------------------------------------------------------===// |
50 | // TestCallOp |
51 | //===----------------------------------------------------------------------===// |
52 | |
53 | LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
54 | // Check that the callee attribute was specified. |
55 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee" ); |
56 | if (!fnAttr) |
57 | return emitOpError("requires a 'callee' symbol reference attribute" ); |
58 | if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) |
59 | return emitOpError() << "'" << fnAttr.getValue() |
60 | << "' does not reference a valid function" ; |
61 | return success(); |
62 | } |
63 | |
64 | //===----------------------------------------------------------------------===// |
65 | // FoldToCallOp |
66 | //===----------------------------------------------------------------------===// |
67 | |
68 | namespace { |
69 | struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { |
70 | using OpRewritePattern<FoldToCallOp>::OpRewritePattern; |
71 | |
72 | LogicalResult matchAndRewrite(FoldToCallOp op, |
73 | PatternRewriter &rewriter) const override { |
74 | rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), |
75 | op.getCalleeAttr(), ValueRange()); |
76 | return success(); |
77 | } |
78 | }; |
79 | } // namespace |
80 | |
81 | void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, |
82 | MLIRContext *context) { |
83 | results.add<FoldToCallOpPattern>(context); |
84 | } |
85 | |
86 | //===----------------------------------------------------------------------===// |
87 | // IsolatedRegionOp - test parsing passthrough operands |
88 | //===----------------------------------------------------------------------===// |
89 | |
90 | ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, |
91 | OperationState &result) { |
92 | // Parse the input operand. |
93 | OpAsmParser::Argument argInfo; |
94 | argInfo.type = parser.getBuilder().getIndexType(); |
95 | if (parser.parseOperand(argInfo.ssaName) || |
96 | parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) |
97 | return failure(); |
98 | |
99 | // Parse the body region, and reuse the operand info as the argument info. |
100 | Region *body = result.addRegion(); |
101 | return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); |
102 | } |
103 | |
104 | void IsolatedRegionOp::print(OpAsmPrinter &p) { |
105 | p << ' '; |
106 | p.printOperand(getOperand()); |
107 | p.shadowRegionArgs(getRegion(), getOperand()); |
108 | p << ' '; |
109 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
110 | } |
111 | |
112 | //===----------------------------------------------------------------------===// |
113 | // SSACFGRegionOp |
114 | //===----------------------------------------------------------------------===// |
115 | |
116 | RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { |
117 | return RegionKind::SSACFG; |
118 | } |
119 | |
120 | //===----------------------------------------------------------------------===// |
121 | // GraphRegionOp |
122 | //===----------------------------------------------------------------------===// |
123 | |
124 | RegionKind GraphRegionOp::getRegionKind(unsigned index) { |
125 | return RegionKind::Graph; |
126 | } |
127 | |
128 | //===----------------------------------------------------------------------===// |
129 | // AffineScopeOp |
130 | //===----------------------------------------------------------------------===// |
131 | |
132 | ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { |
133 | // Parse the body region, and reuse the operand info as the argument info. |
134 | Region *body = result.addRegion(); |
135 | return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); |
136 | } |
137 | |
138 | void AffineScopeOp::print(OpAsmPrinter &p) { |
139 | p << " " ; |
140 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
141 | } |
142 | |
143 | //===----------------------------------------------------------------------===// |
144 | // TestRemoveOpWithInnerOps |
145 | //===----------------------------------------------------------------------===// |
146 | |
147 | namespace { |
148 | struct TestRemoveOpWithInnerOps |
149 | : public OpRewritePattern<TestOpWithRegionPattern> { |
150 | using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; |
151 | |
152 | void initialize() { setDebugName("TestRemoveOpWithInnerOps" ); } |
153 | |
154 | LogicalResult matchAndRewrite(TestOpWithRegionPattern op, |
155 | PatternRewriter &rewriter) const override { |
156 | rewriter.eraseOp(op: op); |
157 | return success(); |
158 | } |
159 | }; |
160 | } // namespace |
161 | |
162 | //===----------------------------------------------------------------------===// |
163 | // TestOpWithRegionPattern |
164 | //===----------------------------------------------------------------------===// |
165 | |
166 | void TestOpWithRegionPattern::getCanonicalizationPatterns( |
167 | RewritePatternSet &results, MLIRContext *context) { |
168 | results.add<TestRemoveOpWithInnerOps>(context); |
169 | } |
170 | |
171 | //===----------------------------------------------------------------------===// |
172 | // TestOpWithRegionFold |
173 | //===----------------------------------------------------------------------===// |
174 | |
175 | OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) { |
176 | return getOperand(); |
177 | } |
178 | |
179 | //===----------------------------------------------------------------------===// |
180 | // TestOpConstant |
181 | //===----------------------------------------------------------------------===// |
182 | |
183 | OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } |
184 | |
185 | //===----------------------------------------------------------------------===// |
186 | // TestOpWithVariadicResultsAndFolder |
187 | //===----------------------------------------------------------------------===// |
188 | |
189 | LogicalResult TestOpWithVariadicResultsAndFolder::fold( |
190 | FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { |
191 | for (Value input : this->getOperands()) { |
192 | results.push_back(input); |
193 | } |
194 | return success(); |
195 | } |
196 | |
197 | //===----------------------------------------------------------------------===// |
198 | // TestOpInPlaceFold |
199 | //===----------------------------------------------------------------------===// |
200 | |
201 | OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { |
202 | // Exercise the fact that an operation created with createOrFold should be |
203 | // allowed to access its parent block. |
204 | assert(getOperation()->getBlock() && |
205 | "expected that operation is not unlinked" ); |
206 | |
207 | if (adaptor.getOp() && !getProperties().attr) { |
208 | // The folder adds "attr" if not present. |
209 | getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()); |
210 | return getResult(); |
211 | } |
212 | return {}; |
213 | } |
214 | |
215 | //===----------------------------------------------------------------------===// |
216 | // OpWithInferTypeInterfaceOp |
217 | //===----------------------------------------------------------------------===// |
218 | |
219 | LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( |
220 | MLIRContext *, std::optional<Location> location, ValueRange operands, |
221 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
222 | SmallVectorImpl<Type> &inferredReturnTypes) { |
223 | if (operands[0].getType() != operands[1].getType()) { |
224 | return emitOptionalError(location, "operand type mismatch " , |
225 | operands[0].getType(), " vs " , |
226 | operands[1].getType()); |
227 | } |
228 | inferredReturnTypes.assign({operands[0].getType()}); |
229 | return success(); |
230 | } |
231 | |
232 | //===----------------------------------------------------------------------===// |
233 | // OpWithShapedTypeInferTypeInterfaceOp |
234 | //===----------------------------------------------------------------------===// |
235 | |
236 | LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( |
237 | MLIRContext *context, std::optional<Location> location, |
238 | ValueShapeRange operands, DictionaryAttr attributes, |
239 | OpaqueProperties properties, RegionRange regions, |
240 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
241 | // Create return type consisting of the last element of the first operand. |
242 | auto operandType = operands.front().getType(); |
243 | auto sval = dyn_cast<ShapedType>(operandType); |
244 | if (!sval) |
245 | return emitOptionalError(location, "only shaped type operands allowed" ); |
246 | int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; |
247 | auto type = IntegerType::get(context, 17); |
248 | |
249 | Attribute encoding; |
250 | if (auto rankedTy = dyn_cast<RankedTensorType>(sval)) |
251 | encoding = rankedTy.getEncoding(); |
252 | inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); |
253 | return success(); |
254 | } |
255 | |
256 | LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( |
257 | OpBuilder &builder, ValueRange operands, |
258 | llvm::SmallVectorImpl<Value> &shapes) { |
259 | shapes = SmallVector<Value, 1>{ |
260 | builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; |
261 | return success(); |
262 | } |
263 | |
264 | //===----------------------------------------------------------------------===// |
265 | // OpWithResultShapeInterfaceOp |
266 | //===----------------------------------------------------------------------===// |
267 | |
268 | LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( |
269 | OpBuilder &builder, ValueRange operands, |
270 | llvm::SmallVectorImpl<Value> &shapes) { |
271 | Location loc = getLoc(); |
272 | shapes.reserve(operands.size()); |
273 | for (Value operand : llvm::reverse(operands)) { |
274 | auto rank = cast<RankedTensorType>(operand.getType()).getRank(); |
275 | auto currShape = llvm::to_vector<4>( |
276 | llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { |
277 | return builder.createOrFold<tensor::DimOp>(loc, operand, dim); |
278 | })); |
279 | shapes.push_back(builder.create<tensor::FromElementsOp>( |
280 | getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), |
281 | currShape)); |
282 | } |
283 | return success(); |
284 | } |
285 | |
286 | //===----------------------------------------------------------------------===// |
287 | // OpWithResultShapePerDimInterfaceOp |
288 | //===----------------------------------------------------------------------===// |
289 | |
290 | LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( |
291 | OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { |
292 | Location loc = getLoc(); |
293 | shapes.reserve(getNumOperands()); |
294 | for (Value operand : llvm::reverse(getOperands())) { |
295 | auto tensorType = cast<RankedTensorType>(operand.getType()); |
296 | auto currShape = llvm::to_vector<4>(llvm::map_range( |
297 | llvm::seq<int64_t>(0, tensorType.getRank()), |
298 | [&](int64_t dim) -> OpFoldResult { |
299 | return tensorType.isDynamicDim(dim) |
300 | ? static_cast<OpFoldResult>( |
301 | builder.createOrFold<tensor::DimOp>(loc, operand, |
302 | dim)) |
303 | : static_cast<OpFoldResult>( |
304 | builder.getIndexAttr(tensorType.getDimSize(dim))); |
305 | })); |
306 | shapes.emplace_back(std::move(currShape)); |
307 | } |
308 | return success(); |
309 | } |
310 | |
311 | //===----------------------------------------------------------------------===// |
312 | // SideEffectOp |
313 | //===----------------------------------------------------------------------===// |
314 | |
315 | namespace { |
316 | /// A test resource for side effects. |
317 | struct TestResource : public SideEffects::Resource::Base<TestResource> { |
318 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) |
319 | |
320 | StringRef getName() final { return "<Test>" ; } |
321 | }; |
322 | } // namespace |
323 | |
324 | void SideEffectOp::getEffects( |
325 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
326 | // Check for an effects attribute on the op instance. |
327 | ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects" ); |
328 | if (!effectsAttr) |
329 | return; |
330 | |
331 | // If there is one, it is an array of dictionary attributes that hold |
332 | // information on the effects of this operation. |
333 | for (Attribute element : effectsAttr) { |
334 | DictionaryAttr effectElement = cast<DictionaryAttr>(element); |
335 | |
336 | // Get the specific memory effect. |
337 | MemoryEffects::Effect *effect = |
338 | StringSwitch<MemoryEffects::Effect *>( |
339 | cast<StringAttr>(effectElement.get("effect" )).getValue()) |
340 | .Case("allocate" , MemoryEffects::Allocate::get()) |
341 | .Case("free" , MemoryEffects::Free::get()) |
342 | .Case("read" , MemoryEffects::Read::get()) |
343 | .Case("write" , MemoryEffects::Write::get()); |
344 | |
345 | // Check for a non-default resource to use. |
346 | SideEffects::Resource *resource = SideEffects::DefaultResource::get(); |
347 | if (effectElement.get("test_resource" )) |
348 | resource = TestResource::get(); |
349 | |
350 | // Check for a result to affect. |
351 | if (effectElement.get("on_result" )) |
352 | effects.emplace_back(effect, getResult(), resource); |
353 | else if (Attribute ref = effectElement.get("on_reference" )) |
354 | effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource); |
355 | else |
356 | effects.emplace_back(effect, resource); |
357 | } |
358 | } |
359 | |
360 | void SideEffectOp::getEffects( |
361 | SmallVectorImpl<TestEffects::EffectInstance> &effects) { |
362 | testSideEffectOpGetEffect(getOperation(), effects); |
363 | } |
364 | |
365 | //===----------------------------------------------------------------------===// |
366 | // StringAttrPrettyNameOp |
367 | //===----------------------------------------------------------------------===// |
368 | |
369 | // This op has fancy handling of its SSA result name. |
370 | ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, |
371 | OperationState &result) { |
372 | // Add the result types. |
373 | for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) |
374 | result.addTypes(parser.getBuilder().getIntegerType(32)); |
375 | |
376 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
377 | return failure(); |
378 | |
379 | // If the attribute dictionary contains no 'names' attribute, infer it from |
380 | // the SSA name (if specified). |
381 | bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { |
382 | return attr.getName() == "names" ; |
383 | }); |
384 | |
385 | // If there was no name specified, check to see if there was a useful name |
386 | // specified in the asm file. |
387 | if (hadNames || parser.getNumResults() == 0) |
388 | return success(); |
389 | |
390 | SmallVector<StringRef, 4> names; |
391 | auto *context = result.getContext(); |
392 | |
393 | for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { |
394 | auto resultName = parser.getResultName(i); |
395 | StringRef nameStr; |
396 | if (!resultName.first.empty() && !isdigit(resultName.first[0])) |
397 | nameStr = resultName.first; |
398 | |
399 | names.push_back(nameStr); |
400 | } |
401 | |
402 | auto namesAttr = parser.getBuilder().getStrArrayAttr(names); |
403 | result.attributes.push_back({StringAttr::get(context, "names" ), namesAttr}); |
404 | return success(); |
405 | } |
406 | |
407 | void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { |
408 | // Note that we only need to print the "name" attribute if the asmprinter |
409 | // result name disagrees with it. This can happen in strange cases, e.g. |
410 | // when there are conflicts. |
411 | bool namesDisagree = getNames().size() != getNumResults(); |
412 | |
413 | SmallString<32> resultNameStr; |
414 | for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { |
415 | resultNameStr.clear(); |
416 | llvm::raw_svector_ostream tmpStream(resultNameStr); |
417 | p.printOperand(getResult(i), tmpStream); |
418 | |
419 | auto expectedName = dyn_cast<StringAttr>(getNames()[i]); |
420 | if (!expectedName || |
421 | tmpStream.str().drop_front() != expectedName.getValue()) { |
422 | namesDisagree = true; |
423 | } |
424 | } |
425 | |
426 | if (namesDisagree) |
427 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); |
428 | else |
429 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names" }); |
430 | } |
431 | |
432 | // We set the SSA name in the asm syntax to the contents of the name |
433 | // attribute. |
434 | void StringAttrPrettyNameOp::getAsmResultNames( |
435 | function_ref<void(Value, StringRef)> setNameFn) { |
436 | |
437 | auto value = getNames(); |
438 | for (size_t i = 0, e = value.size(); i != e; ++i) |
439 | if (auto str = dyn_cast<StringAttr>(value[i])) |
440 | if (!str.getValue().empty()) |
441 | setNameFn(getResult(i), str.getValue()); |
442 | } |
443 | |
444 | //===----------------------------------------------------------------------===// |
445 | // CustomResultsNameOp |
446 | //===----------------------------------------------------------------------===// |
447 | |
448 | void CustomResultsNameOp::getAsmResultNames( |
449 | function_ref<void(Value, StringRef)> setNameFn) { |
450 | ArrayAttr value = getNames(); |
451 | for (size_t i = 0, e = value.size(); i != e; ++i) |
452 | if (auto str = dyn_cast<StringAttr>(value[i])) |
453 | if (!str.empty()) |
454 | setNameFn(getResult(i), str.getValue()); |
455 | } |
456 | |
457 | //===----------------------------------------------------------------------===// |
458 | // ResultTypeWithTraitOp |
459 | //===----------------------------------------------------------------------===// |
460 | |
461 | LogicalResult ResultTypeWithTraitOp::verify() { |
462 | if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) |
463 | return success(); |
464 | return emitError("result type should have trait 'TestTypeTrait'" ); |
465 | } |
466 | |
467 | //===----------------------------------------------------------------------===// |
468 | // AttrWithTraitOp |
469 | //===----------------------------------------------------------------------===// |
470 | |
471 | LogicalResult AttrWithTraitOp::verify() { |
472 | if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) |
473 | return success(); |
474 | return emitError("'attr' attribute should have trait 'TestAttrTrait'" ); |
475 | } |
476 | |
477 | //===----------------------------------------------------------------------===// |
478 | // RegionIfOp |
479 | //===----------------------------------------------------------------------===// |
480 | |
481 | void RegionIfOp::print(OpAsmPrinter &p) { |
482 | p << " " ; |
483 | p.printOperands(getOperands()); |
484 | p << ": " << getOperandTypes(); |
485 | p.printArrowTypeList(getResultTypes()); |
486 | p << " then " ; |
487 | p.printRegion(getThenRegion(), |
488 | /*printEntryBlockArgs=*/true, |
489 | /*printBlockTerminators=*/true); |
490 | p << " else " ; |
491 | p.printRegion(getElseRegion(), |
492 | /*printEntryBlockArgs=*/true, |
493 | /*printBlockTerminators=*/true); |
494 | p << " join " ; |
495 | p.printRegion(getJoinRegion(), |
496 | /*printEntryBlockArgs=*/true, |
497 | /*printBlockTerminators=*/true); |
498 | } |
499 | |
500 | ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { |
501 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; |
502 | SmallVector<Type, 2> operandTypes; |
503 | |
504 | result.regions.reserve(3); |
505 | Region *thenRegion = result.addRegion(); |
506 | Region *elseRegion = result.addRegion(); |
507 | Region *joinRegion = result.addRegion(); |
508 | |
509 | // Parse operand, type and arrow type lists. |
510 | if (parser.parseOperandList(operandInfos) || |
511 | parser.parseColonTypeList(operandTypes) || |
512 | parser.parseArrowTypeList(result.types)) |
513 | return failure(); |
514 | |
515 | // Parse all attached regions. |
516 | if (parser.parseKeyword("then" ) || parser.parseRegion(*thenRegion, {}, {}) || |
517 | parser.parseKeyword("else" ) || parser.parseRegion(*elseRegion, {}, {}) || |
518 | parser.parseKeyword("join" ) || parser.parseRegion(*joinRegion, {}, {})) |
519 | return failure(); |
520 | |
521 | return parser.resolveOperands(operandInfos, operandTypes, |
522 | parser.getCurrentLocation(), result.operands); |
523 | } |
524 | |
525 | OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
526 | assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && |
527 | "invalid region index" ); |
528 | return getOperands(); |
529 | } |
530 | |
531 | void RegionIfOp::getSuccessorRegions( |
532 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
533 | // We always branch to the join region. |
534 | if (!point.isParent()) { |
535 | if (point != getJoinRegion()) |
536 | regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); |
537 | else |
538 | regions.push_back(RegionSuccessor(getResults())); |
539 | return; |
540 | } |
541 | |
542 | // The then and else regions are the entry regions of this op. |
543 | regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); |
544 | regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); |
545 | } |
546 | |
547 | void RegionIfOp::getRegionInvocationBounds( |
548 | ArrayRef<Attribute> operands, |
549 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
550 | // Each region is invoked at most once. |
551 | invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); |
552 | } |
553 | |
554 | //===----------------------------------------------------------------------===// |
555 | // AnyCondOp |
556 | //===----------------------------------------------------------------------===// |
557 | |
558 | void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, |
559 | SmallVectorImpl<RegionSuccessor> ®ions) { |
560 | // The parent op branches into the only region, and the region branches back |
561 | // to the parent op. |
562 | if (point.isParent()) |
563 | regions.emplace_back(&getRegion()); |
564 | else |
565 | regions.emplace_back(getResults()); |
566 | } |
567 | |
568 | void AnyCondOp::getRegionInvocationBounds( |
569 | ArrayRef<Attribute> operands, |
570 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
571 | invocationBounds.emplace_back(1, 1); |
572 | } |
573 | |
574 | //===----------------------------------------------------------------------===// |
575 | // SingleBlockImplicitTerminatorOp |
576 | //===----------------------------------------------------------------------===// |
577 | |
578 | /// Testing the correctness of some traits. |
579 | static_assert( |
580 | llvm::is_detected<OpTrait::has_implicit_terminator_t, |
581 | SingleBlockImplicitTerminatorOp>::value, |
582 | "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp" ); |
583 | static_assert(OpTrait::hasSingleBlockImplicitTerminator< |
584 | SingleBlockImplicitTerminatorOp>::value, |
585 | "hasSingleBlockImplicitTerminator does not match " |
586 | "SingleBlockImplicitTerminatorOp" ); |
587 | |
588 | //===----------------------------------------------------------------------===// |
589 | // SingleNoTerminatorCustomAsmOp |
590 | //===----------------------------------------------------------------------===// |
591 | |
592 | ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, |
593 | OperationState &state) { |
594 | Region *body = state.addRegion(); |
595 | if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) |
596 | return failure(); |
597 | return success(); |
598 | } |
599 | |
600 | void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { |
601 | printer.printRegion( |
602 | getRegion(), /*printEntryBlockArgs=*/false, |
603 | // This op has a single block without terminators. But explicitly mark |
604 | // as not printing block terminators for testing. |
605 | /*printBlockTerminators=*/false); |
606 | } |
607 | |
608 | //===----------------------------------------------------------------------===// |
609 | // TestVerifiersOp |
610 | //===----------------------------------------------------------------------===// |
611 | |
612 | LogicalResult TestVerifiersOp::verify() { |
613 | if (!getRegion().hasOneBlock()) |
614 | return emitOpError("`hasOneBlock` trait hasn't been verified" ); |
615 | |
616 | Operation *definingOp = getInput().getDefiningOp(); |
617 | if (definingOp && failed(mlir::verify(definingOp))) |
618 | return emitOpError("operand hasn't been verified" ); |
619 | |
620 | // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier |
621 | // loop. |
622 | mlir::emitRemark(getLoc(), "success run of verifier" ); |
623 | |
624 | return success(); |
625 | } |
626 | |
627 | LogicalResult TestVerifiersOp::verifyRegions() { |
628 | if (!getRegion().hasOneBlock()) |
629 | return emitOpError("`hasOneBlock` trait hasn't been verified" ); |
630 | |
631 | for (Block &block : getRegion()) |
632 | for (Operation &op : block) |
633 | if (failed(mlir::verify(&op))) |
634 | return emitOpError("nested op hasn't been verified" ); |
635 | |
636 | // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier |
637 | // loop. |
638 | mlir::emitRemark(getLoc(), "success run of region verifier" ); |
639 | |
640 | return success(); |
641 | } |
642 | |
643 | //===----------------------------------------------------------------------===// |
644 | // Test InferIntRangeInterface |
645 | //===----------------------------------------------------------------------===// |
646 | |
647 | //===----------------------------------------------------------------------===// |
648 | // TestWithBoundsOp |
649 | |
650 | void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
651 | SetIntRangeFn setResultRanges) { |
652 | setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); |
653 | } |
654 | |
655 | //===----------------------------------------------------------------------===// |
656 | // TestWithBoundsRegionOp |
657 | |
658 | ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, |
659 | OperationState &result) { |
660 | if (parser.parseOptionalAttrDict(result.attributes)) |
661 | return failure(); |
662 | |
663 | // Parse the input argument |
664 | OpAsmParser::Argument argInfo; |
665 | argInfo.type = parser.getBuilder().getIndexType(); |
666 | if (failed(parser.parseArgument(argInfo))) |
667 | return failure(); |
668 | |
669 | // Parse the body region, and reuse the operand info as the argument info. |
670 | Region *body = result.addRegion(); |
671 | return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); |
672 | } |
673 | |
674 | void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { |
675 | p.printOptionalAttrDict((*this)->getAttrs()); |
676 | p << ' '; |
677 | p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, |
678 | /*omitType=*/true); |
679 | p << ' '; |
680 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
681 | } |
682 | |
683 | void TestWithBoundsRegionOp::inferResultRanges( |
684 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { |
685 | Value arg = getRegion().getArgument(0); |
686 | setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); |
687 | } |
688 | |
689 | //===----------------------------------------------------------------------===// |
690 | // TestIncrementOp |
691 | |
692 | void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
693 | SetIntRangeFn setResultRanges) { |
694 | const ConstantIntRanges &range = argRanges[0]; |
695 | APInt one(range.umin().getBitWidth(), 1); |
696 | setResultRanges(getResult(), |
697 | {range.umin().uadd_sat(one), range.umax().uadd_sat(one), |
698 | range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); |
699 | } |
700 | |
701 | //===----------------------------------------------------------------------===// |
702 | // TestReflectBoundsOp |
703 | |
704 | void TestReflectBoundsOp::inferResultRanges( |
705 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { |
706 | const ConstantIntRanges &range = argRanges[0]; |
707 | MLIRContext *ctx = getContext(); |
708 | Builder b(ctx); |
709 | setUminAttr(b.getIndexAttr(range.umin().getZExtValue())); |
710 | setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue())); |
711 | setSminAttr(b.getIndexAttr(range.smin().getSExtValue())); |
712 | setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue())); |
713 | setResultRanges(getResult(), range); |
714 | } |
715 | |
716 | //===----------------------------------------------------------------------===// |
717 | // ConversionFuncOp |
718 | //===----------------------------------------------------------------------===// |
719 | |
720 | ParseResult ConversionFuncOp::parse(OpAsmParser &parser, |
721 | OperationState &result) { |
722 | auto buildFuncType = |
723 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
724 | function_interface_impl::VariadicFlag, |
725 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
726 | |
727 | return function_interface_impl::parseFunctionOp( |
728 | parser, result, /*allowVariadic=*/false, |
729 | getFunctionTypeAttrName(result.name), buildFuncType, |
730 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
731 | } |
732 | |
733 | void ConversionFuncOp::print(OpAsmPrinter &p) { |
734 | function_interface_impl::printFunctionOp( |
735 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
736 | getArgAttrsAttrName(), getResAttrsAttrName()); |
737 | } |
738 | |
739 | //===----------------------------------------------------------------------===// |
740 | // ReifyBoundOp |
741 | //===----------------------------------------------------------------------===// |
742 | |
743 | mlir::presburger::BoundType ReifyBoundOp::getBoundType() { |
744 | if (getType() == "EQ" ) |
745 | return mlir::presburger::BoundType::EQ; |
746 | if (getType() == "LB" ) |
747 | return mlir::presburger::BoundType::LB; |
748 | if (getType() == "UB" ) |
749 | return mlir::presburger::BoundType::UB; |
750 | llvm_unreachable("invalid bound type" ); |
751 | } |
752 | |
753 | LogicalResult ReifyBoundOp::verify() { |
754 | if (isa<ShapedType>(getVar().getType())) { |
755 | if (!getDim().has_value()) |
756 | return emitOpError("expected 'dim' attribute for shaped type variable" ); |
757 | } else if (getVar().getType().isIndex()) { |
758 | if (getDim().has_value()) |
759 | return emitOpError("unexpected 'dim' attribute for index variable" ); |
760 | } else { |
761 | return emitOpError("expected index-typed variable or shape type variable" ); |
762 | } |
763 | if (getConstant() && getScalable()) |
764 | return emitOpError("'scalable' and 'constant' are mutually exlusive" ); |
765 | if (getScalable() != getVscaleMin().has_value()) |
766 | return emitOpError("expected 'vscale_min' if and only if 'scalable'" ); |
767 | if (getScalable() != getVscaleMax().has_value()) |
768 | return emitOpError("expected 'vscale_min' if and only if 'scalable'" ); |
769 | return success(); |
770 | } |
771 | |
772 | ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() { |
773 | if (getDim().has_value()) |
774 | return ValueBoundsConstraintSet::Variable(getVar(), *getDim()); |
775 | return ValueBoundsConstraintSet::Variable(getVar()); |
776 | } |
777 | |
778 | //===----------------------------------------------------------------------===// |
779 | // CompareOp |
780 | //===----------------------------------------------------------------------===// |
781 | |
782 | ValueBoundsConstraintSet::ComparisonOperator |
783 | CompareOp::getComparisonOperator() { |
784 | if (getCmp() == "EQ" ) |
785 | return ValueBoundsConstraintSet::ComparisonOperator::EQ; |
786 | if (getCmp() == "LT" ) |
787 | return ValueBoundsConstraintSet::ComparisonOperator::LT; |
788 | if (getCmp() == "LE" ) |
789 | return ValueBoundsConstraintSet::ComparisonOperator::LE; |
790 | if (getCmp() == "GT" ) |
791 | return ValueBoundsConstraintSet::ComparisonOperator::GT; |
792 | if (getCmp() == "GE" ) |
793 | return ValueBoundsConstraintSet::ComparisonOperator::GE; |
794 | llvm_unreachable("invalid comparison operator" ); |
795 | } |
796 | |
797 | mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() { |
798 | if (!getLhsMap()) |
799 | return ValueBoundsConstraintSet::Variable(getVarOperands()[0]); |
800 | SmallVector<Value> mapOperands( |
801 | getVarOperands().slice(0, getLhsMap()->getNumInputs())); |
802 | return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands); |
803 | } |
804 | |
805 | mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() { |
806 | int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1; |
807 | if (!getRhsMap()) |
808 | return ValueBoundsConstraintSet::Variable( |
809 | getVarOperands()[rhsOperandsBegin]); |
810 | SmallVector<Value> mapOperands( |
811 | getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs())); |
812 | return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands); |
813 | } |
814 | |
815 | LogicalResult CompareOp::verify() { |
816 | if (getCompose() && (getLhsMap() || getRhsMap())) |
817 | return emitOpError( |
818 | "'compose' not supported when 'lhs_map' or 'rhs_map' is present" ); |
819 | int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1; |
820 | expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1; |
821 | if (getVarOperands().size() != size_t(expectedNumOperands)) |
822 | return emitOpError("expected " ) |
823 | << expectedNumOperands << " operands, but got " |
824 | << getVarOperands().size(); |
825 | return success(); |
826 | } |
827 | |
828 | //===----------------------------------------------------------------------===// |
829 | // TestOpInPlaceSelfFold |
830 | //===----------------------------------------------------------------------===// |
831 | |
832 | OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) { |
833 | if (!getFolded()) { |
834 | // The folder adds the "folded" if not present. |
835 | setFolded(true); |
836 | return getResult(); |
837 | } |
838 | return {}; |
839 | } |
840 | |
841 | //===----------------------------------------------------------------------===// |
842 | // TestOpFoldWithFoldAdaptor |
843 | //===----------------------------------------------------------------------===// |
844 | |
845 | OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { |
846 | int64_t sum = 0; |
847 | if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp())) |
848 | sum += value.getValue().getSExtValue(); |
849 | |
850 | for (Attribute attr : adaptor.getVariadic()) |
851 | if (auto value = dyn_cast_or_null<IntegerAttr>(attr)) |
852 | sum += 2 * value.getValue().getSExtValue(); |
853 | |
854 | for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar()) |
855 | for (Attribute attr : attrs) |
856 | if (auto value = dyn_cast_or_null<IntegerAttr>(attr)) |
857 | sum += 3 * value.getValue().getSExtValue(); |
858 | |
859 | sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end()); |
860 | |
861 | return IntegerAttr::get(getType(), sum); |
862 | } |
863 | |
864 | //===----------------------------------------------------------------------===// |
865 | // OpWithInferTypeAdaptorInterfaceOp |
866 | //===----------------------------------------------------------------------===// |
867 | |
868 | LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes( |
869 | MLIRContext *, std::optional<Location> location, |
870 | OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor, |
871 | SmallVectorImpl<Type> &inferredReturnTypes) { |
872 | if (adaptor.getX().getType() != adaptor.getY().getType()) { |
873 | return emitOptionalError(location, "operand type mismatch " , |
874 | adaptor.getX().getType(), " vs " , |
875 | adaptor.getY().getType()); |
876 | } |
877 | inferredReturnTypes.assign({adaptor.getX().getType()}); |
878 | return success(); |
879 | } |
880 | |
881 | //===----------------------------------------------------------------------===// |
882 | // OpWithRefineTypeInterfaceOp |
883 | //===----------------------------------------------------------------------===// |
884 | |
885 | // TODO: We should be able to only define either inferReturnType or |
886 | // refineReturnType, currently only refineReturnType can be omitted. |
887 | LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( |
888 | MLIRContext *context, std::optional<Location> location, ValueRange operands, |
889 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
890 | SmallVectorImpl<Type> &returnTypes) { |
891 | returnTypes.clear(); |
892 | return OpWithRefineTypeInterfaceOp::refineReturnTypes( |
893 | context, location, operands, attributes, properties, regions, |
894 | returnTypes); |
895 | } |
896 | |
897 | LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( |
898 | MLIRContext *, std::optional<Location> location, ValueRange operands, |
899 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
900 | SmallVectorImpl<Type> &returnTypes) { |
901 | if (operands[0].getType() != operands[1].getType()) { |
902 | return emitOptionalError(location, "operand type mismatch " , |
903 | operands[0].getType(), " vs " , |
904 | operands[1].getType()); |
905 | } |
906 | // TODO: Add helper to make this more concise to write. |
907 | if (returnTypes.empty()) |
908 | returnTypes.resize(1, nullptr); |
909 | if (returnTypes[0] && returnTypes[0] != operands[0].getType()) |
910 | return emitOptionalError(location, |
911 | "required first operand and result to match" ); |
912 | returnTypes[0] = operands[0].getType(); |
913 | return success(); |
914 | } |
915 | |
916 | //===----------------------------------------------------------------------===// |
917 | // OpWithShapedTypeInferTypeAdaptorInterfaceOp |
918 | //===----------------------------------------------------------------------===// |
919 | |
920 | LogicalResult |
921 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents( |
922 | MLIRContext *context, std::optional<Location> location, |
923 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor, |
924 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
925 | // Create return type consisting of the last element of the first operand. |
926 | auto operandType = adaptor.getOperand1().getType(); |
927 | auto sval = dyn_cast<ShapedType>(operandType); |
928 | if (!sval) |
929 | return emitOptionalError(location, "only shaped type operands allowed" ); |
930 | int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; |
931 | auto type = IntegerType::get(context, 17); |
932 | |
933 | Attribute encoding; |
934 | if (auto rankedTy = dyn_cast<RankedTensorType>(sval)) |
935 | encoding = rankedTy.getEncoding(); |
936 | inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); |
937 | return success(); |
938 | } |
939 | |
940 | LogicalResult |
941 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes( |
942 | OpBuilder &builder, ValueRange operands, |
943 | llvm::SmallVectorImpl<Value> &shapes) { |
944 | shapes = SmallVector<Value, 1>{ |
945 | builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; |
946 | return success(); |
947 | } |
948 | |
949 | //===----------------------------------------------------------------------===// |
950 | // TestOpWithPropertiesAndInferredType |
951 | //===----------------------------------------------------------------------===// |
952 | |
953 | LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes( |
954 | MLIRContext *context, std::optional<Location>, ValueRange operands, |
955 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
956 | SmallVectorImpl<Type> &inferredReturnTypes) { |
957 | |
958 | Adaptor adaptor(operands, attributes, properties, regions); |
959 | inferredReturnTypes.push_back(IntegerType::get( |
960 | context, adaptor.getLhs() + adaptor.getProperties().rhs)); |
961 | return success(); |
962 | } |
963 | |
964 | //===----------------------------------------------------------------------===// |
965 | // LoopBlockOp |
966 | //===----------------------------------------------------------------------===// |
967 | |
968 | void LoopBlockOp::getSuccessorRegions( |
969 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
970 | regions.emplace_back(&getBody(), getBody().getArguments()); |
971 | if (point.isParent()) |
972 | return; |
973 | |
974 | regions.emplace_back((*this)->getResults()); |
975 | } |
976 | |
977 | OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
978 | assert(point == getBody()); |
979 | return MutableOperandRange(getInitMutable()); |
980 | } |
981 | |
982 | //===----------------------------------------------------------------------===// |
983 | // LoopBlockTerminatorOp |
984 | //===----------------------------------------------------------------------===// |
985 | |
986 | MutableOperandRange |
987 | LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { |
988 | if (point.isParent()) |
989 | return getExitArgMutable(); |
990 | return getNextIterArgMutable(); |
991 | } |
992 | |
993 | //===----------------------------------------------------------------------===// |
994 | // SwitchWithNoBreakOp |
995 | //===----------------------------------------------------------------------===// |
996 | |
997 | void TestNoTerminatorOp::getSuccessorRegions( |
998 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {} |
999 | |
1000 | //===----------------------------------------------------------------------===// |
1001 | // Test InferIntRangeInterface |
1002 | //===----------------------------------------------------------------------===// |
1003 | |
1004 | OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) { |
1005 | // Just a simple fold for testing purposes that reads an operands constant |
1006 | // value and returns it. |
1007 | if (!attributes.empty()) |
1008 | return attributes.front(); |
1009 | return nullptr; |
1010 | } |
1011 | |
1012 | //===----------------------------------------------------------------------===// |
1013 | // Tensor/Buffer Ops |
1014 | //===----------------------------------------------------------------------===// |
1015 | |
1016 | void ReadBufferOp::getEffects( |
1017 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
1018 | &effects) { |
1019 | // The buffer operand is read. |
1020 | effects.emplace_back(MemoryEffects::Read::get(), getBuffer(), |
1021 | SideEffects::DefaultResource::get()); |
1022 | // The buffer contents are dumped. |
1023 | effects.emplace_back(MemoryEffects::Write::get(), |
1024 | SideEffects::DefaultResource::get()); |
1025 | } |
1026 | |
1027 | //===----------------------------------------------------------------------===// |
1028 | // Test Dataflow |
1029 | //===----------------------------------------------------------------------===// |
1030 | |
1031 | //===----------------------------------------------------------------------===// |
1032 | // TestCallAndStoreOp |
1033 | |
1034 | CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() { |
1035 | return getCallee(); |
1036 | } |
1037 | |
1038 | void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
1039 | setCalleeAttr(callee.get<SymbolRefAttr>()); |
1040 | } |
1041 | |
1042 | Operation::operand_range TestCallAndStoreOp::getArgOperands() { |
1043 | return getCalleeOperands(); |
1044 | } |
1045 | |
1046 | MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() { |
1047 | return getCalleeOperandsMutable(); |
1048 | } |
1049 | |
1050 | //===----------------------------------------------------------------------===// |
1051 | // TestCallOnDeviceOp |
1052 | |
1053 | CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() { |
1054 | return getCallee(); |
1055 | } |
1056 | |
1057 | void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
1058 | setCalleeAttr(callee.get<SymbolRefAttr>()); |
1059 | } |
1060 | |
1061 | Operation::operand_range TestCallOnDeviceOp::getArgOperands() { |
1062 | return getForwardedOperands(); |
1063 | } |
1064 | |
1065 | MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { |
1066 | return getForwardedOperandsMutable(); |
1067 | } |
1068 | |
1069 | //===----------------------------------------------------------------------===// |
1070 | // TestStoreWithARegion |
1071 | |
1072 | void TestStoreWithARegion::getSuccessorRegions( |
1073 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
1074 | if (point.isParent()) |
1075 | regions.emplace_back(&getBody(), getBody().front().getArguments()); |
1076 | else |
1077 | regions.emplace_back(); |
1078 | } |
1079 | |
1080 | //===----------------------------------------------------------------------===// |
1081 | // TestStoreWithALoopRegion |
1082 | |
1083 | void TestStoreWithALoopRegion::getSuccessorRegions( |
1084 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
1085 | // Both the operation itself and the region may be branching into the body or |
1086 | // back into the operation itself. It is possible for the operation not to |
1087 | // enter the body. |
1088 | regions.emplace_back( |
1089 | RegionSuccessor(&getBody(), getBody().front().getArguments())); |
1090 | regions.emplace_back(); |
1091 | } |
1092 | |
1093 | //===----------------------------------------------------------------------===// |
1094 | // TestVersionedOpA |
1095 | //===----------------------------------------------------------------------===// |
1096 | |
1097 | LogicalResult |
1098 | TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader, |
1099 | mlir::OperationState &state) { |
1100 | auto &prop = state.getOrAddProperties<Properties>(); |
1101 | if (mlir::failed(reader.readAttribute(prop.dims))) |
1102 | return mlir::failure(); |
1103 | |
1104 | // Check if we have a version. If not, assume we are parsing the current |
1105 | // version. |
1106 | auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); |
1107 | if (succeeded(maybeVersion)) { |
1108 | // If version is less than 2.0, there is no additional attribute to parse. |
1109 | // We can materialize missing properties post parsing before verification. |
1110 | const auto *version = |
1111 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
1112 | if ((version->major_ < 2)) { |
1113 | return success(); |
1114 | } |
1115 | } |
1116 | |
1117 | if (mlir::failed(reader.readAttribute(prop.modifier))) |
1118 | return mlir::failure(); |
1119 | return mlir::success(); |
1120 | } |
1121 | |
1122 | void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) { |
1123 | auto &prop = getProperties(); |
1124 | writer.writeAttribute(prop.dims); |
1125 | |
1126 | auto maybeVersion = writer.getDialectVersion<test::TestDialect>(); |
1127 | if (succeeded(maybeVersion)) { |
1128 | // If version is less than 2.0, there is no additional attribute to write. |
1129 | const auto *version = |
1130 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
1131 | if ((version->major_ < 2)) { |
1132 | llvm::outs() << "downgrading op properties...\n" ; |
1133 | return; |
1134 | } |
1135 | } |
1136 | writer.writeAttribute(prop.modifier); |
1137 | } |
1138 | |
1139 | //===----------------------------------------------------------------------===// |
1140 | // TestOpWithVersionedProperties |
1141 | //===----------------------------------------------------------------------===// |
1142 | |
1143 | mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode( |
1144 | mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) { |
1145 | uint64_t value1, value2 = 0; |
1146 | if (failed(reader.readVarInt(value1))) |
1147 | return failure(); |
1148 | |
1149 | // Check if we have a version. If not, assume we are parsing the current |
1150 | // version. |
1151 | auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); |
1152 | bool needToParseAnotherInt = true; |
1153 | if (succeeded(maybeVersion)) { |
1154 | // If version is less than 2.0, there is no additional attribute to parse. |
1155 | // We can materialize missing properties post parsing before verification. |
1156 | const auto *version = |
1157 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
1158 | if ((version->major_ < 2)) |
1159 | needToParseAnotherInt = false; |
1160 | } |
1161 | if (needToParseAnotherInt && failed(reader.readVarInt(value2))) |
1162 | return failure(); |
1163 | |
1164 | prop.value1 = value1; |
1165 | prop.value2 = value2; |
1166 | return success(); |
1167 | } |
1168 | |
1169 | void TestOpWithVersionedProperties::writeToMlirBytecode( |
1170 | mlir::DialectBytecodeWriter &writer, |
1171 | const test::VersionedProperties &prop) { |
1172 | writer.writeVarInt(prop.value1); |
1173 | writer.writeVarInt(prop.value2); |
1174 | } |
1175 | |