1//===-- IterationSpace.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Lower/IterationSpace.h"
14#include "flang/Evaluate/expression.h"
15#include "flang/Lower/AbstractConverter.h"
16#include "flang/Lower/Support/Utils.h"
17#include "llvm/Support/Debug.h"
18#include <optional>
19
20#define DEBUG_TYPE "flang-lower-iteration-space"
21
22unsigned Fortran::lower::getHashValue(
23 const Fortran::lower::ExplicitIterSpace::ArrayBases &x) {
24 return std::visit(
25 [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x);
26}
27
28bool Fortran::lower::isEqual(
29 const Fortran::lower::ExplicitIterSpace::ArrayBases &x,
30 const Fortran::lower::ExplicitIterSpace::ArrayBases &y) {
31 return std::visit(
32 Fortran::common::visitors{
33 // Fortran::semantics::Symbol * are the exception here. These pointers
34 // have identity; if two Symbol * values are the same (different) then
35 // they are the same (different) logical symbol.
36 [&](Fortran::lower::FrontEndSymbol p,
37 Fortran::lower::FrontEndSymbol q) { return p == q; },
38 [&](const auto *p, const auto *q) {
39 if constexpr (std::is_same_v<decltype(p), decltype(q)>) {
40 LLVM_DEBUG(llvm::dbgs()
41 << "is equal: " << p << ' ' << q << ' '
42 << IsEqualEvaluateExpr::isEqual(*p, *q) << '\n');
43 return IsEqualEvaluateExpr::isEqual(*p, *q);
44 } else {
45 // Different subtree types are never equal.
46 return false;
47 }
48 }},
49 x, y);
50}
51
52namespace {
53
54/// This class can recover the base array in an expression that contains
55/// explicit iteration space symbols. Most of the class can be ignored as it is
56/// boilerplate Fortran::evaluate::Expr traversal.
57class ArrayBaseFinder {
58public:
59 using RT = bool;
60
61 ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms)
62 : controlVars(syms.begin(), syms.end()) {}
63
64 template <typename T>
65 void operator()(const T &x) {
66 (void)find(x);
67 }
68
69 /// Get the list of bases.
70 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases>
71 getBases() const {
72 LLVM_DEBUG(llvm::dbgs()
73 << "number of array bases found: " << bases.size() << '\n');
74 return bases;
75 }
76
77private:
78 // First, the cases that are of interest.
79 RT find(const Fortran::semantics::Symbol &symbol) {
80 if (symbol.Rank() > 0) {
81 bases.push_back(&symbol);
82 return true;
83 }
84 return {};
85 }
86 RT find(const Fortran::evaluate::Component &x) {
87 auto found = find(x.base());
88 if (!found && x.base().Rank() == 0 && x.Rank() > 0) {
89 bases.push_back(&x);
90 return true;
91 }
92 return found;
93 }
94 RT find(const Fortran::evaluate::ArrayRef &x) {
95 for (const auto &sub : x.subscript())
96 (void)find(sub);
97 if (x.base().IsSymbol()) {
98 if (x.Rank() > 0 || intersection(x.subscript())) {
99 bases.push_back(&x);
100 return true;
101 }
102 return {};
103 }
104 auto found = find(x.base());
105 if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) ||
106 intersection(x.subscript()))) {
107 bases.push_back(&x);
108 return true;
109 }
110 return found;
111 }
112 RT find(const Fortran::evaluate::Triplet &x) {
113 if (const auto *lower = x.GetLower())
114 (void)find(*lower);
115 if (const auto *upper = x.GetUpper())
116 (void)find(*upper);
117 return find(x.GetStride());
118 }
119 RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) {
120 return find(x.value());
121 }
122 RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); }
123 RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); }
124 RT find(const Fortran::evaluate::CoarrayRef &x) {
125 assert(false && "coarray reference");
126 return {};
127 }
128
129 template <typename A>
130 bool intersection(const A &subscripts) {
131 return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts);
132 }
133
134 // The rest is traversal boilerplate and can be ignored.
135 RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); }
136 template <typename A>
137 RT find(const Fortran::semantics::SymbolRef x) {
138 return find(*x);
139 }
140 RT find(const Fortran::evaluate::NamedEntity &x) {
141 if (x.IsSymbol())
142 return find(x.GetFirstSymbol());
143 return find(x.GetComponent());
144 }
145
146 template <typename A, bool C>
147 RT find(const Fortran::common::Indirection<A, C> &x) {
148 return find(x.value());
149 }
150 template <typename A>
151 RT find(const std::unique_ptr<A> &x) {
152 return find(x.get());
153 }
154 template <typename A>
155 RT find(const std::shared_ptr<A> &x) {
156 return find(x.get());
157 }
158 template <typename A>
159 RT find(const A *x) {
160 if (x)
161 return find(*x);
162 return {};
163 }
164 template <typename A>
165 RT find(const std::optional<A> &x) {
166 if (x)
167 return find(*x);
168 return {};
169 }
170 template <typename... A>
171 RT find(const std::variant<A...> &u) {
172 return std::visit([&](const auto &v) { return find(v); }, u);
173 }
174 template <typename A>
175 RT find(const std::vector<A> &x) {
176 for (auto &v : x)
177 (void)find(v);
178 return {};
179 }
180 RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; }
181 RT find(const Fortran::evaluate::NullPointer &) { return {}; }
182 template <typename T>
183 RT find(const Fortran::evaluate::Constant<T> &x) {
184 return {};
185 }
186 RT find(const Fortran::evaluate::StaticDataObject &) { return {}; }
187 RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; }
188 RT find(const Fortran::evaluate::BaseObject &x) {
189 (void)find(x.u);
190 return {};
191 }
192 RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; }
193 RT find(const Fortran::evaluate::ComplexPart &x) { return {}; }
194 template <typename T>
195 RT find(const Fortran::evaluate::Designator<T> &x) {
196 return find(x.u);
197 }
198 template <typename T>
199 RT find(const Fortran::evaluate::Variable<T> &x) {
200 return find(x.u);
201 }
202 RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; }
203 RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; }
204 RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; }
205 RT find(const Fortran::evaluate::ProcedureRef &x) {
206 (void)find(x.proc());
207 if (x.IsElemental())
208 (void)find(x.arguments());
209 return {};
210 }
211 RT find(const Fortran::evaluate::ActualArgument &x) {
212 if (const auto *sym = x.GetAssumedTypeDummy())
213 (void)find(*sym);
214 else
215 (void)find(x.UnwrapExpr());
216 return {};
217 }
218 template <typename T>
219 RT find(const Fortran::evaluate::FunctionRef<T> &x) {
220 (void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x));
221 return {};
222 }
223 template <typename T>
224 RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) {
225 return {};
226 }
227 template <typename T>
228 RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) {
229 return {};
230 }
231 template <typename T>
232 RT find(const Fortran::evaluate::ImpliedDo<T> &) {
233 return {};
234 }
235 RT find(const Fortran::semantics::ParamValue &) { return {}; }
236 RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; }
237 RT find(const Fortran::evaluate::StructureConstructor &) { return {}; }
238 template <typename D, typename R, typename O>
239 RT find(const Fortran::evaluate::Operation<D, R, O> &op) {
240 (void)find(op.left());
241 return false;
242 }
243 template <typename D, typename R, typename LO, typename RO>
244 RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) {
245 (void)find(op.left());
246 (void)find(op.right());
247 return false;
248 }
249 RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
250 (void)find(x.u);
251 return {};
252 }
253 template <typename T>
254 RT find(const Fortran::evaluate::Expr<T> &x) {
255 (void)find(x.u);
256 return {};
257 }
258
259 llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases;
260 llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars;
261};
262
263} // namespace
264
265void Fortran::lower::ExplicitIterSpace::leave() {
266 ccLoopNest.pop_back();
267 --forallContextOpen;
268 conditionalCleanup();
269}
270
271void Fortran::lower::ExplicitIterSpace::addSymbol(
272 Fortran::lower::FrontEndSymbol sym) {
273 assert(!symbolStack.empty());
274 symbolStack.back().push_back(sym);
275}
276
277void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x,
278 bool lhs) {
279 ArrayBaseFinder finder(collectAllSymbols());
280 finder(*x);
281 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases =
282 finder.getBases();
283 if (rhsBases.empty())
284 endAssign();
285 if (lhs) {
286 if (bases.empty()) {
287 lhsBases.push_back(std::nullopt);
288 return;
289 }
290 assert(bases.size() >= 1 && "must detect an array reference on lhs");
291 if (bases.size() > 1)
292 rhsBases.back().append(bases.begin(), bases.end() - 1);
293 lhsBases.push_back(bases.back());
294 return;
295 }
296 rhsBases.back().append(bases.begin(), bases.end());
297}
298
299void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); }
300
301void Fortran::lower::ExplicitIterSpace::pushLevel() {
302 symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{});
303}
304
305void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); }
306
307void Fortran::lower::ExplicitIterSpace::conditionalCleanup() {
308 if (forallContextOpen == 0) {
309 // Exiting the outermost FORALL context.
310 // Cleanup any residual mask buffers.
311 outermostContext().finalizeAndReset();
312 // Clear and reset all the cached information.
313 symbolStack.clear();
314 lhsBases.clear();
315 rhsBases.clear();
316 loadBindings.clear();
317 ccLoopNest.clear();
318 innerArgs.clear();
319 outerLoop = std::nullopt;
320 clearLoops();
321 counter = 0;
322 }
323}
324
325std::optional<size_t>
326Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) {
327 if (lhsBases[counter]) {
328 auto ld = loadBindings.find(*lhsBases[counter]);
329 std::optional<size_t> optPos;
330 if (ld != loadBindings.end() && ld->second == load)
331 optPos = static_cast<size_t>(0u);
332 assert(optPos.has_value() && "load does not correspond to lhs");
333 return optPos;
334 }
335 return std::nullopt;
336}
337
338llvm::SmallVector<Fortran::lower::FrontEndSymbol>
339Fortran::lower::ExplicitIterSpace::collectAllSymbols() {
340 llvm::SmallVector<Fortran::lower::FrontEndSymbol> result;
341 for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack)
342 result.append(vec.begin(), vec.end());
343 return result;
344}
345
346llvm::raw_ostream &
347Fortran::lower::operator<<(llvm::raw_ostream &s,
348 const Fortran::lower::ImplicitIterSpace &e) {
349 for (const llvm::SmallVector<
350 Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs :
351 e.getMasks()) {
352 s << "{ ";
353 for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs)
354 x->AsFortran(s << '(') << "), ";
355 s << "}\n";
356 }
357 return s;
358}
359
360llvm::raw_ostream &
361Fortran::lower::operator<<(llvm::raw_ostream &s,
362 const Fortran::lower::ExplicitIterSpace &e) {
363 auto dump = [&](const auto &u) {
364 std::visit(Fortran::common::visitors{
365 [&](const Fortran::semantics::Symbol *y) {
366 s << " " << *y << '\n';
367 },
368 [&](const Fortran::evaluate::ArrayRef *y) {
369 s << " ";
370 if (y->base().IsSymbol())
371 s << y->base().GetFirstSymbol();
372 else
373 s << y->base().GetComponent().GetLastSymbol();
374 s << '\n';
375 },
376 [&](const Fortran::evaluate::Component *y) {
377 s << " " << y->GetLastSymbol() << '\n';
378 }},
379 u);
380 };
381 s << "LHS bases:\n";
382 for (const std::optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u :
383 e.lhsBases)
384 if (u)
385 dump(*u);
386 s << "RHS bases:\n";
387 for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases>
388 &bases : e.rhsBases) {
389 for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases)
390 dump(u);
391 s << '\n';
392 }
393 return s;
394}
395
396void Fortran::lower::ImplicitIterSpace::dump() const {
397 llvm::errs() << *this << '\n';
398}
399
400void Fortran::lower::ExplicitIterSpace::dump() const {
401 llvm::errs() << *this << '\n';
402}
403

source code of flang/lib/Lower/IterationSpace.cpp