1//===-- ClauseProcessor.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "ClauseProcessor.h"
14#include "Clauses.h"
15
16#include "flang/Lower/PFTBuilder.h"
17#include "flang/Parser/tools.h"
18#include "flang/Semantics/tools.h"
19
20namespace Fortran {
21namespace lower {
22namespace omp {
23
24/// Check for unsupported map operand types.
25static void checkMapType(mlir::Location location, mlir::Type type) {
26 if (auto refType = type.dyn_cast<fir::ReferenceType>())
27 type = refType.getElementType();
28 if (auto boxType = type.dyn_cast_or_null<fir::BoxType>())
29 if (!boxType.getElementType().isa<fir::PointerType>())
30 TODO(location, "OMPD_target_data MapOperand BoxType");
31}
32
33static mlir::omp::ScheduleModifier
34translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) {
35 switch (m) {
36 case omp::clause::Schedule::OrderingModifier::Monotonic:
37 return mlir::omp::ScheduleModifier::monotonic;
38 case omp::clause::Schedule::OrderingModifier::Nonmonotonic:
39 return mlir::omp::ScheduleModifier::nonmonotonic;
40 }
41 return mlir::omp::ScheduleModifier::none;
42}
43
44static mlir::omp::ScheduleModifier
45getScheduleModifier(const omp::clause::Schedule &clause) {
46 using Schedule = omp::clause::Schedule;
47 const auto &modifier =
48 std::get<std::optional<Schedule::OrderingModifier>>(clause.t);
49 if (modifier)
50 return translateScheduleModifier(*modifier);
51 return mlir::omp::ScheduleModifier::none;
52}
53
54static mlir::omp::ScheduleModifier
55getSimdModifier(const omp::clause::Schedule &clause) {
56 using Schedule = omp::clause::Schedule;
57 const auto &modifier =
58 std::get<std::optional<Schedule::ChunkModifier>>(clause.t);
59 if (modifier && *modifier == Schedule::ChunkModifier::Simd)
60 return mlir::omp::ScheduleModifier::simd;
61 return mlir::omp::ScheduleModifier::none;
62}
63
64static void
65genAllocateClause(Fortran::lower::AbstractConverter &converter,
66 const omp::clause::Allocate &clause,
67 llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
68 llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
69 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
70 mlir::Location currentLocation = converter.getCurrentLocation();
71 Fortran::lower::StatementContext stmtCtx;
72
73 auto &objects = std::get<omp::ObjectList>(clause.t);
74
75 using Allocate = omp::clause::Allocate;
76 // ALIGN in this context is unimplemented
77 if (std::get<std::optional<Allocate::AlignModifier>>(clause.t))
78 TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
79
80 // Check if allocate clause has allocator specified. If so, add it
81 // to list of allocators, otherwise, add default allocator to
82 // list of allocators.
83 using SimpleModifier = Allocate::AllocatorSimpleModifier;
84 using ComplexModifier = Allocate::AllocatorComplexModifier;
85 if (auto &mod = std::get<std::optional<SimpleModifier>>(clause.t)) {
86 mlir::Value operand = fir::getBase(converter.genExprValue(*mod, stmtCtx));
87 allocatorOperands.append(objects.size(), operand);
88 } else if (auto &mod = std::get<std::optional<ComplexModifier>>(clause.t)) {
89 mlir::Value operand = fir::getBase(converter.genExprValue(mod->v, stmtCtx));
90 allocatorOperands.append(objects.size(), operand);
91 } else {
92 mlir::Value operand = firOpBuilder.createIntegerConstant(
93 currentLocation, firOpBuilder.getI32Type(), 1);
94 allocatorOperands.append(objects.size(), operand);
95 }
96
97 genObjectList(objects, converter, allocateOperands);
98}
99
100static mlir::omp::ClauseProcBindKindAttr
101genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder,
102 const omp::clause::ProcBind &clause) {
103 mlir::omp::ClauseProcBindKind procBindKind;
104 switch (clause.v) {
105 case omp::clause::ProcBind::AffinityPolicy::Master:
106 procBindKind = mlir::omp::ClauseProcBindKind::Master;
107 break;
108 case omp::clause::ProcBind::AffinityPolicy::Close:
109 procBindKind = mlir::omp::ClauseProcBindKind::Close;
110 break;
111 case omp::clause::ProcBind::AffinityPolicy::Spread:
112 procBindKind = mlir::omp::ClauseProcBindKind::Spread;
113 break;
114 case omp::clause::ProcBind::AffinityPolicy::Primary:
115 procBindKind = mlir::omp::ClauseProcBindKind::Primary;
116 break;
117 }
118 return mlir::omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(),
119 procBindKind);
120}
121
122static mlir::omp::ClauseTaskDependAttr
123genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
124 const omp::clause::Depend::TaskDependenceType kind) {
125 mlir::omp::ClauseTaskDepend pbKind;
126 switch (kind) {
127 case omp::clause::Depend::TaskDependenceType::In:
128 pbKind = mlir::omp::ClauseTaskDepend::taskdependin;
129 break;
130 case omp::clause::Depend::TaskDependenceType::Out:
131 pbKind = mlir::omp::ClauseTaskDepend::taskdependout;
132 break;
133 case omp::clause::Depend::TaskDependenceType::Inout:
134 pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
135 break;
136 case omp::clause::Depend::TaskDependenceType::Mutexinoutset:
137 case omp::clause::Depend::TaskDependenceType::Inoutset:
138 case omp::clause::Depend::TaskDependenceType::Depobj:
139 llvm_unreachable("unhandled parser task dependence type");
140 break;
141 }
142 return mlir::omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(),
143 pbKind);
144}
145
146static mlir::Value
147getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
148 const omp::clause::If &clause,
149 omp::clause::If::DirectiveNameModifier directiveName,
150 mlir::Location clauseLocation) {
151 // Only consider the clause if it's intended for the given directive.
152 auto &directive =
153 std::get<std::optional<omp::clause::If::DirectiveNameModifier>>(clause.t);
154 if (directive && directive.value() != directiveName)
155 return nullptr;
156
157 Fortran::lower::StatementContext stmtCtx;
158 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
159 mlir::Value ifVal = fir::getBase(
160 converter.genExprValue(std::get<omp::SomeExpr>(clause.t), stmtCtx));
161 return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
162 ifVal);
163}
164
165static void addUseDeviceClause(
166 Fortran::lower::AbstractConverter &converter,
167 const omp::ObjectList &objects,
168 llvm::SmallVectorImpl<mlir::Value> &operands,
169 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
170 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
171 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) {
172 genObjectList(objects, converter, operands);
173 for (mlir::Value &operand : operands) {
174 checkMapType(operand.getLoc(), operand.getType());
175 useDeviceTypes.push_back(operand.getType());
176 useDeviceLocs.push_back(operand.getLoc());
177 }
178 for (const omp::Object &object : objects)
179 useDeviceSyms.push_back(object.id());
180}
181
182static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
183 mlir::Location loc,
184 mlir::omp::CollapseClauseOps &result,
185 std::size_t loopVarTypeSize) {
186 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
187 // The types of lower bound, upper bound, and step are converted into the
188 // type of the loop variable if necessary.
189 mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
190 for (unsigned it = 0; it < (unsigned)result.loopLBVar.size(); it++) {
191 result.loopLBVar[it] =
192 firOpBuilder.createConvert(loc, loopVarType, result.loopLBVar[it]);
193 result.loopUBVar[it] =
194 firOpBuilder.createConvert(loc, loopVarType, result.loopUBVar[it]);
195 result.loopStepVar[it] =
196 firOpBuilder.createConvert(loc, loopVarType, result.loopStepVar[it]);
197 }
198}
199
200//===----------------------------------------------------------------------===//
201// ClauseProcessor unique clauses
202//===----------------------------------------------------------------------===//
203
204bool ClauseProcessor::processCollapse(
205 mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
206 mlir::omp::CollapseClauseOps &result,
207 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const {
208 bool found = false;
209 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
210
211 // Collect the loops to collapse.
212 Fortran::lower::pft::Evaluation *doConstructEval =
213 &eval.getFirstNestedEvaluation();
214 if (doConstructEval->getIf<Fortran::parser::DoConstruct>()
215 ->IsDoConcurrent()) {
216 TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
217 }
218
219 std::int64_t collapseValue = 1l;
220 if (auto *clause = findUniqueClause<omp::clause::Collapse>()) {
221 collapseValue = Fortran::evaluate::ToInt64(clause->v).value();
222 found = true;
223 }
224
225 std::size_t loopVarTypeSize = 0;
226 do {
227 Fortran::lower::pft::Evaluation *doLoop =
228 &doConstructEval->getFirstNestedEvaluation();
229 auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
230 assert(doStmt && "Expected do loop to be in the nested evaluation");
231 const auto &loopControl =
232 std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
233 const Fortran::parser::LoopControl::Bounds *bounds =
234 std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
235 assert(bounds && "Expected bounds for worksharing do loop");
236 Fortran::lower::StatementContext stmtCtx;
237 result.loopLBVar.push_back(fir::getBase(converter.genExprValue(
238 *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
239 result.loopUBVar.push_back(fir::getBase(converter.genExprValue(
240 *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
241 if (bounds->step) {
242 result.loopStepVar.push_back(fir::getBase(converter.genExprValue(
243 *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
244 } else { // If `step` is not present, assume it as `1`.
245 result.loopStepVar.push_back(firOpBuilder.createIntegerConstant(
246 currentLocation, firOpBuilder.getIntegerType(32), 1));
247 }
248 iv.push_back(Elt: bounds->name.thing.symbol);
249 loopVarTypeSize = std::max(loopVarTypeSize,
250 bounds->name.thing.symbol->GetUltimate().size());
251 collapseValue--;
252 doConstructEval =
253 &*std::next(doConstructEval->getNestedEvaluations().begin());
254 } while (collapseValue > 0);
255
256 convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
257
258 return found;
259}
260
261bool ClauseProcessor::processDefault() const {
262 if (auto *clause = findUniqueClause<omp::clause::Default>()) {
263 // Private, Firstprivate, Shared, None
264 switch (clause->v) {
265 case omp::clause::Default::DataSharingAttribute::Shared:
266 case omp::clause::Default::DataSharingAttribute::None:
267 // Default clause with shared or none do not require any handling since
268 // Shared is the default behavior in the IR and None is only required
269 // for semantic checks.
270 break;
271 case omp::clause::Default::DataSharingAttribute::Private:
272 // TODO Support default(private)
273 break;
274 case omp::clause::Default::DataSharingAttribute::Firstprivate:
275 // TODO Support default(firstprivate)
276 break;
277 }
278 return true;
279 }
280 return false;
281}
282
283bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
284 mlir::omp::DeviceClauseOps &result) const {
285 const Fortran::parser::CharBlock *source = nullptr;
286 if (auto *clause = findUniqueClause<omp::clause::Device>(&source)) {
287 mlir::Location clauseLocation = converter.genLocation(*source);
288 if (auto deviceModifier =
289 std::get<std::optional<omp::clause::Device::DeviceModifier>>(
290 clause->t)) {
291 if (deviceModifier == omp::clause::Device::DeviceModifier::Ancestor) {
292 TODO(clauseLocation, "OMPD_target Device Modifier Ancestor");
293 }
294 }
295 const auto &deviceExpr = std::get<omp::SomeExpr>(clause->t);
296 result.deviceVar =
297 fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
298 return true;
299 }
300 return false;
301}
302
303bool ClauseProcessor::processDeviceType(
304 mlir::omp::DeviceTypeClauseOps &result) const {
305 if (auto *clause = findUniqueClause<omp::clause::DeviceType>()) {
306 // Case: declare target ... device_type(any | host | nohost)
307 switch (clause->v) {
308 case omp::clause::DeviceType::DeviceTypeDescription::Nohost:
309 result.deviceType = mlir::omp::DeclareTargetDeviceType::nohost;
310 break;
311 case omp::clause::DeviceType::DeviceTypeDescription::Host:
312 result.deviceType = mlir::omp::DeclareTargetDeviceType::host;
313 break;
314 case omp::clause::DeviceType::DeviceTypeDescription::Any:
315 result.deviceType = mlir::omp::DeclareTargetDeviceType::any;
316 break;
317 }
318 return true;
319 }
320 return false;
321}
322
323bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
324 mlir::omp::FinalClauseOps &result) const {
325 const Fortran::parser::CharBlock *source = nullptr;
326 if (auto *clause = findUniqueClause<omp::clause::Final>(&source)) {
327 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
328 mlir::Location clauseLocation = converter.genLocation(*source);
329
330 mlir::Value finalVal =
331 fir::getBase(converter.genExprValue(clause->v, stmtCtx));
332 result.finalVar = firOpBuilder.createConvert(
333 clauseLocation, firOpBuilder.getI1Type(), finalVal);
334 return true;
335 }
336 return false;
337}
338
339bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
340 if (auto *clause = findUniqueClause<omp::clause::Hint>()) {
341 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
342 int64_t hintValue = *Fortran::evaluate::ToInt64(clause->v);
343 result.hintAttr = firOpBuilder.getI64IntegerAttr(hintValue);
344 return true;
345 }
346 return false;
347}
348
349bool ClauseProcessor::processMergeable(
350 mlir::omp::MergeableClauseOps &result) const {
351 return markClauseOccurrence<omp::clause::Mergeable>(result.mergeableAttr);
352}
353
354bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const {
355 return markClauseOccurrence<omp::clause::Nowait>(result.nowaitAttr);
356}
357
358bool ClauseProcessor::processNumTeams(
359 Fortran::lower::StatementContext &stmtCtx,
360 mlir::omp::NumTeamsClauseOps &result) const {
361 // TODO Get lower and upper bounds for num_teams when parser is updated to
362 // accept both.
363 if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
364 // auto lowerBound = std::get<std::optional<ExprTy>>(clause->t);
365 auto &upperBound = std::get<ExprTy>(clause->t);
366 result.numTeamsUpperVar =
367 fir::getBase(converter.genExprValue(upperBound, stmtCtx));
368 return true;
369 }
370 return false;
371}
372
373bool ClauseProcessor::processNumThreads(
374 Fortran::lower::StatementContext &stmtCtx,
375 mlir::omp::NumThreadsClauseOps &result) const {
376 if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
377 // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
378 result.numThreadsVar =
379 fir::getBase(converter.genExprValue(clause->v, stmtCtx));
380 return true;
381 }
382 return false;
383}
384
385bool ClauseProcessor::processOrdered(
386 mlir::omp::OrderedClauseOps &result) const {
387 if (auto *clause = findUniqueClause<omp::clause::Ordered>()) {
388 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
389 int64_t orderedClauseValue = 0l;
390 if (clause->v.has_value())
391 orderedClauseValue = *Fortran::evaluate::ToInt64(*clause->v);
392 result.orderedAttr = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
393 return true;
394 }
395 return false;
396}
397
398bool ClauseProcessor::processPriority(
399 Fortran::lower::StatementContext &stmtCtx,
400 mlir::omp::PriorityClauseOps &result) const {
401 if (auto *clause = findUniqueClause<omp::clause::Priority>()) {
402 result.priorityVar =
403 fir::getBase(converter.genExprValue(clause->v, stmtCtx));
404 return true;
405 }
406 return false;
407}
408
409bool ClauseProcessor::processProcBind(
410 mlir::omp::ProcBindClauseOps &result) const {
411 if (auto *clause = findUniqueClause<omp::clause::ProcBind>()) {
412 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
413 result.procBindKindAttr = genProcBindKindAttr(firOpBuilder, *clause);
414 return true;
415 }
416 return false;
417}
418
419bool ClauseProcessor::processSafelen(
420 mlir::omp::SafelenClauseOps &result) const {
421 if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
422 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
423 const std::optional<std::int64_t> safelenVal =
424 Fortran::evaluate::ToInt64(clause->v);
425 result.safelenAttr = firOpBuilder.getI64IntegerAttr(*safelenVal);
426 return true;
427 }
428 return false;
429}
430
431bool ClauseProcessor::processSchedule(
432 Fortran::lower::StatementContext &stmtCtx,
433 mlir::omp::ScheduleClauseOps &result) const {
434 if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
435 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
436 mlir::MLIRContext *context = firOpBuilder.getContext();
437 const auto &scheduleType = std::get<omp::clause::Schedule::Kind>(clause->t);
438
439 mlir::omp::ClauseScheduleKind scheduleKind;
440 switch (scheduleType) {
441 case omp::clause::Schedule::Kind::Static:
442 scheduleKind = mlir::omp::ClauseScheduleKind::Static;
443 break;
444 case omp::clause::Schedule::Kind::Dynamic:
445 scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic;
446 break;
447 case omp::clause::Schedule::Kind::Guided:
448 scheduleKind = mlir::omp::ClauseScheduleKind::Guided;
449 break;
450 case omp::clause::Schedule::Kind::Auto:
451 scheduleKind = mlir::omp::ClauseScheduleKind::Auto;
452 break;
453 case omp::clause::Schedule::Kind::Runtime:
454 scheduleKind = mlir::omp::ClauseScheduleKind::Runtime;
455 break;
456 }
457
458 result.scheduleValAttr =
459 mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
460
461 mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
462 if (scheduleModifier != mlir::omp::ScheduleModifier::none)
463 result.scheduleModAttr =
464 mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
465
466 if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none)
467 result.scheduleSimdAttr = firOpBuilder.getUnitAttr();
468
469 if (const auto &chunkExpr = std::get<omp::MaybeExpr>(clause->t))
470 result.scheduleChunkVar =
471 fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
472
473 return true;
474 }
475 return false;
476}
477
478bool ClauseProcessor::processSimdlen(
479 mlir::omp::SimdlenClauseOps &result) const {
480 if (auto *clause = findUniqueClause<omp::clause::Simdlen>()) {
481 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
482 const std::optional<std::int64_t> simdlenVal =
483 Fortran::evaluate::ToInt64(clause->v);
484 result.simdlenAttr = firOpBuilder.getI64IntegerAttr(*simdlenVal);
485 return true;
486 }
487 return false;
488}
489
490bool ClauseProcessor::processThreadLimit(
491 Fortran::lower::StatementContext &stmtCtx,
492 mlir::omp::ThreadLimitClauseOps &result) const {
493 if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
494 result.threadLimitVar =
495 fir::getBase(converter.genExprValue(clause->v, stmtCtx));
496 return true;
497 }
498 return false;
499}
500
501bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const {
502 return markClauseOccurrence<omp::clause::Untied>(result.untiedAttr);
503}
504
505//===----------------------------------------------------------------------===//
506// ClauseProcessor repeatable clauses
507//===----------------------------------------------------------------------===//
508
509bool ClauseProcessor::processAllocate(
510 mlir::omp::AllocateClauseOps &result) const {
511 return findRepeatableClause<omp::clause::Allocate>(
512 [&](const omp::clause::Allocate &clause,
513 const Fortran::parser::CharBlock &) {
514 genAllocateClause(converter, clause, result.allocatorVars,
515 result.allocateVars);
516 });
517}
518
519bool ClauseProcessor::processCopyin() const {
520 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
521 mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint();
522 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
523 auto checkAndCopyHostAssociateVar =
524 [&](Fortran::semantics::Symbol *sym,
525 mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) {
526 assert(sym->has<Fortran::semantics::HostAssocDetails>() &&
527 "No host-association found");
528 if (converter.isPresentShallowLookup(*sym))
529 converter.copyHostAssociateVar(*sym, copyAssignIP);
530 };
531 bool hasCopyin = findRepeatableClause<omp::clause::Copyin>(
532 [&](const omp::clause::Copyin &clause,
533 const Fortran::parser::CharBlock &) {
534 for (const omp::Object &object : clause.v) {
535 Fortran::semantics::Symbol *sym = object.id();
536 assert(sym && "Expecting symbol");
537 if (const auto *commonDetails =
538 sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
539 for (const auto &mem : commonDetails->objects())
540 checkAndCopyHostAssociateVar(&*mem, &insPt);
541 break;
542 }
543 if (Fortran::semantics::IsAllocatableOrObjectPointer(
544 &sym->GetUltimate()))
545 TODO(converter.getCurrentLocation(),
546 "pointer or allocatable variables in Copyin clause");
547 assert(sym->has<Fortran::semantics::HostAssocDetails>() &&
548 "No host-association found");
549 checkAndCopyHostAssociateVar(sym);
550 }
551 });
552
553 // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to
554 // the execution of the associated structured block. Emit implicit barrier to
555 // synchronize threads and avoid data races on propagation master's thread
556 // values of threadprivate variables to local instances of that variables of
557 // all other implicit threads.
558 if (hasCopyin)
559 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
560 firOpBuilder.restoreInsertionPoint(insPt);
561 return hasCopyin;
562}
563
564/// Class that extracts information from the specified type.
565class TypeInfo {
566public:
567 TypeInfo(mlir::Type ty) { typeScan(ty); }
568
569 // Returns the length of character types.
570 std::optional<fir::CharacterType::LenType> getCharLength() const {
571 return charLen;
572 }
573
574 // Returns the shape of array types.
575 llvm::ArrayRef<int64_t> getShape() const { return shape; }
576
577 // Is the type inside a box?
578 bool isBox() const { return inBox; }
579
580private:
581 void typeScan(mlir::Type type);
582
583 std::optional<fir::CharacterType::LenType> charLen;
584 llvm::SmallVector<int64_t> shape;
585 bool inBox = false;
586};
587
588void TypeInfo::typeScan(mlir::Type ty) {
589 if (auto sty = mlir::dyn_cast<fir::SequenceType>(ty)) {
590 assert(shape.empty() && !sty.getShape().empty());
591 shape = llvm::SmallVector<int64_t>(sty.getShape());
592 typeScan(sty.getEleTy());
593 } else if (auto bty = mlir::dyn_cast<fir::BoxType>(ty)) {
594 inBox = true;
595 typeScan(bty.getEleTy());
596 } else if (auto cty = mlir::dyn_cast<fir::CharacterType>(ty)) {
597 charLen = cty.getLen();
598 } else if (auto hty = mlir::dyn_cast<fir::HeapType>(ty)) {
599 typeScan(hty.getEleTy());
600 } else if (auto pty = mlir::dyn_cast<fir::PointerType>(ty)) {
601 typeScan(pty.getEleTy());
602 } else {
603 // The scan ends when reaching any built-in or record type.
604 assert(ty.isIntOrIndexOrFloat() || mlir::isa<fir::ComplexType>(ty) ||
605 mlir::isa<fir::LogicalType>(ty) || mlir::isa<fir::RecordType>(ty));
606 }
607}
608
609// Create a function that performs a copy between two variables, compatible
610// with their types and attributes.
611static mlir::func::FuncOp
612createCopyFunc(mlir::Location loc, Fortran::lower::AbstractConverter &converter,
613 mlir::Type varType, fir::FortranVariableFlagsEnum varAttrs) {
614 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
615 mlir::ModuleOp module = builder.getModule();
616 mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
617 TypeInfo typeInfo(eleTy);
618 std::string copyFuncName =
619 fir::getTypeAsString(eleTy, builder.getKindMap(), "_copy");
620
621 if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
622 return decl;
623
624 // create function
625 mlir::OpBuilder::InsertionGuard guard(builder);
626 mlir::OpBuilder modBuilder(module.getBodyRegion());
627 llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
628 auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
629 mlir::func::FuncOp funcOp =
630 modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
631 funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
632 builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
633 {loc, loc});
634 builder.setInsertionPointToStart(&funcOp.getRegion().back());
635 // generate body
636 fir::FortranVariableFlagsAttr attrs;
637 if (varAttrs != fir::FortranVariableFlagsEnum::None)
638 attrs = fir::FortranVariableFlagsAttr::get(builder.getContext(), varAttrs);
639 llvm::SmallVector<mlir::Value> typeparams;
640 if (typeInfo.getCharLength().has_value()) {
641 mlir::Value charLen = builder.createIntegerConstant(
642 loc, builder.getCharacterLengthType(), *typeInfo.getCharLength());
643 typeparams.push_back(charLen);
644 }
645 mlir::Value shape;
646 if (!typeInfo.isBox() && !typeInfo.getShape().empty()) {
647 llvm::SmallVector<mlir::Value> extents;
648 for (auto extent : typeInfo.getShape())
649 extents.push_back(
650 builder.createIntegerConstant(loc, builder.getIndexType(), extent));
651 shape = builder.create<fir::ShapeOp>(loc, extents);
652 }
653 auto declDst = builder.create<hlfir::DeclareOp>(loc, funcOp.getArgument(0),
654 copyFuncName + "_dst", shape,
655 typeparams, attrs);
656 auto declSrc = builder.create<hlfir::DeclareOp>(loc, funcOp.getArgument(1),
657 copyFuncName + "_src", shape,
658 typeparams, attrs);
659 converter.copyVar(loc, declDst.getBase(), declSrc.getBase());
660 builder.create<mlir::func::ReturnOp>(loc);
661 return funcOp;
662}
663
664bool ClauseProcessor::processCopyprivate(
665 mlir::Location currentLocation,
666 mlir::omp::CopyprivateClauseOps &result) const {
667 auto addCopyPrivateVar = [&](Fortran::semantics::Symbol *sym) {
668 mlir::Value symVal = converter.getSymbolAddress(*sym);
669 auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>();
670 if (!declOp)
671 fir::emitFatalError(currentLocation,
672 "COPYPRIVATE is supported only in HLFIR mode");
673 symVal = declOp.getBase();
674 mlir::Type symType = symVal.getType();
675 fir::FortranVariableFlagsEnum attrs =
676 declOp.getFortranAttrs().has_value()
677 ? *declOp.getFortranAttrs()
678 : fir::FortranVariableFlagsEnum::None;
679 mlir::Value cpVar = symVal;
680
681 // CopyPrivate variables must be passed by reference. However, in the case
682 // of assumed shapes/vla the type is not a !fir.ref, but a !fir.box.
683 // In these cases to retrieve the appropriate !fir.ref<!fir.box<...>> to
684 // access the data we need we must perform an alloca and then store to it
685 // and retrieve the data from the new alloca.
686 if (mlir::isa<fir::BaseBoxType>(symType)) {
687 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
688 auto alloca = builder.create<fir::AllocaOp>(currentLocation, symType);
689 builder.create<fir::StoreOp>(currentLocation, symVal, alloca);
690 cpVar = alloca;
691 }
692
693 result.copyprivateVars.push_back(cpVar);
694 mlir::func::FuncOp funcOp =
695 createCopyFunc(currentLocation, converter, cpVar.getType(), attrs);
696 result.copyprivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
697 };
698
699 bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
700 [&](const clause::Copyprivate &clause,
701 const Fortran::parser::CharBlock &) {
702 for (const Object &object : clause.v) {
703 Fortran::semantics::Symbol *sym = object.id();
704 if (const auto *commonDetails =
705 sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
706 for (const auto &mem : commonDetails->objects())
707 addCopyPrivateVar(&*mem);
708 break;
709 }
710 addCopyPrivateVar(sym);
711 }
712 });
713
714 return hasCopyPrivate;
715}
716
717bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
718 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
719
720 return findRepeatableClause<omp::clause::Depend>(
721 [&](const omp::clause::Depend &clause,
722 const Fortran::parser::CharBlock &) {
723 using Depend = omp::clause::Depend;
724 assert(std::holds_alternative<Depend::WithLocators>(clause.u) &&
725 "Only the modern form is handled at the moment");
726 auto &modern = std::get<Depend::WithLocators>(clause.u);
727 auto kind = std::get<Depend::TaskDependenceType>(modern.t);
728 auto &objects = std::get<omp::ObjectList>(modern.t);
729
730 mlir::omp::ClauseTaskDependAttr dependTypeOperand =
731 genDependKindAttr(firOpBuilder, kind);
732 result.dependTypeAttrs.append(objects.size(), dependTypeOperand);
733
734 for (const omp::Object &object : objects) {
735 assert(object.ref() && "Expecting designator");
736
737 if (Fortran::evaluate::ExtractSubstring(*object.ref())) {
738 TODO(converter.getCurrentLocation(),
739 "substring not supported for task depend");
740 } else if (Fortran::evaluate::IsArrayElement(*object.ref())) {
741 TODO(converter.getCurrentLocation(),
742 "array sections not supported for task depend");
743 }
744
745 Fortran::semantics::Symbol *sym = object.id();
746 const mlir::Value variable = converter.getSymbolAddress(*sym);
747 result.dependVars.push_back(variable);
748 }
749 });
750}
751
752bool ClauseProcessor::processHasDeviceAddr(
753 mlir::omp::HasDeviceAddrClauseOps &result,
754 llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
755 llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
756 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
757 const {
758 return findRepeatableClause<omp::clause::HasDeviceAddr>(
759 [&](const omp::clause::HasDeviceAddr &devAddrClause,
760 const Fortran::parser::CharBlock &) {
761 addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars,
762 isDeviceTypes, isDeviceLocs, isDeviceSymbols);
763 });
764}
765
766bool ClauseProcessor::processIf(
767 omp::clause::If::DirectiveNameModifier directiveName,
768 mlir::omp::IfClauseOps &result) const {
769 bool found = false;
770 findRepeatableClause<omp::clause::If>(
771 [&](const omp::clause::If &clause,
772 const Fortran::parser::CharBlock &source) {
773 mlir::Location clauseLocation = converter.genLocation(source);
774 mlir::Value operand = getIfClauseOperand(converter, clause,
775 directiveName, clauseLocation);
776 // Assume that, at most, a single 'if' clause will be applicable to the
777 // given directive.
778 if (operand) {
779 result.ifVar = operand;
780 found = true;
781 }
782 });
783 return found;
784}
785
786bool ClauseProcessor::processIsDevicePtr(
787 mlir::omp::IsDevicePtrClauseOps &result,
788 llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
789 llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
790 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
791 const {
792 return findRepeatableClause<omp::clause::IsDevicePtr>(
793 [&](const omp::clause::IsDevicePtr &devPtrClause,
794 const Fortran::parser::CharBlock &) {
795 addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
796 isDeviceTypes, isDeviceLocs, isDeviceSymbols);
797 });
798}
799
800bool ClauseProcessor::processLink(
801 llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
802 return findRepeatableClause<omp::clause::Link>(
803 [&](const omp::clause::Link &clause, const Fortran::parser::CharBlock &) {
804 // Case: declare target link(var1, var2)...
805 gatherFuncAndVarSyms(
806 clause.v, mlir::omp::DeclareTargetCaptureClause::link, result);
807 });
808}
809
810mlir::omp::MapInfoOp
811createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
812 mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
813 llvm::ArrayRef<mlir::Value> bounds,
814 llvm::ArrayRef<mlir::Value> members, uint64_t mapType,
815 mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
816 bool isVal) {
817 if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
818 baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
819 retTy = baseAddr.getType();
820 }
821
822 mlir::TypeAttr varType = mlir::TypeAttr::get(
823 llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
824
825 mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
826 loc, retTy, baseAddr, varType, varPtrPtr, members, bounds,
827 builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
828 builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
829 builder.getStringAttr(name));
830
831 return op;
832}
833
834bool ClauseProcessor::processMap(
835 mlir::Location currentLocation, Fortran::lower::StatementContext &stmtCtx,
836 mlir::omp::MapClauseOps &result,
837 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms,
838 llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
839 llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
840 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
841 return findRepeatableClause<omp::clause::Map>(
842 [&](const omp::clause::Map &clause,
843 const Fortran::parser::CharBlock &source) {
844 using Map = omp::clause::Map;
845 mlir::Location clauseLocation = converter.genLocation(source);
846 const auto &mapType = std::get<std::optional<Map::MapType>>(clause.t);
847 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
848 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
849 // If the map type is specified, then process it else Tofrom is the
850 // default.
851 if (mapType) {
852 switch (*mapType) {
853 case Map::MapType::To:
854 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
855 break;
856 case Map::MapType::From:
857 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
858 break;
859 case Map::MapType::Tofrom:
860 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
861 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
862 break;
863 case Map::MapType::Alloc:
864 case Map::MapType::Release:
865 // alloc and release is the default map_type for the Target Data
866 // Ops, i.e. if no bits for map_type is supplied then alloc/release
867 // is implicitly assumed based on the target directive. Default
868 // value for Target Data and Enter Data is alloc and for Exit Data
869 // it is release.
870 break;
871 case Map::MapType::Delete:
872 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
873 }
874
875 auto &modTypeMods =
876 std::get<std::optional<Map::MapTypeModifiers>>(clause.t);
877 if (modTypeMods) {
878 if (llvm::is_contained(*modTypeMods, Map::MapTypeModifier::Always))
879 mapTypeBits |=
880 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
881 }
882 } else {
883 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
884 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
885 }
886
887 for (const omp::Object &object : std::get<omp::ObjectList>(clause.t)) {
888 llvm::SmallVector<mlir::Value> bounds;
889 std::stringstream asFortran;
890
891 Fortran::lower::AddrAndBoundsInfo info =
892 Fortran::lower::gatherDataOperandAddrAndBounds<
893 mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
894 converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
895 object.ref(), clauseLocation, asFortran, bounds,
896 treatIndexAsSection);
897
898 auto origSymbol = converter.getSymbolAddress(*object.id());
899 mlir::Value symAddr = info.addr;
900 if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
901 symAddr = origSymbol;
902
903 // Explicit map captures are captured ByRef by default,
904 // optimisation passes may alter this to ByCopy or other capture
905 // types to optimise
906 mlir::Value mapOp = createMapInfoOp(
907 firOpBuilder, clauseLocation, symAddr, mlir::Value{},
908 asFortran.str(), bounds, {},
909 static_cast<
910 std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
911 mapTypeBits),
912 mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
913
914 result.mapVars.push_back(mapOp);
915
916 if (mapSyms)
917 mapSyms->push_back(object.id());
918 if (mapSymLocs)
919 mapSymLocs->push_back(symAddr.getLoc());
920 if (mapSymTypes)
921 mapSymTypes->push_back(symAddr.getType());
922 }
923 });
924}
925
926bool ClauseProcessor::processReduction(
927 mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
928 llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
929 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *outReductionSyms)
930 const {
931 return findRepeatableClause<omp::clause::Reduction>(
932 [&](const omp::clause::Reduction &clause,
933 const Fortran::parser::CharBlock &) {
934 // Use local lists of reductions to prevent variables from other
935 // already-processed reduction clauses from impacting this reduction.
936 // For example, the whole `reductionVars` array is queried to decide
937 // whether to do the reduction byref.
938 llvm::SmallVector<mlir::Value> reductionVars;
939 llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
940 llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
941 ReductionProcessor rp;
942 rp.addDeclareReduction(currentLocation, converter, clause,
943 reductionVars, reductionDeclSymbols,
944 outReductionSyms ? &reductionSyms : nullptr);
945
946 // Copy local lists into the output.
947 llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
948 llvm::copy(reductionDeclSymbols,
949 std::back_inserter(result.reductionDeclSymbols));
950
951 if (outReductionTypes) {
952 outReductionTypes->reserve(outReductionTypes->size() +
953 reductionVars.size());
954 llvm::transform(reductionVars, std::back_inserter(*outReductionTypes),
955 [](mlir::Value v) { return v.getType(); });
956 }
957
958 if (outReductionSyms)
959 llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms));
960 });
961}
962
963bool ClauseProcessor::processSectionsReduction(
964 mlir::Location currentLocation, mlir::omp::ReductionClauseOps &) const {
965 return findRepeatableClause<omp::clause::Reduction>(
966 [&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) {
967 TODO(currentLocation, "OMPC_Reduction");
968 });
969}
970
971bool ClauseProcessor::processTo(
972 llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
973 return findRepeatableClause<omp::clause::To>(
974 [&](const omp::clause::To &clause, const Fortran::parser::CharBlock &) {
975 // Case: declare target to(func, var1, var2)...
976 gatherFuncAndVarSyms(std::get<ObjectList>(clause.t),
977 mlir::omp::DeclareTargetCaptureClause::to, result);
978 });
979}
980
981bool ClauseProcessor::processEnter(
982 llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
983 return findRepeatableClause<omp::clause::Enter>(
984 [&](const omp::clause::Enter &clause,
985 const Fortran::parser::CharBlock &) {
986 // Case: declare target enter(func, var1, var2)...
987 gatherFuncAndVarSyms(
988 clause.v, mlir::omp::DeclareTargetCaptureClause::enter, result);
989 });
990}
991
992bool ClauseProcessor::processUseDeviceAddr(
993 mlir::omp::UseDeviceClauseOps &result,
994 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
995 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
996 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
997 const {
998 return findRepeatableClause<omp::clause::UseDeviceAddr>(
999 [&](const omp::clause::UseDeviceAddr &clause,
1000 const Fortran::parser::CharBlock &) {
1001 addUseDeviceClause(converter, clause.v, result.useDeviceAddrVars,
1002 useDeviceTypes, useDeviceLocs, useDeviceSyms);
1003 });
1004}
1005
1006bool ClauseProcessor::processUseDevicePtr(
1007 mlir::omp::UseDeviceClauseOps &result,
1008 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
1009 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
1010 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
1011 const {
1012 return findRepeatableClause<omp::clause::UseDevicePtr>(
1013 [&](const omp::clause::UseDevicePtr &clause,
1014 const Fortran::parser::CharBlock &) {
1015 addUseDeviceClause(converter, clause.v, result.useDevicePtrVars,
1016 useDeviceTypes, useDeviceLocs, useDeviceSyms);
1017 });
1018}
1019
1020} // namespace omp
1021} // namespace lower
1022} // namespace Fortran
1023

source code of flang/lib/Lower/OpenMP/ClauseProcessor.cpp