| 1 | //===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "TestDialect.h" |
| 10 | #include "TestOps.h" |
| 11 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 12 | #include "mlir/IR/Verifier.h" |
| 13 | #include "mlir/Interfaces/FunctionImplementation.h" |
| 14 | #include "mlir/Interfaces/MemorySlotInterfaces.h" |
| 15 | |
| 16 | using namespace mlir; |
| 17 | using namespace test; |
| 18 | |
| 19 | //===----------------------------------------------------------------------===// |
| 20 | // TestBranchOp |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | |
| 23 | SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { |
| 24 | assert(index == 0 && "invalid successor index" ); |
| 25 | return SuccessorOperands(getTargetOperandsMutable()); |
| 26 | } |
| 27 | |
| 28 | //===----------------------------------------------------------------------===// |
| 29 | // TestProducingBranchOp |
| 30 | //===----------------------------------------------------------------------===// |
| 31 | |
| 32 | SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { |
| 33 | assert(index <= 1 && "invalid successor index" ); |
| 34 | if (index == 1) |
| 35 | return SuccessorOperands(getFirstOperandsMutable()); |
| 36 | return SuccessorOperands(getSecondOperandsMutable()); |
| 37 | } |
| 38 | |
| 39 | //===----------------------------------------------------------------------===// |
| 40 | // TestInternalBranchOp |
| 41 | //===----------------------------------------------------------------------===// |
| 42 | |
| 43 | SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { |
| 44 | assert(index <= 1 && "invalid successor index" ); |
| 45 | if (index == 0) |
| 46 | return SuccessorOperands(0, getSuccessOperandsMutable()); |
| 47 | return SuccessorOperands(1, getErrorOperandsMutable()); |
| 48 | } |
| 49 | |
| 50 | //===----------------------------------------------------------------------===// |
| 51 | // TestCallOp |
| 52 | //===----------------------------------------------------------------------===// |
| 53 | |
| 54 | LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| 55 | // Check that the callee attribute was specified. |
| 56 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee" ); |
| 57 | if (!fnAttr) |
| 58 | return emitOpError("requires a 'callee' symbol reference attribute" ); |
| 59 | if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) |
| 60 | return emitOpError() << "'" << fnAttr.getValue() |
| 61 | << "' does not reference a valid function" ; |
| 62 | return success(); |
| 63 | } |
| 64 | |
| 65 | //===----------------------------------------------------------------------===// |
| 66 | // FoldToCallOp |
| 67 | //===----------------------------------------------------------------------===// |
| 68 | |
| 69 | namespace { |
| 70 | struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { |
| 71 | using OpRewritePattern<FoldToCallOp>::OpRewritePattern; |
| 72 | |
| 73 | LogicalResult matchAndRewrite(FoldToCallOp op, |
| 74 | PatternRewriter &rewriter) const override { |
| 75 | rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), |
| 76 | op.getCalleeAttr(), ValueRange()); |
| 77 | return success(); |
| 78 | } |
| 79 | }; |
| 80 | } // namespace |
| 81 | |
| 82 | void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 83 | MLIRContext *context) { |
| 84 | results.add<FoldToCallOpPattern>(context); |
| 85 | } |
| 86 | |
| 87 | //===----------------------------------------------------------------------===// |
| 88 | // IsolatedRegionOp - test parsing passthrough operands |
| 89 | //===----------------------------------------------------------------------===// |
| 90 | |
| 91 | ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, |
| 92 | OperationState &result) { |
| 93 | // Parse the input operand. |
| 94 | OpAsmParser::Argument argInfo; |
| 95 | argInfo.type = parser.getBuilder().getIndexType(); |
| 96 | if (parser.parseOperand(argInfo.ssaName) || |
| 97 | parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) |
| 98 | return failure(); |
| 99 | |
| 100 | // Parse the body region, and reuse the operand info as the argument info. |
| 101 | Region *body = result.addRegion(); |
| 102 | return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); |
| 103 | } |
| 104 | |
| 105 | void IsolatedRegionOp::print(OpAsmPrinter &p) { |
| 106 | p << ' '; |
| 107 | p.printOperand(getOperand()); |
| 108 | p.shadowRegionArgs(getRegion(), getOperand()); |
| 109 | p << ' '; |
| 110 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
| 111 | } |
| 112 | |
| 113 | //===----------------------------------------------------------------------===// |
| 114 | // SSACFGRegionOp |
| 115 | //===----------------------------------------------------------------------===// |
| 116 | |
| 117 | RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { |
| 118 | return RegionKind::SSACFG; |
| 119 | } |
| 120 | |
| 121 | //===----------------------------------------------------------------------===// |
| 122 | // GraphRegionOp |
| 123 | //===----------------------------------------------------------------------===// |
| 124 | |
| 125 | RegionKind GraphRegionOp::getRegionKind(unsigned index) { |
| 126 | return RegionKind::Graph; |
| 127 | } |
| 128 | |
| 129 | //===----------------------------------------------------------------------===// |
| 130 | // IsolatedGraphRegionOp |
| 131 | //===----------------------------------------------------------------------===// |
| 132 | |
| 133 | RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) { |
| 134 | return RegionKind::Graph; |
| 135 | } |
| 136 | |
| 137 | //===----------------------------------------------------------------------===// |
| 138 | // AffineScopeOp |
| 139 | //===----------------------------------------------------------------------===// |
| 140 | |
| 141 | ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { |
| 142 | // Parse the body region, and reuse the operand info as the argument info. |
| 143 | Region *body = result.addRegion(); |
| 144 | return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); |
| 145 | } |
| 146 | |
| 147 | void AffineScopeOp::print(OpAsmPrinter &p) { |
| 148 | p << " " ; |
| 149 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
| 150 | } |
| 151 | |
| 152 | //===----------------------------------------------------------------------===// |
| 153 | // TestRemoveOpWithInnerOps |
| 154 | //===----------------------------------------------------------------------===// |
| 155 | |
| 156 | namespace { |
| 157 | struct TestRemoveOpWithInnerOps |
| 158 | : public OpRewritePattern<TestOpWithRegionPattern> { |
| 159 | using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; |
| 160 | |
| 161 | void initialize() { setDebugName("TestRemoveOpWithInnerOps" ); } |
| 162 | |
| 163 | LogicalResult matchAndRewrite(TestOpWithRegionPattern op, |
| 164 | PatternRewriter &rewriter) const override { |
| 165 | rewriter.eraseOp(op: op); |
| 166 | return success(); |
| 167 | } |
| 168 | }; |
| 169 | } // namespace |
| 170 | |
| 171 | //===----------------------------------------------------------------------===// |
| 172 | // TestOpWithRegionPattern |
| 173 | //===----------------------------------------------------------------------===// |
| 174 | |
| 175 | void TestOpWithRegionPattern::getCanonicalizationPatterns( |
| 176 | RewritePatternSet &results, MLIRContext *context) { |
| 177 | results.add<TestRemoveOpWithInnerOps>(context); |
| 178 | } |
| 179 | |
| 180 | //===----------------------------------------------------------------------===// |
| 181 | // TestOpWithRegionFold |
| 182 | //===----------------------------------------------------------------------===// |
| 183 | |
| 184 | OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) { |
| 185 | return getOperand(); |
| 186 | } |
| 187 | |
| 188 | //===----------------------------------------------------------------------===// |
| 189 | // TestOpConstant |
| 190 | //===----------------------------------------------------------------------===// |
| 191 | |
| 192 | OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } |
| 193 | |
| 194 | //===----------------------------------------------------------------------===// |
| 195 | // TestOpWithVariadicResultsAndFolder |
| 196 | //===----------------------------------------------------------------------===// |
| 197 | |
| 198 | LogicalResult TestOpWithVariadicResultsAndFolder::fold( |
| 199 | FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { |
| 200 | for (Value input : this->getOperands()) { |
| 201 | results.push_back(input); |
| 202 | } |
| 203 | return success(); |
| 204 | } |
| 205 | |
| 206 | //===----------------------------------------------------------------------===// |
| 207 | // TestOpInPlaceFold |
| 208 | //===----------------------------------------------------------------------===// |
| 209 | |
| 210 | OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { |
| 211 | // Exercise the fact that an operation created with createOrFold should be |
| 212 | // allowed to access its parent block. |
| 213 | assert(getOperation()->getBlock() && |
| 214 | "expected that operation is not unlinked" ); |
| 215 | |
| 216 | if (adaptor.getOp() && !getProperties().attr) { |
| 217 | // The folder adds "attr" if not present. |
| 218 | getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()); |
| 219 | return getResult(); |
| 220 | } |
| 221 | return {}; |
| 222 | } |
| 223 | |
| 224 | //===----------------------------------------------------------------------===// |
| 225 | // OpWithInferTypeInterfaceOp |
| 226 | //===----------------------------------------------------------------------===// |
| 227 | |
| 228 | LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( |
| 229 | MLIRContext *, std::optional<Location> location, ValueRange operands, |
| 230 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
| 231 | SmallVectorImpl<Type> &inferredReturnTypes) { |
| 232 | if (operands[0].getType() != operands[1].getType()) { |
| 233 | return emitOptionalError(location, "operand type mismatch " , |
| 234 | operands[0].getType(), " vs " , |
| 235 | operands[1].getType()); |
| 236 | } |
| 237 | inferredReturnTypes.assign({operands[0].getType()}); |
| 238 | return success(); |
| 239 | } |
| 240 | |
| 241 | //===----------------------------------------------------------------------===// |
| 242 | // OpWithShapedTypeInferTypeInterfaceOp |
| 243 | //===----------------------------------------------------------------------===// |
| 244 | |
| 245 | LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( |
| 246 | MLIRContext *context, std::optional<Location> location, |
| 247 | ValueShapeRange operands, DictionaryAttr attributes, |
| 248 | OpaqueProperties properties, RegionRange regions, |
| 249 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
| 250 | // Create return type consisting of the last element of the first operand. |
| 251 | auto operandType = operands.front().getType(); |
| 252 | auto sval = dyn_cast<ShapedType>(operandType); |
| 253 | if (!sval) |
| 254 | return emitOptionalError(location, "only shaped type operands allowed" ); |
| 255 | int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; |
| 256 | auto type = IntegerType::get(context, 17); |
| 257 | |
| 258 | Attribute encoding; |
| 259 | if (auto rankedTy = dyn_cast<RankedTensorType>(sval)) |
| 260 | encoding = rankedTy.getEncoding(); |
| 261 | inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); |
| 262 | return success(); |
| 263 | } |
| 264 | |
| 265 | LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( |
| 266 | OpBuilder &builder, ValueRange operands, |
| 267 | llvm::SmallVectorImpl<Value> &shapes) { |
| 268 | shapes = SmallVector<Value, 1>{ |
| 269 | builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; |
| 270 | return success(); |
| 271 | } |
| 272 | |
| 273 | //===----------------------------------------------------------------------===// |
| 274 | // OpWithResultShapeInterfaceOp |
| 275 | //===----------------------------------------------------------------------===// |
| 276 | |
| 277 | LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( |
| 278 | OpBuilder &builder, ValueRange operands, |
| 279 | llvm::SmallVectorImpl<Value> &shapes) { |
| 280 | Location loc = getLoc(); |
| 281 | shapes.reserve(operands.size()); |
| 282 | for (Value operand : llvm::reverse(operands)) { |
| 283 | auto rank = cast<RankedTensorType>(operand.getType()).getRank(); |
| 284 | auto currShape = llvm::to_vector<4>( |
| 285 | llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { |
| 286 | return builder.createOrFold<tensor::DimOp>(loc, operand, dim); |
| 287 | })); |
| 288 | shapes.push_back(builder.create<tensor::FromElementsOp>( |
| 289 | getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), |
| 290 | currShape)); |
| 291 | } |
| 292 | return success(); |
| 293 | } |
| 294 | |
| 295 | //===----------------------------------------------------------------------===// |
| 296 | // OpWithResultShapePerDimInterfaceOp |
| 297 | //===----------------------------------------------------------------------===// |
| 298 | |
| 299 | LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( |
| 300 | OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { |
| 301 | Location loc = getLoc(); |
| 302 | shapes.reserve(getNumOperands()); |
| 303 | for (Value operand : llvm::reverse(getOperands())) { |
| 304 | auto tensorType = cast<RankedTensorType>(operand.getType()); |
| 305 | auto currShape = llvm::to_vector<4>(llvm::map_range( |
| 306 | llvm::seq<int64_t>(0, tensorType.getRank()), |
| 307 | [&](int64_t dim) -> OpFoldResult { |
| 308 | return tensorType.isDynamicDim(dim) |
| 309 | ? static_cast<OpFoldResult>( |
| 310 | builder.createOrFold<tensor::DimOp>(loc, operand, |
| 311 | dim)) |
| 312 | : static_cast<OpFoldResult>( |
| 313 | builder.getIndexAttr(tensorType.getDimSize(dim))); |
| 314 | })); |
| 315 | shapes.emplace_back(std::move(currShape)); |
| 316 | } |
| 317 | return success(); |
| 318 | } |
| 319 | |
| 320 | //===----------------------------------------------------------------------===// |
| 321 | // SideEffectOp |
| 322 | //===----------------------------------------------------------------------===// |
| 323 | |
| 324 | namespace { |
| 325 | /// A test resource for side effects. |
| 326 | struct TestResource : public SideEffects::Resource::Base<TestResource> { |
| 327 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) |
| 328 | |
| 329 | StringRef getName() final { return "<Test>" ; } |
| 330 | }; |
| 331 | } // namespace |
| 332 | |
| 333 | void SideEffectOp::getEffects( |
| 334 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 335 | // Check for an effects attribute on the op instance. |
| 336 | ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects" ); |
| 337 | if (!effectsAttr) |
| 338 | return; |
| 339 | |
| 340 | for (Attribute element : effectsAttr) { |
| 341 | DictionaryAttr effectElement = cast<DictionaryAttr>(element); |
| 342 | |
| 343 | // Get the specific memory effect. |
| 344 | MemoryEffects::Effect *effect = |
| 345 | StringSwitch<MemoryEffects::Effect *>( |
| 346 | cast<StringAttr>(effectElement.get("effect" )).getValue()) |
| 347 | .Case("allocate" , MemoryEffects::Allocate::get()) |
| 348 | .Case("free" , MemoryEffects::Free::get()) |
| 349 | .Case("read" , MemoryEffects::Read::get()) |
| 350 | .Case("write" , MemoryEffects::Write::get()); |
| 351 | |
| 352 | // Check for a non-default resource to use. |
| 353 | SideEffects::Resource *resource = SideEffects::DefaultResource::get(); |
| 354 | if (effectElement.get("test_resource" )) |
| 355 | resource = TestResource::get(); |
| 356 | |
| 357 | // Check for a result to affect. |
| 358 | if (effectElement.get("on_result" )) |
| 359 | effects.emplace_back(effect, getOperation()->getOpResults()[0], resource); |
| 360 | else if (Attribute ref = effectElement.get("on_reference" )) |
| 361 | effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource); |
| 362 | else |
| 363 | effects.emplace_back(effect, resource); |
| 364 | } |
| 365 | } |
| 366 | |
| 367 | void SideEffectOp::getEffects( |
| 368 | SmallVectorImpl<TestEffects::EffectInstance> &effects) { |
| 369 | testSideEffectOpGetEffect(getOperation(), effects); |
| 370 | } |
| 371 | |
| 372 | void SideEffectWithRegionOp::getEffects( |
| 373 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 374 | // Check for an effects attribute on the op instance. |
| 375 | ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects" ); |
| 376 | if (!effectsAttr) |
| 377 | return; |
| 378 | |
| 379 | for (Attribute element : effectsAttr) { |
| 380 | DictionaryAttr effectElement = cast<DictionaryAttr>(element); |
| 381 | |
| 382 | // Get the specific memory effect. |
| 383 | MemoryEffects::Effect *effect = |
| 384 | StringSwitch<MemoryEffects::Effect *>( |
| 385 | cast<StringAttr>(effectElement.get("effect" )).getValue()) |
| 386 | .Case("allocate" , MemoryEffects::Allocate::get()) |
| 387 | .Case("free" , MemoryEffects::Free::get()) |
| 388 | .Case("read" , MemoryEffects::Read::get()) |
| 389 | .Case("write" , MemoryEffects::Write::get()); |
| 390 | |
| 391 | // Check for a non-default resource to use. |
| 392 | SideEffects::Resource *resource = SideEffects::DefaultResource::get(); |
| 393 | if (effectElement.get("test_resource" )) |
| 394 | resource = TestResource::get(); |
| 395 | |
| 396 | // Check for a result to affect. |
| 397 | if (effectElement.get("on_result" )) |
| 398 | effects.emplace_back(effect, getOperation()->getOpResults()[0], resource); |
| 399 | else if (effectElement.get("on_operand" )) |
| 400 | effects.emplace_back(effect, &getOperation()->getOpOperands()[0], |
| 401 | resource); |
| 402 | else if (effectElement.get("on_argument" )) |
| 403 | effects.emplace_back(effect, getOperation()->getRegion(0).getArgument(0), |
| 404 | resource); |
| 405 | else if (Attribute ref = effectElement.get("on_reference" )) |
| 406 | effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource); |
| 407 | else |
| 408 | effects.emplace_back(effect, resource); |
| 409 | } |
| 410 | } |
| 411 | |
| 412 | void SideEffectWithRegionOp::getEffects( |
| 413 | SmallVectorImpl<TestEffects::EffectInstance> &effects) { |
| 414 | testSideEffectOpGetEffect(getOperation(), effects); |
| 415 | } |
| 416 | |
| 417 | //===----------------------------------------------------------------------===// |
| 418 | // StringAttrPrettyNameOp |
| 419 | //===----------------------------------------------------------------------===// |
| 420 | |
| 421 | // This op has fancy handling of its SSA result name. |
| 422 | ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, |
| 423 | OperationState &result) { |
| 424 | // Add the result types. |
| 425 | for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) |
| 426 | result.addTypes(parser.getBuilder().getIntegerType(32)); |
| 427 | |
| 428 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
| 429 | return failure(); |
| 430 | |
| 431 | // If the attribute dictionary contains no 'names' attribute, infer it from |
| 432 | // the SSA name (if specified). |
| 433 | bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { |
| 434 | return attr.getName() == "names" ; |
| 435 | }); |
| 436 | |
| 437 | // If there was no name specified, check to see if there was a useful name |
| 438 | // specified in the asm file. |
| 439 | if (hadNames || parser.getNumResults() == 0) |
| 440 | return success(); |
| 441 | |
| 442 | SmallVector<StringRef, 4> names; |
| 443 | auto *context = result.getContext(); |
| 444 | |
| 445 | for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { |
| 446 | auto resultName = parser.getResultName(i); |
| 447 | StringRef nameStr; |
| 448 | if (!resultName.first.empty() && !isdigit(resultName.first[0])) |
| 449 | nameStr = resultName.first; |
| 450 | |
| 451 | names.push_back(nameStr); |
| 452 | } |
| 453 | |
| 454 | auto namesAttr = parser.getBuilder().getStrArrayAttr(names); |
| 455 | result.attributes.push_back({StringAttr::get(context, "names" ), namesAttr}); |
| 456 | return success(); |
| 457 | } |
| 458 | |
| 459 | void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { |
| 460 | // Note that we only need to print the "name" attribute if the asmprinter |
| 461 | // result name disagrees with it. This can happen in strange cases, e.g. |
| 462 | // when there are conflicts. |
| 463 | bool namesDisagree = getNames().size() != getNumResults(); |
| 464 | |
| 465 | SmallString<32> resultNameStr; |
| 466 | for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { |
| 467 | resultNameStr.clear(); |
| 468 | llvm::raw_svector_ostream tmpStream(resultNameStr); |
| 469 | p.printOperand(getResult(i), tmpStream); |
| 470 | |
| 471 | auto expectedName = dyn_cast<StringAttr>(getNames()[i]); |
| 472 | if (!expectedName || |
| 473 | tmpStream.str().drop_front() != expectedName.getValue()) { |
| 474 | namesDisagree = true; |
| 475 | } |
| 476 | } |
| 477 | |
| 478 | if (namesDisagree) |
| 479 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); |
| 480 | else |
| 481 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names" }); |
| 482 | } |
| 483 | |
| 484 | // We set the SSA name in the asm syntax to the contents of the name |
| 485 | // attribute. |
| 486 | void StringAttrPrettyNameOp::getAsmResultNames( |
| 487 | function_ref<void(Value, StringRef)> setNameFn) { |
| 488 | |
| 489 | auto value = getNames(); |
| 490 | for (size_t i = 0, e = value.size(); i != e; ++i) |
| 491 | if (auto str = dyn_cast<StringAttr>(value[i])) |
| 492 | if (!str.getValue().empty()) |
| 493 | setNameFn(getResult(i), str.getValue()); |
| 494 | } |
| 495 | |
| 496 | //===----------------------------------------------------------------------===// |
| 497 | // CustomResultsNameOp |
| 498 | //===----------------------------------------------------------------------===// |
| 499 | |
| 500 | void CustomResultsNameOp::getAsmResultNames( |
| 501 | function_ref<void(Value, StringRef)> setNameFn) { |
| 502 | ArrayAttr value = getNames(); |
| 503 | for (size_t i = 0, e = value.size(); i != e; ++i) |
| 504 | if (auto str = dyn_cast<StringAttr>(value[i])) |
| 505 | if (!str.empty()) |
| 506 | setNameFn(getResult(i), str.getValue()); |
| 507 | } |
| 508 | |
| 509 | //===----------------------------------------------------------------------===// |
| 510 | // ResultNameFromTypeOp |
| 511 | //===----------------------------------------------------------------------===// |
| 512 | |
| 513 | void ResultNameFromTypeOp::getAsmResultNames( |
| 514 | function_ref<void(Value, StringRef)> setNameFn) { |
| 515 | auto result = getResult(); |
| 516 | auto setResultNameFn = [&](::llvm::StringRef name) { |
| 517 | setNameFn(result, name); |
| 518 | }; |
| 519 | auto opAsmTypeInterface = |
| 520 | ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType()); |
| 521 | opAsmTypeInterface.getAsmName(setResultNameFn); |
| 522 | } |
| 523 | |
| 524 | //===----------------------------------------------------------------------===// |
| 525 | // BlockArgumentNameFromTypeOp |
| 526 | //===----------------------------------------------------------------------===// |
| 527 | |
| 528 | void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames( |
| 529 | ::mlir::Region ®ion, ::mlir::OpAsmSetValueNameFn setNameFn) { |
| 530 | for (auto &block : region) { |
| 531 | for (auto arg : block.getArguments()) { |
| 532 | if (auto opAsmTypeInterface = |
| 533 | ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) { |
| 534 | auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); }; |
| 535 | opAsmTypeInterface.getAsmName(setArgNameFn); |
| 536 | } |
| 537 | } |
| 538 | } |
| 539 | } |
| 540 | |
| 541 | //===----------------------------------------------------------------------===// |
| 542 | // ResultTypeWithTraitOp |
| 543 | //===----------------------------------------------------------------------===// |
| 544 | |
| 545 | LogicalResult ResultTypeWithTraitOp::verify() { |
| 546 | if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) |
| 547 | return success(); |
| 548 | return emitError("result type should have trait 'TestTypeTrait'" ); |
| 549 | } |
| 550 | |
| 551 | //===----------------------------------------------------------------------===// |
| 552 | // AttrWithTraitOp |
| 553 | //===----------------------------------------------------------------------===// |
| 554 | |
| 555 | LogicalResult AttrWithTraitOp::verify() { |
| 556 | if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) |
| 557 | return success(); |
| 558 | return emitError("'attr' attribute should have trait 'TestAttrTrait'" ); |
| 559 | } |
| 560 | |
| 561 | //===----------------------------------------------------------------------===// |
| 562 | // RegionIfOp |
| 563 | //===----------------------------------------------------------------------===// |
| 564 | |
| 565 | void RegionIfOp::print(OpAsmPrinter &p) { |
| 566 | p << " " ; |
| 567 | p.printOperands(getOperands()); |
| 568 | p << ": " << getOperandTypes(); |
| 569 | p.printArrowTypeList(getResultTypes()); |
| 570 | p << " then " ; |
| 571 | p.printRegion(getThenRegion(), |
| 572 | /*printEntryBlockArgs=*/true, |
| 573 | /*printBlockTerminators=*/true); |
| 574 | p << " else " ; |
| 575 | p.printRegion(getElseRegion(), |
| 576 | /*printEntryBlockArgs=*/true, |
| 577 | /*printBlockTerminators=*/true); |
| 578 | p << " join " ; |
| 579 | p.printRegion(getJoinRegion(), |
| 580 | /*printEntryBlockArgs=*/true, |
| 581 | /*printBlockTerminators=*/true); |
| 582 | } |
| 583 | |
| 584 | ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { |
| 585 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; |
| 586 | SmallVector<Type, 2> operandTypes; |
| 587 | |
| 588 | result.regions.reserve(3); |
| 589 | Region *thenRegion = result.addRegion(); |
| 590 | Region *elseRegion = result.addRegion(); |
| 591 | Region *joinRegion = result.addRegion(); |
| 592 | |
| 593 | // Parse operand, type and arrow type lists. |
| 594 | if (parser.parseOperandList(operandInfos) || |
| 595 | parser.parseColonTypeList(operandTypes) || |
| 596 | parser.parseArrowTypeList(result.types)) |
| 597 | return failure(); |
| 598 | |
| 599 | // Parse all attached regions. |
| 600 | if (parser.parseKeyword("then" ) || parser.parseRegion(*thenRegion, {}, {}) || |
| 601 | parser.parseKeyword("else" ) || parser.parseRegion(*elseRegion, {}, {}) || |
| 602 | parser.parseKeyword("join" ) || parser.parseRegion(*joinRegion, {}, {})) |
| 603 | return failure(); |
| 604 | |
| 605 | return parser.resolveOperands(operandInfos, operandTypes, |
| 606 | parser.getCurrentLocation(), result.operands); |
| 607 | } |
| 608 | |
| 609 | OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
| 610 | assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && |
| 611 | "invalid region index" ); |
| 612 | return getOperands(); |
| 613 | } |
| 614 | |
| 615 | void RegionIfOp::getSuccessorRegions( |
| 616 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 617 | // We always branch to the join region. |
| 618 | if (!point.isParent()) { |
| 619 | if (point != getJoinRegion()) |
| 620 | regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); |
| 621 | else |
| 622 | regions.push_back(RegionSuccessor(getResults())); |
| 623 | return; |
| 624 | } |
| 625 | |
| 626 | // The then and else regions are the entry regions of this op. |
| 627 | regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); |
| 628 | regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); |
| 629 | } |
| 630 | |
| 631 | void RegionIfOp::getRegionInvocationBounds( |
| 632 | ArrayRef<Attribute> operands, |
| 633 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
| 634 | // Each region is invoked at most once. |
| 635 | invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); |
| 636 | } |
| 637 | |
| 638 | //===----------------------------------------------------------------------===// |
| 639 | // AnyCondOp |
| 640 | //===----------------------------------------------------------------------===// |
| 641 | |
| 642 | void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, |
| 643 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 644 | // The parent op branches into the only region, and the region branches back |
| 645 | // to the parent op. |
| 646 | if (point.isParent()) |
| 647 | regions.emplace_back(&getRegion()); |
| 648 | else |
| 649 | regions.emplace_back(getResults()); |
| 650 | } |
| 651 | |
| 652 | void AnyCondOp::getRegionInvocationBounds( |
| 653 | ArrayRef<Attribute> operands, |
| 654 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
| 655 | invocationBounds.emplace_back(1, 1); |
| 656 | } |
| 657 | |
| 658 | //===----------------------------------------------------------------------===// |
| 659 | // SingleBlockImplicitTerminatorOp |
| 660 | //===----------------------------------------------------------------------===// |
| 661 | |
| 662 | /// Testing the correctness of some traits. |
| 663 | static_assert( |
| 664 | llvm::is_detected<OpTrait::has_implicit_terminator_t, |
| 665 | SingleBlockImplicitTerminatorOp>::value, |
| 666 | "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp" ); |
| 667 | static_assert(OpTrait::hasSingleBlockImplicitTerminator< |
| 668 | SingleBlockImplicitTerminatorOp>::value, |
| 669 | "hasSingleBlockImplicitTerminator does not match " |
| 670 | "SingleBlockImplicitTerminatorOp" ); |
| 671 | |
| 672 | //===----------------------------------------------------------------------===// |
| 673 | // SingleNoTerminatorCustomAsmOp |
| 674 | //===----------------------------------------------------------------------===// |
| 675 | |
| 676 | ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, |
| 677 | OperationState &state) { |
| 678 | Region *body = state.addRegion(); |
| 679 | if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) |
| 680 | return failure(); |
| 681 | return success(); |
| 682 | } |
| 683 | |
| 684 | void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { |
| 685 | printer.printRegion( |
| 686 | getRegion(), /*printEntryBlockArgs=*/false, |
| 687 | // This op has a single block without terminators. But explicitly mark |
| 688 | // as not printing block terminators for testing. |
| 689 | /*printBlockTerminators=*/false); |
| 690 | } |
| 691 | |
| 692 | //===----------------------------------------------------------------------===// |
| 693 | // TestVerifiersOp |
| 694 | //===----------------------------------------------------------------------===// |
| 695 | |
| 696 | LogicalResult TestVerifiersOp::verify() { |
| 697 | if (!getRegion().hasOneBlock()) |
| 698 | return emitOpError("`hasOneBlock` trait hasn't been verified" ); |
| 699 | |
| 700 | Operation *definingOp = getInput().getDefiningOp(); |
| 701 | if (definingOp && failed(mlir::verify(definingOp))) |
| 702 | return emitOpError("operand hasn't been verified" ); |
| 703 | |
| 704 | // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier |
| 705 | // loop. |
| 706 | mlir::emitRemark(getLoc(), "success run of verifier" ); |
| 707 | |
| 708 | return success(); |
| 709 | } |
| 710 | |
| 711 | LogicalResult TestVerifiersOp::verifyRegions() { |
| 712 | if (!getRegion().hasOneBlock()) |
| 713 | return emitOpError("`hasOneBlock` trait hasn't been verified" ); |
| 714 | |
| 715 | for (Block &block : getRegion()) |
| 716 | for (Operation &op : block) |
| 717 | if (failed(mlir::verify(&op))) |
| 718 | return emitOpError("nested op hasn't been verified" ); |
| 719 | |
| 720 | // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier |
| 721 | // loop. |
| 722 | mlir::emitRemark(getLoc(), "success run of region verifier" ); |
| 723 | |
| 724 | return success(); |
| 725 | } |
| 726 | |
| 727 | //===----------------------------------------------------------------------===// |
| 728 | // Test InferIntRangeInterface |
| 729 | //===----------------------------------------------------------------------===// |
| 730 | |
| 731 | //===----------------------------------------------------------------------===// |
| 732 | // TestWithBoundsOp |
| 733 | //===----------------------------------------------------------------------===// |
| 734 | |
| 735 | void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
| 736 | SetIntRangeFn setResultRanges) { |
| 737 | setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); |
| 738 | } |
| 739 | |
| 740 | //===----------------------------------------------------------------------===// |
| 741 | // TestWithBoundsRegionOp |
| 742 | //===----------------------------------------------------------------------===// |
| 743 | |
| 744 | ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, |
| 745 | OperationState &result) { |
| 746 | if (parser.parseOptionalAttrDict(result.attributes)) |
| 747 | return failure(); |
| 748 | |
| 749 | // Parse the input argument |
| 750 | OpAsmParser::Argument argInfo; |
| 751 | if (failed(parser.parseArgument(argInfo, true))) |
| 752 | return failure(); |
| 753 | |
| 754 | // Parse the body region, and reuse the operand info as the argument info. |
| 755 | Region *body = result.addRegion(); |
| 756 | return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); |
| 757 | } |
| 758 | |
| 759 | void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { |
| 760 | p.printOptionalAttrDict((*this)->getAttrs()); |
| 761 | p << ' '; |
| 762 | p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, |
| 763 | /*omitType=*/false); |
| 764 | p << ' '; |
| 765 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
| 766 | } |
| 767 | |
| 768 | void TestWithBoundsRegionOp::inferResultRanges( |
| 769 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { |
| 770 | Value arg = getRegion().getArgument(0); |
| 771 | setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); |
| 772 | } |
| 773 | |
| 774 | //===----------------------------------------------------------------------===// |
| 775 | // TestIncrementOp |
| 776 | //===----------------------------------------------------------------------===// |
| 777 | |
| 778 | void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
| 779 | SetIntRangeFn setResultRanges) { |
| 780 | const ConstantIntRanges &range = argRanges[0]; |
| 781 | APInt one(range.umin().getBitWidth(), 1); |
| 782 | setResultRanges(getResult(), |
| 783 | {range.umin().uadd_sat(one), range.umax().uadd_sat(one), |
| 784 | range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); |
| 785 | } |
| 786 | |
| 787 | //===----------------------------------------------------------------------===// |
| 788 | // TestReflectBoundsOp |
| 789 | //===----------------------------------------------------------------------===// |
| 790 | |
| 791 | void TestReflectBoundsOp::inferResultRanges( |
| 792 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { |
| 793 | const ConstantIntRanges &range = argRanges[0]; |
| 794 | MLIRContext *ctx = getContext(); |
| 795 | Builder b(ctx); |
| 796 | Type sIntTy, uIntTy; |
| 797 | // For plain `IntegerType`s, we can derive the appropriate signed and unsigned |
| 798 | // Types for the Attributes. |
| 799 | Type type = getElementTypeOrSelf(getType()); |
| 800 | if (auto intTy = llvm::dyn_cast<IntegerType>(type)) { |
| 801 | unsigned bitwidth = intTy.getWidth(); |
| 802 | sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true); |
| 803 | uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false); |
| 804 | } else { |
| 805 | sIntTy = uIntTy = type; |
| 806 | } |
| 807 | |
| 808 | setUminAttr(b.getIntegerAttr(uIntTy, range.umin())); |
| 809 | setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax())); |
| 810 | setSminAttr(b.getIntegerAttr(sIntTy, range.smin())); |
| 811 | setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax())); |
| 812 | setResultRanges(getResult(), range); |
| 813 | } |
| 814 | |
| 815 | //===----------------------------------------------------------------------===// |
| 816 | // ConversionFuncOp |
| 817 | //===----------------------------------------------------------------------===// |
| 818 | |
| 819 | ParseResult ConversionFuncOp::parse(OpAsmParser &parser, |
| 820 | OperationState &result) { |
| 821 | auto buildFuncType = |
| 822 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
| 823 | function_interface_impl::VariadicFlag, |
| 824 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
| 825 | |
| 826 | return function_interface_impl::parseFunctionOp( |
| 827 | parser, result, /*allowVariadic=*/false, |
| 828 | getFunctionTypeAttrName(result.name), buildFuncType, |
| 829 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
| 830 | } |
| 831 | |
| 832 | void ConversionFuncOp::print(OpAsmPrinter &p) { |
| 833 | function_interface_impl::printFunctionOp( |
| 834 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
| 835 | getArgAttrsAttrName(), getResAttrsAttrName()); |
| 836 | } |
| 837 | |
| 838 | //===----------------------------------------------------------------------===// |
| 839 | // ReifyBoundOp |
| 840 | //===----------------------------------------------------------------------===// |
| 841 | |
| 842 | mlir::presburger::BoundType ReifyBoundOp::getBoundType() { |
| 843 | if (getType() == "EQ" ) |
| 844 | return mlir::presburger::BoundType::EQ; |
| 845 | if (getType() == "LB" ) |
| 846 | return mlir::presburger::BoundType::LB; |
| 847 | if (getType() == "UB" ) |
| 848 | return mlir::presburger::BoundType::UB; |
| 849 | llvm_unreachable("invalid bound type" ); |
| 850 | } |
| 851 | |
| 852 | LogicalResult ReifyBoundOp::verify() { |
| 853 | if (isa<ShapedType>(getVar().getType())) { |
| 854 | if (!getDim().has_value()) |
| 855 | return emitOpError("expected 'dim' attribute for shaped type variable" ); |
| 856 | } else if (getVar().getType().isIndex()) { |
| 857 | if (getDim().has_value()) |
| 858 | return emitOpError("unexpected 'dim' attribute for index variable" ); |
| 859 | } else { |
| 860 | return emitOpError("expected index-typed variable or shape type variable" ); |
| 861 | } |
| 862 | if (getConstant() && getScalable()) |
| 863 | return emitOpError("'scalable' and 'constant' are mutually exlusive" ); |
| 864 | if (getScalable() != getVscaleMin().has_value()) |
| 865 | return emitOpError("expected 'vscale_min' if and only if 'scalable'" ); |
| 866 | if (getScalable() != getVscaleMax().has_value()) |
| 867 | return emitOpError("expected 'vscale_min' if and only if 'scalable'" ); |
| 868 | return success(); |
| 869 | } |
| 870 | |
| 871 | ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() { |
| 872 | if (getDim().has_value()) |
| 873 | return ValueBoundsConstraintSet::Variable(getVar(), *getDim()); |
| 874 | return ValueBoundsConstraintSet::Variable(getVar()); |
| 875 | } |
| 876 | |
| 877 | //===----------------------------------------------------------------------===// |
| 878 | // CompareOp |
| 879 | //===----------------------------------------------------------------------===// |
| 880 | |
| 881 | ValueBoundsConstraintSet::ComparisonOperator |
| 882 | CompareOp::getComparisonOperator() { |
| 883 | if (getCmp() == "EQ" ) |
| 884 | return ValueBoundsConstraintSet::ComparisonOperator::EQ; |
| 885 | if (getCmp() == "LT" ) |
| 886 | return ValueBoundsConstraintSet::ComparisonOperator::LT; |
| 887 | if (getCmp() == "LE" ) |
| 888 | return ValueBoundsConstraintSet::ComparisonOperator::LE; |
| 889 | if (getCmp() == "GT" ) |
| 890 | return ValueBoundsConstraintSet::ComparisonOperator::GT; |
| 891 | if (getCmp() == "GE" ) |
| 892 | return ValueBoundsConstraintSet::ComparisonOperator::GE; |
| 893 | llvm_unreachable("invalid comparison operator" ); |
| 894 | } |
| 895 | |
| 896 | mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() { |
| 897 | if (!getLhsMap()) |
| 898 | return ValueBoundsConstraintSet::Variable(getVarOperands()[0]); |
| 899 | SmallVector<Value> mapOperands( |
| 900 | getVarOperands().slice(0, getLhsMap()->getNumInputs())); |
| 901 | return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands); |
| 902 | } |
| 903 | |
| 904 | mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() { |
| 905 | int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1; |
| 906 | if (!getRhsMap()) |
| 907 | return ValueBoundsConstraintSet::Variable( |
| 908 | getVarOperands()[rhsOperandsBegin]); |
| 909 | SmallVector<Value> mapOperands( |
| 910 | getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs())); |
| 911 | return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands); |
| 912 | } |
| 913 | |
| 914 | LogicalResult CompareOp::verify() { |
| 915 | if (getCompose() && (getLhsMap() || getRhsMap())) |
| 916 | return emitOpError( |
| 917 | "'compose' not supported when 'lhs_map' or 'rhs_map' is present" ); |
| 918 | int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1; |
| 919 | expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1; |
| 920 | if (getVarOperands().size() != size_t(expectedNumOperands)) |
| 921 | return emitOpError("expected " ) |
| 922 | << expectedNumOperands << " operands, but got " |
| 923 | << getVarOperands().size(); |
| 924 | return success(); |
| 925 | } |
| 926 | |
| 927 | //===----------------------------------------------------------------------===// |
| 928 | // TestOpInPlaceSelfFold |
| 929 | //===----------------------------------------------------------------------===// |
| 930 | |
| 931 | OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) { |
| 932 | if (!getFolded()) { |
| 933 | // The folder adds the "folded" if not present. |
| 934 | setFolded(true); |
| 935 | return getResult(); |
| 936 | } |
| 937 | return {}; |
| 938 | } |
| 939 | |
| 940 | //===----------------------------------------------------------------------===// |
| 941 | // TestOpFoldWithFoldAdaptor |
| 942 | //===----------------------------------------------------------------------===// |
| 943 | |
| 944 | OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { |
| 945 | int64_t sum = 0; |
| 946 | if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp())) |
| 947 | sum += value.getValue().getSExtValue(); |
| 948 | |
| 949 | for (Attribute attr : adaptor.getVariadic()) |
| 950 | if (auto value = dyn_cast_or_null<IntegerAttr>(attr)) |
| 951 | sum += 2 * value.getValue().getSExtValue(); |
| 952 | |
| 953 | for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar()) |
| 954 | for (Attribute attr : attrs) |
| 955 | if (auto value = dyn_cast_or_null<IntegerAttr>(attr)) |
| 956 | sum += 3 * value.getValue().getSExtValue(); |
| 957 | |
| 958 | sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end()); |
| 959 | |
| 960 | return IntegerAttr::get(getType(), sum); |
| 961 | } |
| 962 | |
| 963 | //===----------------------------------------------------------------------===// |
| 964 | // OpWithInferTypeAdaptorInterfaceOp |
| 965 | //===----------------------------------------------------------------------===// |
| 966 | |
| 967 | LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes( |
| 968 | MLIRContext *, std::optional<Location> location, |
| 969 | OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor, |
| 970 | SmallVectorImpl<Type> &inferredReturnTypes) { |
| 971 | if (adaptor.getX().getType() != adaptor.getY().getType()) { |
| 972 | return emitOptionalError(location, "operand type mismatch " , |
| 973 | adaptor.getX().getType(), " vs " , |
| 974 | adaptor.getY().getType()); |
| 975 | } |
| 976 | inferredReturnTypes.assign({adaptor.getX().getType()}); |
| 977 | return success(); |
| 978 | } |
| 979 | |
| 980 | //===----------------------------------------------------------------------===// |
| 981 | // OpWithRefineTypeInterfaceOp |
| 982 | //===----------------------------------------------------------------------===// |
| 983 | |
| 984 | // TODO: We should be able to only define either inferReturnType or |
| 985 | // refineReturnType, currently only refineReturnType can be omitted. |
| 986 | LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( |
| 987 | MLIRContext *context, std::optional<Location> location, ValueRange operands, |
| 988 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
| 989 | SmallVectorImpl<Type> &returnTypes) { |
| 990 | returnTypes.clear(); |
| 991 | return OpWithRefineTypeInterfaceOp::refineReturnTypes( |
| 992 | context, location, operands, attributes, properties, regions, |
| 993 | returnTypes); |
| 994 | } |
| 995 | |
| 996 | LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( |
| 997 | MLIRContext *, std::optional<Location> location, ValueRange operands, |
| 998 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
| 999 | SmallVectorImpl<Type> &returnTypes) { |
| 1000 | if (operands[0].getType() != operands[1].getType()) { |
| 1001 | return emitOptionalError(location, "operand type mismatch " , |
| 1002 | operands[0].getType(), " vs " , |
| 1003 | operands[1].getType()); |
| 1004 | } |
| 1005 | // TODO: Add helper to make this more concise to write. |
| 1006 | if (returnTypes.empty()) |
| 1007 | returnTypes.resize(1, nullptr); |
| 1008 | if (returnTypes[0] && returnTypes[0] != operands[0].getType()) |
| 1009 | return emitOptionalError(location, |
| 1010 | "required first operand and result to match" ); |
| 1011 | returnTypes[0] = operands[0].getType(); |
| 1012 | return success(); |
| 1013 | } |
| 1014 | |
| 1015 | //===----------------------------------------------------------------------===// |
| 1016 | // OpWithShapedTypeInferTypeAdaptorInterfaceOp |
| 1017 | //===----------------------------------------------------------------------===// |
| 1018 | |
| 1019 | LogicalResult |
| 1020 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents( |
| 1021 | MLIRContext *context, std::optional<Location> location, |
| 1022 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor, |
| 1023 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
| 1024 | // Create return type consisting of the last element of the first operand. |
| 1025 | auto operandType = adaptor.getOperand1().getType(); |
| 1026 | auto sval = dyn_cast<ShapedType>(operandType); |
| 1027 | if (!sval) |
| 1028 | return emitOptionalError(location, "only shaped type operands allowed" ); |
| 1029 | int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; |
| 1030 | auto type = IntegerType::get(context, 17); |
| 1031 | |
| 1032 | Attribute encoding; |
| 1033 | if (auto rankedTy = dyn_cast<RankedTensorType>(sval)) |
| 1034 | encoding = rankedTy.getEncoding(); |
| 1035 | inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); |
| 1036 | return success(); |
| 1037 | } |
| 1038 | |
| 1039 | LogicalResult |
| 1040 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes( |
| 1041 | OpBuilder &builder, ValueRange operands, |
| 1042 | llvm::SmallVectorImpl<Value> &shapes) { |
| 1043 | shapes = SmallVector<Value, 1>{ |
| 1044 | builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; |
| 1045 | return success(); |
| 1046 | } |
| 1047 | |
| 1048 | //===----------------------------------------------------------------------===// |
| 1049 | // TestOpWithPropertiesAndInferredType |
| 1050 | //===----------------------------------------------------------------------===// |
| 1051 | |
| 1052 | LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes( |
| 1053 | MLIRContext *context, std::optional<Location>, ValueRange operands, |
| 1054 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
| 1055 | SmallVectorImpl<Type> &inferredReturnTypes) { |
| 1056 | |
| 1057 | Adaptor adaptor(operands, attributes, properties, regions); |
| 1058 | inferredReturnTypes.push_back(IntegerType::get( |
| 1059 | context, adaptor.getLhs() + adaptor.getProperties().rhs)); |
| 1060 | return success(); |
| 1061 | } |
| 1062 | |
| 1063 | //===----------------------------------------------------------------------===// |
| 1064 | // LoopBlockOp |
| 1065 | //===----------------------------------------------------------------------===// |
| 1066 | |
| 1067 | void LoopBlockOp::getSuccessorRegions( |
| 1068 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 1069 | regions.emplace_back(&getBody(), getBody().getArguments()); |
| 1070 | if (point.isParent()) |
| 1071 | return; |
| 1072 | |
| 1073 | regions.emplace_back((*this)->getResults()); |
| 1074 | } |
| 1075 | |
| 1076 | OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
| 1077 | assert(point == getBody()); |
| 1078 | return MutableOperandRange(getInitMutable()); |
| 1079 | } |
| 1080 | |
| 1081 | //===----------------------------------------------------------------------===// |
| 1082 | // LoopBlockTerminatorOp |
| 1083 | //===----------------------------------------------------------------------===// |
| 1084 | |
| 1085 | MutableOperandRange |
| 1086 | LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { |
| 1087 | if (point.isParent()) |
| 1088 | return getExitArgMutable(); |
| 1089 | return getNextIterArgMutable(); |
| 1090 | } |
| 1091 | |
| 1092 | //===----------------------------------------------------------------------===// |
| 1093 | // SwitchWithNoBreakOp |
| 1094 | //===----------------------------------------------------------------------===// |
| 1095 | |
| 1096 | void TestNoTerminatorOp::getSuccessorRegions( |
| 1097 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {} |
| 1098 | |
| 1099 | //===----------------------------------------------------------------------===// |
| 1100 | // Test InferIntRangeInterface |
| 1101 | //===----------------------------------------------------------------------===// |
| 1102 | |
| 1103 | OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) { |
| 1104 | // Just a simple fold for testing purposes that reads an operands constant |
| 1105 | // value and returns it. |
| 1106 | if (!attributes.empty()) |
| 1107 | return attributes.front(); |
| 1108 | return nullptr; |
| 1109 | } |
| 1110 | |
| 1111 | //===----------------------------------------------------------------------===// |
| 1112 | // Tensor/Buffer Ops |
| 1113 | //===----------------------------------------------------------------------===// |
| 1114 | |
| 1115 | void ReadBufferOp::getEffects( |
| 1116 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| 1117 | &effects) { |
| 1118 | // The buffer operand is read. |
| 1119 | effects.emplace_back(MemoryEffects::Read::get(), &getBufferMutable(), |
| 1120 | SideEffects::DefaultResource::get()); |
| 1121 | // The buffer contents are dumped. |
| 1122 | effects.emplace_back(MemoryEffects::Write::get(), |
| 1123 | SideEffects::DefaultResource::get()); |
| 1124 | } |
| 1125 | |
| 1126 | //===----------------------------------------------------------------------===// |
| 1127 | // Test Dataflow |
| 1128 | //===----------------------------------------------------------------------===// |
| 1129 | |
| 1130 | //===----------------------------------------------------------------------===// |
| 1131 | // TestCallAndStoreOp |
| 1132 | //===----------------------------------------------------------------------===// |
| 1133 | |
| 1134 | CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() { |
| 1135 | return getCallee(); |
| 1136 | } |
| 1137 | |
| 1138 | void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
| 1139 | setCalleeAttr(cast<SymbolRefAttr>(callee)); |
| 1140 | } |
| 1141 | |
| 1142 | Operation::operand_range TestCallAndStoreOp::getArgOperands() { |
| 1143 | return getCalleeOperands(); |
| 1144 | } |
| 1145 | |
| 1146 | MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() { |
| 1147 | return getCalleeOperandsMutable(); |
| 1148 | } |
| 1149 | |
| 1150 | //===----------------------------------------------------------------------===// |
| 1151 | // TestCallOnDeviceOp |
| 1152 | //===----------------------------------------------------------------------===// |
| 1153 | |
| 1154 | CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() { |
| 1155 | return getCallee(); |
| 1156 | } |
| 1157 | |
| 1158 | void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
| 1159 | setCalleeAttr(cast<SymbolRefAttr>(callee)); |
| 1160 | } |
| 1161 | |
| 1162 | Operation::operand_range TestCallOnDeviceOp::getArgOperands() { |
| 1163 | return getForwardedOperands(); |
| 1164 | } |
| 1165 | |
| 1166 | MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { |
| 1167 | return getForwardedOperandsMutable(); |
| 1168 | } |
| 1169 | |
| 1170 | //===----------------------------------------------------------------------===// |
| 1171 | // TestStoreWithARegion |
| 1172 | //===----------------------------------------------------------------------===// |
| 1173 | |
| 1174 | void TestStoreWithARegion::getSuccessorRegions( |
| 1175 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 1176 | if (point.isParent()) |
| 1177 | regions.emplace_back(&getBody(), getBody().front().getArguments()); |
| 1178 | else |
| 1179 | regions.emplace_back(); |
| 1180 | } |
| 1181 | |
| 1182 | //===----------------------------------------------------------------------===// |
| 1183 | // TestStoreWithALoopRegion |
| 1184 | //===----------------------------------------------------------------------===// |
| 1185 | |
| 1186 | void TestStoreWithALoopRegion::getSuccessorRegions( |
| 1187 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 1188 | // Both the operation itself and the region may be branching into the body or |
| 1189 | // back into the operation itself. It is possible for the operation not to |
| 1190 | // enter the body. |
| 1191 | regions.emplace_back( |
| 1192 | RegionSuccessor(&getBody(), getBody().front().getArguments())); |
| 1193 | regions.emplace_back(); |
| 1194 | } |
| 1195 | |
| 1196 | //===----------------------------------------------------------------------===// |
| 1197 | // TestVersionedOpA |
| 1198 | //===----------------------------------------------------------------------===// |
| 1199 | |
| 1200 | LogicalResult |
| 1201 | TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader, |
| 1202 | mlir::OperationState &state) { |
| 1203 | auto &prop = state.getOrAddProperties<Properties>(); |
| 1204 | if (mlir::failed(reader.readAttribute(prop.dims))) |
| 1205 | return mlir::failure(); |
| 1206 | |
| 1207 | // Check if we have a version. If not, assume we are parsing the current |
| 1208 | // version. |
| 1209 | auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); |
| 1210 | if (succeeded(maybeVersion)) { |
| 1211 | // If version is less than 2.0, there is no additional attribute to parse. |
| 1212 | // We can materialize missing properties post parsing before verification. |
| 1213 | const auto *version = |
| 1214 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
| 1215 | if ((version->major_ < 2)) { |
| 1216 | return success(); |
| 1217 | } |
| 1218 | } |
| 1219 | |
| 1220 | if (mlir::failed(reader.readAttribute(prop.modifier))) |
| 1221 | return mlir::failure(); |
| 1222 | return mlir::success(); |
| 1223 | } |
| 1224 | |
| 1225 | void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) { |
| 1226 | auto &prop = getProperties(); |
| 1227 | writer.writeAttribute(prop.dims); |
| 1228 | |
| 1229 | auto maybeVersion = writer.getDialectVersion<test::TestDialect>(); |
| 1230 | if (succeeded(maybeVersion)) { |
| 1231 | // If version is less than 2.0, there is no additional attribute to write. |
| 1232 | const auto *version = |
| 1233 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
| 1234 | if ((version->major_ < 2)) { |
| 1235 | llvm::outs() << "downgrading op properties...\n" ; |
| 1236 | return; |
| 1237 | } |
| 1238 | } |
| 1239 | writer.writeAttribute(prop.modifier); |
| 1240 | } |
| 1241 | |
| 1242 | //===----------------------------------------------------------------------===// |
| 1243 | // TestOpWithVersionedProperties |
| 1244 | //===----------------------------------------------------------------------===// |
| 1245 | |
| 1246 | llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode( |
| 1247 | mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) { |
| 1248 | uint64_t value1, value2 = 0; |
| 1249 | if (failed(reader.readVarInt(value1))) |
| 1250 | return failure(); |
| 1251 | |
| 1252 | // Check if we have a version. If not, assume we are parsing the current |
| 1253 | // version. |
| 1254 | auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); |
| 1255 | bool needToParseAnotherInt = true; |
| 1256 | if (succeeded(maybeVersion)) { |
| 1257 | // If version is less than 2.0, there is no additional attribute to parse. |
| 1258 | // We can materialize missing properties post parsing before verification. |
| 1259 | const auto *version = |
| 1260 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
| 1261 | if ((version->major_ < 2)) |
| 1262 | needToParseAnotherInt = false; |
| 1263 | } |
| 1264 | if (needToParseAnotherInt && failed(reader.readVarInt(value2))) |
| 1265 | return failure(); |
| 1266 | |
| 1267 | prop.value1 = value1; |
| 1268 | prop.value2 = value2; |
| 1269 | return success(); |
| 1270 | } |
| 1271 | |
| 1272 | void TestOpWithVersionedProperties::writeToMlirBytecode( |
| 1273 | mlir::DialectBytecodeWriter &writer, |
| 1274 | const test::VersionedProperties &prop) { |
| 1275 | writer.writeVarInt(prop.value1); |
| 1276 | writer.writeVarInt(prop.value2); |
| 1277 | } |
| 1278 | |
| 1279 | //===----------------------------------------------------------------------===// |
| 1280 | // TestMultiSlotAlloca |
| 1281 | //===----------------------------------------------------------------------===// |
| 1282 | |
| 1283 | llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() { |
| 1284 | SmallVector<MemorySlot> slots; |
| 1285 | for (Value result : getResults()) { |
| 1286 | slots.push_back(MemorySlot{ |
| 1287 | result, cast<MemRefType>(result.getType()).getElementType()}); |
| 1288 | } |
| 1289 | return slots; |
| 1290 | } |
| 1291 | |
| 1292 | Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot, |
| 1293 | OpBuilder &builder) { |
| 1294 | return builder.create<TestOpConstant>(getLoc(), slot.elemType, |
| 1295 | builder.getI32IntegerAttr(42)); |
| 1296 | } |
| 1297 | |
| 1298 | void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot, |
| 1299 | BlockArgument argument, |
| 1300 | OpBuilder &builder) { |
| 1301 | // Not relevant for testing. |
| 1302 | } |
| 1303 | |
| 1304 | /// Creates a new TestMultiSlotAlloca operation, just without the `slot`. |
| 1305 | static std::optional<TestMultiSlotAlloca> |
| 1306 | createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder, |
| 1307 | TestMultiSlotAlloca oldOp) { |
| 1308 | |
| 1309 | if (oldOp.getNumResults() == 1) { |
| 1310 | oldOp.erase(); |
| 1311 | return std::nullopt; |
| 1312 | } |
| 1313 | |
| 1314 | SmallVector<Type> newTypes; |
| 1315 | SmallVector<Value> remainingValues; |
| 1316 | |
| 1317 | for (Value oldResult : oldOp.getResults()) { |
| 1318 | if (oldResult == slot.ptr) |
| 1319 | continue; |
| 1320 | remainingValues.push_back(oldResult); |
| 1321 | newTypes.push_back(oldResult.getType()); |
| 1322 | } |
| 1323 | |
| 1324 | OpBuilder::InsertionGuard guard(builder); |
| 1325 | builder.setInsertionPoint(oldOp); |
| 1326 | auto replacement = |
| 1327 | builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes); |
| 1328 | for (auto [oldResult, newResult] : |
| 1329 | llvm::zip_equal(remainingValues, replacement.getResults())) |
| 1330 | oldResult.replaceAllUsesWith(newResult); |
| 1331 | |
| 1332 | oldOp.erase(); |
| 1333 | return replacement; |
| 1334 | } |
| 1335 | |
| 1336 | std::optional<PromotableAllocationOpInterface> |
| 1337 | TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot, |
| 1338 | Value defaultValue, |
| 1339 | OpBuilder &builder) { |
| 1340 | if (defaultValue && defaultValue.use_empty()) |
| 1341 | defaultValue.getDefiningOp()->erase(); |
| 1342 | return createNewMultiAllocaWithoutSlot(slot, builder, *this); |
| 1343 | } |
| 1344 | |
| 1345 | SmallVector<DestructurableMemorySlot> |
| 1346 | TestMultiSlotAlloca::getDestructurableSlots() { |
| 1347 | SmallVector<DestructurableMemorySlot> slots; |
| 1348 | for (Value result : getResults()) { |
| 1349 | auto memrefType = cast<MemRefType>(result.getType()); |
| 1350 | auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType); |
| 1351 | if (!destructurable) |
| 1352 | continue; |
| 1353 | |
| 1354 | std::optional<DenseMap<Attribute, Type>> destructuredType = |
| 1355 | destructurable.getSubelementIndexMap(); |
| 1356 | if (!destructuredType) |
| 1357 | continue; |
| 1358 | slots.emplace_back( |
| 1359 | DestructurableMemorySlot{{result, memrefType}, *destructuredType}); |
| 1360 | } |
| 1361 | return slots; |
| 1362 | } |
| 1363 | |
| 1364 | DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure( |
| 1365 | const DestructurableMemorySlot &slot, |
| 1366 | const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder, |
| 1367 | SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) { |
| 1368 | OpBuilder::InsertionGuard guard(builder); |
| 1369 | builder.setInsertionPointAfter(*this); |
| 1370 | |
| 1371 | DenseMap<Attribute, MemorySlot> slotMap; |
| 1372 | |
| 1373 | for (Attribute usedIndex : usedIndices) { |
| 1374 | Type elemType = slot.subelementTypes.lookup(usedIndex); |
| 1375 | MemRefType elemPtr = MemRefType::get({}, elemType); |
| 1376 | auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr); |
| 1377 | newAllocators.push_back(subAlloca); |
| 1378 | slotMap.try_emplace<MemorySlot>(usedIndex, |
| 1379 | {subAlloca.getResult(0), elemType}); |
| 1380 | } |
| 1381 | |
| 1382 | return slotMap; |
| 1383 | } |
| 1384 | |
| 1385 | std::optional<DestructurableAllocationOpInterface> |
| 1386 | TestMultiSlotAlloca::handleDestructuringComplete( |
| 1387 | const DestructurableMemorySlot &slot, OpBuilder &builder) { |
| 1388 | return createNewMultiAllocaWithoutSlot(slot, builder, *this); |
| 1389 | } |
| 1390 | |