1 | //===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | #ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H |
13 | #define FORTRAN_LOWER_CLAUSEPROCESSOR_H |
14 | |
15 | #include "ClauseFinder.h" |
16 | #include "Clauses.h" |
17 | #include "ReductionProcessor.h" |
18 | #include "Utils.h" |
19 | #include "flang/Lower/AbstractConverter.h" |
20 | #include "flang/Lower/Bridge.h" |
21 | #include "flang/Lower/DirectivesCommon.h" |
22 | #include "flang/Optimizer/Builder/Todo.h" |
23 | #include "flang/Parser/dump-parse-tree.h" |
24 | #include "flang/Parser/parse-tree.h" |
25 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
26 | |
27 | namespace fir { |
28 | class FirOpBuilder; |
29 | } // namespace fir |
30 | |
31 | namespace Fortran { |
32 | namespace lower { |
33 | namespace omp { |
34 | |
35 | // Container type for tracking user specified Defaultmaps for a target region |
36 | using DefaultMapsTy = std::map<clause::Defaultmap::VariableCategory, |
37 | clause::Defaultmap::ImplicitBehavior>; |
38 | |
39 | /// Class that handles the processing of OpenMP clauses. |
40 | /// |
41 | /// Its `process<ClauseName>()` methods perform MLIR code generation for their |
42 | /// corresponding clause if it is present in the clause list. Otherwise, they |
43 | /// will return `false` to signal that the clause was not found. |
44 | /// |
45 | /// The intended use of this class is to move clause processing outside of |
46 | /// construct processing, since the same clauses can appear attached to |
47 | /// different constructs and constructs can be combined, so that code |
48 | /// duplication is minimized. |
49 | /// |
50 | /// Each construct-lowering function only calls the `process<ClauseName>()` |
51 | /// methods that relate to clauses that can impact the lowering of that |
52 | /// construct. |
53 | class ClauseProcessor { |
54 | public: |
55 | ClauseProcessor(lower::AbstractConverter &converter, |
56 | semantics::SemanticsContext &semaCtx, |
57 | const List<Clause> &clauses) |
58 | : converter(converter), semaCtx(semaCtx), clauses(clauses) {} |
59 | |
60 | // 'Unique' clauses: They can appear at most once in the clause list. |
61 | bool processBare(mlir::omp::BareClauseOps &result) const; |
62 | bool processBind(mlir::omp::BindClauseOps &result) const; |
63 | bool processCancelDirectiveName( |
64 | mlir::omp::CancelDirectiveNameClauseOps &result) const; |
65 | bool |
66 | processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval, |
67 | mlir::omp::LoopRelatedClauseOps &result, |
68 | llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const; |
69 | bool processDevice(lower::StatementContext &stmtCtx, |
70 | mlir::omp::DeviceClauseOps &result) const; |
71 | bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const; |
72 | bool processDistSchedule(lower::StatementContext &stmtCtx, |
73 | mlir::omp::DistScheduleClauseOps &result) const; |
74 | bool processExclusive(mlir::Location currentLocation, |
75 | mlir::omp::ExclusiveClauseOps &result) const; |
76 | bool processFilter(lower::StatementContext &stmtCtx, |
77 | mlir::omp::FilterClauseOps &result) const; |
78 | bool processFinal(lower::StatementContext &stmtCtx, |
79 | mlir::omp::FinalClauseOps &result) const; |
80 | bool processGrainsize(lower::StatementContext &stmtCtx, |
81 | mlir::omp::GrainsizeClauseOps &result) const; |
82 | bool processHasDeviceAddr( |
83 | lower::StatementContext &stmtCtx, |
84 | mlir::omp::HasDeviceAddrClauseOps &result, |
85 | llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceSyms) const; |
86 | bool processHint(mlir::omp::HintClauseOps &result) const; |
87 | bool processInclusive(mlir::Location currentLocation, |
88 | mlir::omp::InclusiveClauseOps &result) const; |
89 | bool processMergeable(mlir::omp::MergeableClauseOps &result) const; |
90 | bool processNowait(mlir::omp::NowaitClauseOps &result) const; |
91 | bool processNumTasks(lower::StatementContext &stmtCtx, |
92 | mlir::omp::NumTasksClauseOps &result) const; |
93 | bool processNumTeams(lower::StatementContext &stmtCtx, |
94 | mlir::omp::NumTeamsClauseOps &result) const; |
95 | bool processNumThreads(lower::StatementContext &stmtCtx, |
96 | mlir::omp::NumThreadsClauseOps &result) const; |
97 | bool processOrder(mlir::omp::OrderClauseOps &result) const; |
98 | bool processOrdered(mlir::omp::OrderedClauseOps &result) const; |
99 | bool processPriority(lower::StatementContext &stmtCtx, |
100 | mlir::omp::PriorityClauseOps &result) const; |
101 | bool processProcBind(mlir::omp::ProcBindClauseOps &result) const; |
102 | bool processSafelen(mlir::omp::SafelenClauseOps &result) const; |
103 | bool processSchedule(lower::StatementContext &stmtCtx, |
104 | mlir::omp::ScheduleClauseOps &result) const; |
105 | bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const; |
106 | bool processThreadLimit(lower::StatementContext &stmtCtx, |
107 | mlir::omp::ThreadLimitClauseOps &result) const; |
108 | bool processUntied(mlir::omp::UntiedClauseOps &result) const; |
109 | |
110 | bool processDetach(mlir::omp::DetachClauseOps &result) const; |
111 | // 'Repeatable' clauses: They can appear multiple times in the clause list. |
112 | bool processAligned(mlir::omp::AlignedClauseOps &result) const; |
113 | bool processAllocate(mlir::omp::AllocateClauseOps &result) const; |
114 | bool processCopyin() const; |
115 | bool processCopyprivate(mlir::Location currentLocation, |
116 | mlir::omp::CopyprivateClauseOps &result) const; |
117 | bool processDefaultMap(lower::StatementContext &stmtCtx, |
118 | DefaultMapsTy &result) const; |
119 | bool processDepend(lower::SymMap &symMap, lower::StatementContext &stmtCtx, |
120 | mlir::omp::DependClauseOps &result) const; |
121 | bool |
122 | processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
123 | bool processIf(omp::clause::If::DirectiveNameModifier directiveName, |
124 | mlir::omp::IfClauseOps &result) const; |
125 | bool processInReduction( |
126 | mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result, |
127 | llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const; |
128 | bool processIsDevicePtr( |
129 | mlir::omp::IsDevicePtrClauseOps &result, |
130 | llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const; |
131 | bool processLinear(mlir::omp::LinearClauseOps &result) const; |
132 | bool |
133 | processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
134 | |
135 | // This method is used to process a map clause. |
136 | // The optional parameter mapSyms is used to store the original Fortran symbol |
137 | // for the map operands. It may be used later on to create the block_arguments |
138 | // for some of the directives that require it. |
139 | bool processMap(mlir::Location currentLocation, |
140 | lower::StatementContext &stmtCtx, |
141 | mlir::omp::MapClauseOps &result, |
142 | llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = |
143 | nullptr) const; |
144 | bool processMotionClauses(lower::StatementContext &stmtCtx, |
145 | mlir::omp::MapClauseOps &result); |
146 | bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const; |
147 | bool processReduction( |
148 | mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, |
149 | llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const; |
150 | bool processTaskReduction( |
151 | mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result, |
152 | llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const; |
153 | bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
154 | bool processUseDeviceAddr( |
155 | lower::StatementContext &stmtCtx, |
156 | mlir::omp::UseDeviceAddrClauseOps &result, |
157 | llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const; |
158 | bool processUseDevicePtr( |
159 | lower::StatementContext &stmtCtx, |
160 | mlir::omp::UseDevicePtrClauseOps &result, |
161 | llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const; |
162 | |
163 | // Call this method for these clauses that should be supported but are not |
164 | // implemented yet. It triggers a compilation error if any of the given |
165 | // clauses is found. |
166 | template <typename... Ts> |
167 | void processTODO(mlir::Location currentLocation, |
168 | llvm::omp::Directive directive) const; |
169 | |
170 | private: |
171 | using ClauseIterator = List<Clause>::const_iterator; |
172 | |
173 | /// Return the first instance of the given clause found in the clause list or |
174 | /// `nullptr` if not present. If more than one instance is expected, use |
175 | /// `findRepeatableClause` instead. |
176 | template <typename T> |
177 | const T *findUniqueClause(const parser::CharBlock **source = nullptr) const; |
178 | |
179 | /// Call `callbackFn` for each occurrence of the given clause. Return `true` |
180 | /// if at least one instance was found. |
181 | template <typename T> |
182 | bool findRepeatableClause( |
183 | std::function<void(const T &, const parser::CharBlock &source)> |
184 | callbackFn) const; |
185 | |
186 | /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. |
187 | template <typename T> |
188 | bool markClauseOccurrence(mlir::UnitAttr &result) const; |
189 | |
190 | void processMapObjects( |
191 | lower::StatementContext &stmtCtx, mlir::Location clauseLocation, |
192 | const omp::ObjectList &objects, |
193 | llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, |
194 | std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices, |
195 | llvm::SmallVectorImpl<mlir::Value> &mapVars, |
196 | llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms, |
197 | llvm::StringRef mapperIdNameRef = "" ) const; |
198 | |
199 | lower::AbstractConverter &converter; |
200 | semantics::SemanticsContext &semaCtx; |
201 | List<Clause> clauses; |
202 | }; |
203 | |
204 | template <typename... Ts> |
205 | void ClauseProcessor::processTODO(mlir::Location currentLocation, |
206 | llvm::omp::Directive directive) const { |
207 | auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) { |
208 | if (!x) |
209 | return; |
210 | unsigned version = semaCtx.langOptions().OpenMPVersion; |
211 | TODO(currentLocation, |
212 | "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() + |
213 | " in " + |
214 | llvm::omp::getOpenMPDirectiveName(directive, version).upper() + |
215 | " construct" ); |
216 | }; |
217 | |
218 | for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it) |
219 | (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...); |
220 | } |
221 | |
222 | template <typename T> |
223 | const T * |
224 | ClauseProcessor::findUniqueClause(const parser::CharBlock **source) const { |
225 | return ClauseFinder::findUniqueClause<T>(clauses, source); |
226 | } |
227 | |
228 | template <typename T> |
229 | bool ClauseProcessor::findRepeatableClause( |
230 | std::function<void(const T &, const parser::CharBlock &source)> callbackFn) |
231 | const { |
232 | return ClauseFinder::findRepeatableClause<T>(clauses, callbackFn); |
233 | } |
234 | |
235 | template <typename T> |
236 | bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const { |
237 | if (findUniqueClause<T>()) { |
238 | result = converter.getFirOpBuilder().getUnitAttr(); |
239 | return true; |
240 | } |
241 | return false; |
242 | } |
243 | |
244 | } // namespace omp |
245 | } // namespace lower |
246 | } // namespace Fortran |
247 | |
248 | #endif // FORTRAN_LOWER_CLAUSEPROCESSOR_H |
249 | |