1 | //===-- Utils..cpp ----------------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "Utils.h" |
14 | |
15 | #include "Clauses.h" |
16 | |
17 | #include "ClauseFinder.h" |
18 | #include <flang/Lower/AbstractConverter.h> |
19 | #include <flang/Lower/ConvertType.h> |
20 | #include <flang/Lower/DirectivesCommon.h> |
21 | #include <flang/Lower/PFTBuilder.h> |
22 | #include <flang/Optimizer/Builder/FIRBuilder.h> |
23 | #include <flang/Optimizer/Builder/Todo.h> |
24 | #include <flang/Parser/parse-tree.h> |
25 | #include <flang/Parser/tools.h> |
26 | #include <flang/Semantics/tools.h> |
27 | #include <llvm/Support/CommandLine.h> |
28 | |
29 | #include <iterator> |
30 | |
31 | llvm::cl::opt<bool> treatIndexAsSection( |
32 | "openmp-treat-index-as-section" , |
33 | llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`." ), |
34 | llvm::cl::init(Val: true)); |
35 | |
36 | namespace Fortran { |
37 | namespace lower { |
38 | namespace omp { |
39 | |
40 | int64_t getCollapseValue(const List<Clause> &clauses) { |
41 | auto iter = llvm::find_if(Range: clauses, P: [](const Clause &clause) { |
42 | return clause.id == llvm::omp::Clause::OMPC_collapse; |
43 | }); |
44 | if (iter != clauses.end()) { |
45 | const auto &collapse = std::get<clause::Collapse>(iter->u); |
46 | return evaluate::ToInt64(collapse.v).value(); |
47 | } |
48 | return 1; |
49 | } |
50 | |
51 | void genObjectList(const ObjectList &objects, |
52 | lower::AbstractConverter &converter, |
53 | llvm::SmallVectorImpl<mlir::Value> &operands) { |
54 | for (const Object &object : objects) { |
55 | const semantics::Symbol *sym = object.sym(); |
56 | assert(sym && "Expected Symbol" ); |
57 | if (mlir::Value variable = converter.getSymbolAddress(*sym)) { |
58 | operands.push_back(variable); |
59 | } else if (const auto *details = |
60 | sym->detailsIf<semantics::HostAssocDetails>()) { |
61 | operands.push_back(converter.getSymbolAddress(details->symbol())); |
62 | converter.copySymbolBinding(details->symbol(), *sym); |
63 | } |
64 | } |
65 | } |
66 | |
67 | mlir::Type getLoopVarType(lower::AbstractConverter &converter, |
68 | std::size_t loopVarTypeSize) { |
69 | // OpenMP runtime requires 32-bit or 64-bit loop variables. |
70 | loopVarTypeSize = loopVarTypeSize * 8; |
71 | if (loopVarTypeSize < 32) { |
72 | loopVarTypeSize = 32; |
73 | } else if (loopVarTypeSize > 64) { |
74 | loopVarTypeSize = 64; |
75 | mlir::emitWarning(converter.getCurrentLocation(), |
76 | "OpenMP loop iteration variable cannot have more than 64 " |
77 | "bits size and will be narrowed into 64 bits." ); |
78 | } |
79 | assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) && |
80 | "OpenMP loop iteration variable size must be transformed into 32-bit " |
81 | "or 64-bit" ); |
82 | return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize); |
83 | } |
84 | |
85 | semantics::Symbol * |
86 | getIterationVariableSymbol(const lower::pft::Evaluation &eval) { |
87 | return eval.visit(common::visitors{ |
88 | [&](const parser::DoConstruct &doLoop) { |
89 | if (const auto &maybeCtrl = doLoop.GetLoopControl()) { |
90 | using LoopControl = parser::LoopControl; |
91 | if (auto *bounds = std::get_if<LoopControl::Bounds>(&maybeCtrl->u)) { |
92 | static_assert(std::is_same_v<decltype(bounds->name), |
93 | parser::Scalar<parser::Name>>); |
94 | return bounds->name.thing.symbol; |
95 | } |
96 | } |
97 | return static_cast<semantics::Symbol *>(nullptr); |
98 | }, |
99 | [](auto &&) { return static_cast<semantics::Symbol *>(nullptr); }, |
100 | }); |
101 | } |
102 | |
103 | void gatherFuncAndVarSyms( |
104 | const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, |
105 | llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { |
106 | for (const Object &object : objects) |
107 | symbolAndClause.emplace_back(clause, *object.sym()); |
108 | } |
109 | |
110 | mlir::omp::MapInfoOp |
111 | createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, |
112 | mlir::Value baseAddr, mlir::Value varPtrPtr, |
113 | llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds, |
114 | llvm::ArrayRef<mlir::Value> members, |
115 | mlir::ArrayAttr membersIndex, uint64_t mapType, |
116 | mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, |
117 | bool partialMap, mlir::FlatSymbolRefAttr mapperId) { |
118 | if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) { |
119 | baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr); |
120 | retTy = baseAddr.getType(); |
121 | } |
122 | |
123 | mlir::TypeAttr varType = mlir::TypeAttr::get( |
124 | llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType()); |
125 | |
126 | // For types with unknown extents such as <2x?xi32> we discard the incomplete |
127 | // type info and only retain the base type. The correct dimensions are later |
128 | // recovered through the bounds info. |
129 | if (auto seqType = llvm::dyn_cast<fir::SequenceType>(varType.getValue())) |
130 | if (seqType.hasDynamicExtents()) |
131 | varType = mlir::TypeAttr::get(seqType.getEleTy()); |
132 | |
133 | mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>( |
134 | loc, retTy, baseAddr, varType, |
135 | builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), |
136 | builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType), |
137 | varPtrPtr, members, membersIndex, bounds, mapperId, |
138 | builder.getStringAttr(name), builder.getBoolAttr(partialMap)); |
139 | return op; |
140 | } |
141 | |
142 | // This function gathers the individual omp::Object's that make up a |
143 | // larger omp::Object symbol. |
144 | // |
145 | // For example, provided the larger symbol: "parent%child%member", this |
146 | // function breaks it up into its constituent components ("parent", |
147 | // "child", "member"), so we can access each individual component and |
148 | // introspect details. Important to note is this function breaks it up from |
149 | // RHS to LHS ("member" to "parent") and then we reverse it so that the |
150 | // returned omp::ObjectList is LHS to RHS, with the "parent" at the |
151 | // beginning. |
152 | omp::ObjectList gatherObjectsOf(omp::Object derivedTypeMember, |
153 | semantics::SemanticsContext &semaCtx) { |
154 | omp::ObjectList objList; |
155 | std::optional<omp::Object> baseObj = derivedTypeMember; |
156 | while (baseObj.has_value()) { |
157 | objList.push_back(baseObj.value()); |
158 | baseObj = getBaseObject(baseObj.value(), semaCtx); |
159 | } |
160 | return omp::ObjectList{llvm::reverse(objList)}; |
161 | } |
162 | |
163 | // This function generates a series of indices from a provided omp::Object, |
164 | // that devolves to an ArrayRef symbol, e.g. "array(2,3,4)", this function |
165 | // would generate a series of indices of "[1][2][3]" for the above example, |
166 | // offsetting by -1 to account for the non-zero fortran indexes. |
167 | // |
168 | // These indices can then be provided to a coordinate operation or other |
169 | // GEP-like operation to access the relevant positional member of the |
170 | // array. |
171 | // |
172 | // It is of note that the function only supports subscript integers currently |
173 | // and not Triplets i.e. Array(1:2:3). |
174 | static void generateArrayIndices(lower::AbstractConverter &converter, |
175 | fir::FirOpBuilder &firOpBuilder, |
176 | lower::StatementContext &stmtCtx, |
177 | mlir::Location clauseLocation, |
178 | llvm::SmallVectorImpl<mlir::Value> &indices, |
179 | omp::Object object) { |
180 | auto maybeRef = evaluate::ExtractDataRef(*object.ref()); |
181 | if (!maybeRef) |
182 | return; |
183 | |
184 | auto *arr = std::get_if<evaluate::ArrayRef>(&maybeRef->u); |
185 | if (!arr) |
186 | return; |
187 | |
188 | for (auto v : arr->subscript()) { |
189 | if (std::holds_alternative<Triplet>(v.u)) |
190 | TODO(clauseLocation, "Triplet indexing in map clause is unsupported" ); |
191 | |
192 | auto expr = std::get<Fortran::evaluate::IndirectSubscriptIntegerExpr>(v.u); |
193 | mlir::Value subscript = |
194 | fir::getBase(converter.genExprValue(toEvExpr(expr.value()), stmtCtx)); |
195 | mlir::Value one = firOpBuilder.createIntegerConstant( |
196 | clauseLocation, firOpBuilder.getIndexType(), 1); |
197 | subscript = firOpBuilder.createConvert( |
198 | clauseLocation, firOpBuilder.getIndexType(), subscript); |
199 | indices.push_back(firOpBuilder.create<mlir::arith::SubIOp>(clauseLocation, |
200 | subscript, one)); |
201 | } |
202 | } |
203 | |
204 | /// When mapping members of derived types, there is a chance that one of the |
205 | /// members along the way to a mapped member is an descriptor. In which case |
206 | /// we have to make sure we generate a map for those along the way otherwise |
207 | /// we will be missing a chunk of data required to actually map the member |
208 | /// type to device. This function effectively generates these maps and the |
209 | /// appropriate data accesses required to generate these maps. It will avoid |
210 | /// creating duplicate maps, as duplicates are just as bad as unmapped |
211 | /// descriptor data in a lot of cases for the runtime (and unnecessary |
212 | /// data movement should be avoided where possible). |
213 | /// |
214 | /// As an example for the following mapping: |
215 | /// |
216 | /// type :: vertexes |
217 | /// integer(4), allocatable :: vertexx(:) |
218 | /// integer(4), allocatable :: vertexy(:) |
219 | /// end type vertexes |
220 | /// |
221 | /// type :: dtype |
222 | /// real(4) :: i |
223 | /// type(vertexes), allocatable :: vertexes(:) |
224 | /// end type dtype |
225 | /// |
226 | /// type(dtype), allocatable :: alloca_dtype |
227 | /// |
228 | /// !$omp target map(tofrom: alloca_dtype%vertexes(N1)%vertexx) |
229 | /// |
230 | /// The below HLFIR/FIR is generated (trimmed for conciseness): |
231 | /// |
232 | /// On the first iteration we index into the record type alloca_dtype |
233 | /// to access "vertexes", we then generate a map for this descriptor |
234 | /// alongside bounds to indicate we only need the 1 member, rather than |
235 | /// the whole array block in this case (In theory we could map its |
236 | /// entirety at the cost of data transfer bandwidth). |
237 | /// |
238 | /// %13:2 = hlfir.declare ... "alloca_dtype" ... |
239 | /// %39 = fir.load %13#0 : ... |
240 | /// %40 = fir.coordinate_of %39, %c1 : ... |
241 | /// %51 = omp.map.info var_ptr(%40 : ...) map_clauses(to) capture(ByRef) ... |
242 | /// %52 = fir.load %40 : ... |
243 | /// |
244 | /// Second iteration generating access to "vertexes(N1) utilising the N1 index |
245 | /// %53 = load N1 ... |
246 | /// %54 = fir.convert %53 : (i32) -> i64 |
247 | /// %55 = fir.convert %54 : (i64) -> index |
248 | /// %56 = arith.subi %55, %c1 : index |
249 | /// %57 = fir.coordinate_of %52, %56 : ... |
250 | /// |
251 | /// Still in the second iteration we access the allocatable member "vertexx", |
252 | /// we return %58 from the function and provide it to the final and "main" |
253 | /// map of processMap (generated by the record type segment of the below |
254 | /// function), if this were not the final symbol in the list, i.e. we accessed |
255 | /// a member below vertexx, we would have generated the map below as we did in |
256 | /// the first iteration and then continue to generate further coordinates to |
257 | /// access further components as required. |
258 | /// |
259 | /// %58 = fir.coordinate_of %57, %c0 : ... |
260 | /// %61 = omp.map.info var_ptr(%58 : ...) map_clauses(to) capture(ByRef) ... |
261 | /// |
262 | /// Parent mapping containing prior generated mapped members, generated at |
263 | /// a later step but here to showcase the "end" result |
264 | /// |
265 | /// omp.map.info var_ptr(%13#1 : ...) map_clauses(to) capture(ByRef) |
266 | /// members(%50, %61 : [0, 1, 0], [0, 1, 0] : ... |
267 | /// |
268 | /// \param objectList - The list of omp::Object symbol data for each parent |
269 | /// to the mapped member (also includes the mapped member), generated via |
270 | /// gatherObjectsOf. |
271 | /// \param indices - List of index data associated with the mapped member |
272 | /// symbol, which identifies the placement of the member in its parent, |
273 | /// this helps generate the appropriate member accesses. These indices |
274 | /// can be generated via generateMemberPlacementIndices. |
275 | /// \param asFortran - A string generated from the mapped variable to be |
276 | /// associated with the main map, generally (but not restricted to) |
277 | /// generated via gatherDataOperandAddrAndBounds or other |
278 | /// DirectiveCommons.hpp utilities. |
279 | /// \param mapTypeBits - The map flags that will be associated with the |
280 | /// generated maps, minus alterations of the TO and FROM bits for the |
281 | /// intermediate components to prevent accidental overwriting on device |
282 | /// write back. |
283 | mlir::Value createParentSymAndGenIntermediateMaps( |
284 | mlir::Location clauseLocation, lower::AbstractConverter &converter, |
285 | semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, |
286 | omp::ObjectList &objectList, llvm::SmallVectorImpl<int64_t> &indices, |
287 | OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran, |
288 | llvm::omp::OpenMPOffloadMappingFlags mapTypeBits) { |
289 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
290 | |
291 | /// Checks if an omp::Object is an array expression with a subscript, e.g. |
292 | /// array(1,2). |
293 | auto isArrayExprWithSubscript = [](omp::Object obj) { |
294 | if (auto maybeRef = evaluate::ExtractDataRef(obj.ref())) { |
295 | evaluate::DataRef ref = *maybeRef; |
296 | if (auto *arr = std::get_if<evaluate::ArrayRef>(&ref.u)) |
297 | return !arr->subscript().empty(); |
298 | } |
299 | return false; |
300 | }; |
301 | |
302 | // Generate the access to the original parent base address. |
303 | fir::factory::AddrAndBoundsInfo parentBaseAddr = |
304 | lower::getDataOperandBaseAddr(converter, firOpBuilder, |
305 | *objectList[0].sym(), clauseLocation); |
306 | mlir::Value curValue = parentBaseAddr.addr; |
307 | |
308 | // Iterate over all objects in the objectList, this should consist of all |
309 | // record types between the parent and the member being mapped (including |
310 | // the parent). The object list may also contain array objects as well, |
311 | // this can occur when specifying bounds or a specific element access |
312 | // within a member map, we skip these. |
313 | size_t currentIndicesIdx = 0; |
314 | for (size_t i = 0; i < objectList.size(); ++i) { |
315 | // If we encounter a sequence type, i.e. an array, we must generate the |
316 | // correct coordinate operation to index into the array to proceed further, |
317 | // this is only relevant in cases where we encounter subscripts currently. |
318 | // |
319 | // For example in the following case: |
320 | // |
321 | // map(tofrom: array_dtype(4)%internal_dtypes(3)%float_elements(4)) |
322 | // |
323 | // We must generate coordinate operation accesses for each subscript |
324 | // we encounter. |
325 | if (fir::SequenceType arrType = mlir::dyn_cast<fir::SequenceType>( |
326 | fir::unwrapPassByRefType(curValue.getType()))) { |
327 | if (isArrayExprWithSubscript(objectList[i])) { |
328 | llvm::SmallVector<mlir::Value> subscriptIndices; |
329 | generateArrayIndices(converter, firOpBuilder, stmtCtx, clauseLocation, |
330 | subscriptIndices, objectList[i]); |
331 | assert(!subscriptIndices.empty() && |
332 | "missing expected indices for map clause" ); |
333 | curValue = firOpBuilder.create<fir::CoordinateOp>( |
334 | clauseLocation, firOpBuilder.getRefType(arrType.getEleTy()), |
335 | curValue, subscriptIndices); |
336 | } |
337 | } |
338 | |
339 | // If we encounter a record type, we must access the subsequent member |
340 | // by indexing into it and creating a coordinate operation to do so, we |
341 | // utilise the index information generated previously and passed in to |
342 | // work out the correct member to access and the corresponding member |
343 | // type. |
344 | if (fir::RecordType recordType = mlir::dyn_cast<fir::RecordType>( |
345 | fir::unwrapPassByRefType(curValue.getType()))) { |
346 | fir::IntOrValue idxConst = mlir::IntegerAttr::get( |
347 | firOpBuilder.getI32Type(), indices[currentIndicesIdx]); |
348 | mlir::Type memberTy = recordType.getType(indices[currentIndicesIdx]); |
349 | curValue = firOpBuilder.create<fir::CoordinateOp>( |
350 | clauseLocation, firOpBuilder.getRefType(memberTy), curValue, |
351 | llvm::SmallVector<fir::IntOrValue, 1>{idxConst}); |
352 | |
353 | // If we're a final member, the map will be generated by the processMap |
354 | // call that invoked this function. |
355 | if (currentIndicesIdx == indices.size() - 1) |
356 | break; |
357 | |
358 | // Skip mapping and the subsequent load if we're not |
359 | // a type with a descriptor such as a pointer/allocatable. If we're not a |
360 | // type with a descriptor then we have no need of generating an |
361 | // intermediate map for it, as we only need to generate a map if a member |
362 | // is a descriptor type (and thus obscures the members it contains via a |
363 | // pointer in which it's data needs mapped). |
364 | if (!fir::isTypeWithDescriptor(memberTy)) { |
365 | currentIndicesIdx++; |
366 | continue; |
367 | } |
368 | |
369 | llvm::SmallVector<int64_t> interimIndices( |
370 | indices.begin(), std::next(x: indices.begin(), n: currentIndicesIdx + 1)); |
371 | // Verify we haven't already created a map for this particular member, by |
372 | // checking the list of members already mapped for the current parent, |
373 | // stored in the parentMemberIndices structure |
374 | if (!parentMemberIndices.isDuplicateMemberMapInfo(memberIndices&: interimIndices)) { |
375 | // Generate bounds operations using the standard lowering utility, |
376 | // unfortunately this currently does a bit more than just generate |
377 | // bounds and we discard the other bits. May be useful to extend the |
378 | // utility to just provide bounds in the future. |
379 | llvm::SmallVector<mlir::Value> interimBounds; |
380 | if (i + 1 < objectList.size() && |
381 | objectList[i + 1].sym()->IsObjectArray()) { |
382 | std::stringstream interimFortran; |
383 | Fortran::lower::gatherDataOperandAddrAndBounds< |
384 | mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>( |
385 | converter, converter.getFirOpBuilder(), semaCtx, |
386 | converter.getFctCtx(), *objectList[i + 1].sym(), |
387 | objectList[i + 1].ref(), clauseLocation, interimFortran, |
388 | interimBounds, treatIndexAsSection); |
389 | } |
390 | |
391 | // Remove all map-type bits (e.g. TO, FROM, etc.) from the intermediate |
392 | // allocatable maps, as we simply wish to alloc or release them. It may |
393 | // be safer to just pass OMP_MAP_NONE as the map type, but we may still |
394 | // need some of the other map types the mapped member utilises, so for |
395 | // now it's good to keep an eye on this. |
396 | llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits; |
397 | interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; |
398 | interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; |
399 | interimMapType &= |
400 | ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; |
401 | |
402 | // Create a map for the intermediate member and insert it and it's |
403 | // indices into the parentMemberIndices list to track it. |
404 | mlir::omp::MapInfoOp mapOp = createMapInfoOp( |
405 | firOpBuilder, clauseLocation, curValue, |
406 | /*varPtrPtr=*/mlir::Value{}, asFortran, |
407 | /*bounds=*/interimBounds, |
408 | /*members=*/{}, |
409 | /*membersIndex=*/mlir::ArrayAttr{}, |
410 | static_cast< |
411 | std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( |
412 | interimMapType), |
413 | mlir::omp::VariableCaptureKind::ByRef, curValue.getType()); |
414 | |
415 | parentMemberIndices.memberPlacementIndices.push_back(Elt: interimIndices); |
416 | parentMemberIndices.memberMap.push_back(mapOp); |
417 | } |
418 | |
419 | // Load the currently accessed member, so we can continue to access |
420 | // further segments. |
421 | curValue = firOpBuilder.create<fir::LoadOp>(clauseLocation, curValue); |
422 | currentIndicesIdx++; |
423 | } |
424 | } |
425 | |
426 | return curValue; |
427 | } |
428 | |
429 | static int64_t |
430 | getComponentPlacementInParent(const semantics::Symbol *componentSym) { |
431 | const auto *derived = componentSym->owner() |
432 | .derivedTypeSpec() |
433 | ->typeSymbol() |
434 | .detailsIf<semantics::DerivedTypeDetails>(); |
435 | assert(derived && |
436 | "expected derived type details when processing component symbol" ); |
437 | for (auto [placement, name] : llvm::enumerate(derived->componentNames())) |
438 | if (name == componentSym->name()) |
439 | return placement; |
440 | return -1; |
441 | } |
442 | |
443 | static std::optional<Object> |
444 | getComponentObject(std::optional<Object> object, |
445 | semantics::SemanticsContext &semaCtx) { |
446 | if (!object) |
447 | return std::nullopt; |
448 | |
449 | auto ref = evaluate::ExtractDataRef(object.value().ref()); |
450 | if (!ref) |
451 | return std::nullopt; |
452 | |
453 | if (std::holds_alternative<evaluate::Component>(ref->u)) |
454 | return object; |
455 | |
456 | auto baseObj = getBaseObject(object.value(), semaCtx); |
457 | if (!baseObj) |
458 | return std::nullopt; |
459 | |
460 | return getComponentObject(baseObj.value(), semaCtx); |
461 | } |
462 | |
463 | void generateMemberPlacementIndices(const Object &object, |
464 | llvm::SmallVectorImpl<int64_t> &indices, |
465 | semantics::SemanticsContext &semaCtx) { |
466 | assert(indices.empty() && "indices vector passed to " |
467 | "generateMemberPlacementIndices should be empty" ); |
468 | auto compObj = getComponentObject(object, semaCtx); |
469 | |
470 | while (compObj) { |
471 | int64_t index = getComponentPlacementInParent(compObj->sym()); |
472 | assert( |
473 | index >= 0 && |
474 | "unexpected index value returned from getComponentPlacementInParent" ); |
475 | indices.push_back(Elt: index); |
476 | compObj = |
477 | getComponentObject(getBaseObject(compObj.value(), semaCtx), semaCtx); |
478 | } |
479 | |
480 | indices = llvm::SmallVector<int64_t>{llvm::reverse(C&: indices)}; |
481 | } |
482 | |
483 | void OmpMapParentAndMemberData::addChildIndexAndMapToParent( |
484 | const omp::Object &object, mlir::omp::MapInfoOp &mapOp, |
485 | semantics::SemanticsContext &semaCtx) { |
486 | llvm::SmallVector<int64_t> indices; |
487 | generateMemberPlacementIndices(object, indices, semaCtx); |
488 | memberPlacementIndices.push_back(Elt: indices); |
489 | memberMap.push_back(mapOp); |
490 | } |
491 | |
492 | bool isMemberOrParentAllocatableOrPointer( |
493 | const Object &object, semantics::SemanticsContext &semaCtx) { |
494 | if (semantics::IsAllocatableOrObjectPointer(object.sym())) |
495 | return true; |
496 | |
497 | auto compObj = getBaseObject(object, semaCtx); |
498 | while (compObj) { |
499 | if (semantics::IsAllocatableOrObjectPointer(compObj.value().sym())) |
500 | return true; |
501 | compObj = getBaseObject(compObj.value(), semaCtx); |
502 | } |
503 | |
504 | return false; |
505 | } |
506 | |
507 | void insertChildMapInfoIntoParent( |
508 | lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, |
509 | lower::StatementContext &stmtCtx, |
510 | std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices, |
511 | llvm::SmallVectorImpl<mlir::Value> &mapOperands, |
512 | llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) { |
513 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
514 | for (auto indices : parentMemberIndices) { |
515 | auto *parentIter = |
516 | llvm::find_if(mapSyms, [&indices](const semantics::Symbol *v) { |
517 | return v == indices.first.sym(); |
518 | }); |
519 | if (parentIter != mapSyms.end()) { |
520 | auto mapOp = llvm::cast<mlir::omp::MapInfoOp>( |
521 | mapOperands[std::distance(mapSyms.begin(), parentIter)] |
522 | .getDefiningOp()); |
523 | |
524 | // NOTE: To maintain appropriate SSA ordering, we move the parent map |
525 | // which will now have references to its children after the last |
526 | // of its members to be generated. This is necessary when a user |
527 | // has defined a series of parent and children maps where the parent |
528 | // precedes the children. An alternative, may be to do |
529 | // delayed generation of map info operations from the clauses and |
530 | // organize them first before generation. Or to use the |
531 | // topologicalSort utility which will enforce a stronger SSA |
532 | // dominance ordering at the cost of efficiency/time. |
533 | mapOp->moveAfter(indices.second.memberMap.back()); |
534 | |
535 | for (mlir::omp::MapInfoOp memberMap : indices.second.memberMap) |
536 | mapOp.getMembersMutable().append(memberMap.getResult()); |
537 | |
538 | mapOp.setMembersIndexAttr(firOpBuilder.create2DI64ArrayAttr( |
539 | indices.second.memberPlacementIndices)); |
540 | } else { |
541 | // NOTE: We take the map type of the first child, this may not |
542 | // be the correct thing to do, however, we shall see. For the moment |
543 | // it allows this to work with enter and exit without causing MLIR |
544 | // verification issues. The more appropriate thing may be to take |
545 | // the "main" map type clause from the directive being used. |
546 | uint64_t mapType = indices.second.memberMap[0].getMapType(); |
547 | |
548 | llvm::SmallVector<mlir::Value> members; |
549 | members.reserve(indices.second.memberMap.size()); |
550 | for (mlir::omp::MapInfoOp memberMap : indices.second.memberMap) |
551 | members.push_back(memberMap.getResult()); |
552 | |
553 | // Create parent to emplace and bind members |
554 | llvm::SmallVector<mlir::Value> bounds; |
555 | std::stringstream asFortran; |
556 | fir::factory::AddrAndBoundsInfo info = |
557 | lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp, |
558 | mlir::omp::MapBoundsType>( |
559 | converter, firOpBuilder, semaCtx, converter.getFctCtx(), |
560 | *indices.first.sym(), indices.first.ref(), |
561 | converter.getCurrentLocation(), asFortran, bounds, |
562 | treatIndexAsSection); |
563 | |
564 | mlir::omp::MapInfoOp mapOp = createMapInfoOp( |
565 | firOpBuilder, info.rawInput.getLoc(), info.rawInput, |
566 | /*varPtrPtr=*/mlir::Value(), asFortran.str(), bounds, members, |
567 | firOpBuilder.create2DI64ArrayAttr( |
568 | indices.second.memberPlacementIndices), |
569 | mapType, mlir::omp::VariableCaptureKind::ByRef, |
570 | info.rawInput.getType(), |
571 | /*partialMap=*/true); |
572 | |
573 | mapOperands.push_back(mapOp); |
574 | mapSyms.push_back(indices.first.sym()); |
575 | } |
576 | } |
577 | } |
578 | |
579 | void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp, |
580 | mlir::Location loc) { |
581 | using Lastprivate = omp::clause::Lastprivate; |
582 | auto &maybeMod = |
583 | std::get<std::optional<Lastprivate::LastprivateModifier>>(lastp.t); |
584 | if (maybeMod) { |
585 | assert(*maybeMod == Lastprivate::LastprivateModifier::Conditional && |
586 | "Unexpected lastprivate modifier" ); |
587 | TODO(loc, "lastprivate clause with CONDITIONAL modifier" ); |
588 | } |
589 | } |
590 | |
591 | static void convertLoopBounds(lower::AbstractConverter &converter, |
592 | mlir::Location loc, |
593 | mlir::omp::LoopRelatedClauseOps &result, |
594 | std::size_t loopVarTypeSize) { |
595 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
596 | // The types of lower bound, upper bound, and step are converted into the |
597 | // type of the loop variable if necessary. |
598 | mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); |
599 | for (unsigned it = 0; it < (unsigned)result.loopLowerBounds.size(); it++) { |
600 | result.loopLowerBounds[it] = firOpBuilder.createConvert( |
601 | loc, loopVarType, result.loopLowerBounds[it]); |
602 | result.loopUpperBounds[it] = firOpBuilder.createConvert( |
603 | loc, loopVarType, result.loopUpperBounds[it]); |
604 | result.loopSteps[it] = |
605 | firOpBuilder.createConvert(loc, loopVarType, result.loopSteps[it]); |
606 | } |
607 | } |
608 | |
609 | bool collectLoopRelatedInfo( |
610 | lower::AbstractConverter &converter, mlir::Location currentLocation, |
611 | lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses, |
612 | mlir::omp::LoopRelatedClauseOps &result, |
613 | llvm::SmallVectorImpl<const semantics::Symbol *> &iv) { |
614 | bool found = false; |
615 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
616 | |
617 | // Collect the loops to collapse. |
618 | lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation(); |
619 | if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) { |
620 | TODO(currentLocation, "Do Concurrent in Worksharing loop construct" ); |
621 | } |
622 | |
623 | std::int64_t collapseValue = 1l; |
624 | if (auto *clause = |
625 | ClauseFinder::findUniqueClause<omp::clause::Collapse>(clauses)) { |
626 | collapseValue = evaluate::ToInt64(clause->v).value(); |
627 | found = true; |
628 | } |
629 | |
630 | std::size_t loopVarTypeSize = 0; |
631 | do { |
632 | lower::pft::Evaluation *doLoop = |
633 | &doConstructEval->getFirstNestedEvaluation(); |
634 | auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>(); |
635 | assert(doStmt && "Expected do loop to be in the nested evaluation" ); |
636 | const auto &loopControl = |
637 | std::get<std::optional<parser::LoopControl>>(doStmt->t); |
638 | const parser::LoopControl::Bounds *bounds = |
639 | std::get_if<parser::LoopControl::Bounds>(&loopControl->u); |
640 | assert(bounds && "Expected bounds for worksharing do loop" ); |
641 | lower::StatementContext stmtCtx; |
642 | result.loopLowerBounds.push_back(fir::getBase( |
643 | converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx))); |
644 | result.loopUpperBounds.push_back(fir::getBase( |
645 | converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx))); |
646 | if (bounds->step) { |
647 | result.loopSteps.push_back(fir::getBase( |
648 | converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx))); |
649 | } else { // If `step` is not present, assume it as `1`. |
650 | result.loopSteps.push_back(firOpBuilder.createIntegerConstant( |
651 | currentLocation, firOpBuilder.getIntegerType(32), 1)); |
652 | } |
653 | iv.push_back(Elt: bounds->name.thing.symbol); |
654 | loopVarTypeSize = std::max(loopVarTypeSize, |
655 | bounds->name.thing.symbol->GetUltimate().size()); |
656 | collapseValue--; |
657 | doConstructEval = |
658 | &*std::next(doConstructEval->getNestedEvaluations().begin()); |
659 | } while (collapseValue > 0); |
660 | |
661 | convertLoopBounds(converter, currentLocation, result, loopVarTypeSize); |
662 | |
663 | return found; |
664 | } |
665 | } // namespace omp |
666 | } // namespace lower |
667 | } // namespace Fortran |
668 | |