| 1 | //===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | #ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H |
| 13 | #define FORTRAN_LOWER_CLAUSEPROCESSOR_H |
| 14 | |
| 15 | #include "ClauseFinder.h" |
| 16 | #include "Clauses.h" |
| 17 | #include "ReductionProcessor.h" |
| 18 | #include "Utils.h" |
| 19 | #include "flang/Lower/AbstractConverter.h" |
| 20 | #include "flang/Lower/Bridge.h" |
| 21 | #include "flang/Lower/DirectivesCommon.h" |
| 22 | #include "flang/Optimizer/Builder/Todo.h" |
| 23 | #include "flang/Parser/dump-parse-tree.h" |
| 24 | #include "flang/Parser/parse-tree.h" |
| 25 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
| 26 | |
| 27 | namespace fir { |
| 28 | class FirOpBuilder; |
| 29 | } // namespace fir |
| 30 | |
| 31 | namespace Fortran { |
| 32 | namespace lower { |
| 33 | namespace omp { |
| 34 | |
| 35 | // Container type for tracking user specified Defaultmaps for a target region |
| 36 | using DefaultMapsTy = std::map<clause::Defaultmap::VariableCategory, |
| 37 | clause::Defaultmap::ImplicitBehavior>; |
| 38 | |
| 39 | /// Class that handles the processing of OpenMP clauses. |
| 40 | /// |
| 41 | /// Its `process<ClauseName>()` methods perform MLIR code generation for their |
| 42 | /// corresponding clause if it is present in the clause list. Otherwise, they |
| 43 | /// will return `false` to signal that the clause was not found. |
| 44 | /// |
| 45 | /// The intended use of this class is to move clause processing outside of |
| 46 | /// construct processing, since the same clauses can appear attached to |
| 47 | /// different constructs and constructs can be combined, so that code |
| 48 | /// duplication is minimized. |
| 49 | /// |
| 50 | /// Each construct-lowering function only calls the `process<ClauseName>()` |
| 51 | /// methods that relate to clauses that can impact the lowering of that |
| 52 | /// construct. |
| 53 | class ClauseProcessor { |
| 54 | public: |
| 55 | ClauseProcessor(lower::AbstractConverter &converter, |
| 56 | semantics::SemanticsContext &semaCtx, |
| 57 | const List<Clause> &clauses) |
| 58 | : converter(converter), semaCtx(semaCtx), clauses(clauses) {} |
| 59 | |
| 60 | // 'Unique' clauses: They can appear at most once in the clause list. |
| 61 | bool processBare(mlir::omp::BareClauseOps &result) const; |
| 62 | bool processBind(mlir::omp::BindClauseOps &result) const; |
| 63 | bool processCancelDirectiveName( |
| 64 | mlir::omp::CancelDirectiveNameClauseOps &result) const; |
| 65 | bool |
| 66 | processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval, |
| 67 | mlir::omp::LoopRelatedClauseOps &result, |
| 68 | llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const; |
| 69 | bool processDevice(lower::StatementContext &stmtCtx, |
| 70 | mlir::omp::DeviceClauseOps &result) const; |
| 71 | bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const; |
| 72 | bool processDistSchedule(lower::StatementContext &stmtCtx, |
| 73 | mlir::omp::DistScheduleClauseOps &result) const; |
| 74 | bool processExclusive(mlir::Location currentLocation, |
| 75 | mlir::omp::ExclusiveClauseOps &result) const; |
| 76 | bool processFilter(lower::StatementContext &stmtCtx, |
| 77 | mlir::omp::FilterClauseOps &result) const; |
| 78 | bool processFinal(lower::StatementContext &stmtCtx, |
| 79 | mlir::omp::FinalClauseOps &result) const; |
| 80 | bool processGrainsize(lower::StatementContext &stmtCtx, |
| 81 | mlir::omp::GrainsizeClauseOps &result) const; |
| 82 | bool processHasDeviceAddr( |
| 83 | lower::StatementContext &stmtCtx, |
| 84 | mlir::omp::HasDeviceAddrClauseOps &result, |
| 85 | llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceSyms) const; |
| 86 | bool processHint(mlir::omp::HintClauseOps &result) const; |
| 87 | bool processInclusive(mlir::Location currentLocation, |
| 88 | mlir::omp::InclusiveClauseOps &result) const; |
| 89 | bool processMergeable(mlir::omp::MergeableClauseOps &result) const; |
| 90 | bool processNowait(mlir::omp::NowaitClauseOps &result) const; |
| 91 | bool processNumTasks(lower::StatementContext &stmtCtx, |
| 92 | mlir::omp::NumTasksClauseOps &result) const; |
| 93 | bool processNumTeams(lower::StatementContext &stmtCtx, |
| 94 | mlir::omp::NumTeamsClauseOps &result) const; |
| 95 | bool processNumThreads(lower::StatementContext &stmtCtx, |
| 96 | mlir::omp::NumThreadsClauseOps &result) const; |
| 97 | bool processOrder(mlir::omp::OrderClauseOps &result) const; |
| 98 | bool processOrdered(mlir::omp::OrderedClauseOps &result) const; |
| 99 | bool processPriority(lower::StatementContext &stmtCtx, |
| 100 | mlir::omp::PriorityClauseOps &result) const; |
| 101 | bool processProcBind(mlir::omp::ProcBindClauseOps &result) const; |
| 102 | bool processSafelen(mlir::omp::SafelenClauseOps &result) const; |
| 103 | bool processSchedule(lower::StatementContext &stmtCtx, |
| 104 | mlir::omp::ScheduleClauseOps &result) const; |
| 105 | bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const; |
| 106 | bool processThreadLimit(lower::StatementContext &stmtCtx, |
| 107 | mlir::omp::ThreadLimitClauseOps &result) const; |
| 108 | bool processUntied(mlir::omp::UntiedClauseOps &result) const; |
| 109 | |
| 110 | bool processDetach(mlir::omp::DetachClauseOps &result) const; |
| 111 | // 'Repeatable' clauses: They can appear multiple times in the clause list. |
| 112 | bool processAligned(mlir::omp::AlignedClauseOps &result) const; |
| 113 | bool processAllocate(mlir::omp::AllocateClauseOps &result) const; |
| 114 | bool processCopyin() const; |
| 115 | bool processCopyprivate(mlir::Location currentLocation, |
| 116 | mlir::omp::CopyprivateClauseOps &result) const; |
| 117 | bool processDefaultMap(lower::StatementContext &stmtCtx, |
| 118 | DefaultMapsTy &result) const; |
| 119 | bool processDepend(lower::SymMap &symMap, lower::StatementContext &stmtCtx, |
| 120 | mlir::omp::DependClauseOps &result) const; |
| 121 | bool |
| 122 | processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
| 123 | bool processIf(omp::clause::If::DirectiveNameModifier directiveName, |
| 124 | mlir::omp::IfClauseOps &result) const; |
| 125 | bool processInReduction( |
| 126 | mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result, |
| 127 | llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const; |
| 128 | bool processIsDevicePtr( |
| 129 | mlir::omp::IsDevicePtrClauseOps &result, |
| 130 | llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const; |
| 131 | bool processLinear(mlir::omp::LinearClauseOps &result) const; |
| 132 | bool |
| 133 | processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
| 134 | |
| 135 | // This method is used to process a map clause. |
| 136 | // The optional parameter mapSyms is used to store the original Fortran symbol |
| 137 | // for the map operands. It may be used later on to create the block_arguments |
| 138 | // for some of the directives that require it. |
| 139 | bool processMap(mlir::Location currentLocation, |
| 140 | lower::StatementContext &stmtCtx, |
| 141 | mlir::omp::MapClauseOps &result, |
| 142 | llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = |
| 143 | nullptr) const; |
| 144 | bool processMotionClauses(lower::StatementContext &stmtCtx, |
| 145 | mlir::omp::MapClauseOps &result); |
| 146 | bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const; |
| 147 | bool processReduction( |
| 148 | mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, |
| 149 | llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const; |
| 150 | bool processTaskReduction( |
| 151 | mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result, |
| 152 | llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const; |
| 153 | bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
| 154 | bool processUseDeviceAddr( |
| 155 | lower::StatementContext &stmtCtx, |
| 156 | mlir::omp::UseDeviceAddrClauseOps &result, |
| 157 | llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const; |
| 158 | bool processUseDevicePtr( |
| 159 | lower::StatementContext &stmtCtx, |
| 160 | mlir::omp::UseDevicePtrClauseOps &result, |
| 161 | llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const; |
| 162 | |
| 163 | // Call this method for these clauses that should be supported but are not |
| 164 | // implemented yet. It triggers a compilation error if any of the given |
| 165 | // clauses is found. |
| 166 | template <typename... Ts> |
| 167 | void processTODO(mlir::Location currentLocation, |
| 168 | llvm::omp::Directive directive) const; |
| 169 | |
| 170 | private: |
| 171 | using ClauseIterator = List<Clause>::const_iterator; |
| 172 | |
| 173 | /// Return the first instance of the given clause found in the clause list or |
| 174 | /// `nullptr` if not present. If more than one instance is expected, use |
| 175 | /// `findRepeatableClause` instead. |
| 176 | template <typename T> |
| 177 | const T *findUniqueClause(const parser::CharBlock **source = nullptr) const; |
| 178 | |
| 179 | /// Call `callbackFn` for each occurrence of the given clause. Return `true` |
| 180 | /// if at least one instance was found. |
| 181 | template <typename T> |
| 182 | bool findRepeatableClause( |
| 183 | std::function<void(const T &, const parser::CharBlock &source)> |
| 184 | callbackFn) const; |
| 185 | |
| 186 | /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. |
| 187 | template <typename T> |
| 188 | bool markClauseOccurrence(mlir::UnitAttr &result) const; |
| 189 | |
| 190 | void processMapObjects( |
| 191 | lower::StatementContext &stmtCtx, mlir::Location clauseLocation, |
| 192 | const omp::ObjectList &objects, |
| 193 | llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, |
| 194 | std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices, |
| 195 | llvm::SmallVectorImpl<mlir::Value> &mapVars, |
| 196 | llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms, |
| 197 | llvm::StringRef mapperIdNameRef = "" ) const; |
| 198 | |
| 199 | lower::AbstractConverter &converter; |
| 200 | semantics::SemanticsContext &semaCtx; |
| 201 | List<Clause> clauses; |
| 202 | }; |
| 203 | |
| 204 | template <typename... Ts> |
| 205 | void ClauseProcessor::processTODO(mlir::Location currentLocation, |
| 206 | llvm::omp::Directive directive) const { |
| 207 | auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) { |
| 208 | if (!x) |
| 209 | return; |
| 210 | unsigned version = semaCtx.langOptions().OpenMPVersion; |
| 211 | TODO(currentLocation, |
| 212 | "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() + |
| 213 | " in " + |
| 214 | llvm::omp::getOpenMPDirectiveName(directive, version).upper() + |
| 215 | " construct" ); |
| 216 | }; |
| 217 | |
| 218 | for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it) |
| 219 | (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...); |
| 220 | } |
| 221 | |
| 222 | template <typename T> |
| 223 | const T * |
| 224 | ClauseProcessor::findUniqueClause(const parser::CharBlock **source) const { |
| 225 | return ClauseFinder::findUniqueClause<T>(clauses, source); |
| 226 | } |
| 227 | |
| 228 | template <typename T> |
| 229 | bool ClauseProcessor::findRepeatableClause( |
| 230 | std::function<void(const T &, const parser::CharBlock &source)> callbackFn) |
| 231 | const { |
| 232 | return ClauseFinder::findRepeatableClause<T>(clauses, callbackFn); |
| 233 | } |
| 234 | |
| 235 | template <typename T> |
| 236 | bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const { |
| 237 | if (findUniqueClause<T>()) { |
| 238 | result = converter.getFirOpBuilder().getUnitAttr(); |
| 239 | return true; |
| 240 | } |
| 241 | return false; |
| 242 | } |
| 243 | |
| 244 | } // namespace omp |
| 245 | } // namespace lower |
| 246 | } // namespace Fortran |
| 247 | |
| 248 | #endif // FORTRAN_LOWER_CLAUSEPROCESSOR_H |
| 249 | |