| 1 | //===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "TestDialect.h" |
| 10 | #include "TestOps.h" |
| 11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 12 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 13 | #include "mlir/IR/Verifier.h" |
| 14 | #include "mlir/Interfaces/FunctionImplementation.h" |
| 15 | #include "mlir/Interfaces/MemorySlotInterfaces.h" |
| 16 | |
| 17 | using namespace mlir; |
| 18 | using namespace test; |
| 19 | |
| 20 | //===----------------------------------------------------------------------===// |
| 21 | // TestBranchOp |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | |
| 24 | SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { |
| 25 | assert(index == 0 && "invalid successor index" ); |
| 26 | return SuccessorOperands(getTargetOperandsMutable()); |
| 27 | } |
| 28 | |
| 29 | //===----------------------------------------------------------------------===// |
| 30 | // TestProducingBranchOp |
| 31 | //===----------------------------------------------------------------------===// |
| 32 | |
| 33 | SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { |
| 34 | assert(index <= 1 && "invalid successor index" ); |
| 35 | if (index == 1) |
| 36 | return SuccessorOperands(getFirstOperandsMutable()); |
| 37 | return SuccessorOperands(getSecondOperandsMutable()); |
| 38 | } |
| 39 | |
| 40 | //===----------------------------------------------------------------------===// |
| 41 | // TestInternalBranchOp |
| 42 | //===----------------------------------------------------------------------===// |
| 43 | |
| 44 | SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { |
| 45 | assert(index <= 1 && "invalid successor index" ); |
| 46 | if (index == 0) |
| 47 | return SuccessorOperands(0, getSuccessOperandsMutable()); |
| 48 | return SuccessorOperands(1, getErrorOperandsMutable()); |
| 49 | } |
| 50 | |
| 51 | //===----------------------------------------------------------------------===// |
| 52 | // TestCallOp |
| 53 | //===----------------------------------------------------------------------===// |
| 54 | |
| 55 | LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| 56 | // Check that the callee attribute was specified. |
| 57 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>(name: "callee" ); |
| 58 | if (!fnAttr) |
| 59 | return emitOpError(message: "requires a 'callee' symbol reference attribute" ); |
| 60 | if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(from: *this, symbol: fnAttr)) |
| 61 | return emitOpError() << "'" << fnAttr.getValue() |
| 62 | << "' does not reference a valid function" ; |
| 63 | return success(); |
| 64 | } |
| 65 | |
| 66 | //===----------------------------------------------------------------------===// |
| 67 | // FoldToCallOp |
| 68 | //===----------------------------------------------------------------------===// |
| 69 | |
| 70 | namespace { |
| 71 | struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { |
| 72 | using OpRewritePattern<FoldToCallOp>::OpRewritePattern; |
| 73 | |
| 74 | LogicalResult matchAndRewrite(FoldToCallOp op, |
| 75 | PatternRewriter &rewriter) const override { |
| 76 | rewriter.replaceOpWithNewOp<func::CallOp>(op, args: TypeRange(), |
| 77 | args: op.getCalleeAttr(), args: ValueRange()); |
| 78 | return success(); |
| 79 | } |
| 80 | }; |
| 81 | } // namespace |
| 82 | |
| 83 | void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 84 | MLIRContext *context) { |
| 85 | results.add<FoldToCallOpPattern>(arg&: context); |
| 86 | } |
| 87 | |
| 88 | //===----------------------------------------------------------------------===// |
| 89 | // IsolatedRegionOp - test parsing passthrough operands |
| 90 | //===----------------------------------------------------------------------===// |
| 91 | |
| 92 | ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, |
| 93 | OperationState &result) { |
| 94 | // Parse the input operand. |
| 95 | OpAsmParser::Argument argInfo; |
| 96 | argInfo.type = parser.getBuilder().getIndexType(); |
| 97 | if (parser.parseOperand(result&: argInfo.ssaName) || |
| 98 | parser.resolveOperand(operand: argInfo.ssaName, type: argInfo.type, result&: result.operands)) |
| 99 | return failure(); |
| 100 | |
| 101 | // Parse the body region, and reuse the operand info as the argument info. |
| 102 | Region *body = result.addRegion(); |
| 103 | return parser.parseRegion(region&: *body, arguments: argInfo, /*enableNameShadowing=*/true); |
| 104 | } |
| 105 | |
| 106 | void IsolatedRegionOp::print(OpAsmPrinter &p) { |
| 107 | p << ' '; |
| 108 | p.printOperand(value: getOperand()); |
| 109 | p.shadowRegionArgs(region&: getRegion(), namesToUse: getOperand()); |
| 110 | p << ' '; |
| 111 | p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false); |
| 112 | } |
| 113 | |
| 114 | //===----------------------------------------------------------------------===// |
| 115 | // SSACFGRegionOp |
| 116 | //===----------------------------------------------------------------------===// |
| 117 | |
| 118 | RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { |
| 119 | return RegionKind::SSACFG; |
| 120 | } |
| 121 | |
| 122 | //===----------------------------------------------------------------------===// |
| 123 | // GraphRegionOp |
| 124 | //===----------------------------------------------------------------------===// |
| 125 | |
| 126 | RegionKind GraphRegionOp::getRegionKind(unsigned index) { |
| 127 | return RegionKind::Graph; |
| 128 | } |
| 129 | |
| 130 | //===----------------------------------------------------------------------===// |
| 131 | // IsolatedGraphRegionOp |
| 132 | //===----------------------------------------------------------------------===// |
| 133 | |
| 134 | RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) { |
| 135 | return RegionKind::Graph; |
| 136 | } |
| 137 | |
| 138 | //===----------------------------------------------------------------------===// |
| 139 | // AffineScopeOp |
| 140 | //===----------------------------------------------------------------------===// |
| 141 | |
| 142 | ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { |
| 143 | // Parse the body region, and reuse the operand info as the argument info. |
| 144 | Region *body = result.addRegion(); |
| 145 | return parser.parseRegion(region&: *body, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}); |
| 146 | } |
| 147 | |
| 148 | void AffineScopeOp::print(OpAsmPrinter &p) { |
| 149 | p << " " ; |
| 150 | p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false); |
| 151 | } |
| 152 | |
| 153 | //===----------------------------------------------------------------------===// |
| 154 | // TestRemoveOpWithInnerOps |
| 155 | //===----------------------------------------------------------------------===// |
| 156 | |
| 157 | namespace { |
| 158 | struct TestRemoveOpWithInnerOps |
| 159 | : public OpRewritePattern<TestOpWithRegionPattern> { |
| 160 | using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; |
| 161 | |
| 162 | void initialize() { setDebugName("TestRemoveOpWithInnerOps" ); } |
| 163 | |
| 164 | LogicalResult matchAndRewrite(TestOpWithRegionPattern op, |
| 165 | PatternRewriter &rewriter) const override { |
| 166 | rewriter.eraseOp(op); |
| 167 | return success(); |
| 168 | } |
| 169 | }; |
| 170 | } // namespace |
| 171 | |
| 172 | //===----------------------------------------------------------------------===// |
| 173 | // TestOpWithRegionPattern |
| 174 | //===----------------------------------------------------------------------===// |
| 175 | |
| 176 | void TestOpWithRegionPattern::getCanonicalizationPatterns( |
| 177 | RewritePatternSet &results, MLIRContext *context) { |
| 178 | results.add<TestRemoveOpWithInnerOps>(arg&: context); |
| 179 | } |
| 180 | |
| 181 | //===----------------------------------------------------------------------===// |
| 182 | // TestOpWithRegionFold |
| 183 | //===----------------------------------------------------------------------===// |
| 184 | |
| 185 | OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) { |
| 186 | return getOperand(); |
| 187 | } |
| 188 | |
| 189 | //===----------------------------------------------------------------------===// |
| 190 | // TestOpConstant |
| 191 | //===----------------------------------------------------------------------===// |
| 192 | |
| 193 | OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } |
| 194 | |
| 195 | //===----------------------------------------------------------------------===// |
| 196 | // TestOpWithVariadicResultsAndFolder |
| 197 | //===----------------------------------------------------------------------===// |
| 198 | |
| 199 | LogicalResult TestOpWithVariadicResultsAndFolder::fold( |
| 200 | FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { |
| 201 | for (Value input : this->getOperands()) { |
| 202 | results.push_back(Elt: input); |
| 203 | } |
| 204 | return success(); |
| 205 | } |
| 206 | |
| 207 | //===----------------------------------------------------------------------===// |
| 208 | // TestOpInPlaceFold |
| 209 | //===----------------------------------------------------------------------===// |
| 210 | |
| 211 | OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { |
| 212 | // Exercise the fact that an operation created with createOrFold should be |
| 213 | // allowed to access its parent block. |
| 214 | assert(getOperation()->getBlock() && |
| 215 | "expected that operation is not unlinked" ); |
| 216 | |
| 217 | if (adaptor.getOp() && !getProperties().attr) { |
| 218 | // The folder adds "attr" if not present. |
| 219 | getProperties().attr = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getOp()); |
| 220 | return getResult(); |
| 221 | } |
| 222 | return {}; |
| 223 | } |
| 224 | |
| 225 | //===----------------------------------------------------------------------===// |
| 226 | // OpWithInferTypeInterfaceOp |
| 227 | //===----------------------------------------------------------------------===// |
| 228 | |
| 229 | LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( |
| 230 | MLIRContext *, std::optional<Location> location, ValueRange operands, |
| 231 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
| 232 | SmallVectorImpl<Type> &inferredReturnTypes) { |
| 233 | if (operands[0].getType() != operands[1].getType()) { |
| 234 | return emitOptionalError(loc: location, args: "operand type mismatch " , |
| 235 | args: operands[0].getType(), args: " vs " , |
| 236 | args: operands[1].getType()); |
| 237 | } |
| 238 | inferredReturnTypes.assign(IL: {operands[0].getType()}); |
| 239 | return success(); |
| 240 | } |
| 241 | |
| 242 | //===----------------------------------------------------------------------===// |
| 243 | // OpWithShapedTypeInferTypeInterfaceOp |
| 244 | //===----------------------------------------------------------------------===// |
| 245 | |
| 246 | LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( |
| 247 | MLIRContext *context, std::optional<Location> location, |
| 248 | ValueShapeRange operands, DictionaryAttr attributes, |
| 249 | OpaqueProperties properties, RegionRange regions, |
| 250 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
| 251 | // Create return type consisting of the last element of the first operand. |
| 252 | auto operandType = operands.front().getType(); |
| 253 | auto sval = dyn_cast<ShapedType>(Val&: operandType); |
| 254 | if (!sval) |
| 255 | return emitOptionalError(loc: location, args: "only shaped type operands allowed" ); |
| 256 | int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; |
| 257 | auto type = IntegerType::get(context, width: 17); |
| 258 | |
| 259 | Attribute encoding; |
| 260 | if (auto rankedTy = dyn_cast<RankedTensorType>(Val&: sval)) |
| 261 | encoding = rankedTy.getEncoding(); |
| 262 | inferredReturnShapes.push_back(Elt: ShapedTypeComponents({dim}, type, encoding)); |
| 263 | return success(); |
| 264 | } |
| 265 | |
| 266 | LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( |
| 267 | OpBuilder &builder, ValueRange operands, |
| 268 | llvm::SmallVectorImpl<Value> &shapes) { |
| 269 | shapes = SmallVector<Value, 1>{ |
| 270 | builder.createOrFold<tensor::DimOp>(location: getLoc(), args: operands.front(), args: 0)}; |
| 271 | return success(); |
| 272 | } |
| 273 | |
| 274 | //===----------------------------------------------------------------------===// |
| 275 | // OpWithResultShapeInterfaceOp |
| 276 | //===----------------------------------------------------------------------===// |
| 277 | |
| 278 | LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( |
| 279 | OpBuilder &builder, ValueRange operands, |
| 280 | llvm::SmallVectorImpl<Value> &shapes) { |
| 281 | Location loc = getLoc(); |
| 282 | shapes.reserve(N: operands.size()); |
| 283 | for (Value operand : llvm::reverse(C&: operands)) { |
| 284 | auto rank = cast<RankedTensorType>(Val: operand.getType()).getRank(); |
| 285 | auto currShape = llvm::to_vector<4>( |
| 286 | Range: llvm::map_range(C: llvm::seq<int64_t>(Begin: 0, End: rank), F: [&](int64_t dim) -> Value { |
| 287 | return builder.createOrFold<tensor::DimOp>(location: loc, args&: operand, args&: dim); |
| 288 | })); |
| 289 | shapes.push_back(Elt: builder.create<tensor::FromElementsOp>( |
| 290 | location: getLoc(), args: RankedTensorType::get(shape: {rank}, elementType: builder.getIndexType()), |
| 291 | args&: currShape)); |
| 292 | } |
| 293 | return success(); |
| 294 | } |
| 295 | |
| 296 | //===----------------------------------------------------------------------===// |
| 297 | // OpWithResultShapePerDimInterfaceOp |
| 298 | //===----------------------------------------------------------------------===// |
| 299 | |
| 300 | LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( |
| 301 | OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { |
| 302 | Location loc = getLoc(); |
| 303 | shapes.reserve(N: getNumOperands()); |
| 304 | for (Value operand : llvm::reverse(C: getOperands())) { |
| 305 | auto tensorType = cast<RankedTensorType>(Val: operand.getType()); |
| 306 | auto currShape = llvm::to_vector<4>(Range: llvm::map_range( |
| 307 | C: llvm::seq<int64_t>(Begin: 0, End: tensorType.getRank()), |
| 308 | F: [&](int64_t dim) -> OpFoldResult { |
| 309 | return tensorType.isDynamicDim(idx: dim) |
| 310 | ? static_cast<OpFoldResult>( |
| 311 | builder.createOrFold<tensor::DimOp>(location: loc, args&: operand, |
| 312 | args&: dim)) |
| 313 | : static_cast<OpFoldResult>( |
| 314 | builder.getIndexAttr(value: tensorType.getDimSize(idx: dim))); |
| 315 | })); |
| 316 | shapes.emplace_back(Args: std::move(currShape)); |
| 317 | } |
| 318 | return success(); |
| 319 | } |
| 320 | |
| 321 | //===----------------------------------------------------------------------===// |
| 322 | // SideEffectOp |
| 323 | //===----------------------------------------------------------------------===// |
| 324 | |
| 325 | namespace { |
| 326 | /// A test resource for side effects. |
| 327 | struct TestResource : public SideEffects::Resource::Base<TestResource> { |
| 328 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) |
| 329 | |
| 330 | StringRef getName() final { return "<Test>" ; } |
| 331 | }; |
| 332 | } // namespace |
| 333 | |
| 334 | void SideEffectOp::getEffects( |
| 335 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 336 | // Check for an effects attribute on the op instance. |
| 337 | ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>(name: "effects" ); |
| 338 | if (!effectsAttr) |
| 339 | return; |
| 340 | |
| 341 | for (Attribute element : effectsAttr) { |
| 342 | DictionaryAttr effectElement = cast<DictionaryAttr>(Val&: element); |
| 343 | |
| 344 | // Get the specific memory effect. |
| 345 | MemoryEffects::Effect *effect = |
| 346 | StringSwitch<MemoryEffects::Effect *>( |
| 347 | cast<StringAttr>(Val: effectElement.get(name: "effect" )).getValue()) |
| 348 | .Case(S: "allocate" , Value: MemoryEffects::Allocate::get()) |
| 349 | .Case(S: "free" , Value: MemoryEffects::Free::get()) |
| 350 | .Case(S: "read" , Value: MemoryEffects::Read::get()) |
| 351 | .Case(S: "write" , Value: MemoryEffects::Write::get()); |
| 352 | |
| 353 | // Check for a non-default resource to use. |
| 354 | SideEffects::Resource *resource = SideEffects::DefaultResource::get(); |
| 355 | if (effectElement.get(name: "test_resource" )) |
| 356 | resource = TestResource::get(); |
| 357 | |
| 358 | // Check for a result to affect. |
| 359 | if (effectElement.get(name: "on_result" )) |
| 360 | effects.emplace_back(Args&: effect, Args: getOperation()->getOpResults()[0], Args&: resource); |
| 361 | else if (Attribute ref = effectElement.get(name: "on_reference" )) |
| 362 | effects.emplace_back(Args&: effect, Args: cast<SymbolRefAttr>(Val&: ref), Args&: resource); |
| 363 | else |
| 364 | effects.emplace_back(Args&: effect, Args&: resource); |
| 365 | } |
| 366 | } |
| 367 | |
| 368 | void SideEffectOp::getEffects( |
| 369 | SmallVectorImpl<TestEffects::EffectInstance> &effects) { |
| 370 | testSideEffectOpGetEffect(op: getOperation(), effects); |
| 371 | } |
| 372 | |
| 373 | void SideEffectWithRegionOp::getEffects( |
| 374 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 375 | // Check for an effects attribute on the op instance. |
| 376 | ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>(name: "effects" ); |
| 377 | if (!effectsAttr) |
| 378 | return; |
| 379 | |
| 380 | for (Attribute element : effectsAttr) { |
| 381 | DictionaryAttr effectElement = cast<DictionaryAttr>(Val&: element); |
| 382 | |
| 383 | // Get the specific memory effect. |
| 384 | MemoryEffects::Effect *effect = |
| 385 | StringSwitch<MemoryEffects::Effect *>( |
| 386 | cast<StringAttr>(Val: effectElement.get(name: "effect" )).getValue()) |
| 387 | .Case(S: "allocate" , Value: MemoryEffects::Allocate::get()) |
| 388 | .Case(S: "free" , Value: MemoryEffects::Free::get()) |
| 389 | .Case(S: "read" , Value: MemoryEffects::Read::get()) |
| 390 | .Case(S: "write" , Value: MemoryEffects::Write::get()); |
| 391 | |
| 392 | // Check for a non-default resource to use. |
| 393 | SideEffects::Resource *resource = SideEffects::DefaultResource::get(); |
| 394 | if (effectElement.get(name: "test_resource" )) |
| 395 | resource = TestResource::get(); |
| 396 | |
| 397 | // Check for a result to affect. |
| 398 | if (effectElement.get(name: "on_result" )) |
| 399 | effects.emplace_back(Args&: effect, Args: getOperation()->getOpResults()[0], Args&: resource); |
| 400 | else if (effectElement.get(name: "on_operand" )) |
| 401 | effects.emplace_back(Args&: effect, Args: &getOperation()->getOpOperands()[0], |
| 402 | Args&: resource); |
| 403 | else if (effectElement.get(name: "on_argument" )) |
| 404 | effects.emplace_back(Args&: effect, Args: getOperation()->getRegion(index: 0).getArgument(i: 0), |
| 405 | Args&: resource); |
| 406 | else if (Attribute ref = effectElement.get(name: "on_reference" )) |
| 407 | effects.emplace_back(Args&: effect, Args: cast<SymbolRefAttr>(Val&: ref), Args&: resource); |
| 408 | else |
| 409 | effects.emplace_back(Args&: effect, Args&: resource); |
| 410 | } |
| 411 | } |
| 412 | |
| 413 | void SideEffectWithRegionOp::getEffects( |
| 414 | SmallVectorImpl<TestEffects::EffectInstance> &effects) { |
| 415 | testSideEffectOpGetEffect(op: getOperation(), effects); |
| 416 | } |
| 417 | |
| 418 | //===----------------------------------------------------------------------===// |
| 419 | // StringAttrPrettyNameOp |
| 420 | //===----------------------------------------------------------------------===// |
| 421 | |
| 422 | // This op has fancy handling of its SSA result name. |
| 423 | ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, |
| 424 | OperationState &result) { |
| 425 | // Add the result types. |
| 426 | for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) |
| 427 | result.addTypes(newTypes: parser.getBuilder().getIntegerType(width: 32)); |
| 428 | |
| 429 | if (parser.parseOptionalAttrDictWithKeyword(result&: result.attributes)) |
| 430 | return failure(); |
| 431 | |
| 432 | // If the attribute dictionary contains no 'names' attribute, infer it from |
| 433 | // the SSA name (if specified). |
| 434 | bool hadNames = llvm::any_of(Range&: result.attributes, P: [](NamedAttribute attr) { |
| 435 | return attr.getName() == "names" ; |
| 436 | }); |
| 437 | |
| 438 | // If there was no name specified, check to see if there was a useful name |
| 439 | // specified in the asm file. |
| 440 | if (hadNames || parser.getNumResults() == 0) |
| 441 | return success(); |
| 442 | |
| 443 | SmallVector<StringRef, 4> names; |
| 444 | auto *context = result.getContext(); |
| 445 | |
| 446 | for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { |
| 447 | auto resultName = parser.getResultName(resultNo: i); |
| 448 | StringRef nameStr; |
| 449 | if (!resultName.first.empty() && !isdigit(resultName.first[0])) |
| 450 | nameStr = resultName.first; |
| 451 | |
| 452 | names.push_back(Elt: nameStr); |
| 453 | } |
| 454 | |
| 455 | auto namesAttr = parser.getBuilder().getStrArrayAttr(values: names); |
| 456 | result.attributes.push_back(newAttribute: {StringAttr::get(context, bytes: "names" ), namesAttr}); |
| 457 | return success(); |
| 458 | } |
| 459 | |
| 460 | void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { |
| 461 | // Note that we only need to print the "name" attribute if the asmprinter |
| 462 | // result name disagrees with it. This can happen in strange cases, e.g. |
| 463 | // when there are conflicts. |
| 464 | bool namesDisagree = getNames().size() != getNumResults(); |
| 465 | |
| 466 | SmallString<32> resultNameStr; |
| 467 | for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { |
| 468 | resultNameStr.clear(); |
| 469 | llvm::raw_svector_ostream tmpStream(resultNameStr); |
| 470 | p.printOperand(value: getResult(i), os&: tmpStream); |
| 471 | |
| 472 | auto expectedName = dyn_cast<StringAttr>(Val: getNames()[i]); |
| 473 | if (!expectedName || |
| 474 | tmpStream.str().drop_front() != expectedName.getValue()) { |
| 475 | namesDisagree = true; |
| 476 | } |
| 477 | } |
| 478 | |
| 479 | if (namesDisagree) |
| 480 | p.printOptionalAttrDictWithKeyword(attrs: (*this)->getAttrs()); |
| 481 | else |
| 482 | p.printOptionalAttrDictWithKeyword(attrs: (*this)->getAttrs(), elidedAttrs: {"names" }); |
| 483 | } |
| 484 | |
| 485 | // We set the SSA name in the asm syntax to the contents of the name |
| 486 | // attribute. |
| 487 | void StringAttrPrettyNameOp::getAsmResultNames( |
| 488 | function_ref<void(Value, StringRef)> setNameFn) { |
| 489 | |
| 490 | auto value = getNames(); |
| 491 | for (size_t i = 0, e = value.size(); i != e; ++i) |
| 492 | if (auto str = dyn_cast<StringAttr>(Val: value[i])) |
| 493 | if (!str.getValue().empty()) |
| 494 | setNameFn(getResult(i), str.getValue()); |
| 495 | } |
| 496 | |
| 497 | //===----------------------------------------------------------------------===// |
| 498 | // CustomResultsNameOp |
| 499 | //===----------------------------------------------------------------------===// |
| 500 | |
| 501 | void CustomResultsNameOp::getAsmResultNames( |
| 502 | function_ref<void(Value, StringRef)> setNameFn) { |
| 503 | ArrayAttr value = getNames(); |
| 504 | for (size_t i = 0, e = value.size(); i != e; ++i) |
| 505 | if (auto str = dyn_cast<StringAttr>(Val: value[i])) |
| 506 | if (!str.empty()) |
| 507 | setNameFn(getResult(i), str.getValue()); |
| 508 | } |
| 509 | |
| 510 | //===----------------------------------------------------------------------===// |
| 511 | // ResultNameFromTypeOp |
| 512 | //===----------------------------------------------------------------------===// |
| 513 | |
| 514 | void ResultNameFromTypeOp::getAsmResultNames( |
| 515 | function_ref<void(Value, StringRef)> setNameFn) { |
| 516 | auto result = getResult(); |
| 517 | auto setResultNameFn = [&](::llvm::StringRef name) { |
| 518 | setNameFn(result, name); |
| 519 | }; |
| 520 | auto opAsmTypeInterface = |
| 521 | ::mlir::cast<::mlir::OpAsmTypeInterface>(Val: result.getType()); |
| 522 | opAsmTypeInterface.getAsmName(setNameFn: setResultNameFn); |
| 523 | } |
| 524 | |
| 525 | //===----------------------------------------------------------------------===// |
| 526 | // BlockArgumentNameFromTypeOp |
| 527 | //===----------------------------------------------------------------------===// |
| 528 | |
| 529 | void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames( |
| 530 | ::mlir::Region ®ion, ::mlir::OpAsmSetValueNameFn setNameFn) { |
| 531 | for (auto &block : region) { |
| 532 | for (auto arg : block.getArguments()) { |
| 533 | if (auto opAsmTypeInterface = |
| 534 | ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(Val: arg.getType())) { |
| 535 | auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); }; |
| 536 | opAsmTypeInterface.getAsmName(setNameFn: setArgNameFn); |
| 537 | } |
| 538 | } |
| 539 | } |
| 540 | } |
| 541 | |
| 542 | //===----------------------------------------------------------------------===// |
| 543 | // ResultTypeWithTraitOp |
| 544 | //===----------------------------------------------------------------------===// |
| 545 | |
| 546 | LogicalResult ResultTypeWithTraitOp::verify() { |
| 547 | if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) |
| 548 | return success(); |
| 549 | return emitError(message: "result type should have trait 'TestTypeTrait'" ); |
| 550 | } |
| 551 | |
| 552 | //===----------------------------------------------------------------------===// |
| 553 | // AttrWithTraitOp |
| 554 | //===----------------------------------------------------------------------===// |
| 555 | |
| 556 | LogicalResult AttrWithTraitOp::verify() { |
| 557 | if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) |
| 558 | return success(); |
| 559 | return emitError(message: "'attr' attribute should have trait 'TestAttrTrait'" ); |
| 560 | } |
| 561 | |
| 562 | //===----------------------------------------------------------------------===// |
| 563 | // RegionIfOp |
| 564 | //===----------------------------------------------------------------------===// |
| 565 | |
| 566 | void RegionIfOp::print(OpAsmPrinter &p) { |
| 567 | p << " " ; |
| 568 | p.printOperands(container: getOperands()); |
| 569 | p << ": " << getOperandTypes(); |
| 570 | p.printArrowTypeList(types: getResultTypes()); |
| 571 | p << " then " ; |
| 572 | p.printRegion(blocks&: getThenRegion(), |
| 573 | /*printEntryBlockArgs=*/true, |
| 574 | /*printBlockTerminators=*/true); |
| 575 | p << " else " ; |
| 576 | p.printRegion(blocks&: getElseRegion(), |
| 577 | /*printEntryBlockArgs=*/true, |
| 578 | /*printBlockTerminators=*/true); |
| 579 | p << " join " ; |
| 580 | p.printRegion(blocks&: getJoinRegion(), |
| 581 | /*printEntryBlockArgs=*/true, |
| 582 | /*printBlockTerminators=*/true); |
| 583 | } |
| 584 | |
| 585 | ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { |
| 586 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; |
| 587 | SmallVector<Type, 2> operandTypes; |
| 588 | |
| 589 | result.regions.reserve(N: 3); |
| 590 | Region *thenRegion = result.addRegion(); |
| 591 | Region *elseRegion = result.addRegion(); |
| 592 | Region *joinRegion = result.addRegion(); |
| 593 | |
| 594 | // Parse operand, type and arrow type lists. |
| 595 | if (parser.parseOperandList(result&: operandInfos) || |
| 596 | parser.parseColonTypeList(result&: operandTypes) || |
| 597 | parser.parseArrowTypeList(result&: result.types)) |
| 598 | return failure(); |
| 599 | |
| 600 | // Parse all attached regions. |
| 601 | if (parser.parseKeyword(keyword: "then" ) || parser.parseRegion(region&: *thenRegion, arguments: {}, enableNameShadowing: {}) || |
| 602 | parser.parseKeyword(keyword: "else" ) || parser.parseRegion(region&: *elseRegion, arguments: {}, enableNameShadowing: {}) || |
| 603 | parser.parseKeyword(keyword: "join" ) || parser.parseRegion(region&: *joinRegion, arguments: {}, enableNameShadowing: {})) |
| 604 | return failure(); |
| 605 | |
| 606 | return parser.resolveOperands(operands&: operandInfos, types&: operandTypes, |
| 607 | loc: parser.getCurrentLocation(), result&: result.operands); |
| 608 | } |
| 609 | |
| 610 | OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
| 611 | assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && |
| 612 | "invalid region index" ); |
| 613 | return getOperands(); |
| 614 | } |
| 615 | |
| 616 | void RegionIfOp::getSuccessorRegions( |
| 617 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 618 | // We always branch to the join region. |
| 619 | if (!point.isParent()) { |
| 620 | if (point != getJoinRegion()) |
| 621 | regions.push_back(Elt: RegionSuccessor(&getJoinRegion(), getJoinArgs())); |
| 622 | else |
| 623 | regions.push_back(Elt: RegionSuccessor(getResults())); |
| 624 | return; |
| 625 | } |
| 626 | |
| 627 | // The then and else regions are the entry regions of this op. |
| 628 | regions.push_back(Elt: RegionSuccessor(&getThenRegion(), getThenArgs())); |
| 629 | regions.push_back(Elt: RegionSuccessor(&getElseRegion(), getElseArgs())); |
| 630 | } |
| 631 | |
| 632 | void RegionIfOp::getRegionInvocationBounds( |
| 633 | ArrayRef<Attribute> operands, |
| 634 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
| 635 | // Each region is invoked at most once. |
| 636 | invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); |
| 637 | } |
| 638 | |
| 639 | //===----------------------------------------------------------------------===// |
| 640 | // AnyCondOp |
| 641 | //===----------------------------------------------------------------------===// |
| 642 | |
| 643 | void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, |
| 644 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 645 | // The parent op branches into the only region, and the region branches back |
| 646 | // to the parent op. |
| 647 | if (point.isParent()) |
| 648 | regions.emplace_back(Args: &getRegion()); |
| 649 | else |
| 650 | regions.emplace_back(Args: getResults()); |
| 651 | } |
| 652 | |
| 653 | void AnyCondOp::getRegionInvocationBounds( |
| 654 | ArrayRef<Attribute> operands, |
| 655 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
| 656 | invocationBounds.emplace_back(Args: 1, Args: 1); |
| 657 | } |
| 658 | |
| 659 | //===----------------------------------------------------------------------===// |
| 660 | // SingleBlockImplicitTerminatorOp |
| 661 | //===----------------------------------------------------------------------===// |
| 662 | |
| 663 | /// Testing the correctness of some traits. |
| 664 | static_assert( |
| 665 | llvm::is_detected<OpTrait::has_implicit_terminator_t, |
| 666 | SingleBlockImplicitTerminatorOp>::value, |
| 667 | "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp" ); |
| 668 | static_assert(OpTrait::hasSingleBlockImplicitTerminator< |
| 669 | SingleBlockImplicitTerminatorOp>::value, |
| 670 | "hasSingleBlockImplicitTerminator does not match " |
| 671 | "SingleBlockImplicitTerminatorOp" ); |
| 672 | |
| 673 | //===----------------------------------------------------------------------===// |
| 674 | // SingleNoTerminatorCustomAsmOp |
| 675 | //===----------------------------------------------------------------------===// |
| 676 | |
| 677 | ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, |
| 678 | OperationState &state) { |
| 679 | Region *body = state.addRegion(); |
| 680 | if (parser.parseRegion(region&: *body, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {})) |
| 681 | return failure(); |
| 682 | return success(); |
| 683 | } |
| 684 | |
| 685 | void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { |
| 686 | printer.printRegion( |
| 687 | blocks&: getRegion(), /*printEntryBlockArgs=*/false, |
| 688 | // This op has a single block without terminators. But explicitly mark |
| 689 | // as not printing block terminators for testing. |
| 690 | /*printBlockTerminators=*/false); |
| 691 | } |
| 692 | |
| 693 | //===----------------------------------------------------------------------===// |
| 694 | // TestVerifiersOp |
| 695 | //===----------------------------------------------------------------------===// |
| 696 | |
| 697 | LogicalResult TestVerifiersOp::verify() { |
| 698 | if (!getRegion().hasOneBlock()) |
| 699 | return emitOpError(message: "`hasOneBlock` trait hasn't been verified" ); |
| 700 | |
| 701 | Operation *definingOp = getInput().getDefiningOp(); |
| 702 | if (definingOp && failed(Result: mlir::verify(op: definingOp))) |
| 703 | return emitOpError(message: "operand hasn't been verified" ); |
| 704 | |
| 705 | // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier |
| 706 | // loop. |
| 707 | mlir::emitRemark(loc: getLoc(), message: "success run of verifier" ); |
| 708 | |
| 709 | return success(); |
| 710 | } |
| 711 | |
| 712 | LogicalResult TestVerifiersOp::verifyRegions() { |
| 713 | if (!getRegion().hasOneBlock()) |
| 714 | return emitOpError(message: "`hasOneBlock` trait hasn't been verified" ); |
| 715 | |
| 716 | for (Block &block : getRegion()) |
| 717 | for (Operation &op : block) |
| 718 | if (failed(Result: mlir::verify(op: &op))) |
| 719 | return emitOpError(message: "nested op hasn't been verified" ); |
| 720 | |
| 721 | // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier |
| 722 | // loop. |
| 723 | mlir::emitRemark(loc: getLoc(), message: "success run of region verifier" ); |
| 724 | |
| 725 | return success(); |
| 726 | } |
| 727 | |
| 728 | //===----------------------------------------------------------------------===// |
| 729 | // Test InferIntRangeInterface |
| 730 | //===----------------------------------------------------------------------===// |
| 731 | |
| 732 | //===----------------------------------------------------------------------===// |
| 733 | // TestWithBoundsOp |
| 734 | //===----------------------------------------------------------------------===// |
| 735 | |
| 736 | void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
| 737 | SetIntRangeFn setResultRanges) { |
| 738 | setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); |
| 739 | } |
| 740 | |
| 741 | //===----------------------------------------------------------------------===// |
| 742 | // TestWithBoundsRegionOp |
| 743 | //===----------------------------------------------------------------------===// |
| 744 | |
| 745 | ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, |
| 746 | OperationState &result) { |
| 747 | if (parser.parseOptionalAttrDict(result&: result.attributes)) |
| 748 | return failure(); |
| 749 | |
| 750 | // Parse the input argument |
| 751 | OpAsmParser::Argument argInfo; |
| 752 | if (failed(Result: parser.parseArgument(result&: argInfo, allowType: true))) |
| 753 | return failure(); |
| 754 | |
| 755 | // Parse the body region, and reuse the operand info as the argument info. |
| 756 | Region *body = result.addRegion(); |
| 757 | return parser.parseRegion(region&: *body, arguments: argInfo, /*enableNameShadowing=*/false); |
| 758 | } |
| 759 | |
| 760 | void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { |
| 761 | p.printOptionalAttrDict(attrs: (*this)->getAttrs()); |
| 762 | p << ' '; |
| 763 | p.printRegionArgument(arg: getRegion().getArgument(i: 0), /*argAttrs=*/{}, |
| 764 | /*omitType=*/false); |
| 765 | p << ' '; |
| 766 | p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false); |
| 767 | } |
| 768 | |
| 769 | void TestWithBoundsRegionOp::inferResultRanges( |
| 770 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { |
| 771 | Value arg = getRegion().getArgument(i: 0); |
| 772 | setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); |
| 773 | } |
| 774 | |
| 775 | //===----------------------------------------------------------------------===// |
| 776 | // TestIncrementOp |
| 777 | //===----------------------------------------------------------------------===// |
| 778 | |
| 779 | void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
| 780 | SetIntRangeFn setResultRanges) { |
| 781 | const ConstantIntRanges &range = argRanges[0]; |
| 782 | APInt one(range.umin().getBitWidth(), 1); |
| 783 | setResultRanges(getResult(), |
| 784 | {range.umin().uadd_sat(RHS: one), range.umax().uadd_sat(RHS: one), |
| 785 | range.smin().sadd_sat(RHS: one), range.smax().sadd_sat(RHS: one)}); |
| 786 | } |
| 787 | |
| 788 | //===----------------------------------------------------------------------===// |
| 789 | // TestReflectBoundsOp |
| 790 | //===----------------------------------------------------------------------===// |
| 791 | |
| 792 | void TestReflectBoundsOp::inferResultRanges( |
| 793 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { |
| 794 | const ConstantIntRanges &range = argRanges[0]; |
| 795 | MLIRContext *ctx = getContext(); |
| 796 | Builder b(ctx); |
| 797 | Type sIntTy, uIntTy; |
| 798 | // For plain `IntegerType`s, we can derive the appropriate signed and unsigned |
| 799 | // Types for the Attributes. |
| 800 | Type type = getElementTypeOrSelf(type: getType()); |
| 801 | if (auto intTy = llvm::dyn_cast<IntegerType>(Val&: type)) { |
| 802 | unsigned bitwidth = intTy.getWidth(); |
| 803 | sIntTy = b.getIntegerType(width: bitwidth, /*isSigned=*/true); |
| 804 | uIntTy = b.getIntegerType(width: bitwidth, /*isSigned=*/false); |
| 805 | } else { |
| 806 | sIntTy = uIntTy = type; |
| 807 | } |
| 808 | |
| 809 | setUminAttr(b.getIntegerAttr(type: uIntTy, value: range.umin())); |
| 810 | setUmaxAttr(b.getIntegerAttr(type: uIntTy, value: range.umax())); |
| 811 | setSminAttr(b.getIntegerAttr(type: sIntTy, value: range.smin())); |
| 812 | setSmaxAttr(b.getIntegerAttr(type: sIntTy, value: range.smax())); |
| 813 | setResultRanges(getResult(), range); |
| 814 | } |
| 815 | |
| 816 | //===----------------------------------------------------------------------===// |
| 817 | // ConversionFuncOp |
| 818 | //===----------------------------------------------------------------------===// |
| 819 | |
| 820 | ParseResult ConversionFuncOp::parse(OpAsmParser &parser, |
| 821 | OperationState &result) { |
| 822 | auto buildFuncType = |
| 823 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
| 824 | function_interface_impl::VariadicFlag, |
| 825 | std::string &) { return builder.getFunctionType(inputs: argTypes, results); }; |
| 826 | |
| 827 | return function_interface_impl::parseFunctionOp( |
| 828 | parser, result, /*allowVariadic=*/false, |
| 829 | typeAttrName: getFunctionTypeAttrName(name: result.name), funcTypeBuilder: buildFuncType, |
| 830 | argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name)); |
| 831 | } |
| 832 | |
| 833 | void ConversionFuncOp::print(OpAsmPrinter &p) { |
| 834 | function_interface_impl::printFunctionOp( |
| 835 | p, op: *this, /*isVariadic=*/false, typeAttrName: getFunctionTypeAttrName(), |
| 836 | argAttrsName: getArgAttrsAttrName(), resAttrsName: getResAttrsAttrName()); |
| 837 | } |
| 838 | |
| 839 | //===----------------------------------------------------------------------===// |
| 840 | // TestValueWithBoundsOp |
| 841 | //===----------------------------------------------------------------------===// |
| 842 | |
| 843 | void TestValueWithBoundsOp::populateBoundsForIndexValue( |
| 844 | Value v, ValueBoundsConstraintSet &cstr) { |
| 845 | cstr.bound(value: v) >= getMin().getSExtValue(); |
| 846 | cstr.bound(value: v) <= getMax().getSExtValue(); |
| 847 | } |
| 848 | |
| 849 | //===----------------------------------------------------------------------===// |
| 850 | // ReifyBoundOp |
| 851 | //===----------------------------------------------------------------------===// |
| 852 | |
| 853 | mlir::presburger::BoundType ReifyBoundOp::getBoundType() { |
| 854 | if (getType() == "EQ" ) |
| 855 | return mlir::presburger::BoundType::EQ; |
| 856 | if (getType() == "LB" ) |
| 857 | return mlir::presburger::BoundType::LB; |
| 858 | if (getType() == "UB" ) |
| 859 | return mlir::presburger::BoundType::UB; |
| 860 | llvm_unreachable("invalid bound type" ); |
| 861 | } |
| 862 | |
| 863 | LogicalResult ReifyBoundOp::verify() { |
| 864 | if (isa<ShapedType>(Val: getVar().getType())) { |
| 865 | if (!getDim().has_value()) |
| 866 | return emitOpError(message: "expected 'dim' attribute for shaped type variable" ); |
| 867 | } else if (getVar().getType().isIndex()) { |
| 868 | if (getDim().has_value()) |
| 869 | return emitOpError(message: "unexpected 'dim' attribute for index variable" ); |
| 870 | } else { |
| 871 | return emitOpError(message: "expected index-typed variable or shape type variable" ); |
| 872 | } |
| 873 | if (getConstant() && getScalable()) |
| 874 | return emitOpError(message: "'scalable' and 'constant' are mutually exlusive" ); |
| 875 | if (getScalable() != getVscaleMin().has_value()) |
| 876 | return emitOpError(message: "expected 'vscale_min' if and only if 'scalable'" ); |
| 877 | if (getScalable() != getVscaleMax().has_value()) |
| 878 | return emitOpError(message: "expected 'vscale_min' if and only if 'scalable'" ); |
| 879 | return success(); |
| 880 | } |
| 881 | |
| 882 | ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() { |
| 883 | if (getDim().has_value()) |
| 884 | return ValueBoundsConstraintSet::Variable(getVar(), *getDim()); |
| 885 | return ValueBoundsConstraintSet::Variable(getVar()); |
| 886 | } |
| 887 | |
| 888 | //===----------------------------------------------------------------------===// |
| 889 | // CompareOp |
| 890 | //===----------------------------------------------------------------------===// |
| 891 | |
| 892 | ValueBoundsConstraintSet::ComparisonOperator |
| 893 | CompareOp::getComparisonOperator() { |
| 894 | if (getCmp() == "EQ" ) |
| 895 | return ValueBoundsConstraintSet::ComparisonOperator::EQ; |
| 896 | if (getCmp() == "LT" ) |
| 897 | return ValueBoundsConstraintSet::ComparisonOperator::LT; |
| 898 | if (getCmp() == "LE" ) |
| 899 | return ValueBoundsConstraintSet::ComparisonOperator::LE; |
| 900 | if (getCmp() == "GT" ) |
| 901 | return ValueBoundsConstraintSet::ComparisonOperator::GT; |
| 902 | if (getCmp() == "GE" ) |
| 903 | return ValueBoundsConstraintSet::ComparisonOperator::GE; |
| 904 | llvm_unreachable("invalid comparison operator" ); |
| 905 | } |
| 906 | |
| 907 | mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() { |
| 908 | if (!getLhsMap()) |
| 909 | return ValueBoundsConstraintSet::Variable(getVarOperands()[0]); |
| 910 | SmallVector<Value> mapOperands( |
| 911 | getVarOperands().slice(n: 0, m: getLhsMap()->getNumInputs())); |
| 912 | return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands); |
| 913 | } |
| 914 | |
| 915 | mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() { |
| 916 | int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1; |
| 917 | if (!getRhsMap()) |
| 918 | return ValueBoundsConstraintSet::Variable( |
| 919 | getVarOperands()[rhsOperandsBegin]); |
| 920 | SmallVector<Value> mapOperands( |
| 921 | getVarOperands().slice(n: rhsOperandsBegin, m: getRhsMap()->getNumInputs())); |
| 922 | return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands); |
| 923 | } |
| 924 | |
| 925 | LogicalResult CompareOp::verify() { |
| 926 | if (getCompose() && (getLhsMap() || getRhsMap())) |
| 927 | return emitOpError( |
| 928 | message: "'compose' not supported when 'lhs_map' or 'rhs_map' is present" ); |
| 929 | int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1; |
| 930 | expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1; |
| 931 | if (getVarOperands().size() != size_t(expectedNumOperands)) |
| 932 | return emitOpError(message: "expected " ) |
| 933 | << expectedNumOperands << " operands, but got " |
| 934 | << getVarOperands().size(); |
| 935 | return success(); |
| 936 | } |
| 937 | |
| 938 | //===----------------------------------------------------------------------===// |
| 939 | // TestOpInPlaceSelfFold |
| 940 | //===----------------------------------------------------------------------===// |
| 941 | |
| 942 | OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) { |
| 943 | if (!getFolded()) { |
| 944 | // The folder adds the "folded" if not present. |
| 945 | setFolded(true); |
| 946 | return getResult(); |
| 947 | } |
| 948 | return {}; |
| 949 | } |
| 950 | |
| 951 | //===----------------------------------------------------------------------===// |
| 952 | // TestOpFoldWithFoldAdaptor |
| 953 | //===----------------------------------------------------------------------===// |
| 954 | |
| 955 | OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { |
| 956 | int64_t sum = 0; |
| 957 | if (auto value = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getOp())) |
| 958 | sum += value.getValue().getSExtValue(); |
| 959 | |
| 960 | for (Attribute attr : adaptor.getVariadic()) |
| 961 | if (auto value = dyn_cast_or_null<IntegerAttr>(Val&: attr)) |
| 962 | sum += 2 * value.getValue().getSExtValue(); |
| 963 | |
| 964 | for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar()) |
| 965 | for (Attribute attr : attrs) |
| 966 | if (auto value = dyn_cast_or_null<IntegerAttr>(Val&: attr)) |
| 967 | sum += 3 * value.getValue().getSExtValue(); |
| 968 | |
| 969 | sum += 4 * std::distance(first: adaptor.getBody().begin(), last: adaptor.getBody().end()); |
| 970 | |
| 971 | return IntegerAttr::get(type: getType(), value: sum); |
| 972 | } |
| 973 | |
| 974 | //===----------------------------------------------------------------------===// |
| 975 | // OpWithInferTypeAdaptorInterfaceOp |
| 976 | //===----------------------------------------------------------------------===// |
| 977 | |
| 978 | LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes( |
| 979 | MLIRContext *, std::optional<Location> location, |
| 980 | OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor, |
| 981 | SmallVectorImpl<Type> &inferredReturnTypes) { |
| 982 | if (adaptor.getX().getType() != adaptor.getY().getType()) { |
| 983 | return emitOptionalError(loc: location, args: "operand type mismatch " , |
| 984 | args: adaptor.getX().getType(), args: " vs " , |
| 985 | args: adaptor.getY().getType()); |
| 986 | } |
| 987 | inferredReturnTypes.assign(IL: {adaptor.getX().getType()}); |
| 988 | return success(); |
| 989 | } |
| 990 | |
| 991 | //===----------------------------------------------------------------------===// |
| 992 | // OpWithRefineTypeInterfaceOp |
| 993 | //===----------------------------------------------------------------------===// |
| 994 | |
| 995 | // TODO: We should be able to only define either inferReturnType or |
| 996 | // refineReturnType, currently only refineReturnType can be omitted. |
| 997 | LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( |
| 998 | MLIRContext *context, std::optional<Location> location, ValueRange operands, |
| 999 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
| 1000 | SmallVectorImpl<Type> &returnTypes) { |
| 1001 | returnTypes.clear(); |
| 1002 | return OpWithRefineTypeInterfaceOp::refineReturnTypes( |
| 1003 | context, location, operands, attributes, properties, regions, |
| 1004 | returnTypes); |
| 1005 | } |
| 1006 | |
| 1007 | LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( |
| 1008 | MLIRContext *, std::optional<Location> location, ValueRange operands, |
| 1009 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
| 1010 | SmallVectorImpl<Type> &returnTypes) { |
| 1011 | if (operands[0].getType() != operands[1].getType()) { |
| 1012 | return emitOptionalError(loc: location, args: "operand type mismatch " , |
| 1013 | args: operands[0].getType(), args: " vs " , |
| 1014 | args: operands[1].getType()); |
| 1015 | } |
| 1016 | // TODO: Add helper to make this more concise to write. |
| 1017 | if (returnTypes.empty()) |
| 1018 | returnTypes.resize(N: 1, NV: nullptr); |
| 1019 | if (returnTypes[0] && returnTypes[0] != operands[0].getType()) |
| 1020 | return emitOptionalError(loc: location, |
| 1021 | args: "required first operand and result to match" ); |
| 1022 | returnTypes[0] = operands[0].getType(); |
| 1023 | return success(); |
| 1024 | } |
| 1025 | |
| 1026 | //===----------------------------------------------------------------------===// |
| 1027 | // OpWithShapedTypeInferTypeAdaptorInterfaceOp |
| 1028 | //===----------------------------------------------------------------------===// |
| 1029 | |
| 1030 | LogicalResult |
| 1031 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents( |
| 1032 | MLIRContext *context, std::optional<Location> location, |
| 1033 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor, |
| 1034 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
| 1035 | // Create return type consisting of the last element of the first operand. |
| 1036 | auto operandType = adaptor.getOperand1().getType(); |
| 1037 | auto sval = dyn_cast<ShapedType>(Val&: operandType); |
| 1038 | if (!sval) |
| 1039 | return emitOptionalError(loc: location, args: "only shaped type operands allowed" ); |
| 1040 | int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; |
| 1041 | auto type = IntegerType::get(context, width: 17); |
| 1042 | |
| 1043 | Attribute encoding; |
| 1044 | if (auto rankedTy = dyn_cast<RankedTensorType>(Val&: sval)) |
| 1045 | encoding = rankedTy.getEncoding(); |
| 1046 | inferredReturnShapes.push_back(Elt: ShapedTypeComponents({dim}, type, encoding)); |
| 1047 | return success(); |
| 1048 | } |
| 1049 | |
| 1050 | LogicalResult |
| 1051 | OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes( |
| 1052 | OpBuilder &builder, ValueRange operands, |
| 1053 | llvm::SmallVectorImpl<Value> &shapes) { |
| 1054 | shapes = SmallVector<Value, 1>{ |
| 1055 | builder.createOrFold<tensor::DimOp>(location: getLoc(), args: operands.front(), args: 0)}; |
| 1056 | return success(); |
| 1057 | } |
| 1058 | |
| 1059 | //===----------------------------------------------------------------------===// |
| 1060 | // TestOpWithPropertiesAndInferredType |
| 1061 | //===----------------------------------------------------------------------===// |
| 1062 | |
| 1063 | LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes( |
| 1064 | MLIRContext *context, std::optional<Location>, ValueRange operands, |
| 1065 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
| 1066 | SmallVectorImpl<Type> &inferredReturnTypes) { |
| 1067 | |
| 1068 | Adaptor adaptor(operands, attributes, properties, regions); |
| 1069 | inferredReturnTypes.push_back(Elt: IntegerType::get( |
| 1070 | context, width: adaptor.getLhs() + adaptor.getProperties().rhs)); |
| 1071 | return success(); |
| 1072 | } |
| 1073 | |
| 1074 | //===----------------------------------------------------------------------===// |
| 1075 | // LoopBlockOp |
| 1076 | //===----------------------------------------------------------------------===// |
| 1077 | |
| 1078 | void LoopBlockOp::getSuccessorRegions( |
| 1079 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 1080 | regions.emplace_back(Args: &getBody(), Args: getBody().getArguments()); |
| 1081 | if (point.isParent()) |
| 1082 | return; |
| 1083 | |
| 1084 | regions.emplace_back(Args: (*this)->getResults()); |
| 1085 | } |
| 1086 | |
| 1087 | OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
| 1088 | assert(point == getBody()); |
| 1089 | return MutableOperandRange(getInitMutable()); |
| 1090 | } |
| 1091 | |
| 1092 | //===----------------------------------------------------------------------===// |
| 1093 | // LoopBlockTerminatorOp |
| 1094 | //===----------------------------------------------------------------------===// |
| 1095 | |
| 1096 | MutableOperandRange |
| 1097 | LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { |
| 1098 | if (point.isParent()) |
| 1099 | return getExitArgMutable(); |
| 1100 | return getNextIterArgMutable(); |
| 1101 | } |
| 1102 | |
| 1103 | //===----------------------------------------------------------------------===// |
| 1104 | // SwitchWithNoBreakOp |
| 1105 | //===----------------------------------------------------------------------===// |
| 1106 | |
| 1107 | void TestNoTerminatorOp::getSuccessorRegions( |
| 1108 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {} |
| 1109 | |
| 1110 | //===----------------------------------------------------------------------===// |
| 1111 | // Test InferIntRangeInterface |
| 1112 | //===----------------------------------------------------------------------===// |
| 1113 | |
| 1114 | OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) { |
| 1115 | // Just a simple fold for testing purposes that reads an operands constant |
| 1116 | // value and returns it. |
| 1117 | if (!attributes.empty()) |
| 1118 | return attributes.front(); |
| 1119 | return nullptr; |
| 1120 | } |
| 1121 | |
| 1122 | //===----------------------------------------------------------------------===// |
| 1123 | // Tensor/Buffer Ops |
| 1124 | //===----------------------------------------------------------------------===// |
| 1125 | |
| 1126 | void ReadBufferOp::getEffects( |
| 1127 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| 1128 | &effects) { |
| 1129 | // The buffer operand is read. |
| 1130 | effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &getBufferMutable(), |
| 1131 | Args: SideEffects::DefaultResource::get()); |
| 1132 | // The buffer contents are dumped. |
| 1133 | effects.emplace_back(Args: MemoryEffects::Write::get(), |
| 1134 | Args: SideEffects::DefaultResource::get()); |
| 1135 | } |
| 1136 | |
| 1137 | //===----------------------------------------------------------------------===// |
| 1138 | // Test Dataflow |
| 1139 | //===----------------------------------------------------------------------===// |
| 1140 | |
| 1141 | //===----------------------------------------------------------------------===// |
| 1142 | // TestCallAndStoreOp |
| 1143 | //===----------------------------------------------------------------------===// |
| 1144 | |
| 1145 | CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() { |
| 1146 | return getCallee(); |
| 1147 | } |
| 1148 | |
| 1149 | void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
| 1150 | setCalleeAttr(cast<SymbolRefAttr>(Val&: callee)); |
| 1151 | } |
| 1152 | |
| 1153 | Operation::operand_range TestCallAndStoreOp::getArgOperands() { |
| 1154 | return getCalleeOperands(); |
| 1155 | } |
| 1156 | |
| 1157 | MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() { |
| 1158 | return getCalleeOperandsMutable(); |
| 1159 | } |
| 1160 | |
| 1161 | //===----------------------------------------------------------------------===// |
| 1162 | // TestCallOnDeviceOp |
| 1163 | //===----------------------------------------------------------------------===// |
| 1164 | |
| 1165 | CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() { |
| 1166 | return getCallee(); |
| 1167 | } |
| 1168 | |
| 1169 | void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
| 1170 | setCalleeAttr(cast<SymbolRefAttr>(Val&: callee)); |
| 1171 | } |
| 1172 | |
| 1173 | Operation::operand_range TestCallOnDeviceOp::getArgOperands() { |
| 1174 | return getForwardedOperands(); |
| 1175 | } |
| 1176 | |
| 1177 | MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { |
| 1178 | return getForwardedOperandsMutable(); |
| 1179 | } |
| 1180 | |
| 1181 | //===----------------------------------------------------------------------===// |
| 1182 | // TestStoreWithARegion |
| 1183 | //===----------------------------------------------------------------------===// |
| 1184 | |
| 1185 | void TestStoreWithARegion::getSuccessorRegions( |
| 1186 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 1187 | if (point.isParent()) |
| 1188 | regions.emplace_back(Args: &getBody(), Args: getBody().front().getArguments()); |
| 1189 | else |
| 1190 | regions.emplace_back(); |
| 1191 | } |
| 1192 | |
| 1193 | //===----------------------------------------------------------------------===// |
| 1194 | // TestStoreWithALoopRegion |
| 1195 | //===----------------------------------------------------------------------===// |
| 1196 | |
| 1197 | void TestStoreWithALoopRegion::getSuccessorRegions( |
| 1198 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
| 1199 | // Both the operation itself and the region may be branching into the body or |
| 1200 | // back into the operation itself. It is possible for the operation not to |
| 1201 | // enter the body. |
| 1202 | regions.emplace_back( |
| 1203 | Args: RegionSuccessor(&getBody(), getBody().front().getArguments())); |
| 1204 | regions.emplace_back(); |
| 1205 | } |
| 1206 | |
| 1207 | //===----------------------------------------------------------------------===// |
| 1208 | // TestVersionedOpA |
| 1209 | //===----------------------------------------------------------------------===// |
| 1210 | |
| 1211 | LogicalResult |
| 1212 | TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader, |
| 1213 | mlir::OperationState &state) { |
| 1214 | auto &prop = state.getOrAddProperties<Properties>(); |
| 1215 | if (mlir::failed(Result: reader.readAttribute(result&: prop.dims))) |
| 1216 | return mlir::failure(); |
| 1217 | |
| 1218 | // Check if we have a version. If not, assume we are parsing the current |
| 1219 | // version. |
| 1220 | auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); |
| 1221 | if (succeeded(Result: maybeVersion)) { |
| 1222 | // If version is less than 2.0, there is no additional attribute to parse. |
| 1223 | // We can materialize missing properties post parsing before verification. |
| 1224 | const auto *version = |
| 1225 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
| 1226 | if ((version->major_ < 2)) { |
| 1227 | return success(); |
| 1228 | } |
| 1229 | } |
| 1230 | |
| 1231 | if (mlir::failed(Result: reader.readAttribute(result&: prop.modifier))) |
| 1232 | return mlir::failure(); |
| 1233 | return mlir::success(); |
| 1234 | } |
| 1235 | |
| 1236 | void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) { |
| 1237 | auto &prop = getProperties(); |
| 1238 | writer.writeAttribute(attr: prop.dims); |
| 1239 | |
| 1240 | auto maybeVersion = writer.getDialectVersion<test::TestDialect>(); |
| 1241 | if (succeeded(Result: maybeVersion)) { |
| 1242 | // If version is less than 2.0, there is no additional attribute to write. |
| 1243 | const auto *version = |
| 1244 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
| 1245 | if ((version->major_ < 2)) { |
| 1246 | llvm::outs() << "downgrading op properties...\n" ; |
| 1247 | return; |
| 1248 | } |
| 1249 | } |
| 1250 | writer.writeAttribute(attr: prop.modifier); |
| 1251 | } |
| 1252 | |
| 1253 | //===----------------------------------------------------------------------===// |
| 1254 | // TestOpWithVersionedProperties |
| 1255 | //===----------------------------------------------------------------------===// |
| 1256 | |
| 1257 | llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode( |
| 1258 | mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) { |
| 1259 | uint64_t value1, value2 = 0; |
| 1260 | if (failed(Result: reader.readVarInt(result&: value1))) |
| 1261 | return failure(); |
| 1262 | |
| 1263 | // Check if we have a version. If not, assume we are parsing the current |
| 1264 | // version. |
| 1265 | auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); |
| 1266 | bool needToParseAnotherInt = true; |
| 1267 | if (succeeded(Result: maybeVersion)) { |
| 1268 | // If version is less than 2.0, there is no additional attribute to parse. |
| 1269 | // We can materialize missing properties post parsing before verification. |
| 1270 | const auto *version = |
| 1271 | reinterpret_cast<const TestDialectVersion *>(*maybeVersion); |
| 1272 | if ((version->major_ < 2)) |
| 1273 | needToParseAnotherInt = false; |
| 1274 | } |
| 1275 | if (needToParseAnotherInt && failed(Result: reader.readVarInt(result&: value2))) |
| 1276 | return failure(); |
| 1277 | |
| 1278 | prop.value1 = value1; |
| 1279 | prop.value2 = value2; |
| 1280 | return success(); |
| 1281 | } |
| 1282 | |
| 1283 | void TestOpWithVersionedProperties::writeToMlirBytecode( |
| 1284 | mlir::DialectBytecodeWriter &writer, |
| 1285 | const test::VersionedProperties &prop) { |
| 1286 | writer.writeVarInt(value: prop.value1); |
| 1287 | writer.writeVarInt(value: prop.value2); |
| 1288 | } |
| 1289 | |
| 1290 | //===----------------------------------------------------------------------===// |
| 1291 | // TestMultiSlotAlloca |
| 1292 | //===----------------------------------------------------------------------===// |
| 1293 | |
| 1294 | llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() { |
| 1295 | SmallVector<MemorySlot> slots; |
| 1296 | for (Value result : getResults()) { |
| 1297 | slots.push_back(Elt: MemorySlot{ |
| 1298 | .ptr: result, .elemType: cast<MemRefType>(Val: result.getType()).getElementType()}); |
| 1299 | } |
| 1300 | return slots; |
| 1301 | } |
| 1302 | |
| 1303 | Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot, |
| 1304 | OpBuilder &builder) { |
| 1305 | return builder.create<TestOpConstant>(location: getLoc(), args: slot.elemType, |
| 1306 | args: builder.getI32IntegerAttr(value: 42)); |
| 1307 | } |
| 1308 | |
| 1309 | void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot, |
| 1310 | BlockArgument argument, |
| 1311 | OpBuilder &builder) { |
| 1312 | // Not relevant for testing. |
| 1313 | } |
| 1314 | |
| 1315 | /// Creates a new TestMultiSlotAlloca operation, just without the `slot`. |
| 1316 | static std::optional<TestMultiSlotAlloca> |
| 1317 | createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder, |
| 1318 | TestMultiSlotAlloca oldOp) { |
| 1319 | |
| 1320 | if (oldOp.getNumResults() == 1) { |
| 1321 | oldOp.erase(); |
| 1322 | return std::nullopt; |
| 1323 | } |
| 1324 | |
| 1325 | SmallVector<Type> newTypes; |
| 1326 | SmallVector<Value> remainingValues; |
| 1327 | |
| 1328 | for (Value oldResult : oldOp.getResults()) { |
| 1329 | if (oldResult == slot.ptr) |
| 1330 | continue; |
| 1331 | remainingValues.push_back(Elt: oldResult); |
| 1332 | newTypes.push_back(Elt: oldResult.getType()); |
| 1333 | } |
| 1334 | |
| 1335 | OpBuilder::InsertionGuard guard(builder); |
| 1336 | builder.setInsertionPoint(oldOp); |
| 1337 | auto replacement = |
| 1338 | builder.create<TestMultiSlotAlloca>(location: oldOp->getLoc(), args&: newTypes); |
| 1339 | for (auto [oldResult, newResult] : |
| 1340 | llvm::zip_equal(t&: remainingValues, u: replacement.getResults())) |
| 1341 | oldResult.replaceAllUsesWith(newValue: newResult); |
| 1342 | |
| 1343 | oldOp.erase(); |
| 1344 | return replacement; |
| 1345 | } |
| 1346 | |
| 1347 | std::optional<PromotableAllocationOpInterface> |
| 1348 | TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot, |
| 1349 | Value defaultValue, |
| 1350 | OpBuilder &builder) { |
| 1351 | if (defaultValue && defaultValue.use_empty()) |
| 1352 | defaultValue.getDefiningOp()->erase(); |
| 1353 | return createNewMultiAllocaWithoutSlot(slot, builder, oldOp: *this); |
| 1354 | } |
| 1355 | |
| 1356 | SmallVector<DestructurableMemorySlot> |
| 1357 | TestMultiSlotAlloca::getDestructurableSlots() { |
| 1358 | SmallVector<DestructurableMemorySlot> slots; |
| 1359 | for (Value result : getResults()) { |
| 1360 | auto memrefType = cast<MemRefType>(Val: result.getType()); |
| 1361 | auto destructurable = dyn_cast<DestructurableTypeInterface>(Val&: memrefType); |
| 1362 | if (!destructurable) |
| 1363 | continue; |
| 1364 | |
| 1365 | std::optional<DenseMap<Attribute, Type>> destructuredType = |
| 1366 | destructurable.getSubelementIndexMap(); |
| 1367 | if (!destructuredType) |
| 1368 | continue; |
| 1369 | slots.emplace_back( |
| 1370 | Args: DestructurableMemorySlot{{.ptr: result, .elemType: memrefType}, .subelementTypes: *destructuredType}); |
| 1371 | } |
| 1372 | return slots; |
| 1373 | } |
| 1374 | |
| 1375 | DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure( |
| 1376 | const DestructurableMemorySlot &slot, |
| 1377 | const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder, |
| 1378 | SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) { |
| 1379 | OpBuilder::InsertionGuard guard(builder); |
| 1380 | builder.setInsertionPointAfter(*this); |
| 1381 | |
| 1382 | DenseMap<Attribute, MemorySlot> slotMap; |
| 1383 | |
| 1384 | for (Attribute usedIndex : usedIndices) { |
| 1385 | Type elemType = slot.subelementTypes.lookup(Val: usedIndex); |
| 1386 | MemRefType elemPtr = MemRefType::get(shape: {}, elementType: elemType); |
| 1387 | auto subAlloca = builder.create<TestMultiSlotAlloca>(location: getLoc(), args&: elemPtr); |
| 1388 | newAllocators.push_back(Elt: subAlloca); |
| 1389 | slotMap.try_emplace<MemorySlot>(Key: usedIndex, |
| 1390 | Args: {.ptr: subAlloca.getResult(i: 0), .elemType: elemType}); |
| 1391 | } |
| 1392 | |
| 1393 | return slotMap; |
| 1394 | } |
| 1395 | |
| 1396 | std::optional<DestructurableAllocationOpInterface> |
| 1397 | TestMultiSlotAlloca::handleDestructuringComplete( |
| 1398 | const DestructurableMemorySlot &slot, OpBuilder &builder) { |
| 1399 | return createNewMultiAllocaWithoutSlot(slot, builder, oldOp: *this); |
| 1400 | } |
| 1401 | |
| 1402 | ::mlir::LogicalResult test::TestDummyTensorOp::bufferize( |
| 1403 | ::mlir::RewriterBase &rewriter, |
| 1404 | const ::mlir::bufferization::BufferizationOptions &options, |
| 1405 | ::mlir::bufferization::BufferizationState &state) { |
| 1406 | auto buffer = |
| 1407 | mlir::bufferization::getBuffer(rewriter, value: getInput(), options, state); |
| 1408 | if (mlir::failed(Result: buffer)) |
| 1409 | return failure(); |
| 1410 | |
| 1411 | const auto outType = getOutput().getType(); |
| 1412 | const auto bufferizedOutType = test::TestMemrefType::get( |
| 1413 | context: getContext(), shape: outType.getShape(), elementType: outType.getElementType(), memSpace: nullptr); |
| 1414 | // replace op with memref analogy |
| 1415 | auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>( |
| 1416 | location: getLoc(), args: bufferizedOutType, args&: *buffer); |
| 1417 | |
| 1418 | mlir::bufferization::replaceOpWithBufferizedValues(rewriter, op: getOperation(), |
| 1419 | values: dummyMemrefOp.getResult()); |
| 1420 | |
| 1421 | return mlir::success(); |
| 1422 | } |
| 1423 | |
| 1424 | ::mlir::LogicalResult test::TestCreateTensorOp::bufferize( |
| 1425 | ::mlir::RewriterBase &rewriter, |
| 1426 | const ::mlir::bufferization::BufferizationOptions &options, |
| 1427 | ::mlir::bufferization::BufferizationState &state) { |
| 1428 | // Note: mlir::bufferization::getBufferType() would internally call |
| 1429 | // TestCreateTensorOp::getBufferType() |
| 1430 | const auto bufferizedOutType = |
| 1431 | mlir::bufferization::getBufferType(value: getOutput(), options, state); |
| 1432 | if (mlir::failed(Result: bufferizedOutType)) |
| 1433 | return failure(); |
| 1434 | |
| 1435 | // replace op with memref analogy |
| 1436 | auto createMemrefOp = |
| 1437 | rewriter.create<test::TestCreateMemrefOp>(location: getLoc(), args: *bufferizedOutType); |
| 1438 | |
| 1439 | mlir::bufferization::replaceOpWithBufferizedValues( |
| 1440 | rewriter, op: getOperation(), values: createMemrefOp.getResult()); |
| 1441 | |
| 1442 | return mlir::success(); |
| 1443 | } |
| 1444 | |
| 1445 | mlir::FailureOr<mlir::bufferization::BufferLikeType> |
| 1446 | test::TestCreateTensorOp::getBufferType( |
| 1447 | mlir::Value value, const mlir::bufferization::BufferizationOptions &, |
| 1448 | const mlir::bufferization::BufferizationState &, |
| 1449 | llvm::SmallVector<::mlir::Value> &) { |
| 1450 | const auto type = dyn_cast<test::TestTensorType>(Val: value.getType()); |
| 1451 | if (type == nullptr) |
| 1452 | return failure(); |
| 1453 | |
| 1454 | return cast<mlir::bufferization::BufferLikeType>(Val: test::TestMemrefType::get( |
| 1455 | context: getContext(), shape: type.getShape(), elementType: type.getElementType(), memSpace: nullptr)); |
| 1456 | } |
| 1457 | |