1//===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12#ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H
13#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
14
15#include "ClauseFinder.h"
16#include "Clauses.h"
17#include "ReductionProcessor.h"
18#include "Utils.h"
19#include "flang/Lower/AbstractConverter.h"
20#include "flang/Lower/Bridge.h"
21#include "flang/Lower/DirectivesCommon.h"
22#include "flang/Optimizer/Builder/Todo.h"
23#include "flang/Parser/dump-parse-tree.h"
24#include "flang/Parser/parse-tree.h"
25#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
26
27namespace fir {
28class FirOpBuilder;
29} // namespace fir
30
31namespace Fortran {
32namespace lower {
33namespace omp {
34
35// Container type for tracking user specified Defaultmaps for a target region
36using DefaultMapsTy = std::map<clause::Defaultmap::VariableCategory,
37 clause::Defaultmap::ImplicitBehavior>;
38
39/// Class that handles the processing of OpenMP clauses.
40///
41/// Its `process<ClauseName>()` methods perform MLIR code generation for their
42/// corresponding clause if it is present in the clause list. Otherwise, they
43/// will return `false` to signal that the clause was not found.
44///
45/// The intended use of this class is to move clause processing outside of
46/// construct processing, since the same clauses can appear attached to
47/// different constructs and constructs can be combined, so that code
48/// duplication is minimized.
49///
50/// Each construct-lowering function only calls the `process<ClauseName>()`
51/// methods that relate to clauses that can impact the lowering of that
52/// construct.
53class ClauseProcessor {
54public:
55 ClauseProcessor(lower::AbstractConverter &converter,
56 semantics::SemanticsContext &semaCtx,
57 const List<Clause> &clauses)
58 : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
59
60 // 'Unique' clauses: They can appear at most once in the clause list.
61 bool processBare(mlir::omp::BareClauseOps &result) const;
62 bool processBind(mlir::omp::BindClauseOps &result) const;
63 bool processCancelDirectiveName(
64 mlir::omp::CancelDirectiveNameClauseOps &result) const;
65 bool
66 processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
67 mlir::omp::LoopRelatedClauseOps &result,
68 llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
69 bool processDevice(lower::StatementContext &stmtCtx,
70 mlir::omp::DeviceClauseOps &result) const;
71 bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
72 bool processDistSchedule(lower::StatementContext &stmtCtx,
73 mlir::omp::DistScheduleClauseOps &result) const;
74 bool processExclusive(mlir::Location currentLocation,
75 mlir::omp::ExclusiveClauseOps &result) const;
76 bool processFilter(lower::StatementContext &stmtCtx,
77 mlir::omp::FilterClauseOps &result) const;
78 bool processFinal(lower::StatementContext &stmtCtx,
79 mlir::omp::FinalClauseOps &result) const;
80 bool processGrainsize(lower::StatementContext &stmtCtx,
81 mlir::omp::GrainsizeClauseOps &result) const;
82 bool processHasDeviceAddr(
83 lower::StatementContext &stmtCtx,
84 mlir::omp::HasDeviceAddrClauseOps &result,
85 llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceSyms) const;
86 bool processHint(mlir::omp::HintClauseOps &result) const;
87 bool processInclusive(mlir::Location currentLocation,
88 mlir::omp::InclusiveClauseOps &result) const;
89 bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
90 bool processNowait(mlir::omp::NowaitClauseOps &result) const;
91 bool processNumTasks(lower::StatementContext &stmtCtx,
92 mlir::omp::NumTasksClauseOps &result) const;
93 bool processNumTeams(lower::StatementContext &stmtCtx,
94 mlir::omp::NumTeamsClauseOps &result) const;
95 bool processNumThreads(lower::StatementContext &stmtCtx,
96 mlir::omp::NumThreadsClauseOps &result) const;
97 bool processOrder(mlir::omp::OrderClauseOps &result) const;
98 bool processOrdered(mlir::omp::OrderedClauseOps &result) const;
99 bool processPriority(lower::StatementContext &stmtCtx,
100 mlir::omp::PriorityClauseOps &result) const;
101 bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
102 bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
103 bool processSchedule(lower::StatementContext &stmtCtx,
104 mlir::omp::ScheduleClauseOps &result) const;
105 bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const;
106 bool processThreadLimit(lower::StatementContext &stmtCtx,
107 mlir::omp::ThreadLimitClauseOps &result) const;
108 bool processUntied(mlir::omp::UntiedClauseOps &result) const;
109
110 bool processDetach(mlir::omp::DetachClauseOps &result) const;
111 // 'Repeatable' clauses: They can appear multiple times in the clause list.
112 bool processAligned(mlir::omp::AlignedClauseOps &result) const;
113 bool processAllocate(mlir::omp::AllocateClauseOps &result) const;
114 bool processCopyin() const;
115 bool processCopyprivate(mlir::Location currentLocation,
116 mlir::omp::CopyprivateClauseOps &result) const;
117 bool processDefaultMap(lower::StatementContext &stmtCtx,
118 DefaultMapsTy &result) const;
119 bool processDepend(lower::SymMap &symMap, lower::StatementContext &stmtCtx,
120 mlir::omp::DependClauseOps &result) const;
121 bool
122 processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
123 bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
124 mlir::omp::IfClauseOps &result) const;
125 bool processInReduction(
126 mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
127 llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
128 bool processIsDevicePtr(
129 mlir::omp::IsDevicePtrClauseOps &result,
130 llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
131 bool processLinear(mlir::omp::LinearClauseOps &result) const;
132 bool
133 processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
134
135 // This method is used to process a map clause.
136 // The optional parameter mapSyms is used to store the original Fortran symbol
137 // for the map operands. It may be used later on to create the block_arguments
138 // for some of the directives that require it.
139 bool processMap(mlir::Location currentLocation,
140 lower::StatementContext &stmtCtx,
141 mlir::omp::MapClauseOps &result,
142 llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms =
143 nullptr) const;
144 bool processMotionClauses(lower::StatementContext &stmtCtx,
145 mlir::omp::MapClauseOps &result);
146 bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const;
147 bool processReduction(
148 mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
149 llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
150 bool processTaskReduction(
151 mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
152 llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
153 bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
154 bool processUseDeviceAddr(
155 lower::StatementContext &stmtCtx,
156 mlir::omp::UseDeviceAddrClauseOps &result,
157 llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
158 bool processUseDevicePtr(
159 lower::StatementContext &stmtCtx,
160 mlir::omp::UseDevicePtrClauseOps &result,
161 llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
162
163 // Call this method for these clauses that should be supported but are not
164 // implemented yet. It triggers a compilation error if any of the given
165 // clauses is found.
166 template <typename... Ts>
167 void processTODO(mlir::Location currentLocation,
168 llvm::omp::Directive directive) const;
169
170private:
171 using ClauseIterator = List<Clause>::const_iterator;
172
173 /// Return the first instance of the given clause found in the clause list or
174 /// `nullptr` if not present. If more than one instance is expected, use
175 /// `findRepeatableClause` instead.
176 template <typename T>
177 const T *findUniqueClause(const parser::CharBlock **source = nullptr) const;
178
179 /// Call `callbackFn` for each occurrence of the given clause. Return `true`
180 /// if at least one instance was found.
181 template <typename T>
182 bool findRepeatableClause(
183 std::function<void(const T &, const parser::CharBlock &source)>
184 callbackFn) const;
185
186 /// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
187 template <typename T>
188 bool markClauseOccurrence(mlir::UnitAttr &result) const;
189
190 void processMapObjects(
191 lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
192 const omp::ObjectList &objects,
193 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
194 std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
195 llvm::SmallVectorImpl<mlir::Value> &mapVars,
196 llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
197 llvm::StringRef mapperIdNameRef = "") const;
198
199 lower::AbstractConverter &converter;
200 semantics::SemanticsContext &semaCtx;
201 List<Clause> clauses;
202};
203
204template <typename... Ts>
205void ClauseProcessor::processTODO(mlir::Location currentLocation,
206 llvm::omp::Directive directive) const {
207 auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) {
208 if (!x)
209 return;
210 unsigned version = semaCtx.langOptions().OpenMPVersion;
211 TODO(currentLocation,
212 "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() +
213 " in " +
214 llvm::omp::getOpenMPDirectiveName(directive, version).upper() +
215 " construct");
216 };
217
218 for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it)
219 (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...);
220}
221
222template <typename T>
223const T *
224ClauseProcessor::findUniqueClause(const parser::CharBlock **source) const {
225 return ClauseFinder::findUniqueClause<T>(clauses, source);
226}
227
228template <typename T>
229bool ClauseProcessor::findRepeatableClause(
230 std::function<void(const T &, const parser::CharBlock &source)> callbackFn)
231 const {
232 return ClauseFinder::findRepeatableClause<T>(clauses, callbackFn);
233}
234
235template <typename T>
236bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
237 if (findUniqueClause<T>()) {
238 result = converter.getFirOpBuilder().getUnitAttr();
239 return true;
240 }
241 return false;
242}
243
244} // namespace omp
245} // namespace lower
246} // namespace Fortran
247
248#endif // FORTRAN_LOWER_CLAUSEPROCESSOR_H
249

source code of flang/lib/Lower/OpenMP/ClauseProcessor.h