1//===-- IterationSpace.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Lower/IterationSpace.h"
14#include "flang/Evaluate/expression.h"
15#include "flang/Lower/AbstractConverter.h"
16#include "flang/Lower/Support/Utils.h"
17#include "llvm/Support/Debug.h"
18#include <optional>
19
20#define DEBUG_TYPE "flang-lower-iteration-space"
21
22namespace {
23
24/// This class can recover the base array in an expression that contains
25/// explicit iteration space symbols. Most of the class can be ignored as it is
26/// boilerplate Fortran::evaluate::Expr traversal.
27class ArrayBaseFinder {
28public:
29 using RT = bool;
30
31 ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms)
32 : controlVars(syms) {}
33
34 template <typename T>
35 void operator()(const T &x) {
36 (void)find(x);
37 }
38
39 /// Get the list of bases.
40 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases>
41 getBases() const {
42 LLVM_DEBUG(llvm::dbgs()
43 << "number of array bases found: " << bases.size() << '\n');
44 return bases;
45 }
46
47private:
48 // First, the cases that are of interest.
49 RT find(const Fortran::semantics::Symbol &symbol) {
50 if (symbol.Rank() > 0) {
51 bases.push_back(&symbol);
52 return true;
53 }
54 return {};
55 }
56 RT find(const Fortran::evaluate::Component &x) {
57 auto found = find(x.base());
58 if (!found && x.base().Rank() == 0 && x.Rank() > 0) {
59 bases.push_back(&x);
60 return true;
61 }
62 return found;
63 }
64 RT find(const Fortran::evaluate::ArrayRef &x) {
65 for (const auto &sub : x.subscript())
66 (void)find(sub);
67 if (x.base().IsSymbol()) {
68 if (x.Rank() > 0 || intersection(x.subscript())) {
69 bases.push_back(&x);
70 return true;
71 }
72 return {};
73 }
74 auto found = find(x.base());
75 if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) ||
76 intersection(x.subscript()))) {
77 bases.push_back(&x);
78 return true;
79 }
80 return found;
81 }
82 RT find(const Fortran::evaluate::Triplet &x) {
83 if (const auto *lower = x.GetLower())
84 (void)find(*lower);
85 if (const auto *upper = x.GetUpper())
86 (void)find(*upper);
87 return find(x.GetStride());
88 }
89 RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) {
90 return find(x.value());
91 }
92 RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); }
93 RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); }
94 RT find(const Fortran::evaluate::CoarrayRef &x) {
95 assert(false && "coarray reference");
96 return {};
97 }
98
99 template <typename A>
100 bool intersection(const A &subscripts) {
101 return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts);
102 }
103
104 // The rest is traversal boilerplate and can be ignored.
105 RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); }
106 template <typename A>
107 RT find(const Fortran::semantics::SymbolRef x) {
108 return find(*x);
109 }
110 RT find(const Fortran::evaluate::NamedEntity &x) {
111 if (x.IsSymbol())
112 return find(x.GetFirstSymbol());
113 return find(x.GetComponent());
114 }
115
116 template <typename A, bool C>
117 RT find(const Fortran::common::Indirection<A, C> &x) {
118 return find(x.value());
119 }
120 template <typename A>
121 RT find(const std::unique_ptr<A> &x) {
122 return find(x.get());
123 }
124 template <typename A>
125 RT find(const std::shared_ptr<A> &x) {
126 return find(x.get());
127 }
128 template <typename A>
129 RT find(const A *x) {
130 if (x)
131 return find(*x);
132 return {};
133 }
134 template <typename A>
135 RT find(const std::optional<A> &x) {
136 if (x)
137 return find(*x);
138 return {};
139 }
140 template <typename... A>
141 RT find(const std::variant<A...> &u) {
142 return Fortran::common::visit([&](const auto &v) { return find(v); }, u);
143 }
144 template <typename A>
145 RT find(const std::vector<A> &x) {
146 for (auto &v : x)
147 (void)find(v);
148 return {};
149 }
150 RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; }
151 RT find(const Fortran::evaluate::NullPointer &) { return {}; }
152 template <typename T>
153 RT find(const Fortran::evaluate::Constant<T> &x) {
154 return {};
155 }
156 RT find(const Fortran::evaluate::StaticDataObject &) { return {}; }
157 RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; }
158 RT find(const Fortran::evaluate::BaseObject &x) {
159 (void)find(x.u);
160 return {};
161 }
162 RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; }
163 RT find(const Fortran::evaluate::ComplexPart &x) { return {}; }
164 template <typename T>
165 RT find(const Fortran::evaluate::Designator<T> &x) {
166 return find(x.u);
167 }
168 RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; }
169 RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; }
170 RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; }
171 RT find(const Fortran::evaluate::ProcedureRef &x) {
172 (void)find(x.proc());
173 if (x.IsElemental())
174 (void)find(x.arguments());
175 return {};
176 }
177 RT find(const Fortran::evaluate::ActualArgument &x) {
178 if (const auto *sym = x.GetAssumedTypeDummy())
179 (void)find(*sym);
180 else
181 (void)find(x.UnwrapExpr());
182 return {};
183 }
184 template <typename T>
185 RT find(const Fortran::evaluate::FunctionRef<T> &x) {
186 (void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x));
187 return {};
188 }
189 template <typename T>
190 RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) {
191 return {};
192 }
193 template <typename T>
194 RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) {
195 return {};
196 }
197 template <typename T>
198 RT find(const Fortran::evaluate::ImpliedDo<T> &) {
199 return {};
200 }
201 RT find(const Fortran::semantics::ParamValue &) { return {}; }
202 RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; }
203 RT find(const Fortran::evaluate::StructureConstructor &) { return {}; }
204 template <typename D, typename R, typename O>
205 RT find(const Fortran::evaluate::Operation<D, R, O> &op) {
206 (void)find(op.left());
207 return false;
208 }
209 template <typename D, typename R, typename LO, typename RO>
210 RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) {
211 (void)find(op.left());
212 (void)find(op.right());
213 return false;
214 }
215 RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
216 (void)find(x.u);
217 return {};
218 }
219 template <typename T>
220 RT find(const Fortran::evaluate::Expr<T> &x) {
221 (void)find(x.u);
222 return {};
223 }
224
225 llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases;
226 llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars;
227};
228
229} // namespace
230
231void Fortran::lower::ExplicitIterSpace::leave() {
232 ccLoopNest.pop_back();
233 --forallContextOpen;
234 conditionalCleanup();
235}
236
237void Fortran::lower::ExplicitIterSpace::addSymbol(
238 Fortran::lower::FrontEndSymbol sym) {
239 assert(!symbolStack.empty());
240 symbolStack.back().push_back(sym);
241}
242
243void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x,
244 bool lhs) {
245 ArrayBaseFinder finder(collectAllSymbols());
246 finder(*x);
247 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases =
248 finder.getBases();
249 if (rhsBases.empty())
250 endAssign();
251 if (lhs) {
252 if (bases.empty()) {
253 lhsBases.push_back(std::nullopt);
254 return;
255 }
256 assert(bases.size() >= 1 && "must detect an array reference on lhs");
257 if (bases.size() > 1)
258 rhsBases.back().append(bases.begin(), bases.end() - 1);
259 lhsBases.push_back(bases.back());
260 return;
261 }
262 rhsBases.back().append(bases.begin(), bases.end());
263}
264
265void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); }
266
267void Fortran::lower::ExplicitIterSpace::pushLevel() {
268 symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{});
269}
270
271void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); }
272
273void Fortran::lower::ExplicitIterSpace::conditionalCleanup() {
274 if (forallContextOpen == 0) {
275 // Exiting the outermost FORALL context.
276 // Cleanup any residual mask buffers.
277 outermostContext().finalizeAndReset();
278 // Clear and reset all the cached information.
279 symbolStack.clear();
280 lhsBases.clear();
281 rhsBases.clear();
282 loadBindings.clear();
283 ccLoopNest.clear();
284 innerArgs.clear();
285 outerLoop = std::nullopt;
286 clearLoops();
287 counter = 0;
288 }
289}
290
291std::optional<size_t>
292Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) {
293 if (lhsBases[counter]) {
294 auto ld = loadBindings.find(*lhsBases[counter]);
295 std::optional<size_t> optPos;
296 if (ld != loadBindings.end() && ld->second == load)
297 optPos = static_cast<size_t>(0u);
298 assert(optPos.has_value() && "load does not correspond to lhs");
299 return optPos;
300 }
301 return std::nullopt;
302}
303
304llvm::SmallVector<Fortran::lower::FrontEndSymbol>
305Fortran::lower::ExplicitIterSpace::collectAllSymbols() {
306 llvm::SmallVector<Fortran::lower::FrontEndSymbol> result;
307 for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack)
308 result.append(vec.begin(), vec.end());
309 return result;
310}
311
312llvm::raw_ostream &
313Fortran::lower::operator<<(llvm::raw_ostream &s,
314 const Fortran::lower::ImplicitIterSpace &e) {
315 for (const llvm::SmallVector<
316 Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs :
317 e.getMasks()) {
318 s << "{ ";
319 for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs)
320 x->AsFortran(s << '(') << "), ";
321 s << "}\n";
322 }
323 return s;
324}
325
326llvm::raw_ostream &
327Fortran::lower::operator<<(llvm::raw_ostream &s,
328 const Fortran::lower::ExplicitIterSpace &e) {
329 auto dump = [&](const auto &u) {
330 Fortran::common::visit(
331 Fortran::common::visitors{
332 [&](const Fortran::semantics::Symbol *y) {
333 s << " " << *y << '\n';
334 },
335 [&](const Fortran::evaluate::ArrayRef *y) {
336 s << " ";
337 if (y->base().IsSymbol())
338 s << y->base().GetFirstSymbol();
339 else
340 s << y->base().GetComponent().GetLastSymbol();
341 s << '\n';
342 },
343 [&](const Fortran::evaluate::Component *y) {
344 s << " " << y->GetLastSymbol() << '\n';
345 }},
346 u);
347 };
348 s << "LHS bases:\n";
349 for (const std::optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u :
350 e.lhsBases)
351 if (u)
352 dump(*u);
353 s << "RHS bases:\n";
354 for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases>
355 &bases : e.rhsBases) {
356 for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases)
357 dump(u);
358 s << '\n';
359 }
360 return s;
361}
362
363void Fortran::lower::ImplicitIterSpace::dump() const {
364 llvm::errs() << *this << '\n';
365}
366
367void Fortran::lower::ExplicitIterSpace::dump() const {
368 llvm::errs() << *this << '\n';
369}
370

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of flang/lib/Lower/IterationSpace.cpp