| 1 | //===-- FIROps.cpp --------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "flang/Optimizer/Dialect/FIROps.h" |
| 14 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
| 15 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
| 16 | #include "flang/Optimizer/Dialect/FIROpsSupport.h" |
| 17 | #include "flang/Optimizer/Dialect/FIRType.h" |
| 18 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
| 19 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
| 20 | #include "flang/Optimizer/Support/Utils.h" |
| 21 | #include "mlir/Dialect/CommonFolders.h" |
| 22 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 23 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
| 24 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
| 25 | #include "mlir/IR/Attributes.h" |
| 26 | #include "mlir/IR/BuiltinAttributes.h" |
| 27 | #include "mlir/IR/BuiltinOps.h" |
| 28 | #include "mlir/IR/Diagnostics.h" |
| 29 | #include "mlir/IR/Matchers.h" |
| 30 | #include "mlir/IR/OpDefinition.h" |
| 31 | #include "mlir/IR/PatternMatch.h" |
| 32 | #include "mlir/IR/TypeRange.h" |
| 33 | #include "llvm/ADT/STLExtras.h" |
| 34 | #include "llvm/ADT/SmallVector.h" |
| 35 | #include "llvm/ADT/TypeSwitch.h" |
| 36 | #include "llvm/Support/CommandLine.h" |
| 37 | |
| 38 | namespace { |
| 39 | #include "flang/Optimizer/Dialect/CanonicalizationPatterns.inc" |
| 40 | } // namespace |
| 41 | |
| 42 | static llvm::cl::opt<bool> clUseStrictVolatileVerification( |
| 43 | "strict-fir-volatile-verifier" , llvm::cl::init(false), |
| 44 | llvm::cl::desc( |
| 45 | "use stricter verifier for FIR operations with volatile types" )); |
| 46 | |
| 47 | bool fir::useStrictVolatileVerification() { |
| 48 | return clUseStrictVolatileVerification; |
| 49 | } |
| 50 | |
| 51 | static void propagateAttributes(mlir::Operation *fromOp, |
| 52 | mlir::Operation *toOp) { |
| 53 | if (!fromOp || !toOp) |
| 54 | return; |
| 55 | |
| 56 | for (mlir::NamedAttribute attr : fromOp->getAttrs()) { |
| 57 | if (attr.getName().getValue().starts_with( |
| 58 | mlir::acc::OpenACCDialect::getDialectNamespace())) |
| 59 | toOp->setAttr(attr.getName(), attr.getValue()); |
| 60 | } |
| 61 | } |
| 62 | |
| 63 | /// Return true if a sequence type is of some incomplete size or a record type |
| 64 | /// is malformed or contains an incomplete sequence type. An incomplete sequence |
| 65 | /// type is one with more unknown extents in the type than have been provided |
| 66 | /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by |
| 67 | /// definition. |
| 68 | static bool verifyInType(mlir::Type inType, |
| 69 | llvm::SmallVectorImpl<llvm::StringRef> &visited, |
| 70 | unsigned dynamicExtents = 0) { |
| 71 | if (auto st = mlir::dyn_cast<fir::SequenceType>(inType)) { |
| 72 | auto shape = st.getShape(); |
| 73 | if (shape.size() == 0) |
| 74 | return true; |
| 75 | for (std::size_t i = 0, end = shape.size(); i < end; ++i) { |
| 76 | if (shape[i] != fir::SequenceType::getUnknownExtent()) |
| 77 | continue; |
| 78 | if (dynamicExtents-- == 0) |
| 79 | return true; |
| 80 | } |
| 81 | } else if (auto rt = mlir::dyn_cast<fir::RecordType>(inType)) { |
| 82 | // don't recurse if we're already visiting this one |
| 83 | if (llvm::is_contained(visited, rt.getName())) |
| 84 | return false; |
| 85 | // keep track of record types currently being visited |
| 86 | visited.push_back(Elt: rt.getName()); |
| 87 | for (auto &field : rt.getTypeList()) |
| 88 | if (verifyInType(field.second, visited)) |
| 89 | return true; |
| 90 | visited.pop_back(); |
| 91 | } |
| 92 | return false; |
| 93 | } |
| 94 | |
| 95 | static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { |
| 96 | auto ty = fir::unwrapSequenceType(inType); |
| 97 | if (numParams > 0) { |
| 98 | if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) |
| 99 | return numParams != recTy.getNumLenParams(); |
| 100 | if (auto chrTy = mlir::dyn_cast<fir::CharacterType>(ty)) |
| 101 | return !(numParams == 1 && chrTy.hasDynamicLen()); |
| 102 | return true; |
| 103 | } |
| 104 | if (auto chrTy = mlir::dyn_cast<fir::CharacterType>(ty)) |
| 105 | return !chrTy.hasConstantLen(); |
| 106 | return false; |
| 107 | } |
| 108 | |
| 109 | /// Parser shared by Alloca and Allocmem |
| 110 | /// |
| 111 | /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type |
| 112 | /// ( `(` $typeparams `)` )? ( `,` $shape )? |
| 113 | /// attr-dict-without-keyword |
| 114 | template <typename FN> |
| 115 | static mlir::ParseResult parseAllocatableOp(FN wrapResultType, |
| 116 | mlir::OpAsmParser &parser, |
| 117 | mlir::OperationState &result) { |
| 118 | mlir::Type intype; |
| 119 | if (parser.parseType(result&: intype)) |
| 120 | return mlir::failure(); |
| 121 | auto &builder = parser.getBuilder(); |
| 122 | result.addAttribute("in_type" , mlir::TypeAttr::get(intype)); |
| 123 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
| 124 | llvm::SmallVector<mlir::Type> typeVec; |
| 125 | bool hasOperands = false; |
| 126 | std::int32_t typeparamsSize = 0; |
| 127 | if (!parser.parseOptionalLParen()) { |
| 128 | // parse the LEN params of the derived type. (<params> : <types>) |
| 129 | if (parser.parseOperandList(result&: operands, delimiter: mlir::OpAsmParser::Delimiter::None) || |
| 130 | parser.parseColonTypeList(result&: typeVec) || parser.parseRParen()) |
| 131 | return mlir::failure(); |
| 132 | typeparamsSize = operands.size(); |
| 133 | hasOperands = true; |
| 134 | } |
| 135 | std::int32_t shapeSize = 0; |
| 136 | if (!parser.parseOptionalComma()) { |
| 137 | // parse size to scale by, vector of n dimensions of type index |
| 138 | if (parser.parseOperandList(result&: operands, delimiter: mlir::OpAsmParser::Delimiter::None)) |
| 139 | return mlir::failure(); |
| 140 | shapeSize = operands.size() - typeparamsSize; |
| 141 | auto idxTy = builder.getIndexType(); |
| 142 | for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) |
| 143 | typeVec.push_back(Elt: idxTy); |
| 144 | hasOperands = true; |
| 145 | } |
| 146 | if (hasOperands && |
| 147 | parser.resolveOperands(operands, types&: typeVec, loc: parser.getNameLoc(), |
| 148 | result&: result.operands)) |
| 149 | return mlir::failure(); |
| 150 | mlir::Type restype = wrapResultType(intype); |
| 151 | if (!restype) { |
| 152 | parser.emitError(loc: parser.getNameLoc(), message: "invalid allocate type: " ) << intype; |
| 153 | return mlir::failure(); |
| 154 | } |
| 155 | result.addAttribute("operandSegmentSizes" , builder.getDenseI32ArrayAttr( |
| 156 | {typeparamsSize, shapeSize})); |
| 157 | if (parser.parseOptionalAttrDict(result&: result.attributes) || |
| 158 | parser.addTypeToList(type: restype, result&: result.types)) |
| 159 | return mlir::failure(); |
| 160 | return mlir::success(); |
| 161 | } |
| 162 | |
| 163 | template <typename OP> |
| 164 | static void printAllocatableOp(mlir::OpAsmPrinter &p, OP &op) { |
| 165 | p << ' ' << op.getInType(); |
| 166 | if (!op.getTypeparams().empty()) { |
| 167 | p << '(' << op.getTypeparams() << " : " << op.getTypeparams().getTypes() |
| 168 | << ')'; |
| 169 | } |
| 170 | // print the shape of the allocation (if any); all must be index type |
| 171 | for (auto sh : op.getShape()) { |
| 172 | p << ", " ; |
| 173 | p.printOperand(sh); |
| 174 | } |
| 175 | p.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs: {"in_type" , "operandSegmentSizes" }); |
| 176 | } |
| 177 | |
| 178 | //===----------------------------------------------------------------------===// |
| 179 | // AllocaOp |
| 180 | //===----------------------------------------------------------------------===// |
| 181 | |
| 182 | /// Create a legal memory reference as return type |
| 183 | static mlir::Type wrapAllocaResultType(mlir::Type intype) { |
| 184 | // FIR semantics: memory references to memory references are disallowed |
| 185 | if (mlir::isa<fir::ReferenceType>(intype)) |
| 186 | return {}; |
| 187 | return fir::ReferenceType::get(intype); |
| 188 | } |
| 189 | |
| 190 | mlir::Type fir::AllocaOp::getAllocatedType() { |
| 191 | return mlir::cast<fir::ReferenceType>(getType()).getEleTy(); |
| 192 | } |
| 193 | |
| 194 | mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { |
| 195 | return fir::ReferenceType::get(ty); |
| 196 | } |
| 197 | |
| 198 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
| 199 | mlir::OperationState &result, mlir::Type inType, |
| 200 | llvm::StringRef uniqName, mlir::ValueRange typeparams, |
| 201 | mlir::ValueRange shape, |
| 202 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 203 | auto nameAttr = builder.getStringAttr(uniqName); |
| 204 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, |
| 205 | /*pinned=*/false, typeparams, shape); |
| 206 | result.addAttributes(attributes); |
| 207 | } |
| 208 | |
| 209 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
| 210 | mlir::OperationState &result, mlir::Type inType, |
| 211 | llvm::StringRef uniqName, bool pinned, |
| 212 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
| 213 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 214 | auto nameAttr = builder.getStringAttr(uniqName); |
| 215 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, |
| 216 | pinned, typeparams, shape); |
| 217 | result.addAttributes(attributes); |
| 218 | } |
| 219 | |
| 220 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
| 221 | mlir::OperationState &result, mlir::Type inType, |
| 222 | llvm::StringRef uniqName, llvm::StringRef bindcName, |
| 223 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
| 224 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 225 | auto nameAttr = |
| 226 | uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); |
| 227 | auto bindcAttr = |
| 228 | bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); |
| 229 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, |
| 230 | bindcAttr, /*pinned=*/false, typeparams, shape); |
| 231 | result.addAttributes(attributes); |
| 232 | } |
| 233 | |
| 234 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
| 235 | mlir::OperationState &result, mlir::Type inType, |
| 236 | llvm::StringRef uniqName, llvm::StringRef bindcName, |
| 237 | bool pinned, mlir::ValueRange typeparams, |
| 238 | mlir::ValueRange shape, |
| 239 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 240 | auto nameAttr = |
| 241 | uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); |
| 242 | auto bindcAttr = |
| 243 | bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); |
| 244 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, |
| 245 | bindcAttr, pinned, typeparams, shape); |
| 246 | result.addAttributes(attributes); |
| 247 | } |
| 248 | |
| 249 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
| 250 | mlir::OperationState &result, mlir::Type inType, |
| 251 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
| 252 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 253 | build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, |
| 254 | /*pinned=*/false, typeparams, shape); |
| 255 | result.addAttributes(attributes); |
| 256 | } |
| 257 | |
| 258 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
| 259 | mlir::OperationState &result, mlir::Type inType, |
| 260 | bool pinned, mlir::ValueRange typeparams, |
| 261 | mlir::ValueRange shape, |
| 262 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 263 | build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, pinned, |
| 264 | typeparams, shape); |
| 265 | result.addAttributes(attributes); |
| 266 | } |
| 267 | |
| 268 | mlir::ParseResult fir::AllocaOp::parse(mlir::OpAsmParser &parser, |
| 269 | mlir::OperationState &result) { |
| 270 | return parseAllocatableOp(wrapAllocaResultType, parser, result); |
| 271 | } |
| 272 | |
| 273 | void fir::AllocaOp::print(mlir::OpAsmPrinter &p) { |
| 274 | printAllocatableOp(p, *this); |
| 275 | } |
| 276 | |
| 277 | llvm::LogicalResult fir::AllocaOp::verify() { |
| 278 | llvm::SmallVector<llvm::StringRef> visited; |
| 279 | if (verifyInType(getInType(), visited, numShapeOperands())) |
| 280 | return emitOpError("invalid type for allocation" ); |
| 281 | if (verifyTypeParamCount(getInType(), numLenParams())) |
| 282 | return emitOpError("LEN params do not correspond to type" ); |
| 283 | mlir::Type outType = getType(); |
| 284 | if (!mlir::isa<fir::ReferenceType>(outType)) |
| 285 | return emitOpError("must be a !fir.ref type" ); |
| 286 | return mlir::success(); |
| 287 | } |
| 288 | |
| 289 | bool fir::AllocaOp::ownsNestedAlloca(mlir::Operation *op) { |
| 290 | return op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>() || |
| 291 | op->hasTrait<mlir::OpTrait::AutomaticAllocationScope>() || |
| 292 | mlir::isa<mlir::LoopLikeOpInterface>(*op); |
| 293 | } |
| 294 | |
| 295 | mlir::Region *fir::AllocaOp::getOwnerRegion() { |
| 296 | mlir::Operation *currentOp = getOperation(); |
| 297 | while (mlir::Operation *parentOp = currentOp->getParentOp()) { |
| 298 | // If the operation was not registered, inquiries about its traits will be |
| 299 | // incorrect and it is not possible to reason about the operation. This |
| 300 | // should not happen in a normal Fortran compilation flow, but be foolproof. |
| 301 | if (!parentOp->isRegistered()) |
| 302 | return nullptr; |
| 303 | if (fir::AllocaOp::ownsNestedAlloca(parentOp)) |
| 304 | return currentOp->getParentRegion(); |
| 305 | currentOp = parentOp; |
| 306 | } |
| 307 | return nullptr; |
| 308 | } |
| 309 | |
| 310 | //===----------------------------------------------------------------------===// |
| 311 | // AllocMemOp |
| 312 | //===----------------------------------------------------------------------===// |
| 313 | |
| 314 | /// Create a legal heap reference as return type |
| 315 | static mlir::Type wrapAllocMemResultType(mlir::Type intype) { |
| 316 | // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER |
| 317 | // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well |
| 318 | // FIR semantics: one may not allocate a memory reference value |
| 319 | if (mlir::isa<fir::ReferenceType, fir::HeapType, fir::PointerType, |
| 320 | mlir::FunctionType>(intype)) |
| 321 | return {}; |
| 322 | return fir::HeapType::get(intype); |
| 323 | } |
| 324 | |
| 325 | mlir::Type fir::AllocMemOp::getAllocatedType() { |
| 326 | return mlir::cast<fir::HeapType>(getType()).getEleTy(); |
| 327 | } |
| 328 | |
| 329 | mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { |
| 330 | return fir::HeapType::get(ty); |
| 331 | } |
| 332 | |
| 333 | void fir::AllocMemOp::build(mlir::OpBuilder &builder, |
| 334 | mlir::OperationState &result, mlir::Type inType, |
| 335 | llvm::StringRef uniqName, |
| 336 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
| 337 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 338 | auto nameAttr = builder.getStringAttr(uniqName); |
| 339 | build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, {}, |
| 340 | typeparams, shape); |
| 341 | result.addAttributes(attributes); |
| 342 | } |
| 343 | |
| 344 | void fir::AllocMemOp::build(mlir::OpBuilder &builder, |
| 345 | mlir::OperationState &result, mlir::Type inType, |
| 346 | llvm::StringRef uniqName, llvm::StringRef bindcName, |
| 347 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
| 348 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 349 | auto nameAttr = builder.getStringAttr(uniqName); |
| 350 | auto bindcAttr = builder.getStringAttr(bindcName); |
| 351 | build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, |
| 352 | bindcAttr, typeparams, shape); |
| 353 | result.addAttributes(attributes); |
| 354 | } |
| 355 | |
| 356 | void fir::AllocMemOp::build(mlir::OpBuilder &builder, |
| 357 | mlir::OperationState &result, mlir::Type inType, |
| 358 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
| 359 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 360 | build(builder, result, wrapAllocMemResultType(inType), inType, {}, {}, |
| 361 | typeparams, shape); |
| 362 | result.addAttributes(attributes); |
| 363 | } |
| 364 | |
| 365 | mlir::ParseResult fir::AllocMemOp::parse(mlir::OpAsmParser &parser, |
| 366 | mlir::OperationState &result) { |
| 367 | return parseAllocatableOp(wrapAllocMemResultType, parser, result); |
| 368 | } |
| 369 | |
| 370 | void fir::AllocMemOp::print(mlir::OpAsmPrinter &p) { |
| 371 | printAllocatableOp(p, *this); |
| 372 | } |
| 373 | |
| 374 | llvm::LogicalResult fir::AllocMemOp::verify() { |
| 375 | llvm::SmallVector<llvm::StringRef> visited; |
| 376 | if (verifyInType(getInType(), visited, numShapeOperands())) |
| 377 | return emitOpError("invalid type for allocation" ); |
| 378 | if (verifyTypeParamCount(getInType(), numLenParams())) |
| 379 | return emitOpError("LEN params do not correspond to type" ); |
| 380 | mlir::Type outType = getType(); |
| 381 | if (!mlir::dyn_cast<fir::HeapType>(outType)) |
| 382 | return emitOpError("must be a !fir.heap type" ); |
| 383 | if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) |
| 384 | return emitOpError("cannot allocate !fir.box of unknown rank or type" ); |
| 385 | return mlir::success(); |
| 386 | } |
| 387 | |
| 388 | //===----------------------------------------------------------------------===// |
| 389 | // ArrayCoorOp |
| 390 | //===----------------------------------------------------------------------===// |
| 391 | |
| 392 | // CHARACTERs and derived types with LEN PARAMETERs are dependent types that |
| 393 | // require runtime values to fully define the type of an object. |
| 394 | static bool validTypeParams(mlir::Type dynTy, mlir::ValueRange typeParams, |
| 395 | bool allowParamsForBox = false) { |
| 396 | dynTy = fir::unwrapAllRefAndSeqType(dynTy); |
| 397 | if (mlir::isa<fir::BaseBoxType>(dynTy)) { |
| 398 | // A box value will contain type parameter values itself. |
| 399 | if (!allowParamsForBox) |
| 400 | return typeParams.size() == 0; |
| 401 | |
| 402 | // A boxed value may have no length parameters, when the lengths |
| 403 | // are assumed. If dynamic lengths are used, then proceed |
| 404 | // to the verification below. |
| 405 | if (typeParams.size() == 0) |
| 406 | return true; |
| 407 | |
| 408 | dynTy = fir::getFortranElementType(dynTy); |
| 409 | } |
| 410 | // Derived type must have all type parameters satisfied. |
| 411 | if (auto recTy = mlir::dyn_cast<fir::RecordType>(dynTy)) |
| 412 | return typeParams.size() == recTy.getNumLenParams(); |
| 413 | // Characters with non-constant LEN must have a type parameter value. |
| 414 | if (auto charTy = mlir::dyn_cast<fir::CharacterType>(dynTy)) |
| 415 | if (charTy.hasDynamicLen()) |
| 416 | return typeParams.size() == 1; |
| 417 | // Otherwise, any type parameters are invalid. |
| 418 | return typeParams.size() == 0; |
| 419 | } |
| 420 | |
| 421 | llvm::LogicalResult fir::ArrayCoorOp::verify() { |
| 422 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
| 423 | auto arrTy = mlir::dyn_cast<fir::SequenceType>(eleTy); |
| 424 | if (!arrTy) |
| 425 | return emitOpError("must be a reference to an array" ); |
| 426 | auto arrDim = arrTy.getDimension(); |
| 427 | |
| 428 | if (auto shapeOp = getShape()) { |
| 429 | auto shapeTy = shapeOp.getType(); |
| 430 | unsigned shapeTyRank = 0; |
| 431 | if (auto s = mlir::dyn_cast<fir::ShapeType>(shapeTy)) { |
| 432 | shapeTyRank = s.getRank(); |
| 433 | } else if (auto ss = mlir::dyn_cast<fir::ShapeShiftType>(shapeTy)) { |
| 434 | shapeTyRank = ss.getRank(); |
| 435 | } else { |
| 436 | auto s = mlir::cast<fir::ShiftType>(shapeTy); |
| 437 | shapeTyRank = s.getRank(); |
| 438 | // TODO: it looks like PreCGRewrite and CodeGen can support |
| 439 | // fir.shift with plain array reference, so we may consider |
| 440 | // removing this check. |
| 441 | if (!mlir::isa<fir::BaseBoxType>(getMemref().getType())) |
| 442 | return emitOpError("shift can only be provided with fir.box memref" ); |
| 443 | } |
| 444 | if (arrDim && arrDim != shapeTyRank) |
| 445 | return emitOpError("rank of dimension mismatched" ); |
| 446 | // TODO: support slicing with changing the number of dimensions, |
| 447 | // e.g. when array_coor represents an element access to array(:,1,:) |
| 448 | // slice: the shape is 3D and the number of indices is 2 in this case. |
| 449 | if (shapeTyRank != getIndices().size()) |
| 450 | return emitOpError("number of indices do not match dim rank" ); |
| 451 | } |
| 452 | |
| 453 | if (auto sliceOp = getSlice()) { |
| 454 | if (auto sl = mlir::dyn_cast_or_null<fir::SliceOp>(sliceOp.getDefiningOp())) |
| 455 | if (!sl.getSubstr().empty()) |
| 456 | return emitOpError("array_coor cannot take a slice with substring" ); |
| 457 | if (auto sliceTy = mlir::dyn_cast<fir::SliceType>(sliceOp.getType())) |
| 458 | if (sliceTy.getRank() != arrDim) |
| 459 | return emitOpError("rank of dimension in slice mismatched" ); |
| 460 | } |
| 461 | if (!validTypeParams(getMemref().getType(), getTypeparams())) |
| 462 | return emitOpError("invalid type parameters" ); |
| 463 | |
| 464 | return mlir::success(); |
| 465 | } |
| 466 | |
| 467 | // Pull in fir.embox and fir.rebox into fir.array_coor when possible. |
| 468 | struct SimplifyArrayCoorOp : public mlir::OpRewritePattern<fir::ArrayCoorOp> { |
| 469 | using mlir::OpRewritePattern<fir::ArrayCoorOp>::OpRewritePattern; |
| 470 | llvm::LogicalResult |
| 471 | matchAndRewrite(fir::ArrayCoorOp op, |
| 472 | mlir::PatternRewriter &rewriter) const override { |
| 473 | mlir::Value memref = op.getMemref(); |
| 474 | if (!mlir::isa<fir::BaseBoxType>(memref.getType())) |
| 475 | return mlir::failure(); |
| 476 | |
| 477 | mlir::Value boxedMemref, boxedShape, boxedSlice; |
| 478 | if (auto emboxOp = |
| 479 | mlir::dyn_cast_or_null<fir::EmboxOp>(memref.getDefiningOp())) { |
| 480 | boxedMemref = emboxOp.getMemref(); |
| 481 | boxedShape = emboxOp.getShape(); |
| 482 | boxedSlice = emboxOp.getSlice(); |
| 483 | // If any of operands, that are not currently supported for migration |
| 484 | // to ArrayCoorOp, is present, don't rewrite. |
| 485 | if (!emboxOp.getTypeparams().empty() || emboxOp.getSourceBox() || |
| 486 | emboxOp.getAccessMap()) |
| 487 | return mlir::failure(); |
| 488 | } else if (auto reboxOp = mlir::dyn_cast_or_null<fir::ReboxOp>( |
| 489 | memref.getDefiningOp())) { |
| 490 | boxedMemref = reboxOp.getBox(); |
| 491 | boxedShape = reboxOp.getShape(); |
| 492 | // Avoid pulling in rebox that performs reshaping. |
| 493 | // There is no way to represent box reshaping with array_coor. |
| 494 | if (boxedShape && !mlir::isa<fir::ShiftType>(boxedShape.getType())) |
| 495 | return mlir::failure(); |
| 496 | boxedSlice = reboxOp.getSlice(); |
| 497 | } else { |
| 498 | return mlir::failure(); |
| 499 | } |
| 500 | |
| 501 | bool boxedShapeIsShift = |
| 502 | boxedShape && mlir::isa<fir::ShiftType>(boxedShape.getType()); |
| 503 | bool boxedShapeIsShape = |
| 504 | boxedShape && mlir::isa<fir::ShapeType>(boxedShape.getType()); |
| 505 | bool boxedShapeIsShapeShift = |
| 506 | boxedShape && mlir::isa<fir::ShapeShiftType>(boxedShape.getType()); |
| 507 | |
| 508 | // Slices changing the number of dimensions are not supported |
| 509 | // for array_coor yet. |
| 510 | unsigned origBoxRank; |
| 511 | if (mlir::isa<fir::BaseBoxType>(boxedMemref.getType())) |
| 512 | origBoxRank = fir::getBoxRank(boxedMemref.getType()); |
| 513 | else if (auto arrTy = mlir::dyn_cast<fir::SequenceType>( |
| 514 | fir::unwrapRefType(boxedMemref.getType()))) |
| 515 | origBoxRank = arrTy.getDimension(); |
| 516 | else |
| 517 | return mlir::failure(); |
| 518 | |
| 519 | if (fir::getBoxRank(memref.getType()) != origBoxRank) |
| 520 | return mlir::failure(); |
| 521 | |
| 522 | // Slices with substring are not supported by array_coor. |
| 523 | if (boxedSlice) |
| 524 | if (auto sliceOp = |
| 525 | mlir::dyn_cast_or_null<fir::SliceOp>(boxedSlice.getDefiningOp())) |
| 526 | if (!sliceOp.getSubstr().empty()) |
| 527 | return mlir::failure(); |
| 528 | |
| 529 | // If embox/rebox and array_coor have conflicting shapes or slices, |
| 530 | // do nothing. |
| 531 | if (op.getShape() && boxedShape && boxedShape != op.getShape()) |
| 532 | return mlir::failure(); |
| 533 | if (op.getSlice() && boxedSlice && boxedSlice != op.getSlice()) |
| 534 | return mlir::failure(); |
| 535 | |
| 536 | std::optional<IndicesVectorTy> shiftedIndices; |
| 537 | // The embox/rebox and array_coor either have compatible |
| 538 | // shape/slice at this point or shape/slice is null |
| 539 | // in one of them but not in the other. |
| 540 | // The compatibility means they are equal or both null. |
| 541 | if (!op.getShape()) { |
| 542 | if (boxedShape) { |
| 543 | if (op.getSlice()) { |
| 544 | if (!boxedSlice) { |
| 545 | if (boxedShapeIsShift) { |
| 546 | // %0 = fir.rebox %arg(%shift) |
| 547 | // %1 = fir.array_coor %0 [%slice] %idx |
| 548 | // Both the slice indices and %idx are 1-based, so the rebox |
| 549 | // may be pulled in as: |
| 550 | // %1 = fir.array_coor %arg [%slice] %idx |
| 551 | boxedShape = nullptr; |
| 552 | } else if (boxedShapeIsShape) { |
| 553 | // %0 = fir.embox %arg(%shape) |
| 554 | // %1 = fir.array_coor %0 [%slice] %idx |
| 555 | // Pull in as: |
| 556 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
| 557 | } else if (boxedShapeIsShapeShift) { |
| 558 | // %0 = fir.embox %arg(%shapeshift) |
| 559 | // %1 = fir.array_coor %0 [%slice] %idx |
| 560 | // Pull in as: |
| 561 | // %shape = fir.shape <extents from the %shapeshift> |
| 562 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
| 563 | boxedShape = getShapeFromShapeShift(v: boxedShape, rewriter); |
| 564 | if (!boxedShape) |
| 565 | return mlir::failure(); |
| 566 | } else { |
| 567 | return mlir::failure(); |
| 568 | } |
| 569 | } else { |
| 570 | if (boxedShapeIsShift) { |
| 571 | // %0 = fir.rebox %arg(%shift) [%slice] |
| 572 | // %1 = fir.array_coor %0 [%slice] %idx |
| 573 | // This FIR may only be valid if the shape specifies |
| 574 | // that all lower bounds are 1s and the slice's start indices |
| 575 | // and strides are all 1s. |
| 576 | // We could pull in the rebox as: |
| 577 | // %1 = fir.array_coor %arg [%slice] %idx |
| 578 | // Do not do anything for the time being. |
| 579 | return mlir::failure(); |
| 580 | } else if (boxedShapeIsShape) { |
| 581 | // %0 = fir.embox %arg(%shape) [%slice] |
| 582 | // %1 = fir.array_coor %0 [%slice] %idx |
| 583 | // This FIR may only be valid if the slice's start indices |
| 584 | // and strides are all 1s. |
| 585 | // We could pull in the embox as: |
| 586 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
| 587 | return mlir::failure(); |
| 588 | } else if (boxedShapeIsShapeShift) { |
| 589 | // %0 = fir.embox %arg(%shapeshift) [%slice] |
| 590 | // %1 = fir.array_coor %0 [%slice] %idx |
| 591 | // This FIR may only be valid if the shape specifies |
| 592 | // that all lower bounds are 1s and the slice's start indices |
| 593 | // and strides are all 1s. |
| 594 | // We could pull in the embox as: |
| 595 | // %shape = fir.shape <extents from the %shapeshift> |
| 596 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
| 597 | return mlir::failure(); |
| 598 | } else { |
| 599 | return mlir::failure(); |
| 600 | } |
| 601 | } |
| 602 | } else { // !op.getSlice() |
| 603 | if (!boxedSlice) { |
| 604 | if (boxedShapeIsShift) { |
| 605 | // %0 = fir.rebox %arg(%shift) |
| 606 | // %1 = fir.array_coor %0 %idx |
| 607 | // Pull in as: |
| 608 | // %1 = fir.array_coor %arg %idx |
| 609 | boxedShape = nullptr; |
| 610 | } else if (boxedShapeIsShape) { |
| 611 | // %0 = fir.embox %arg(%shape) |
| 612 | // %1 = fir.array_coor %0 %idx |
| 613 | // Pull in as: |
| 614 | // %1 = fir.array_coor %arg(%shape) %idx |
| 615 | } else if (boxedShapeIsShapeShift) { |
| 616 | // %0 = fir.embox %arg(%shapeshift) |
| 617 | // %1 = fir.array_coor %0 %idx |
| 618 | // Pull in as: |
| 619 | // %shape = fir.shape <extents from the %shapeshift> |
| 620 | // %1 = fir.array_coor %arg(%shape) %idx |
| 621 | boxedShape = getShapeFromShapeShift(v: boxedShape, rewriter); |
| 622 | if (!boxedShape) |
| 623 | return mlir::failure(); |
| 624 | } else { |
| 625 | return mlir::failure(); |
| 626 | } |
| 627 | } else { |
| 628 | if (boxedShapeIsShift) { |
| 629 | // %0 = fir.embox %arg(%shift) [%slice] |
| 630 | // %1 = fir.array_coor %0 %idx |
| 631 | // Pull in as: |
| 632 | // %tmp = arith.addi %idx, %shift.origin |
| 633 | // %idx_shifted = arith.subi %tmp, 1 |
| 634 | // %1 = fir.array_coor %arg(%shift) %[slice] %idx_shifted |
| 635 | shiftedIndices = |
| 636 | getShiftedIndices(v: boxedShape, indices: op.getIndices(), rewriter); |
| 637 | if (!shiftedIndices) |
| 638 | return mlir::failure(); |
| 639 | } else if (boxedShapeIsShape) { |
| 640 | // %0 = fir.embox %arg(%shape) [%slice] |
| 641 | // %1 = fir.array_coor %0 %idx |
| 642 | // Pull in as: |
| 643 | // %1 = fir.array_coor %arg(%shape) %[slice] %idx |
| 644 | } else if (boxedShapeIsShapeShift) { |
| 645 | // %0 = fir.embox %arg(%shapeshift) [%slice] |
| 646 | // %1 = fir.array_coor %0 %idx |
| 647 | // Pull in as: |
| 648 | // %tmp = arith.addi %idx, %shapeshift.lb |
| 649 | // %idx_shifted = arith.subi %tmp, 1 |
| 650 | // %1 = fir.array_coor %arg(%shapeshift) %[slice] %idx_shifted |
| 651 | shiftedIndices = |
| 652 | getShiftedIndices(v: boxedShape, indices: op.getIndices(), rewriter); |
| 653 | if (!shiftedIndices) |
| 654 | return mlir::failure(); |
| 655 | } else { |
| 656 | return mlir::failure(); |
| 657 | } |
| 658 | } |
| 659 | } |
| 660 | } else { // !boxedShape |
| 661 | if (op.getSlice()) { |
| 662 | if (!boxedSlice) { |
| 663 | // %0 = fir.rebox %arg |
| 664 | // %1 = fir.array_coor %0 [%slice] %idx |
| 665 | // Pull in as: |
| 666 | // %1 = fir.array_coor %arg [%slice] %idx |
| 667 | } else { |
| 668 | // %0 = fir.rebox %arg [%slice] |
| 669 | // %1 = fir.array_coor %0 [%slice] %idx |
| 670 | // This is a valid FIR iff the slice's lower bounds |
| 671 | // and strides are all 1s. |
| 672 | // Pull in as: |
| 673 | // %1 = fir.array_coor %arg [%slice] %idx |
| 674 | } |
| 675 | } else { // !op.getSlice() |
| 676 | if (!boxedSlice) { |
| 677 | // %0 = fir.rebox %arg |
| 678 | // %1 = fir.array_coor %0 %idx |
| 679 | // Pull in as: |
| 680 | // %1 = fir.array_coor %arg %idx |
| 681 | } else { |
| 682 | // %0 = fir.rebox %arg [%slice] |
| 683 | // %1 = fir.array_coor %0 %idx |
| 684 | // Pull in as: |
| 685 | // %1 = fir.array_coor %arg [%slice] %idx |
| 686 | } |
| 687 | } |
| 688 | } |
| 689 | } else { // op.getShape() |
| 690 | if (boxedShape) { |
| 691 | // Check if pulling in non-default shape is correct. |
| 692 | if (op.getSlice()) { |
| 693 | if (!boxedSlice) { |
| 694 | // %0 = fir.embox %arg(%shape) |
| 695 | // %1 = fir.array_coor %0(%shape) [%slice] %idx |
| 696 | // Pull in as: |
| 697 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
| 698 | } else { |
| 699 | // %0 = fir.embox %arg(%shape) [%slice] |
| 700 | // %1 = fir.array_coor %0(%shape) [%slice] %idx |
| 701 | // Pull in as: |
| 702 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
| 703 | } |
| 704 | } else { // !op.getSlice() |
| 705 | if (!boxedSlice) { |
| 706 | // %0 = fir.embox %arg(%shape) |
| 707 | // %1 = fir.array_coor %0(%shape) %idx |
| 708 | // Pull in as: |
| 709 | // %1 = fir.array_coor %arg(%shape) %idx |
| 710 | } else { |
| 711 | // %0 = fir.embox %arg(%shape) [%slice] |
| 712 | // %1 = fir.array_coor %0(%shape) %idx |
| 713 | // Pull in as: |
| 714 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
| 715 | } |
| 716 | } |
| 717 | } else { // !boxedShape |
| 718 | if (op.getSlice()) { |
| 719 | if (!boxedSlice) { |
| 720 | // %0 = fir.rebox %arg |
| 721 | // %1 = fir.array_coor %0(%shape) [%slice] %idx |
| 722 | // Pull in as: |
| 723 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
| 724 | } else { |
| 725 | // %0 = fir.rebox %arg [%slice] |
| 726 | // %1 = fir.array_coor %0(%shape) [%slice] %idx |
| 727 | return mlir::failure(); |
| 728 | } |
| 729 | } else { // !op.getSlice() |
| 730 | if (!boxedSlice) { |
| 731 | // %0 = fir.rebox %arg |
| 732 | // %1 = fir.array_coor %0(%shape) %idx |
| 733 | // Pull in as: |
| 734 | // %1 = fir.array_coor %arg(%shape) %idx |
| 735 | } else { |
| 736 | // %0 = fir.rebox %arg [%slice] |
| 737 | // %1 = fir.array_coor %0(%shape) %idx |
| 738 | // Cannot pull in without adjusting the slice indices. |
| 739 | return mlir::failure(); |
| 740 | } |
| 741 | } |
| 742 | } |
| 743 | } |
| 744 | |
| 745 | // TODO: temporarily avoid producing array_coor with the shape shift |
| 746 | // and plain array reference (it seems to be a limitation of |
| 747 | // ArrayCoorOp verifier). |
| 748 | if (!mlir::isa<fir::BaseBoxType>(boxedMemref.getType())) { |
| 749 | if (boxedShape) { |
| 750 | if (mlir::isa<fir::ShiftType>(boxedShape.getType())) |
| 751 | return mlir::failure(); |
| 752 | } else if (op.getShape() && |
| 753 | mlir::isa<fir::ShiftType>(op.getShape().getType())) { |
| 754 | return mlir::failure(); |
| 755 | } |
| 756 | } |
| 757 | |
| 758 | rewriter.modifyOpInPlace(op, [&]() { |
| 759 | op.getMemrefMutable().assign(boxedMemref); |
| 760 | if (boxedShape) |
| 761 | op.getShapeMutable().assign(boxedShape); |
| 762 | if (boxedSlice) |
| 763 | op.getSliceMutable().assign(boxedSlice); |
| 764 | if (shiftedIndices) |
| 765 | op.getIndicesMutable().assign(*shiftedIndices); |
| 766 | }); |
| 767 | return mlir::success(); |
| 768 | } |
| 769 | |
| 770 | private: |
| 771 | using IndicesVectorTy = std::vector<mlir::Value>; |
| 772 | |
| 773 | // If v is a shape_shift operation: |
| 774 | // fir.shape_shift %l1, %e1, %l2, %e2, ... |
| 775 | // create: |
| 776 | // fir.shape %e1, %e2, ... |
| 777 | static mlir::Value getShapeFromShapeShift(mlir::Value v, |
| 778 | mlir::PatternRewriter &rewriter) { |
| 779 | auto shapeShiftOp = |
| 780 | mlir::dyn_cast_or_null<fir::ShapeShiftOp>(v.getDefiningOp()); |
| 781 | if (!shapeShiftOp) |
| 782 | return nullptr; |
| 783 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
| 784 | rewriter.setInsertionPoint(shapeShiftOp); |
| 785 | return rewriter.create<fir::ShapeOp>(shapeShiftOp.getLoc(), |
| 786 | shapeShiftOp.getExtents()); |
| 787 | } |
| 788 | |
| 789 | static std::optional<IndicesVectorTy> |
| 790 | getShiftedIndices(mlir::Value v, mlir::ValueRange indices, |
| 791 | mlir::PatternRewriter &rewriter) { |
| 792 | auto insertAdjustments = [&](mlir::Operation *op, mlir::ValueRange lbs) { |
| 793 | // Compute the shifted indices using the extended type. |
| 794 | // Note that this can probably result in less efficient |
| 795 | // MLIR and further LLVM IR due to the extra conversions. |
| 796 | mlir::OpBuilder::InsertPoint savedIP = rewriter.saveInsertionPoint(); |
| 797 | rewriter.setInsertionPoint(op); |
| 798 | mlir::Location loc = op->getLoc(); |
| 799 | mlir::Type idxTy = rewriter.getIndexType(); |
| 800 | mlir::Value one = rewriter.create<mlir::arith::ConstantOp>( |
| 801 | loc, idxTy, rewriter.getIndexAttr(1)); |
| 802 | rewriter.restoreInsertionPoint(ip: savedIP); |
| 803 | auto nsw = mlir::arith::IntegerOverflowFlags::nsw; |
| 804 | |
| 805 | IndicesVectorTy shiftedIndices; |
| 806 | for (auto [lb, idx] : llvm::zip(t&: lbs, u&: indices)) { |
| 807 | mlir::Value extLb = rewriter.create<fir::ConvertOp>(loc, idxTy, lb); |
| 808 | mlir::Value extIdx = rewriter.create<fir::ConvertOp>(loc, idxTy, idx); |
| 809 | mlir::Value add = |
| 810 | rewriter.create<mlir::arith::AddIOp>(loc, extIdx, extLb, nsw); |
| 811 | mlir::Value sub = |
| 812 | rewriter.create<mlir::arith::SubIOp>(loc, add, one, nsw); |
| 813 | shiftedIndices.push_back(x: sub); |
| 814 | } |
| 815 | |
| 816 | return shiftedIndices; |
| 817 | }; |
| 818 | |
| 819 | if (auto shiftOp = |
| 820 | mlir::dyn_cast_or_null<fir::ShiftOp>(v.getDefiningOp())) { |
| 821 | return insertAdjustments(shiftOp.getOperation(), shiftOp.getOrigins()); |
| 822 | } else if (auto shapeShiftOp = mlir::dyn_cast_or_null<fir::ShapeShiftOp>( |
| 823 | v.getDefiningOp())) { |
| 824 | return insertAdjustments(shapeShiftOp.getOperation(), |
| 825 | shapeShiftOp.getOrigins()); |
| 826 | } |
| 827 | |
| 828 | return std::nullopt; |
| 829 | } |
| 830 | }; |
| 831 | |
| 832 | void fir::ArrayCoorOp::getCanonicalizationPatterns( |
| 833 | mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { |
| 834 | // TODO: !fir.shape<1> operand may be removed from array_coor always. |
| 835 | patterns.add<SimplifyArrayCoorOp>(context); |
| 836 | } |
| 837 | |
| 838 | //===----------------------------------------------------------------------===// |
| 839 | // ArrayLoadOp |
| 840 | //===----------------------------------------------------------------------===// |
| 841 | |
| 842 | static mlir::Type adjustedElementType(mlir::Type t) { |
| 843 | if (auto ty = mlir::dyn_cast<fir::ReferenceType>(t)) { |
| 844 | auto eleTy = ty.getEleTy(); |
| 845 | if (fir::isa_char(eleTy)) |
| 846 | return eleTy; |
| 847 | if (fir::isa_derived(eleTy)) |
| 848 | return eleTy; |
| 849 | if (mlir::isa<fir::SequenceType>(eleTy)) |
| 850 | return eleTy; |
| 851 | } |
| 852 | return t; |
| 853 | } |
| 854 | |
| 855 | std::vector<mlir::Value> fir::ArrayLoadOp::getExtents() { |
| 856 | if (auto sh = getShape()) |
| 857 | if (auto *op = sh.getDefiningOp()) { |
| 858 | if (auto shOp = mlir::dyn_cast<fir::ShapeOp>(op)) { |
| 859 | auto extents = shOp.getExtents(); |
| 860 | return {extents.begin(), extents.end()}; |
| 861 | } |
| 862 | return mlir::cast<fir::ShapeShiftOp>(op).getExtents(); |
| 863 | } |
| 864 | return {}; |
| 865 | } |
| 866 | |
| 867 | void fir::ArrayLoadOp::getEffects( |
| 868 | llvm::SmallVectorImpl< |
| 869 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
| 870 | &effects) { |
| 871 | effects.emplace_back(mlir::MemoryEffects::Read::get(), &getMemrefMutable(), |
| 872 | mlir::SideEffects::DefaultResource::get()); |
| 873 | addVolatileMemoryEffects({getMemref().getType()}, effects); |
| 874 | } |
| 875 | |
| 876 | llvm::LogicalResult fir::ArrayLoadOp::verify() { |
| 877 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
| 878 | auto arrTy = mlir::dyn_cast<fir::SequenceType>(eleTy); |
| 879 | if (!arrTy) |
| 880 | return emitOpError("must be a reference to an array" ); |
| 881 | auto arrDim = arrTy.getDimension(); |
| 882 | |
| 883 | if (auto shapeOp = getShape()) { |
| 884 | auto shapeTy = shapeOp.getType(); |
| 885 | unsigned shapeTyRank = 0u; |
| 886 | if (auto s = mlir::dyn_cast<fir::ShapeType>(shapeTy)) { |
| 887 | shapeTyRank = s.getRank(); |
| 888 | } else if (auto ss = mlir::dyn_cast<fir::ShapeShiftType>(shapeTy)) { |
| 889 | shapeTyRank = ss.getRank(); |
| 890 | } else { |
| 891 | auto s = mlir::cast<fir::ShiftType>(shapeTy); |
| 892 | shapeTyRank = s.getRank(); |
| 893 | if (!mlir::isa<fir::BaseBoxType>(getMemref().getType())) |
| 894 | return emitOpError("shift can only be provided with fir.box memref" ); |
| 895 | } |
| 896 | if (arrDim && arrDim != shapeTyRank) |
| 897 | return emitOpError("rank of dimension mismatched" ); |
| 898 | } |
| 899 | |
| 900 | if (auto sliceOp = getSlice()) { |
| 901 | if (auto sl = mlir::dyn_cast_or_null<fir::SliceOp>(sliceOp.getDefiningOp())) |
| 902 | if (!sl.getSubstr().empty()) |
| 903 | return emitOpError("array_load cannot take a slice with substring" ); |
| 904 | if (auto sliceTy = mlir::dyn_cast<fir::SliceType>(sliceOp.getType())) |
| 905 | if (sliceTy.getRank() != arrDim) |
| 906 | return emitOpError("rank of dimension in slice mismatched" ); |
| 907 | } |
| 908 | |
| 909 | if (!validTypeParams(getMemref().getType(), getTypeparams())) |
| 910 | return emitOpError("invalid type parameters" ); |
| 911 | |
| 912 | return mlir::success(); |
| 913 | } |
| 914 | |
| 915 | //===----------------------------------------------------------------------===// |
| 916 | // ArrayMergeStoreOp |
| 917 | //===----------------------------------------------------------------------===// |
| 918 | |
| 919 | llvm::LogicalResult fir::ArrayMergeStoreOp::verify() { |
| 920 | if (!mlir::isa<fir::ArrayLoadOp>(getOriginal().getDefiningOp())) |
| 921 | return emitOpError("operand #0 must be result of a fir.array_load op" ); |
| 922 | if (auto sl = getSlice()) { |
| 923 | if (auto sliceOp = |
| 924 | mlir::dyn_cast_or_null<fir::SliceOp>(sl.getDefiningOp())) { |
| 925 | if (!sliceOp.getSubstr().empty()) |
| 926 | return emitOpError( |
| 927 | "array_merge_store cannot take a slice with substring" ); |
| 928 | if (!sliceOp.getFields().empty()) { |
| 929 | // This is an intra-object merge, where the slice is projecting the |
| 930 | // subfields that are to be overwritten by the merge operation. |
| 931 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
| 932 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) { |
| 933 | auto projTy = |
| 934 | fir::applyPathToType(seqTy.getEleTy(), sliceOp.getFields()); |
| 935 | if (fir::unwrapSequenceType(getOriginal().getType()) != projTy) |
| 936 | return emitOpError( |
| 937 | "type of origin does not match sliced memref type" ); |
| 938 | if (fir::unwrapSequenceType(getSequence().getType()) != projTy) |
| 939 | return emitOpError( |
| 940 | "type of sequence does not match sliced memref type" ); |
| 941 | return mlir::success(); |
| 942 | } |
| 943 | return emitOpError("referenced type is not an array" ); |
| 944 | } |
| 945 | } |
| 946 | return mlir::success(); |
| 947 | } |
| 948 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
| 949 | if (getOriginal().getType() != eleTy) |
| 950 | return emitOpError("type of origin does not match memref element type" ); |
| 951 | if (getSequence().getType() != eleTy) |
| 952 | return emitOpError("type of sequence does not match memref element type" ); |
| 953 | if (!validTypeParams(getMemref().getType(), getTypeparams())) |
| 954 | return emitOpError("invalid type parameters" ); |
| 955 | return mlir::success(); |
| 956 | } |
| 957 | |
| 958 | void fir::ArrayMergeStoreOp::getEffects( |
| 959 | llvm::SmallVectorImpl< |
| 960 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
| 961 | &effects) { |
| 962 | effects.emplace_back(mlir::MemoryEffects::Write::get(), &getMemrefMutable(), |
| 963 | mlir::SideEffects::DefaultResource::get()); |
| 964 | addVolatileMemoryEffects({getMemref().getType()}, effects); |
| 965 | } |
| 966 | |
| 967 | //===----------------------------------------------------------------------===// |
| 968 | // ArrayFetchOp |
| 969 | //===----------------------------------------------------------------------===// |
| 970 | |
| 971 | // Template function used for both array_fetch and array_update verification. |
| 972 | template <typename A> |
| 973 | mlir::Type validArraySubobject(A op) { |
| 974 | auto ty = op.getSequence().getType(); |
| 975 | return fir::applyPathToType(ty, op.getIndices()); |
| 976 | } |
| 977 | |
| 978 | llvm::LogicalResult fir::ArrayFetchOp::verify() { |
| 979 | auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); |
| 980 | auto indSize = getIndices().size(); |
| 981 | if (indSize < arrTy.getDimension()) |
| 982 | return emitOpError("number of indices != dimension of array" ); |
| 983 | if (indSize == arrTy.getDimension() && |
| 984 | ::adjustedElementType(getElement().getType()) != arrTy.getEleTy()) |
| 985 | return emitOpError("return type does not match array" ); |
| 986 | auto ty = validArraySubobject(*this); |
| 987 | if (!ty || ty != ::adjustedElementType(getType())) |
| 988 | return emitOpError("return type and/or indices do not type check" ); |
| 989 | if (!mlir::isa<fir::ArrayLoadOp>(getSequence().getDefiningOp())) |
| 990 | return emitOpError("argument #0 must be result of fir.array_load" ); |
| 991 | if (!validTypeParams(arrTy, getTypeparams())) |
| 992 | return emitOpError("invalid type parameters" ); |
| 993 | return mlir::success(); |
| 994 | } |
| 995 | |
| 996 | //===----------------------------------------------------------------------===// |
| 997 | // ArrayAccessOp |
| 998 | //===----------------------------------------------------------------------===// |
| 999 | |
| 1000 | llvm::LogicalResult fir::ArrayAccessOp::verify() { |
| 1001 | auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); |
| 1002 | std::size_t indSize = getIndices().size(); |
| 1003 | if (indSize < arrTy.getDimension()) |
| 1004 | return emitOpError("number of indices != dimension of array" ); |
| 1005 | if (indSize == arrTy.getDimension() && |
| 1006 | getElement().getType() != fir::ReferenceType::get(arrTy.getEleTy())) |
| 1007 | return emitOpError("return type does not match array" ); |
| 1008 | mlir::Type ty = validArraySubobject(*this); |
| 1009 | if (!ty || fir::ReferenceType::get(ty) != getType()) |
| 1010 | return emitOpError("return type and/or indices do not type check" ); |
| 1011 | if (!validTypeParams(arrTy, getTypeparams())) |
| 1012 | return emitOpError("invalid type parameters" ); |
| 1013 | return mlir::success(); |
| 1014 | } |
| 1015 | |
| 1016 | //===----------------------------------------------------------------------===// |
| 1017 | // ArrayUpdateOp |
| 1018 | //===----------------------------------------------------------------------===// |
| 1019 | |
| 1020 | llvm::LogicalResult fir::ArrayUpdateOp::verify() { |
| 1021 | if (fir::isa_ref_type(getMerge().getType())) |
| 1022 | return emitOpError("does not support reference type for merge" ); |
| 1023 | auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); |
| 1024 | auto indSize = getIndices().size(); |
| 1025 | if (indSize < arrTy.getDimension()) |
| 1026 | return emitOpError("number of indices != dimension of array" ); |
| 1027 | if (indSize == arrTy.getDimension() && |
| 1028 | ::adjustedElementType(getMerge().getType()) != arrTy.getEleTy()) |
| 1029 | return emitOpError("merged value does not have element type" ); |
| 1030 | auto ty = validArraySubobject(*this); |
| 1031 | if (!ty || ty != ::adjustedElementType(getMerge().getType())) |
| 1032 | return emitOpError("merged value and/or indices do not type check" ); |
| 1033 | if (!validTypeParams(arrTy, getTypeparams())) |
| 1034 | return emitOpError("invalid type parameters" ); |
| 1035 | return mlir::success(); |
| 1036 | } |
| 1037 | |
| 1038 | //===----------------------------------------------------------------------===// |
| 1039 | // ArrayModifyOp |
| 1040 | //===----------------------------------------------------------------------===// |
| 1041 | |
| 1042 | llvm::LogicalResult fir::ArrayModifyOp::verify() { |
| 1043 | auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); |
| 1044 | auto indSize = getIndices().size(); |
| 1045 | if (indSize < arrTy.getDimension()) |
| 1046 | return emitOpError("number of indices must match array dimension" ); |
| 1047 | return mlir::success(); |
| 1048 | } |
| 1049 | |
| 1050 | //===----------------------------------------------------------------------===// |
| 1051 | // BoxAddrOp |
| 1052 | //===----------------------------------------------------------------------===// |
| 1053 | |
| 1054 | void fir::BoxAddrOp::build(mlir::OpBuilder &builder, |
| 1055 | mlir::OperationState &result, mlir::Value val) { |
| 1056 | mlir::Type type = |
| 1057 | llvm::TypeSwitch<mlir::Type, mlir::Type>(val.getType()) |
| 1058 | .Case<fir::BaseBoxType>([&](fir::BaseBoxType ty) -> mlir::Type { |
| 1059 | mlir::Type eleTy = ty.getEleTy(); |
| 1060 | if (fir::isa_ref_type(eleTy)) |
| 1061 | return eleTy; |
| 1062 | return fir::ReferenceType::get(eleTy); |
| 1063 | }) |
| 1064 | .Case<fir::BoxCharType>([&](fir::BoxCharType ty) -> mlir::Type { |
| 1065 | return fir::ReferenceType::get(ty.getEleTy()); |
| 1066 | }) |
| 1067 | .Case<fir::BoxProcType>( |
| 1068 | [&](fir::BoxProcType ty) { return ty.getEleTy(); }) |
| 1069 | .Default([&](const auto &) { return mlir::Type{}; }); |
| 1070 | assert(type && "bad val type" ); |
| 1071 | build(builder, result, type, val); |
| 1072 | } |
| 1073 | |
| 1074 | mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) { |
| 1075 | if (auto *v = getVal().getDefiningOp()) { |
| 1076 | if (auto box = mlir::dyn_cast<fir::EmboxOp>(v)) { |
| 1077 | // Fold only if not sliced |
| 1078 | if (!box.getSlice() && box.getMemref().getType() == getType()) { |
| 1079 | propagateAttributes(getOperation(), box.getMemref().getDefiningOp()); |
| 1080 | return box.getMemref(); |
| 1081 | } |
| 1082 | } |
| 1083 | if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v)) |
| 1084 | if (box.getMemref().getType() == getType()) |
| 1085 | return box.getMemref(); |
| 1086 | } |
| 1087 | return {}; |
| 1088 | } |
| 1089 | |
| 1090 | //===----------------------------------------------------------------------===// |
| 1091 | // BoxCharLenOp |
| 1092 | //===----------------------------------------------------------------------===// |
| 1093 | |
| 1094 | mlir::OpFoldResult fir::BoxCharLenOp::fold(FoldAdaptor adaptor) { |
| 1095 | if (auto v = getVal().getDefiningOp()) { |
| 1096 | if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v)) |
| 1097 | return box.getLen(); |
| 1098 | } |
| 1099 | return {}; |
| 1100 | } |
| 1101 | |
| 1102 | //===----------------------------------------------------------------------===// |
| 1103 | // BoxDimsOp |
| 1104 | //===----------------------------------------------------------------------===// |
| 1105 | |
| 1106 | /// Get the result types packed in a tuple tuple |
| 1107 | mlir::Type fir::BoxDimsOp::getTupleType() { |
| 1108 | // note: triple, but 4 is nearest power of 2 |
| 1109 | llvm::SmallVector<mlir::Type> triple{ |
| 1110 | getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; |
| 1111 | return mlir::TupleType::get(getContext(), triple); |
| 1112 | } |
| 1113 | |
| 1114 | //===----------------------------------------------------------------------===// |
| 1115 | // BoxRankOp |
| 1116 | //===----------------------------------------------------------------------===// |
| 1117 | |
| 1118 | void fir::BoxRankOp::getEffects( |
| 1119 | llvm::SmallVectorImpl< |
| 1120 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
| 1121 | &effects) { |
| 1122 | mlir::OpOperand &inputBox = getBoxMutable(); |
| 1123 | if (fir::isBoxAddress(inputBox.get().getType())) |
| 1124 | effects.emplace_back(mlir::MemoryEffects::Read::get(), &inputBox, |
| 1125 | mlir::SideEffects::DefaultResource::get()); |
| 1126 | } |
| 1127 | |
| 1128 | //===----------------------------------------------------------------------===// |
| 1129 | // CallOp |
| 1130 | //===----------------------------------------------------------------------===// |
| 1131 | |
| 1132 | mlir::FunctionType fir::CallOp::getFunctionType() { |
| 1133 | return mlir::FunctionType::get(getContext(), getOperandTypes(), |
| 1134 | getResultTypes()); |
| 1135 | } |
| 1136 | |
| 1137 | void fir::CallOp::print(mlir::OpAsmPrinter &p) { |
| 1138 | bool isDirect = getCallee().has_value(); |
| 1139 | p << ' '; |
| 1140 | if (isDirect) |
| 1141 | p << *getCallee(); |
| 1142 | else |
| 1143 | p << getOperand(0); |
| 1144 | p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')'; |
| 1145 | |
| 1146 | // Print `proc_attrs<...>`, if present. |
| 1147 | fir::FortranProcedureFlagsEnumAttr procAttrs = getProcedureAttrsAttr(); |
| 1148 | if (procAttrs && |
| 1149 | procAttrs.getValue() != fir::FortranProcedureFlagsEnum::none) { |
| 1150 | p << ' ' << fir::FortranProcedureFlagsEnumAttr::getMnemonic(); |
| 1151 | p.printStrippedAttrOrType(procAttrs); |
| 1152 | } |
| 1153 | |
| 1154 | // Print 'fastmath<...>' (if it has non-default value) before |
| 1155 | // any other attributes. |
| 1156 | mlir::arith::FastMathFlagsAttr fmfAttr = getFastmathAttr(); |
| 1157 | if (fmfAttr.getValue() != mlir::arith::FastMathFlags::none) { |
| 1158 | p << ' ' << mlir::arith::FastMathFlagsAttr::getMnemonic(); |
| 1159 | p.printStrippedAttrOrType(fmfAttr); |
| 1160 | } |
| 1161 | |
| 1162 | p.printOptionalAttrDict((*this)->getAttrs(), |
| 1163 | {fir::CallOp::getCalleeAttrNameStr(), |
| 1164 | getFastmathAttrName(), getProcedureAttrsAttrName(), |
| 1165 | getArgAttrsAttrName(), getResAttrsAttrName()}); |
| 1166 | p << " : " ; |
| 1167 | mlir::call_interface_impl::printFunctionSignature( |
| 1168 | p, getArgs().drop_front(isDirect ? 0 : 1).getTypes(), getArgAttrsAttr(), |
| 1169 | /*isVariadic=*/false, getResultTypes(), getResAttrsAttr()); |
| 1170 | } |
| 1171 | |
| 1172 | mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser, |
| 1173 | mlir::OperationState &result) { |
| 1174 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
| 1175 | if (parser.parseOperandList(operands)) |
| 1176 | return mlir::failure(); |
| 1177 | |
| 1178 | mlir::NamedAttrList attrs; |
| 1179 | mlir::SymbolRefAttr funcAttr; |
| 1180 | bool isDirect = operands.empty(); |
| 1181 | if (isDirect) |
| 1182 | if (parser.parseAttribute(funcAttr, fir::CallOp::getCalleeAttrNameStr(), |
| 1183 | attrs)) |
| 1184 | return mlir::failure(); |
| 1185 | |
| 1186 | if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren)) |
| 1187 | return mlir::failure(); |
| 1188 | |
| 1189 | // Parse `proc_attrs<...>`, if present. |
| 1190 | fir::FortranProcedureFlagsEnumAttr procAttr; |
| 1191 | if (mlir::succeeded(parser.parseOptionalKeyword( |
| 1192 | fir::FortranProcedureFlagsEnumAttr::getMnemonic()))) |
| 1193 | if (parser.parseCustomAttributeWithFallback( |
| 1194 | procAttr, mlir::Type{}, getProcedureAttrsAttrName(result.name), |
| 1195 | attrs)) |
| 1196 | return mlir::failure(); |
| 1197 | |
| 1198 | // Parse 'fastmath<...>', if present. |
| 1199 | mlir::arith::FastMathFlagsAttr fmfAttr; |
| 1200 | llvm::StringRef fmfAttrName = getFastmathAttrName(result.name); |
| 1201 | if (mlir::succeeded(parser.parseOptionalKeyword(fmfAttrName))) |
| 1202 | if (parser.parseCustomAttributeWithFallback(fmfAttr, mlir::Type{}, |
| 1203 | fmfAttrName, attrs)) |
| 1204 | return mlir::failure(); |
| 1205 | |
| 1206 | if (parser.parseOptionalAttrDict(attrs) || parser.parseColon()) |
| 1207 | return mlir::failure(); |
| 1208 | llvm::SmallVector<mlir::Type> argTypes; |
| 1209 | llvm::SmallVector<mlir::Type> resTypes; |
| 1210 | llvm::SmallVector<mlir::DictionaryAttr> argAttrs; |
| 1211 | llvm::SmallVector<mlir::DictionaryAttr> resultAttrs; |
| 1212 | if (mlir::call_interface_impl::parseFunctionSignature( |
| 1213 | parser, argTypes, argAttrs, resTypes, resultAttrs)) |
| 1214 | return parser.emitError(parser.getNameLoc(), "expected function type" ); |
| 1215 | mlir::FunctionType funcType = |
| 1216 | mlir::FunctionType::get(parser.getContext(), argTypes, resTypes); |
| 1217 | if (isDirect) { |
| 1218 | if (parser.resolveOperands(operands, funcType.getInputs(), |
| 1219 | parser.getNameLoc(), result.operands)) |
| 1220 | return mlir::failure(); |
| 1221 | } else { |
| 1222 | auto funcArgs = |
| 1223 | llvm::ArrayRef<mlir::OpAsmParser::UnresolvedOperand>(operands) |
| 1224 | .drop_front(); |
| 1225 | if (parser.resolveOperand(operands[0], funcType, result.operands) || |
| 1226 | parser.resolveOperands(funcArgs, funcType.getInputs(), |
| 1227 | parser.getNameLoc(), result.operands)) |
| 1228 | return mlir::failure(); |
| 1229 | } |
| 1230 | result.attributes = attrs; |
| 1231 | mlir::call_interface_impl::addArgAndResultAttrs( |
| 1232 | parser.getBuilder(), result, argAttrs, resultAttrs, |
| 1233 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
| 1234 | result.addTypes(funcType.getResults()); |
| 1235 | return mlir::success(); |
| 1236 | } |
| 1237 | |
| 1238 | void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 1239 | mlir::func::FuncOp callee, mlir::ValueRange operands) { |
| 1240 | result.addOperands(operands); |
| 1241 | result.addAttribute(getCalleeAttrNameStr(), mlir::SymbolRefAttr::get(callee)); |
| 1242 | result.addTypes(callee.getFunctionType().getResults()); |
| 1243 | } |
| 1244 | |
| 1245 | void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 1246 | mlir::SymbolRefAttr callee, |
| 1247 | llvm::ArrayRef<mlir::Type> results, |
| 1248 | mlir::ValueRange operands) { |
| 1249 | result.addOperands(operands); |
| 1250 | if (callee) |
| 1251 | result.addAttribute(getCalleeAttrNameStr(), callee); |
| 1252 | result.addTypes(results); |
| 1253 | } |
| 1254 | |
| 1255 | //===----------------------------------------------------------------------===// |
| 1256 | // CharConvertOp |
| 1257 | //===----------------------------------------------------------------------===// |
| 1258 | |
| 1259 | llvm::LogicalResult fir::CharConvertOp::verify() { |
| 1260 | auto unwrap = [&](mlir::Type t) { |
| 1261 | t = fir::unwrapSequenceType(fir::dyn_cast_ptrEleTy(t)); |
| 1262 | return mlir::dyn_cast<fir::CharacterType>(t); |
| 1263 | }; |
| 1264 | auto inTy = unwrap(getFrom().getType()); |
| 1265 | auto outTy = unwrap(getTo().getType()); |
| 1266 | if (!(inTy && outTy)) |
| 1267 | return emitOpError("not a reference to a character" ); |
| 1268 | if (inTy.getFKind() == outTy.getFKind()) |
| 1269 | return emitOpError("buffers must have different KIND values" ); |
| 1270 | return mlir::success(); |
| 1271 | } |
| 1272 | |
| 1273 | //===----------------------------------------------------------------------===// |
| 1274 | // CmpOp |
| 1275 | //===----------------------------------------------------------------------===// |
| 1276 | |
| 1277 | template <typename OPTY> |
| 1278 | static void printCmpOp(mlir::OpAsmPrinter &p, OPTY op) { |
| 1279 | p << ' '; |
| 1280 | auto predSym = mlir::arith::symbolizeCmpFPredicate( |
| 1281 | op->template getAttrOfType<mlir::IntegerAttr>( |
| 1282 | OPTY::getPredicateAttrName()) |
| 1283 | .getInt()); |
| 1284 | assert(predSym.has_value() && "invalid symbol value for predicate" ); |
| 1285 | p << '"' << mlir::arith::stringifyCmpFPredicate(predSym.value()) << '"' |
| 1286 | << ", " ; |
| 1287 | p.printOperand(op.getLhs()); |
| 1288 | p << ", " ; |
| 1289 | p.printOperand(op.getRhs()); |
| 1290 | p.printOptionalAttrDict(attrs: op->getAttrs(), |
| 1291 | /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); |
| 1292 | p << " : " << op.getLhs().getType(); |
| 1293 | } |
| 1294 | |
| 1295 | template <typename OPTY> |
| 1296 | static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, |
| 1297 | mlir::OperationState &result) { |
| 1298 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> ops; |
| 1299 | mlir::NamedAttrList attrs; |
| 1300 | mlir::Attribute predicateNameAttr; |
| 1301 | mlir::Type type; |
| 1302 | if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), |
| 1303 | attrs) || |
| 1304 | parser.parseComma() || parser.parseOperandList(result&: ops, requiredOperandCount: 2) || |
| 1305 | parser.parseOptionalAttrDict(result&: attrs) || parser.parseColonType(result&: type) || |
| 1306 | parser.resolveOperands(operands&: ops, type, result&: result.operands)) |
| 1307 | return mlir::failure(); |
| 1308 | |
| 1309 | if (!mlir::isa<mlir::StringAttr>(Val: predicateNameAttr)) |
| 1310 | return parser.emitError(loc: parser.getNameLoc(), |
| 1311 | message: "expected string comparison predicate attribute" ); |
| 1312 | |
| 1313 | // Rewrite string attribute to an enum value. |
| 1314 | llvm::StringRef predicateName = |
| 1315 | mlir::cast<mlir::StringAttr>(predicateNameAttr).getValue(); |
| 1316 | auto predicate = fir::CmpcOp::getPredicateByName(predicateName); |
| 1317 | auto builder = parser.getBuilder(); |
| 1318 | mlir::Type i1Type = builder.getI1Type(); |
| 1319 | attrs.set(OPTY::getPredicateAttrName(), |
| 1320 | builder.getI64IntegerAttr(static_cast<std::int64_t>(predicate))); |
| 1321 | result.attributes = attrs; |
| 1322 | result.addTypes(newTypes: {i1Type}); |
| 1323 | return mlir::success(); |
| 1324 | } |
| 1325 | |
| 1326 | //===----------------------------------------------------------------------===// |
| 1327 | // CmpcOp |
| 1328 | //===----------------------------------------------------------------------===// |
| 1329 | |
| 1330 | void fir::buildCmpCOp(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 1331 | mlir::arith::CmpFPredicate predicate, mlir::Value lhs, |
| 1332 | mlir::Value rhs) { |
| 1333 | result.addOperands({lhs, rhs}); |
| 1334 | result.types.push_back(builder.getI1Type()); |
| 1335 | result.addAttribute( |
| 1336 | fir::CmpcOp::getPredicateAttrName(), |
| 1337 | builder.getI64IntegerAttr(static_cast<std::int64_t>(predicate))); |
| 1338 | } |
| 1339 | |
| 1340 | mlir::arith::CmpFPredicate |
| 1341 | fir::CmpcOp::getPredicateByName(llvm::StringRef name) { |
| 1342 | auto pred = mlir::arith::symbolizeCmpFPredicate(name); |
| 1343 | assert(pred.has_value() && "invalid predicate name" ); |
| 1344 | return pred.value(); |
| 1345 | } |
| 1346 | |
| 1347 | void fir::CmpcOp::print(mlir::OpAsmPrinter &p) { printCmpOp(p, *this); } |
| 1348 | |
| 1349 | mlir::ParseResult fir::CmpcOp::parse(mlir::OpAsmParser &parser, |
| 1350 | mlir::OperationState &result) { |
| 1351 | return parseCmpOp<fir::CmpcOp>(parser, result); |
| 1352 | } |
| 1353 | |
| 1354 | //===----------------------------------------------------------------------===// |
| 1355 | // VolatileCastOp |
| 1356 | //===----------------------------------------------------------------------===// |
| 1357 | |
| 1358 | static bool typesMatchExceptForVolatility(mlir::Type fromType, |
| 1359 | mlir::Type toType) { |
| 1360 | // If we can change only the volatility and get identical types, then we |
| 1361 | // match. |
| 1362 | if (fir::updateTypeWithVolatility(fromType, fir::isa_volatile_type(toType)) == |
| 1363 | toType) |
| 1364 | return true; |
| 1365 | |
| 1366 | // Otherwise, recurse on the element types if the base classes are the same. |
| 1367 | const bool match = |
| 1368 | llvm::TypeSwitch<mlir::Type, bool>(fromType) |
| 1369 | .Case<fir::BoxType, fir::ReferenceType, fir::ClassType>( |
| 1370 | [&](auto type) { |
| 1371 | using TYPE = decltype(type); |
| 1372 | // If we are not the same base class, then we don't match. |
| 1373 | auto castedToType = mlir::dyn_cast<TYPE>(toType); |
| 1374 | if (!castedToType) |
| 1375 | return false; |
| 1376 | // If we are the same base class, we match if the element types |
| 1377 | // match. |
| 1378 | return typesMatchExceptForVolatility(type.getEleTy(), |
| 1379 | castedToType.getEleTy()); |
| 1380 | }) |
| 1381 | .Default([](mlir::Type) { return false; }); |
| 1382 | |
| 1383 | return match; |
| 1384 | } |
| 1385 | |
| 1386 | llvm::LogicalResult fir::VolatileCastOp::verify() { |
| 1387 | mlir::Type fromType = getValue().getType(); |
| 1388 | mlir::Type toType = getType(); |
| 1389 | if (!typesMatchExceptForVolatility(fromType, toType)) |
| 1390 | return emitOpError("types must be identical except for volatility " ) |
| 1391 | << fromType << " / " << toType; |
| 1392 | return mlir::success(); |
| 1393 | } |
| 1394 | |
| 1395 | mlir::OpFoldResult fir::VolatileCastOp::fold(FoldAdaptor adaptor) { |
| 1396 | if (getValue().getType() == getType()) |
| 1397 | return getValue(); |
| 1398 | return {}; |
| 1399 | } |
| 1400 | |
| 1401 | //===----------------------------------------------------------------------===// |
| 1402 | // ConvertOp |
| 1403 | //===----------------------------------------------------------------------===// |
| 1404 | |
| 1405 | void fir::ConvertOp::getCanonicalizationPatterns( |
| 1406 | mlir::RewritePatternSet &results, mlir::MLIRContext *context) { |
| 1407 | results.insert<ConvertConvertOptPattern, ConvertAscendingIndexOptPattern, |
| 1408 | ConvertDescendingIndexOptPattern, RedundantConvertOptPattern, |
| 1409 | CombineConvertOptPattern, CombineConvertTruncOptPattern, |
| 1410 | ForwardConstantConvertPattern, ChainedPointerConvertsPattern>( |
| 1411 | context); |
| 1412 | } |
| 1413 | |
| 1414 | mlir::OpFoldResult fir::ConvertOp::fold(FoldAdaptor adaptor) { |
| 1415 | if (getValue().getType() == getType()) |
| 1416 | return getValue(); |
| 1417 | if (matchPattern(getValue(), mlir::m_Op<fir::ConvertOp>())) { |
| 1418 | auto inner = mlir::cast<fir::ConvertOp>(getValue().getDefiningOp()); |
| 1419 | // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a |
| 1420 | if (auto toTy = mlir::dyn_cast<fir::LogicalType>(getType())) |
| 1421 | if (auto fromTy = |
| 1422 | mlir::dyn_cast<fir::LogicalType>(inner.getValue().getType())) |
| 1423 | if (mlir::isa<mlir::IntegerType>(inner.getType()) && (toTy == fromTy)) |
| 1424 | return inner.getValue(); |
| 1425 | // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a |
| 1426 | if (auto toTy = mlir::dyn_cast<mlir::IntegerType>(getType())) |
| 1427 | if (auto fromTy = |
| 1428 | mlir::dyn_cast<mlir::IntegerType>(inner.getValue().getType())) |
| 1429 | if (mlir::isa<fir::LogicalType>(inner.getType()) && (toTy == fromTy) && |
| 1430 | (fromTy.getWidth() == 1)) |
| 1431 | return inner.getValue(); |
| 1432 | } |
| 1433 | return {}; |
| 1434 | } |
| 1435 | |
| 1436 | bool fir::ConvertOp::isInteger(mlir::Type ty) { |
| 1437 | return mlir::isa<mlir::IntegerType, mlir::IndexType, fir::IntegerType>(ty); |
| 1438 | } |
| 1439 | |
| 1440 | bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { |
| 1441 | return isInteger(ty) || mlir::isa<fir::LogicalType>(ty); |
| 1442 | } |
| 1443 | |
| 1444 | bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { |
| 1445 | return mlir::isa<mlir::FloatType>(ty); |
| 1446 | } |
| 1447 | |
| 1448 | bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { |
| 1449 | return mlir::isa<fir::ReferenceType, fir::PointerType, fir::HeapType, |
| 1450 | fir::LLVMPointerType, mlir::MemRefType, mlir::FunctionType, |
| 1451 | fir::TypeDescType, mlir::LLVM::LLVMPointerType>(ty); |
| 1452 | } |
| 1453 | |
| 1454 | static std::optional<mlir::Type> getVectorElementType(mlir::Type ty) { |
| 1455 | mlir::Type elemTy; |
| 1456 | if (mlir::isa<fir::VectorType>(ty)) |
| 1457 | elemTy = mlir::dyn_cast<fir::VectorType>(ty).getElementType(); |
| 1458 | else if (mlir::isa<mlir::VectorType>(Val: ty)) |
| 1459 | elemTy = mlir::dyn_cast<mlir::VectorType>(ty).getElementType(); |
| 1460 | else |
| 1461 | return std::nullopt; |
| 1462 | |
| 1463 | // e.g. fir.vector<4:ui32> => mlir.vector<4xi32> |
| 1464 | // e.g. mlir.vector<4xui32> => mlir.vector<4xi32> |
| 1465 | if (elemTy.isUnsignedInteger()) { |
| 1466 | elemTy = mlir::IntegerType::get( |
| 1467 | ty.getContext(), mlir::dyn_cast<mlir::IntegerType>(elemTy).getWidth()); |
| 1468 | } |
| 1469 | return elemTy; |
| 1470 | } |
| 1471 | |
| 1472 | static std::optional<uint64_t> getVectorLen(mlir::Type ty) { |
| 1473 | if (mlir::isa<fir::VectorType>(ty)) |
| 1474 | return mlir::dyn_cast<fir::VectorType>(ty).getLen(); |
| 1475 | else if (mlir::isa<mlir::VectorType>(Val: ty)) { |
| 1476 | // fir.vector only supports 1-D vector |
| 1477 | if (!(mlir::dyn_cast<mlir::VectorType>(ty).isScalable())) |
| 1478 | return mlir::dyn_cast<mlir::VectorType>(ty).getShape()[0]; |
| 1479 | } |
| 1480 | |
| 1481 | return std::nullopt; |
| 1482 | } |
| 1483 | |
| 1484 | bool fir::ConvertOp::areVectorsCompatible(mlir::Type inTy, mlir::Type outTy) { |
| 1485 | if (!(mlir::isa<fir::VectorType>(inTy) && |
| 1486 | mlir::isa<mlir::VectorType>(outTy)) && |
| 1487 | !(mlir::isa<mlir::VectorType>(inTy) && mlir::isa<fir::VectorType>(outTy))) |
| 1488 | return false; |
| 1489 | |
| 1490 | // Only support integer, unsigned and real vector |
| 1491 | // Both vectors must have the same element type |
| 1492 | std::optional<mlir::Type> inElemTy = getVectorElementType(inTy); |
| 1493 | std::optional<mlir::Type> outElemTy = getVectorElementType(outTy); |
| 1494 | if (!inElemTy.has_value() || !outElemTy.has_value() || |
| 1495 | inElemTy.value() != outElemTy.value()) |
| 1496 | return false; |
| 1497 | |
| 1498 | // Both vectors must have the same number of elements |
| 1499 | std::optional<uint64_t> inLen = getVectorLen(inTy); |
| 1500 | std::optional<uint64_t> outLen = getVectorLen(outTy); |
| 1501 | if (!inLen.has_value() || !outLen.has_value() || |
| 1502 | inLen.value() != outLen.value()) |
| 1503 | return false; |
| 1504 | |
| 1505 | return true; |
| 1506 | } |
| 1507 | |
| 1508 | static bool areRecordsCompatible(mlir::Type inTy, mlir::Type outTy) { |
| 1509 | // Both records must have the same field types. |
| 1510 | // Trust frontend semantics for in-depth checks, such as if both records |
| 1511 | // have the BIND(C) attribute. |
| 1512 | auto inRecTy = mlir::dyn_cast<fir::RecordType>(inTy); |
| 1513 | auto outRecTy = mlir::dyn_cast<fir::RecordType>(outTy); |
| 1514 | return inRecTy && outRecTy && inRecTy.getTypeList() == outRecTy.getTypeList(); |
| 1515 | } |
| 1516 | |
| 1517 | bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) { |
| 1518 | if (inType == outType) |
| 1519 | return true; |
| 1520 | return (isPointerCompatible(inType) && isPointerCompatible(outType)) || |
| 1521 | (isIntegerCompatible(inType) && isIntegerCompatible(outType)) || |
| 1522 | (isInteger(inType) && isFloatCompatible(outType)) || |
| 1523 | (isFloatCompatible(inType) && isInteger(outType)) || |
| 1524 | (isFloatCompatible(inType) && isFloatCompatible(outType)) || |
| 1525 | (isIntegerCompatible(inType) && isPointerCompatible(outType)) || |
| 1526 | (isPointerCompatible(inType) && isIntegerCompatible(outType)) || |
| 1527 | (mlir::isa<fir::BoxType>(inType) && |
| 1528 | mlir::isa<fir::BoxType>(outType)) || |
| 1529 | (mlir::isa<fir::BoxProcType>(inType) && |
| 1530 | mlir::isa<fir::BoxProcType>(outType)) || |
| 1531 | (fir::isa_complex(inType) && fir::isa_complex(outType)) || |
| 1532 | (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) || |
| 1533 | (fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)) || |
| 1534 | (fir::isPolymorphicType(inType) && mlir::isa<BoxType>(outType)) || |
| 1535 | areVectorsCompatible(inType, outType) || |
| 1536 | areRecordsCompatible(inType, outType); |
| 1537 | } |
| 1538 | |
| 1539 | // In general, ptrtoint-like conversions are allowed to lose volatility |
| 1540 | // information because they are either: |
| 1541 | // |
| 1542 | // 1. passing an entity to an external function and there's nothing we can do |
| 1543 | // about volatility after that happens, or |
| 1544 | // 2. for code generation, at which point we represent volatility with |
| 1545 | // attributes on the LLVM instructions and intrinsics. |
| 1546 | // |
| 1547 | // For all other cases, volatility ought to match exactly. |
| 1548 | static mlir::LogicalResult verifyVolatility(mlir::Type inType, |
| 1549 | mlir::Type outType) { |
| 1550 | const bool toLLVMPointer = mlir::isa<mlir::LLVM::LLVMPointerType>(outType); |
| 1551 | const bool toInteger = fir::isa_integer(outType); |
| 1552 | |
| 1553 | // When converting references to classes or allocatables into boxes for |
| 1554 | // runtime arguments, we cast away all the volatility information and pass a |
| 1555 | // box<none>. This is allowed. |
| 1556 | const bool isBoxNoneLike = [&]() { |
| 1557 | if (fir::isBoxNone(outType)) |
| 1558 | return true; |
| 1559 | if (auto referenceType = mlir::dyn_cast<fir::ReferenceType>(outType)) { |
| 1560 | if (fir::isBoxNone(referenceType.getElementType())) { |
| 1561 | return true; |
| 1562 | } |
| 1563 | } |
| 1564 | return false; |
| 1565 | }(); |
| 1566 | |
| 1567 | const bool isPtrToIntLike = toLLVMPointer || toInteger || isBoxNoneLike; |
| 1568 | if (isPtrToIntLike) { |
| 1569 | return mlir::success(); |
| 1570 | } |
| 1571 | |
| 1572 | // In all other cases, we need to check for an exact volatility match. |
| 1573 | return mlir::success(fir::isa_volatile_type(inType) == |
| 1574 | fir::isa_volatile_type(outType)); |
| 1575 | } |
| 1576 | |
| 1577 | llvm::LogicalResult fir::ConvertOp::verify() { |
| 1578 | mlir::Type inType = getValue().getType(); |
| 1579 | mlir::Type outType = getType(); |
| 1580 | if (fir::useStrictVolatileVerification()) { |
| 1581 | if (failed(verifyVolatility(inType, outType))) { |
| 1582 | return emitOpError("this conversion does not preserve volatility: " ) |
| 1583 | << inType << " / " << outType; |
| 1584 | } |
| 1585 | } |
| 1586 | if (canBeConverted(inType, outType)) |
| 1587 | return mlir::success(); |
| 1588 | return emitOpError("invalid type conversion" ) |
| 1589 | << getValue().getType() << " / " << getType(); |
| 1590 | } |
| 1591 | |
| 1592 | //===----------------------------------------------------------------------===// |
| 1593 | // CoordinateOp |
| 1594 | //===----------------------------------------------------------------------===// |
| 1595 | |
| 1596 | void fir::CoordinateOp::build(mlir::OpBuilder &builder, |
| 1597 | mlir::OperationState &result, |
| 1598 | mlir::Type resultType, mlir::Value ref, |
| 1599 | mlir::ValueRange coor) { |
| 1600 | llvm::SmallVector<int32_t> fieldIndices; |
| 1601 | llvm::SmallVector<mlir::Value> dynamicIndices; |
| 1602 | bool anyField = false; |
| 1603 | for (mlir::Value index : coor) { |
| 1604 | if (auto field = index.getDefiningOp<fir::FieldIndexOp>()) { |
| 1605 | auto recTy = mlir::cast<fir::RecordType>(field.getOnType()); |
| 1606 | fieldIndices.push_back(recTy.getFieldIndex(field.getFieldId())); |
| 1607 | anyField = true; |
| 1608 | } else { |
| 1609 | fieldIndices.push_back(fir::CoordinateOp::kDynamicIndex); |
| 1610 | dynamicIndices.push_back(index); |
| 1611 | } |
| 1612 | } |
| 1613 | auto typeAttr = mlir::TypeAttr::get(ref.getType()); |
| 1614 | if (anyField) { |
| 1615 | build(builder, result, resultType, ref, dynamicIndices, typeAttr, |
| 1616 | builder.getDenseI32ArrayAttr(fieldIndices)); |
| 1617 | } else { |
| 1618 | build(builder, result, resultType, ref, dynamicIndices, typeAttr, nullptr); |
| 1619 | } |
| 1620 | } |
| 1621 | |
| 1622 | void fir::CoordinateOp::build(mlir::OpBuilder &builder, |
| 1623 | mlir::OperationState &result, |
| 1624 | mlir::Type resultType, mlir::Value ref, |
| 1625 | llvm::ArrayRef<fir::IntOrValue> coor) { |
| 1626 | llvm::SmallVector<int32_t> fieldIndices; |
| 1627 | llvm::SmallVector<mlir::Value> dynamicIndices; |
| 1628 | bool anyField = false; |
| 1629 | for (fir::IntOrValue index : coor) { |
| 1630 | llvm::TypeSwitch<fir::IntOrValue>(index) |
| 1631 | .Case<mlir::IntegerAttr>([&](mlir::IntegerAttr intAttr) { |
| 1632 | fieldIndices.push_back(intAttr.getInt()); |
| 1633 | anyField = true; |
| 1634 | }) |
| 1635 | .Case<mlir::Value>([&](mlir::Value value) { |
| 1636 | dynamicIndices.push_back(value); |
| 1637 | fieldIndices.push_back(fir::CoordinateOp::kDynamicIndex); |
| 1638 | }); |
| 1639 | } |
| 1640 | auto typeAttr = mlir::TypeAttr::get(ref.getType()); |
| 1641 | if (anyField) { |
| 1642 | build(builder, result, resultType, ref, dynamicIndices, typeAttr, |
| 1643 | builder.getDenseI32ArrayAttr(fieldIndices)); |
| 1644 | } else { |
| 1645 | build(builder, result, resultType, ref, dynamicIndices, typeAttr, nullptr); |
| 1646 | } |
| 1647 | } |
| 1648 | |
| 1649 | void fir::CoordinateOp::print(mlir::OpAsmPrinter &p) { |
| 1650 | p << ' ' << getRef(); |
| 1651 | if (!getFieldIndicesAttr()) { |
| 1652 | p << ", " << getCoor(); |
| 1653 | } else { |
| 1654 | mlir::Type eleTy = fir::getFortranElementType(getRef().getType()); |
| 1655 | for (auto index : getIndices()) { |
| 1656 | p << ", " ; |
| 1657 | llvm::TypeSwitch<fir::IntOrValue>(index) |
| 1658 | .Case<mlir::IntegerAttr>([&](mlir::IntegerAttr intAttr) { |
| 1659 | if (auto recordType = llvm::dyn_cast<fir::RecordType>(eleTy)) { |
| 1660 | int fieldId = intAttr.getInt(); |
| 1661 | if (fieldId < static_cast<int>(recordType.getNumFields())) { |
| 1662 | auto nameAndType = recordType.getTypeList()[fieldId]; |
| 1663 | p << std::get<std::string>(nameAndType); |
| 1664 | eleTy = fir::getFortranElementType( |
| 1665 | std::get<mlir::Type>(nameAndType)); |
| 1666 | return; |
| 1667 | } |
| 1668 | } |
| 1669 | // Invalid index, still print it so that invalid IR can be |
| 1670 | // investigated. |
| 1671 | p << intAttr; |
| 1672 | }) |
| 1673 | .Case<mlir::Value>([&](mlir::Value value) { p << value; }); |
| 1674 | } |
| 1675 | } |
| 1676 | p.printOptionalAttrDict( |
| 1677 | (*this)->getAttrs(), |
| 1678 | /*elideAttrs=*/{getBaseTypeAttrName(), getFieldIndicesAttrName()}); |
| 1679 | p << " : " ; |
| 1680 | p.printFunctionalType(getOperandTypes(), (*this)->getResultTypes()); |
| 1681 | } |
| 1682 | |
| 1683 | mlir::ParseResult fir::CoordinateOp::parse(mlir::OpAsmParser &parser, |
| 1684 | mlir::OperationState &result) { |
| 1685 | mlir::OpAsmParser::UnresolvedOperand memref; |
| 1686 | if (parser.parseOperand(memref) || parser.parseComma()) |
| 1687 | return mlir::failure(); |
| 1688 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> coorOperands; |
| 1689 | llvm::SmallVector<std::pair<llvm::StringRef, int>> fieldNames; |
| 1690 | llvm::SmallVector<int32_t> fieldIndices; |
| 1691 | while (true) { |
| 1692 | llvm::StringRef fieldName; |
| 1693 | if (mlir::succeeded(parser.parseOptionalKeyword(&fieldName))) { |
| 1694 | fieldNames.push_back({fieldName, static_cast<int>(fieldIndices.size())}); |
| 1695 | // Actual value will be computed later when base type has been parsed. |
| 1696 | fieldIndices.push_back(0); |
| 1697 | } else { |
| 1698 | mlir::OpAsmParser::UnresolvedOperand index; |
| 1699 | if (parser.parseOperand(index)) |
| 1700 | return mlir::failure(); |
| 1701 | fieldIndices.push_back(fir::CoordinateOp::kDynamicIndex); |
| 1702 | coorOperands.push_back(index); |
| 1703 | } |
| 1704 | if (mlir::failed(parser.parseOptionalComma())) |
| 1705 | break; |
| 1706 | } |
| 1707 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> allOperands; |
| 1708 | allOperands.push_back(memref); |
| 1709 | allOperands.append(coorOperands.begin(), coorOperands.end()); |
| 1710 | mlir::FunctionType funcTy; |
| 1711 | auto loc = parser.getCurrentLocation(); |
| 1712 | if (parser.parseOptionalAttrDict(result.attributes) || |
| 1713 | parser.parseColonType(funcTy) || |
| 1714 | parser.resolveOperands(allOperands, funcTy.getInputs(), loc, |
| 1715 | result.operands) || |
| 1716 | parser.addTypesToList(funcTy.getResults(), result.types)) |
| 1717 | return mlir::failure(); |
| 1718 | result.addAttribute(getBaseTypeAttrName(result.name), |
| 1719 | mlir::TypeAttr::get(funcTy.getInput(0))); |
| 1720 | if (!fieldNames.empty()) { |
| 1721 | mlir::Type eleTy = fir::getFortranElementType(funcTy.getInput(0)); |
| 1722 | for (auto [fieldName, operandPosition] : fieldNames) { |
| 1723 | auto recTy = llvm::dyn_cast<fir::RecordType>(eleTy); |
| 1724 | if (!recTy) |
| 1725 | return parser.emitError( |
| 1726 | loc, "base must be a derived type when field name appears" ); |
| 1727 | unsigned fieldNum = recTy.getFieldIndex(fieldName); |
| 1728 | if (fieldNum > recTy.getNumFields()) |
| 1729 | return parser.emitError(loc) |
| 1730 | << "field '" << fieldName |
| 1731 | << "' is not a component or subcomponent of the base type" ; |
| 1732 | fieldIndices[operandPosition] = fieldNum; |
| 1733 | eleTy = fir::getFortranElementType( |
| 1734 | std::get<mlir::Type>(recTy.getTypeList()[fieldNum])); |
| 1735 | } |
| 1736 | result.addAttribute(getFieldIndicesAttrName(result.name), |
| 1737 | parser.getBuilder().getDenseI32ArrayAttr(fieldIndices)); |
| 1738 | } |
| 1739 | return mlir::success(); |
| 1740 | } |
| 1741 | |
| 1742 | llvm::LogicalResult fir::CoordinateOp::verify() { |
| 1743 | const mlir::Type refTy = getRef().getType(); |
| 1744 | if (fir::isa_ref_type(refTy)) { |
| 1745 | auto eleTy = fir::dyn_cast_ptrEleTy(refTy); |
| 1746 | if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) { |
| 1747 | if (arrTy.hasUnknownShape()) |
| 1748 | return emitOpError("cannot find coordinate in unknown shape" ); |
| 1749 | if (arrTy.getConstantRows() < arrTy.getDimension() - 1) |
| 1750 | return emitOpError("cannot find coordinate with unknown extents" ); |
| 1751 | } |
| 1752 | if (!(fir::isa_aggregate(eleTy) || fir::isa_complex(eleTy) || |
| 1753 | fir::isa_char_string(eleTy))) |
| 1754 | return emitOpError("cannot apply to this element type" ); |
| 1755 | } |
| 1756 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(refTy); |
| 1757 | unsigned dimension = 0; |
| 1758 | const unsigned numCoors = getCoor().size(); |
| 1759 | for (auto coorOperand : llvm::enumerate(getCoor())) { |
| 1760 | auto co = coorOperand.value(); |
| 1761 | if (dimension == 0 && mlir::isa<fir::SequenceType>(eleTy)) { |
| 1762 | dimension = mlir::cast<fir::SequenceType>(eleTy).getDimension(); |
| 1763 | if (dimension == 0) |
| 1764 | return emitOpError("cannot apply to array of unknown rank" ); |
| 1765 | } |
| 1766 | if (auto *defOp = co.getDefiningOp()) { |
| 1767 | if (auto index = mlir::dyn_cast<fir::LenParamIndexOp>(defOp)) { |
| 1768 | // Recovering a LEN type parameter only makes sense from a boxed |
| 1769 | // value. For a bare reference, the LEN type parameters must be |
| 1770 | // passed as additional arguments to `index`. |
| 1771 | if (mlir::isa<fir::BoxType>(refTy)) { |
| 1772 | if (coorOperand.index() != numCoors - 1) |
| 1773 | return emitOpError("len_param_index must be last argument" ); |
| 1774 | if (getNumOperands() != 2) |
| 1775 | return emitOpError("too many operands for len_param_index case" ); |
| 1776 | } |
| 1777 | if (eleTy != index.getOnType()) |
| 1778 | emitOpError( |
| 1779 | "len_param_index type not compatible with reference type" ); |
| 1780 | return mlir::success(); |
| 1781 | } else if (auto index = mlir::dyn_cast<fir::FieldIndexOp>(defOp)) { |
| 1782 | if (eleTy != index.getOnType()) |
| 1783 | emitOpError("field_index type not compatible with reference type" ); |
| 1784 | if (auto recTy = mlir::dyn_cast<fir::RecordType>(eleTy)) { |
| 1785 | eleTy = recTy.getType(index.getFieldName()); |
| 1786 | continue; |
| 1787 | } |
| 1788 | return emitOpError("field_index not applied to !fir.type" ); |
| 1789 | } |
| 1790 | } |
| 1791 | if (dimension) { |
| 1792 | if (--dimension == 0) |
| 1793 | eleTy = mlir::cast<fir::SequenceType>(eleTy).getElementType(); |
| 1794 | } else { |
| 1795 | if (auto t = mlir::dyn_cast<mlir::TupleType>(eleTy)) { |
| 1796 | // FIXME: Generally, we don't know which field of the tuple is being |
| 1797 | // referred to unless the operand is a constant. Just assume everything |
| 1798 | // is good in the tuple case for now. |
| 1799 | return mlir::success(); |
| 1800 | } else if (auto t = mlir::dyn_cast<fir::RecordType>(eleTy)) { |
| 1801 | // FIXME: This is the same as the tuple case. |
| 1802 | return mlir::success(); |
| 1803 | } else if (auto t = mlir::dyn_cast<mlir::ComplexType>(eleTy)) { |
| 1804 | eleTy = t.getElementType(); |
| 1805 | } else if (auto t = mlir::dyn_cast<fir::CharacterType>(eleTy)) { |
| 1806 | if (t.getLen() == fir::CharacterType::singleton()) |
| 1807 | return emitOpError("cannot apply to character singleton" ); |
| 1808 | eleTy = fir::CharacterType::getSingleton(t.getContext(), t.getFKind()); |
| 1809 | if (fir::unwrapRefType(getType()) != eleTy) |
| 1810 | return emitOpError("character type mismatch" ); |
| 1811 | } else { |
| 1812 | return emitOpError("invalid parameters (too many)" ); |
| 1813 | } |
| 1814 | } |
| 1815 | } |
| 1816 | return mlir::success(); |
| 1817 | } |
| 1818 | |
| 1819 | fir::CoordinateIndicesAdaptor fir::CoordinateOp::getIndices() { |
| 1820 | return CoordinateIndicesAdaptor(getFieldIndicesAttr(), getCoor()); |
| 1821 | } |
| 1822 | |
| 1823 | //===----------------------------------------------------------------------===// |
| 1824 | // DispatchOp |
| 1825 | //===----------------------------------------------------------------------===// |
| 1826 | |
| 1827 | llvm::LogicalResult fir::DispatchOp::verify() { |
| 1828 | // Check that pass_arg_pos is in range of actual operands. pass_arg_pos is |
| 1829 | // unsigned so check for less than zero is not needed. |
| 1830 | if (getPassArgPos() && *getPassArgPos() > (getArgOperands().size() - 1)) |
| 1831 | return emitOpError( |
| 1832 | "pass_arg_pos must be smaller than the number of operands" ); |
| 1833 | |
| 1834 | // Operand pointed by pass_arg_pos must have polymorphic type. |
| 1835 | if (getPassArgPos() && |
| 1836 | !fir::isPolymorphicType(getArgOperands()[*getPassArgPos()].getType())) |
| 1837 | return emitOpError("pass_arg_pos must be a polymorphic operand" ); |
| 1838 | return mlir::success(); |
| 1839 | } |
| 1840 | |
| 1841 | mlir::FunctionType fir::DispatchOp::getFunctionType() { |
| 1842 | return mlir::FunctionType::get(getContext(), getOperandTypes(), |
| 1843 | getResultTypes()); |
| 1844 | } |
| 1845 | |
| 1846 | //===----------------------------------------------------------------------===// |
| 1847 | // TypeInfoOp |
| 1848 | //===----------------------------------------------------------------------===// |
| 1849 | |
| 1850 | void fir::TypeInfoOp::build(mlir::OpBuilder &builder, |
| 1851 | mlir::OperationState &result, fir::RecordType type, |
| 1852 | fir::RecordType parentType, |
| 1853 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
| 1854 | result.addRegion(); |
| 1855 | result.addRegion(); |
| 1856 | result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
| 1857 | builder.getStringAttr(type.getName())); |
| 1858 | result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type)); |
| 1859 | if (parentType) |
| 1860 | result.addAttribute(getParentTypeAttrName(result.name), |
| 1861 | mlir::TypeAttr::get(parentType)); |
| 1862 | result.addAttributes(attrs); |
| 1863 | } |
| 1864 | |
| 1865 | llvm::LogicalResult fir::TypeInfoOp::verify() { |
| 1866 | if (!getDispatchTable().empty()) |
| 1867 | for (auto &op : getDispatchTable().front().without_terminator()) |
| 1868 | if (!mlir::isa<fir::DTEntryOp>(op)) |
| 1869 | return op.emitOpError("dispatch table must contain dt_entry" ); |
| 1870 | |
| 1871 | if (!mlir::isa<fir::RecordType>(getType())) |
| 1872 | return emitOpError("type must be a fir.type" ); |
| 1873 | |
| 1874 | if (getParentType() && !mlir::isa<fir::RecordType>(*getParentType())) |
| 1875 | return emitOpError("parent_type must be a fir.type" ); |
| 1876 | return mlir::success(); |
| 1877 | } |
| 1878 | |
| 1879 | //===----------------------------------------------------------------------===// |
| 1880 | // EmboxOp |
| 1881 | //===----------------------------------------------------------------------===// |
| 1882 | |
| 1883 | // Conversions from reference types to box types must preserve volatility. |
| 1884 | static llvm::LogicalResult |
| 1885 | verifyEmboxOpVolatilityInvariants(mlir::Type memrefType, |
| 1886 | mlir::Type resultType) { |
| 1887 | |
| 1888 | if (!fir::useStrictVolatileVerification()) |
| 1889 | return mlir::success(); |
| 1890 | |
| 1891 | mlir::Type boxElementType = |
| 1892 | llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) |
| 1893 | .Case<fir::BoxType, fir::ClassType>( |
| 1894 | [&](auto type) { return type.getEleTy(); }) |
| 1895 | .Default([&](mlir::Type type) { return type; }); |
| 1896 | |
| 1897 | // If the embox is simply wrapping a non-volatile type into a volatile box, |
| 1898 | // we're not losing any volatility information. |
| 1899 | if (boxElementType == memrefType) { |
| 1900 | return mlir::success(); |
| 1901 | } |
| 1902 | |
| 1903 | // Otherwise, the volatility of the input and result must match. |
| 1904 | const bool volatilityMatches = |
| 1905 | fir::isa_volatile_type(memrefType) == fir::isa_volatile_type(resultType); |
| 1906 | |
| 1907 | return mlir::success(IsSuccess: volatilityMatches); |
| 1908 | } |
| 1909 | |
| 1910 | llvm::LogicalResult fir::EmboxOp::verify() { |
| 1911 | auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); |
| 1912 | bool isArray = false; |
| 1913 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) { |
| 1914 | eleTy = seqTy.getEleTy(); |
| 1915 | isArray = true; |
| 1916 | } |
| 1917 | if (hasLenParams()) { |
| 1918 | auto lenPs = numLenParams(); |
| 1919 | if (auto rt = mlir::dyn_cast<fir::RecordType>(eleTy)) { |
| 1920 | if (lenPs != rt.getNumLenParams()) |
| 1921 | return emitOpError("number of LEN params does not correspond" |
| 1922 | " to the !fir.type type" ); |
| 1923 | } else if (auto strTy = mlir::dyn_cast<fir::CharacterType>(eleTy)) { |
| 1924 | if (strTy.getLen() != fir::CharacterType::unknownLen()) |
| 1925 | return emitOpError("CHARACTER already has static LEN" ); |
| 1926 | } else { |
| 1927 | return emitOpError("LEN parameters require CHARACTER or derived type" ); |
| 1928 | } |
| 1929 | for (auto lp : getTypeparams()) |
| 1930 | if (!fir::isa_integer(lp.getType())) |
| 1931 | return emitOpError("LEN parameters must be integral type" ); |
| 1932 | } |
| 1933 | if (getShape() && !isArray) |
| 1934 | return emitOpError("shape must not be provided for a scalar" ); |
| 1935 | if (getSlice() && !isArray) |
| 1936 | return emitOpError("slice must not be provided for a scalar" ); |
| 1937 | if (getSourceBox() && !mlir::isa<fir::ClassType>(getResult().getType())) |
| 1938 | return emitOpError("source_box must be used with fir.class result type" ); |
| 1939 | if (failed(verifyEmboxOpVolatilityInvariants(getMemref().getType(), |
| 1940 | getResult().getType()))) |
| 1941 | return emitOpError( |
| 1942 | "cannot convert between volatile and non-volatile types:" ) |
| 1943 | << " " << getMemref().getType() << " " << getResult().getType(); |
| 1944 | return mlir::success(); |
| 1945 | } |
| 1946 | |
| 1947 | //===----------------------------------------------------------------------===// |
| 1948 | // EmboxCharOp |
| 1949 | //===----------------------------------------------------------------------===// |
| 1950 | |
| 1951 | llvm::LogicalResult fir::EmboxCharOp::verify() { |
| 1952 | auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); |
| 1953 | if (!mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)) |
| 1954 | return mlir::failure(); |
| 1955 | return mlir::success(); |
| 1956 | } |
| 1957 | |
| 1958 | //===----------------------------------------------------------------------===// |
| 1959 | // EmboxProcOp |
| 1960 | //===----------------------------------------------------------------------===// |
| 1961 | |
| 1962 | llvm::LogicalResult fir::EmboxProcOp::verify() { |
| 1963 | // host bindings (optional) must be a reference to a tuple |
| 1964 | if (auto h = getHost()) { |
| 1965 | if (auto r = mlir::dyn_cast<fir::ReferenceType>(h.getType())) |
| 1966 | if (mlir::isa<mlir::TupleType>(r.getEleTy())) |
| 1967 | return mlir::success(); |
| 1968 | return mlir::failure(); |
| 1969 | } |
| 1970 | return mlir::success(); |
| 1971 | } |
| 1972 | |
| 1973 | //===----------------------------------------------------------------------===// |
| 1974 | // TypeDescOp |
| 1975 | //===----------------------------------------------------------------------===// |
| 1976 | |
| 1977 | void fir::TypeDescOp::build(mlir::OpBuilder &, mlir::OperationState &result, |
| 1978 | mlir::TypeAttr inty) { |
| 1979 | result.addAttribute("in_type" , inty); |
| 1980 | result.addTypes(TypeDescType::get(inty.getValue())); |
| 1981 | } |
| 1982 | |
| 1983 | mlir::ParseResult fir::TypeDescOp::parse(mlir::OpAsmParser &parser, |
| 1984 | mlir::OperationState &result) { |
| 1985 | mlir::Type intype; |
| 1986 | if (parser.parseType(intype)) |
| 1987 | return mlir::failure(); |
| 1988 | result.addAttribute("in_type" , mlir::TypeAttr::get(intype)); |
| 1989 | mlir::Type restype = fir::TypeDescType::get(intype); |
| 1990 | if (parser.addTypeToList(restype, result.types)) |
| 1991 | return mlir::failure(); |
| 1992 | return mlir::success(); |
| 1993 | } |
| 1994 | |
| 1995 | void fir::TypeDescOp::print(mlir::OpAsmPrinter &p) { |
| 1996 | p << ' ' << getOperation()->getAttr("in_type" ); |
| 1997 | p.printOptionalAttrDict(getOperation()->getAttrs(), {"in_type" }); |
| 1998 | } |
| 1999 | |
| 2000 | llvm::LogicalResult fir::TypeDescOp::verify() { |
| 2001 | mlir::Type resultTy = getType(); |
| 2002 | if (auto tdesc = mlir::dyn_cast<fir::TypeDescType>(resultTy)) { |
| 2003 | if (tdesc.getOfTy() != getInType()) |
| 2004 | return emitOpError("wrapped type mismatched" ); |
| 2005 | return mlir::success(); |
| 2006 | } |
| 2007 | return emitOpError("must be !fir.tdesc type" ); |
| 2008 | } |
| 2009 | |
| 2010 | //===----------------------------------------------------------------------===// |
| 2011 | // GlobalOp |
| 2012 | //===----------------------------------------------------------------------===// |
| 2013 | |
| 2014 | mlir::Type fir::GlobalOp::resultType() { |
| 2015 | return wrapAllocaResultType(getType()); |
| 2016 | } |
| 2017 | |
| 2018 | mlir::ParseResult fir::GlobalOp::parse(mlir::OpAsmParser &parser, |
| 2019 | mlir::OperationState &result) { |
| 2020 | // Parse the optional linkage |
| 2021 | llvm::StringRef linkage; |
| 2022 | auto &builder = parser.getBuilder(); |
| 2023 | if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { |
| 2024 | if (fir::GlobalOp::verifyValidLinkage(linkage)) |
| 2025 | return mlir::failure(); |
| 2026 | mlir::StringAttr linkAttr = builder.getStringAttr(linkage); |
| 2027 | result.addAttribute(fir::GlobalOp::getLinkNameAttrName(result.name), |
| 2028 | linkAttr); |
| 2029 | } |
| 2030 | |
| 2031 | // Parse the name as a symbol reference attribute. |
| 2032 | mlir::SymbolRefAttr nameAttr; |
| 2033 | if (parser.parseAttribute(nameAttr, |
| 2034 | fir::GlobalOp::getSymrefAttrName(result.name), |
| 2035 | result.attributes)) |
| 2036 | return mlir::failure(); |
| 2037 | result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
| 2038 | nameAttr.getRootReference()); |
| 2039 | |
| 2040 | bool simpleInitializer = false; |
| 2041 | if (mlir::succeeded(parser.parseOptionalLParen())) { |
| 2042 | mlir::Attribute attr; |
| 2043 | if (parser.parseAttribute(attr, getInitValAttrName(result.name), |
| 2044 | result.attributes) || |
| 2045 | parser.parseRParen()) |
| 2046 | return mlir::failure(); |
| 2047 | simpleInitializer = true; |
| 2048 | } |
| 2049 | |
| 2050 | if (parser.parseOptionalAttrDict(result.attributes)) |
| 2051 | return mlir::failure(); |
| 2052 | |
| 2053 | if (succeeded( |
| 2054 | parser.parseOptionalKeyword(getConstantAttrName(result.name)))) { |
| 2055 | // if "constant" keyword then mark this as a constant, not a variable |
| 2056 | result.addAttribute(getConstantAttrName(result.name), |
| 2057 | builder.getUnitAttr()); |
| 2058 | } |
| 2059 | |
| 2060 | if (succeeded(parser.parseOptionalKeyword(getTargetAttrName(result.name)))) |
| 2061 | result.addAttribute(getTargetAttrName(result.name), builder.getUnitAttr()); |
| 2062 | |
| 2063 | mlir::Type globalType; |
| 2064 | if (parser.parseColonType(globalType)) |
| 2065 | return mlir::failure(); |
| 2066 | |
| 2067 | result.addAttribute(fir::GlobalOp::getTypeAttrName(result.name), |
| 2068 | mlir::TypeAttr::get(globalType)); |
| 2069 | |
| 2070 | if (simpleInitializer) { |
| 2071 | result.addRegion(); |
| 2072 | } else { |
| 2073 | // Parse the optional initializer body. |
| 2074 | auto parseResult = |
| 2075 | parser.parseOptionalRegion(*result.addRegion(), /*arguments=*/{}); |
| 2076 | if (parseResult.has_value() && mlir::failed(*parseResult)) |
| 2077 | return mlir::failure(); |
| 2078 | } |
| 2079 | return mlir::success(); |
| 2080 | } |
| 2081 | |
| 2082 | void fir::GlobalOp::print(mlir::OpAsmPrinter &p) { |
| 2083 | if (getLinkName()) |
| 2084 | p << ' ' << *getLinkName(); |
| 2085 | p << ' '; |
| 2086 | p.printAttributeWithoutType(getSymrefAttr()); |
| 2087 | if (auto val = getValueOrNull()) |
| 2088 | p << '(' << val << ')'; |
| 2089 | // Print all other attributes that are not pretty printed here. |
| 2090 | p.printOptionalAttrDict((*this)->getAttrs(), /*elideAttrs=*/{ |
| 2091 | getSymNameAttrName(), getSymrefAttrName(), |
| 2092 | getTypeAttrName(), getConstantAttrName(), |
| 2093 | getTargetAttrName(), getLinkNameAttrName(), |
| 2094 | getInitValAttrName()}); |
| 2095 | if (getOperation()->getAttr(getConstantAttrName())) |
| 2096 | p << " " << getConstantAttrName().strref(); |
| 2097 | if (getOperation()->getAttr(getTargetAttrName())) |
| 2098 | p << " " << getTargetAttrName().strref(); |
| 2099 | p << " : " ; |
| 2100 | p.printType(getType()); |
| 2101 | if (hasInitializationBody()) { |
| 2102 | p << ' '; |
| 2103 | p.printRegion(getOperation()->getRegion(0), |
| 2104 | /*printEntryBlockArgs=*/false, |
| 2105 | /*printBlockTerminators=*/true); |
| 2106 | } |
| 2107 | } |
| 2108 | |
| 2109 | void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { |
| 2110 | getBlock().getOperations().push_back(op); |
| 2111 | } |
| 2112 | |
| 2113 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
| 2114 | mlir::OperationState &result, llvm::StringRef name, |
| 2115 | bool isConstant, bool isTarget, mlir::Type type, |
| 2116 | mlir::Attribute initialVal, mlir::StringAttr linkage, |
| 2117 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
| 2118 | result.addRegion(); |
| 2119 | result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type)); |
| 2120 | result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
| 2121 | builder.getStringAttr(name)); |
| 2122 | result.addAttribute(getSymrefAttrName(result.name), |
| 2123 | mlir::SymbolRefAttr::get(builder.getContext(), name)); |
| 2124 | if (isConstant) |
| 2125 | result.addAttribute(getConstantAttrName(result.name), |
| 2126 | builder.getUnitAttr()); |
| 2127 | if (isTarget) |
| 2128 | result.addAttribute(getTargetAttrName(result.name), builder.getUnitAttr()); |
| 2129 | if (initialVal) |
| 2130 | result.addAttribute(getInitValAttrName(result.name), initialVal); |
| 2131 | if (linkage) |
| 2132 | result.addAttribute(getLinkNameAttrName(result.name), linkage); |
| 2133 | result.attributes.append(attrs.begin(), attrs.end()); |
| 2134 | } |
| 2135 | |
| 2136 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
| 2137 | mlir::OperationState &result, llvm::StringRef name, |
| 2138 | mlir::Type type, mlir::Attribute initialVal, |
| 2139 | mlir::StringAttr linkage, |
| 2140 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
| 2141 | build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, |
| 2142 | {}, linkage, attrs); |
| 2143 | } |
| 2144 | |
| 2145 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
| 2146 | mlir::OperationState &result, llvm::StringRef name, |
| 2147 | bool isConstant, bool isTarget, mlir::Type type, |
| 2148 | mlir::StringAttr linkage, |
| 2149 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
| 2150 | build(builder, result, name, isConstant, isTarget, type, {}, linkage, attrs); |
| 2151 | } |
| 2152 | |
| 2153 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
| 2154 | mlir::OperationState &result, llvm::StringRef name, |
| 2155 | mlir::Type type, mlir::StringAttr linkage, |
| 2156 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
| 2157 | build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, |
| 2158 | {}, linkage, attrs); |
| 2159 | } |
| 2160 | |
| 2161 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
| 2162 | mlir::OperationState &result, llvm::StringRef name, |
| 2163 | bool isConstant, bool isTarget, mlir::Type type, |
| 2164 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
| 2165 | build(builder, result, name, isConstant, isTarget, type, mlir::StringAttr{}, |
| 2166 | attrs); |
| 2167 | } |
| 2168 | |
| 2169 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
| 2170 | mlir::OperationState &result, llvm::StringRef name, |
| 2171 | mlir::Type type, |
| 2172 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
| 2173 | build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, |
| 2174 | attrs); |
| 2175 | } |
| 2176 | |
| 2177 | mlir::ParseResult fir::GlobalOp::verifyValidLinkage(llvm::StringRef linkage) { |
| 2178 | // Supporting only a subset of the LLVM linkage types for now |
| 2179 | static const char *validNames[] = {"common" , "internal" , "linkonce" , |
| 2180 | "linkonce_odr" , "weak" }; |
| 2181 | return mlir::success(llvm::is_contained(validNames, linkage)); |
| 2182 | } |
| 2183 | |
| 2184 | //===----------------------------------------------------------------------===// |
| 2185 | // GlobalLenOp |
| 2186 | //===----------------------------------------------------------------------===// |
| 2187 | |
| 2188 | mlir::ParseResult fir::GlobalLenOp::parse(mlir::OpAsmParser &parser, |
| 2189 | mlir::OperationState &result) { |
| 2190 | llvm::StringRef fieldName; |
| 2191 | if (failed(parser.parseOptionalKeyword(&fieldName))) { |
| 2192 | mlir::StringAttr fieldAttr; |
| 2193 | if (parser.parseAttribute(fieldAttr, |
| 2194 | fir::GlobalLenOp::getLenParamAttrName(), |
| 2195 | result.attributes)) |
| 2196 | return mlir::failure(); |
| 2197 | } else { |
| 2198 | result.addAttribute(fir::GlobalLenOp::getLenParamAttrName(), |
| 2199 | parser.getBuilder().getStringAttr(fieldName)); |
| 2200 | } |
| 2201 | mlir::IntegerAttr constant; |
| 2202 | if (parser.parseComma() || |
| 2203 | parser.parseAttribute(constant, fir::GlobalLenOp::getIntAttrName(), |
| 2204 | result.attributes)) |
| 2205 | return mlir::failure(); |
| 2206 | return mlir::success(); |
| 2207 | } |
| 2208 | |
| 2209 | void fir::GlobalLenOp::print(mlir::OpAsmPrinter &p) { |
| 2210 | p << ' ' << getOperation()->getAttr(fir::GlobalLenOp::getLenParamAttrName()) |
| 2211 | << ", " << getOperation()->getAttr(fir::GlobalLenOp::getIntAttrName()); |
| 2212 | } |
| 2213 | |
| 2214 | //===----------------------------------------------------------------------===// |
| 2215 | // FieldIndexOp |
| 2216 | //===----------------------------------------------------------------------===// |
| 2217 | |
| 2218 | template <typename TY> |
| 2219 | mlir::ParseResult parseFieldLikeOp(mlir::OpAsmParser &parser, |
| 2220 | mlir::OperationState &result) { |
| 2221 | llvm::StringRef fieldName; |
| 2222 | auto &builder = parser.getBuilder(); |
| 2223 | mlir::Type recty; |
| 2224 | if (parser.parseOptionalKeyword(keyword: &fieldName) || parser.parseComma() || |
| 2225 | parser.parseType(result&: recty)) |
| 2226 | return mlir::failure(); |
| 2227 | result.addAttribute(fir::FieldIndexOp::getFieldAttrName(), |
| 2228 | builder.getStringAttr(fieldName)); |
| 2229 | if (!mlir::dyn_cast<fir::RecordType>(recty)) |
| 2230 | return mlir::failure(); |
| 2231 | result.addAttribute(fir::FieldIndexOp::getTypeAttrName(), |
| 2232 | mlir::TypeAttr::get(recty)); |
| 2233 | if (!parser.parseOptionalLParen()) { |
| 2234 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
| 2235 | llvm::SmallVector<mlir::Type> types; |
| 2236 | auto loc = parser.getNameLoc(); |
| 2237 | if (parser.parseOperandList(result&: operands, delimiter: mlir::OpAsmParser::Delimiter::None) || |
| 2238 | parser.parseColonTypeList(result&: types) || parser.parseRParen() || |
| 2239 | parser.resolveOperands(operands, types, loc, result&: result.operands)) |
| 2240 | return mlir::failure(); |
| 2241 | } |
| 2242 | mlir::Type fieldType = TY::get(builder.getContext()); |
| 2243 | if (parser.addTypeToList(type: fieldType, result&: result.types)) |
| 2244 | return mlir::failure(); |
| 2245 | return mlir::success(); |
| 2246 | } |
| 2247 | |
| 2248 | mlir::ParseResult fir::FieldIndexOp::parse(mlir::OpAsmParser &parser, |
| 2249 | mlir::OperationState &result) { |
| 2250 | return parseFieldLikeOp<fir::FieldType>(parser, result); |
| 2251 | } |
| 2252 | |
| 2253 | template <typename OP> |
| 2254 | void printFieldLikeOp(mlir::OpAsmPrinter &p, OP &op) { |
| 2255 | p << ' ' |
| 2256 | << op.getOperation() |
| 2257 | ->template getAttrOfType<mlir::StringAttr>( |
| 2258 | fir::FieldIndexOp::getFieldAttrName()) |
| 2259 | .getValue() |
| 2260 | << ", " << op.getOperation()->getAttr(fir::FieldIndexOp::getTypeAttrName()); |
| 2261 | if (op.getNumOperands()) { |
| 2262 | p << '('; |
| 2263 | p.printOperands(op.getTypeparams()); |
| 2264 | auto sep = ") : " ; |
| 2265 | for (auto op : op.getTypeparams()) { |
| 2266 | p << sep; |
| 2267 | if (op) |
| 2268 | p.printType(type: op.getType()); |
| 2269 | else |
| 2270 | p << "()" ; |
| 2271 | sep = ", " ; |
| 2272 | } |
| 2273 | } |
| 2274 | } |
| 2275 | |
| 2276 | void fir::FieldIndexOp::print(mlir::OpAsmPrinter &p) { |
| 2277 | printFieldLikeOp(p, *this); |
| 2278 | } |
| 2279 | |
| 2280 | void fir::FieldIndexOp::build(mlir::OpBuilder &builder, |
| 2281 | mlir::OperationState &result, |
| 2282 | llvm::StringRef fieldName, mlir::Type recTy, |
| 2283 | mlir::ValueRange operands) { |
| 2284 | result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); |
| 2285 | result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); |
| 2286 | result.addOperands(operands); |
| 2287 | } |
| 2288 | |
| 2289 | llvm::SmallVector<mlir::Attribute> fir::FieldIndexOp::getAttributes() { |
| 2290 | llvm::SmallVector<mlir::Attribute> attrs; |
| 2291 | attrs.push_back(getFieldIdAttr()); |
| 2292 | attrs.push_back(getOnTypeAttr()); |
| 2293 | return attrs; |
| 2294 | } |
| 2295 | |
| 2296 | //===----------------------------------------------------------------------===// |
| 2297 | // InsertOnRangeOp |
| 2298 | //===----------------------------------------------------------------------===// |
| 2299 | |
| 2300 | static mlir::ParseResult |
| 2301 | parseCustomRangeSubscript(mlir::OpAsmParser &parser, |
| 2302 | mlir::DenseIntElementsAttr &coord) { |
| 2303 | llvm::SmallVector<std::int64_t> lbounds; |
| 2304 | llvm::SmallVector<std::int64_t> ubounds; |
| 2305 | if (parser.parseKeyword(keyword: "from" ) || |
| 2306 | parser.parseCommaSeparatedList( |
| 2307 | delimiter: mlir::AsmParser::Delimiter::Paren, |
| 2308 | parseElementFn: [&] { return parser.parseInteger(result&: lbounds.emplace_back(Args: 0)); }) || |
| 2309 | parser.parseKeyword(keyword: "to" ) || |
| 2310 | parser.parseCommaSeparatedList(delimiter: mlir::AsmParser::Delimiter::Paren, parseElementFn: [&] { |
| 2311 | return parser.parseInteger(result&: ubounds.emplace_back(Args: 0)); |
| 2312 | })) |
| 2313 | return mlir::failure(); |
| 2314 | llvm::SmallVector<std::int64_t> zippedBounds; |
| 2315 | for (auto zip : llvm::zip(t&: lbounds, u&: ubounds)) { |
| 2316 | zippedBounds.push_back(Elt: std::get<0>(t&: zip)); |
| 2317 | zippedBounds.push_back(Elt: std::get<1>(t&: zip)); |
| 2318 | } |
| 2319 | coord = mlir::Builder(parser.getContext()).getIndexTensorAttr(values: zippedBounds); |
| 2320 | return mlir::success(); |
| 2321 | } |
| 2322 | |
| 2323 | static void printCustomRangeSubscript(mlir::OpAsmPrinter &printer, |
| 2324 | fir::InsertOnRangeOp op, |
| 2325 | mlir::DenseIntElementsAttr coord) { |
| 2326 | printer << "from (" ; |
| 2327 | auto enumerate = llvm::enumerate(coord.getValues<std::int64_t>()); |
| 2328 | // Even entries are the lower bounds. |
| 2329 | llvm::interleaveComma( |
| 2330 | make_filter_range( |
| 2331 | enumerate, |
| 2332 | [](auto indexed_value) { return indexed_value.index() % 2 == 0; }), |
| 2333 | printer, [&](auto indexed_value) { printer << indexed_value.value(); }); |
| 2334 | printer << ") to (" ; |
| 2335 | // Odd entries are the upper bounds. |
| 2336 | llvm::interleaveComma( |
| 2337 | make_filter_range( |
| 2338 | enumerate, |
| 2339 | [](auto indexed_value) { return indexed_value.index() % 2 != 0; }), |
| 2340 | printer, [&](auto indexed_value) { printer << indexed_value.value(); }); |
| 2341 | printer << ")" ; |
| 2342 | } |
| 2343 | |
| 2344 | /// Range bounds must be nonnegative, and the range must not be empty. |
| 2345 | llvm::LogicalResult fir::InsertOnRangeOp::verify() { |
| 2346 | if (fir::hasDynamicSize(getSeq().getType())) |
| 2347 | return emitOpError("must have constant shape and size" ); |
| 2348 | mlir::DenseIntElementsAttr coorAttr = getCoor(); |
| 2349 | if (coorAttr.size() < 2 || coorAttr.size() % 2 != 0) |
| 2350 | return emitOpError("has uneven number of values in ranges" ); |
| 2351 | bool rangeIsKnownToBeNonempty = false; |
| 2352 | for (auto i = coorAttr.getValues<std::int64_t>().end(), |
| 2353 | b = coorAttr.getValues<std::int64_t>().begin(); |
| 2354 | i != b;) { |
| 2355 | int64_t ub = (*--i); |
| 2356 | int64_t lb = (*--i); |
| 2357 | if (lb < 0 || ub < 0) |
| 2358 | return emitOpError("negative range bound" ); |
| 2359 | if (rangeIsKnownToBeNonempty) |
| 2360 | continue; |
| 2361 | if (lb > ub) |
| 2362 | return emitOpError("empty range" ); |
| 2363 | rangeIsKnownToBeNonempty = lb < ub; |
| 2364 | } |
| 2365 | return mlir::success(); |
| 2366 | } |
| 2367 | |
| 2368 | bool fir::InsertOnRangeOp::isFullRange() { |
| 2369 | auto extents = getType().getShape(); |
| 2370 | mlir::DenseIntElementsAttr indexes = getCoor(); |
| 2371 | if (indexes.size() / 2 != static_cast<int64_t>(extents.size())) |
| 2372 | return false; |
| 2373 | auto cur_index = indexes.value_begin<int64_t>(); |
| 2374 | for (unsigned i = 0; i < indexes.size(); i += 2) { |
| 2375 | if (*(cur_index++) != 0) |
| 2376 | return false; |
| 2377 | if (*(cur_index++) != extents[i / 2] - 1) |
| 2378 | return false; |
| 2379 | } |
| 2380 | return true; |
| 2381 | } |
| 2382 | |
| 2383 | //===----------------------------------------------------------------------===// |
| 2384 | // InsertValueOp |
| 2385 | //===----------------------------------------------------------------------===// |
| 2386 | |
| 2387 | static bool checkIsIntegerConstant(mlir::Attribute attr, std::int64_t conVal) { |
| 2388 | if (auto iattr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) |
| 2389 | return iattr.getInt() == conVal; |
| 2390 | return false; |
| 2391 | } |
| 2392 | |
| 2393 | static bool isZero(mlir::Attribute a) { return checkIsIntegerConstant(attr: a, conVal: 0); } |
| 2394 | static bool isOne(mlir::Attribute a) { return checkIsIntegerConstant(attr: a, conVal: 1); } |
| 2395 | |
| 2396 | // Undo some complex patterns created in the front-end and turn them back into |
| 2397 | // complex ops. |
| 2398 | template <typename FltOp, typename CpxOp> |
| 2399 | struct UndoComplexPattern : public mlir::RewritePattern { |
| 2400 | UndoComplexPattern(mlir::MLIRContext *ctx) |
| 2401 | : mlir::RewritePattern("fir.insert_value" , 2, ctx) {} |
| 2402 | |
| 2403 | llvm::LogicalResult |
| 2404 | matchAndRewrite(mlir::Operation *op, |
| 2405 | mlir::PatternRewriter &rewriter) const override { |
| 2406 | auto insval = mlir::dyn_cast_or_null<fir::InsertValueOp>(op); |
| 2407 | if (!insval || !mlir::isa<mlir::ComplexType>(insval.getType())) |
| 2408 | return mlir::failure(); |
| 2409 | auto insval2 = mlir::dyn_cast_or_null<fir::InsertValueOp>( |
| 2410 | insval.getAdt().getDefiningOp()); |
| 2411 | if (!insval2) |
| 2412 | return mlir::failure(); |
| 2413 | auto binf = mlir::dyn_cast_or_null<FltOp>(insval.getVal().getDefiningOp()); |
| 2414 | auto binf2 = |
| 2415 | mlir::dyn_cast_or_null<FltOp>(insval2.getVal().getDefiningOp()); |
| 2416 | if (!binf || !binf2 || insval.getCoor().size() != 1 || |
| 2417 | !isOne(insval.getCoor()[0]) || insval2.getCoor().size() != 1 || |
| 2418 | !isZero(insval2.getCoor()[0])) |
| 2419 | return mlir::failure(); |
| 2420 | auto eai = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
| 2421 | binf.getLhs().getDefiningOp()); |
| 2422 | auto ebi = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
| 2423 | binf.getRhs().getDefiningOp()); |
| 2424 | auto ear = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
| 2425 | binf2.getLhs().getDefiningOp()); |
| 2426 | auto ebr = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
| 2427 | binf2.getRhs().getDefiningOp()); |
| 2428 | if (!eai || !ebi || !ear || !ebr || ear.getAdt() != eai.getAdt() || |
| 2429 | ebr.getAdt() != ebi.getAdt() || eai.getCoor().size() != 1 || |
| 2430 | !isOne(eai.getCoor()[0]) || ebi.getCoor().size() != 1 || |
| 2431 | !isOne(ebi.getCoor()[0]) || ear.getCoor().size() != 1 || |
| 2432 | !isZero(ear.getCoor()[0]) || ebr.getCoor().size() != 1 || |
| 2433 | !isZero(ebr.getCoor()[0])) |
| 2434 | return mlir::failure(); |
| 2435 | rewriter.replaceOpWithNewOp<CpxOp>(op, ear.getAdt(), ebr.getAdt()); |
| 2436 | return mlir::success(); |
| 2437 | } |
| 2438 | }; |
| 2439 | |
| 2440 | void fir::InsertValueOp::getCanonicalizationPatterns( |
| 2441 | mlir::RewritePatternSet &results, mlir::MLIRContext *context) { |
| 2442 | results.insert<UndoComplexPattern<mlir::arith::AddFOp, fir::AddcOp>, |
| 2443 | UndoComplexPattern<mlir::arith::SubFOp, fir::SubcOp>>(context); |
| 2444 | } |
| 2445 | |
| 2446 | //===----------------------------------------------------------------------===// |
| 2447 | // IterWhileOp |
| 2448 | //===----------------------------------------------------------------------===// |
| 2449 | |
| 2450 | void fir::IterWhileOp::build(mlir::OpBuilder &builder, |
| 2451 | mlir::OperationState &result, mlir::Value lb, |
| 2452 | mlir::Value ub, mlir::Value step, |
| 2453 | mlir::Value iterate, bool finalCountValue, |
| 2454 | mlir::ValueRange iterArgs, |
| 2455 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 2456 | result.addOperands({lb, ub, step, iterate}); |
| 2457 | if (finalCountValue) { |
| 2458 | result.addTypes(builder.getIndexType()); |
| 2459 | result.addAttribute(getFinalValueAttrNameStr(), builder.getUnitAttr()); |
| 2460 | } |
| 2461 | result.addTypes(iterate.getType()); |
| 2462 | result.addOperands(iterArgs); |
| 2463 | for (auto v : iterArgs) |
| 2464 | result.addTypes(v.getType()); |
| 2465 | mlir::Region *bodyRegion = result.addRegion(); |
| 2466 | bodyRegion->push_back(new mlir::Block{}); |
| 2467 | bodyRegion->front().addArgument(builder.getIndexType(), result.location); |
| 2468 | bodyRegion->front().addArgument(iterate.getType(), result.location); |
| 2469 | bodyRegion->front().addArguments( |
| 2470 | iterArgs.getTypes(), |
| 2471 | llvm::SmallVector<mlir::Location>(iterArgs.size(), result.location)); |
| 2472 | result.addAttributes(attributes); |
| 2473 | } |
| 2474 | |
| 2475 | mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser, |
| 2476 | mlir::OperationState &result) { |
| 2477 | auto &builder = parser.getBuilder(); |
| 2478 | mlir::OpAsmParser::Argument inductionVariable, iterateVar; |
| 2479 | mlir::OpAsmParser::UnresolvedOperand lb, ub, step, iterateInput; |
| 2480 | if (parser.parseLParen() || parser.parseArgument(inductionVariable) || |
| 2481 | parser.parseEqual()) |
| 2482 | return mlir::failure(); |
| 2483 | |
| 2484 | // Parse loop bounds. |
| 2485 | auto indexType = builder.getIndexType(); |
| 2486 | auto i1Type = builder.getIntegerType(1); |
| 2487 | if (parser.parseOperand(lb) || |
| 2488 | parser.resolveOperand(lb, indexType, result.operands) || |
| 2489 | parser.parseKeyword("to" ) || parser.parseOperand(ub) || |
| 2490 | parser.resolveOperand(ub, indexType, result.operands) || |
| 2491 | parser.parseKeyword("step" ) || parser.parseOperand(step) || |
| 2492 | parser.parseRParen() || |
| 2493 | parser.resolveOperand(step, indexType, result.operands) || |
| 2494 | parser.parseKeyword("and" ) || parser.parseLParen() || |
| 2495 | parser.parseArgument(iterateVar) || parser.parseEqual() || |
| 2496 | parser.parseOperand(iterateInput) || parser.parseRParen() || |
| 2497 | parser.resolveOperand(iterateInput, i1Type, result.operands)) |
| 2498 | return mlir::failure(); |
| 2499 | |
| 2500 | // Parse the initial iteration arguments. |
| 2501 | auto prependCount = false; |
| 2502 | |
| 2503 | // Induction variable. |
| 2504 | llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs; |
| 2505 | regionArgs.push_back(inductionVariable); |
| 2506 | regionArgs.push_back(iterateVar); |
| 2507 | |
| 2508 | if (succeeded(parser.parseOptionalKeyword("iter_args" ))) { |
| 2509 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
| 2510 | llvm::SmallVector<mlir::Type> regionTypes; |
| 2511 | // Parse assignment list and results type list. |
| 2512 | if (parser.parseAssignmentList(regionArgs, operands) || |
| 2513 | parser.parseArrowTypeList(regionTypes)) |
| 2514 | return mlir::failure(); |
| 2515 | if (regionTypes.size() == operands.size() + 2) |
| 2516 | prependCount = true; |
| 2517 | llvm::ArrayRef<mlir::Type> resTypes = regionTypes; |
| 2518 | resTypes = prependCount ? resTypes.drop_front(2) : resTypes; |
| 2519 | // Resolve input operands. |
| 2520 | for (auto operandType : llvm::zip(operands, resTypes)) |
| 2521 | if (parser.resolveOperand(std::get<0>(operandType), |
| 2522 | std::get<1>(operandType), result.operands)) |
| 2523 | return mlir::failure(); |
| 2524 | if (prependCount) { |
| 2525 | result.addTypes(regionTypes); |
| 2526 | } else { |
| 2527 | result.addTypes(i1Type); |
| 2528 | result.addTypes(resTypes); |
| 2529 | } |
| 2530 | } else if (succeeded(parser.parseOptionalArrow())) { |
| 2531 | llvm::SmallVector<mlir::Type> typeList; |
| 2532 | if (parser.parseLParen() || parser.parseTypeList(typeList) || |
| 2533 | parser.parseRParen()) |
| 2534 | return mlir::failure(); |
| 2535 | // Type list must be "(index, i1)". |
| 2536 | if (typeList.size() != 2 || !mlir::isa<mlir::IndexType>(typeList[0]) || |
| 2537 | !typeList[1].isSignlessInteger(1)) |
| 2538 | return mlir::failure(); |
| 2539 | result.addTypes(typeList); |
| 2540 | prependCount = true; |
| 2541 | } else { |
| 2542 | result.addTypes(i1Type); |
| 2543 | } |
| 2544 | |
| 2545 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
| 2546 | return mlir::failure(); |
| 2547 | |
| 2548 | llvm::SmallVector<mlir::Type> argTypes; |
| 2549 | // Induction variable (hidden) |
| 2550 | if (prependCount) |
| 2551 | result.addAttribute(IterWhileOp::getFinalValueAttrNameStr(), |
| 2552 | builder.getUnitAttr()); |
| 2553 | else |
| 2554 | argTypes.push_back(indexType); |
| 2555 | // Loop carried variables (including iterate) |
| 2556 | argTypes.append(result.types.begin(), result.types.end()); |
| 2557 | // Parse the body region. |
| 2558 | auto *body = result.addRegion(); |
| 2559 | if (regionArgs.size() != argTypes.size()) |
| 2560 | return parser.emitError( |
| 2561 | parser.getNameLoc(), |
| 2562 | "mismatch in number of loop-carried values and defined values" ); |
| 2563 | |
| 2564 | for (size_t i = 0, e = regionArgs.size(); i != e; ++i) |
| 2565 | regionArgs[i].type = argTypes[i]; |
| 2566 | |
| 2567 | if (parser.parseRegion(*body, regionArgs)) |
| 2568 | return mlir::failure(); |
| 2569 | |
| 2570 | fir::IterWhileOp::ensureTerminator(*body, builder, result.location); |
| 2571 | return mlir::success(); |
| 2572 | } |
| 2573 | |
| 2574 | llvm::LogicalResult fir::IterWhileOp::verify() { |
| 2575 | // Check that the body defines as single block argument for the induction |
| 2576 | // variable. |
| 2577 | auto *body = getBody(); |
| 2578 | if (!body->getArgument(1).getType().isInteger(1)) |
| 2579 | return emitOpError( |
| 2580 | "expected body second argument to be an index argument for " |
| 2581 | "the induction variable" ); |
| 2582 | if (!body->getArgument(0).getType().isIndex()) |
| 2583 | return emitOpError( |
| 2584 | "expected body first argument to be an index argument for " |
| 2585 | "the induction variable" ); |
| 2586 | |
| 2587 | auto opNumResults = getNumResults(); |
| 2588 | if (getFinalValue()) { |
| 2589 | // Result type must be "(index, i1, ...)". |
| 2590 | if (!mlir::isa<mlir::IndexType>(getResult(0).getType())) |
| 2591 | return emitOpError("result #0 expected to be index" ); |
| 2592 | if (!getResult(1).getType().isSignlessInteger(1)) |
| 2593 | return emitOpError("result #1 expected to be i1" ); |
| 2594 | opNumResults--; |
| 2595 | } else { |
| 2596 | // iterate_while always returns the early exit induction value. |
| 2597 | // Result type must be "(i1, ...)" |
| 2598 | if (!getResult(0).getType().isSignlessInteger(1)) |
| 2599 | return emitOpError("result #0 expected to be i1" ); |
| 2600 | } |
| 2601 | if (opNumResults == 0) |
| 2602 | return mlir::failure(); |
| 2603 | if (getNumIterOperands() != opNumResults) |
| 2604 | return emitOpError( |
| 2605 | "mismatch in number of loop-carried values and defined values" ); |
| 2606 | if (getNumRegionIterArgs() != opNumResults) |
| 2607 | return emitOpError( |
| 2608 | "mismatch in number of basic block args and defined values" ); |
| 2609 | auto iterOperands = getIterOperands(); |
| 2610 | auto iterArgs = getRegionIterArgs(); |
| 2611 | auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); |
| 2612 | unsigned i = 0u; |
| 2613 | for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { |
| 2614 | if (std::get<0>(e).getType() != std::get<2>(e).getType()) |
| 2615 | return emitOpError() << "types mismatch between " << i |
| 2616 | << "th iter operand and defined value" ; |
| 2617 | if (std::get<1>(e).getType() != std::get<2>(e).getType()) |
| 2618 | return emitOpError() << "types mismatch between " << i |
| 2619 | << "th iter region arg and defined value" ; |
| 2620 | |
| 2621 | i++; |
| 2622 | } |
| 2623 | return mlir::success(); |
| 2624 | } |
| 2625 | |
| 2626 | void fir::IterWhileOp::print(mlir::OpAsmPrinter &p) { |
| 2627 | p << " (" << getInductionVar() << " = " << getLowerBound() << " to " |
| 2628 | << getUpperBound() << " step " << getStep() << ") and (" ; |
| 2629 | assert(hasIterOperands()); |
| 2630 | auto regionArgs = getRegionIterArgs(); |
| 2631 | auto operands = getIterOperands(); |
| 2632 | p << regionArgs.front() << " = " << *operands.begin() << ")" ; |
| 2633 | if (regionArgs.size() > 1) { |
| 2634 | p << " iter_args(" ; |
| 2635 | llvm::interleaveComma( |
| 2636 | llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, |
| 2637 | [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); |
| 2638 | p << ") -> (" ; |
| 2639 | llvm::interleaveComma( |
| 2640 | llvm::drop_begin(getResultTypes(), getFinalValue() ? 0 : 1), p); |
| 2641 | p << ")" ; |
| 2642 | } else if (getFinalValue()) { |
| 2643 | p << " -> (" << getResultTypes() << ')'; |
| 2644 | } |
| 2645 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), |
| 2646 | {getFinalValueAttrNameStr()}); |
| 2647 | p << ' '; |
| 2648 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
| 2649 | /*printBlockTerminators=*/true); |
| 2650 | } |
| 2651 | |
| 2652 | llvm::SmallVector<mlir::Region *> fir::IterWhileOp::getLoopRegions() { |
| 2653 | return {&getRegion()}; |
| 2654 | } |
| 2655 | |
| 2656 | mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) { |
| 2657 | for (auto i : llvm::enumerate(getInitArgs())) |
| 2658 | if (iterArg == i.value()) |
| 2659 | return getRegion().front().getArgument(i.index() + 1); |
| 2660 | return {}; |
| 2661 | } |
| 2662 | |
| 2663 | void fir::IterWhileOp::resultToSourceOps( |
| 2664 | llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) { |
| 2665 | auto oper = getFinalValue() ? resultNum + 1 : resultNum; |
| 2666 | auto *term = getRegion().front().getTerminator(); |
| 2667 | if (oper < term->getNumOperands()) |
| 2668 | results.push_back(term->getOperand(oper)); |
| 2669 | } |
| 2670 | |
| 2671 | mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) { |
| 2672 | if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) |
| 2673 | return getInitArgs()[blockArgNum - 1]; |
| 2674 | return {}; |
| 2675 | } |
| 2676 | |
| 2677 | std::optional<llvm::MutableArrayRef<mlir::OpOperand>> |
| 2678 | fir::IterWhileOp::getYieldedValuesMutable() { |
| 2679 | auto *term = getRegion().front().getTerminator(); |
| 2680 | return getFinalValue() ? term->getOpOperands().drop_front() |
| 2681 | : term->getOpOperands(); |
| 2682 | } |
| 2683 | |
| 2684 | //===----------------------------------------------------------------------===// |
| 2685 | // LenParamIndexOp |
| 2686 | //===----------------------------------------------------------------------===// |
| 2687 | |
| 2688 | mlir::ParseResult fir::LenParamIndexOp::parse(mlir::OpAsmParser &parser, |
| 2689 | mlir::OperationState &result) { |
| 2690 | return parseFieldLikeOp<fir::LenType>(parser, result); |
| 2691 | } |
| 2692 | |
| 2693 | void fir::LenParamIndexOp::print(mlir::OpAsmPrinter &p) { |
| 2694 | printFieldLikeOp(p, *this); |
| 2695 | } |
| 2696 | |
| 2697 | void fir::LenParamIndexOp::build(mlir::OpBuilder &builder, |
| 2698 | mlir::OperationState &result, |
| 2699 | llvm::StringRef fieldName, mlir::Type recTy, |
| 2700 | mlir::ValueRange operands) { |
| 2701 | result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); |
| 2702 | result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); |
| 2703 | result.addOperands(operands); |
| 2704 | } |
| 2705 | |
| 2706 | llvm::SmallVector<mlir::Attribute> fir::LenParamIndexOp::getAttributes() { |
| 2707 | llvm::SmallVector<mlir::Attribute> attrs; |
| 2708 | attrs.push_back(getFieldIdAttr()); |
| 2709 | attrs.push_back(getOnTypeAttr()); |
| 2710 | return attrs; |
| 2711 | } |
| 2712 | |
| 2713 | //===----------------------------------------------------------------------===// |
| 2714 | // LoadOp |
| 2715 | //===----------------------------------------------------------------------===// |
| 2716 | |
| 2717 | void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 2718 | mlir::Value refVal) { |
| 2719 | if (!refVal) { |
| 2720 | mlir::emitError(result.location, "LoadOp has null argument" ); |
| 2721 | return; |
| 2722 | } |
| 2723 | auto eleTy = fir::dyn_cast_ptrEleTy(refVal.getType()); |
| 2724 | if (!eleTy) { |
| 2725 | mlir::emitError(result.location, "not a memory reference type" ); |
| 2726 | return; |
| 2727 | } |
| 2728 | build(builder, result, eleTy, refVal); |
| 2729 | } |
| 2730 | |
| 2731 | void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 2732 | mlir::Type resTy, mlir::Value refVal) { |
| 2733 | |
| 2734 | if (!refVal) { |
| 2735 | mlir::emitError(result.location, "LoadOp has null argument" ); |
| 2736 | return; |
| 2737 | } |
| 2738 | result.addOperands(refVal); |
| 2739 | result.addTypes(resTy); |
| 2740 | } |
| 2741 | |
| 2742 | mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { |
| 2743 | if ((ele = fir::dyn_cast_ptrEleTy(ref))) |
| 2744 | return mlir::success(); |
| 2745 | return mlir::failure(); |
| 2746 | } |
| 2747 | |
| 2748 | mlir::ParseResult fir::LoadOp::parse(mlir::OpAsmParser &parser, |
| 2749 | mlir::OperationState &result) { |
| 2750 | mlir::Type type; |
| 2751 | mlir::OpAsmParser::UnresolvedOperand oper; |
| 2752 | if (parser.parseOperand(oper) || |
| 2753 | parser.parseOptionalAttrDict(result.attributes) || |
| 2754 | parser.parseColonType(type) || |
| 2755 | parser.resolveOperand(oper, type, result.operands)) |
| 2756 | return mlir::failure(); |
| 2757 | mlir::Type eleTy; |
| 2758 | if (fir::LoadOp::getElementOf(eleTy, type) || |
| 2759 | parser.addTypeToList(eleTy, result.types)) |
| 2760 | return mlir::failure(); |
| 2761 | return mlir::success(); |
| 2762 | } |
| 2763 | |
| 2764 | void fir::LoadOp::print(mlir::OpAsmPrinter &p) { |
| 2765 | p << ' '; |
| 2766 | p.printOperand(getMemref()); |
| 2767 | p.printOptionalAttrDict(getOperation()->getAttrs(), {}); |
| 2768 | p << " : " << getMemref().getType(); |
| 2769 | } |
| 2770 | |
| 2771 | void fir::LoadOp::getEffects( |
| 2772 | llvm::SmallVectorImpl< |
| 2773 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
| 2774 | &effects) { |
| 2775 | effects.emplace_back(mlir::MemoryEffects::Read::get(), &getMemrefMutable(), |
| 2776 | mlir::SideEffects::DefaultResource::get()); |
| 2777 | addVolatileMemoryEffects({getMemref().getType()}, effects); |
| 2778 | } |
| 2779 | |
| 2780 | //===----------------------------------------------------------------------===// |
| 2781 | // DoLoopOp |
| 2782 | //===----------------------------------------------------------------------===// |
| 2783 | |
| 2784 | void fir::DoLoopOp::build(mlir::OpBuilder &builder, |
| 2785 | mlir::OperationState &result, mlir::Value lb, |
| 2786 | mlir::Value ub, mlir::Value step, bool unordered, |
| 2787 | bool finalCountValue, mlir::ValueRange iterArgs, |
| 2788 | mlir::ValueRange reduceOperands, |
| 2789 | llvm::ArrayRef<mlir::Attribute> reduceAttrs, |
| 2790 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 2791 | result.addOperands({lb, ub, step}); |
| 2792 | result.addOperands(reduceOperands); |
| 2793 | result.addOperands(iterArgs); |
| 2794 | result.addAttribute(getOperandSegmentSizeAttr(), |
| 2795 | builder.getDenseI32ArrayAttr( |
| 2796 | {1, 1, 1, static_cast<int32_t>(reduceOperands.size()), |
| 2797 | static_cast<int32_t>(iterArgs.size())})); |
| 2798 | if (finalCountValue) { |
| 2799 | result.addTypes(builder.getIndexType()); |
| 2800 | result.addAttribute(getFinalValueAttrName(result.name), |
| 2801 | builder.getUnitAttr()); |
| 2802 | } |
| 2803 | for (auto v : iterArgs) |
| 2804 | result.addTypes(v.getType()); |
| 2805 | mlir::Region *bodyRegion = result.addRegion(); |
| 2806 | bodyRegion->push_back(new mlir::Block{}); |
| 2807 | if (iterArgs.empty() && !finalCountValue) |
| 2808 | fir::DoLoopOp::ensureTerminator(*bodyRegion, builder, result.location); |
| 2809 | bodyRegion->front().addArgument(builder.getIndexType(), result.location); |
| 2810 | bodyRegion->front().addArguments( |
| 2811 | iterArgs.getTypes(), |
| 2812 | llvm::SmallVector<mlir::Location>(iterArgs.size(), result.location)); |
| 2813 | if (unordered) |
| 2814 | result.addAttribute(getUnorderedAttrName(result.name), |
| 2815 | builder.getUnitAttr()); |
| 2816 | if (!reduceAttrs.empty()) |
| 2817 | result.addAttribute(getReduceAttrsAttrName(result.name), |
| 2818 | builder.getArrayAttr(reduceAttrs)); |
| 2819 | result.addAttributes(attributes); |
| 2820 | } |
| 2821 | |
| 2822 | mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser, |
| 2823 | mlir::OperationState &result) { |
| 2824 | auto &builder = parser.getBuilder(); |
| 2825 | mlir::OpAsmParser::Argument inductionVariable; |
| 2826 | mlir::OpAsmParser::UnresolvedOperand lb, ub, step; |
| 2827 | // Parse the induction variable followed by '='. |
| 2828 | if (parser.parseArgument(inductionVariable) || parser.parseEqual()) |
| 2829 | return mlir::failure(); |
| 2830 | |
| 2831 | // Parse loop bounds. |
| 2832 | auto indexType = builder.getIndexType(); |
| 2833 | if (parser.parseOperand(lb) || |
| 2834 | parser.resolveOperand(lb, indexType, result.operands) || |
| 2835 | parser.parseKeyword("to" ) || parser.parseOperand(ub) || |
| 2836 | parser.resolveOperand(ub, indexType, result.operands) || |
| 2837 | parser.parseKeyword("step" ) || parser.parseOperand(step) || |
| 2838 | parser.resolveOperand(step, indexType, result.operands)) |
| 2839 | return mlir::failure(); |
| 2840 | |
| 2841 | if (mlir::succeeded(parser.parseOptionalKeyword("unordered" ))) |
| 2842 | result.addAttribute("unordered" , builder.getUnitAttr()); |
| 2843 | |
| 2844 | // Parse the reduction arguments. |
| 2845 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands; |
| 2846 | llvm::SmallVector<mlir::Type> reduceArgTypes; |
| 2847 | if (succeeded(parser.parseOptionalKeyword("reduce" ))) { |
| 2848 | // Parse reduction attributes and variables. |
| 2849 | llvm::SmallVector<ReduceAttr> attributes; |
| 2850 | if (failed(parser.parseCommaSeparatedList( |
| 2851 | mlir::AsmParser::Delimiter::Paren, [&]() { |
| 2852 | if (parser.parseAttribute(attributes.emplace_back()) || |
| 2853 | parser.parseArrow() || |
| 2854 | parser.parseOperand(reduceOperands.emplace_back()) || |
| 2855 | parser.parseColonType(reduceArgTypes.emplace_back())) |
| 2856 | return mlir::failure(); |
| 2857 | return mlir::success(); |
| 2858 | }))) |
| 2859 | return mlir::failure(); |
| 2860 | // Resolve input operands. |
| 2861 | for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes)) |
| 2862 | if (parser.resolveOperand(std::get<0>(operand_type), |
| 2863 | std::get<1>(operand_type), result.operands)) |
| 2864 | return mlir::failure(); |
| 2865 | llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), |
| 2866 | attributes.end()); |
| 2867 | result.addAttribute(getReduceAttrsAttrName(result.name), |
| 2868 | builder.getArrayAttr(arrayAttr)); |
| 2869 | } |
| 2870 | |
| 2871 | // Parse the optional initial iteration arguments. |
| 2872 | llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs; |
| 2873 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands; |
| 2874 | llvm::SmallVector<mlir::Type> argTypes; |
| 2875 | bool prependCount = false; |
| 2876 | regionArgs.push_back(inductionVariable); |
| 2877 | |
| 2878 | if (succeeded(parser.parseOptionalKeyword("iter_args" ))) { |
| 2879 | // Parse assignment list and results type list. |
| 2880 | if (parser.parseAssignmentList(regionArgs, iterOperands) || |
| 2881 | parser.parseArrowTypeList(result.types)) |
| 2882 | return mlir::failure(); |
| 2883 | if (result.types.size() == iterOperands.size() + 1) |
| 2884 | prependCount = true; |
| 2885 | // Resolve input operands. |
| 2886 | llvm::ArrayRef<mlir::Type> resTypes = result.types; |
| 2887 | for (auto operand_type : llvm::zip( |
| 2888 | iterOperands, prependCount ? resTypes.drop_front() : resTypes)) |
| 2889 | if (parser.resolveOperand(std::get<0>(operand_type), |
| 2890 | std::get<1>(operand_type), result.operands)) |
| 2891 | return mlir::failure(); |
| 2892 | } else if (succeeded(parser.parseOptionalArrow())) { |
| 2893 | if (parser.parseKeyword("index" )) |
| 2894 | return mlir::failure(); |
| 2895 | result.types.push_back(indexType); |
| 2896 | prependCount = true; |
| 2897 | } |
| 2898 | |
| 2899 | // Set the operandSegmentSizes attribute |
| 2900 | result.addAttribute(getOperandSegmentSizeAttr(), |
| 2901 | builder.getDenseI32ArrayAttr( |
| 2902 | {1, 1, 1, static_cast<int32_t>(reduceOperands.size()), |
| 2903 | static_cast<int32_t>(iterOperands.size())})); |
| 2904 | |
| 2905 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
| 2906 | return mlir::failure(); |
| 2907 | |
| 2908 | // Induction variable. |
| 2909 | if (prependCount) |
| 2910 | result.addAttribute(DoLoopOp::getFinalValueAttrName(result.name), |
| 2911 | builder.getUnitAttr()); |
| 2912 | else |
| 2913 | argTypes.push_back(indexType); |
| 2914 | // Loop carried variables |
| 2915 | argTypes.append(result.types.begin(), result.types.end()); |
| 2916 | // Parse the body region. |
| 2917 | auto *body = result.addRegion(); |
| 2918 | if (regionArgs.size() != argTypes.size()) |
| 2919 | return parser.emitError( |
| 2920 | parser.getNameLoc(), |
| 2921 | "mismatch in number of loop-carried values and defined values" ); |
| 2922 | for (size_t i = 0, e = regionArgs.size(); i != e; ++i) |
| 2923 | regionArgs[i].type = argTypes[i]; |
| 2924 | |
| 2925 | if (parser.parseRegion(*body, regionArgs)) |
| 2926 | return mlir::failure(); |
| 2927 | |
| 2928 | DoLoopOp::ensureTerminator(*body, builder, result.location); |
| 2929 | |
| 2930 | return mlir::success(); |
| 2931 | } |
| 2932 | |
| 2933 | fir::DoLoopOp fir::getForInductionVarOwner(mlir::Value val) { |
| 2934 | auto ivArg = mlir::dyn_cast<mlir::BlockArgument>(val); |
| 2935 | if (!ivArg) |
| 2936 | return {}; |
| 2937 | assert(ivArg.getOwner() && "unlinked block argument" ); |
| 2938 | auto *containingInst = ivArg.getOwner()->getParentOp(); |
| 2939 | return mlir::dyn_cast_or_null<fir::DoLoopOp>(containingInst); |
| 2940 | } |
| 2941 | |
| 2942 | // Lifted from loop.loop |
| 2943 | llvm::LogicalResult fir::DoLoopOp::verify() { |
| 2944 | // Check that the body defines as single block argument for the induction |
| 2945 | // variable. |
| 2946 | auto *body = getBody(); |
| 2947 | if (!body->getArgument(0).getType().isIndex()) |
| 2948 | return emitOpError( |
| 2949 | "expected body first argument to be an index argument for " |
| 2950 | "the induction variable" ); |
| 2951 | |
| 2952 | auto opNumResults = getNumResults(); |
| 2953 | if (opNumResults == 0) |
| 2954 | return mlir::success(); |
| 2955 | |
| 2956 | if (getFinalValue()) { |
| 2957 | if (getUnordered()) |
| 2958 | return emitOpError("unordered loop has no final value" ); |
| 2959 | opNumResults--; |
| 2960 | } |
| 2961 | if (getNumIterOperands() != opNumResults) |
| 2962 | return emitOpError( |
| 2963 | "mismatch in number of loop-carried values and defined values" ); |
| 2964 | if (getNumRegionIterArgs() != opNumResults) |
| 2965 | return emitOpError( |
| 2966 | "mismatch in number of basic block args and defined values" ); |
| 2967 | auto iterOperands = getIterOperands(); |
| 2968 | auto iterArgs = getRegionIterArgs(); |
| 2969 | auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); |
| 2970 | unsigned i = 0u; |
| 2971 | for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { |
| 2972 | if (std::get<0>(e).getType() != std::get<2>(e).getType()) |
| 2973 | return emitOpError() << "types mismatch between " << i |
| 2974 | << "th iter operand and defined value" ; |
| 2975 | if (std::get<1>(e).getType() != std::get<2>(e).getType()) |
| 2976 | return emitOpError() << "types mismatch between " << i |
| 2977 | << "th iter region arg and defined value" ; |
| 2978 | |
| 2979 | i++; |
| 2980 | } |
| 2981 | auto reduceAttrs = getReduceAttrsAttr(); |
| 2982 | if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0)) |
| 2983 | return emitOpError( |
| 2984 | "mismatch in number of reduction variables and reduction attributes" ); |
| 2985 | return mlir::success(); |
| 2986 | } |
| 2987 | |
| 2988 | void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) { |
| 2989 | bool printBlockTerminators = false; |
| 2990 | p << ' ' << getInductionVar() << " = " << getLowerBound() << " to " |
| 2991 | << getUpperBound() << " step " << getStep(); |
| 2992 | if (getUnordered()) |
| 2993 | p << " unordered" ; |
| 2994 | if (hasReduceOperands()) { |
| 2995 | p << " reduce(" ; |
| 2996 | auto attrs = getReduceAttrsAttr(); |
| 2997 | auto operands = getReduceOperands(); |
| 2998 | llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) { |
| 2999 | p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " |
| 3000 | << std::get<1>(it).getType(); |
| 3001 | }); |
| 3002 | p << ')'; |
| 3003 | printBlockTerminators = true; |
| 3004 | } |
| 3005 | if (hasIterOperands()) { |
| 3006 | p << " iter_args(" ; |
| 3007 | auto regionArgs = getRegionIterArgs(); |
| 3008 | auto operands = getIterOperands(); |
| 3009 | llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { |
| 3010 | p << std::get<0>(it) << " = " << std::get<1>(it); |
| 3011 | }); |
| 3012 | p << ") -> (" << getResultTypes() << ')'; |
| 3013 | printBlockTerminators = true; |
| 3014 | } else if (getFinalValue()) { |
| 3015 | p << " -> " << getResultTypes(); |
| 3016 | printBlockTerminators = true; |
| 3017 | } |
| 3018 | p.printOptionalAttrDictWithKeyword( |
| 3019 | (*this)->getAttrs(), |
| 3020 | {"unordered" , "finalValue" , "reduceAttrs" , "operandSegmentSizes" }); |
| 3021 | p << ' '; |
| 3022 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
| 3023 | printBlockTerminators); |
| 3024 | } |
| 3025 | |
| 3026 | llvm::SmallVector<mlir::Region *> fir::DoLoopOp::getLoopRegions() { |
| 3027 | return {&getRegion()}; |
| 3028 | } |
| 3029 | |
| 3030 | /// Translate a value passed as an iter_arg to the corresponding block |
| 3031 | /// argument in the body of the loop. |
| 3032 | mlir::BlockArgument fir::DoLoopOp::iterArgToBlockArg(mlir::Value iterArg) { |
| 3033 | for (auto i : llvm::enumerate(getInitArgs())) |
| 3034 | if (iterArg == i.value()) |
| 3035 | return getRegion().front().getArgument(i.index() + 1); |
| 3036 | return {}; |
| 3037 | } |
| 3038 | |
| 3039 | /// Translate the result vector (by index number) to the corresponding value |
| 3040 | /// to the `fir.result` Op. |
| 3041 | void fir::DoLoopOp::resultToSourceOps( |
| 3042 | llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) { |
| 3043 | auto oper = getFinalValue() ? resultNum + 1 : resultNum; |
| 3044 | auto *term = getRegion().front().getTerminator(); |
| 3045 | if (oper < term->getNumOperands()) |
| 3046 | results.push_back(term->getOperand(oper)); |
| 3047 | } |
| 3048 | |
| 3049 | /// Translate the block argument (by index number) to the corresponding value |
| 3050 | /// passed as an iter_arg to the parent DoLoopOp. |
| 3051 | mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) { |
| 3052 | if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) |
| 3053 | return getInitArgs()[blockArgNum - 1]; |
| 3054 | return {}; |
| 3055 | } |
| 3056 | |
| 3057 | std::optional<llvm::MutableArrayRef<mlir::OpOperand>> |
| 3058 | fir::DoLoopOp::getYieldedValuesMutable() { |
| 3059 | auto *term = getRegion().front().getTerminator(); |
| 3060 | return getFinalValue() ? term->getOpOperands().drop_front() |
| 3061 | : term->getOpOperands(); |
| 3062 | } |
| 3063 | |
| 3064 | //===----------------------------------------------------------------------===// |
| 3065 | // DTEntryOp |
| 3066 | //===----------------------------------------------------------------------===// |
| 3067 | |
| 3068 | mlir::ParseResult fir::DTEntryOp::parse(mlir::OpAsmParser &parser, |
| 3069 | mlir::OperationState &result) { |
| 3070 | llvm::StringRef methodName; |
| 3071 | // allow `methodName` or `"methodName"` |
| 3072 | if (failed(parser.parseOptionalKeyword(&methodName))) { |
| 3073 | mlir::StringAttr methodAttr; |
| 3074 | if (parser.parseAttribute(methodAttr, getMethodAttrName(result.name), |
| 3075 | result.attributes)) |
| 3076 | return mlir::failure(); |
| 3077 | } else { |
| 3078 | result.addAttribute(getMethodAttrName(result.name), |
| 3079 | parser.getBuilder().getStringAttr(methodName)); |
| 3080 | } |
| 3081 | mlir::SymbolRefAttr calleeAttr; |
| 3082 | if (parser.parseComma() || |
| 3083 | parser.parseAttribute(calleeAttr, fir::DTEntryOp::getProcAttrNameStr(), |
| 3084 | result.attributes)) |
| 3085 | return mlir::failure(); |
| 3086 | return mlir::success(); |
| 3087 | } |
| 3088 | |
| 3089 | void fir::DTEntryOp::print(mlir::OpAsmPrinter &p) { |
| 3090 | p << ' ' << getMethodAttr() << ", " << getProcAttr(); |
| 3091 | } |
| 3092 | |
| 3093 | //===----------------------------------------------------------------------===// |
| 3094 | // ReboxOp |
| 3095 | //===----------------------------------------------------------------------===// |
| 3096 | |
| 3097 | /// Get the scalar type related to a fir.box type. |
| 3098 | /// Example: return f32 for !fir.box<!fir.heap<!fir.array<?x?xf32>>. |
| 3099 | static mlir::Type getBoxScalarEleTy(mlir::Type boxTy) { |
| 3100 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy); |
| 3101 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) |
| 3102 | return seqTy.getEleTy(); |
| 3103 | return eleTy; |
| 3104 | } |
| 3105 | |
| 3106 | /// Test if \p t1 and \p t2 are compatible character types (if they can |
| 3107 | /// represent the same type at runtime). |
| 3108 | static bool areCompatibleCharacterTypes(mlir::Type t1, mlir::Type t2) { |
| 3109 | auto c1 = mlir::dyn_cast<fir::CharacterType>(t1); |
| 3110 | auto c2 = mlir::dyn_cast<fir::CharacterType>(t2); |
| 3111 | if (!c1 || !c2) |
| 3112 | return false; |
| 3113 | if (c1.hasDynamicLen() || c2.hasDynamicLen()) |
| 3114 | return true; |
| 3115 | return c1.getLen() == c2.getLen(); |
| 3116 | } |
| 3117 | |
| 3118 | llvm::LogicalResult fir::ReboxOp::verify() { |
| 3119 | auto inputBoxTy = getBox().getType(); |
| 3120 | if (fir::isa_unknown_size_box(inputBoxTy)) |
| 3121 | return emitOpError("box operand must not have unknown rank or type" ); |
| 3122 | auto outBoxTy = getType(); |
| 3123 | if (fir::isa_unknown_size_box(outBoxTy)) |
| 3124 | return emitOpError("result type must not have unknown rank or type" ); |
| 3125 | auto inputRank = fir::getBoxRank(inputBoxTy); |
| 3126 | auto inputEleTy = getBoxScalarEleTy(inputBoxTy); |
| 3127 | auto outRank = fir::getBoxRank(outBoxTy); |
| 3128 | auto outEleTy = getBoxScalarEleTy(outBoxTy); |
| 3129 | |
| 3130 | if (auto sliceVal = getSlice()) { |
| 3131 | // Slicing case |
| 3132 | if (mlir::cast<fir::SliceType>(sliceVal.getType()).getRank() != inputRank) |
| 3133 | return emitOpError("slice operand rank must match box operand rank" ); |
| 3134 | if (auto shapeVal = getShape()) { |
| 3135 | if (auto shiftTy = mlir::dyn_cast<fir::ShiftType>(shapeVal.getType())) { |
| 3136 | if (shiftTy.getRank() != inputRank) |
| 3137 | return emitOpError("shape operand and input box ranks must match " |
| 3138 | "when there is a slice" ); |
| 3139 | } else { |
| 3140 | return emitOpError("shape operand must absent or be a fir.shift " |
| 3141 | "when there is a slice" ); |
| 3142 | } |
| 3143 | } |
| 3144 | if (auto sliceOp = sliceVal.getDefiningOp()) { |
| 3145 | auto slicedRank = mlir::cast<fir::SliceOp>(sliceOp).getOutRank(); |
| 3146 | if (slicedRank != outRank) |
| 3147 | return emitOpError("result type rank and rank after applying slice " |
| 3148 | "operand must match" ); |
| 3149 | } |
| 3150 | } else { |
| 3151 | // Reshaping case |
| 3152 | unsigned shapeRank = inputRank; |
| 3153 | if (auto shapeVal = getShape()) { |
| 3154 | auto ty = shapeVal.getType(); |
| 3155 | if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(ty)) { |
| 3156 | shapeRank = shapeTy.getRank(); |
| 3157 | } else if (auto shapeShiftTy = mlir::dyn_cast<fir::ShapeShiftType>(ty)) { |
| 3158 | shapeRank = shapeShiftTy.getRank(); |
| 3159 | } else { |
| 3160 | auto shiftTy = mlir::cast<fir::ShiftType>(ty); |
| 3161 | shapeRank = shiftTy.getRank(); |
| 3162 | if (shapeRank != inputRank) |
| 3163 | return emitOpError("shape operand and input box ranks must match " |
| 3164 | "when the shape is a fir.shift" ); |
| 3165 | } |
| 3166 | } |
| 3167 | if (shapeRank != outRank) |
| 3168 | return emitOpError("result type and shape operand ranks must match" ); |
| 3169 | } |
| 3170 | |
| 3171 | if (inputEleTy != outEleTy) { |
| 3172 | // TODO: check that outBoxTy is a parent type of inputBoxTy for derived |
| 3173 | // types. |
| 3174 | // Character input and output types with constant length may be different if |
| 3175 | // there is a substring in the slice, otherwise, they must match. If any of |
| 3176 | // the types is a character with dynamic length, the other type can be any |
| 3177 | // character type. |
| 3178 | const bool typeCanMismatch = |
| 3179 | mlir::isa<fir::RecordType>(inputEleTy) || |
| 3180 | mlir::isa<mlir::NoneType>(outEleTy) || |
| 3181 | (mlir::isa<mlir::NoneType>(inputEleTy) && |
| 3182 | mlir::isa<fir::RecordType>(outEleTy)) || |
| 3183 | (getSlice() && mlir::isa<fir::CharacterType>(inputEleTy)) || |
| 3184 | (getSlice() && fir::isa_complex(inputEleTy) && |
| 3185 | mlir::isa<mlir::FloatType>(outEleTy)) || |
| 3186 | areCompatibleCharacterTypes(inputEleTy, outEleTy); |
| 3187 | if (!typeCanMismatch) |
| 3188 | return emitOpError( |
| 3189 | "op input and output element types must match for intrinsic types" ); |
| 3190 | } |
| 3191 | return mlir::success(); |
| 3192 | } |
| 3193 | |
| 3194 | //===----------------------------------------------------------------------===// |
| 3195 | // ReboxAssumedRankOp |
| 3196 | //===----------------------------------------------------------------------===// |
| 3197 | |
| 3198 | static bool areCompatibleAssumedRankElementType(mlir::Type inputEleTy, |
| 3199 | mlir::Type outEleTy) { |
| 3200 | if (inputEleTy == outEleTy) |
| 3201 | return true; |
| 3202 | // Output is unlimited polymorphic -> output dynamic type is the same as input |
| 3203 | // type. |
| 3204 | if (mlir::isa<mlir::NoneType>(Val: outEleTy)) |
| 3205 | return true; |
| 3206 | // Output/Input are derived types. Assuming input extends output type, output |
| 3207 | // dynamic type is the output static type, unless output is polymorphic. |
| 3208 | if (mlir::isa<fir::RecordType>(inputEleTy) && |
| 3209 | mlir::isa<fir::RecordType>(outEleTy)) |
| 3210 | return true; |
| 3211 | if (areCompatibleCharacterTypes(t1: inputEleTy, t2: outEleTy)) |
| 3212 | return true; |
| 3213 | return false; |
| 3214 | } |
| 3215 | |
| 3216 | llvm::LogicalResult fir::ReboxAssumedRankOp::verify() { |
| 3217 | mlir::Type inputType = getBox().getType(); |
| 3218 | if (!mlir::isa<fir::BaseBoxType>(inputType) && !fir::isBoxAddress(inputType)) |
| 3219 | return emitOpError("input must be a box or box address" ); |
| 3220 | mlir::Type inputEleTy = |
| 3221 | mlir::cast<fir::BaseBoxType>(fir::unwrapRefType(inputType)) |
| 3222 | .unwrapInnerType(); |
| 3223 | mlir::Type outEleTy = |
| 3224 | mlir::cast<fir::BaseBoxType>(getType()).unwrapInnerType(); |
| 3225 | if (!areCompatibleAssumedRankElementType(inputEleTy, outEleTy)) |
| 3226 | return emitOpError("input and output element types are incompatible" ); |
| 3227 | return mlir::success(); |
| 3228 | } |
| 3229 | |
| 3230 | void fir::ReboxAssumedRankOp::getEffects( |
| 3231 | llvm::SmallVectorImpl< |
| 3232 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
| 3233 | &effects) { |
| 3234 | mlir::OpOperand &inputBox = getBoxMutable(); |
| 3235 | if (fir::isBoxAddress(inputBox.get().getType())) |
| 3236 | effects.emplace_back(mlir::MemoryEffects::Read::get(), &inputBox, |
| 3237 | mlir::SideEffects::DefaultResource::get()); |
| 3238 | } |
| 3239 | |
| 3240 | //===----------------------------------------------------------------------===// |
| 3241 | // ResultOp |
| 3242 | //===----------------------------------------------------------------------===// |
| 3243 | |
| 3244 | llvm::LogicalResult fir::ResultOp::verify() { |
| 3245 | auto *parentOp = (*this)->getParentOp(); |
| 3246 | auto results = parentOp->getResults(); |
| 3247 | auto operands = (*this)->getOperands(); |
| 3248 | |
| 3249 | if (parentOp->getNumResults() != getNumOperands()) |
| 3250 | return emitOpError() << "parent of result must have same arity" ; |
| 3251 | for (auto e : llvm::zip(results, operands)) |
| 3252 | if (std::get<0>(e).getType() != std::get<1>(e).getType()) |
| 3253 | return emitOpError() << "types mismatch between result op and its parent" ; |
| 3254 | return mlir::success(); |
| 3255 | } |
| 3256 | |
| 3257 | //===----------------------------------------------------------------------===// |
| 3258 | // SaveResultOp |
| 3259 | //===----------------------------------------------------------------------===// |
| 3260 | |
| 3261 | llvm::LogicalResult fir::SaveResultOp::verify() { |
| 3262 | auto resultType = getValue().getType(); |
| 3263 | if (resultType != fir::dyn_cast_ptrEleTy(getMemref().getType())) |
| 3264 | return emitOpError("value type must match memory reference type" ); |
| 3265 | if (fir::isa_unknown_size_box(resultType)) |
| 3266 | return emitOpError("cannot save !fir.box of unknown rank or type" ); |
| 3267 | |
| 3268 | if (mlir::isa<fir::BoxType>(resultType)) { |
| 3269 | if (getShape() || !getTypeparams().empty()) |
| 3270 | return emitOpError( |
| 3271 | "must not have shape or length operands if the value is a fir.box" ); |
| 3272 | return mlir::success(); |
| 3273 | } |
| 3274 | |
| 3275 | // fir.record or fir.array case. |
| 3276 | unsigned shapeTyRank = 0; |
| 3277 | if (auto shapeVal = getShape()) { |
| 3278 | auto shapeTy = shapeVal.getType(); |
| 3279 | if (auto s = mlir::dyn_cast<fir::ShapeType>(shapeTy)) |
| 3280 | shapeTyRank = s.getRank(); |
| 3281 | else |
| 3282 | shapeTyRank = mlir::cast<fir::ShapeShiftType>(shapeTy).getRank(); |
| 3283 | } |
| 3284 | |
| 3285 | auto eleTy = resultType; |
| 3286 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(resultType)) { |
| 3287 | if (seqTy.getDimension() != shapeTyRank) |
| 3288 | emitOpError("shape operand must be provided and have the value rank " |
| 3289 | "when the value is a fir.array" ); |
| 3290 | eleTy = seqTy.getEleTy(); |
| 3291 | } else { |
| 3292 | if (shapeTyRank != 0) |
| 3293 | emitOpError( |
| 3294 | "shape operand should only be provided if the value is a fir.array" ); |
| 3295 | } |
| 3296 | |
| 3297 | if (auto recTy = mlir::dyn_cast<fir::RecordType>(eleTy)) { |
| 3298 | if (recTy.getNumLenParams() != getTypeparams().size()) |
| 3299 | emitOpError("length parameters number must match with the value type " |
| 3300 | "length parameters" ); |
| 3301 | } else if (auto charTy = mlir::dyn_cast<fir::CharacterType>(eleTy)) { |
| 3302 | if (getTypeparams().size() > 1) |
| 3303 | emitOpError("no more than one length parameter must be provided for " |
| 3304 | "character value" ); |
| 3305 | } else { |
| 3306 | if (!getTypeparams().empty()) |
| 3307 | emitOpError("length parameters must not be provided for this value type" ); |
| 3308 | } |
| 3309 | |
| 3310 | return mlir::success(); |
| 3311 | } |
| 3312 | |
| 3313 | //===----------------------------------------------------------------------===// |
| 3314 | // IntegralSwitchTerminator |
| 3315 | //===----------------------------------------------------------------------===// |
| 3316 | static constexpr llvm::StringRef getCompareOffsetAttr() { |
| 3317 | return "compare_operand_offsets" ; |
| 3318 | } |
| 3319 | |
| 3320 | static constexpr llvm::StringRef getTargetOffsetAttr() { |
| 3321 | return "target_operand_offsets" ; |
| 3322 | } |
| 3323 | |
| 3324 | template <typename OpT> |
| 3325 | static llvm::LogicalResult verifyIntegralSwitchTerminator(OpT op) { |
| 3326 | if (!mlir::isa<mlir::IntegerType, mlir::IndexType, fir::IntegerType>( |
| 3327 | op.getSelector().getType())) |
| 3328 | return op.emitOpError("must be an integer" ); |
| 3329 | auto cases = |
| 3330 | op->template getAttrOfType<mlir::ArrayAttr>(op.getCasesAttr()).getValue(); |
| 3331 | auto count = op.getNumDest(); |
| 3332 | if (count == 0) |
| 3333 | return op.emitOpError("must have at least one successor" ); |
| 3334 | if (op.getNumConditions() != count) |
| 3335 | return op.emitOpError("number of cases and targets don't match" ); |
| 3336 | if (op.targetOffsetSize() != count) |
| 3337 | return op.emitOpError("incorrect number of successor operand groups" ); |
| 3338 | for (decltype(count) i = 0; i != count; ++i) { |
| 3339 | if (!mlir::isa<mlir::IntegerAttr, mlir::UnitAttr>(cases[i])) |
| 3340 | return op.emitOpError("invalid case alternative" ); |
| 3341 | } |
| 3342 | return mlir::success(); |
| 3343 | } |
| 3344 | |
| 3345 | static mlir::ParseResult parseIntegralSwitchTerminator( |
| 3346 | mlir::OpAsmParser &parser, mlir::OperationState &result, |
| 3347 | llvm::StringRef casesAttr, llvm::StringRef operandSegmentAttr) { |
| 3348 | mlir::OpAsmParser::UnresolvedOperand selector; |
| 3349 | mlir::Type type; |
| 3350 | if (fir::parseSelector(parser, result, selector, type)) |
| 3351 | return mlir::failure(); |
| 3352 | |
| 3353 | llvm::SmallVector<mlir::Attribute> ivalues; |
| 3354 | llvm::SmallVector<mlir::Block *> dests; |
| 3355 | llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; |
| 3356 | while (true) { |
| 3357 | mlir::Attribute ivalue; // Integer or Unit |
| 3358 | mlir::Block *dest; |
| 3359 | llvm::SmallVector<mlir::Value> destArg; |
| 3360 | mlir::NamedAttrList temp; |
| 3361 | if (parser.parseAttribute(result&: ivalue, attrName: "i" , attrs&: temp) || parser.parseComma() || |
| 3362 | parser.parseSuccessorAndUseList(dest, operands&: destArg)) |
| 3363 | return mlir::failure(); |
| 3364 | ivalues.push_back(Elt: ivalue); |
| 3365 | dests.push_back(Elt: dest); |
| 3366 | destArgs.push_back(Elt: destArg); |
| 3367 | if (!parser.parseOptionalRSquare()) |
| 3368 | break; |
| 3369 | if (parser.parseComma()) |
| 3370 | return mlir::failure(); |
| 3371 | } |
| 3372 | auto &bld = parser.getBuilder(); |
| 3373 | result.addAttribute(casesAttr, bld.getArrayAttr(ivalues)); |
| 3374 | llvm::SmallVector<int32_t> argOffs; |
| 3375 | int32_t sumArgs = 0; |
| 3376 | const auto count = dests.size(); |
| 3377 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
| 3378 | result.addSuccessors(successor: dests[i]); |
| 3379 | result.addOperands(newOperands: destArgs[i]); |
| 3380 | auto argSize = destArgs[i].size(); |
| 3381 | argOffs.push_back(Elt: argSize); |
| 3382 | sumArgs += argSize; |
| 3383 | } |
| 3384 | result.addAttribute(operandSegmentAttr, |
| 3385 | bld.getDenseI32ArrayAttr({1, 0, sumArgs})); |
| 3386 | result.addAttribute(getTargetOffsetAttr(), bld.getDenseI32ArrayAttr(argOffs)); |
| 3387 | return mlir::success(); |
| 3388 | } |
| 3389 | |
| 3390 | template <typename OpT> |
| 3391 | static void printIntegralSwitchTerminator(OpT op, mlir::OpAsmPrinter &p) { |
| 3392 | p << ' '; |
| 3393 | p.printOperand(op.getSelector()); |
| 3394 | p << " : " << op.getSelector().getType() << " [" ; |
| 3395 | auto cases = |
| 3396 | op->template getAttrOfType<mlir::ArrayAttr>(op.getCasesAttr()).getValue(); |
| 3397 | auto count = op.getNumConditions(); |
| 3398 | for (decltype(count) i = 0; i != count; ++i) { |
| 3399 | if (i) |
| 3400 | p << ", " ; |
| 3401 | auto &attr = cases[i]; |
| 3402 | if (auto intAttr = mlir::dyn_cast_or_null<mlir::IntegerAttr>(attr)) |
| 3403 | p << intAttr.getValue(); |
| 3404 | else |
| 3405 | p.printAttribute(attr); |
| 3406 | p << ", " ; |
| 3407 | op.printSuccessorAtIndex(p, i); |
| 3408 | } |
| 3409 | p << ']'; |
| 3410 | p.printOptionalAttrDict( |
| 3411 | attrs: op->getAttrs(), elidedAttrs: {op.getCasesAttr(), getCompareOffsetAttr(), |
| 3412 | getTargetOffsetAttr(), op.getOperandSegmentSizeAttr()}); |
| 3413 | } |
| 3414 | |
| 3415 | //===----------------------------------------------------------------------===// |
| 3416 | // SelectOp |
| 3417 | //===----------------------------------------------------------------------===// |
| 3418 | |
| 3419 | llvm::LogicalResult fir::SelectOp::verify() { |
| 3420 | return verifyIntegralSwitchTerminator(*this); |
| 3421 | } |
| 3422 | |
| 3423 | mlir::ParseResult fir::SelectOp::parse(mlir::OpAsmParser &parser, |
| 3424 | mlir::OperationState &result) { |
| 3425 | return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), |
| 3426 | getOperandSegmentSizeAttr()); |
| 3427 | } |
| 3428 | |
| 3429 | void fir::SelectOp::print(mlir::OpAsmPrinter &p) { |
| 3430 | printIntegralSwitchTerminator(*this, p); |
| 3431 | } |
| 3432 | |
| 3433 | template <typename A, typename... AdditionalArgs> |
| 3434 | static A getSubOperands(unsigned pos, A allArgs, mlir::DenseI32ArrayAttr ranges, |
| 3435 | AdditionalArgs &&...additionalArgs) { |
| 3436 | unsigned start = 0; |
| 3437 | for (unsigned i = 0; i < pos; ++i) |
| 3438 | start += ranges[i]; |
| 3439 | return allArgs.slice(start, ranges[pos], |
| 3440 | std::forward<AdditionalArgs>(additionalArgs)...); |
| 3441 | } |
| 3442 | |
| 3443 | static mlir::MutableOperandRange |
| 3444 | getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, |
| 3445 | llvm::StringRef offsetAttr) { |
| 3446 | mlir::Operation *owner = operands.getOwner(); |
| 3447 | mlir::NamedAttribute targetOffsetAttr = |
| 3448 | *owner->getAttrDictionary().getNamed(offsetAttr); |
| 3449 | return getSubOperands( |
| 3450 | pos, operands, |
| 3451 | mlir::cast<mlir::DenseI32ArrayAttr>(targetOffsetAttr.getValue()), |
| 3452 | mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); |
| 3453 | } |
| 3454 | |
| 3455 | std::optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { |
| 3456 | return {}; |
| 3457 | } |
| 3458 | |
| 3459 | std::optional<llvm::ArrayRef<mlir::Value>> |
| 3460 | fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { |
| 3461 | return {}; |
| 3462 | } |
| 3463 | |
| 3464 | mlir::SuccessorOperands fir::SelectOp::getSuccessorOperands(unsigned oper) { |
| 3465 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
| 3466 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
| 3467 | } |
| 3468 | |
| 3469 | std::optional<llvm::ArrayRef<mlir::Value>> |
| 3470 | fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
| 3471 | unsigned oper) { |
| 3472 | auto a = |
| 3473 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
| 3474 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
| 3475 | getOperandSegmentSizeAttr()); |
| 3476 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
| 3477 | } |
| 3478 | |
| 3479 | std::optional<mlir::ValueRange> |
| 3480 | fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) { |
| 3481 | auto a = |
| 3482 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
| 3483 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
| 3484 | getOperandSegmentSizeAttr()); |
| 3485 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
| 3486 | } |
| 3487 | |
| 3488 | unsigned fir::SelectOp::targetOffsetSize() { |
| 3489 | return (*this) |
| 3490 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
| 3491 | .size(); |
| 3492 | } |
| 3493 | |
| 3494 | //===----------------------------------------------------------------------===// |
| 3495 | // SelectCaseOp |
| 3496 | //===----------------------------------------------------------------------===// |
| 3497 | |
| 3498 | std::optional<mlir::OperandRange> |
| 3499 | fir::SelectCaseOp::getCompareOperands(unsigned cond) { |
| 3500 | auto a = |
| 3501 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); |
| 3502 | return {getSubOperands(cond, getCompareArgs(), a)}; |
| 3503 | } |
| 3504 | |
| 3505 | std::optional<llvm::ArrayRef<mlir::Value>> |
| 3506 | fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, |
| 3507 | unsigned cond) { |
| 3508 | auto a = |
| 3509 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); |
| 3510 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
| 3511 | getOperandSegmentSizeAttr()); |
| 3512 | return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; |
| 3513 | } |
| 3514 | |
| 3515 | std::optional<mlir::ValueRange> |
| 3516 | fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands, |
| 3517 | unsigned cond) { |
| 3518 | auto a = |
| 3519 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); |
| 3520 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
| 3521 | getOperandSegmentSizeAttr()); |
| 3522 | return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; |
| 3523 | } |
| 3524 | |
| 3525 | mlir::SuccessorOperands fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { |
| 3526 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
| 3527 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
| 3528 | } |
| 3529 | |
| 3530 | std::optional<llvm::ArrayRef<mlir::Value>> |
| 3531 | fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
| 3532 | unsigned oper) { |
| 3533 | auto a = |
| 3534 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
| 3535 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
| 3536 | getOperandSegmentSizeAttr()); |
| 3537 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
| 3538 | } |
| 3539 | |
| 3540 | std::optional<mlir::ValueRange> |
| 3541 | fir::SelectCaseOp::getSuccessorOperands(mlir::ValueRange operands, |
| 3542 | unsigned oper) { |
| 3543 | auto a = |
| 3544 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
| 3545 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
| 3546 | getOperandSegmentSizeAttr()); |
| 3547 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
| 3548 | } |
| 3549 | |
| 3550 | // parser for fir.select_case Op |
| 3551 | mlir::ParseResult fir::SelectCaseOp::parse(mlir::OpAsmParser &parser, |
| 3552 | mlir::OperationState &result) { |
| 3553 | mlir::OpAsmParser::UnresolvedOperand selector; |
| 3554 | mlir::Type type; |
| 3555 | if (fir::parseSelector(parser, result, selector, type)) |
| 3556 | return mlir::failure(); |
| 3557 | |
| 3558 | llvm::SmallVector<mlir::Attribute> attrs; |
| 3559 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> opers; |
| 3560 | llvm::SmallVector<mlir::Block *> dests; |
| 3561 | llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; |
| 3562 | llvm::SmallVector<std::int32_t> argOffs; |
| 3563 | std::int32_t offSize = 0; |
| 3564 | while (true) { |
| 3565 | mlir::Attribute attr; |
| 3566 | mlir::Block *dest; |
| 3567 | llvm::SmallVector<mlir::Value> destArg; |
| 3568 | mlir::NamedAttrList temp; |
| 3569 | if (parser.parseAttribute(attr, "a" , temp) || isValidCaseAttr(attr) || |
| 3570 | parser.parseComma()) |
| 3571 | return mlir::failure(); |
| 3572 | attrs.push_back(attr); |
| 3573 | if (mlir::dyn_cast_or_null<mlir::UnitAttr>(attr)) { |
| 3574 | argOffs.push_back(0); |
| 3575 | } else if (mlir::dyn_cast_or_null<fir::ClosedIntervalAttr>(attr)) { |
| 3576 | mlir::OpAsmParser::UnresolvedOperand oper1; |
| 3577 | mlir::OpAsmParser::UnresolvedOperand oper2; |
| 3578 | if (parser.parseOperand(oper1) || parser.parseComma() || |
| 3579 | parser.parseOperand(oper2) || parser.parseComma()) |
| 3580 | return mlir::failure(); |
| 3581 | opers.push_back(oper1); |
| 3582 | opers.push_back(oper2); |
| 3583 | argOffs.push_back(2); |
| 3584 | offSize += 2; |
| 3585 | } else { |
| 3586 | mlir::OpAsmParser::UnresolvedOperand oper; |
| 3587 | if (parser.parseOperand(oper) || parser.parseComma()) |
| 3588 | return mlir::failure(); |
| 3589 | opers.push_back(oper); |
| 3590 | argOffs.push_back(1); |
| 3591 | ++offSize; |
| 3592 | } |
| 3593 | if (parser.parseSuccessorAndUseList(dest, destArg)) |
| 3594 | return mlir::failure(); |
| 3595 | dests.push_back(dest); |
| 3596 | destArgs.push_back(destArg); |
| 3597 | if (mlir::succeeded(parser.parseOptionalRSquare())) |
| 3598 | break; |
| 3599 | if (parser.parseComma()) |
| 3600 | return mlir::failure(); |
| 3601 | } |
| 3602 | result.addAttribute(fir::SelectCaseOp::getCasesAttr(), |
| 3603 | parser.getBuilder().getArrayAttr(attrs)); |
| 3604 | if (parser.resolveOperands(opers, type, result.operands)) |
| 3605 | return mlir::failure(); |
| 3606 | llvm::SmallVector<int32_t> targOffs; |
| 3607 | int32_t toffSize = 0; |
| 3608 | const auto count = dests.size(); |
| 3609 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
| 3610 | result.addSuccessors(dests[i]); |
| 3611 | result.addOperands(destArgs[i]); |
| 3612 | auto argSize = destArgs[i].size(); |
| 3613 | targOffs.push_back(argSize); |
| 3614 | toffSize += argSize; |
| 3615 | } |
| 3616 | auto &bld = parser.getBuilder(); |
| 3617 | result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), |
| 3618 | bld.getDenseI32ArrayAttr({1, offSize, toffSize})); |
| 3619 | result.addAttribute(getCompareOffsetAttr(), |
| 3620 | bld.getDenseI32ArrayAttr(argOffs)); |
| 3621 | result.addAttribute(getTargetOffsetAttr(), |
| 3622 | bld.getDenseI32ArrayAttr(targOffs)); |
| 3623 | return mlir::success(); |
| 3624 | } |
| 3625 | |
| 3626 | void fir::SelectCaseOp::print(mlir::OpAsmPrinter &p) { |
| 3627 | p << ' '; |
| 3628 | p.printOperand(getSelector()); |
| 3629 | p << " : " << getSelector().getType() << " [" ; |
| 3630 | auto cases = |
| 3631 | getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); |
| 3632 | auto count = getNumConditions(); |
| 3633 | for (decltype(count) i = 0; i != count; ++i) { |
| 3634 | if (i) |
| 3635 | p << ", " ; |
| 3636 | p << cases[i] << ", " ; |
| 3637 | if (!mlir::isa<mlir::UnitAttr>(cases[i])) { |
| 3638 | auto caseArgs = *getCompareOperands(i); |
| 3639 | p.printOperand(*caseArgs.begin()); |
| 3640 | p << ", " ; |
| 3641 | if (mlir::isa<fir::ClosedIntervalAttr>(cases[i])) { |
| 3642 | p.printOperand(*(++caseArgs.begin())); |
| 3643 | p << ", " ; |
| 3644 | } |
| 3645 | } |
| 3646 | printSuccessorAtIndex(p, i); |
| 3647 | } |
| 3648 | p << ']'; |
| 3649 | p.printOptionalAttrDict(getOperation()->getAttrs(), |
| 3650 | {getCasesAttr(), getCompareOffsetAttr(), |
| 3651 | getTargetOffsetAttr(), getOperandSegmentSizeAttr()}); |
| 3652 | } |
| 3653 | |
| 3654 | unsigned fir::SelectCaseOp::compareOffsetSize() { |
| 3655 | return (*this) |
| 3656 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()) |
| 3657 | .size(); |
| 3658 | } |
| 3659 | |
| 3660 | unsigned fir::SelectCaseOp::targetOffsetSize() { |
| 3661 | return (*this) |
| 3662 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
| 3663 | .size(); |
| 3664 | } |
| 3665 | |
| 3666 | void fir::SelectCaseOp::build(mlir::OpBuilder &builder, |
| 3667 | mlir::OperationState &result, |
| 3668 | mlir::Value selector, |
| 3669 | llvm::ArrayRef<mlir::Attribute> compareAttrs, |
| 3670 | llvm::ArrayRef<mlir::ValueRange> cmpOperands, |
| 3671 | llvm::ArrayRef<mlir::Block *> destinations, |
| 3672 | llvm::ArrayRef<mlir::ValueRange> destOperands, |
| 3673 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 3674 | result.addOperands(selector); |
| 3675 | result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); |
| 3676 | llvm::SmallVector<int32_t> operOffs; |
| 3677 | int32_t operSize = 0; |
| 3678 | for (auto attr : compareAttrs) { |
| 3679 | if (mlir::isa<fir::ClosedIntervalAttr>(attr)) { |
| 3680 | operOffs.push_back(2); |
| 3681 | operSize += 2; |
| 3682 | } else if (mlir::isa<mlir::UnitAttr>(attr)) { |
| 3683 | operOffs.push_back(0); |
| 3684 | } else { |
| 3685 | operOffs.push_back(1); |
| 3686 | ++operSize; |
| 3687 | } |
| 3688 | } |
| 3689 | for (auto ops : cmpOperands) |
| 3690 | result.addOperands(ops); |
| 3691 | result.addAttribute(getCompareOffsetAttr(), |
| 3692 | builder.getDenseI32ArrayAttr(operOffs)); |
| 3693 | const auto count = destinations.size(); |
| 3694 | for (auto d : destinations) |
| 3695 | result.addSuccessors(d); |
| 3696 | const auto opCount = destOperands.size(); |
| 3697 | llvm::SmallVector<std::int32_t> argOffs; |
| 3698 | std::int32_t sumArgs = 0; |
| 3699 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
| 3700 | if (i < opCount) { |
| 3701 | result.addOperands(destOperands[i]); |
| 3702 | const auto argSz = destOperands[i].size(); |
| 3703 | argOffs.push_back(argSz); |
| 3704 | sumArgs += argSz; |
| 3705 | } else { |
| 3706 | argOffs.push_back(0); |
| 3707 | } |
| 3708 | } |
| 3709 | result.addAttribute(getOperandSegmentSizeAttr(), |
| 3710 | builder.getDenseI32ArrayAttr({1, operSize, sumArgs})); |
| 3711 | result.addAttribute(getTargetOffsetAttr(), |
| 3712 | builder.getDenseI32ArrayAttr(argOffs)); |
| 3713 | result.addAttributes(attributes); |
| 3714 | } |
| 3715 | |
| 3716 | /// This builder has a slightly simplified interface in that the list of |
| 3717 | /// operands need not be partitioned by the builder. Instead the operands are |
| 3718 | /// partitioned here, before being passed to the default builder. This |
| 3719 | /// partitioning is unchecked, so can go awry on bad input. |
| 3720 | void fir::SelectCaseOp::build(mlir::OpBuilder &builder, |
| 3721 | mlir::OperationState &result, |
| 3722 | mlir::Value selector, |
| 3723 | llvm::ArrayRef<mlir::Attribute> compareAttrs, |
| 3724 | llvm::ArrayRef<mlir::Value> cmpOpList, |
| 3725 | llvm::ArrayRef<mlir::Block *> destinations, |
| 3726 | llvm::ArrayRef<mlir::ValueRange> destOperands, |
| 3727 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 3728 | llvm::SmallVector<mlir::ValueRange> cmpOpers; |
| 3729 | auto iter = cmpOpList.begin(); |
| 3730 | for (auto &attr : compareAttrs) { |
| 3731 | if (mlir::isa<fir::ClosedIntervalAttr>(attr)) { |
| 3732 | cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); |
| 3733 | iter += 2; |
| 3734 | } else if (mlir::isa<mlir::UnitAttr>(attr)) { |
| 3735 | cmpOpers.push_back(mlir::ValueRange{}); |
| 3736 | } else { |
| 3737 | cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); |
| 3738 | ++iter; |
| 3739 | } |
| 3740 | } |
| 3741 | build(builder, result, selector, compareAttrs, cmpOpers, destinations, |
| 3742 | destOperands, attributes); |
| 3743 | } |
| 3744 | |
| 3745 | llvm::LogicalResult fir::SelectCaseOp::verify() { |
| 3746 | if (!mlir::isa<mlir::IntegerType, mlir::IndexType, fir::IntegerType, |
| 3747 | fir::LogicalType, fir::CharacterType>(getSelector().getType())) |
| 3748 | return emitOpError("must be an integer, character, or logical" ); |
| 3749 | auto cases = |
| 3750 | getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); |
| 3751 | auto count = getNumDest(); |
| 3752 | if (count == 0) |
| 3753 | return emitOpError("must have at least one successor" ); |
| 3754 | if (getNumConditions() != count) |
| 3755 | return emitOpError("number of conditions and successors don't match" ); |
| 3756 | if (compareOffsetSize() != count) |
| 3757 | return emitOpError("incorrect number of compare operand groups" ); |
| 3758 | if (targetOffsetSize() != count) |
| 3759 | return emitOpError("incorrect number of successor operand groups" ); |
| 3760 | for (decltype(count) i = 0; i != count; ++i) { |
| 3761 | auto &attr = cases[i]; |
| 3762 | if (!(mlir::isa<fir::PointIntervalAttr>(attr) || |
| 3763 | mlir::isa<fir::LowerBoundAttr>(attr) || |
| 3764 | mlir::isa<fir::UpperBoundAttr>(attr) || |
| 3765 | mlir::isa<fir::ClosedIntervalAttr>(attr) || |
| 3766 | mlir::isa<mlir::UnitAttr>(attr))) |
| 3767 | return emitOpError("incorrect select case attribute type" ); |
| 3768 | } |
| 3769 | return mlir::success(); |
| 3770 | } |
| 3771 | |
| 3772 | //===----------------------------------------------------------------------===// |
| 3773 | // SelectRankOp |
| 3774 | //===----------------------------------------------------------------------===// |
| 3775 | |
| 3776 | llvm::LogicalResult fir::SelectRankOp::verify() { |
| 3777 | return verifyIntegralSwitchTerminator(*this); |
| 3778 | } |
| 3779 | |
| 3780 | mlir::ParseResult fir::SelectRankOp::parse(mlir::OpAsmParser &parser, |
| 3781 | mlir::OperationState &result) { |
| 3782 | return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), |
| 3783 | getOperandSegmentSizeAttr()); |
| 3784 | } |
| 3785 | |
| 3786 | void fir::SelectRankOp::print(mlir::OpAsmPrinter &p) { |
| 3787 | printIntegralSwitchTerminator(*this, p); |
| 3788 | } |
| 3789 | |
| 3790 | std::optional<mlir::OperandRange> |
| 3791 | fir::SelectRankOp::getCompareOperands(unsigned) { |
| 3792 | return {}; |
| 3793 | } |
| 3794 | |
| 3795 | std::optional<llvm::ArrayRef<mlir::Value>> |
| 3796 | fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { |
| 3797 | return {}; |
| 3798 | } |
| 3799 | |
| 3800 | mlir::SuccessorOperands fir::SelectRankOp::getSuccessorOperands(unsigned oper) { |
| 3801 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
| 3802 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
| 3803 | } |
| 3804 | |
| 3805 | std::optional<llvm::ArrayRef<mlir::Value>> |
| 3806 | fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
| 3807 | unsigned oper) { |
| 3808 | auto a = |
| 3809 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
| 3810 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
| 3811 | getOperandSegmentSizeAttr()); |
| 3812 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
| 3813 | } |
| 3814 | |
| 3815 | std::optional<mlir::ValueRange> |
| 3816 | fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands, |
| 3817 | unsigned oper) { |
| 3818 | auto a = |
| 3819 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
| 3820 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
| 3821 | getOperandSegmentSizeAttr()); |
| 3822 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
| 3823 | } |
| 3824 | |
| 3825 | unsigned fir::SelectRankOp::targetOffsetSize() { |
| 3826 | return (*this) |
| 3827 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
| 3828 | .size(); |
| 3829 | } |
| 3830 | |
| 3831 | //===----------------------------------------------------------------------===// |
| 3832 | // SelectTypeOp |
| 3833 | //===----------------------------------------------------------------------===// |
| 3834 | |
| 3835 | std::optional<mlir::OperandRange> |
| 3836 | fir::SelectTypeOp::getCompareOperands(unsigned) { |
| 3837 | return {}; |
| 3838 | } |
| 3839 | |
| 3840 | std::optional<llvm::ArrayRef<mlir::Value>> |
| 3841 | fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { |
| 3842 | return {}; |
| 3843 | } |
| 3844 | |
| 3845 | mlir::SuccessorOperands fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { |
| 3846 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
| 3847 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
| 3848 | } |
| 3849 | |
| 3850 | std::optional<llvm::ArrayRef<mlir::Value>> |
| 3851 | fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
| 3852 | unsigned oper) { |
| 3853 | auto a = |
| 3854 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
| 3855 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
| 3856 | getOperandSegmentSizeAttr()); |
| 3857 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
| 3858 | } |
| 3859 | |
| 3860 | std::optional<mlir::ValueRange> |
| 3861 | fir::SelectTypeOp::getSuccessorOperands(mlir::ValueRange operands, |
| 3862 | unsigned oper) { |
| 3863 | auto a = |
| 3864 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
| 3865 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
| 3866 | getOperandSegmentSizeAttr()); |
| 3867 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
| 3868 | } |
| 3869 | |
| 3870 | mlir::ParseResult fir::SelectTypeOp::parse(mlir::OpAsmParser &parser, |
| 3871 | mlir::OperationState &result) { |
| 3872 | mlir::OpAsmParser::UnresolvedOperand selector; |
| 3873 | mlir::Type type; |
| 3874 | if (fir::parseSelector(parser, result, selector, type)) |
| 3875 | return mlir::failure(); |
| 3876 | |
| 3877 | llvm::SmallVector<mlir::Attribute> attrs; |
| 3878 | llvm::SmallVector<mlir::Block *> dests; |
| 3879 | llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; |
| 3880 | while (true) { |
| 3881 | mlir::Attribute attr; |
| 3882 | mlir::Block *dest; |
| 3883 | llvm::SmallVector<mlir::Value> destArg; |
| 3884 | mlir::NamedAttrList temp; |
| 3885 | if (parser.parseAttribute(attr, "a" , temp) || parser.parseComma() || |
| 3886 | parser.parseSuccessorAndUseList(dest, destArg)) |
| 3887 | return mlir::failure(); |
| 3888 | attrs.push_back(attr); |
| 3889 | dests.push_back(dest); |
| 3890 | destArgs.push_back(destArg); |
| 3891 | if (mlir::succeeded(parser.parseOptionalRSquare())) |
| 3892 | break; |
| 3893 | if (parser.parseComma()) |
| 3894 | return mlir::failure(); |
| 3895 | } |
| 3896 | auto &bld = parser.getBuilder(); |
| 3897 | result.addAttribute(fir::SelectTypeOp::getCasesAttr(), |
| 3898 | bld.getArrayAttr(attrs)); |
| 3899 | llvm::SmallVector<int32_t> argOffs; |
| 3900 | int32_t offSize = 0; |
| 3901 | const auto count = dests.size(); |
| 3902 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
| 3903 | result.addSuccessors(dests[i]); |
| 3904 | result.addOperands(destArgs[i]); |
| 3905 | auto argSize = destArgs[i].size(); |
| 3906 | argOffs.push_back(argSize); |
| 3907 | offSize += argSize; |
| 3908 | } |
| 3909 | result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), |
| 3910 | bld.getDenseI32ArrayAttr({1, 0, offSize})); |
| 3911 | result.addAttribute(getTargetOffsetAttr(), bld.getDenseI32ArrayAttr(argOffs)); |
| 3912 | return mlir::success(); |
| 3913 | } |
| 3914 | |
| 3915 | unsigned fir::SelectTypeOp::targetOffsetSize() { |
| 3916 | return (*this) |
| 3917 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
| 3918 | .size(); |
| 3919 | } |
| 3920 | |
| 3921 | void fir::SelectTypeOp::print(mlir::OpAsmPrinter &p) { |
| 3922 | p << ' '; |
| 3923 | p.printOperand(getSelector()); |
| 3924 | p << " : " << getSelector().getType() << " [" ; |
| 3925 | auto cases = |
| 3926 | getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); |
| 3927 | auto count = getNumConditions(); |
| 3928 | for (decltype(count) i = 0; i != count; ++i) { |
| 3929 | if (i) |
| 3930 | p << ", " ; |
| 3931 | p << cases[i] << ", " ; |
| 3932 | printSuccessorAtIndex(p, i); |
| 3933 | } |
| 3934 | p << ']'; |
| 3935 | p.printOptionalAttrDict(getOperation()->getAttrs(), |
| 3936 | {getCasesAttr(), getCompareOffsetAttr(), |
| 3937 | getTargetOffsetAttr(), |
| 3938 | fir::SelectTypeOp::getOperandSegmentSizeAttr()}); |
| 3939 | } |
| 3940 | |
| 3941 | llvm::LogicalResult fir::SelectTypeOp::verify() { |
| 3942 | if (!mlir::isa<fir::BaseBoxType>(getSelector().getType())) |
| 3943 | return emitOpError("must be a fir.class or fir.box type" ); |
| 3944 | if (auto boxType = mlir::dyn_cast<fir::BoxType>(getSelector().getType())) |
| 3945 | if (!mlir::isa<mlir::NoneType>(boxType.getEleTy())) |
| 3946 | return emitOpError("selector must be polymorphic" ); |
| 3947 | auto typeGuardAttr = getCases(); |
| 3948 | for (unsigned idx = 0; idx < typeGuardAttr.size(); ++idx) |
| 3949 | if (mlir::isa<mlir::UnitAttr>(typeGuardAttr[idx]) && |
| 3950 | idx != typeGuardAttr.size() - 1) |
| 3951 | return emitOpError("default must be the last attribute" ); |
| 3952 | auto count = getNumDest(); |
| 3953 | if (count == 0) |
| 3954 | return emitOpError("must have at least one successor" ); |
| 3955 | if (getNumConditions() != count) |
| 3956 | return emitOpError("number of conditions and successors don't match" ); |
| 3957 | if (targetOffsetSize() != count) |
| 3958 | return emitOpError("incorrect number of successor operand groups" ); |
| 3959 | for (unsigned i = 0; i != count; ++i) { |
| 3960 | if (!mlir::isa<fir::ExactTypeAttr, fir::SubclassAttr, mlir::UnitAttr>( |
| 3961 | typeGuardAttr[i])) |
| 3962 | return emitOpError("invalid type-case alternative" ); |
| 3963 | } |
| 3964 | return mlir::success(); |
| 3965 | } |
| 3966 | |
| 3967 | void fir::SelectTypeOp::build(mlir::OpBuilder &builder, |
| 3968 | mlir::OperationState &result, |
| 3969 | mlir::Value selector, |
| 3970 | llvm::ArrayRef<mlir::Attribute> typeOperands, |
| 3971 | llvm::ArrayRef<mlir::Block *> destinations, |
| 3972 | llvm::ArrayRef<mlir::ValueRange> destOperands, |
| 3973 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
| 3974 | result.addOperands(selector); |
| 3975 | result.addAttribute(getCasesAttr(), builder.getArrayAttr(typeOperands)); |
| 3976 | const auto count = destinations.size(); |
| 3977 | for (mlir::Block *dest : destinations) |
| 3978 | result.addSuccessors(dest); |
| 3979 | const auto opCount = destOperands.size(); |
| 3980 | llvm::SmallVector<int32_t> argOffs; |
| 3981 | int32_t sumArgs = 0; |
| 3982 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
| 3983 | if (i < opCount) { |
| 3984 | result.addOperands(destOperands[i]); |
| 3985 | const auto argSz = destOperands[i].size(); |
| 3986 | argOffs.push_back(argSz); |
| 3987 | sumArgs += argSz; |
| 3988 | } else { |
| 3989 | argOffs.push_back(0); |
| 3990 | } |
| 3991 | } |
| 3992 | result.addAttribute(getOperandSegmentSizeAttr(), |
| 3993 | builder.getDenseI32ArrayAttr({1, 0, sumArgs})); |
| 3994 | result.addAttribute(getTargetOffsetAttr(), |
| 3995 | builder.getDenseI32ArrayAttr(argOffs)); |
| 3996 | result.addAttributes(attributes); |
| 3997 | } |
| 3998 | |
| 3999 | //===----------------------------------------------------------------------===// |
| 4000 | // ShapeOp |
| 4001 | //===----------------------------------------------------------------------===// |
| 4002 | |
| 4003 | llvm::LogicalResult fir::ShapeOp::verify() { |
| 4004 | auto size = getExtents().size(); |
| 4005 | auto shapeTy = mlir::dyn_cast<fir::ShapeType>(getType()); |
| 4006 | assert(shapeTy && "must be a shape type" ); |
| 4007 | if (shapeTy.getRank() != size) |
| 4008 | return emitOpError("shape type rank mismatch" ); |
| 4009 | return mlir::success(); |
| 4010 | } |
| 4011 | |
| 4012 | void fir::ShapeOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 4013 | mlir::ValueRange extents) { |
| 4014 | auto type = fir::ShapeType::get(builder.getContext(), extents.size()); |
| 4015 | build(builder, result, type, extents); |
| 4016 | } |
| 4017 | |
| 4018 | //===----------------------------------------------------------------------===// |
| 4019 | // ShapeShiftOp |
| 4020 | //===----------------------------------------------------------------------===// |
| 4021 | |
| 4022 | llvm::LogicalResult fir::ShapeShiftOp::verify() { |
| 4023 | auto size = getPairs().size(); |
| 4024 | if (size < 2 || size > 16 * 2) |
| 4025 | return emitOpError("incorrect number of args" ); |
| 4026 | if (size % 2 != 0) |
| 4027 | return emitOpError("requires a multiple of 2 args" ); |
| 4028 | auto shapeTy = mlir::dyn_cast<fir::ShapeShiftType>(getType()); |
| 4029 | assert(shapeTy && "must be a shape shift type" ); |
| 4030 | if (shapeTy.getRank() * 2 != size) |
| 4031 | return emitOpError("shape type rank mismatch" ); |
| 4032 | return mlir::success(); |
| 4033 | } |
| 4034 | |
| 4035 | //===----------------------------------------------------------------------===// |
| 4036 | // ShiftOp |
| 4037 | //===----------------------------------------------------------------------===// |
| 4038 | |
| 4039 | llvm::LogicalResult fir::ShiftOp::verify() { |
| 4040 | auto size = getOrigins().size(); |
| 4041 | auto shiftTy = mlir::dyn_cast<fir::ShiftType>(getType()); |
| 4042 | assert(shiftTy && "must be a shift type" ); |
| 4043 | if (shiftTy.getRank() != size) |
| 4044 | return emitOpError("shift type rank mismatch" ); |
| 4045 | return mlir::success(); |
| 4046 | } |
| 4047 | |
| 4048 | //===----------------------------------------------------------------------===// |
| 4049 | // SliceOp |
| 4050 | //===----------------------------------------------------------------------===// |
| 4051 | |
| 4052 | void fir::SliceOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 4053 | mlir::ValueRange trips, mlir::ValueRange path, |
| 4054 | mlir::ValueRange substr) { |
| 4055 | const auto rank = trips.size() / 3; |
| 4056 | auto sliceTy = fir::SliceType::get(builder.getContext(), rank); |
| 4057 | build(builder, result, sliceTy, trips, path, substr); |
| 4058 | } |
| 4059 | |
| 4060 | /// Return the output rank of a slice op. The output rank must be between 1 and |
| 4061 | /// the rank of the array being sliced (inclusive). |
| 4062 | unsigned fir::SliceOp::getOutputRank(mlir::ValueRange triples) { |
| 4063 | unsigned rank = 0; |
| 4064 | if (!triples.empty()) { |
| 4065 | for (unsigned i = 1, end = triples.size(); i < end; i += 3) { |
| 4066 | auto *op = triples[i].getDefiningOp(); |
| 4067 | if (!mlir::isa_and_nonnull<fir::UndefOp>(op)) |
| 4068 | ++rank; |
| 4069 | } |
| 4070 | assert(rank > 0); |
| 4071 | } |
| 4072 | return rank; |
| 4073 | } |
| 4074 | |
| 4075 | llvm::LogicalResult fir::SliceOp::verify() { |
| 4076 | auto size = getTriples().size(); |
| 4077 | if (size < 3 || size > 16 * 3) |
| 4078 | return emitOpError("incorrect number of args for triple" ); |
| 4079 | if (size % 3 != 0) |
| 4080 | return emitOpError("requires a multiple of 3 args" ); |
| 4081 | auto sliceTy = mlir::dyn_cast<fir::SliceType>(getType()); |
| 4082 | assert(sliceTy && "must be a slice type" ); |
| 4083 | if (sliceTy.getRank() * 3 != size) |
| 4084 | return emitOpError("slice type rank mismatch" ); |
| 4085 | return mlir::success(); |
| 4086 | } |
| 4087 | |
| 4088 | //===----------------------------------------------------------------------===// |
| 4089 | // StoreOp |
| 4090 | //===----------------------------------------------------------------------===// |
| 4091 | |
| 4092 | mlir::Type fir::StoreOp::elementType(mlir::Type refType) { |
| 4093 | return fir::dyn_cast_ptrEleTy(refType); |
| 4094 | } |
| 4095 | |
| 4096 | mlir::ParseResult fir::StoreOp::parse(mlir::OpAsmParser &parser, |
| 4097 | mlir::OperationState &result) { |
| 4098 | mlir::Type type; |
| 4099 | mlir::OpAsmParser::UnresolvedOperand oper; |
| 4100 | mlir::OpAsmParser::UnresolvedOperand store; |
| 4101 | if (parser.parseOperand(oper) || parser.parseKeyword("to" ) || |
| 4102 | parser.parseOperand(store) || |
| 4103 | parser.parseOptionalAttrDict(result.attributes) || |
| 4104 | parser.parseColonType(type) || |
| 4105 | parser.resolveOperand(oper, fir::StoreOp::elementType(type), |
| 4106 | result.operands) || |
| 4107 | parser.resolveOperand(store, type, result.operands)) |
| 4108 | return mlir::failure(); |
| 4109 | return mlir::success(); |
| 4110 | } |
| 4111 | |
| 4112 | void fir::StoreOp::print(mlir::OpAsmPrinter &p) { |
| 4113 | p << ' '; |
| 4114 | p.printOperand(getValue()); |
| 4115 | p << " to " ; |
| 4116 | p.printOperand(getMemref()); |
| 4117 | p.printOptionalAttrDict(getOperation()->getAttrs(), {}); |
| 4118 | p << " : " << getMemref().getType(); |
| 4119 | } |
| 4120 | |
| 4121 | llvm::LogicalResult fir::StoreOp::verify() { |
| 4122 | if (getValue().getType() != fir::dyn_cast_ptrEleTy(getMemref().getType())) |
| 4123 | return emitOpError("store value type must match memory reference type" ); |
| 4124 | return mlir::success(); |
| 4125 | } |
| 4126 | |
| 4127 | void fir::StoreOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 4128 | mlir::Value value, mlir::Value memref) { |
| 4129 | build(builder, result, value, memref, {}); |
| 4130 | } |
| 4131 | |
| 4132 | void fir::StoreOp::getEffects( |
| 4133 | llvm::SmallVectorImpl< |
| 4134 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
| 4135 | &effects) { |
| 4136 | effects.emplace_back(mlir::MemoryEffects::Write::get(), &getMemrefMutable(), |
| 4137 | mlir::SideEffects::DefaultResource::get()); |
| 4138 | addVolatileMemoryEffects({getMemref().getType()}, effects); |
| 4139 | } |
| 4140 | |
| 4141 | //===----------------------------------------------------------------------===// |
| 4142 | // CopyOp |
| 4143 | //===----------------------------------------------------------------------===// |
| 4144 | |
| 4145 | void fir::CopyOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 4146 | mlir::Value source, mlir::Value destination, |
| 4147 | bool noOverlap) { |
| 4148 | mlir::UnitAttr noOverlapAttr = |
| 4149 | noOverlap ? builder.getUnitAttr() : mlir::UnitAttr{}; |
| 4150 | build(builder, result, source, destination, noOverlapAttr); |
| 4151 | } |
| 4152 | |
| 4153 | llvm::LogicalResult fir::CopyOp::verify() { |
| 4154 | mlir::Type sourceType = fir::unwrapRefType(getSource().getType()); |
| 4155 | mlir::Type destinationType = fir::unwrapRefType(getDestination().getType()); |
| 4156 | if (sourceType != destinationType) |
| 4157 | return emitOpError("source and destination must have the same value type" ); |
| 4158 | return mlir::success(); |
| 4159 | } |
| 4160 | |
| 4161 | void fir::CopyOp::getEffects( |
| 4162 | llvm::SmallVectorImpl< |
| 4163 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
| 4164 | &effects) { |
| 4165 | effects.emplace_back(mlir::MemoryEffects::Read::get(), &getSourceMutable(), |
| 4166 | mlir::SideEffects::DefaultResource::get()); |
| 4167 | effects.emplace_back(mlir::MemoryEffects::Write::get(), |
| 4168 | &getDestinationMutable(), |
| 4169 | mlir::SideEffects::DefaultResource::get()); |
| 4170 | addVolatileMemoryEffects({getDestination().getType(), getSource().getType()}, |
| 4171 | effects); |
| 4172 | } |
| 4173 | |
| 4174 | //===----------------------------------------------------------------------===// |
| 4175 | // StringLitOp |
| 4176 | //===----------------------------------------------------------------------===// |
| 4177 | |
| 4178 | inline fir::CharacterType::KindTy stringLitOpGetKind(fir::StringLitOp op) { |
| 4179 | auto eleTy = mlir::cast<fir::SequenceType>(op.getType()).getElementType(); |
| 4180 | return mlir::cast<fir::CharacterType>(eleTy).getFKind(); |
| 4181 | } |
| 4182 | |
| 4183 | bool fir::StringLitOp::isWideValue() { return stringLitOpGetKind(*this) != 1; } |
| 4184 | |
| 4185 | static mlir::NamedAttribute |
| 4186 | mkNamedIntegerAttr(mlir::OpBuilder &builder, llvm::StringRef name, int64_t v) { |
| 4187 | assert(v > 0); |
| 4188 | return builder.getNamedAttr( |
| 4189 | name, builder.getIntegerAttr(builder.getIntegerType(64), v)); |
| 4190 | } |
| 4191 | |
| 4192 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
| 4193 | mlir::OperationState &result, |
| 4194 | fir::CharacterType inType, llvm::StringRef val, |
| 4195 | std::optional<int64_t> len) { |
| 4196 | auto valAttr = builder.getNamedAttr(value(), builder.getStringAttr(val)); |
| 4197 | int64_t length = len ? *len : inType.getLen(); |
| 4198 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
| 4199 | result.addAttributes({valAttr, lenAttr}); |
| 4200 | result.addTypes(inType); |
| 4201 | } |
| 4202 | |
| 4203 | template <typename C> |
| 4204 | static mlir::ArrayAttr convertToArrayAttr(mlir::OpBuilder &builder, |
| 4205 | llvm::ArrayRef<C> xlist) { |
| 4206 | llvm::SmallVector<mlir::Attribute> attrs; |
| 4207 | auto ty = builder.getIntegerType(8 * sizeof(C)); |
| 4208 | for (auto ch : xlist) |
| 4209 | attrs.push_back(Elt: builder.getIntegerAttr(ty, ch)); |
| 4210 | return builder.getArrayAttr(attrs); |
| 4211 | } |
| 4212 | |
| 4213 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
| 4214 | mlir::OperationState &result, |
| 4215 | fir::CharacterType inType, |
| 4216 | llvm::ArrayRef<char> vlist, |
| 4217 | std::optional<std::int64_t> len) { |
| 4218 | auto valAttr = |
| 4219 | builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); |
| 4220 | std::int64_t length = len ? *len : inType.getLen(); |
| 4221 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
| 4222 | result.addAttributes({valAttr, lenAttr}); |
| 4223 | result.addTypes(inType); |
| 4224 | } |
| 4225 | |
| 4226 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
| 4227 | mlir::OperationState &result, |
| 4228 | fir::CharacterType inType, |
| 4229 | llvm::ArrayRef<char16_t> vlist, |
| 4230 | std::optional<std::int64_t> len) { |
| 4231 | auto valAttr = |
| 4232 | builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); |
| 4233 | std::int64_t length = len ? *len : inType.getLen(); |
| 4234 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
| 4235 | result.addAttributes({valAttr, lenAttr}); |
| 4236 | result.addTypes(inType); |
| 4237 | } |
| 4238 | |
| 4239 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
| 4240 | mlir::OperationState &result, |
| 4241 | fir::CharacterType inType, |
| 4242 | llvm::ArrayRef<char32_t> vlist, |
| 4243 | std::optional<std::int64_t> len) { |
| 4244 | auto valAttr = |
| 4245 | builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); |
| 4246 | std::int64_t length = len ? *len : inType.getLen(); |
| 4247 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
| 4248 | result.addAttributes({valAttr, lenAttr}); |
| 4249 | result.addTypes(inType); |
| 4250 | } |
| 4251 | |
| 4252 | mlir::ParseResult fir::StringLitOp::parse(mlir::OpAsmParser &parser, |
| 4253 | mlir::OperationState &result) { |
| 4254 | auto &builder = parser.getBuilder(); |
| 4255 | mlir::Attribute val; |
| 4256 | mlir::NamedAttrList attrs; |
| 4257 | llvm::SMLoc trailingTypeLoc; |
| 4258 | if (parser.parseAttribute(val, "fake" , attrs)) |
| 4259 | return mlir::failure(); |
| 4260 | if (auto v = mlir::dyn_cast<mlir::StringAttr>(val)) |
| 4261 | result.attributes.push_back( |
| 4262 | builder.getNamedAttr(fir::StringLitOp::value(), v)); |
| 4263 | else if (auto v = mlir::dyn_cast<mlir::DenseElementsAttr>(val)) |
| 4264 | result.attributes.push_back( |
| 4265 | builder.getNamedAttr(fir::StringLitOp::xlist(), v)); |
| 4266 | else if (auto v = mlir::dyn_cast<mlir::ArrayAttr>(val)) |
| 4267 | result.attributes.push_back( |
| 4268 | builder.getNamedAttr(fir::StringLitOp::xlist(), v)); |
| 4269 | else |
| 4270 | return parser.emitError(parser.getCurrentLocation(), |
| 4271 | "found an invalid constant" ); |
| 4272 | mlir::IntegerAttr sz; |
| 4273 | mlir::Type type; |
| 4274 | if (parser.parseLParen() || |
| 4275 | parser.parseAttribute(sz, fir::StringLitOp::size(), result.attributes) || |
| 4276 | parser.parseRParen() || parser.getCurrentLocation(&trailingTypeLoc) || |
| 4277 | parser.parseColonType(type)) |
| 4278 | return mlir::failure(); |
| 4279 | auto charTy = mlir::dyn_cast<fir::CharacterType>(type); |
| 4280 | if (!charTy) |
| 4281 | return parser.emitError(trailingTypeLoc, "must have character type" ); |
| 4282 | type = fir::CharacterType::get(builder.getContext(), charTy.getFKind(), |
| 4283 | sz.getInt()); |
| 4284 | if (!type || parser.addTypesToList(type, result.types)) |
| 4285 | return mlir::failure(); |
| 4286 | return mlir::success(); |
| 4287 | } |
| 4288 | |
| 4289 | void fir::StringLitOp::print(mlir::OpAsmPrinter &p) { |
| 4290 | p << ' ' << getValue() << '('; |
| 4291 | p << mlir::cast<mlir::IntegerAttr>(getSize()).getValue() << ") : " ; |
| 4292 | p.printType(getType()); |
| 4293 | } |
| 4294 | |
| 4295 | llvm::LogicalResult fir::StringLitOp::verify() { |
| 4296 | if (mlir::cast<mlir::IntegerAttr>(getSize()).getValue().isNegative()) |
| 4297 | return emitOpError("size must be non-negative" ); |
| 4298 | if (auto xl = getOperation()->getAttr(fir::StringLitOp::xlist())) { |
| 4299 | if (auto xList = mlir::dyn_cast<mlir::ArrayAttr>(xl)) { |
| 4300 | for (auto a : xList) |
| 4301 | if (!mlir::isa<mlir::IntegerAttr>(a)) |
| 4302 | return emitOpError("values in initializer must be integers" ); |
| 4303 | } else if (mlir::isa<mlir::DenseElementsAttr>(xl)) { |
| 4304 | // do nothing |
| 4305 | } else { |
| 4306 | return emitOpError("has unexpected attribute" ); |
| 4307 | } |
| 4308 | } |
| 4309 | return mlir::success(); |
| 4310 | } |
| 4311 | |
| 4312 | //===----------------------------------------------------------------------===// |
| 4313 | // UnboxProcOp |
| 4314 | //===----------------------------------------------------------------------===// |
| 4315 | |
| 4316 | llvm::LogicalResult fir::UnboxProcOp::verify() { |
| 4317 | if (auto eleTy = fir::dyn_cast_ptrEleTy(getRefTuple().getType())) |
| 4318 | if (mlir::isa<mlir::TupleType>(eleTy)) |
| 4319 | return mlir::success(); |
| 4320 | return emitOpError("second output argument has bad type" ); |
| 4321 | } |
| 4322 | |
| 4323 | //===----------------------------------------------------------------------===// |
| 4324 | // IfOp |
| 4325 | //===----------------------------------------------------------------------===// |
| 4326 | |
| 4327 | void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 4328 | mlir::Value cond, bool withElseRegion) { |
| 4329 | build(builder, result, std::nullopt, cond, withElseRegion); |
| 4330 | } |
| 4331 | |
| 4332 | void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
| 4333 | mlir::TypeRange resultTypes, mlir::Value cond, |
| 4334 | bool withElseRegion) { |
| 4335 | result.addOperands(cond); |
| 4336 | result.addTypes(resultTypes); |
| 4337 | |
| 4338 | mlir::Region *thenRegion = result.addRegion(); |
| 4339 | thenRegion->push_back(new mlir::Block()); |
| 4340 | if (resultTypes.empty()) |
| 4341 | IfOp::ensureTerminator(*thenRegion, builder, result.location); |
| 4342 | |
| 4343 | mlir::Region *elseRegion = result.addRegion(); |
| 4344 | if (withElseRegion) { |
| 4345 | elseRegion->push_back(new mlir::Block()); |
| 4346 | if (resultTypes.empty()) |
| 4347 | IfOp::ensureTerminator(*elseRegion, builder, result.location); |
| 4348 | } |
| 4349 | } |
| 4350 | |
| 4351 | // These 3 functions copied from scf.if implementation. |
| 4352 | |
| 4353 | /// Given the region at `index`, or the parent operation if `index` is None, |
| 4354 | /// return the successor regions. These are the regions that may be selected |
| 4355 | /// during the flow of control. |
| 4356 | void fir::IfOp::getSuccessorRegions( |
| 4357 | mlir::RegionBranchPoint point, |
| 4358 | llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) { |
| 4359 | // The `then` and the `else` region branch back to the parent operation. |
| 4360 | if (!point.isParent()) { |
| 4361 | regions.push_back(mlir::RegionSuccessor(getResults())); |
| 4362 | return; |
| 4363 | } |
| 4364 | |
| 4365 | // Don't consider the else region if it is empty. |
| 4366 | regions.push_back(mlir::RegionSuccessor(&getThenRegion())); |
| 4367 | |
| 4368 | // Don't consider the else region if it is empty. |
| 4369 | mlir::Region *elseRegion = &this->getElseRegion(); |
| 4370 | if (elseRegion->empty()) |
| 4371 | regions.push_back(mlir::RegionSuccessor()); |
| 4372 | else |
| 4373 | regions.push_back(mlir::RegionSuccessor(elseRegion)); |
| 4374 | } |
| 4375 | |
| 4376 | void fir::IfOp::getEntrySuccessorRegions( |
| 4377 | llvm::ArrayRef<mlir::Attribute> operands, |
| 4378 | llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) { |
| 4379 | FoldAdaptor adaptor(operands); |
| 4380 | auto boolAttr = |
| 4381 | mlir::dyn_cast_or_null<mlir::BoolAttr>(adaptor.getCondition()); |
| 4382 | if (!boolAttr || boolAttr.getValue()) |
| 4383 | regions.emplace_back(&getThenRegion()); |
| 4384 | |
| 4385 | // If the else region is empty, execution continues after the parent op. |
| 4386 | if (!boolAttr || !boolAttr.getValue()) { |
| 4387 | if (!getElseRegion().empty()) |
| 4388 | regions.emplace_back(&getElseRegion()); |
| 4389 | else |
| 4390 | regions.emplace_back(getResults()); |
| 4391 | } |
| 4392 | } |
| 4393 | |
| 4394 | void fir::IfOp::getRegionInvocationBounds( |
| 4395 | llvm::ArrayRef<mlir::Attribute> operands, |
| 4396 | llvm::SmallVectorImpl<mlir::InvocationBounds> &invocationBounds) { |
| 4397 | if (auto cond = mlir::dyn_cast_or_null<mlir::BoolAttr>(operands[0])) { |
| 4398 | // If the condition is known, then one region is known to be executed once |
| 4399 | // and the other zero times. |
| 4400 | invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); |
| 4401 | invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1); |
| 4402 | } else { |
| 4403 | // Non-constant condition. Each region may be executed 0 or 1 times. |
| 4404 | invocationBounds.assign(2, {0, 1}); |
| 4405 | } |
| 4406 | } |
| 4407 | |
| 4408 | mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser, |
| 4409 | mlir::OperationState &result) { |
| 4410 | result.regions.reserve(2); |
| 4411 | mlir::Region *thenRegion = result.addRegion(); |
| 4412 | mlir::Region *elseRegion = result.addRegion(); |
| 4413 | |
| 4414 | auto &builder = parser.getBuilder(); |
| 4415 | mlir::OpAsmParser::UnresolvedOperand cond; |
| 4416 | mlir::Type i1Type = builder.getIntegerType(1); |
| 4417 | if (parser.parseOperand(cond) || |
| 4418 | parser.resolveOperand(cond, i1Type, result.operands)) |
| 4419 | return mlir::failure(); |
| 4420 | |
| 4421 | if (parser.parseOptionalArrowTypeList(result.types)) |
| 4422 | return mlir::failure(); |
| 4423 | |
| 4424 | if (parser.parseRegion(*thenRegion, {}, {})) |
| 4425 | return mlir::failure(); |
| 4426 | fir::IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), |
| 4427 | result.location); |
| 4428 | |
| 4429 | if (mlir::succeeded(parser.parseOptionalKeyword("else" ))) { |
| 4430 | if (parser.parseRegion(*elseRegion, {}, {})) |
| 4431 | return mlir::failure(); |
| 4432 | fir::IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), |
| 4433 | result.location); |
| 4434 | } |
| 4435 | |
| 4436 | // Parse the optional attribute list. |
| 4437 | if (parser.parseOptionalAttrDict(result.attributes)) |
| 4438 | return mlir::failure(); |
| 4439 | return mlir::success(); |
| 4440 | } |
| 4441 | |
| 4442 | llvm::LogicalResult fir::IfOp::verify() { |
| 4443 | if (getNumResults() != 0 && getElseRegion().empty()) |
| 4444 | return emitOpError("must have an else block if defining values" ); |
| 4445 | |
| 4446 | return mlir::success(); |
| 4447 | } |
| 4448 | |
| 4449 | void fir::IfOp::print(mlir::OpAsmPrinter &p) { |
| 4450 | bool printBlockTerminators = false; |
| 4451 | p << ' ' << getCondition(); |
| 4452 | if (!getResults().empty()) { |
| 4453 | p << " -> (" << getResultTypes() << ')'; |
| 4454 | printBlockTerminators = true; |
| 4455 | } |
| 4456 | p << ' '; |
| 4457 | p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false, |
| 4458 | printBlockTerminators); |
| 4459 | |
| 4460 | // Print the 'else' regions if it exists and has a block. |
| 4461 | auto &otherReg = getElseRegion(); |
| 4462 | if (!otherReg.empty()) { |
| 4463 | p << " else " ; |
| 4464 | p.printRegion(otherReg, /*printEntryBlockArgs=*/false, |
| 4465 | printBlockTerminators); |
| 4466 | } |
| 4467 | p.printOptionalAttrDict((*this)->getAttrs()); |
| 4468 | } |
| 4469 | |
| 4470 | void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results, |
| 4471 | unsigned resultNum) { |
| 4472 | auto *term = getThenRegion().front().getTerminator(); |
| 4473 | if (resultNum < term->getNumOperands()) |
| 4474 | results.push_back(term->getOperand(resultNum)); |
| 4475 | term = getElseRegion().front().getTerminator(); |
| 4476 | if (resultNum < term->getNumOperands()) |
| 4477 | results.push_back(term->getOperand(resultNum)); |
| 4478 | } |
| 4479 | |
| 4480 | //===----------------------------------------------------------------------===// |
| 4481 | // BoxOffsetOp |
| 4482 | //===----------------------------------------------------------------------===// |
| 4483 | |
| 4484 | llvm::LogicalResult fir::BoxOffsetOp::verify() { |
| 4485 | auto boxType = mlir::dyn_cast_or_null<fir::BaseBoxType>( |
| 4486 | fir::dyn_cast_ptrEleTy(getBoxRef().getType())); |
| 4487 | mlir::Type boxCharType; |
| 4488 | if (!boxType) { |
| 4489 | boxCharType = mlir::dyn_cast_or_null<fir::BoxCharType>( |
| 4490 | fir::dyn_cast_ptrEleTy(getBoxRef().getType())); |
| 4491 | if (!boxCharType) |
| 4492 | return emitOpError("box_ref operand must have !fir.ref<!fir.box<T>> or " |
| 4493 | "!fir.ref<!fir.boxchar<k>> type" ); |
| 4494 | if (getField() == fir::BoxFieldAttr::derived_type) |
| 4495 | return emitOpError("cannot address derived_type field of a fir.boxchar" ); |
| 4496 | } |
| 4497 | if (getField() != fir::BoxFieldAttr::base_addr && |
| 4498 | getField() != fir::BoxFieldAttr::derived_type) |
| 4499 | return emitOpError("cannot address provided field" ); |
| 4500 | if (getField() == fir::BoxFieldAttr::derived_type) { |
| 4501 | if (!fir::boxHasAddendum(boxType)) |
| 4502 | return emitOpError("can only address derived_type field of derived type " |
| 4503 | "or unlimited polymorphic fir.box" ); |
| 4504 | } |
| 4505 | return mlir::success(); |
| 4506 | } |
| 4507 | |
| 4508 | void fir::BoxOffsetOp::build(mlir::OpBuilder &builder, |
| 4509 | mlir::OperationState &result, mlir::Value boxRef, |
| 4510 | fir::BoxFieldAttr field) { |
| 4511 | mlir::Type valueType = |
| 4512 | fir::unwrapPassByRefType(fir::unwrapRefType(boxRef.getType())); |
| 4513 | mlir::Type resultType = valueType; |
| 4514 | if (field == fir::BoxFieldAttr::base_addr) |
| 4515 | resultType = fir::LLVMPointerType::get(fir::ReferenceType::get(valueType)); |
| 4516 | else if (field == fir::BoxFieldAttr::derived_type) |
| 4517 | resultType = fir::LLVMPointerType::get( |
| 4518 | fir::TypeDescType::get(fir::unwrapSequenceType(valueType))); |
| 4519 | build(builder, result, {resultType}, boxRef, field); |
| 4520 | } |
| 4521 | |
| 4522 | //===----------------------------------------------------------------------===// |
| 4523 | |
| 4524 | mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { |
| 4525 | if (mlir::isa<mlir::UnitAttr, fir::ClosedIntervalAttr, fir::PointIntervalAttr, |
| 4526 | fir::LowerBoundAttr, fir::UpperBoundAttr>(attr)) |
| 4527 | return mlir::success(); |
| 4528 | return mlir::failure(); |
| 4529 | } |
| 4530 | |
| 4531 | unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, |
| 4532 | unsigned dest) { |
| 4533 | unsigned o = 0; |
| 4534 | for (unsigned i = 0; i < dest; ++i) { |
| 4535 | auto &attr = cases[i]; |
| 4536 | if (!mlir::dyn_cast_or_null<mlir::UnitAttr>(attr)) { |
| 4537 | ++o; |
| 4538 | if (mlir::dyn_cast_or_null<fir::ClosedIntervalAttr>(attr)) |
| 4539 | ++o; |
| 4540 | } |
| 4541 | } |
| 4542 | return o; |
| 4543 | } |
| 4544 | |
| 4545 | mlir::ParseResult |
| 4546 | fir::parseSelector(mlir::OpAsmParser &parser, mlir::OperationState &result, |
| 4547 | mlir::OpAsmParser::UnresolvedOperand &selector, |
| 4548 | mlir::Type &type) { |
| 4549 | if (parser.parseOperand(selector) || parser.parseColonType(type) || |
| 4550 | parser.resolveOperand(selector, type, result.operands) || |
| 4551 | parser.parseLSquare()) |
| 4552 | return mlir::failure(); |
| 4553 | return mlir::success(); |
| 4554 | } |
| 4555 | |
| 4556 | mlir::func::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, |
| 4557 | llvm::StringRef name, |
| 4558 | mlir::FunctionType type, |
| 4559 | llvm::ArrayRef<mlir::NamedAttribute> attrs, |
| 4560 | const mlir::SymbolTable *symbolTable) { |
| 4561 | if (symbolTable) |
| 4562 | if (auto f = symbolTable->lookup<mlir::func::FuncOp>(name)) { |
| 4563 | #ifdef EXPENSIVE_CHECKS |
| 4564 | assert(f == module.lookupSymbol<mlir::func::FuncOp>(name) && |
| 4565 | "symbolTable and module out of sync" ); |
| 4566 | #endif |
| 4567 | return f; |
| 4568 | } |
| 4569 | if (auto f = module.lookupSymbol<mlir::func::FuncOp>(name)) |
| 4570 | return f; |
| 4571 | mlir::OpBuilder modBuilder(module.getBodyRegion()); |
| 4572 | modBuilder.setInsertionPointToEnd(module.getBody()); |
| 4573 | auto result = modBuilder.create<mlir::func::FuncOp>(loc, name, type, attrs); |
| 4574 | result.setVisibility(mlir::SymbolTable::Visibility::Private); |
| 4575 | return result; |
| 4576 | } |
| 4577 | |
| 4578 | fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, |
| 4579 | llvm::StringRef name, mlir::Type type, |
| 4580 | llvm::ArrayRef<mlir::NamedAttribute> attrs, |
| 4581 | const mlir::SymbolTable *symbolTable) { |
| 4582 | if (symbolTable) |
| 4583 | if (auto g = symbolTable->lookup<fir::GlobalOp>(name)) { |
| 4584 | #ifdef EXPENSIVE_CHECKS |
| 4585 | assert(g == module.lookupSymbol<fir::GlobalOp>(name) && |
| 4586 | "symbolTable and module out of sync" ); |
| 4587 | #endif |
| 4588 | return g; |
| 4589 | } |
| 4590 | if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) |
| 4591 | return g; |
| 4592 | mlir::OpBuilder modBuilder(module.getBodyRegion()); |
| 4593 | auto result = modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); |
| 4594 | result.setVisibility(mlir::SymbolTable::Visibility::Private); |
| 4595 | return result; |
| 4596 | } |
| 4597 | |
| 4598 | bool fir::hasHostAssociationArgument(mlir::func::FuncOp func) { |
| 4599 | if (auto allArgAttrs = func.getAllArgAttrs()) |
| 4600 | for (auto attr : allArgAttrs) |
| 4601 | if (auto dict = mlir::dyn_cast_or_null<mlir::DictionaryAttr>(attr)) |
| 4602 | if (dict.get(fir::getHostAssocAttrName())) |
| 4603 | return true; |
| 4604 | return false; |
| 4605 | } |
| 4606 | |
| 4607 | // Test if value's definition has the specified set of |
| 4608 | // attributeNames. The value's definition is one of the operations |
| 4609 | // that are able to carry the Fortran variable attributes, e.g. |
| 4610 | // fir.alloca or fir.allocmem. Function arguments may also represent |
| 4611 | // value definitions and carry relevant attributes. |
| 4612 | // |
| 4613 | // If it is not possible to reach the limited set of definition |
| 4614 | // entities from the given value, then the function will return |
| 4615 | // std::nullopt. Otherwise, the definition is known and the return |
| 4616 | // value is computed as: |
| 4617 | // * if checkAny is true, then the function will return true |
| 4618 | // iff any of the attributeNames attributes is set on the definition. |
| 4619 | // * if checkAny is false, then the function will return true |
| 4620 | // iff all of the attributeNames attributes are set on the definition. |
| 4621 | static std::optional<bool> |
| 4622 | valueCheckFirAttributes(mlir::Value value, |
| 4623 | llvm::ArrayRef<llvm::StringRef> attributeNames, |
| 4624 | bool checkAny) { |
| 4625 | auto testAttributeSets = [&](llvm::ArrayRef<mlir::NamedAttribute> setAttrs, |
| 4626 | llvm::ArrayRef<llvm::StringRef> checkAttrs) { |
| 4627 | if (checkAny) { |
| 4628 | // Return true iff any of checkAttrs attributes is present |
| 4629 | // in setAttrs set. |
| 4630 | for (llvm::StringRef checkAttrName : checkAttrs) |
| 4631 | if (llvm::any_of(Range&: setAttrs, P: [&](mlir::NamedAttribute setAttr) { |
| 4632 | return setAttr.getName() == checkAttrName; |
| 4633 | })) |
| 4634 | return true; |
| 4635 | |
| 4636 | return false; |
| 4637 | } |
| 4638 | |
| 4639 | // Return true iff all attributes from checkAttrs are present |
| 4640 | // in setAttrs set. |
| 4641 | for (mlir::StringRef checkAttrName : checkAttrs) |
| 4642 | if (llvm::none_of(Range&: setAttrs, P: [&](mlir::NamedAttribute setAttr) { |
| 4643 | return setAttr.getName() == checkAttrName; |
| 4644 | })) |
| 4645 | return false; |
| 4646 | |
| 4647 | return true; |
| 4648 | }; |
| 4649 | // If this is a fir.box that was loaded, the fir attributes will be on the |
| 4650 | // related fir.ref<fir.box> creation. |
| 4651 | if (mlir::isa<fir::BoxType>(value.getType())) |
| 4652 | if (auto definingOp = value.getDefiningOp()) |
| 4653 | if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(definingOp)) |
| 4654 | value = loadOp.getMemref(); |
| 4655 | // If this is a function argument, look in the argument attributes. |
| 4656 | if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(Val&: value)) { |
| 4657 | if (blockArg.getOwner() && blockArg.getOwner()->isEntryBlock()) |
| 4658 | if (auto funcOp = mlir::dyn_cast<mlir::func::FuncOp>( |
| 4659 | blockArg.getOwner()->getParentOp())) |
| 4660 | return testAttributeSets( |
| 4661 | mlir::cast<mlir::FunctionOpInterface>(*funcOp).getArgAttrs( |
| 4662 | blockArg.getArgNumber()), |
| 4663 | attributeNames); |
| 4664 | |
| 4665 | // If it is not a function argument, the attributes are unknown. |
| 4666 | return std::nullopt; |
| 4667 | } |
| 4668 | |
| 4669 | if (auto definingOp = value.getDefiningOp()) { |
| 4670 | // If this is an allocated value, look at the allocation attributes. |
| 4671 | if (mlir::isa<fir::AllocMemOp>(definingOp) || |
| 4672 | mlir::isa<fir::AllocaOp>(definingOp)) |
| 4673 | return testAttributeSets(definingOp->getAttrs(), attributeNames); |
| 4674 | // If this is an imported global, look at AddrOfOp and GlobalOp attributes. |
| 4675 | // Both operations are looked at because use/host associated variable (the |
| 4676 | // AddrOfOp) can have ASYNCHRONOUS/VOLATILE attributes even if the ultimate |
| 4677 | // entity (the globalOp) does not have them. |
| 4678 | if (auto addressOfOp = mlir::dyn_cast<fir::AddrOfOp>(definingOp)) { |
| 4679 | if (testAttributeSets(addressOfOp->getAttrs(), attributeNames)) |
| 4680 | return true; |
| 4681 | if (auto module = definingOp->getParentOfType<mlir::ModuleOp>()) |
| 4682 | if (auto globalOp = |
| 4683 | module.lookupSymbol<fir::GlobalOp>(addressOfOp.getSymbol())) |
| 4684 | return testAttributeSets(globalOp->getAttrs(), attributeNames); |
| 4685 | } |
| 4686 | } |
| 4687 | // TODO: Construct associated entities attributes. Decide where the fir |
| 4688 | // attributes must be placed/looked for in this case. |
| 4689 | return std::nullopt; |
| 4690 | } |
| 4691 | |
| 4692 | bool fir::valueMayHaveFirAttributes( |
| 4693 | mlir::Value value, llvm::ArrayRef<llvm::StringRef> attributeNames) { |
| 4694 | std::optional<bool> mayHaveAttr = |
| 4695 | valueCheckFirAttributes(value, attributeNames, /*checkAny=*/true); |
| 4696 | return mayHaveAttr.value_or(true); |
| 4697 | } |
| 4698 | |
| 4699 | bool fir::valueHasFirAttribute(mlir::Value value, |
| 4700 | llvm::StringRef attributeName) { |
| 4701 | std::optional<bool> mayHaveAttr = |
| 4702 | valueCheckFirAttributes(value, {attributeName}, /*checkAny=*/false); |
| 4703 | return mayHaveAttr.value_or(false); |
| 4704 | } |
| 4705 | |
| 4706 | bool fir::anyFuncArgsHaveAttr(mlir::func::FuncOp func, llvm::StringRef attr) { |
| 4707 | for (unsigned i = 0, end = func.getNumArguments(); i < end; ++i) |
| 4708 | if (func.getArgAttr(i, attr)) |
| 4709 | return true; |
| 4710 | return false; |
| 4711 | } |
| 4712 | |
| 4713 | std::optional<std::int64_t> fir::getIntIfConstant(mlir::Value value) { |
| 4714 | if (auto *definingOp = value.getDefiningOp()) { |
| 4715 | if (auto cst = mlir::dyn_cast<mlir::arith::ConstantOp>(definingOp)) |
| 4716 | if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(cst.getValue())) |
| 4717 | return intAttr.getInt(); |
| 4718 | if (auto llConstOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(definingOp)) |
| 4719 | if (auto attr = mlir::dyn_cast<mlir::IntegerAttr>(llConstOp.getValue())) |
| 4720 | return attr.getValue().getSExtValue(); |
| 4721 | } |
| 4722 | return {}; |
| 4723 | } |
| 4724 | |
| 4725 | bool fir::isDummyArgument(mlir::Value v) { |
| 4726 | auto blockArg{mlir::dyn_cast<mlir::BlockArgument>(v)}; |
| 4727 | if (!blockArg) { |
| 4728 | auto defOp = v.getDefiningOp(); |
| 4729 | if (defOp) { |
| 4730 | if (auto declareOp = mlir::dyn_cast<fir::DeclareOp>(defOp)) |
| 4731 | if (declareOp.getDummyScope()) |
| 4732 | return true; |
| 4733 | } |
| 4734 | return false; |
| 4735 | } |
| 4736 | |
| 4737 | auto *owner{blockArg.getOwner()}; |
| 4738 | return owner->isEntryBlock() && |
| 4739 | mlir::isa<mlir::FunctionOpInterface>(owner->getParentOp()); |
| 4740 | } |
| 4741 | |
| 4742 | mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) { |
| 4743 | for (auto i = path.begin(), end = path.end(); eleTy && i < end;) { |
| 4744 | eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy) |
| 4745 | .Case<fir::RecordType>([&](fir::RecordType ty) { |
| 4746 | if (auto *op = (*i++).getDefiningOp()) { |
| 4747 | if (auto off = mlir::dyn_cast<fir::FieldIndexOp>(op)) |
| 4748 | return ty.getType(off.getFieldName()); |
| 4749 | if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op)) |
| 4750 | return ty.getType(fir::toInt(off)); |
| 4751 | } |
| 4752 | return mlir::Type{}; |
| 4753 | }) |
| 4754 | .Case<fir::SequenceType>([&](fir::SequenceType ty) { |
| 4755 | bool valid = true; |
| 4756 | const auto rank = ty.getDimension(); |
| 4757 | for (std::remove_const_t<decltype(rank)> ii = 0; |
| 4758 | valid && ii < rank; ++ii) |
| 4759 | valid = i < end && fir::isa_integer((*i++).getType()); |
| 4760 | return valid ? ty.getEleTy() : mlir::Type{}; |
| 4761 | }) |
| 4762 | .Case<mlir::TupleType>([&](mlir::TupleType ty) { |
| 4763 | if (auto *op = (*i++).getDefiningOp()) |
| 4764 | if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op)) |
| 4765 | return ty.getType(fir::toInt(off)); |
| 4766 | return mlir::Type{}; |
| 4767 | }) |
| 4768 | .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { |
| 4769 | if (fir::isa_integer((*i++).getType())) |
| 4770 | return ty.getElementType(); |
| 4771 | return mlir::Type{}; |
| 4772 | }) |
| 4773 | .Default([&](const auto &) { return mlir::Type{}; }); |
| 4774 | } |
| 4775 | return eleTy; |
| 4776 | } |
| 4777 | |
| 4778 | bool fir::reboxPreservesContinuity(fir::ReboxOp rebox, bool checkWhole) { |
| 4779 | // If slicing is not involved, then the rebox does not affect |
| 4780 | // the continuity of the array. |
| 4781 | auto sliceArg = rebox.getSlice(); |
| 4782 | if (!sliceArg) |
| 4783 | return true; |
| 4784 | |
| 4785 | if (auto sliceOp = |
| 4786 | mlir::dyn_cast_or_null<fir::SliceOp>(sliceArg.getDefiningOp())) { |
| 4787 | if (sliceOp.getFields().empty() && sliceOp.getSubstr().empty()) { |
| 4788 | // TODO: generalize code for the triples analysis with |
| 4789 | // hlfir::designatePreservesContinuity, especially when |
| 4790 | // recognition of the whole dimension slices is added. |
| 4791 | auto triples = sliceOp.getTriples(); |
| 4792 | assert((triples.size() % 3) == 0 && "invalid triples size" ); |
| 4793 | |
| 4794 | // A slice with step=1 in the innermost dimension preserves |
| 4795 | // the continuity of the array in the innermost dimension. |
| 4796 | // If checkWhole is false, then check only the innermost slice triples. |
| 4797 | std::size_t checkUpTo = checkWhole ? triples.size() : 3; |
| 4798 | checkUpTo = std::min(checkUpTo, triples.size()); |
| 4799 | for (std::size_t i = 0; i < checkUpTo; i += 3) { |
| 4800 | if (triples[i] != triples[i + 1]) { |
| 4801 | // This is a section of the dimension. Only allow it |
| 4802 | // to be the first triple. |
| 4803 | if (i != 0) |
| 4804 | return false; |
| 4805 | auto constantStep = fir::getIntIfConstant(triples[i + 2]); |
| 4806 | if (!constantStep || *constantStep != 1) |
| 4807 | return false; |
| 4808 | } |
| 4809 | } |
| 4810 | return true; |
| 4811 | } |
| 4812 | } |
| 4813 | return false; |
| 4814 | } |
| 4815 | |
| 4816 | std::optional<int64_t> fir::getAllocaByteSize(fir::AllocaOp alloca, |
| 4817 | const mlir::DataLayout &dl, |
| 4818 | const fir::KindMapping &kindMap) { |
| 4819 | mlir::Type type = alloca.getInType(); |
| 4820 | // TODO: should use the constant operands when all info is not available in |
| 4821 | // the type. |
| 4822 | if (!alloca.isDynamic()) |
| 4823 | if (auto sizeAndAlignment = |
| 4824 | getTypeSizeAndAlignment(alloca.getLoc(), type, dl, kindMap)) |
| 4825 | return sizeAndAlignment->first; |
| 4826 | return std::nullopt; |
| 4827 | } |
| 4828 | |
| 4829 | //===----------------------------------------------------------------------===// |
| 4830 | // DeclareOp |
| 4831 | //===----------------------------------------------------------------------===// |
| 4832 | |
| 4833 | llvm::LogicalResult fir::DeclareOp::verify() { |
| 4834 | auto fortranVar = |
| 4835 | mlir::cast<fir::FortranVariableOpInterface>(this->getOperation()); |
| 4836 | return fortranVar.verifyDeclareLikeOpImpl(getMemref()); |
| 4837 | } |
| 4838 | |
| 4839 | //===----------------------------------------------------------------------===// |
| 4840 | // PackArrayOp |
| 4841 | //===----------------------------------------------------------------------===// |
| 4842 | |
| 4843 | llvm::LogicalResult fir::PackArrayOp::verify() { |
| 4844 | mlir::Type arrayType = getArray().getType(); |
| 4845 | if (!validTypeParams(arrayType, getTypeparams(), /*allowParamsForBox=*/true)) |
| 4846 | return emitOpError("invalid type parameters" ); |
| 4847 | |
| 4848 | if (getInnermost() && fir::getBoxRank(arrayType) == 1) |
| 4849 | return emitOpError( |
| 4850 | "'innermost' is invalid for 1D arrays, use 'whole' instead" ); |
| 4851 | return mlir::success(); |
| 4852 | } |
| 4853 | |
| 4854 | void fir::PackArrayOp::getEffects( |
| 4855 | llvm::SmallVectorImpl< |
| 4856 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
| 4857 | &effects) { |
| 4858 | if (getStack()) |
| 4859 | effects.emplace_back( |
| 4860 | mlir::MemoryEffects::Allocate::get(), |
| 4861 | mlir::SideEffects::AutomaticAllocationScopeResource::get()); |
| 4862 | else |
| 4863 | effects.emplace_back(mlir::MemoryEffects::Allocate::get(), |
| 4864 | mlir::SideEffects::DefaultResource::get()); |
| 4865 | |
| 4866 | if (!getNoCopy()) |
| 4867 | effects.emplace_back(mlir::MemoryEffects::Read::get(), |
| 4868 | mlir::SideEffects::DefaultResource::get()); |
| 4869 | } |
| 4870 | |
| 4871 | static mlir::ParseResult |
| 4872 | parsePackArrayConstraints(mlir::OpAsmParser &parser, mlir::IntegerAttr &maxSize, |
| 4873 | mlir::IntegerAttr &maxElementSize, |
| 4874 | mlir::IntegerAttr &minStride) { |
| 4875 | mlir::OperationName opName = mlir::OperationName( |
| 4876 | fir::PackArrayOp::getOperationName(), parser.getContext()); |
| 4877 | struct { |
| 4878 | llvm::StringRef name; |
| 4879 | mlir::IntegerAttr &ref; |
| 4880 | } attributes[] = { |
| 4881 | {fir::PackArrayOp::getMaxSizeAttrName(opName), maxSize}, |
| 4882 | {fir::PackArrayOp::getMaxElementSizeAttrName(opName), maxElementSize}, |
| 4883 | {fir::PackArrayOp::getMinStrideAttrName(opName), minStride}}; |
| 4884 | |
| 4885 | mlir::NamedAttrList parsedAttrs; |
| 4886 | if (succeeded(Result: parser.parseOptionalAttrDict(result&: parsedAttrs))) { |
| 4887 | for (auto parsedAttr : parsedAttrs) { |
| 4888 | for (auto opAttr : attributes) { |
| 4889 | if (parsedAttr.getName() == opAttr.name) |
| 4890 | opAttr.ref = mlir::cast<mlir::IntegerAttr>(parsedAttr.getValue()); |
| 4891 | } |
| 4892 | } |
| 4893 | return mlir::success(); |
| 4894 | } |
| 4895 | return mlir::failure(); |
| 4896 | } |
| 4897 | |
| 4898 | static void printPackArrayConstraints(mlir::OpAsmPrinter &p, |
| 4899 | fir::PackArrayOp &op, |
| 4900 | const mlir::IntegerAttr &maxSize, |
| 4901 | const mlir::IntegerAttr &maxElementSize, |
| 4902 | const mlir::IntegerAttr &minStride) { |
| 4903 | llvm::SmallVector<mlir::NamedAttribute> attributes; |
| 4904 | if (maxSize) |
| 4905 | attributes.emplace_back(op.getMaxSizeAttrName(), maxSize); |
| 4906 | if (maxElementSize) |
| 4907 | attributes.emplace_back(op.getMaxElementSizeAttrName(), maxElementSize); |
| 4908 | if (minStride) |
| 4909 | attributes.emplace_back(op.getMinStrideAttrName(), minStride); |
| 4910 | |
| 4911 | p.printOptionalAttrDict(attrs: attributes); |
| 4912 | } |
| 4913 | |
| 4914 | //===----------------------------------------------------------------------===// |
| 4915 | // UnpackArrayOp |
| 4916 | //===----------------------------------------------------------------------===// |
| 4917 | |
| 4918 | llvm::LogicalResult fir::UnpackArrayOp::verify() { |
| 4919 | if (auto packOp = getTemp().getDefiningOp<fir::PackArrayOp>()) |
| 4920 | if (getStack() != packOp.getStack()) |
| 4921 | return emitOpError() << "the pack operation uses different memory for " |
| 4922 | "the temporary (stack vs heap): " |
| 4923 | << *packOp.getOperation() << "\n" ; |
| 4924 | return mlir::success(); |
| 4925 | } |
| 4926 | |
| 4927 | void fir::UnpackArrayOp::getEffects( |
| 4928 | llvm::SmallVectorImpl< |
| 4929 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
| 4930 | &effects) { |
| 4931 | if (getStack()) |
| 4932 | effects.emplace_back( |
| 4933 | mlir::MemoryEffects::Free::get(), |
| 4934 | mlir::SideEffects::AutomaticAllocationScopeResource::get()); |
| 4935 | else |
| 4936 | effects.emplace_back(mlir::MemoryEffects::Free::get(), |
| 4937 | mlir::SideEffects::DefaultResource::get()); |
| 4938 | |
| 4939 | if (!getNoCopy()) |
| 4940 | effects.emplace_back(mlir::MemoryEffects::Write::get(), |
| 4941 | mlir::SideEffects::DefaultResource::get()); |
| 4942 | } |
| 4943 | |
| 4944 | //===----------------------------------------------------------------------===// |
| 4945 | // IsContiguousBoxOp |
| 4946 | //===----------------------------------------------------------------------===// |
| 4947 | |
| 4948 | namespace { |
| 4949 | struct SimplifyIsContiguousBoxOp |
| 4950 | : public mlir::OpRewritePattern<fir::IsContiguousBoxOp> { |
| 4951 | using mlir::OpRewritePattern<fir::IsContiguousBoxOp>::OpRewritePattern; |
| 4952 | mlir::LogicalResult |
| 4953 | matchAndRewrite(fir::IsContiguousBoxOp op, |
| 4954 | mlir::PatternRewriter &rewriter) const override; |
| 4955 | }; |
| 4956 | } // namespace |
| 4957 | |
| 4958 | mlir::LogicalResult SimplifyIsContiguousBoxOp::matchAndRewrite( |
| 4959 | fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const { |
| 4960 | auto boxType = mlir::cast<fir::BaseBoxType>(op.getBox().getType()); |
| 4961 | // Nothing to do for assumed-rank arrays and !fir.box<none>. |
| 4962 | if (boxType.isAssumedRank() || fir::isBoxNone(boxType)) |
| 4963 | return mlir::failure(); |
| 4964 | |
| 4965 | if (fir::getBoxRank(boxType) == 0) { |
| 4966 | // Scalars are always contiguous. |
| 4967 | mlir::Type i1Type = rewriter.getI1Type(); |
| 4968 | rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( |
| 4969 | op, i1Type, rewriter.getIntegerAttr(i1Type, 1)); |
| 4970 | return mlir::success(); |
| 4971 | } |
| 4972 | |
| 4973 | // TODO: support more patterns, e.g. a result of fir.embox without |
| 4974 | // the slice is contiguous. We can add fir::isSimplyContiguous(box) |
| 4975 | // that walks def-use to figure it out. |
| 4976 | return mlir::failure(); |
| 4977 | } |
| 4978 | |
| 4979 | void fir::IsContiguousBoxOp::getCanonicalizationPatterns( |
| 4980 | mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { |
| 4981 | patterns.add<SimplifyIsContiguousBoxOp>(context); |
| 4982 | } |
| 4983 | |
| 4984 | //===----------------------------------------------------------------------===// |
| 4985 | // BoxTotalElementsOp |
| 4986 | //===----------------------------------------------------------------------===// |
| 4987 | |
| 4988 | namespace { |
| 4989 | struct SimplifyBoxTotalElementsOp |
| 4990 | : public mlir::OpRewritePattern<fir::BoxTotalElementsOp> { |
| 4991 | using mlir::OpRewritePattern<fir::BoxTotalElementsOp>::OpRewritePattern; |
| 4992 | mlir::LogicalResult |
| 4993 | matchAndRewrite(fir::BoxTotalElementsOp op, |
| 4994 | mlir::PatternRewriter &rewriter) const override; |
| 4995 | }; |
| 4996 | } // namespace |
| 4997 | |
| 4998 | mlir::LogicalResult SimplifyBoxTotalElementsOp::matchAndRewrite( |
| 4999 | fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const { |
| 5000 | auto boxType = mlir::cast<fir::BaseBoxType>(op.getBox().getType()); |
| 5001 | // Nothing to do for assumed-rank arrays and !fir.box<none>. |
| 5002 | if (boxType.isAssumedRank() || fir::isBoxNone(boxType)) |
| 5003 | return mlir::failure(); |
| 5004 | |
| 5005 | if (fir::getBoxRank(boxType) == 0) { |
| 5006 | // Scalar: 1 element. |
| 5007 | rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( |
| 5008 | op, op.getType(), rewriter.getIntegerAttr(op.getType(), 1)); |
| 5009 | return mlir::success(); |
| 5010 | } |
| 5011 | |
| 5012 | // TODO: support more cases, e.g. !fir.box<!fir.array<10xi32>>. |
| 5013 | return mlir::failure(); |
| 5014 | } |
| 5015 | |
| 5016 | void fir::BoxTotalElementsOp::getCanonicalizationPatterns( |
| 5017 | mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { |
| 5018 | patterns.add<SimplifyBoxTotalElementsOp>(context); |
| 5019 | } |
| 5020 | |
| 5021 | //===----------------------------------------------------------------------===// |
| 5022 | // LocalitySpecifierOp |
| 5023 | //===----------------------------------------------------------------------===// |
| 5024 | |
| 5025 | llvm::LogicalResult fir::LocalitySpecifierOp::verifyRegions() { |
| 5026 | mlir::Type argType = getArgType(); |
| 5027 | auto verifyTerminator = [&](mlir::Operation *terminator, |
| 5028 | bool yieldsValue) -> llvm::LogicalResult { |
| 5029 | if (!terminator->getBlock()->getSuccessors().empty()) |
| 5030 | return llvm::success(); |
| 5031 | |
| 5032 | if (!llvm::isa<fir::YieldOp>(terminator)) |
| 5033 | return mlir::emitError(terminator->getLoc()) |
| 5034 | << "expected exit block terminator to be an `fir.yield` op." ; |
| 5035 | |
| 5036 | YieldOp yieldOp = llvm::cast<YieldOp>(terminator); |
| 5037 | mlir::TypeRange yieldedTypes = yieldOp.getResults().getTypes(); |
| 5038 | |
| 5039 | if (!yieldsValue) { |
| 5040 | if (yieldedTypes.empty()) |
| 5041 | return llvm::success(); |
| 5042 | |
| 5043 | return mlir::emitError(terminator->getLoc()) |
| 5044 | << "Did not expect any values to be yielded." ; |
| 5045 | } |
| 5046 | |
| 5047 | if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType) |
| 5048 | return llvm::success(); |
| 5049 | |
| 5050 | auto error = mlir::emitError(yieldOp.getLoc()) |
| 5051 | << "Invalid yielded value. Expected type: " << argType |
| 5052 | << ", got: " ; |
| 5053 | |
| 5054 | if (yieldedTypes.empty()) |
| 5055 | error << "None" ; |
| 5056 | else |
| 5057 | error << yieldedTypes; |
| 5058 | |
| 5059 | return error; |
| 5060 | }; |
| 5061 | |
| 5062 | auto verifyRegion = [&](mlir::Region ®ion, unsigned expectedNumArgs, |
| 5063 | llvm::StringRef regionName, |
| 5064 | bool yieldsValue) -> llvm::LogicalResult { |
| 5065 | assert(!region.empty()); |
| 5066 | |
| 5067 | if (region.getNumArguments() != expectedNumArgs) |
| 5068 | return mlir::emitError(region.getLoc()) |
| 5069 | << "`" << regionName << "`: " |
| 5070 | << "expected " << expectedNumArgs |
| 5071 | << " region arguments, got: " << region.getNumArguments(); |
| 5072 | |
| 5073 | for (mlir::Block &block : region) { |
| 5074 | // MLIR will verify the absence of the terminator for us. |
| 5075 | if (!block.mightHaveTerminator()) |
| 5076 | continue; |
| 5077 | |
| 5078 | if (failed(verifyTerminator(block.getTerminator(), yieldsValue))) |
| 5079 | return llvm::failure(); |
| 5080 | } |
| 5081 | |
| 5082 | return llvm::success(); |
| 5083 | }; |
| 5084 | |
| 5085 | // Ensure all of the region arguments have the same type |
| 5086 | for (mlir::Region *region : getRegions()) |
| 5087 | for (mlir::Type ty : region->getArgumentTypes()) |
| 5088 | if (ty != argType) |
| 5089 | return emitError() << "Region argument type mismatch: got " << ty |
| 5090 | << " expected " << argType << "." ; |
| 5091 | |
| 5092 | mlir::Region &initRegion = getInitRegion(); |
| 5093 | if (!initRegion.empty() && |
| 5094 | failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init" , |
| 5095 | /*yieldsValue=*/true))) |
| 5096 | return llvm::failure(); |
| 5097 | |
| 5098 | LocalitySpecifierType dsType = getLocalitySpecifierType(); |
| 5099 | |
| 5100 | if (dsType == LocalitySpecifierType::Local && !getCopyRegion().empty()) |
| 5101 | return emitError("`local` specifiers do not require a `copy` region." ); |
| 5102 | |
| 5103 | if (dsType == LocalitySpecifierType::LocalInit && getCopyRegion().empty()) |
| 5104 | return emitError( |
| 5105 | "`local_init` specifiers require at least a `copy` region." ); |
| 5106 | |
| 5107 | if (dsType == LocalitySpecifierType::LocalInit && |
| 5108 | failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy" , |
| 5109 | /*yieldsValue=*/true))) |
| 5110 | return llvm::failure(); |
| 5111 | |
| 5112 | if (!getDeallocRegion().empty() && |
| 5113 | failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc" , |
| 5114 | /*yieldsValue=*/false))) |
| 5115 | return llvm::failure(); |
| 5116 | |
| 5117 | return llvm::success(); |
| 5118 | } |
| 5119 | |
| 5120 | //===----------------------------------------------------------------------===// |
| 5121 | // DoConcurrentOp |
| 5122 | //===----------------------------------------------------------------------===// |
| 5123 | |
| 5124 | llvm::LogicalResult fir::DoConcurrentOp::verify() { |
| 5125 | mlir::Block *body = getBody(); |
| 5126 | |
| 5127 | if (body->empty()) |
| 5128 | return emitOpError("body cannot be empty" ); |
| 5129 | |
| 5130 | if (!body->mightHaveTerminator() || |
| 5131 | !mlir::isa<fir::DoConcurrentLoopOp>(body->getTerminator())) |
| 5132 | return emitOpError("must be terminated by 'fir.do_concurrent.loop'" ); |
| 5133 | |
| 5134 | return mlir::success(); |
| 5135 | } |
| 5136 | |
| 5137 | //===----------------------------------------------------------------------===// |
| 5138 | // DoConcurrentLoopOp |
| 5139 | //===----------------------------------------------------------------------===// |
| 5140 | |
| 5141 | mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, |
| 5142 | mlir::OperationState &result) { |
| 5143 | auto &builder = parser.getBuilder(); |
| 5144 | // Parse an opening `(` followed by induction variables followed by `)` |
| 5145 | llvm::SmallVector<mlir::OpAsmParser::Argument, 4> regionArgs; |
| 5146 | |
| 5147 | if (parser.parseArgumentList(regionArgs, mlir::OpAsmParser::Delimiter::Paren)) |
| 5148 | return mlir::failure(); |
| 5149 | |
| 5150 | llvm::SmallVector<mlir::Type> argTypes(regionArgs.size(), |
| 5151 | builder.getIndexType()); |
| 5152 | |
| 5153 | // Parse loop bounds. |
| 5154 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower; |
| 5155 | if (parser.parseEqual() || |
| 5156 | parser.parseOperandList(lower, regionArgs.size(), |
| 5157 | mlir::OpAsmParser::Delimiter::Paren) || |
| 5158 | parser.resolveOperands(lower, builder.getIndexType(), result.operands)) |
| 5159 | return mlir::failure(); |
| 5160 | |
| 5161 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper; |
| 5162 | if (parser.parseKeyword("to" ) || |
| 5163 | parser.parseOperandList(upper, regionArgs.size(), |
| 5164 | mlir::OpAsmParser::Delimiter::Paren) || |
| 5165 | parser.resolveOperands(upper, builder.getIndexType(), result.operands)) |
| 5166 | return mlir::failure(); |
| 5167 | |
| 5168 | // Parse step values. |
| 5169 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps; |
| 5170 | if (parser.parseKeyword("step" ) || |
| 5171 | parser.parseOperandList(steps, regionArgs.size(), |
| 5172 | mlir::OpAsmParser::Delimiter::Paren) || |
| 5173 | parser.resolveOperands(steps, builder.getIndexType(), result.operands)) |
| 5174 | return mlir::failure(); |
| 5175 | |
| 5176 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands; |
| 5177 | llvm::SmallVector<mlir::Type> reduceArgTypes; |
| 5178 | if (succeeded(parser.parseOptionalKeyword("reduce" ))) { |
| 5179 | // Parse reduction attributes and variables. |
| 5180 | llvm::SmallVector<fir::ReduceAttr> attributes; |
| 5181 | if (failed(parser.parseCommaSeparatedList( |
| 5182 | mlir::AsmParser::Delimiter::Paren, [&]() { |
| 5183 | if (parser.parseAttribute(attributes.emplace_back()) || |
| 5184 | parser.parseArrow() || |
| 5185 | parser.parseOperand(reduceOperands.emplace_back()) || |
| 5186 | parser.parseColonType(reduceArgTypes.emplace_back())) |
| 5187 | return mlir::failure(); |
| 5188 | return mlir::success(); |
| 5189 | }))) |
| 5190 | return mlir::failure(); |
| 5191 | // Resolve input operands. |
| 5192 | for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes)) |
| 5193 | if (parser.resolveOperand(std::get<0>(operand_type), |
| 5194 | std::get<1>(operand_type), result.operands)) |
| 5195 | return mlir::failure(); |
| 5196 | llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), |
| 5197 | attributes.end()); |
| 5198 | result.addAttribute(getReduceAttrsAttrName(result.name), |
| 5199 | builder.getArrayAttr(arrayAttr)); |
| 5200 | } |
| 5201 | |
| 5202 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> localOperands; |
| 5203 | if (succeeded(parser.parseOptionalKeyword("local" ))) { |
| 5204 | std::size_t oldArgTypesSize = argTypes.size(); |
| 5205 | if (failed(parser.parseLParen())) |
| 5206 | return mlir::failure(); |
| 5207 | |
| 5208 | llvm::SmallVector<mlir::SymbolRefAttr> localSymbolVec; |
| 5209 | if (failed(parser.parseCommaSeparatedList([&]() { |
| 5210 | if (failed(parser.parseAttribute(localSymbolVec.emplace_back()))) |
| 5211 | return mlir::failure(); |
| 5212 | |
| 5213 | if (parser.parseOperand(localOperands.emplace_back()) || |
| 5214 | parser.parseArrow() || |
| 5215 | parser.parseArgument(regionArgs.emplace_back())) |
| 5216 | return mlir::failure(); |
| 5217 | |
| 5218 | return mlir::success(); |
| 5219 | }))) |
| 5220 | return mlir::failure(); |
| 5221 | |
| 5222 | if (failed(parser.parseColon())) |
| 5223 | return mlir::failure(); |
| 5224 | |
| 5225 | if (failed(parser.parseCommaSeparatedList([&]() { |
| 5226 | if (failed(parser.parseType(argTypes.emplace_back()))) |
| 5227 | return mlir::failure(); |
| 5228 | |
| 5229 | return mlir::success(); |
| 5230 | }))) |
| 5231 | return mlir::failure(); |
| 5232 | |
| 5233 | if (regionArgs.size() != argTypes.size()) |
| 5234 | return parser.emitError(parser.getNameLoc(), |
| 5235 | "mismatch in number of local arg and types" ); |
| 5236 | |
| 5237 | if (failed(parser.parseRParen())) |
| 5238 | return mlir::failure(); |
| 5239 | |
| 5240 | for (auto operandType : llvm::zip_equal( |
| 5241 | localOperands, llvm::drop_begin(argTypes, oldArgTypesSize))) |
| 5242 | if (parser.resolveOperand(std::get<0>(operandType), |
| 5243 | std::get<1>(operandType), result.operands)) |
| 5244 | return mlir::failure(); |
| 5245 | |
| 5246 | llvm::SmallVector<mlir::Attribute> symbolAttrs(localSymbolVec.begin(), |
| 5247 | localSymbolVec.end()); |
| 5248 | result.addAttribute(getLocalSymsAttrName(result.name), |
| 5249 | builder.getArrayAttr(symbolAttrs)); |
| 5250 | } |
| 5251 | |
| 5252 | // Set `operandSegmentSizes` attribute. |
| 5253 | result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(), |
| 5254 | builder.getDenseI32ArrayAttr( |
| 5255 | {static_cast<int32_t>(lower.size()), |
| 5256 | static_cast<int32_t>(upper.size()), |
| 5257 | static_cast<int32_t>(steps.size()), |
| 5258 | static_cast<int32_t>(reduceOperands.size()), |
| 5259 | static_cast<int32_t>(localOperands.size())})); |
| 5260 | |
| 5261 | // Now parse the body. |
| 5262 | for (auto [arg, type] : llvm::zip_equal(regionArgs, argTypes)) |
| 5263 | arg.type = type; |
| 5264 | |
| 5265 | mlir::Region *body = result.addRegion(); |
| 5266 | if (parser.parseRegion(*body, regionArgs)) |
| 5267 | return mlir::failure(); |
| 5268 | |
| 5269 | // Parse attributes. |
| 5270 | if (parser.parseOptionalAttrDict(result.attributes)) |
| 5271 | return mlir::failure(); |
| 5272 | |
| 5273 | return mlir::success(); |
| 5274 | } |
| 5275 | |
| 5276 | void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) { |
| 5277 | p << " (" << getBody()->getArguments().slice(0, getNumInductionVars()) |
| 5278 | << ") = (" << getLowerBound() << ") to (" << getUpperBound() << ") step (" |
| 5279 | << getStep() << ")" ; |
| 5280 | |
| 5281 | if (!getReduceOperands().empty()) { |
| 5282 | p << " reduce(" ; |
| 5283 | auto attrs = getReduceAttrsAttr(); |
| 5284 | auto operands = getReduceOperands(); |
| 5285 | llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) { |
| 5286 | p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " |
| 5287 | << std::get<1>(it).getType(); |
| 5288 | }); |
| 5289 | p << ')'; |
| 5290 | } |
| 5291 | |
| 5292 | if (!getLocalVars().empty()) { |
| 5293 | p << " local(" ; |
| 5294 | llvm::interleaveComma(llvm::zip_equal(getLocalSymsAttr(), getLocalVars(), |
| 5295 | getRegionLocalArgs()), |
| 5296 | p, [&](auto it) { |
| 5297 | p << std::get<0>(it) << " " << std::get<1>(it) |
| 5298 | << " -> " << std::get<2>(it); |
| 5299 | }); |
| 5300 | p << " : " ; |
| 5301 | llvm::interleaveComma(getLocalVars(), p, |
| 5302 | [&](auto it) { p << it.getType(); }); |
| 5303 | p << ")" ; |
| 5304 | } |
| 5305 | |
| 5306 | p << ' '; |
| 5307 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
| 5308 | p.printOptionalAttrDict( |
| 5309 | (*this)->getAttrs(), |
| 5310 | /*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(), |
| 5311 | DoConcurrentLoopOp::getReduceAttrsAttrName(), |
| 5312 | DoConcurrentLoopOp::getLocalSymsAttrName()}); |
| 5313 | } |
| 5314 | |
| 5315 | llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions() { |
| 5316 | return {&getRegion()}; |
| 5317 | } |
| 5318 | |
| 5319 | llvm::LogicalResult fir::DoConcurrentLoopOp::verify() { |
| 5320 | mlir::Operation::operand_range lbValues = getLowerBound(); |
| 5321 | mlir::Operation::operand_range ubValues = getUpperBound(); |
| 5322 | mlir::Operation::operand_range stepValues = getStep(); |
| 5323 | mlir::Operation::operand_range localVars = getLocalVars(); |
| 5324 | |
| 5325 | if (lbValues.empty()) |
| 5326 | return emitOpError( |
| 5327 | "needs at least one tuple element for lowerBound, upperBound and step" ); |
| 5328 | |
| 5329 | if (lbValues.size() != ubValues.size() || |
| 5330 | ubValues.size() != stepValues.size()) |
| 5331 | return emitOpError("different number of tuple elements for lowerBound, " |
| 5332 | "upperBound or step" ); |
| 5333 | |
| 5334 | // Check that the body defines the same number of block arguments as the |
| 5335 | // number of tuple elements in step. |
| 5336 | mlir::Block *body = getBody(); |
| 5337 | unsigned numIndVarArgs = body->getNumArguments() - localVars.size(); |
| 5338 | |
| 5339 | if (numIndVarArgs != stepValues.size()) |
| 5340 | return emitOpError() << "expects the same number of induction variables: " |
| 5341 | << body->getNumArguments() |
| 5342 | << " as bound and step values: " << stepValues.size(); |
| 5343 | for (auto arg : body->getArguments().slice(0, numIndVarArgs)) |
| 5344 | if (!arg.getType().isIndex()) |
| 5345 | return emitOpError( |
| 5346 | "expects arguments for the induction variable to be of index type" ); |
| 5347 | |
| 5348 | auto reduceAttrs = getReduceAttrsAttr(); |
| 5349 | if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0)) |
| 5350 | return emitOpError( |
| 5351 | "mismatch in number of reduction variables and reduction attributes" ); |
| 5352 | |
| 5353 | return mlir::success(); |
| 5354 | } |
| 5355 | |
| 5356 | std::optional<llvm::SmallVector<mlir::Value>> |
| 5357 | fir::DoConcurrentLoopOp::getLoopInductionVars() { |
| 5358 | return llvm::SmallVector<mlir::Value>{ |
| 5359 | getBody()->getArguments().slice(0, getLowerBound().size())}; |
| 5360 | } |
| 5361 | |
| 5362 | //===----------------------------------------------------------------------===// |
| 5363 | // FIROpsDialect |
| 5364 | //===----------------------------------------------------------------------===// |
| 5365 | |
| 5366 | void fir::FIROpsDialect::registerOpExternalInterfaces() { |
| 5367 | // Attach default declare target interfaces to operations which can be marked |
| 5368 | // as declare target. |
| 5369 | fir::GlobalOp::attachInterface< |
| 5370 | mlir::omp::DeclareTargetDefaultModel<fir::GlobalOp>>(*getContext()); |
| 5371 | } |
| 5372 | |
| 5373 | // Tablegen operators |
| 5374 | |
| 5375 | #define GET_OP_CLASSES |
| 5376 | #include "flang/Optimizer/Dialect/FIROps.cpp.inc" |
| 5377 | |