1 | //===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | #ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H |
13 | #define FORTRAN_LOWER_CLAUASEPROCESSOR_H |
14 | |
15 | #include "Clauses.h" |
16 | #include "DirectivesCommon.h" |
17 | #include "ReductionProcessor.h" |
18 | #include "Utils.h" |
19 | #include "flang/Lower/AbstractConverter.h" |
20 | #include "flang/Lower/Bridge.h" |
21 | #include "flang/Optimizer/Builder/Todo.h" |
22 | #include "flang/Parser/dump-parse-tree.h" |
23 | #include "flang/Parser/parse-tree.h" |
24 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
25 | |
26 | namespace fir { |
27 | class FirOpBuilder; |
28 | } // namespace fir |
29 | |
30 | namespace Fortran { |
31 | namespace lower { |
32 | namespace omp { |
33 | |
34 | /// Class that handles the processing of OpenMP clauses. |
35 | /// |
36 | /// Its `process<ClauseName>()` methods perform MLIR code generation for their |
37 | /// corresponding clause if it is present in the clause list. Otherwise, they |
38 | /// will return `false` to signal that the clause was not found. |
39 | /// |
40 | /// The intended use of this class is to move clause processing outside of |
41 | /// construct processing, since the same clauses can appear attached to |
42 | /// different constructs and constructs can be combined, so that code |
43 | /// duplication is minimized. |
44 | /// |
45 | /// Each construct-lowering function only calls the `process<ClauseName>()` |
46 | /// methods that relate to clauses that can impact the lowering of that |
47 | /// construct. |
48 | class ClauseProcessor { |
49 | public: |
50 | ClauseProcessor(Fortran::lower::AbstractConverter &converter, |
51 | Fortran::semantics::SemanticsContext &semaCtx, |
52 | const List<Clause> &clauses) |
53 | : converter(converter), semaCtx(semaCtx), clauses(clauses) {} |
54 | |
55 | // 'Unique' clauses: They can appear at most once in the clause list. |
56 | bool processCollapse( |
57 | mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval, |
58 | mlir::omp::CollapseClauseOps &result, |
59 | llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const; |
60 | bool processDefault() const; |
61 | bool processDevice(Fortran::lower::StatementContext &stmtCtx, |
62 | mlir::omp::DeviceClauseOps &result) const; |
63 | bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const; |
64 | bool processFinal(Fortran::lower::StatementContext &stmtCtx, |
65 | mlir::omp::FinalClauseOps &result) const; |
66 | bool |
67 | processHasDeviceAddr(mlir::omp::HasDeviceAddrClauseOps &result, |
68 | llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes, |
69 | llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs, |
70 | llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
71 | &isDeviceSymbols) const; |
72 | bool processHint(mlir::omp::HintClauseOps &result) const; |
73 | bool processMergeable(mlir::omp::MergeableClauseOps &result) const; |
74 | bool processNowait(mlir::omp::NowaitClauseOps &result) const; |
75 | bool processNumTeams(Fortran::lower::StatementContext &stmtCtx, |
76 | mlir::omp::NumTeamsClauseOps &result) const; |
77 | bool processNumThreads(Fortran::lower::StatementContext &stmtCtx, |
78 | mlir::omp::NumThreadsClauseOps &result) const; |
79 | bool processOrdered(mlir::omp::OrderedClauseOps &result) const; |
80 | bool processPriority(Fortran::lower::StatementContext &stmtCtx, |
81 | mlir::omp::PriorityClauseOps &result) const; |
82 | bool processProcBind(mlir::omp::ProcBindClauseOps &result) const; |
83 | bool processSafelen(mlir::omp::SafelenClauseOps &result) const; |
84 | bool processSchedule(Fortran::lower::StatementContext &stmtCtx, |
85 | mlir::omp::ScheduleClauseOps &result) const; |
86 | bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const; |
87 | bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx, |
88 | mlir::omp::ThreadLimitClauseOps &result) const; |
89 | bool processUntied(mlir::omp::UntiedClauseOps &result) const; |
90 | |
91 | // 'Repeatable' clauses: They can appear multiple times in the clause list. |
92 | bool processAllocate(mlir::omp::AllocateClauseOps &result) const; |
93 | bool processCopyin() const; |
94 | bool processCopyprivate(mlir::Location currentLocation, |
95 | mlir::omp::CopyprivateClauseOps &result) const; |
96 | bool processDepend(mlir::omp::DependClauseOps &result) const; |
97 | bool |
98 | processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
99 | bool processIf(omp::clause::If::DirectiveNameModifier directiveName, |
100 | mlir::omp::IfClauseOps &result) const; |
101 | bool |
102 | processIsDevicePtr(mlir::omp::IsDevicePtrClauseOps &result, |
103 | llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes, |
104 | llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs, |
105 | llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
106 | &isDeviceSymbols) const; |
107 | bool |
108 | processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
109 | |
110 | // This method is used to process a map clause. |
111 | // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to |
112 | // store the original type, location and Fortran symbol for the map operands. |
113 | // They may be used later on to create the block_arguments for some of the |
114 | // target directives that require it. |
115 | bool processMap( |
116 | mlir::Location currentLocation, Fortran::lower::StatementContext &stmtCtx, |
117 | mlir::omp::MapClauseOps &result, |
118 | llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms = |
119 | nullptr, |
120 | llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr, |
121 | llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const; |
122 | bool processReduction( |
123 | mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, |
124 | llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr, |
125 | llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSyms = |
126 | nullptr) const; |
127 | bool processSectionsReduction(mlir::Location currentLocation, |
128 | mlir::omp::ReductionClauseOps &result) const; |
129 | bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
130 | bool |
131 | processUseDeviceAddr(mlir::omp::UseDeviceClauseOps &result, |
132 | llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, |
133 | llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, |
134 | llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
135 | &useDeviceSyms) const; |
136 | bool |
137 | processUseDevicePtr(mlir::omp::UseDeviceClauseOps &result, |
138 | llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, |
139 | llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, |
140 | llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
141 | &useDeviceSyms) const; |
142 | |
143 | template <typename T> |
144 | bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx, |
145 | mlir::omp::MapClauseOps &result); |
146 | |
147 | // Call this method for these clauses that should be supported but are not |
148 | // implemented yet. It triggers a compilation error if any of the given |
149 | // clauses is found. |
150 | template <typename... Ts> |
151 | void processTODO(mlir::Location currentLocation, |
152 | llvm::omp::Directive directive) const; |
153 | |
154 | private: |
155 | using ClauseIterator = List<Clause>::const_iterator; |
156 | |
157 | /// Utility to find a clause within a range in the clause list. |
158 | template <typename T> |
159 | static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end); |
160 | |
161 | /// Return the first instance of the given clause found in the clause list or |
162 | /// `nullptr` if not present. If more than one instance is expected, use |
163 | /// `findRepeatableClause` instead. |
164 | template <typename T> |
165 | const T * |
166 | findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const; |
167 | |
168 | /// Call `callbackFn` for each occurrence of the given clause. Return `true` |
169 | /// if at least one instance was found. |
170 | template <typename T> |
171 | bool findRepeatableClause( |
172 | std::function<void(const T &, const Fortran::parser::CharBlock &source)> |
173 | callbackFn) const; |
174 | |
175 | /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. |
176 | template <typename T> |
177 | bool markClauseOccurrence(mlir::UnitAttr &result) const; |
178 | |
179 | Fortran::lower::AbstractConverter &converter; |
180 | Fortran::semantics::SemanticsContext &semaCtx; |
181 | List<Clause> clauses; |
182 | }; |
183 | |
184 | template <typename T> |
185 | bool ClauseProcessor::processMotionClauses( |
186 | Fortran::lower::StatementContext &stmtCtx, |
187 | mlir::omp::MapClauseOps &result) { |
188 | return findRepeatableClause<T>( |
189 | [&](const T &clause, const Fortran::parser::CharBlock &source) { |
190 | mlir::Location clauseLocation = converter.genLocation(source); |
191 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
192 | |
193 | static_assert(std::is_same_v<T, omp::clause::To> || |
194 | std::is_same_v<T, omp::clause::From>); |
195 | |
196 | // TODO Support motion modifiers: present, mapper, iterator. |
197 | constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = |
198 | std::is_same_v<T, omp::clause::To> |
199 | ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
200 | : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; |
201 | |
202 | auto &objects = std::get<ObjectList>(clause.t); |
203 | for (const omp::Object &object : objects) { |
204 | llvm::SmallVector<mlir::Value> bounds; |
205 | std::stringstream asFortran; |
206 | Fortran::lower::AddrAndBoundsInfo info = |
207 | Fortran::lower::gatherDataOperandAddrAndBounds< |
208 | mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>( |
209 | converter, firOpBuilder, semaCtx, stmtCtx, *object.id(), |
210 | object.ref(), clauseLocation, asFortran, bounds, |
211 | treatIndexAsSection); |
212 | |
213 | auto origSymbol = converter.getSymbolAddress(*object.id()); |
214 | mlir::Value symAddr = info.addr; |
215 | if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) |
216 | symAddr = origSymbol; |
217 | |
218 | // Explicit map captures are captured ByRef by default, |
219 | // optimisation passes may alter this to ByCopy or other capture |
220 | // types to optimise |
221 | mlir::Value mapOp = createMapInfoOp( |
222 | firOpBuilder, clauseLocation, symAddr, mlir::Value{}, |
223 | asFortran.str(), bounds, {}, |
224 | static_cast< |
225 | std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( |
226 | mapTypeBits), |
227 | mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); |
228 | |
229 | result.mapVars.push_back(mapOp); |
230 | } |
231 | }); |
232 | } |
233 | |
234 | template <typename... Ts> |
235 | void ClauseProcessor::processTODO(mlir::Location currentLocation, |
236 | llvm::omp::Directive directive) const { |
237 | auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) { |
238 | if (!x) |
239 | return; |
240 | TODO(currentLocation, |
241 | "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() + |
242 | " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() + |
243 | " construct" ); |
244 | }; |
245 | |
246 | for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it) |
247 | (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...); |
248 | } |
249 | |
250 | template <typename T> |
251 | ClauseProcessor::ClauseIterator |
252 | ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) { |
253 | for (ClauseIterator it = begin; it != end; ++it) { |
254 | if (std::get_if<T>(&it->u)) |
255 | return it; |
256 | } |
257 | |
258 | return end; |
259 | } |
260 | |
261 | template <typename T> |
262 | const T *ClauseProcessor::findUniqueClause( |
263 | const Fortran::parser::CharBlock **source) const { |
264 | ClauseIterator it = findClause<T>(clauses.begin(), clauses.end()); |
265 | if (it != clauses.end()) { |
266 | if (source) |
267 | *source = &it->source; |
268 | return &std::get<T>(it->u); |
269 | } |
270 | return nullptr; |
271 | } |
272 | |
273 | template <typename T> |
274 | bool ClauseProcessor::findRepeatableClause( |
275 | std::function<void(const T &, const Fortran::parser::CharBlock &source)> |
276 | callbackFn) const { |
277 | bool found = false; |
278 | ClauseIterator nextIt, endIt = clauses.end(); |
279 | for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) { |
280 | nextIt = findClause<T>(it, endIt); |
281 | |
282 | if (nextIt != endIt) { |
283 | callbackFn(std::get<T>(nextIt->u), nextIt->source); |
284 | found = true; |
285 | ++nextIt; |
286 | } |
287 | } |
288 | return found; |
289 | } |
290 | |
291 | template <typename T> |
292 | bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const { |
293 | if (findUniqueClause<T>()) { |
294 | result = converter.getFirOpBuilder().getUnitAttr(); |
295 | return true; |
296 | } |
297 | return false; |
298 | } |
299 | |
300 | } // namespace omp |
301 | } // namespace lower |
302 | } // namespace Fortran |
303 | |
304 | #endif // FORTRAN_LOWER_CLAUASEPROCESSOR_H |
305 | |