1//===-- ClauseProcessor.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "ClauseProcessor.h"
14#include "Clauses.h"
15#include "Utils.h"
16
17#include "flang/Lower/ConvertExprToHLFIR.h"
18#include "flang/Lower/PFTBuilder.h"
19#include "flang/Parser/tools.h"
20#include "flang/Semantics/tools.h"
21#include "llvm/Frontend/OpenMP/OMP.h.inc"
22#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
23
24namespace Fortran {
25namespace lower {
26namespace omp {
27
28/// Check for unsupported map operand types.
29static void checkMapType(mlir::Location location, mlir::Type type) {
30 if (auto refType = mlir::dyn_cast<fir::ReferenceType>(type))
31 type = refType.getElementType();
32 if (auto boxType = mlir::dyn_cast_or_null<fir::BoxType>(type))
33 if (!mlir::isa<fir::PointerType>(boxType.getElementType()))
34 TODO(location, "OMPD_target_data MapOperand BoxType");
35}
36
37static mlir::omp::ScheduleModifier
38translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) {
39 switch (m) {
40 case omp::clause::Schedule::OrderingModifier::Monotonic:
41 return mlir::omp::ScheduleModifier::monotonic;
42 case omp::clause::Schedule::OrderingModifier::Nonmonotonic:
43 return mlir::omp::ScheduleModifier::nonmonotonic;
44 }
45 return mlir::omp::ScheduleModifier::none;
46}
47
48static mlir::omp::ScheduleModifier
49getScheduleModifier(const omp::clause::Schedule &clause) {
50 using Schedule = omp::clause::Schedule;
51 const auto &modifier =
52 std::get<std::optional<Schedule::OrderingModifier>>(clause.t);
53 if (modifier)
54 return translateScheduleModifier(*modifier);
55 return mlir::omp::ScheduleModifier::none;
56}
57
58static mlir::omp::ScheduleModifier
59getSimdModifier(const omp::clause::Schedule &clause) {
60 using Schedule = omp::clause::Schedule;
61 const auto &modifier =
62 std::get<std::optional<Schedule::ChunkModifier>>(clause.t);
63 if (modifier && *modifier == Schedule::ChunkModifier::Simd)
64 return mlir::omp::ScheduleModifier::simd;
65 return mlir::omp::ScheduleModifier::none;
66}
67
68static void
69genAllocateClause(lower::AbstractConverter &converter,
70 const omp::clause::Allocate &clause,
71 llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
72 llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
73 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
74 mlir::Location currentLocation = converter.getCurrentLocation();
75 lower::StatementContext stmtCtx;
76
77 auto &objects = std::get<omp::ObjectList>(clause.t);
78
79 using Allocate = omp::clause::Allocate;
80 // ALIGN in this context is unimplemented
81 if (std::get<std::optional<Allocate::AlignModifier>>(clause.t))
82 TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
83
84 // Check if allocate clause has allocator specified. If so, add it
85 // to list of allocators, otherwise, add default allocator to
86 // list of allocators.
87 using ComplexModifier = Allocate::AllocatorComplexModifier;
88 if (auto &mod = std::get<std::optional<ComplexModifier>>(clause.t)) {
89 mlir::Value operand = fir::getBase(converter.genExprValue(mod->v, stmtCtx));
90 allocatorOperands.append(objects.size(), operand);
91 } else {
92 mlir::Value operand = firOpBuilder.createIntegerConstant(
93 currentLocation, firOpBuilder.getI32Type(), 1);
94 allocatorOperands.append(objects.size(), operand);
95 }
96
97 genObjectList(objects, converter, allocateOperands);
98}
99
100static mlir::omp::ClauseBindKindAttr
101genBindKindAttr(fir::FirOpBuilder &firOpBuilder,
102 const omp::clause::Bind &clause) {
103 mlir::omp::ClauseBindKind bindKind;
104 switch (clause.v) {
105 case omp::clause::Bind::Binding::Teams:
106 bindKind = mlir::omp::ClauseBindKind::Teams;
107 break;
108 case omp::clause::Bind::Binding::Parallel:
109 bindKind = mlir::omp::ClauseBindKind::Parallel;
110 break;
111 case omp::clause::Bind::Binding::Thread:
112 bindKind = mlir::omp::ClauseBindKind::Thread;
113 break;
114 }
115 return mlir::omp::ClauseBindKindAttr::get(firOpBuilder.getContext(),
116 bindKind);
117}
118
119static mlir::omp::ClauseProcBindKindAttr
120genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder,
121 const omp::clause::ProcBind &clause) {
122 mlir::omp::ClauseProcBindKind procBindKind;
123 switch (clause.v) {
124 case omp::clause::ProcBind::AffinityPolicy::Master:
125 procBindKind = mlir::omp::ClauseProcBindKind::Master;
126 break;
127 case omp::clause::ProcBind::AffinityPolicy::Close:
128 procBindKind = mlir::omp::ClauseProcBindKind::Close;
129 break;
130 case omp::clause::ProcBind::AffinityPolicy::Spread:
131 procBindKind = mlir::omp::ClauseProcBindKind::Spread;
132 break;
133 case omp::clause::ProcBind::AffinityPolicy::Primary:
134 procBindKind = mlir::omp::ClauseProcBindKind::Primary;
135 break;
136 }
137 return mlir::omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(),
138 procBindKind);
139}
140
141static mlir::omp::ClauseTaskDependAttr
142genDependKindAttr(lower::AbstractConverter &converter,
143 const omp::clause::DependenceType kind) {
144 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
145 mlir::Location currentLocation = converter.getCurrentLocation();
146
147 mlir::omp::ClauseTaskDepend pbKind;
148 switch (kind) {
149 case omp::clause::DependenceType::In:
150 pbKind = mlir::omp::ClauseTaskDepend::taskdependin;
151 break;
152 case omp::clause::DependenceType::Out:
153 pbKind = mlir::omp::ClauseTaskDepend::taskdependout;
154 break;
155 case omp::clause::DependenceType::Inout:
156 pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
157 break;
158 case omp::clause::DependenceType::Mutexinoutset:
159 pbKind = mlir::omp::ClauseTaskDepend::taskdependmutexinoutset;
160 break;
161 case omp::clause::DependenceType::Inoutset:
162 pbKind = mlir::omp::ClauseTaskDepend::taskdependinoutset;
163 break;
164 case omp::clause::DependenceType::Depobj:
165 TODO(currentLocation, "DEPOBJ dependence-type");
166 break;
167 case omp::clause::DependenceType::Sink:
168 case omp::clause::DependenceType::Source:
169 llvm_unreachable("unhandled parser task dependence type");
170 break;
171 }
172 return mlir::omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(),
173 pbKind);
174}
175
176static mlir::Value
177getIfClauseOperand(lower::AbstractConverter &converter,
178 const omp::clause::If &clause,
179 omp::clause::If::DirectiveNameModifier directiveName,
180 mlir::Location clauseLocation) {
181 // Only consider the clause if it's intended for the given directive.
182 auto &directive =
183 std::get<std::optional<omp::clause::If::DirectiveNameModifier>>(clause.t);
184 if (directive && directive.value() != directiveName)
185 return nullptr;
186
187 lower::StatementContext stmtCtx;
188 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
189 mlir::Value ifVal = fir::getBase(
190 converter.genExprValue(std::get<omp::SomeExpr>(clause.t), stmtCtx));
191 return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
192 ifVal);
193}
194
195static void addUseDeviceClause(
196 lower::AbstractConverter &converter, const omp::ObjectList &objects,
197 llvm::SmallVectorImpl<mlir::Value> &operands,
198 llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
199 genObjectList(objects, converter, operands);
200 for (mlir::Value &operand : operands)
201 checkMapType(operand.getLoc(), operand.getType());
202
203 for (const omp::Object &object : objects)
204 useDeviceSyms.push_back(object.sym());
205}
206
207//===----------------------------------------------------------------------===//
208// ClauseProcessor unique clauses
209//===----------------------------------------------------------------------===//
210
211bool ClauseProcessor::processBare(mlir::omp::BareClauseOps &result) const {
212 return markClauseOccurrence<omp::clause::OmpxBare>(result.bare);
213}
214
215bool ClauseProcessor::processBind(mlir::omp::BindClauseOps &result) const {
216 if (auto *clause = findUniqueClause<omp::clause::Bind>()) {
217 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
218 result.bindKind = genBindKindAttr(firOpBuilder, *clause);
219 return true;
220 }
221 return false;
222}
223
224bool ClauseProcessor::processCancelDirectiveName(
225 mlir::omp::CancelDirectiveNameClauseOps &result) const {
226 using ConstructType = mlir::omp::ClauseCancellationConstructType;
227 mlir::MLIRContext *context = &converter.getMLIRContext();
228
229 ConstructType directive;
230 if (auto *clause = findUniqueClause<omp::CancellationConstructType>()) {
231 switch (clause->v) {
232 case llvm::omp::OMP_CANCELLATION_CONSTRUCT_Parallel:
233 directive = mlir::omp::ClauseCancellationConstructType::Parallel;
234 break;
235 case llvm::omp::OMP_CANCELLATION_CONSTRUCT_Loop:
236 directive = mlir::omp::ClauseCancellationConstructType::Loop;
237 break;
238 case llvm::omp::OMP_CANCELLATION_CONSTRUCT_Sections:
239 directive = mlir::omp::ClauseCancellationConstructType::Sections;
240 break;
241 case llvm::omp::OMP_CANCELLATION_CONSTRUCT_Taskgroup:
242 directive = mlir::omp::ClauseCancellationConstructType::Taskgroup;
243 break;
244 case llvm::omp::OMP_CANCELLATION_CONSTRUCT_None:
245 llvm_unreachable("OMP_CANCELLATION_CONSTRUCT_None");
246 break;
247 }
248 } else {
249 llvm_unreachable("cancel construct missing cancellation construct type");
250 }
251
252 result.cancelDirective =
253 mlir::omp::ClauseCancellationConstructTypeAttr::get(context, directive);
254 return true;
255}
256
257bool ClauseProcessor::processCollapse(
258 mlir::Location currentLocation, lower::pft::Evaluation &eval,
259 mlir::omp::LoopRelatedClauseOps &result,
260 llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
261 return collectLoopRelatedInfo(converter, currentLocation, eval, clauses,
262 result, iv);
263}
264
265bool ClauseProcessor::processDevice(lower::StatementContext &stmtCtx,
266 mlir::omp::DeviceClauseOps &result) const {
267 const parser::CharBlock *source = nullptr;
268 if (auto *clause = findUniqueClause<omp::clause::Device>(&source)) {
269 mlir::Location clauseLocation = converter.genLocation(*source);
270 if (auto deviceModifier =
271 std::get<std::optional<omp::clause::Device::DeviceModifier>>(
272 clause->t)) {
273 if (deviceModifier == omp::clause::Device::DeviceModifier::Ancestor) {
274 TODO(clauseLocation, "OMPD_target Device Modifier Ancestor");
275 }
276 }
277 const auto &deviceExpr = std::get<omp::SomeExpr>(clause->t);
278 result.device = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
279 return true;
280 }
281 return false;
282}
283
284bool ClauseProcessor::processDeviceType(
285 mlir::omp::DeviceTypeClauseOps &result) const {
286 if (auto *clause = findUniqueClause<omp::clause::DeviceType>()) {
287 // Case: declare target ... device_type(any | host | nohost)
288 switch (clause->v) {
289 case omp::clause::DeviceType::DeviceTypeDescription::Nohost:
290 result.deviceType = mlir::omp::DeclareTargetDeviceType::nohost;
291 break;
292 case omp::clause::DeviceType::DeviceTypeDescription::Host:
293 result.deviceType = mlir::omp::DeclareTargetDeviceType::host;
294 break;
295 case omp::clause::DeviceType::DeviceTypeDescription::Any:
296 result.deviceType = mlir::omp::DeclareTargetDeviceType::any;
297 break;
298 }
299 return true;
300 }
301 return false;
302}
303
304bool ClauseProcessor::processDistSchedule(
305 lower::StatementContext &stmtCtx,
306 mlir::omp::DistScheduleClauseOps &result) const {
307 if (auto *clause = findUniqueClause<omp::clause::DistSchedule>()) {
308 result.distScheduleStatic = converter.getFirOpBuilder().getUnitAttr();
309 const auto &chunkSize = std::get<std::optional<ExprTy>>(clause->t);
310 if (chunkSize)
311 result.distScheduleChunkSize =
312 fir::getBase(converter.genExprValue(*chunkSize, stmtCtx));
313 return true;
314 }
315 return false;
316}
317
318bool ClauseProcessor::processExclusive(
319 mlir::Location currentLocation,
320 mlir::omp::ExclusiveClauseOps &result) const {
321 if (auto *clause = findUniqueClause<omp::clause::Exclusive>()) {
322 for (const Object &object : clause->v) {
323 const semantics::Symbol *symbol = object.sym();
324 mlir::Value symVal = converter.getSymbolAddress(*symbol);
325 result.exclusiveVars.push_back(symVal);
326 }
327 return true;
328 }
329 return false;
330}
331
332bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
333 mlir::omp::FilterClauseOps &result) const {
334 if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
335 result.filteredThreadId =
336 fir::getBase(converter.genExprValue(clause->v, stmtCtx));
337 return true;
338 }
339 return false;
340}
341
342bool ClauseProcessor::processFinal(lower::StatementContext &stmtCtx,
343 mlir::omp::FinalClauseOps &result) const {
344 const parser::CharBlock *source = nullptr;
345 if (auto *clause = findUniqueClause<omp::clause::Final>(&source)) {
346 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
347 mlir::Location clauseLocation = converter.genLocation(*source);
348
349 mlir::Value finalVal =
350 fir::getBase(converter.genExprValue(clause->v, stmtCtx));
351 result.final = firOpBuilder.createConvert(
352 clauseLocation, firOpBuilder.getI1Type(), finalVal);
353 return true;
354 }
355 return false;
356}
357
358bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
359 if (auto *clause = findUniqueClause<omp::clause::Hint>()) {
360 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
361 int64_t hintValue = *evaluate::ToInt64(clause->v);
362 result.hint = firOpBuilder.getI64IntegerAttr(hintValue);
363 return true;
364 }
365 return false;
366}
367
368bool ClauseProcessor::processInclusive(
369 mlir::Location currentLocation,
370 mlir::omp::InclusiveClauseOps &result) const {
371 if (auto *clause = findUniqueClause<omp::clause::Inclusive>()) {
372 for (const Object &object : clause->v) {
373 const semantics::Symbol *symbol = object.sym();
374 mlir::Value symVal = converter.getSymbolAddress(*symbol);
375 result.inclusiveVars.push_back(symVal);
376 }
377 return true;
378 }
379 return false;
380}
381
382bool ClauseProcessor::processMergeable(
383 mlir::omp::MergeableClauseOps &result) const {
384 return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
385}
386
387bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const {
388 return markClauseOccurrence<omp::clause::Nowait>(result.nowait);
389}
390
391bool ClauseProcessor::processNumTasks(
392 lower::StatementContext &stmtCtx,
393 mlir::omp::NumTasksClauseOps &result) const {
394 using NumTasks = omp::clause::NumTasks;
395 if (auto *clause = findUniqueClause<NumTasks>()) {
396 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
397 mlir::MLIRContext *context = firOpBuilder.getContext();
398 const auto &modifier =
399 std::get<std::optional<NumTasks::Prescriptiveness>>(clause->t);
400 if (modifier && *modifier == NumTasks::Prescriptiveness::Strict) {
401 result.numTasksMod = mlir::omp::ClauseNumTasksTypeAttr::get(
402 context, mlir::omp::ClauseNumTasksType::Strict);
403 }
404 const auto &numtasksExpr = std::get<omp::SomeExpr>(clause->t);
405 result.numTasks =
406 fir::getBase(converter.genExprValue(numtasksExpr, stmtCtx));
407 return true;
408 }
409 return false;
410}
411
412bool ClauseProcessor::processNumTeams(
413 lower::StatementContext &stmtCtx,
414 mlir::omp::NumTeamsClauseOps &result) const {
415 // TODO Get lower and upper bounds for num_teams when parser is updated to
416 // accept both.
417 if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
418 // The num_teams directive accepts a list of team lower/upper bounds.
419 // This is an extension to support grid specification for ompx_bare.
420 // Here, only expect a single element in the list.
421 assert(clause->v.size() == 1);
422 // auto lowerBound = std::get<std::optional<ExprTy>>(clause->v[0]->t);
423 auto &upperBound = std::get<ExprTy>(clause->v[0].t);
424 result.numTeamsUpper =
425 fir::getBase(converter.genExprValue(upperBound, stmtCtx));
426 return true;
427 }
428 return false;
429}
430
431bool ClauseProcessor::processNumThreads(
432 lower::StatementContext &stmtCtx,
433 mlir::omp::NumThreadsClauseOps &result) const {
434 if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
435 // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
436 result.numThreads =
437 fir::getBase(converter.genExprValue(clause->v, stmtCtx));
438 return true;
439 }
440 return false;
441}
442
443bool ClauseProcessor::processOrder(mlir::omp::OrderClauseOps &result) const {
444 using Order = omp::clause::Order;
445 if (auto *clause = findUniqueClause<Order>()) {
446 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
447 result.order = mlir::omp::ClauseOrderKindAttr::get(
448 firOpBuilder.getContext(), mlir::omp::ClauseOrderKind::Concurrent);
449 const auto &modifier =
450 std::get<std::optional<Order::OrderModifier>>(clause->t);
451 if (modifier && *modifier == Order::OrderModifier::Unconstrained) {
452 result.orderMod = mlir::omp::OrderModifierAttr::get(
453 firOpBuilder.getContext(), mlir::omp::OrderModifier::unconstrained);
454 } else {
455 // "If order-modifier is not unconstrained, the behavior is as if the
456 // reproducible modifier is present."
457 result.orderMod = mlir::omp::OrderModifierAttr::get(
458 firOpBuilder.getContext(), mlir::omp::OrderModifier::reproducible);
459 }
460 return true;
461 }
462 return false;
463}
464
465bool ClauseProcessor::processOrdered(
466 mlir::omp::OrderedClauseOps &result) const {
467 if (auto *clause = findUniqueClause<omp::clause::Ordered>()) {
468 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
469 int64_t orderedClauseValue = 0l;
470 if (clause->v.has_value())
471 orderedClauseValue = *evaluate::ToInt64(*clause->v);
472 result.ordered = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
473 return true;
474 }
475 return false;
476}
477
478bool ClauseProcessor::processPriority(
479 lower::StatementContext &stmtCtx,
480 mlir::omp::PriorityClauseOps &result) const {
481 if (auto *clause = findUniqueClause<omp::clause::Priority>()) {
482 result.priority = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
483 return true;
484 }
485 return false;
486}
487
488bool ClauseProcessor::processDetach(mlir::omp::DetachClauseOps &result) const {
489 if (auto *clause = findUniqueClause<omp::clause::Detach>()) {
490 semantics::Symbol *sym = clause->v.sym();
491 mlir::Value symVal = converter.getSymbolAddress(*sym);
492 result.eventHandle = symVal;
493 return true;
494 }
495 return false;
496}
497
498bool ClauseProcessor::processProcBind(
499 mlir::omp::ProcBindClauseOps &result) const {
500 if (auto *clause = findUniqueClause<omp::clause::ProcBind>()) {
501 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
502 result.procBindKind = genProcBindKindAttr(firOpBuilder, *clause);
503 return true;
504 }
505 return false;
506}
507
508bool ClauseProcessor::processSafelen(
509 mlir::omp::SafelenClauseOps &result) const {
510 if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
511 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
512 const std::optional<std::int64_t> safelenVal = evaluate::ToInt64(clause->v);
513 result.safelen = firOpBuilder.getI64IntegerAttr(*safelenVal);
514 return true;
515 }
516 return false;
517}
518
519bool ClauseProcessor::processSchedule(
520 lower::StatementContext &stmtCtx,
521 mlir::omp::ScheduleClauseOps &result) const {
522 if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
523 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
524 mlir::MLIRContext *context = firOpBuilder.getContext();
525 const auto &scheduleType = std::get<omp::clause::Schedule::Kind>(clause->t);
526
527 mlir::omp::ClauseScheduleKind scheduleKind;
528 switch (scheduleType) {
529 case omp::clause::Schedule::Kind::Static:
530 scheduleKind = mlir::omp::ClauseScheduleKind::Static;
531 break;
532 case omp::clause::Schedule::Kind::Dynamic:
533 scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic;
534 break;
535 case omp::clause::Schedule::Kind::Guided:
536 scheduleKind = mlir::omp::ClauseScheduleKind::Guided;
537 break;
538 case omp::clause::Schedule::Kind::Auto:
539 scheduleKind = mlir::omp::ClauseScheduleKind::Auto;
540 break;
541 case omp::clause::Schedule::Kind::Runtime:
542 scheduleKind = mlir::omp::ClauseScheduleKind::Runtime;
543 break;
544 }
545
546 result.scheduleKind =
547 mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
548
549 mlir::omp::ScheduleModifier scheduleMod = getScheduleModifier(*clause);
550 if (scheduleMod != mlir::omp::ScheduleModifier::none)
551 result.scheduleMod =
552 mlir::omp::ScheduleModifierAttr::get(context, scheduleMod);
553
554 if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none)
555 result.scheduleSimd = firOpBuilder.getUnitAttr();
556
557 if (const auto &chunkExpr = std::get<omp::MaybeExpr>(clause->t))
558 result.scheduleChunk =
559 fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
560
561 return true;
562 }
563 return false;
564}
565
566bool ClauseProcessor::processSimdlen(
567 mlir::omp::SimdlenClauseOps &result) const {
568 if (auto *clause = findUniqueClause<omp::clause::Simdlen>()) {
569 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
570 const std::optional<std::int64_t> simdlenVal = evaluate::ToInt64(clause->v);
571 result.simdlen = firOpBuilder.getI64IntegerAttr(*simdlenVal);
572 return true;
573 }
574 return false;
575}
576
577bool ClauseProcessor::processThreadLimit(
578 lower::StatementContext &stmtCtx,
579 mlir::omp::ThreadLimitClauseOps &result) const {
580 if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
581 result.threadLimit =
582 fir::getBase(converter.genExprValue(clause->v, stmtCtx));
583 return true;
584 }
585 return false;
586}
587
588bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const {
589 return markClauseOccurrence<omp::clause::Untied>(result.untied);
590}
591
592//===----------------------------------------------------------------------===//
593// ClauseProcessor repeatable clauses
594//===----------------------------------------------------------------------===//
595static llvm::StringMap<bool> getTargetFeatures(mlir::ModuleOp module) {
596 llvm::StringMap<bool> featuresMap;
597 llvm::SmallVector<llvm::StringRef> targetFeaturesVec;
598 if (mlir::LLVM::TargetFeaturesAttr features =
599 fir::getTargetFeatures(module)) {
600 llvm::ArrayRef<mlir::StringAttr> featureAttrs = features.getFeatures();
601 for (auto &featureAttr : featureAttrs) {
602 llvm::StringRef featureKeyString = featureAttr.strref();
603 featuresMap[featureKeyString.substr(1)] = (featureKeyString[0] == '+');
604 }
605 }
606 return featuresMap;
607}
608
609static void
610addAlignedClause(lower::AbstractConverter &converter,
611 const omp::clause::Aligned &clause,
612 llvm::SmallVectorImpl<mlir::Value> &alignedVars,
613 llvm::SmallVectorImpl<mlir::Attribute> &alignments) {
614 using Aligned = omp::clause::Aligned;
615 lower::StatementContext stmtCtx;
616 mlir::IntegerAttr alignmentValueAttr;
617 int64_t alignment = 0;
618 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
619
620 if (auto &alignmentValueParserExpr =
621 std::get<std::optional<Aligned::Alignment>>(clause.t)) {
622 mlir::Value operand = fir::getBase(
623 converter.genExprValue(*alignmentValueParserExpr, stmtCtx));
624 alignment = *fir::getIntIfConstant(operand);
625 } else {
626 llvm::StringMap<bool> featuresMap = getTargetFeatures(builder.getModule());
627 llvm::Triple triple = fir::getTargetTriple(builder.getModule());
628 alignment =
629 llvm::OpenMPIRBuilder::getOpenMPDefaultSimdAlign(TargetTriple: triple, Features: featuresMap);
630 }
631
632 // The default alignment for some targets is equal to 0.
633 // Do not generate alignment assumption if alignment is less than or equal to
634 // 0.
635 if (alignment > 0) {
636 // alignment value must be power of 2
637 assert((alignment & (alignment - 1)) == 0 && "alignment is not power of 2");
638 auto &objects = std::get<omp::ObjectList>(clause.t);
639 if (!objects.empty())
640 genObjectList(objects, converter, alignedVars);
641 alignmentValueAttr = builder.getI64IntegerAttr(alignment);
642 // All the list items in a aligned clause will have same alignment
643 for (std::size_t i = 0; i < objects.size(); i++)
644 alignments.push_back(alignmentValueAttr);
645 }
646}
647
648bool ClauseProcessor::processAligned(
649 mlir::omp::AlignedClauseOps &result) const {
650 return findRepeatableClause<omp::clause::Aligned>(
651 [&](const omp::clause::Aligned &clause, const parser::CharBlock &) {
652 addAlignedClause(converter, clause, result.alignedVars,
653 result.alignments);
654 });
655}
656
657bool ClauseProcessor::processAllocate(
658 mlir::omp::AllocateClauseOps &result) const {
659 return findRepeatableClause<omp::clause::Allocate>(
660 [&](const omp::clause::Allocate &clause, const parser::CharBlock &) {
661 genAllocateClause(converter, clause, result.allocatorVars,
662 result.allocateVars);
663 });
664}
665
666bool ClauseProcessor::processCopyin() const {
667 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
668 mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint();
669 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
670 auto checkAndCopyHostAssociateVar =
671 [&](semantics::Symbol *sym,
672 mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) {
673 assert(sym->has<semantics::HostAssocDetails>() &&
674 "No host-association found");
675 if (converter.isPresentShallowLookup(*sym))
676 converter.copyHostAssociateVar(*sym, copyAssignIP);
677 };
678 bool hasCopyin = findRepeatableClause<omp::clause::Copyin>(
679 [&](const omp::clause::Copyin &clause, const parser::CharBlock &) {
680 for (const omp::Object &object : clause.v) {
681 semantics::Symbol *sym = object.sym();
682 assert(sym && "Expecting symbol");
683 if (const auto *commonDetails =
684 sym->detailsIf<semantics::CommonBlockDetails>()) {
685 for (const auto &mem : commonDetails->objects())
686 checkAndCopyHostAssociateVar(&*mem, &insPt);
687 break;
688 }
689
690 assert(sym->has<semantics::HostAssocDetails>() &&
691 "No host-association found");
692 checkAndCopyHostAssociateVar(sym);
693 }
694 });
695
696 // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to
697 // the execution of the associated structured block. Emit implicit barrier to
698 // synchronize threads and avoid data races on propagation master's thread
699 // values of threadprivate variables to local instances of that variables of
700 // all other implicit threads.
701
702 // All copies are inserted at either "insPt" (i.e. immediately before it),
703 // or at some earlier point (as determined by "copyHostAssociateVar").
704 // Unless the insertion point is given to "copyHostAssociateVar" explicitly,
705 // it will not restore the builder's insertion point. Since the copies may be
706 // inserted in any order (not following the execution order), make sure the
707 // barrier is inserted following all of them.
708 firOpBuilder.restoreInsertionPoint(insPt);
709 if (hasCopyin)
710 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
711 return hasCopyin;
712}
713
714/// Class that extracts information from the specified type.
715class TypeInfo {
716public:
717 TypeInfo(mlir::Type ty) { typeScan(ty); }
718
719 // Returns the length of character types.
720 std::optional<fir::CharacterType::LenType> getCharLength() const {
721 return charLen;
722 }
723
724 // Returns the shape of array types.
725 llvm::ArrayRef<int64_t> getShape() const { return shape; }
726
727 // Is the type inside a box?
728 bool isBox() const { return inBox; }
729
730private:
731 void typeScan(mlir::Type type);
732
733 std::optional<fir::CharacterType::LenType> charLen;
734 llvm::SmallVector<int64_t> shape;
735 bool inBox = false;
736};
737
738void TypeInfo::typeScan(mlir::Type ty) {
739 if (auto sty = mlir::dyn_cast<fir::SequenceType>(ty)) {
740 assert(shape.empty() && !sty.getShape().empty());
741 shape = llvm::SmallVector<int64_t>(sty.getShape());
742 typeScan(sty.getEleTy());
743 } else if (auto bty = mlir::dyn_cast<fir::BoxType>(ty)) {
744 inBox = true;
745 typeScan(bty.getEleTy());
746 } else if (auto cty = mlir::dyn_cast<fir::ClassType>(ty)) {
747 inBox = true;
748 typeScan(cty.getEleTy());
749 } else if (auto cty = mlir::dyn_cast<fir::CharacterType>(ty)) {
750 charLen = cty.getLen();
751 } else if (auto hty = mlir::dyn_cast<fir::HeapType>(ty)) {
752 typeScan(hty.getEleTy());
753 } else if (auto pty = mlir::dyn_cast<fir::PointerType>(ty)) {
754 typeScan(pty.getEleTy());
755 } else {
756 // The scan ends when reaching any built-in, record or boxproc type.
757 assert(ty.isIntOrIndexOrFloat() || mlir::isa<mlir::ComplexType>(ty) ||
758 mlir::isa<fir::LogicalType>(ty) || mlir::isa<fir::RecordType>(ty) ||
759 mlir::isa<fir::BoxProcType>(ty));
760 }
761}
762
763// Create a function that performs a copy between two variables, compatible
764// with their types and attributes.
765static mlir::func::FuncOp
766createCopyFunc(mlir::Location loc, lower::AbstractConverter &converter,
767 mlir::Type varType, fir::FortranVariableFlagsEnum varAttrs) {
768 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
769 mlir::ModuleOp module = builder.getModule();
770 mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
771 TypeInfo typeInfo(eleTy);
772 std::string copyFuncName =
773 fir::getTypeAsString(eleTy, builder.getKindMap(), "_copy");
774
775 if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
776 return decl;
777
778 // create function
779 mlir::OpBuilder::InsertionGuard guard(builder);
780 mlir::OpBuilder modBuilder(module.getBodyRegion());
781 llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
782 auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
783 mlir::func::FuncOp funcOp =
784 modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
785 funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
786 fir::factory::setInternalLinkage(funcOp);
787 builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
788 {loc, loc});
789 builder.setInsertionPointToStart(&funcOp.getRegion().back());
790 // generate body
791 fir::FortranVariableFlagsAttr attrs;
792 if (varAttrs != fir::FortranVariableFlagsEnum::None)
793 attrs = fir::FortranVariableFlagsAttr::get(builder.getContext(), varAttrs);
794 llvm::SmallVector<mlir::Value> typeparams;
795 if (typeInfo.getCharLength().has_value()) {
796 mlir::Value charLen = builder.createIntegerConstant(
797 loc, builder.getCharacterLengthType(), *typeInfo.getCharLength());
798 typeparams.push_back(charLen);
799 }
800 mlir::Value shape;
801 if (!typeInfo.isBox() && !typeInfo.getShape().empty()) {
802 llvm::SmallVector<mlir::Value> extents;
803 for (auto extent : typeInfo.getShape())
804 extents.push_back(
805 builder.createIntegerConstant(loc, builder.getIndexType(), extent));
806 shape = builder.create<fir::ShapeOp>(loc, extents);
807 }
808 auto declDst = builder.create<hlfir::DeclareOp>(
809 loc, funcOp.getArgument(0), copyFuncName + "_dst", shape, typeparams,
810 /*dummy_scope=*/nullptr, attrs);
811 auto declSrc = builder.create<hlfir::DeclareOp>(
812 loc, funcOp.getArgument(1), copyFuncName + "_src", shape, typeparams,
813 /*dummy_scope=*/nullptr, attrs);
814 converter.copyVar(loc, declDst.getBase(), declSrc.getBase(), varAttrs);
815 builder.create<mlir::func::ReturnOp>(loc);
816 return funcOp;
817}
818
819bool ClauseProcessor::processCopyprivate(
820 mlir::Location currentLocation,
821 mlir::omp::CopyprivateClauseOps &result) const {
822 auto addCopyPrivateVar = [&](semantics::Symbol *sym) {
823 mlir::Value symVal = converter.getSymbolAddress(*sym);
824 auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>();
825 if (!declOp)
826 fir::emitFatalError(currentLocation,
827 "COPYPRIVATE is supported only in HLFIR mode");
828 symVal = declOp.getBase();
829 mlir::Type symType = symVal.getType();
830 fir::FortranVariableFlagsEnum attrs =
831 declOp.getFortranAttrs().has_value()
832 ? *declOp.getFortranAttrs()
833 : fir::FortranVariableFlagsEnum::None;
834 mlir::Value cpVar = symVal;
835
836 // CopyPrivate variables must be passed by reference. However, in the case
837 // of assumed shapes/vla the type is not a !fir.ref, but a !fir.box.
838 // In these cases to retrieve the appropriate !fir.ref<!fir.box<...>> to
839 // access the data we need we must perform an alloca and then store to it
840 // and retrieve the data from the new alloca.
841 if (mlir::isa<fir::BaseBoxType>(symType)) {
842 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
843 auto alloca = builder.create<fir::AllocaOp>(currentLocation, symType);
844 builder.create<fir::StoreOp>(currentLocation, symVal, alloca);
845 cpVar = alloca;
846 }
847
848 result.copyprivateVars.push_back(cpVar);
849 mlir::func::FuncOp funcOp =
850 createCopyFunc(currentLocation, converter, cpVar.getType(), attrs);
851 result.copyprivateSyms.push_back(mlir::SymbolRefAttr::get(funcOp));
852 };
853
854 bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
855 [&](const clause::Copyprivate &clause, const parser::CharBlock &) {
856 for (const Object &object : clause.v) {
857 semantics::Symbol *sym = object.sym();
858 if (const auto *commonDetails =
859 sym->detailsIf<semantics::CommonBlockDetails>()) {
860 for (const auto &mem : commonDetails->objects())
861 addCopyPrivateVar(&*mem);
862 break;
863 }
864 addCopyPrivateVar(sym);
865 }
866 });
867
868 return hasCopyPrivate;
869}
870
871template <typename T>
872static bool isVectorSubscript(const evaluate::Expr<T> &expr) {
873 if (std::optional<evaluate::DataRef> dataRef{evaluate::ExtractDataRef(expr)})
874 if (const auto *arrayRef = std::get_if<evaluate::ArrayRef>(&dataRef->u))
875 for (const evaluate::Subscript &subscript : arrayRef->subscript())
876 if (std::holds_alternative<evaluate::IndirectSubscriptIntegerExpr>(
877 subscript.u))
878 if (subscript.Rank() > 0)
879 return true;
880 return false;
881}
882
883bool ClauseProcessor::processDefaultMap(lower::StatementContext &stmtCtx,
884 DefaultMapsTy &result) const {
885 auto process = [&](const omp::clause::Defaultmap &clause,
886 const parser::CharBlock &) {
887 using Defmap = omp::clause::Defaultmap;
888 clause::Defaultmap::VariableCategory variableCategory =
889 Defmap::VariableCategory::All;
890 // Variable Category is optional, if not specified defaults to all.
891 // Multiples of the same category are illegal as are any other
892 // defaultmaps being specified when a user specified all is in place,
893 // however, this should be handled earlier during semantics.
894 if (auto varCat =
895 std::get<std::optional<Defmap::VariableCategory>>(clause.t))
896 variableCategory = varCat.value();
897 auto behaviour = std::get<Defmap::ImplicitBehavior>(clause.t);
898 result[variableCategory] = behaviour;
899 };
900 return findRepeatableClause<omp::clause::Defaultmap>(process);
901}
902
903bool ClauseProcessor::processDepend(lower::SymMap &symMap,
904 lower::StatementContext &stmtCtx,
905 mlir::omp::DependClauseOps &result) const {
906 auto process = [&](const omp::clause::Depend &clause,
907 const parser::CharBlock &) {
908 using Depend = omp::clause::Depend;
909 if (!std::holds_alternative<Depend::TaskDep>(clause.u)) {
910 TODO(converter.getCurrentLocation(),
911 "DEPEND clause with SINK or SOURCE is not supported yet");
912 }
913 auto &taskDep = std::get<Depend::TaskDep>(clause.u);
914 auto depType = std::get<clause::DependenceType>(taskDep.t);
915 auto &objects = std::get<omp::ObjectList>(taskDep.t);
916 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
917
918 if (std::get<std::optional<omp::clause::Iterator>>(taskDep.t)) {
919 TODO(converter.getCurrentLocation(),
920 "Support for iterator modifiers is not implemented yet");
921 }
922 mlir::omp::ClauseTaskDependAttr dependTypeOperand =
923 genDependKindAttr(converter, depType);
924 result.dependKinds.append(objects.size(), dependTypeOperand);
925
926 for (const omp::Object &object : objects) {
927 assert(object.ref() && "Expecting designator");
928 mlir::Value dependVar;
929
930 if (evaluate::ExtractSubstring(*object.ref())) {
931 TODO(converter.getCurrentLocation(),
932 "substring not supported for task depend");
933 } else if (evaluate::IsArrayElement(*object.ref())) {
934 // Array Section
935 SomeExpr expr = *object.ref();
936
937 if (isVectorSubscript(expr)) {
938 // OpenMP needs the address of the first indexed element (required by
939 // the standard to be the lowest index) to identify the dependency. We
940 // don't need an accurate length for the array section because the
941 // OpenMP standard forbids overlapping array sections.
942 dependVar = genVectorSubscriptedDesignatorFirstElementAddress(
943 converter.getCurrentLocation(), converter, expr, symMap, stmtCtx);
944 } else {
945 // Ordinary array section e.g. A(1:512:2)
946 hlfir::EntityWithAttributes entity = convertExprToHLFIR(
947 converter.getCurrentLocation(), converter, expr, symMap, stmtCtx);
948 dependVar = entity.getBase();
949 }
950 } else if (evaluate::isStructureComponent(*object.ref())) {
951 SomeExpr expr = *object.ref();
952 hlfir::EntityWithAttributes entity = convertExprToHLFIR(
953 converter.getCurrentLocation(), converter, expr, symMap, stmtCtx);
954 dependVar = entity.getBase();
955 } else {
956 semantics::Symbol *sym = object.sym();
957 dependVar = converter.getSymbolAddress(*sym);
958 }
959
960 // If we pass a mutable box e.g. !fir.ref<!fir.box<!fir.heap<...>>> then
961 // the runtime will use the address of the box not the address of the
962 // data. Flang generates a lot of memcpys between different box
963 // allocations so this is not a reliable way to identify the dependency.
964 if (auto ref = mlir::dyn_cast<fir::ReferenceType>(dependVar.getType()))
965 if (fir::isa_box_type(ref.getElementType()))
966 dependVar = builder.create<fir::LoadOp>(
967 converter.getCurrentLocation(), dependVar);
968
969 // The openmp dialect doesn't know what to do with boxes (and it would
970 // break layering to teach it about them). The dependency variable can be
971 // a box because it was an array section or because the original symbol
972 // was mapped to a box.
973 // Getting the address of the box data is okay because all the runtime
974 // ultimately cares about is the base address of the array.
975 if (fir::isa_box_type(dependVar.getType()))
976 dependVar = builder.create<fir::BoxAddrOp>(
977 converter.getCurrentLocation(), dependVar);
978
979 result.dependVars.push_back(dependVar);
980 }
981 };
982
983 return findRepeatableClause<omp::clause::Depend>(process);
984}
985
986bool ClauseProcessor::processGrainsize(
987 lower::StatementContext &stmtCtx,
988 mlir::omp::GrainsizeClauseOps &result) const {
989 using Grainsize = omp::clause::Grainsize;
990 if (auto *clause = findUniqueClause<Grainsize>()) {
991 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
992 mlir::MLIRContext *context = firOpBuilder.getContext();
993 const auto &modifier =
994 std::get<std::optional<Grainsize::Prescriptiveness>>(clause->t);
995 if (modifier && *modifier == Grainsize::Prescriptiveness::Strict) {
996 result.grainsizeMod = mlir::omp::ClauseGrainsizeTypeAttr::get(
997 context, mlir::omp::ClauseGrainsizeType::Strict);
998 }
999 const auto &grainsizeExpr = std::get<omp::SomeExpr>(clause->t);
1000 result.grainsize =
1001 fir::getBase(converter.genExprValue(grainsizeExpr, stmtCtx));
1002 return true;
1003 }
1004 return false;
1005}
1006
1007bool ClauseProcessor::processHasDeviceAddr(
1008 lower::StatementContext &stmtCtx, mlir::omp::HasDeviceAddrClauseOps &result,
1009 llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceSyms) const {
1010 // For HAS_DEVICE_ADDR objects, implicitly map the top-level entities.
1011 // Their address (or the whole descriptor, if the entity had one) will be
1012 // passed to the target region.
1013 std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
1014 bool clauseFound = findRepeatableClause<omp::clause::HasDeviceAddr>(
1015 [&](const omp::clause::HasDeviceAddr &clause,
1016 const parser::CharBlock &source) {
1017 mlir::Location location = converter.genLocation(source);
1018 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1019 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1020 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1021 omp::ObjectList baseObjects;
1022 llvm::transform(clause.v, std::back_inserter(baseObjects),
1023 [&](const omp::Object &object) {
1024 if (auto maybeBase = getBaseObject(object, semaCtx))
1025 return *maybeBase;
1026 return object;
1027 });
1028 processMapObjects(stmtCtx, location, baseObjects, mapTypeBits,
1029 parentMemberIndices, result.hasDeviceAddrVars,
1030 hasDeviceSyms);
1031 });
1032
1033 insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1034 result.hasDeviceAddrVars, hasDeviceSyms);
1035 return clauseFound;
1036}
1037
1038bool ClauseProcessor::processIf(
1039 omp::clause::If::DirectiveNameModifier directiveName,
1040 mlir::omp::IfClauseOps &result) const {
1041 bool found = false;
1042 findRepeatableClause<omp::clause::If>([&](const omp::clause::If &clause,
1043 const parser::CharBlock &source) {
1044 mlir::Location clauseLocation = converter.genLocation(source);
1045 mlir::Value operand =
1046 getIfClauseOperand(converter, clause, directiveName, clauseLocation);
1047 // Assume that, at most, a single 'if' clause will be applicable to the
1048 // given directive.
1049 if (operand) {
1050 result.ifExpr = operand;
1051 found = true;
1052 }
1053 });
1054 return found;
1055}
1056bool ClauseProcessor::processInReduction(
1057 mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
1058 llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
1059 return findRepeatableClause<omp::clause::InReduction>(
1060 [&](const omp::clause::InReduction &clause, const parser::CharBlock &) {
1061 llvm::SmallVector<mlir::Value> inReductionVars;
1062 llvm::SmallVector<bool> inReduceVarByRef;
1063 llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
1064 llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
1065 ReductionProcessor rp;
1066 rp.processReductionArguments<omp::clause::InReduction>(
1067 currentLocation, converter, clause, inReductionVars,
1068 inReduceVarByRef, inReductionDeclSymbols, inReductionSyms);
1069
1070 // Copy local lists into the output.
1071 llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
1072 llvm::copy(inReduceVarByRef,
1073 std::back_inserter(result.inReductionByref));
1074 llvm::copy(inReductionDeclSymbols,
1075 std::back_inserter(result.inReductionSyms));
1076 llvm::copy(inReductionSyms, std::back_inserter(outReductionSyms));
1077 });
1078}
1079
1080bool ClauseProcessor::processIsDevicePtr(
1081 mlir::omp::IsDevicePtrClauseOps &result,
1082 llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
1083 return findRepeatableClause<omp::clause::IsDevicePtr>(
1084 [&](const omp::clause::IsDevicePtr &devPtrClause,
1085 const parser::CharBlock &) {
1086 addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
1087 isDeviceSyms);
1088 });
1089}
1090
1091bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
1092 lower::StatementContext stmtCtx;
1093 return findRepeatableClause<
1094 omp::clause::Linear>([&](const omp::clause::Linear &clause,
1095 const parser::CharBlock &) {
1096 auto &objects = std::get<omp::ObjectList>(clause.t);
1097 for (const omp::Object &object : objects) {
1098 semantics::Symbol *sym = object.sym();
1099 const mlir::Value variable = converter.getSymbolAddress(*sym);
1100 result.linearVars.push_back(variable);
1101 }
1102 if (objects.size()) {
1103 if (auto &mod =
1104 std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
1105 clause.t)) {
1106 mlir::Value operand =
1107 fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
1108 result.linearStepVars.append(objects.size(), operand);
1109 } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
1110 clause.t)) {
1111 mlir::Location currentLocation = converter.getCurrentLocation();
1112 TODO(currentLocation, "Linear modifiers not yet implemented");
1113 } else {
1114 // If nothing is present, add the default step of 1.
1115 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1116 mlir::Location currentLocation = converter.getCurrentLocation();
1117 mlir::Value operand = firOpBuilder.createIntegerConstant(
1118 currentLocation, firOpBuilder.getI32Type(), 1);
1119 result.linearStepVars.append(objects.size(), operand);
1120 }
1121 }
1122 });
1123}
1124
1125bool ClauseProcessor::processLink(
1126 llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
1127 return findRepeatableClause<omp::clause::Link>(
1128 [&](const omp::clause::Link &clause, const parser::CharBlock &) {
1129 // Case: declare target link(var1, var2)...
1130 gatherFuncAndVarSyms(
1131 clause.v, mlir::omp::DeclareTargetCaptureClause::link, result);
1132 });
1133}
1134
1135void ClauseProcessor::processMapObjects(
1136 lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
1137 const omp::ObjectList &objects,
1138 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
1139 std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
1140 llvm::SmallVectorImpl<mlir::Value> &mapVars,
1141 llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
1142 llvm::StringRef mapperIdNameRef) const {
1143 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1144
1145 auto getDefaultMapperID = [&](const omp::Object &object,
1146 std::string &mapperIdName) {
1147 if (!mlir::isa<mlir::omp::DeclareMapperOp>(
1148 firOpBuilder.getRegion().getParentOp())) {
1149 const semantics::DerivedTypeSpec *typeSpec = nullptr;
1150
1151 if (object.sym()->owner().IsDerivedType())
1152 typeSpec = object.sym()->owner().derivedTypeSpec();
1153 else if (object.sym()->GetType() &&
1154 object.sym()->GetType()->category() ==
1155 semantics::DeclTypeSpec::TypeDerived)
1156 typeSpec = &object.sym()->GetType()->derivedTypeSpec();
1157
1158 if (typeSpec) {
1159 mapperIdName =
1160 typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
1161 if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
1162 mapperIdName = converter.mangleName(mapperIdName, sym->owner());
1163 }
1164 }
1165 };
1166
1167 // Create the mapper symbol from its name, if specified.
1168 mlir::FlatSymbolRefAttr mapperId;
1169 if (!mapperIdNameRef.empty() && !objects.empty() &&
1170 mapperIdNameRef != "__implicit_mapper") {
1171 std::string mapperIdName = mapperIdNameRef.str();
1172 const omp::Object &object = objects.front();
1173 if (mapperIdNameRef == "default")
1174 getDefaultMapperID(object, mapperIdName);
1175 assert(converter.getModuleOp().lookupSymbol(mapperIdName) &&
1176 "mapper not found");
1177 mapperId =
1178 mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), mapperIdName);
1179 }
1180
1181 for (const omp::Object &object : objects) {
1182 llvm::SmallVector<mlir::Value> bounds;
1183 std::stringstream asFortran;
1184 std::optional<omp::Object> parentObj;
1185
1186 fir::factory::AddrAndBoundsInfo info =
1187 lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
1188 mlir::omp::MapBoundsType>(
1189 converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
1190 object.ref(), clauseLocation, asFortran, bounds,
1191 treatIndexAsSection);
1192
1193 mlir::Value baseOp = info.rawInput;
1194 if (object.sym()->owner().IsDerivedType()) {
1195 omp::ObjectList objectList = gatherObjectsOf(object, semaCtx);
1196 assert(!objectList.empty() &&
1197 "could not find parent objects of derived type member");
1198 parentObj = objectList[0];
1199 parentMemberIndices.emplace(parentObj.value(),
1200 OmpMapParentAndMemberData{});
1201
1202 if (isMemberOrParentAllocatableOrPointer(object, semaCtx)) {
1203 llvm::SmallVector<int64_t> indices;
1204 generateMemberPlacementIndices(object, indices, semaCtx);
1205 baseOp = createParentSymAndGenIntermediateMaps(
1206 clauseLocation, converter, semaCtx, stmtCtx, objectList, indices,
1207 parentMemberIndices[parentObj.value()], asFortran.str(),
1208 mapTypeBits);
1209 }
1210 }
1211
1212 if (mapperIdNameRef == "__implicit_mapper") {
1213 std::string mapperIdName;
1214 getDefaultMapperID(object, mapperIdName);
1215 mapperId = converter.getModuleOp().lookupSymbol(mapperIdName)
1216 ? mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
1217 mapperIdName)
1218 : mlir::FlatSymbolRefAttr();
1219 }
1220
1221 // Explicit map captures are captured ByRef by default,
1222 // optimisation passes may alter this to ByCopy or other capture
1223 // types to optimise
1224 auto location = mlir::NameLoc::get(
1225 mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
1226 baseOp.getLoc());
1227 mlir::omp::MapInfoOp mapOp = createMapInfoOp(
1228 firOpBuilder, location, baseOp,
1229 /*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
1230 /*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{},
1231 static_cast<
1232 std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1233 mapTypeBits),
1234 mlir::omp::VariableCaptureKind::ByRef, baseOp.getType(),
1235 /*partialMap=*/false, mapperId);
1236
1237 if (parentObj.has_value()) {
1238 parentMemberIndices[parentObj.value()].addChildIndexAndMapToParent(
1239 object, mapOp, semaCtx);
1240 } else {
1241 mapVars.push_back(mapOp);
1242 mapSyms.push_back(object.sym());
1243 }
1244 }
1245}
1246
1247bool ClauseProcessor::processMap(
1248 mlir::Location currentLocation, lower::StatementContext &stmtCtx,
1249 mlir::omp::MapClauseOps &result,
1250 llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms) const {
1251 // We always require tracking of symbols, even if the caller does not,
1252 // so we create an optionally used local set of symbols when the mapSyms
1253 // argument is not present.
1254 llvm::SmallVector<const semantics::Symbol *> localMapSyms;
1255 llvm::SmallVectorImpl<const semantics::Symbol *> *ptrMapSyms =
1256 mapSyms ? mapSyms : &localMapSyms;
1257 std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
1258
1259 auto process = [&](const omp::clause::Map &clause,
1260 const parser::CharBlock &source) {
1261 using Map = omp::clause::Map;
1262 mlir::Location clauseLocation = converter.genLocation(source);
1263 const auto &[mapType, typeMods, mappers, iterator, objects] = clause.t;
1264 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1265 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1266 std::string mapperIdName = "__implicit_mapper";
1267 // If the map type is specified, then process it else Tofrom is the
1268 // default.
1269 Map::MapType type = mapType.value_or(Map::MapType::Tofrom);
1270 switch (type) {
1271 case Map::MapType::To:
1272 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1273 break;
1274 case Map::MapType::From:
1275 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1276 break;
1277 case Map::MapType::Tofrom:
1278 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1279 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1280 break;
1281 case Map::MapType::Alloc:
1282 case Map::MapType::Release:
1283 // alloc and release is the default map_type for the Target Data
1284 // Ops, i.e. if no bits for map_type is supplied then alloc/release
1285 // is implicitly assumed based on the target directive. Default
1286 // value for Target Data and Enter Data is alloc and for Exit Data
1287 // it is release.
1288 break;
1289 case Map::MapType::Delete:
1290 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1291 }
1292
1293 if (typeMods) {
1294 // TODO: Still requires "self" modifier, an OpenMP 6.0+ feature
1295 if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Always))
1296 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1297 if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Present))
1298 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1299 if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Close))
1300 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1301 if (llvm::is_contained(*typeMods, Map::MapTypeModifier::OmpxHold))
1302 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1303 }
1304
1305 if (iterator) {
1306 TODO(currentLocation,
1307 "Support for iterator modifiers is not implemented yet");
1308 }
1309 if (mappers) {
1310 assert(mappers->size() == 1 && "more than one mapper");
1311 mapperIdName = mappers->front().v.id().symbol->name().ToString();
1312 if (mapperIdName != "default")
1313 mapperIdName = converter.mangleName(
1314 mapperIdName, mappers->front().v.id().symbol->owner());
1315 }
1316
1317 processMapObjects(stmtCtx, clauseLocation,
1318 std::get<omp::ObjectList>(clause.t), mapTypeBits,
1319 parentMemberIndices, result.mapVars, *ptrMapSyms,
1320 mapperIdName);
1321 };
1322
1323 bool clauseFound = findRepeatableClause<omp::clause::Map>(process);
1324 insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1325 result.mapVars, *ptrMapSyms);
1326
1327 return clauseFound;
1328}
1329
1330bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
1331 mlir::omp::MapClauseOps &result) {
1332 std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
1333 llvm::SmallVector<const semantics::Symbol *> mapSymbols;
1334
1335 auto callbackFn = [&](const auto &clause, const parser::CharBlock &source) {
1336 mlir::Location clauseLocation = converter.genLocation(source);
1337 const auto &[expectation, mapper, iterator, objects] = clause.t;
1338
1339 // TODO Support motion modifiers: mapper, iterator.
1340 if (mapper) {
1341 TODO(clauseLocation, "Mapper modifier is not supported yet");
1342 } else if (iterator) {
1343 TODO(clauseLocation, "Iterator modifier is not supported yet");
1344 }
1345
1346 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1347 std::is_same_v<llvm::remove_cvref_t<decltype(clause)>, omp::clause::To>
1348 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
1349 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1350 if (expectation && *expectation == omp::clause::To::Expectation::Present)
1351 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1352 processMapObjects(stmtCtx, clauseLocation, objects, mapTypeBits,
1353 parentMemberIndices, result.mapVars, mapSymbols);
1354 };
1355
1356 bool clauseFound = findRepeatableClause<omp::clause::To>(callbackFn);
1357 clauseFound =
1358 findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound;
1359
1360 insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1361 result.mapVars, mapSymbols);
1362
1363 return clauseFound;
1364}
1365
1366bool ClauseProcessor::processNontemporal(
1367 mlir::omp::NontemporalClauseOps &result) const {
1368 return findRepeatableClause<omp::clause::Nontemporal>(
1369 [&](const omp::clause::Nontemporal &clause, const parser::CharBlock &) {
1370 for (const Object &object : clause.v) {
1371 semantics::Symbol *sym = object.sym();
1372 mlir::Value symVal = converter.getSymbolAddress(*sym);
1373 result.nontemporalVars.push_back(symVal);
1374 }
1375 });
1376}
1377
1378bool ClauseProcessor::processReduction(
1379 mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
1380 llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
1381 return findRepeatableClause<omp::clause::Reduction>(
1382 [&](const omp::clause::Reduction &clause, const parser::CharBlock &) {
1383 llvm::SmallVector<mlir::Value> reductionVars;
1384 llvm::SmallVector<bool> reduceVarByRef;
1385 llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
1386 llvm::SmallVector<const semantics::Symbol *> reductionSyms;
1387 ReductionProcessor rp;
1388 rp.processReductionArguments<omp::clause::Reduction>(
1389 currentLocation, converter, clause, reductionVars, reduceVarByRef,
1390 reductionDeclSymbols, reductionSyms, &result.reductionMod);
1391 // Copy local lists into the output.
1392 llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
1393 llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
1394 llvm::copy(reductionDeclSymbols,
1395 std::back_inserter(result.reductionSyms));
1396 llvm::copy(reductionSyms, std::back_inserter(outReductionSyms));
1397 });
1398}
1399
1400bool ClauseProcessor::processTaskReduction(
1401 mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
1402 llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
1403 return findRepeatableClause<omp::clause::TaskReduction>(
1404 [&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
1405 llvm::SmallVector<mlir::Value> taskReductionVars;
1406 llvm::SmallVector<bool> TaskReduceVarByRef;
1407 llvm::SmallVector<mlir::Attribute> TaskReductionDeclSymbols;
1408 llvm::SmallVector<const semantics::Symbol *> TaskReductionSyms;
1409 ReductionProcessor rp;
1410 rp.processReductionArguments<omp::clause::TaskReduction>(
1411 currentLocation, converter, clause, taskReductionVars,
1412 TaskReduceVarByRef, TaskReductionDeclSymbols, TaskReductionSyms);
1413 // Copy local lists into the output.
1414 llvm::copy(taskReductionVars,
1415 std::back_inserter(result.taskReductionVars));
1416 llvm::copy(TaskReduceVarByRef,
1417 std::back_inserter(result.taskReductionByref));
1418 llvm::copy(TaskReductionDeclSymbols,
1419 std::back_inserter(result.taskReductionSyms));
1420 llvm::copy(TaskReductionSyms, std::back_inserter(outReductionSyms));
1421 });
1422}
1423
1424bool ClauseProcessor::processTo(
1425 llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
1426 return findRepeatableClause<omp::clause::To>(
1427 [&](const omp::clause::To &clause, const parser::CharBlock &) {
1428 // Case: declare target to(func, var1, var2)...
1429 gatherFuncAndVarSyms(std::get<ObjectList>(clause.t),
1430 mlir::omp::DeclareTargetCaptureClause::to, result);
1431 });
1432}
1433
1434bool ClauseProcessor::processEnter(
1435 llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
1436 return findRepeatableClause<omp::clause::Enter>(
1437 [&](const omp::clause::Enter &clause, const parser::CharBlock &) {
1438 // Case: declare target enter(func, var1, var2)...
1439 gatherFuncAndVarSyms(
1440 clause.v, mlir::omp::DeclareTargetCaptureClause::enter, result);
1441 });
1442}
1443
1444bool ClauseProcessor::processUseDeviceAddr(
1445 lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result,
1446 llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1447 std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
1448 bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
1449 [&](const omp::clause::UseDeviceAddr &clause,
1450 const parser::CharBlock &source) {
1451 mlir::Location location = converter.genLocation(source);
1452 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1453 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1454 processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1455 parentMemberIndices, result.useDeviceAddrVars,
1456 useDeviceSyms);
1457 });
1458
1459 insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1460 result.useDeviceAddrVars, useDeviceSyms);
1461 return clauseFound;
1462}
1463
1464bool ClauseProcessor::processUseDevicePtr(
1465 lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result,
1466 llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1467 std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
1468
1469 bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
1470 [&](const omp::clause::UseDevicePtr &clause,
1471 const parser::CharBlock &source) {
1472 mlir::Location location = converter.genLocation(source);
1473 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1474 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1475 processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1476 parentMemberIndices, result.useDevicePtrVars,
1477 useDeviceSyms);
1478 });
1479
1480 insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1481 result.useDevicePtrVars, useDeviceSyms);
1482 return clauseFound;
1483}
1484
1485} // namespace omp
1486} // namespace lower
1487} // namespace Fortran
1488

source code of flang/lib/Lower/OpenMP/ClauseProcessor.cpp