1//===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12#ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H
13#define FORTRAN_LOWER_CLAUASEPROCESSOR_H
14
15#include "Clauses.h"
16#include "DirectivesCommon.h"
17#include "ReductionProcessor.h"
18#include "Utils.h"
19#include "flang/Lower/AbstractConverter.h"
20#include "flang/Lower/Bridge.h"
21#include "flang/Optimizer/Builder/Todo.h"
22#include "flang/Parser/dump-parse-tree.h"
23#include "flang/Parser/parse-tree.h"
24#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
25
26namespace fir {
27class FirOpBuilder;
28} // namespace fir
29
30namespace Fortran {
31namespace lower {
32namespace omp {
33
34/// Class that handles the processing of OpenMP clauses.
35///
36/// Its `process<ClauseName>()` methods perform MLIR code generation for their
37/// corresponding clause if it is present in the clause list. Otherwise, they
38/// will return `false` to signal that the clause was not found.
39///
40/// The intended use of this class is to move clause processing outside of
41/// construct processing, since the same clauses can appear attached to
42/// different constructs and constructs can be combined, so that code
43/// duplication is minimized.
44///
45/// Each construct-lowering function only calls the `process<ClauseName>()`
46/// methods that relate to clauses that can impact the lowering of that
47/// construct.
48class ClauseProcessor {
49public:
50 ClauseProcessor(Fortran::lower::AbstractConverter &converter,
51 Fortran::semantics::SemanticsContext &semaCtx,
52 const List<Clause> &clauses)
53 : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
54
55 // 'Unique' clauses: They can appear at most once in the clause list.
56 bool processCollapse(
57 mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
58 mlir::omp::CollapseClauseOps &result,
59 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const;
60 bool processDefault() const;
61 bool processDevice(Fortran::lower::StatementContext &stmtCtx,
62 mlir::omp::DeviceClauseOps &result) const;
63 bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
64 bool processFinal(Fortran::lower::StatementContext &stmtCtx,
65 mlir::omp::FinalClauseOps &result) const;
66 bool
67 processHasDeviceAddr(mlir::omp::HasDeviceAddrClauseOps &result,
68 llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
69 llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
70 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
71 &isDeviceSymbols) const;
72 bool processHint(mlir::omp::HintClauseOps &result) const;
73 bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
74 bool processNowait(mlir::omp::NowaitClauseOps &result) const;
75 bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
76 mlir::omp::NumTeamsClauseOps &result) const;
77 bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
78 mlir::omp::NumThreadsClauseOps &result) const;
79 bool processOrdered(mlir::omp::OrderedClauseOps &result) const;
80 bool processPriority(Fortran::lower::StatementContext &stmtCtx,
81 mlir::omp::PriorityClauseOps &result) const;
82 bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
83 bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
84 bool processSchedule(Fortran::lower::StatementContext &stmtCtx,
85 mlir::omp::ScheduleClauseOps &result) const;
86 bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const;
87 bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx,
88 mlir::omp::ThreadLimitClauseOps &result) const;
89 bool processUntied(mlir::omp::UntiedClauseOps &result) const;
90
91 // 'Repeatable' clauses: They can appear multiple times in the clause list.
92 bool processAllocate(mlir::omp::AllocateClauseOps &result) const;
93 bool processCopyin() const;
94 bool processCopyprivate(mlir::Location currentLocation,
95 mlir::omp::CopyprivateClauseOps &result) const;
96 bool processDepend(mlir::omp::DependClauseOps &result) const;
97 bool
98 processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
99 bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
100 mlir::omp::IfClauseOps &result) const;
101 bool
102 processIsDevicePtr(mlir::omp::IsDevicePtrClauseOps &result,
103 llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
104 llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
105 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
106 &isDeviceSymbols) const;
107 bool
108 processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
109
110 // This method is used to process a map clause.
111 // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to
112 // store the original type, location and Fortran symbol for the map operands.
113 // They may be used later on to create the block_arguments for some of the
114 // target directives that require it.
115 bool processMap(
116 mlir::Location currentLocation, Fortran::lower::StatementContext &stmtCtx,
117 mlir::omp::MapClauseOps &result,
118 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms =
119 nullptr,
120 llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
121 llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
122 bool processReduction(
123 mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
124 llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
125 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSyms =
126 nullptr) const;
127 bool processSectionsReduction(mlir::Location currentLocation,
128 mlir::omp::ReductionClauseOps &result) const;
129 bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
130 bool
131 processUseDeviceAddr(mlir::omp::UseDeviceClauseOps &result,
132 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
133 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
134 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
135 &useDeviceSyms) const;
136 bool
137 processUseDevicePtr(mlir::omp::UseDeviceClauseOps &result,
138 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
139 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
140 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
141 &useDeviceSyms) const;
142
143 template <typename T>
144 bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx,
145 mlir::omp::MapClauseOps &result);
146
147 // Call this method for these clauses that should be supported but are not
148 // implemented yet. It triggers a compilation error if any of the given
149 // clauses is found.
150 template <typename... Ts>
151 void processTODO(mlir::Location currentLocation,
152 llvm::omp::Directive directive) const;
153
154private:
155 using ClauseIterator = List<Clause>::const_iterator;
156
157 /// Utility to find a clause within a range in the clause list.
158 template <typename T>
159 static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
160
161 /// Return the first instance of the given clause found in the clause list or
162 /// `nullptr` if not present. If more than one instance is expected, use
163 /// `findRepeatableClause` instead.
164 template <typename T>
165 const T *
166 findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const;
167
168 /// Call `callbackFn` for each occurrence of the given clause. Return `true`
169 /// if at least one instance was found.
170 template <typename T>
171 bool findRepeatableClause(
172 std::function<void(const T &, const Fortran::parser::CharBlock &source)>
173 callbackFn) const;
174
175 /// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
176 template <typename T>
177 bool markClauseOccurrence(mlir::UnitAttr &result) const;
178
179 Fortran::lower::AbstractConverter &converter;
180 Fortran::semantics::SemanticsContext &semaCtx;
181 List<Clause> clauses;
182};
183
184template <typename T>
185bool ClauseProcessor::processMotionClauses(
186 Fortran::lower::StatementContext &stmtCtx,
187 mlir::omp::MapClauseOps &result) {
188 return findRepeatableClause<T>(
189 [&](const T &clause, const Fortran::parser::CharBlock &source) {
190 mlir::Location clauseLocation = converter.genLocation(source);
191 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
192
193 static_assert(std::is_same_v<T, omp::clause::To> ||
194 std::is_same_v<T, omp::clause::From>);
195
196 // TODO Support motion modifiers: present, mapper, iterator.
197 constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
198 std::is_same_v<T, omp::clause::To>
199 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
200 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
201
202 auto &objects = std::get<ObjectList>(clause.t);
203 for (const omp::Object &object : objects) {
204 llvm::SmallVector<mlir::Value> bounds;
205 std::stringstream asFortran;
206 Fortran::lower::AddrAndBoundsInfo info =
207 Fortran::lower::gatherDataOperandAddrAndBounds<
208 mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
209 converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
210 object.ref(), clauseLocation, asFortran, bounds,
211 treatIndexAsSection);
212
213 auto origSymbol = converter.getSymbolAddress(*object.id());
214 mlir::Value symAddr = info.addr;
215 if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
216 symAddr = origSymbol;
217
218 // Explicit map captures are captured ByRef by default,
219 // optimisation passes may alter this to ByCopy or other capture
220 // types to optimise
221 mlir::Value mapOp = createMapInfoOp(
222 firOpBuilder, clauseLocation, symAddr, mlir::Value{},
223 asFortran.str(), bounds, {},
224 static_cast<
225 std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
226 mapTypeBits),
227 mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
228
229 result.mapVars.push_back(mapOp);
230 }
231 });
232}
233
234template <typename... Ts>
235void ClauseProcessor::processTODO(mlir::Location currentLocation,
236 llvm::omp::Directive directive) const {
237 auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) {
238 if (!x)
239 return;
240 TODO(currentLocation,
241 "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() +
242 " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() +
243 " construct");
244 };
245
246 for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it)
247 (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...);
248}
249
250template <typename T>
251ClauseProcessor::ClauseIterator
252ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
253 for (ClauseIterator it = begin; it != end; ++it) {
254 if (std::get_if<T>(&it->u))
255 return it;
256 }
257
258 return end;
259}
260
261template <typename T>
262const T *ClauseProcessor::findUniqueClause(
263 const Fortran::parser::CharBlock **source) const {
264 ClauseIterator it = findClause<T>(clauses.begin(), clauses.end());
265 if (it != clauses.end()) {
266 if (source)
267 *source = &it->source;
268 return &std::get<T>(it->u);
269 }
270 return nullptr;
271}
272
273template <typename T>
274bool ClauseProcessor::findRepeatableClause(
275 std::function<void(const T &, const Fortran::parser::CharBlock &source)>
276 callbackFn) const {
277 bool found = false;
278 ClauseIterator nextIt, endIt = clauses.end();
279 for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
280 nextIt = findClause<T>(it, endIt);
281
282 if (nextIt != endIt) {
283 callbackFn(std::get<T>(nextIt->u), nextIt->source);
284 found = true;
285 ++nextIt;
286 }
287 }
288 return found;
289}
290
291template <typename T>
292bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
293 if (findUniqueClause<T>()) {
294 result = converter.getFirOpBuilder().getUnitAttr();
295 return true;
296 }
297 return false;
298}
299
300} // namespace omp
301} // namespace lower
302} // namespace Fortran
303
304#endif // FORTRAN_LOWER_CLAUASEPROCESSOR_H
305

source code of flang/lib/Lower/OpenMP/ClauseProcessor.h