1 | //===-- lib/Semantics/rewrite-directives.cpp ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "rewrite-directives.h" |
10 | #include "flang/Parser/parse-tree-visitor.h" |
11 | #include "flang/Parser/parse-tree.h" |
12 | #include "flang/Semantics/semantics.h" |
13 | #include "flang/Semantics/symbol.h" |
14 | #include "llvm/Frontend/OpenMP/OMP.h" |
15 | #include <list> |
16 | |
17 | namespace Fortran::semantics { |
18 | |
19 | using namespace parser::literals; |
20 | |
21 | class DirectiveRewriteMutator { |
22 | public: |
23 | explicit DirectiveRewriteMutator(SemanticsContext &context) |
24 | : context_{context} {} |
25 | |
26 | // Default action for a parse tree node is to visit children. |
27 | template <typename T> bool Pre(T &) { return true; } |
28 | template <typename T> void Post(T &) {} |
29 | |
30 | protected: |
31 | SemanticsContext &context_; |
32 | }; |
33 | |
34 | // Rewrite atomic constructs to add an explicit memory ordering to all that do |
35 | // not specify it, honoring in this way the `atomic_default_mem_order` clause of |
36 | // the REQUIRES directive. |
37 | class OmpRewriteMutator : public DirectiveRewriteMutator { |
38 | public: |
39 | explicit OmpRewriteMutator(SemanticsContext &context) |
40 | : DirectiveRewriteMutator(context) {} |
41 | |
42 | template <typename T> bool Pre(T &) { return true; } |
43 | template <typename T> void Post(T &) {} |
44 | |
45 | bool Pre(parser::OpenMPAtomicConstruct &); |
46 | bool Pre(parser::OpenMPRequiresConstruct &); |
47 | |
48 | private: |
49 | bool atomicDirectiveDefaultOrderFound_{false}; |
50 | }; |
51 | |
52 | bool OmpRewriteMutator::Pre(parser::OpenMPAtomicConstruct &x) { |
53 | // Find top-level parent of the operation. |
54 | Symbol *topLevelParent{common::visit( |
55 | [&](auto &atomic) { |
56 | Symbol *symbol{nullptr}; |
57 | Scope *scope{ |
58 | &context_.FindScope(std::get<parser::Verbatim>(atomic.t).source)}; |
59 | do { |
60 | if (Symbol * parent{scope->symbol()}) { |
61 | symbol = parent; |
62 | } |
63 | scope = &scope->parent(); |
64 | } while (!scope->IsGlobal()); |
65 | |
66 | assert(symbol && |
67 | "Atomic construct must be within a scope associated with a symbol" ); |
68 | return symbol; |
69 | }, |
70 | x.u)}; |
71 | |
72 | // Get the `atomic_default_mem_order` clause from the top-level parent. |
73 | std::optional<common::OmpAtomicDefaultMemOrderType> defaultMemOrder; |
74 | common::visit( |
75 | [&](auto &details) { |
76 | if constexpr (std::is_convertible_v<decltype(&details), |
77 | WithOmpDeclarative *>) { |
78 | if (details.has_ompAtomicDefaultMemOrder()) { |
79 | defaultMemOrder = *details.ompAtomicDefaultMemOrder(); |
80 | } |
81 | } |
82 | }, |
83 | topLevelParent->details()); |
84 | |
85 | if (!defaultMemOrder) { |
86 | return false; |
87 | } |
88 | |
89 | auto findMemOrderClause = |
90 | [](const std::list<parser::OmpAtomicClause> &clauses) { |
91 | return std::find_if( |
92 | clauses.begin(), clauses.end(), [](const auto &clause) { |
93 | return std::get_if<parser::OmpMemoryOrderClause>( |
94 | &clause.u); |
95 | }) != clauses.end(); |
96 | }; |
97 | |
98 | // Get the clause list to which the new memory order clause must be added, |
99 | // only if there are no other memory order clauses present for this atomic |
100 | // directive. |
101 | std::list<parser::OmpAtomicClause> *clauseList = common::visit( |
102 | common::visitors{[&](parser::OmpAtomic &atomicConstruct) { |
103 | // OmpAtomic only has a single list of clauses. |
104 | auto &clauses{std::get<parser::OmpAtomicClauseList>( |
105 | atomicConstruct.t)}; |
106 | return !findMemOrderClause(clauses.v) ? &clauses.v |
107 | : nullptr; |
108 | }, |
109 | [&](auto &atomicConstruct) { |
110 | // All other atomic constructs have two lists of clauses. |
111 | auto &clausesLhs{std::get<0>(atomicConstruct.t)}; |
112 | auto &clausesRhs{std::get<2>(atomicConstruct.t)}; |
113 | return !findMemOrderClause(clausesLhs.v) && |
114 | !findMemOrderClause(clausesRhs.v) |
115 | ? &clausesRhs.v |
116 | : nullptr; |
117 | }}, |
118 | x.u); |
119 | |
120 | // Add a memory order clause to the atomic directive. |
121 | if (clauseList) { |
122 | atomicDirectiveDefaultOrderFound_ = true; |
123 | switch (*defaultMemOrder) { |
124 | case common::OmpAtomicDefaultMemOrderType::AcqRel: |
125 | clauseList->emplace_back<parser::OmpMemoryOrderClause>(common::visit( |
126 | common::visitors{[](parser::OmpAtomicRead &) -> parser::OmpClause { |
127 | return parser::OmpClause::Acquire{}; |
128 | }, |
129 | [](parser::OmpAtomicCapture &) -> parser::OmpClause { |
130 | return parser::OmpClause::AcqRel{}; |
131 | }, |
132 | [](auto &) -> parser::OmpClause { |
133 | // parser::{OmpAtomic, OmpAtomicUpdate, OmpAtomicWrite} |
134 | return parser::OmpClause::Release{}; |
135 | }}, |
136 | x.u)); |
137 | break; |
138 | case common::OmpAtomicDefaultMemOrderType::Relaxed: |
139 | clauseList->emplace_back<parser::OmpMemoryOrderClause>( |
140 | parser::OmpClause{parser::OmpClause::Relaxed{}}); |
141 | break; |
142 | case common::OmpAtomicDefaultMemOrderType::SeqCst: |
143 | clauseList->emplace_back<parser::OmpMemoryOrderClause>( |
144 | parser::OmpClause{parser::OmpClause::SeqCst{}}); |
145 | break; |
146 | } |
147 | } |
148 | |
149 | return false; |
150 | } |
151 | |
152 | bool OmpRewriteMutator::Pre(parser::OpenMPRequiresConstruct &x) { |
153 | for (parser::OmpClause &clause : std::get<parser::OmpClauseList>(x.t).v) { |
154 | if (std::holds_alternative<parser::OmpClause::AtomicDefaultMemOrder>( |
155 | clause.u) && |
156 | atomicDirectiveDefaultOrderFound_) { |
157 | context_.Say(clause.source, |
158 | "REQUIRES directive with '%s' clause found lexically after atomic " |
159 | "operation without a memory order clause"_err_en_US , |
160 | parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( |
161 | llvm::omp::OMPC_atomic_default_mem_order) |
162 | .str())); |
163 | } |
164 | } |
165 | return false; |
166 | } |
167 | |
168 | bool RewriteOmpParts(SemanticsContext &context, parser::Program &program) { |
169 | if (!context.IsEnabled(common::LanguageFeature::OpenMP)) { |
170 | return true; |
171 | } |
172 | OmpRewriteMutator ompMutator{context}; |
173 | parser::Walk(program, ompMutator); |
174 | return !context.AnyFatalError(); |
175 | } |
176 | |
177 | } // namespace Fortran::semantics |
178 | |