1 | //===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "TestOps.h" |
11 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
12 | #include "mlir/IR/Verifier.h" |
13 | #include "mlir/Interfaces/FunctionImplementation.h" |
14 | #include "mlir/Interfaces/MemorySlotInterfaces.h" |
15 | |
16 | using namespace mlir; |
17 | using namespace test; |
18 | |
19 | //===----------------------------------------------------------------------===// |
20 | // TestBranchOp |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { |
24 | assert(index == 0 && "invalid successor index" ); |
25 | return SuccessorOperands(getTargetOperandsMutable()); |
26 | } |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // TestProducingBranchOp |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { |
33 | assert(index <= 1 && "invalid successor index" ); |
34 | if (index == 1) |
35 | return SuccessorOperands(getFirstOperandsMutable()); |
36 | return SuccessorOperands(getSecondOperandsMutable()); |
37 | } |
38 | |
39 | //===----------------------------------------------------------------------===// |
40 | // TestInternalBranchOp |
41 | //===----------------------------------------------------------------------===// |
42 | |
43 | SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { |
44 | assert(index <= 1 && "invalid successor index" ); |
45 | if (index == 0) |
46 | return SuccessorOperands(0, getSuccessOperandsMutable()); |
47 | return SuccessorOperands(1, getErrorOperandsMutable()); |
48 | } |
49 | |
50 | //===----------------------------------------------------------------------===// |
51 | // TestCallOp |
52 | //===----------------------------------------------------------------------===// |
53 | |
54 | LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
55 | // Check that the callee attribute was specified. |
56 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee" ); |
57 | if (!fnAttr) |
58 | return emitOpError("requires a 'callee' symbol reference attribute" ); |
59 | if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) |
60 | return emitOpError() << "'" << fnAttr.getValue() |
61 | << "' does not reference a valid function" ; |
62 | return success(); |
63 | } |
64 | |
65 | //===----------------------------------------------------------------------===// |
66 | // FoldToCallOp |
67 | //===----------------------------------------------------------------------===// |
68 | |
69 | namespace { |
70 | struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { |
71 | using OpRewritePattern<FoldToCallOp>::OpRewritePattern; |
72 | |
73 | LogicalResult matchAndRewrite(FoldToCallOp op, |
74 | PatternRewriter &rewriter) const override { |
75 | rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), |
76 | op.getCalleeAttr(), ValueRange()); |
77 | return success(); |
78 | } |
79 | }; |
80 | } // namespace |
81 | |
82 | void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, |
83 | MLIRContext *context) { |
84 | results.add<FoldToCallOpPattern>(context); |
85 | } |
86 | |
87 | //===----------------------------------------------------------------------===// |
88 | // IsolatedRegionOp - test parsing passthrough operands |
89 | //===----------------------------------------------------------------------===// |
90 | |
91 | ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, |
92 | OperationState &result) { |
93 | // Parse the input operand. |
94 | OpAsmParser::Argument argInfo; |
95 | argInfo.type = parser.getBuilder().getIndexType(); |
96 | if (parser.parseOperand(argInfo.ssaName) || |
97 | parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) |
98 | return failure(); |
99 | |
100 | // Parse the body region, and reuse the operand info as the argument info. |
101 | Region *body = result.addRegion(); |
102 | return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); |
103 | } |
104 | |
105 | void IsolatedRegionOp::print(OpAsmPrinter &p) { |
106 | p << ' '; |
107 | p.printOperand(getOperand()); |
108 | p.shadowRegionArgs(getRegion(), getOperand()); |
109 | p << ' '; |
110 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
111 | } |
112 | |
113 | //===----------------------------------------------------------------------===// |
114 | // SSACFGRegionOp |
115 | //===----------------------------------------------------------------------===// |
116 | |
117 | RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { |
118 | return RegionKind::SSACFG; |
119 | } |
120 | |
121 | //===----------------------------------------------------------------------===// |
122 | // GraphRegionOp |
123 | //===----------------------------------------------------------------------===// |
124 | |
125 | RegionKind GraphRegionOp::getRegionKind(unsigned index) { |
126 | return RegionKind::Graph; |
127 | } |
128 | |
129 | //===----------------------------------------------------------------------===// |
130 | // IsolatedGraphRegionOp |
131 | //===----------------------------------------------------------------------===// |
132 | |
133 | RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) { |
134 | return RegionKind::Graph; |
135 | } |
136 | |
137 | //===----------------------------------------------------------------------===// |
138 | // AffineScopeOp |
139 | //===----------------------------------------------------------------------===// |
140 | |
141 | ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { |
142 | // Parse the body region, and reuse the operand info as the argument info. |
143 | Region *body = result.addRegion(); |
144 | return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); |
145 | } |
146 | |
147 | void AffineScopeOp::print(OpAsmPrinter &p) { |
148 | p << " " ; |
149 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
150 | } |
151 | |
152 | //===----------------------------------------------------------------------===// |
153 | // TestRemoveOpWithInnerOps |
154 | //===----------------------------------------------------------------------===// |
155 | |
156 | namespace { |
157 | struct TestRemoveOpWithInnerOps |
158 | : public OpRewritePattern<TestOpWithRegionPattern> { |
159 | using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; |
160 | |
161 | void initialize() { setDebugName("TestRemoveOpWithInnerOps" ); } |
162 | |
163 | LogicalResult matchAndRewrite(TestOpWithRegionPattern op, |
164 | PatternRewriter &rewriter) const override { |
165 | rewriter.eraseOp(op: op); |
166 | return success(); |
167 | } |
168 | }; |
169 | } // namespace |
170 | |
171 | //===----------------------------------------------------------------------===// |
172 | // TestOpWithRegionPattern |
173 | //===----------------------------------------------------------------------===// |
174 | |
175 | void TestOpWithRegionPattern::getCanonicalizationPatterns( |
176 | RewritePatternSet &results, MLIRContext *context) { |
177 | results.add<TestRemoveOpWithInnerOps>(context); |
178 | } |
179 | |
180 | //===----------------------------------------------------------------------===// |
181 | // TestOpWithRegionFold |
182 | //===----------------------------------------------------------------------===// |
183 | |
184 | OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) { |
185 | return getOperand(); |
186 | } |
187 | |
188 | //===----------------------------------------------------------------------===// |
189 | // TestOpConstant |
190 | //===----------------------------------------------------------------------===// |
191 | |
192 | OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } |
193 | |
194 | //===----------------------------------------------------------------------===// |
195 | // TestOpWithVariadicResultsAndFolder |
196 | //===----------------------------------------------------------------------===// |
197 | |
198 | LogicalResult TestOpWithVariadicResultsAndFolder::fold( |
199 | FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { |
200 | for (Value input : this->getOperands()) { |
201 | results.push_back(input); |
202 | } |
203 | return success(); |
204 | } |
205 | |
206 | //===----------------------------------------------------------------------===// |
207 | // TestOpInPlaceFold |
208 | //===----------------------------------------------------------------------===// |
209 | |
210 | OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { |
211 | // Exercise the fact that an operation created with createOrFold should be |
212 | // allowed to access its parent block. |
213 | assert(getOperation()->getBlock() && |
214 | "expected that operation is not unlinked" ); |
215 | |
216 | if (adaptor.getOp() && !getProperties().attr) { |
217 | // The folder adds "attr" if not present. |
218 | getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()); |
219 | return getResult(); |
220 | } |
221 | return {}; |
222 | } |
223 | |
224 | //===----------------------------------------------------------------------===// |
225 | // OpWithInferTypeInterfaceOp |
226 | //===----------------------------------------------------------------------===// |
227 | |
228 | LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( |
229 | MLIRContext *, std::optional<Location> location, ValueRange operands, |
230 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
231 | SmallVectorImpl<Type> &inferredReturnTypes) { |
232 | if (operands[0].getType() != operands[1].getType()) { |
233 | return emitOptionalError(location, "operand type mismatch " , |
234 | operands[0].getType(), " vs " , |
235 | operands[1].getType()); |
236 | } |
237 | inferredReturnTypes.assign({operands[0].getType()}); |
238 | return success(); |
239 | } |
240 | |
241 | //===----------------------------------------------------------------------===// |
242 | // OpWithShapedTypeInferTypeInterfaceOp |
243 | //===----------------------------------------------------------------------===// |
244 | |
245 | LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( |
246 | MLIRContext *context, std::optional<Location> location, |
247 | ValueShapeRange operands, DictionaryAttr attributes, |
248 | OpaqueProperties properties, RegionRange regions, |
249 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
250 | // Create return type consisting of the last element of the first operand. |
251 | auto operandType = operands.front().getType(); |
252 | auto sval = dyn_cast<ShapedType>(operandType); |
253 | if (!sval) |
254 | return emitOptionalError(location, "only shaped type operands allowed" ); |
255 | int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; |
256 | auto type = IntegerType::get(context, 17); |
257 | |
258 | Attribute encoding; |
259 | if (auto rankedTy = dyn_cast<RankedTensorType>(sval)) |
260 | encoding = rankedTy.getEncoding(); |
261 | inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); |
262 | return success(); |
263 | } |
264 | |
265 | LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( |
266 | OpBuilder &builder, ValueRange operands, |
267 | llvm::SmallVectorImpl<Value> &shapes) { |
268 | shapes = SmallVector<Value, 1>{ |
269 | builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; |
270 | return success(); |
271 | } |
272 | |
273 | //===----------------------------------------------------------------------===// |
274 | // OpWithResultShapeInterfaceOp |
275 | //===----------------------------------------------------------------------===// |
276 | |
277 | LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( |
278 | OpBuilder &builder, ValueRange operands, |
279 | llvm::SmallVectorImpl<Value> &shapes) { |
280 | Location loc = getLoc(); |
281 | shapes.reserve(operands.size()); |
282 | for (Value operand : llvm::reverse(operands)) { |
283 | auto rank = cast<RankedTensorType>(operand.getType()).getRank(); |
284 | auto currShape = llvm::to_vector<4>( |
285 | llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { |
286 | return builder.createOrFold<tensor::DimOp>(loc, operand, dim); |
287 | })); |
288 | shapes.push_back(builder.create<tensor::FromElementsOp>( |
289 | getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), |
290 | currShape)); |
291 | } |
292 | return success(); |
293 | } |
294 | |
295 | //===----------------------------------------------------------------------===// |
296 | // OpWithResultShapePerDimInterfaceOp |
297 | //===----------------------------------------------------------------------===// |
298 | |
299 | LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( |
300 | OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { |
301 | Location loc = getLoc(); |
302 | shapes.reserve(getNumOperands()); |
303 | for (Value operand : llvm::reverse(getOperands())) { |
304 | auto tensorType = cast<RankedTensorType>(operand.getType()); |
305 | auto currShape = llvm::to_vector<4>(llvm::map_range( |
306 | llvm::seq<int64_t>(0, tensorType.getRank()), |
307 | [&](int64_t dim) -> OpFoldResult { |
308 | return tensorType.isDynamicDim(dim) |
309 | ? static_cast<OpFoldResult>( |
310 | builder.createOrFold<tensor::DimOp>(loc, operand, |
311 | dim)) |
312 | : static_cast<OpFoldResult>( |
313 | builder.getIndexAttr(tensorType.getDimSize(dim))); |
314 | })); |
315 | shapes.emplace_back(std::move(currShape)); |
316 | } |
317 | return success(); |
318 | } |
319 | |
320 | //===----------------------------------------------------------------------===// |
321 | // SideEffectOp |
322 | //===----------------------------------------------------------------------===// |
323 | |
324 | namespace { |
325 | /// A test resource for side effects. |
326 | struct TestResource : public SideEffects::Resource::Base<TestResource> { |
327 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) |
328 | |
329 | StringRef getName() final { return "<Test>" ; } |
330 | }; |
331 | } // namespace |
332 | |
333 | void SideEffectOp::getEffects( |
334 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
335 | // Check for an effects attribute on the op instance. |
336 | ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects" ); |
337 | if (!effectsAttr) |
338 | return; |
339 | |
340 | for (Attribute element : effectsAttr) { |
341 | DictionaryAttr effectElement = cast<DictionaryAttr>(element); |
342 | |
343 | // Get the specific memory effect. |
344 | MemoryEffects::Effect *effect = |
345 | StringSwitch<MemoryEffects::Effect *>( |
346 | cast<StringAttr>(effectElement.get("effect" )).getValue()) |
347 | .Case("allocate" , MemoryEffects::Allocate::get()) |
348 | .Case("free" , MemoryEffects::Free::get()) |
349 | .Case("read" , MemoryEffects::Read::get()) |
350 | .Case("write" , MemoryEffects::Write::get()); |
351 | |
352 | // Check for a non-default resource to use. |
353 | SideEffects::Resource *resource = SideEffects::DefaultResource::get(); |
354 | if (effectElement.get("test_resource" )) |
355 | resource = TestResource::get(); |
356 | |
357 | // Check for a result to affect. |
358 | if (effectElement.get("on_result" )) |
359 | effects.emplace_back(effect, getOperation()->getOpResults()[0], resource); |
360 | else if (Attribute ref = effectElement.get("on_reference" )) |
361 | effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource); |
362 | else |
363 | effects.emplace_back(effect, resource); |
364 | } |
365 | } |
366 | |
367 | void SideEffectOp::getEffects( |
368 | SmallVectorImpl<TestEffects::EffectInstance> &effects) { |
369 | testSideEffectOpGetEffect(getOperation(), effects); |
370 | } |
371 | |
372 | void SideEffectWithRegionOp::getEffects( |
373 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
374 | // Check for an effects attribute on the op instance. |
375 | ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects" ); |
376 | if (!effectsAttr) |
377 | return; |
378 | |
379 | for (Attribute element : effectsAttr) { |
380 | DictionaryAttr effectElement = cast<DictionaryAttr>(element); |
381 | |
382 | // Get the specific memory effect. |
383 | MemoryEffects::Effect *effect = |
384 | StringSwitch<MemoryEffects::Effect *>( |
385 | cast<StringAttr>(effectElement.get("effect" )).getValue()) |
386 | .Case("allocate" , MemoryEffects::Allocate::get()) |
387 | .Case("free" , MemoryEffects::Free::get()) |
388 | .Case("read" , MemoryEffects::Read::get()) |
389 | .Case("write" , MemoryEffects::Write::get()); |
390 | |
391 | // Check for a non-default resource to use. |
392 | SideEffects::Resource *resource = SideEffects::DefaultResource::get(); |
393 | if (effectElement.get("test_resource" )) |
394 | resource = TestResource::get(); |
395 | |
396 | // Check for a result to affect. |
397 | if (effectElement.get("on_result" )) |
398 | effects.emplace_back(effect, getOperation()->getOpResults()[0], resource); |
399 | else if (effectElement.get("on_operand" )) |
400 | effects.emplace_back(effect, &getOperation()->getOpOperands()[0], |
401 | resource); |
402 | else if (effectElement.get("on_argument" )) |
403 | effects.emplace_back(effect, getOperation()->getRegion(0).getArgument(0), |
404 | resource); |
405 | else if (Attribute ref = effectElement.get("on_reference" )) |
406 | effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource); |
407 | else |
408 | effects.emplace_back(effect, resource); |
409 | } |
410 | } |
411 | |
412 | void SideEffectWithRegionOp::getEffects( |
413 | SmallVectorImpl<TestEffects::EffectInstance> &effects) { |
414 | testSideEffectOpGetEffect(getOperation(), effects); |
415 | } |
416 | |
417 | //===----------------------------------------------------------------------===// |
418 | // StringAttrPrettyNameOp |
419 | //===----------------------------------------------------------------------===// |
420 | |
421 | // This op has fancy handling of its SSA result name. |
422 | ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, |
423 | OperationState &result) { |
424 | // Add the result types. |
425 | for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) |
426 | result.addTypes(parser.getBuilder().getIntegerType(32)); |
427 | |
428 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
429 | return failure(); |
430 | |
431 | // If the attribute dictionary contains no 'names' attribute, infer it from |
432 | // the SSA name (if specified). |
433 | bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { |
434 | return attr.getName() == "names" ; |
435 | }); |
436 | |
437 | // If there was no name specified, check to see if there was a useful name |
438 | // specified in the asm file. |
439 | if (hadNames || parser.getNumResults() == 0) |
440 | return success(); |
441 | |
442 | SmallVector<StringRef, 4> names; |
443 | auto *context = result.getContext(); |
444 | |
445 | for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { |
446 | auto resultName = parser.getResultName(i); |
447 | StringRef nameStr; |
448 | if (!resultName.first.empty() && !isdigit(resultName.first[0])) |
449 | nameStr = resultName.first; |
450 | |
451 | names.push_back(nameStr); |
452 | } |
453 | |
454 | auto namesAttr = parser.getBuilder().getStrArrayAttr(names); |
455 | result.attributes.push_back({StringAttr::get(context, "names" ), namesAttr}); |
456 | return success(); |
457 | } |
458 | |
459 | void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { |
460 | // Note that we only need to print the "name" attribute if the asmprinter |
461 | // result name disagrees with it. This can happen in strange cases, e.g. |
462 | // when there are conflicts. |
463 | bool namesDisagree = getNames().size() != getNumResults(); |
464 | |
465 | SmallString<32> resultNameStr; |
466 | for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { |
467 | resultNameStr.clear(); |
468 | llvm::raw_svector_ostream tmpStream(resultNameStr); |
469 | p.printOperand(getResult(i), tmpStream); |
470 | |
471 | auto expectedName = dyn_cast<StringAttr>(getNames()[i]); |
472 | if (!expectedName || |
473 | tmpStream.str().drop_front() != expectedName.getValue()) { |
474 | namesDisagree = true; |
475 | } |
476 | } |
477 | |
478 | if (namesDisagree) |
479 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); |
480 | else |
481 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names" }); |
482 | } |
483 | |
484 | // We set the SSA name in the asm syntax to the contents of the name |
485 | // attribute. |
486 | void StringAttrPrettyNameOp::getAsmResultNames( |
487 | function_ref<void(Value, StringRef)> setNameFn) { |
488 | |
489 | auto value = getNames(); |
490 | for (size_t i = 0, e = value.size(); i != e; ++i) |
491 | if (auto str = dyn_cast<StringAttr>(value[i])) |
492 | if (!str.getValue().empty()) |
493 | setNameFn(getResult(i), str.getValue()); |
494 | } |
495 | |
496 | //===----------------------------------------------------------------------===// |
497 | // CustomResultsNameOp |
498 | //===----------------------------------------------------------------------===// |
499 | |
500 | void CustomResultsNameOp::getAsmResultNames( |
501 | function_ref<void(Value, StringRef)> setNameFn) { |
502 | ArrayAttr value = getNames(); |
503 | for (size_t i = 0, e = value.size(); i != e; ++i) |
504 | if (auto str = dyn_cast<StringAttr>(value[i])) |
505 | if (!str.empty()) |
506 | setNameFn(getResult(i), str.getValue()); |
507 | } |
508 | |
509 | //===----------------------------------------------------------------------===// |
510 | // ResultNameFromTypeOp |
511 | //===----------------------------------------------------------------------===// |
512 | |
513 | void ResultNameFromTypeOp::getAsmResultNames( |
514 | function_ref<void(Value, StringRef)> setNameFn) { |
515 | auto result = getResult(); |
516 | auto setResultNameFn = [&](::llvm::StringRef name) { |
517 | setNameFn(result, name); |
518 | }; |
519 | auto opAsmTypeInterface = |
520 | ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType()); |
521 | opAsmTypeInterface.getAsmName(setResultNameFn); |
522 | } |
523 | |
524 | //===----------------------------------------------------------------------===// |
525 | // BlockArgumentNameFromTypeOp |
526 | //===----------------------------------------------------------------------===// |
527 | |
528 | void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames( |
529 | ::mlir::Region ®ion, ::mlir::OpAsmSetValueNameFn setNameFn) { |
530 | for (auto &block : region) { |
531 | for (auto arg : block.getArguments()) { |
532 | if (auto opAsmTypeInterface = |
533 | ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) { |
534 | auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); }; |
535 | opAsmTypeInterface.getAsmName(setArgNameFn); |
536 | } |
537 | } |
538 | } |
539 | } |
540 | |
541 | //===----------------------------------------------------------------------===// |
542 | // ResultTypeWithTraitOp |
543 | //===----------------------------------------------------------------------===// |
544 | |
545 | LogicalResult ResultTypeWithTraitOp::verify() { |
546 | if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) |
547 | return success(); |
548 | return emitError("result type should have trait 'TestTypeTrait'" ); |
549 | } |
550 | |
551 | //===----------------------------------------------------------------------===// |
552 | // AttrWithTraitOp |
553 | //===----------------------------------------------------------------------===// |
554 | |
555 | LogicalResult AttrWithTraitOp::verify() { |
556 | if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) |
557 | return success(); |
558 | return emitError("'attr' attribute should have trait 'TestAttrTrait'" ); |
559 | } |
560 | |
561 | //===----------------------------------------------------------------------===// |
562 | // RegionIfOp |
563 | //===----------------------------------------------------------------------===// |
564 | |
565 | void RegionIfOp::print(OpAsmPrinter &p) { |
566 | p << " " ; |
567 | p.printOperands(getOperands()); |
568 | p << ": " << getOperandTypes(); |
569 | p.printArrowTypeList(getResultTypes()); |
570 | p << " then " ; |
571 | p.printRegion(getThenRegion(), |
572 | /*printEntryBlockArgs=*/true, |
573 | /*printBlockTerminators=*/true); |
574 | p << " else " ; |
575 | p.printRegion(getElseRegion(), |
576 | /*printEntryBlockArgs=*/true, |
577 | /*printBlockTerminators=*/true); |
578 | p << " join " ; |
579 | p.printRegion(getJoinRegion(), |
580 | /*printEntryBlockArgs=*/true, |
581 | /*printBlockTerminators=*/true); |
582 | } |
583 | |
584 | ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { |
585 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; |
586 | SmallVector<Type, 2> operandTypes; |
587 | |
588 | result.regions.reserve(3); |
589 | Region *thenRegion = result.addRegion(); |
590 | Region *elseRegion = result.addRegion(); |
591 | Region *joinRegion = result.addRegion(); |
592 | |
593 | // Parse operand, type and arrow type lists. |
594 | if (parser.parseOperandList(operandInfos) || |
595 | parser.parseColonTypeList(operandTypes) || |
596 | parser.parseArrowTypeList(result.types)) |
597 | return failure(); |
598 | |
599 | // Parse all attached regions. |
600 | if (parser.parseKeyword("then" ) || parser.parseRegion(*thenRegion, {}, {}) || |
601 | parser.parseKeyword("else" ) || parser.parseRegion(*elseRegion, {}, {}) || |
602 | parser.parseKeyword("join" ) || parser.parseRegion(*joinRegion, {}, {})) |
603 | return failure(); |
604 | |
605 | return parser.resolveOperands(operandInfos, operandTypes, |
606 | parser.getCurrentLocation(), result.operands); |
607 | } |
608 | |
609 | OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
610 | assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && |
611 | "invalid region index" ); |
612 | return getOperands(); |
613 | } |
614 | |
615 | void RegionIfOp::getSuccessorRegions( |
616 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
617 | // We always branch to the join region. |
618 | if (!point.isParent()) { |
619 | if (point != getJoinRegion()) |
620 | regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); |
621 | else |
622 | regions.push_back(RegionSuccessor(getResults())); |
623 | return; |
624 | } |
625 | |
626 | // The then and else regions are the entry regions of this op. |
627 | regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); |
628 | regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); |
629 | } |
630 | |
631 | void RegionIfOp::getRegionInvocationBounds( |
632 | ArrayRef<Attribute> operands, |
633 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
634 | // Each region is invoked at most once. |
635 | invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); |
636 | } |
637 | |
638 | //===----------------------------------------------------------------------===// |
639 | // AnyCondOp |
640 | //===----------------------------------------------------------------------===// |
641 | |
642 | void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, |
643 | SmallVectorImpl<RegionSuccessor> ®ions) { |
644 | // The parent op branches into the only region, and the region branches back |
645 | // to the parent op. |
646 | if (point.isParent()) |
647 | regions.emplace_back(&getRegion()); |
648 | else |
649 | regions.emplace_back(getResults()); |
650 | } |
651 | |
652 | void AnyCondOp::getRegionInvocationBounds( |
653 | ArrayRef<Attribute> operands, |
654 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
655 | invocationBounds.emplace_back(1, 1); |
656 | } |
657 | |
658 | //===----------------------------------------------------------------------===// |
659 | // SingleBlockImplicitTerminatorOp |
660 | //===----------------------------------------------------------------------===// |
661 | |
662 | /// Testing the correctness of some traits. |
663 | static_assert( |
664 | llvm::is_detected<OpTrait::has_implicit_terminator_t, |
665 | SingleBlockImplicitTerminatorOp>::value, |
666 | "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp" ); |
667 | static_assert(OpTrait::hasSingleBlockImplicitTerminator< |
668 | SingleBlockImplicitTerminatorOp>::value, |
669 | "hasSingleBlockImplicitTerminator does not match " |
670 | "SingleBlockImplicitTerminatorOp" ); |
671 | |
672 | //===----------------------------------------------------------------------===// |
673 | // SingleNoTerminatorCustomAsmOp |
674 | //===----------------------------------------------------------------------===// |
675 | |
676 | ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, |
677 | OperationState &state) { |
678 | Region *body = state.addRegion(); |
679 | if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) |
680 | return failure(); |
681 | return success(); |
682 | } |
683 | |
684 | void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { |
685 | printer.printRegion( |
686 | getRegion(), /*printEntryBlockArgs=*/false, |
687 | // This op has a single block without terminators. But explicitly mark |
688 | // as not printing block terminators for testing. |
689 | /*printBlockTerminators=*/false); |
690 | } |
691 | |
692 | //===----------------------------------------------------------------------===// |
693 | // TestVerifiersOp |
694 | //===----------------------------------------------------------------------===// |
695 | |
696 | LogicalResult TestVerifiersOp::verify() { |
697 | if (!getRegion().hasOneBlock()) |
698 | return emitOpError("`hasOneBlock` trait hasn't been verified" ); |
699 | |
700 | Operation *definingOp = getInput().getDefiningOp(); |
701 | if (definingOp && failed(mlir::verify(definingOp))) |
702 | return emitOpError("operand hasn't been verified" ); |
703 | |
704 | // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier |
705 | // loop. |
706 | mlir::emitRemark(getLoc(), "success run of verifier" ); |
707 | |
708 | return success(); |
709 | } |
710 | |
711 | LogicalResult TestVerifiersOp::verifyRegions() { |
712 | if (!getRegion().hasOneBlock()) |
713 | return emitOpError("`hasOneBlock` trait hasn't been verified" ); |
714 | |
715 | for (Block &block : getRegion()) |
716 | for (Operation &op : block) |
717 | if (failed(mlir::verify(&op))) |
718 | return emitOpError("nested op hasn't been verified" ); |
719 | |
720 | // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier |
721 | // loop. |
722 | mlir::emitRemark(getLoc(), "success run of region verifier" ); |
723 | |
724 | return success(); |
725 | } |
726 | |
727 | //===----------------------------------------------------------------------===// |
728 | // Test InferIntRangeInterface |
729 | //===----------------------------------------------------------------------===// |
730 | |
731 | //===----------------------------------------------------------------------===// |
732 | // TestWithBoundsOp |
733 | //===----------------------------------------------------------------------===// |
734 | |
735 | void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
736 | SetIntRangeFn setResultRanges) { |
737 | setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); |
738 | } |
739 | |
740 | //===----------------------------------------------------------------------===// |
741 | // TestWithBoundsRegionOp |
742 | //===----------------------------------------------------------------------===// |
743 | |
744 | ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, |
745 | OperationState &result) { |
746 | if (parser.parseOptionalAttrDict(result.attributes)) |
747 | return failure(); |
748 | |
749 | // Parse the input argument |
750 | OpAsmParser::Argument argInfo; |
751 | if (failed(parser.parseArgument(argInfo, true))) |
752 | return failure(); |
753 | |
754 | // Parse the body region, and reuse the operand info as the argument info. |
755 | Region *body = result.addRegion(); |
756 | return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); |
757 | } |
758 | |
759 | void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { |
760 | p.printOptionalAttrDict((*this)->getAttrs()); |
761 | p << ' '; |
762 | p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, |
763 | /*omitType=*/false); |
764 | p << ' '; |
765 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
766 | } |
767 | |
768 | void TestWithBoundsRegionOp::inferResultRanges( |
769 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { |
770 | Value arg = getRegion().getArgument(0); |
771 | setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); |
772 | } |
773 | |
774 | //===----------------------------------------------------------------------===// |
775 | // TestIncrementOp |
776 | //===----------------------------------------------------------------------===// |
777 | |
778 | void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
779 | SetIntRangeFn setResultRanges) { |
780 | const ConstantIntRanges &range = argRanges[0]; |
781 | APInt one(range.umin().getBitWidth(), 1); |
782 | setResultRanges(getResult(), |
783 | {range.umin().uadd_sat(one), range.umax().uadd_sat(one), |
784 | range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); |
785 | } |
786 | |
787 | //===----------------------------------------------------------------------===// |
788 | // TestReflectBoundsOp |
789 | //===----------------------------------------------------------------------===// |
790 | |
791 | void TestReflectBoundsOp::inferResultRanges( |
792 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { |
793 | const ConstantIntRanges &range = argRanges[0]; |
794 | MLIRContext *ctx = getContext(); |
795 | Builder b(ctx); |
796 | Type sIntTy, uIntTy; |
797 | // For plain `IntegerType`s, we can derive the appropriate signed and unsigned |
798 | // Types for the Attributes. |
799 | Type type = getElementTypeOrSelf(getType()); |
800 | if (auto intTy = llvm::dyn_cast<IntegerType>(type)) { |
801 | unsigned bitwidth = intTy.getWidth(); |
802 | sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true); |
803 | uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false); |
804 | } else { |
805 | sIntTy = uIntTy = type; |
806 | } |
807 | |
808 | setUminAttr(b.getIntegerAttr(uIntTy, range.umin())); |
809 | setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax())); |
810 | setSminAttr(b.getIntegerAttr(sIntTy, range.smin())); |
811 | setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax())); |
812 | setResultRanges(getResult(), range); |
813 | } |
814 | |
815 | //===----------------------------------------------------------------------===// |
816 | // ConversionFuncOp |
817 | //===----------------------------------------------------------------------===// |
818 | |
819 | ParseResult ConversionFuncOp::parse(OpAsmParser &parser, |
820 | OperationState &result) { |
821 | auto buildFuncType = |
822 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
823 | function_interface_impl::VariadicFlag, |
824 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
825 | |
826 | return function_interface_impl::parseFunctionOp( |
827 | parser, result, /*allowVariadic=*/false, |
828 | getFunctionTypeAttrName(result.name), buildFuncType, |
829 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
830 | } |
831 | |
832 | void ConversionFuncOp::print(OpAsmPrinter &p) { |
833 | function_interface_impl::printFunctionOp( |
834 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
835 | getArgAttrsAttrName(), getResAttrsAttrName()); |
836 | } |
837 | |
838 | //===----------------------------------------------------------------------===// |
839 | // ReifyBoundOp |
840 | //===----------------------------------------------------------------------===// |
841 | |
842 | mlir::presburger::BoundType ReifyBoundOp::getBoundType() { |
843 | if (getType() == "EQ" ) |
844 | return mlir::presburger::BoundType::EQ; |
845 | if (getType() == "LB" ) |
846 | return mlir::presburger::BoundType::LB; |
847 | if (getType() == "UB" ) |
848 | return mlir::presburger::BoundType::UB; |
849 | llvm_unreachable("invalid bound type" ); |
850 | } |
851 | |
852 | LogicalResult ReifyBoundOp::verify() { |
853 | if (isa<ShapedType>(getVar().getType())) { |
854 | if (!getDim().has_value()) |
855 | return emitOpError("expected 'dim' attribute for shaped type variable" ); |
856 | } else if (getVar().getType().isIndex()) { |
857 | if (getDim().has_value()) |
858 | return emitOpError("unexpected 'dim' attribute for index variable" ); |
859 | } else { |
860 | return emitOpError("expected index-typed variable or shape type variable" ); |
861 | } |
862 | if (getConstant() && getScalable()) |
863 | return emitOpError("'scalable' and 'constant' are mutually exlusive" ); |
864 | if (getScalable() != getVscaleMin().has_value()) |
865 | return emitOpError("expected 'vscale_min' if and only if 'scalable'" ); |
866 | if (getScalable() != getVscaleMax().has_value()) |
867 | return emitOpError("expected 'vscale_min' if and only if 'scalable'" ); |
868 | return success(); |
869 | } |
870 | |
871 | ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() { |
872 | if (getDim().has_value()) |
873 | return ValueBoundsConstraintSet::Variable(getVar(), *getDim()); |
874 | return ValueBoundsConstraintSet::Variable(getVar()); |
875 | } |
876 | |
877 | //===----------------------------------------------------------------------===// |
878 | // CompareOp |
879 | //===----------------------------------------------------------------------===// |
880 | |
881 | ValueBoundsConstraintSet::ComparisonOperator |
882 | CompareOp::getComparisonOperator() { |
883 | if (getCmp() == "EQ" ) |
884 | return ValueBoundsConstraintSet::ComparisonOperator::EQ; |
885 | if (getCmp() == "LT" ) |
886 | return ValueBoundsConstraintSet::ComparisonOperator::LT; |
887 | if (getCmp() == "LE" ) |
888 | return ValueBoundsConstraintSet::ComparisonOperator::LE; |
889 | if (getCmp() == "GT" ) |
890 | return ValueBoundsConstraintSet::ComparisonOperator::GT; |
891 | if (getCmp() == "GE" ) |
892 | return ValueBoundsConstraintSet::ComparisonOperator::GE; |
893 | llvm_unreachable("invalid comparison operator" ); |
894 | } |
895 | |
896 | mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() { |
897 | if (!getLhsMap()) |
898 | return ValueBoundsConstraintSet::Variable(getVarOperands()[0]); |
899 | SmallVector<Value> mapOperands( |
900 | getVarOperands().slice(0, getLhsMap()->getNumInputs())); |
901 | return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands); |
902 | } |
903 | |
904 | mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() { |
905 | int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1; |
906 | if (!getRhsMap()) |
907 | return ValueBoundsConstraintSet::Variable( |
908 | getVarOperands()[rhsOperandsBegin]); |
909 | SmallVector<Value> mapOperands( |
910 | getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs())); |
911 | return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands); |
912 | } |
913 | |
914 | LogicalResult CompareOp::verify() { |
915 | if (getCompose() && (getLhsMap() || getRhsMap())) |
916 | return emitOpError( |
917 | "'compose' not supported when 'lhs_map' or 'rhs_map' is present" ); |
918 | int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1; |
919 | expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1; |
920 | if (getVarOperands().size() != size_t(expectedNumOperands)) |
921 | return emitOpError("expected " ) |
922 | << expectedNumOperands << " operands, but got " |
923 | << getVarOperands().size(); |
924 | return success(); |
925 | } |
926 | |
927 | //===----------------------------------------------------------------------===// |
928 | // TestOpInPlaceSelfFold |
929 | //===----------------------------------------------------------------------===// |
930 | |
931 | OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) { |
932 | if (!getFolded()) { |
933 | // The folder adds the "folded" if not present. |
934 | setFolded(true); |
935 | return getResult(); |
936 | } |
937 | return {}; |
938 | } |
939 | |
940 | //===----------------------------------------------------------------------===// |
941 | // TestOpFoldWithFoldAdaptor |
942 | //===----------------------------------------------------------------------===// |
943 | |
944 | OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { |
945 | int64_t sum = 0; |
946 | if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp())) |
947 | sum += value.getValue().getSExtValue(); |
948 | |
949 | for (Attribute attr : adaptor.getVariadic()) |
950 | if (auto value = dyn_cast_or_null<IntegerAttr>(attr)) |
951 | sum += 2 * value.getValue().getSExtValue(); |
952 | |
953 | for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar()) |
954 | for (Attribute attr : attrs) |
955 | if (auto value = dyn_cast_or_null<IntegerAttr>(attr)) |
956 | sum += 3 * value.getValue().getSExtValue(); |
957 | |
958 | sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end()); |
959 | |
960 | return IntegerAttr::get(getType(), sum); |
961 | } |
962 | |
963 | //===----------------------------------------------------------------------===// |
964 | // OpWithInferTypeAdaptorInterfaceOp |
965 | //===----------------------------------------------------------------------===// |
966 | |
967 | LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes( |
968 | MLIRContext *, std::optional<Location> location, |
969 | OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor, |
970 | SmallVectorImpl<Type> &inferredReturnTypes) { |
971 | if (adaptor.getX().getType() != adaptor.getY().getType()) { |
972 | return emitOptionalError(location, "operand type mismatch " , |
973 | adaptor.getX().getType(), " vs " , |
974 | adaptor.getY().getType()); |
975 | } |
976 | inferredReturnTypes.assign({adaptor.getX().getType()}); |
977 | return success(); |
978 | } |
979 | |
980 | //===----------------------------------------------------------------------===// |
981 | // OpWithRefineTypeInterfaceOp |
982 | //===----------------------------------------------------------------------===// |
983 | |
984 | // TODO: We should be able to only define either inferReturnType or |
985 | // refineReturnType, currently only refineReturnType can be omitted. |
986 | LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( |
987 | MLIRContext *context, std::optional<Location> location, ValueRange operands, |
988 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
989 | SmallVectorImpl<Type> &returnTypes) { |
990 | returnTypes.clear(); |
991 | return OpWithRefineTypeInterfaceOp::refineReturnTypes( |
992 | context, location, operands, attributes, properties, regions, |
993 | returnTypes); |
994 | } |
995 | |
996 | LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( |
997 | MLIRContext *, std::optional<Location> location, ValueRange operands, |
998 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
999 | SmallVectorImpl<Type> &returnTypes) { |
1000 | if (operands[0].getType() != operands[1].getType()) { |
1001 | return emitOptionalError(location, "operand type mismatch " , |
1002 | operands[0].getType(), " vs " , |
1003 | operands[1].getType()); |
1004 | } |
1005 | // TODO: Add helper to make this more concise to write. |
1006 | if (returnTypes.empty()) |
1007 | returnTypes.resize(1, nullptr); |
1008 | if (returnTypes[0] && returnTypes[0] != operands[0].getType()) |
1009 | return emitOptionalError(location, |
1010 | "required first operand and result to match" ); |
1011 | returnTypes[0] = operands[0].getType(); |
1012 | return success(); |
1013 | } |
1014 | |
1015 | //===----------------------------------------------------------------------===// |
1016 | // OpWithShapedTypeInferTypeAdaptorInterfaceOp |
1017 | //===----------------------------------------------------------------------===// |
1018 | |
1019 | LogicalResult |
1020 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents( |
1021 | MLIRContext *context, std::optional<Location> location, |
1022 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor, |
1023 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
1024 | // Create return type consisting of the last element of the first operand. |
1025 | auto operandType = adaptor.getOperand1().getType(); |
1026 | auto sval = dyn_cast<ShapedType>(operandType); |
1027 | if (!sval) |
1028 | return emitOptionalError(location, "only shaped type operands allowed" ); |
1029 | int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; |
1030 | auto type = IntegerType::get(context, 17); |
1031 | |
1032 | Attribute encoding; |
1033 | if (auto rankedTy = dyn_cast<RankedTensorType>(sval)) |
1034 | encoding = rankedTy.getEncoding(); |
1035 | inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); |
1036 | return success(); |
1037 | } |
1038 | |
1039 | LogicalResult |
1040 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes( |
1041 | OpBuilder &builder, ValueRange operands, |
1042 | llvm::SmallVectorImpl<Value> &shapes) { |
1043 | shapes = SmallVector<Value, 1>{ |
1044 | builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; |
1045 | return success(); |
1046 | } |
1047 | |
1048 | //===----------------------------------------------------------------------===// |
1049 | // TestOpWithPropertiesAndInferredType |
1050 | //===----------------------------------------------------------------------===// |
1051 | |
1052 | LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes( |
1053 | MLIRContext *context, std::optional<Location>, ValueRange operands, |
1054 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
1055 | SmallVectorImpl<Type> &inferredReturnTypes) { |
1056 | |
1057 | Adaptor adaptor(operands, attributes, properties, regions); |
1058 | inferredReturnTypes.push_back(IntegerType::get( |
1059 | context, adaptor.getLhs() + adaptor.getProperties().rhs)); |
1060 | return success(); |
1061 | } |
1062 | |
1063 | //===----------------------------------------------------------------------===// |
1064 | // LoopBlockOp |
1065 | //===----------------------------------------------------------------------===// |
1066 | |
1067 | void LoopBlockOp::getSuccessorRegions( |
1068 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
1069 | regions.emplace_back(&getBody(), getBody().getArguments()); |
1070 | if (point.isParent()) |
1071 | return; |
1072 | |
1073 | regions.emplace_back((*this)->getResults()); |
1074 | } |
1075 | |
1076 | OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
1077 | assert(point == getBody()); |
1078 | return MutableOperandRange(getInitMutable()); |
1079 | } |
1080 | |
1081 | //===----------------------------------------------------------------------===// |
1082 | // LoopBlockTerminatorOp |
1083 | //===----------------------------------------------------------------------===// |
1084 | |
1085 | MutableOperandRange |
1086 | LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { |
1087 | if (point.isParent()) |
1088 | return getExitArgMutable(); |
1089 | return getNextIterArgMutable(); |
1090 | } |
1091 | |
1092 | //===----------------------------------------------------------------------===// |
1093 | // SwitchWithNoBreakOp |
1094 | //===----------------------------------------------------------------------===// |
1095 | |
1096 | void TestNoTerminatorOp::getSuccessorRegions( |
1097 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {} |
1098 | |
1099 | //===----------------------------------------------------------------------===// |
1100 | // Test InferIntRangeInterface |
1101 | //===----------------------------------------------------------------------===// |
1102 | |
1103 | OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) { |
1104 | // Just a simple fold for testing purposes that reads an operands constant |
1105 | // value and returns it. |
1106 | if (!attributes.empty()) |
1107 | return attributes.front(); |
1108 | return nullptr; |
1109 | } |
1110 | |
1111 | //===----------------------------------------------------------------------===// |
1112 | // Tensor/Buffer Ops |
1113 | //===----------------------------------------------------------------------===// |
1114 | |
1115 | void ReadBufferOp::getEffects( |
1116 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
1117 | &effects) { |
1118 | // The buffer operand is read. |
1119 | effects.emplace_back(MemoryEffects::Read::get(), &getBufferMutable(), |
1120 | SideEffects::DefaultResource::get()); |
1121 | // The buffer contents are dumped. |
1122 | effects.emplace_back(MemoryEffects::Write::get(), |
1123 | SideEffects::DefaultResource::get()); |
1124 | } |
1125 | |
1126 | //===----------------------------------------------------------------------===// |
1127 | // Test Dataflow |
1128 | //===----------------------------------------------------------------------===// |
1129 | |
1130 | //===----------------------------------------------------------------------===// |
1131 | // TestCallAndStoreOp |
1132 | //===----------------------------------------------------------------------===// |
1133 | |
1134 | CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() { |
1135 | return getCallee(); |
1136 | } |
1137 | |
1138 | void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
1139 | setCalleeAttr(cast<SymbolRefAttr>(callee)); |
1140 | } |
1141 | |
1142 | Operation::operand_range TestCallAndStoreOp::getArgOperands() { |
1143 | return getCalleeOperands(); |
1144 | } |
1145 | |
1146 | MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() { |
1147 | return getCalleeOperandsMutable(); |
1148 | } |
1149 | |
1150 | //===----------------------------------------------------------------------===// |
1151 | // TestCallOnDeviceOp |
1152 | //===----------------------------------------------------------------------===// |
1153 | |
1154 | CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() { |
1155 | return getCallee(); |
1156 | } |
1157 | |
1158 | void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
1159 | setCalleeAttr(cast<SymbolRefAttr>(callee)); |
1160 | } |
1161 | |
1162 | Operation::operand_range TestCallOnDeviceOp::getArgOperands() { |
1163 | return getForwardedOperands(); |
1164 | } |
1165 | |
1166 | MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { |
1167 | return getForwardedOperandsMutable(); |
1168 | } |
1169 | |
1170 | //===----------------------------------------------------------------------===// |
1171 | // TestStoreWithARegion |
1172 | //===----------------------------------------------------------------------===// |
1173 | |
1174 | void TestStoreWithARegion::getSuccessorRegions( |
1175 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
1176 | if (point.isParent()) |
1177 | regions.emplace_back(&getBody(), getBody().front().getArguments()); |
1178 | else |
1179 | regions.emplace_back(); |
1180 | } |
1181 | |
1182 | //===----------------------------------------------------------------------===// |
1183 | // TestStoreWithALoopRegion |
1184 | //===----------------------------------------------------------------------===// |
1185 | |
1186 | void TestStoreWithALoopRegion::getSuccessorRegions( |
1187 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
1188 | // Both the operation itself and the region may be branching into the body or |
1189 | // back into the operation itself. It is possible for the operation not to |
1190 | // enter the body. |
1191 | regions.emplace_back( |
1192 | RegionSuccessor(&getBody(), getBody().front().getArguments())); |
1193 | regions.emplace_back(); |
1194 | } |
1195 | |
1196 | //===----------------------------------------------------------------------===// |
1197 | // TestVersionedOpA |
1198 | //===----------------------------------------------------------------------===// |
1199 | |
1200 | LogicalResult |
1201 | TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader, |
1202 | mlir::OperationState &state) { |
1203 | auto &prop = state.getOrAddProperties<Properties>(); |
1204 | if (mlir::failed(reader.readAttribute(prop.dims))) |
1205 | return mlir::failure(); |
1206 | |
1207 | // Check if we have a version. If not, assume we are parsing the current |
1208 | // version. |
1209 | auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); |
1210 | if (succeeded(maybeVersion)) { |
1211 | // If version is less than 2.0, there is no additional attribute to parse. |
1212 | // We can materialize missing properties post parsing before verification. |
1213 | const auto *version = |
1214 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
1215 | if ((version->major_ < 2)) { |
1216 | return success(); |
1217 | } |
1218 | } |
1219 | |
1220 | if (mlir::failed(reader.readAttribute(prop.modifier))) |
1221 | return mlir::failure(); |
1222 | return mlir::success(); |
1223 | } |
1224 | |
1225 | void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) { |
1226 | auto &prop = getProperties(); |
1227 | writer.writeAttribute(prop.dims); |
1228 | |
1229 | auto maybeVersion = writer.getDialectVersion<test::TestDialect>(); |
1230 | if (succeeded(maybeVersion)) { |
1231 | // If version is less than 2.0, there is no additional attribute to write. |
1232 | const auto *version = |
1233 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
1234 | if ((version->major_ < 2)) { |
1235 | llvm::outs() << "downgrading op properties...\n" ; |
1236 | return; |
1237 | } |
1238 | } |
1239 | writer.writeAttribute(prop.modifier); |
1240 | } |
1241 | |
1242 | //===----------------------------------------------------------------------===// |
1243 | // TestOpWithVersionedProperties |
1244 | //===----------------------------------------------------------------------===// |
1245 | |
1246 | llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode( |
1247 | mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) { |
1248 | uint64_t value1, value2 = 0; |
1249 | if (failed(reader.readVarInt(value1))) |
1250 | return failure(); |
1251 | |
1252 | // Check if we have a version. If not, assume we are parsing the current |
1253 | // version. |
1254 | auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); |
1255 | bool needToParseAnotherInt = true; |
1256 | if (succeeded(maybeVersion)) { |
1257 | // If version is less than 2.0, there is no additional attribute to parse. |
1258 | // We can materialize missing properties post parsing before verification. |
1259 | const auto *version = |
1260 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
1261 | if ((version->major_ < 2)) |
1262 | needToParseAnotherInt = false; |
1263 | } |
1264 | if (needToParseAnotherInt && failed(reader.readVarInt(value2))) |
1265 | return failure(); |
1266 | |
1267 | prop.value1 = value1; |
1268 | prop.value2 = value2; |
1269 | return success(); |
1270 | } |
1271 | |
1272 | void TestOpWithVersionedProperties::writeToMlirBytecode( |
1273 | mlir::DialectBytecodeWriter &writer, |
1274 | const test::VersionedProperties &prop) { |
1275 | writer.writeVarInt(prop.value1); |
1276 | writer.writeVarInt(prop.value2); |
1277 | } |
1278 | |
1279 | //===----------------------------------------------------------------------===// |
1280 | // TestMultiSlotAlloca |
1281 | //===----------------------------------------------------------------------===// |
1282 | |
1283 | llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() { |
1284 | SmallVector<MemorySlot> slots; |
1285 | for (Value result : getResults()) { |
1286 | slots.push_back(MemorySlot{ |
1287 | result, cast<MemRefType>(result.getType()).getElementType()}); |
1288 | } |
1289 | return slots; |
1290 | } |
1291 | |
1292 | Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot, |
1293 | OpBuilder &builder) { |
1294 | return builder.create<TestOpConstant>(getLoc(), slot.elemType, |
1295 | builder.getI32IntegerAttr(42)); |
1296 | } |
1297 | |
1298 | void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot, |
1299 | BlockArgument argument, |
1300 | OpBuilder &builder) { |
1301 | // Not relevant for testing. |
1302 | } |
1303 | |
1304 | /// Creates a new TestMultiSlotAlloca operation, just without the `slot`. |
1305 | static std::optional<TestMultiSlotAlloca> |
1306 | createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder, |
1307 | TestMultiSlotAlloca oldOp) { |
1308 | |
1309 | if (oldOp.getNumResults() == 1) { |
1310 | oldOp.erase(); |
1311 | return std::nullopt; |
1312 | } |
1313 | |
1314 | SmallVector<Type> newTypes; |
1315 | SmallVector<Value> remainingValues; |
1316 | |
1317 | for (Value oldResult : oldOp.getResults()) { |
1318 | if (oldResult == slot.ptr) |
1319 | continue; |
1320 | remainingValues.push_back(oldResult); |
1321 | newTypes.push_back(oldResult.getType()); |
1322 | } |
1323 | |
1324 | OpBuilder::InsertionGuard guard(builder); |
1325 | builder.setInsertionPoint(oldOp); |
1326 | auto replacement = |
1327 | builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes); |
1328 | for (auto [oldResult, newResult] : |
1329 | llvm::zip_equal(remainingValues, replacement.getResults())) |
1330 | oldResult.replaceAllUsesWith(newResult); |
1331 | |
1332 | oldOp.erase(); |
1333 | return replacement; |
1334 | } |
1335 | |
1336 | std::optional<PromotableAllocationOpInterface> |
1337 | TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot, |
1338 | Value defaultValue, |
1339 | OpBuilder &builder) { |
1340 | if (defaultValue && defaultValue.use_empty()) |
1341 | defaultValue.getDefiningOp()->erase(); |
1342 | return createNewMultiAllocaWithoutSlot(slot, builder, *this); |
1343 | } |
1344 | |
1345 | SmallVector<DestructurableMemorySlot> |
1346 | TestMultiSlotAlloca::getDestructurableSlots() { |
1347 | SmallVector<DestructurableMemorySlot> slots; |
1348 | for (Value result : getResults()) { |
1349 | auto memrefType = cast<MemRefType>(result.getType()); |
1350 | auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType); |
1351 | if (!destructurable) |
1352 | continue; |
1353 | |
1354 | std::optional<DenseMap<Attribute, Type>> destructuredType = |
1355 | destructurable.getSubelementIndexMap(); |
1356 | if (!destructuredType) |
1357 | continue; |
1358 | slots.emplace_back( |
1359 | DestructurableMemorySlot{{result, memrefType}, *destructuredType}); |
1360 | } |
1361 | return slots; |
1362 | } |
1363 | |
1364 | DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure( |
1365 | const DestructurableMemorySlot &slot, |
1366 | const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder, |
1367 | SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) { |
1368 | OpBuilder::InsertionGuard guard(builder); |
1369 | builder.setInsertionPointAfter(*this); |
1370 | |
1371 | DenseMap<Attribute, MemorySlot> slotMap; |
1372 | |
1373 | for (Attribute usedIndex : usedIndices) { |
1374 | Type elemType = slot.subelementTypes.lookup(usedIndex); |
1375 | MemRefType elemPtr = MemRefType::get({}, elemType); |
1376 | auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr); |
1377 | newAllocators.push_back(subAlloca); |
1378 | slotMap.try_emplace<MemorySlot>(usedIndex, |
1379 | {subAlloca.getResult(0), elemType}); |
1380 | } |
1381 | |
1382 | return slotMap; |
1383 | } |
1384 | |
1385 | std::optional<DestructurableAllocationOpInterface> |
1386 | TestMultiSlotAlloca::handleDestructuringComplete( |
1387 | const DestructurableMemorySlot &slot, OpBuilder &builder) { |
1388 | return createNewMultiAllocaWithoutSlot(slot, builder, *this); |
1389 | } |
1390 | |