| 1 | //===-- Utils..cpp ----------------------------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "Utils.h" |
| 14 | |
| 15 | #include "Clauses.h" |
| 16 | |
| 17 | #include "ClauseFinder.h" |
| 18 | #include <flang/Lower/AbstractConverter.h> |
| 19 | #include <flang/Lower/ConvertType.h> |
| 20 | #include <flang/Lower/DirectivesCommon.h> |
| 21 | #include <flang/Lower/PFTBuilder.h> |
| 22 | #include <flang/Optimizer/Builder/FIRBuilder.h> |
| 23 | #include <flang/Optimizer/Builder/Todo.h> |
| 24 | #include <flang/Parser/parse-tree.h> |
| 25 | #include <flang/Parser/tools.h> |
| 26 | #include <flang/Semantics/tools.h> |
| 27 | #include <llvm/Support/CommandLine.h> |
| 28 | |
| 29 | #include <iterator> |
| 30 | |
| 31 | llvm::cl::opt<bool> treatIndexAsSection( |
| 32 | "openmp-treat-index-as-section" , |
| 33 | llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`." ), |
| 34 | llvm::cl::init(Val: true)); |
| 35 | |
| 36 | namespace Fortran { |
| 37 | namespace lower { |
| 38 | namespace omp { |
| 39 | |
| 40 | int64_t getCollapseValue(const List<Clause> &clauses) { |
| 41 | auto iter = llvm::find_if(Range: clauses, P: [](const Clause &clause) { |
| 42 | return clause.id == llvm::omp::Clause::OMPC_collapse; |
| 43 | }); |
| 44 | if (iter != clauses.end()) { |
| 45 | const auto &collapse = std::get<clause::Collapse>(iter->u); |
| 46 | return evaluate::ToInt64(collapse.v).value(); |
| 47 | } |
| 48 | return 1; |
| 49 | } |
| 50 | |
| 51 | void genObjectList(const ObjectList &objects, |
| 52 | lower::AbstractConverter &converter, |
| 53 | llvm::SmallVectorImpl<mlir::Value> &operands) { |
| 54 | for (const Object &object : objects) { |
| 55 | const semantics::Symbol *sym = object.sym(); |
| 56 | assert(sym && "Expected Symbol" ); |
| 57 | if (mlir::Value variable = converter.getSymbolAddress(*sym)) { |
| 58 | operands.push_back(variable); |
| 59 | } else if (const auto *details = |
| 60 | sym->detailsIf<semantics::HostAssocDetails>()) { |
| 61 | operands.push_back(converter.getSymbolAddress(details->symbol())); |
| 62 | converter.copySymbolBinding(details->symbol(), *sym); |
| 63 | } |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | mlir::Type getLoopVarType(lower::AbstractConverter &converter, |
| 68 | std::size_t loopVarTypeSize) { |
| 69 | // OpenMP runtime requires 32-bit or 64-bit loop variables. |
| 70 | loopVarTypeSize = loopVarTypeSize * 8; |
| 71 | if (loopVarTypeSize < 32) { |
| 72 | loopVarTypeSize = 32; |
| 73 | } else if (loopVarTypeSize > 64) { |
| 74 | loopVarTypeSize = 64; |
| 75 | mlir::emitWarning(converter.getCurrentLocation(), |
| 76 | "OpenMP loop iteration variable cannot have more than 64 " |
| 77 | "bits size and will be narrowed into 64 bits." ); |
| 78 | } |
| 79 | assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) && |
| 80 | "OpenMP loop iteration variable size must be transformed into 32-bit " |
| 81 | "or 64-bit" ); |
| 82 | return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize); |
| 83 | } |
| 84 | |
| 85 | semantics::Symbol * |
| 86 | getIterationVariableSymbol(const lower::pft::Evaluation &eval) { |
| 87 | return eval.visit(common::visitors{ |
| 88 | [&](const parser::DoConstruct &doLoop) { |
| 89 | if (const auto &maybeCtrl = doLoop.GetLoopControl()) { |
| 90 | using LoopControl = parser::LoopControl; |
| 91 | if (auto *bounds = std::get_if<LoopControl::Bounds>(&maybeCtrl->u)) { |
| 92 | static_assert(std::is_same_v<decltype(bounds->name), |
| 93 | parser::Scalar<parser::Name>>); |
| 94 | return bounds->name.thing.symbol; |
| 95 | } |
| 96 | } |
| 97 | return static_cast<semantics::Symbol *>(nullptr); |
| 98 | }, |
| 99 | [](auto &&) { return static_cast<semantics::Symbol *>(nullptr); }, |
| 100 | }); |
| 101 | } |
| 102 | |
| 103 | void gatherFuncAndVarSyms( |
| 104 | const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, |
| 105 | llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { |
| 106 | for (const Object &object : objects) |
| 107 | symbolAndClause.emplace_back(clause, *object.sym()); |
| 108 | } |
| 109 | |
| 110 | mlir::omp::MapInfoOp |
| 111 | createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, |
| 112 | mlir::Value baseAddr, mlir::Value varPtrPtr, |
| 113 | llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds, |
| 114 | llvm::ArrayRef<mlir::Value> members, |
| 115 | mlir::ArrayAttr membersIndex, uint64_t mapType, |
| 116 | mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, |
| 117 | bool partialMap, mlir::FlatSymbolRefAttr mapperId) { |
| 118 | if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) { |
| 119 | baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr); |
| 120 | retTy = baseAddr.getType(); |
| 121 | } |
| 122 | |
| 123 | mlir::TypeAttr varType = mlir::TypeAttr::get( |
| 124 | llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType()); |
| 125 | |
| 126 | // For types with unknown extents such as <2x?xi32> we discard the incomplete |
| 127 | // type info and only retain the base type. The correct dimensions are later |
| 128 | // recovered through the bounds info. |
| 129 | if (auto seqType = llvm::dyn_cast<fir::SequenceType>(varType.getValue())) |
| 130 | if (seqType.hasDynamicExtents()) |
| 131 | varType = mlir::TypeAttr::get(seqType.getEleTy()); |
| 132 | |
| 133 | mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>( |
| 134 | loc, retTy, baseAddr, varType, |
| 135 | builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), |
| 136 | builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType), |
| 137 | varPtrPtr, members, membersIndex, bounds, mapperId, |
| 138 | builder.getStringAttr(name), builder.getBoolAttr(partialMap)); |
| 139 | return op; |
| 140 | } |
| 141 | |
| 142 | // This function gathers the individual omp::Object's that make up a |
| 143 | // larger omp::Object symbol. |
| 144 | // |
| 145 | // For example, provided the larger symbol: "parent%child%member", this |
| 146 | // function breaks it up into its constituent components ("parent", |
| 147 | // "child", "member"), so we can access each individual component and |
| 148 | // introspect details. Important to note is this function breaks it up from |
| 149 | // RHS to LHS ("member" to "parent") and then we reverse it so that the |
| 150 | // returned omp::ObjectList is LHS to RHS, with the "parent" at the |
| 151 | // beginning. |
| 152 | omp::ObjectList gatherObjectsOf(omp::Object derivedTypeMember, |
| 153 | semantics::SemanticsContext &semaCtx) { |
| 154 | omp::ObjectList objList; |
| 155 | std::optional<omp::Object> baseObj = derivedTypeMember; |
| 156 | while (baseObj.has_value()) { |
| 157 | objList.push_back(baseObj.value()); |
| 158 | baseObj = getBaseObject(baseObj.value(), semaCtx); |
| 159 | } |
| 160 | return omp::ObjectList{llvm::reverse(objList)}; |
| 161 | } |
| 162 | |
| 163 | // This function generates a series of indices from a provided omp::Object, |
| 164 | // that devolves to an ArrayRef symbol, e.g. "array(2,3,4)", this function |
| 165 | // would generate a series of indices of "[1][2][3]" for the above example, |
| 166 | // offsetting by -1 to account for the non-zero fortran indexes. |
| 167 | // |
| 168 | // These indices can then be provided to a coordinate operation or other |
| 169 | // GEP-like operation to access the relevant positional member of the |
| 170 | // array. |
| 171 | // |
| 172 | // It is of note that the function only supports subscript integers currently |
| 173 | // and not Triplets i.e. Array(1:2:3). |
| 174 | static void generateArrayIndices(lower::AbstractConverter &converter, |
| 175 | fir::FirOpBuilder &firOpBuilder, |
| 176 | lower::StatementContext &stmtCtx, |
| 177 | mlir::Location clauseLocation, |
| 178 | llvm::SmallVectorImpl<mlir::Value> &indices, |
| 179 | omp::Object object) { |
| 180 | auto maybeRef = evaluate::ExtractDataRef(*object.ref()); |
| 181 | if (!maybeRef) |
| 182 | return; |
| 183 | |
| 184 | auto *arr = std::get_if<evaluate::ArrayRef>(&maybeRef->u); |
| 185 | if (!arr) |
| 186 | return; |
| 187 | |
| 188 | for (auto v : arr->subscript()) { |
| 189 | if (std::holds_alternative<Triplet>(v.u)) |
| 190 | TODO(clauseLocation, "Triplet indexing in map clause is unsupported" ); |
| 191 | |
| 192 | auto expr = std::get<Fortran::evaluate::IndirectSubscriptIntegerExpr>(v.u); |
| 193 | mlir::Value subscript = |
| 194 | fir::getBase(converter.genExprValue(toEvExpr(expr.value()), stmtCtx)); |
| 195 | mlir::Value one = firOpBuilder.createIntegerConstant( |
| 196 | clauseLocation, firOpBuilder.getIndexType(), 1); |
| 197 | subscript = firOpBuilder.createConvert( |
| 198 | clauseLocation, firOpBuilder.getIndexType(), subscript); |
| 199 | indices.push_back(firOpBuilder.create<mlir::arith::SubIOp>(clauseLocation, |
| 200 | subscript, one)); |
| 201 | } |
| 202 | } |
| 203 | |
| 204 | /// When mapping members of derived types, there is a chance that one of the |
| 205 | /// members along the way to a mapped member is an descriptor. In which case |
| 206 | /// we have to make sure we generate a map for those along the way otherwise |
| 207 | /// we will be missing a chunk of data required to actually map the member |
| 208 | /// type to device. This function effectively generates these maps and the |
| 209 | /// appropriate data accesses required to generate these maps. It will avoid |
| 210 | /// creating duplicate maps, as duplicates are just as bad as unmapped |
| 211 | /// descriptor data in a lot of cases for the runtime (and unnecessary |
| 212 | /// data movement should be avoided where possible). |
| 213 | /// |
| 214 | /// As an example for the following mapping: |
| 215 | /// |
| 216 | /// type :: vertexes |
| 217 | /// integer(4), allocatable :: vertexx(:) |
| 218 | /// integer(4), allocatable :: vertexy(:) |
| 219 | /// end type vertexes |
| 220 | /// |
| 221 | /// type :: dtype |
| 222 | /// real(4) :: i |
| 223 | /// type(vertexes), allocatable :: vertexes(:) |
| 224 | /// end type dtype |
| 225 | /// |
| 226 | /// type(dtype), allocatable :: alloca_dtype |
| 227 | /// |
| 228 | /// !$omp target map(tofrom: alloca_dtype%vertexes(N1)%vertexx) |
| 229 | /// |
| 230 | /// The below HLFIR/FIR is generated (trimmed for conciseness): |
| 231 | /// |
| 232 | /// On the first iteration we index into the record type alloca_dtype |
| 233 | /// to access "vertexes", we then generate a map for this descriptor |
| 234 | /// alongside bounds to indicate we only need the 1 member, rather than |
| 235 | /// the whole array block in this case (In theory we could map its |
| 236 | /// entirety at the cost of data transfer bandwidth). |
| 237 | /// |
| 238 | /// %13:2 = hlfir.declare ... "alloca_dtype" ... |
| 239 | /// %39 = fir.load %13#0 : ... |
| 240 | /// %40 = fir.coordinate_of %39, %c1 : ... |
| 241 | /// %51 = omp.map.info var_ptr(%40 : ...) map_clauses(to) capture(ByRef) ... |
| 242 | /// %52 = fir.load %40 : ... |
| 243 | /// |
| 244 | /// Second iteration generating access to "vertexes(N1) utilising the N1 index |
| 245 | /// %53 = load N1 ... |
| 246 | /// %54 = fir.convert %53 : (i32) -> i64 |
| 247 | /// %55 = fir.convert %54 : (i64) -> index |
| 248 | /// %56 = arith.subi %55, %c1 : index |
| 249 | /// %57 = fir.coordinate_of %52, %56 : ... |
| 250 | /// |
| 251 | /// Still in the second iteration we access the allocatable member "vertexx", |
| 252 | /// we return %58 from the function and provide it to the final and "main" |
| 253 | /// map of processMap (generated by the record type segment of the below |
| 254 | /// function), if this were not the final symbol in the list, i.e. we accessed |
| 255 | /// a member below vertexx, we would have generated the map below as we did in |
| 256 | /// the first iteration and then continue to generate further coordinates to |
| 257 | /// access further components as required. |
| 258 | /// |
| 259 | /// %58 = fir.coordinate_of %57, %c0 : ... |
| 260 | /// %61 = omp.map.info var_ptr(%58 : ...) map_clauses(to) capture(ByRef) ... |
| 261 | /// |
| 262 | /// Parent mapping containing prior generated mapped members, generated at |
| 263 | /// a later step but here to showcase the "end" result |
| 264 | /// |
| 265 | /// omp.map.info var_ptr(%13#1 : ...) map_clauses(to) capture(ByRef) |
| 266 | /// members(%50, %61 : [0, 1, 0], [0, 1, 0] : ... |
| 267 | /// |
| 268 | /// \param objectList - The list of omp::Object symbol data for each parent |
| 269 | /// to the mapped member (also includes the mapped member), generated via |
| 270 | /// gatherObjectsOf. |
| 271 | /// \param indices - List of index data associated with the mapped member |
| 272 | /// symbol, which identifies the placement of the member in its parent, |
| 273 | /// this helps generate the appropriate member accesses. These indices |
| 274 | /// can be generated via generateMemberPlacementIndices. |
| 275 | /// \param asFortran - A string generated from the mapped variable to be |
| 276 | /// associated with the main map, generally (but not restricted to) |
| 277 | /// generated via gatherDataOperandAddrAndBounds or other |
| 278 | /// DirectiveCommons.hpp utilities. |
| 279 | /// \param mapTypeBits - The map flags that will be associated with the |
| 280 | /// generated maps, minus alterations of the TO and FROM bits for the |
| 281 | /// intermediate components to prevent accidental overwriting on device |
| 282 | /// write back. |
| 283 | mlir::Value createParentSymAndGenIntermediateMaps( |
| 284 | mlir::Location clauseLocation, lower::AbstractConverter &converter, |
| 285 | semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, |
| 286 | omp::ObjectList &objectList, llvm::SmallVectorImpl<int64_t> &indices, |
| 287 | OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran, |
| 288 | llvm::omp::OpenMPOffloadMappingFlags mapTypeBits) { |
| 289 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
| 290 | |
| 291 | /// Checks if an omp::Object is an array expression with a subscript, e.g. |
| 292 | /// array(1,2). |
| 293 | auto isArrayExprWithSubscript = [](omp::Object obj) { |
| 294 | if (auto maybeRef = evaluate::ExtractDataRef(obj.ref())) { |
| 295 | evaluate::DataRef ref = *maybeRef; |
| 296 | if (auto *arr = std::get_if<evaluate::ArrayRef>(&ref.u)) |
| 297 | return !arr->subscript().empty(); |
| 298 | } |
| 299 | return false; |
| 300 | }; |
| 301 | |
| 302 | // Generate the access to the original parent base address. |
| 303 | fir::factory::AddrAndBoundsInfo parentBaseAddr = |
| 304 | lower::getDataOperandBaseAddr(converter, firOpBuilder, |
| 305 | *objectList[0].sym(), clauseLocation); |
| 306 | mlir::Value curValue = parentBaseAddr.addr; |
| 307 | |
| 308 | // Iterate over all objects in the objectList, this should consist of all |
| 309 | // record types between the parent and the member being mapped (including |
| 310 | // the parent). The object list may also contain array objects as well, |
| 311 | // this can occur when specifying bounds or a specific element access |
| 312 | // within a member map, we skip these. |
| 313 | size_t currentIndicesIdx = 0; |
| 314 | for (size_t i = 0; i < objectList.size(); ++i) { |
| 315 | // If we encounter a sequence type, i.e. an array, we must generate the |
| 316 | // correct coordinate operation to index into the array to proceed further, |
| 317 | // this is only relevant in cases where we encounter subscripts currently. |
| 318 | // |
| 319 | // For example in the following case: |
| 320 | // |
| 321 | // map(tofrom: array_dtype(4)%internal_dtypes(3)%float_elements(4)) |
| 322 | // |
| 323 | // We must generate coordinate operation accesses for each subscript |
| 324 | // we encounter. |
| 325 | if (fir::SequenceType arrType = mlir::dyn_cast<fir::SequenceType>( |
| 326 | fir::unwrapPassByRefType(curValue.getType()))) { |
| 327 | if (isArrayExprWithSubscript(objectList[i])) { |
| 328 | llvm::SmallVector<mlir::Value> subscriptIndices; |
| 329 | generateArrayIndices(converter, firOpBuilder, stmtCtx, clauseLocation, |
| 330 | subscriptIndices, objectList[i]); |
| 331 | assert(!subscriptIndices.empty() && |
| 332 | "missing expected indices for map clause" ); |
| 333 | curValue = firOpBuilder.create<fir::CoordinateOp>( |
| 334 | clauseLocation, firOpBuilder.getRefType(arrType.getEleTy()), |
| 335 | curValue, subscriptIndices); |
| 336 | } |
| 337 | } |
| 338 | |
| 339 | // If we encounter a record type, we must access the subsequent member |
| 340 | // by indexing into it and creating a coordinate operation to do so, we |
| 341 | // utilise the index information generated previously and passed in to |
| 342 | // work out the correct member to access and the corresponding member |
| 343 | // type. |
| 344 | if (fir::RecordType recordType = mlir::dyn_cast<fir::RecordType>( |
| 345 | fir::unwrapPassByRefType(curValue.getType()))) { |
| 346 | fir::IntOrValue idxConst = mlir::IntegerAttr::get( |
| 347 | firOpBuilder.getI32Type(), indices[currentIndicesIdx]); |
| 348 | mlir::Type memberTy = recordType.getType(indices[currentIndicesIdx]); |
| 349 | curValue = firOpBuilder.create<fir::CoordinateOp>( |
| 350 | clauseLocation, firOpBuilder.getRefType(memberTy), curValue, |
| 351 | llvm::SmallVector<fir::IntOrValue, 1>{idxConst}); |
| 352 | |
| 353 | // If we're a final member, the map will be generated by the processMap |
| 354 | // call that invoked this function. |
| 355 | if (currentIndicesIdx == indices.size() - 1) |
| 356 | break; |
| 357 | |
| 358 | // Skip mapping and the subsequent load if we're not |
| 359 | // a type with a descriptor such as a pointer/allocatable. If we're not a |
| 360 | // type with a descriptor then we have no need of generating an |
| 361 | // intermediate map for it, as we only need to generate a map if a member |
| 362 | // is a descriptor type (and thus obscures the members it contains via a |
| 363 | // pointer in which it's data needs mapped). |
| 364 | if (!fir::isTypeWithDescriptor(memberTy)) { |
| 365 | currentIndicesIdx++; |
| 366 | continue; |
| 367 | } |
| 368 | |
| 369 | llvm::SmallVector<int64_t> interimIndices( |
| 370 | indices.begin(), std::next(x: indices.begin(), n: currentIndicesIdx + 1)); |
| 371 | // Verify we haven't already created a map for this particular member, by |
| 372 | // checking the list of members already mapped for the current parent, |
| 373 | // stored in the parentMemberIndices structure |
| 374 | if (!parentMemberIndices.isDuplicateMemberMapInfo(memberIndices&: interimIndices)) { |
| 375 | // Generate bounds operations using the standard lowering utility, |
| 376 | // unfortunately this currently does a bit more than just generate |
| 377 | // bounds and we discard the other bits. May be useful to extend the |
| 378 | // utility to just provide bounds in the future. |
| 379 | llvm::SmallVector<mlir::Value> interimBounds; |
| 380 | if (i + 1 < objectList.size() && |
| 381 | objectList[i + 1].sym()->IsObjectArray()) { |
| 382 | std::stringstream interimFortran; |
| 383 | Fortran::lower::gatherDataOperandAddrAndBounds< |
| 384 | mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>( |
| 385 | converter, converter.getFirOpBuilder(), semaCtx, |
| 386 | converter.getFctCtx(), *objectList[i + 1].sym(), |
| 387 | objectList[i + 1].ref(), clauseLocation, interimFortran, |
| 388 | interimBounds, treatIndexAsSection); |
| 389 | } |
| 390 | |
| 391 | // Remove all map-type bits (e.g. TO, FROM, etc.) from the intermediate |
| 392 | // allocatable maps, as we simply wish to alloc or release them. It may |
| 393 | // be safer to just pass OMP_MAP_NONE as the map type, but we may still |
| 394 | // need some of the other map types the mapped member utilises, so for |
| 395 | // now it's good to keep an eye on this. |
| 396 | llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits; |
| 397 | interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; |
| 398 | interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; |
| 399 | interimMapType &= |
| 400 | ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; |
| 401 | |
| 402 | // Create a map for the intermediate member and insert it and it's |
| 403 | // indices into the parentMemberIndices list to track it. |
| 404 | mlir::omp::MapInfoOp mapOp = createMapInfoOp( |
| 405 | firOpBuilder, clauseLocation, curValue, |
| 406 | /*varPtrPtr=*/mlir::Value{}, asFortran, |
| 407 | /*bounds=*/interimBounds, |
| 408 | /*members=*/{}, |
| 409 | /*membersIndex=*/mlir::ArrayAttr{}, |
| 410 | static_cast< |
| 411 | std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( |
| 412 | interimMapType), |
| 413 | mlir::omp::VariableCaptureKind::ByRef, curValue.getType()); |
| 414 | |
| 415 | parentMemberIndices.memberPlacementIndices.push_back(Elt: interimIndices); |
| 416 | parentMemberIndices.memberMap.push_back(mapOp); |
| 417 | } |
| 418 | |
| 419 | // Load the currently accessed member, so we can continue to access |
| 420 | // further segments. |
| 421 | curValue = firOpBuilder.create<fir::LoadOp>(clauseLocation, curValue); |
| 422 | currentIndicesIdx++; |
| 423 | } |
| 424 | } |
| 425 | |
| 426 | return curValue; |
| 427 | } |
| 428 | |
| 429 | static int64_t |
| 430 | getComponentPlacementInParent(const semantics::Symbol *componentSym) { |
| 431 | const auto *derived = componentSym->owner() |
| 432 | .derivedTypeSpec() |
| 433 | ->typeSymbol() |
| 434 | .detailsIf<semantics::DerivedTypeDetails>(); |
| 435 | assert(derived && |
| 436 | "expected derived type details when processing component symbol" ); |
| 437 | for (auto [placement, name] : llvm::enumerate(derived->componentNames())) |
| 438 | if (name == componentSym->name()) |
| 439 | return placement; |
| 440 | return -1; |
| 441 | } |
| 442 | |
| 443 | static std::optional<Object> |
| 444 | getComponentObject(std::optional<Object> object, |
| 445 | semantics::SemanticsContext &semaCtx) { |
| 446 | if (!object) |
| 447 | return std::nullopt; |
| 448 | |
| 449 | auto ref = evaluate::ExtractDataRef(object.value().ref()); |
| 450 | if (!ref) |
| 451 | return std::nullopt; |
| 452 | |
| 453 | if (std::holds_alternative<evaluate::Component>(ref->u)) |
| 454 | return object; |
| 455 | |
| 456 | auto baseObj = getBaseObject(object.value(), semaCtx); |
| 457 | if (!baseObj) |
| 458 | return std::nullopt; |
| 459 | |
| 460 | return getComponentObject(baseObj.value(), semaCtx); |
| 461 | } |
| 462 | |
| 463 | void generateMemberPlacementIndices(const Object &object, |
| 464 | llvm::SmallVectorImpl<int64_t> &indices, |
| 465 | semantics::SemanticsContext &semaCtx) { |
| 466 | assert(indices.empty() && "indices vector passed to " |
| 467 | "generateMemberPlacementIndices should be empty" ); |
| 468 | auto compObj = getComponentObject(object, semaCtx); |
| 469 | |
| 470 | while (compObj) { |
| 471 | int64_t index = getComponentPlacementInParent(compObj->sym()); |
| 472 | assert( |
| 473 | index >= 0 && |
| 474 | "unexpected index value returned from getComponentPlacementInParent" ); |
| 475 | indices.push_back(Elt: index); |
| 476 | compObj = |
| 477 | getComponentObject(getBaseObject(compObj.value(), semaCtx), semaCtx); |
| 478 | } |
| 479 | |
| 480 | indices = llvm::SmallVector<int64_t>{llvm::reverse(C&: indices)}; |
| 481 | } |
| 482 | |
| 483 | void OmpMapParentAndMemberData::addChildIndexAndMapToParent( |
| 484 | const omp::Object &object, mlir::omp::MapInfoOp &mapOp, |
| 485 | semantics::SemanticsContext &semaCtx) { |
| 486 | llvm::SmallVector<int64_t> indices; |
| 487 | generateMemberPlacementIndices(object, indices, semaCtx); |
| 488 | memberPlacementIndices.push_back(Elt: indices); |
| 489 | memberMap.push_back(mapOp); |
| 490 | } |
| 491 | |
| 492 | bool isMemberOrParentAllocatableOrPointer( |
| 493 | const Object &object, semantics::SemanticsContext &semaCtx) { |
| 494 | if (semantics::IsAllocatableOrObjectPointer(object.sym())) |
| 495 | return true; |
| 496 | |
| 497 | auto compObj = getBaseObject(object, semaCtx); |
| 498 | while (compObj) { |
| 499 | if (semantics::IsAllocatableOrObjectPointer(compObj.value().sym())) |
| 500 | return true; |
| 501 | compObj = getBaseObject(compObj.value(), semaCtx); |
| 502 | } |
| 503 | |
| 504 | return false; |
| 505 | } |
| 506 | |
| 507 | void insertChildMapInfoIntoParent( |
| 508 | lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, |
| 509 | lower::StatementContext &stmtCtx, |
| 510 | std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices, |
| 511 | llvm::SmallVectorImpl<mlir::Value> &mapOperands, |
| 512 | llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) { |
| 513 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
| 514 | for (auto indices : parentMemberIndices) { |
| 515 | auto *parentIter = |
| 516 | llvm::find_if(mapSyms, [&indices](const semantics::Symbol *v) { |
| 517 | return v == indices.first.sym(); |
| 518 | }); |
| 519 | if (parentIter != mapSyms.end()) { |
| 520 | auto mapOp = llvm::cast<mlir::omp::MapInfoOp>( |
| 521 | mapOperands[std::distance(mapSyms.begin(), parentIter)] |
| 522 | .getDefiningOp()); |
| 523 | |
| 524 | // NOTE: To maintain appropriate SSA ordering, we move the parent map |
| 525 | // which will now have references to its children after the last |
| 526 | // of its members to be generated. This is necessary when a user |
| 527 | // has defined a series of parent and children maps where the parent |
| 528 | // precedes the children. An alternative, may be to do |
| 529 | // delayed generation of map info operations from the clauses and |
| 530 | // organize them first before generation. Or to use the |
| 531 | // topologicalSort utility which will enforce a stronger SSA |
| 532 | // dominance ordering at the cost of efficiency/time. |
| 533 | mapOp->moveAfter(indices.second.memberMap.back()); |
| 534 | |
| 535 | for (mlir::omp::MapInfoOp memberMap : indices.second.memberMap) |
| 536 | mapOp.getMembersMutable().append(memberMap.getResult()); |
| 537 | |
| 538 | mapOp.setMembersIndexAttr(firOpBuilder.create2DI64ArrayAttr( |
| 539 | indices.second.memberPlacementIndices)); |
| 540 | } else { |
| 541 | // NOTE: We take the map type of the first child, this may not |
| 542 | // be the correct thing to do, however, we shall see. For the moment |
| 543 | // it allows this to work with enter and exit without causing MLIR |
| 544 | // verification issues. The more appropriate thing may be to take |
| 545 | // the "main" map type clause from the directive being used. |
| 546 | uint64_t mapType = indices.second.memberMap[0].getMapType(); |
| 547 | |
| 548 | llvm::SmallVector<mlir::Value> members; |
| 549 | members.reserve(indices.second.memberMap.size()); |
| 550 | for (mlir::omp::MapInfoOp memberMap : indices.second.memberMap) |
| 551 | members.push_back(memberMap.getResult()); |
| 552 | |
| 553 | // Create parent to emplace and bind members |
| 554 | llvm::SmallVector<mlir::Value> bounds; |
| 555 | std::stringstream asFortran; |
| 556 | fir::factory::AddrAndBoundsInfo info = |
| 557 | lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp, |
| 558 | mlir::omp::MapBoundsType>( |
| 559 | converter, firOpBuilder, semaCtx, converter.getFctCtx(), |
| 560 | *indices.first.sym(), indices.first.ref(), |
| 561 | converter.getCurrentLocation(), asFortran, bounds, |
| 562 | treatIndexAsSection); |
| 563 | |
| 564 | mlir::omp::MapInfoOp mapOp = createMapInfoOp( |
| 565 | firOpBuilder, info.rawInput.getLoc(), info.rawInput, |
| 566 | /*varPtrPtr=*/mlir::Value(), asFortran.str(), bounds, members, |
| 567 | firOpBuilder.create2DI64ArrayAttr( |
| 568 | indices.second.memberPlacementIndices), |
| 569 | mapType, mlir::omp::VariableCaptureKind::ByRef, |
| 570 | info.rawInput.getType(), |
| 571 | /*partialMap=*/true); |
| 572 | |
| 573 | mapOperands.push_back(mapOp); |
| 574 | mapSyms.push_back(indices.first.sym()); |
| 575 | } |
| 576 | } |
| 577 | } |
| 578 | |
| 579 | void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp, |
| 580 | mlir::Location loc) { |
| 581 | using Lastprivate = omp::clause::Lastprivate; |
| 582 | auto &maybeMod = |
| 583 | std::get<std::optional<Lastprivate::LastprivateModifier>>(lastp.t); |
| 584 | if (maybeMod) { |
| 585 | assert(*maybeMod == Lastprivate::LastprivateModifier::Conditional && |
| 586 | "Unexpected lastprivate modifier" ); |
| 587 | TODO(loc, "lastprivate clause with CONDITIONAL modifier" ); |
| 588 | } |
| 589 | } |
| 590 | |
| 591 | static void convertLoopBounds(lower::AbstractConverter &converter, |
| 592 | mlir::Location loc, |
| 593 | mlir::omp::LoopRelatedClauseOps &result, |
| 594 | std::size_t loopVarTypeSize) { |
| 595 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
| 596 | // The types of lower bound, upper bound, and step are converted into the |
| 597 | // type of the loop variable if necessary. |
| 598 | mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); |
| 599 | for (unsigned it = 0; it < (unsigned)result.loopLowerBounds.size(); it++) { |
| 600 | result.loopLowerBounds[it] = firOpBuilder.createConvert( |
| 601 | loc, loopVarType, result.loopLowerBounds[it]); |
| 602 | result.loopUpperBounds[it] = firOpBuilder.createConvert( |
| 603 | loc, loopVarType, result.loopUpperBounds[it]); |
| 604 | result.loopSteps[it] = |
| 605 | firOpBuilder.createConvert(loc, loopVarType, result.loopSteps[it]); |
| 606 | } |
| 607 | } |
| 608 | |
| 609 | bool collectLoopRelatedInfo( |
| 610 | lower::AbstractConverter &converter, mlir::Location currentLocation, |
| 611 | lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses, |
| 612 | mlir::omp::LoopRelatedClauseOps &result, |
| 613 | llvm::SmallVectorImpl<const semantics::Symbol *> &iv) { |
| 614 | bool found = false; |
| 615 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
| 616 | |
| 617 | // Collect the loops to collapse. |
| 618 | lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation(); |
| 619 | if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) { |
| 620 | TODO(currentLocation, "Do Concurrent in Worksharing loop construct" ); |
| 621 | } |
| 622 | |
| 623 | std::int64_t collapseValue = 1l; |
| 624 | if (auto *clause = |
| 625 | ClauseFinder::findUniqueClause<omp::clause::Collapse>(clauses)) { |
| 626 | collapseValue = evaluate::ToInt64(clause->v).value(); |
| 627 | found = true; |
| 628 | } |
| 629 | |
| 630 | std::size_t loopVarTypeSize = 0; |
| 631 | do { |
| 632 | lower::pft::Evaluation *doLoop = |
| 633 | &doConstructEval->getFirstNestedEvaluation(); |
| 634 | auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>(); |
| 635 | assert(doStmt && "Expected do loop to be in the nested evaluation" ); |
| 636 | const auto &loopControl = |
| 637 | std::get<std::optional<parser::LoopControl>>(doStmt->t); |
| 638 | const parser::LoopControl::Bounds *bounds = |
| 639 | std::get_if<parser::LoopControl::Bounds>(&loopControl->u); |
| 640 | assert(bounds && "Expected bounds for worksharing do loop" ); |
| 641 | lower::StatementContext stmtCtx; |
| 642 | result.loopLowerBounds.push_back(fir::getBase( |
| 643 | converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx))); |
| 644 | result.loopUpperBounds.push_back(fir::getBase( |
| 645 | converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx))); |
| 646 | if (bounds->step) { |
| 647 | result.loopSteps.push_back(fir::getBase( |
| 648 | converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx))); |
| 649 | } else { // If `step` is not present, assume it as `1`. |
| 650 | result.loopSteps.push_back(firOpBuilder.createIntegerConstant( |
| 651 | currentLocation, firOpBuilder.getIntegerType(32), 1)); |
| 652 | } |
| 653 | iv.push_back(Elt: bounds->name.thing.symbol); |
| 654 | loopVarTypeSize = std::max(loopVarTypeSize, |
| 655 | bounds->name.thing.symbol->GetUltimate().size()); |
| 656 | collapseValue--; |
| 657 | doConstructEval = |
| 658 | &*std::next(doConstructEval->getNestedEvaluations().begin()); |
| 659 | } while (collapseValue > 0); |
| 660 | |
| 661 | convertLoopBounds(converter, currentLocation, result, loopVarTypeSize); |
| 662 | |
| 663 | return found; |
| 664 | } |
| 665 | } // namespace omp |
| 666 | } // namespace lower |
| 667 | } // namespace Fortran |
| 668 | |