1 | //===-- IterationSpace.cpp ------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Lower/IterationSpace.h" |
14 | #include "flang/Evaluate/expression.h" |
15 | #include "flang/Lower/AbstractConverter.h" |
16 | #include "flang/Lower/Support/Utils.h" |
17 | #include "llvm/Support/Debug.h" |
18 | #include <optional> |
19 | |
20 | #define DEBUG_TYPE "flang-lower-iteration-space" |
21 | |
22 | unsigned Fortran::lower::getHashValue( |
23 | const Fortran::lower::ExplicitIterSpace::ArrayBases &x) { |
24 | return std::visit( |
25 | [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x); |
26 | } |
27 | |
28 | bool Fortran::lower::isEqual( |
29 | const Fortran::lower::ExplicitIterSpace::ArrayBases &x, |
30 | const Fortran::lower::ExplicitIterSpace::ArrayBases &y) { |
31 | return std::visit( |
32 | Fortran::common::visitors{ |
33 | // Fortran::semantics::Symbol * are the exception here. These pointers |
34 | // have identity; if two Symbol * values are the same (different) then |
35 | // they are the same (different) logical symbol. |
36 | [&](Fortran::lower::FrontEndSymbol p, |
37 | Fortran::lower::FrontEndSymbol q) { return p == q; }, |
38 | [&](const auto *p, const auto *q) { |
39 | if constexpr (std::is_same_v<decltype(p), decltype(q)>) { |
40 | LLVM_DEBUG(llvm::dbgs() |
41 | << "is equal: " << p << ' ' << q << ' ' |
42 | << IsEqualEvaluateExpr::isEqual(*p, *q) << '\n'); |
43 | return IsEqualEvaluateExpr::isEqual(*p, *q); |
44 | } else { |
45 | // Different subtree types are never equal. |
46 | return false; |
47 | } |
48 | }}, |
49 | x, y); |
50 | } |
51 | |
52 | namespace { |
53 | |
54 | /// This class can recover the base array in an expression that contains |
55 | /// explicit iteration space symbols. Most of the class can be ignored as it is |
56 | /// boilerplate Fortran::evaluate::Expr traversal. |
57 | class ArrayBaseFinder { |
58 | public: |
59 | using RT = bool; |
60 | |
61 | ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms) |
62 | : controlVars(syms.begin(), syms.end()) {} |
63 | |
64 | template <typename T> |
65 | void operator()(const T &x) { |
66 | (void)find(x); |
67 | } |
68 | |
69 | /// Get the list of bases. |
70 | llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> |
71 | getBases() const { |
72 | LLVM_DEBUG(llvm::dbgs() |
73 | << "number of array bases found: " << bases.size() << '\n'); |
74 | return bases; |
75 | } |
76 | |
77 | private: |
78 | // First, the cases that are of interest. |
79 | RT find(const Fortran::semantics::Symbol &symbol) { |
80 | if (symbol.Rank() > 0) { |
81 | bases.push_back(&symbol); |
82 | return true; |
83 | } |
84 | return {}; |
85 | } |
86 | RT find(const Fortran::evaluate::Component &x) { |
87 | auto found = find(x.base()); |
88 | if (!found && x.base().Rank() == 0 && x.Rank() > 0) { |
89 | bases.push_back(&x); |
90 | return true; |
91 | } |
92 | return found; |
93 | } |
94 | RT find(const Fortran::evaluate::ArrayRef &x) { |
95 | for (const auto &sub : x.subscript()) |
96 | (void)find(sub); |
97 | if (x.base().IsSymbol()) { |
98 | if (x.Rank() > 0 || intersection(x.subscript())) { |
99 | bases.push_back(&x); |
100 | return true; |
101 | } |
102 | return {}; |
103 | } |
104 | auto found = find(x.base()); |
105 | if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) || |
106 | intersection(x.subscript()))) { |
107 | bases.push_back(&x); |
108 | return true; |
109 | } |
110 | return found; |
111 | } |
112 | RT find(const Fortran::evaluate::Triplet &x) { |
113 | if (const auto *lower = x.GetLower()) |
114 | (void)find(*lower); |
115 | if (const auto *upper = x.GetUpper()) |
116 | (void)find(*upper); |
117 | return find(x.GetStride()); |
118 | } |
119 | RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) { |
120 | return find(x.value()); |
121 | } |
122 | RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); } |
123 | RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); } |
124 | RT find(const Fortran::evaluate::CoarrayRef &x) { |
125 | assert(false && "coarray reference" ); |
126 | return {}; |
127 | } |
128 | |
129 | template <typename A> |
130 | bool intersection(const A &subscripts) { |
131 | return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts); |
132 | } |
133 | |
134 | // The rest is traversal boilerplate and can be ignored. |
135 | RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); } |
136 | template <typename A> |
137 | RT find(const Fortran::semantics::SymbolRef x) { |
138 | return find(*x); |
139 | } |
140 | RT find(const Fortran::evaluate::NamedEntity &x) { |
141 | if (x.IsSymbol()) |
142 | return find(x.GetFirstSymbol()); |
143 | return find(x.GetComponent()); |
144 | } |
145 | |
146 | template <typename A, bool C> |
147 | RT find(const Fortran::common::Indirection<A, C> &x) { |
148 | return find(x.value()); |
149 | } |
150 | template <typename A> |
151 | RT find(const std::unique_ptr<A> &x) { |
152 | return find(x.get()); |
153 | } |
154 | template <typename A> |
155 | RT find(const std::shared_ptr<A> &x) { |
156 | return find(x.get()); |
157 | } |
158 | template <typename A> |
159 | RT find(const A *x) { |
160 | if (x) |
161 | return find(*x); |
162 | return {}; |
163 | } |
164 | template <typename A> |
165 | RT find(const std::optional<A> &x) { |
166 | if (x) |
167 | return find(*x); |
168 | return {}; |
169 | } |
170 | template <typename... A> |
171 | RT find(const std::variant<A...> &u) { |
172 | return std::visit([&](const auto &v) { return find(v); }, u); |
173 | } |
174 | template <typename A> |
175 | RT find(const std::vector<A> &x) { |
176 | for (auto &v : x) |
177 | (void)find(v); |
178 | return {}; |
179 | } |
180 | RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; } |
181 | RT find(const Fortran::evaluate::NullPointer &) { return {}; } |
182 | template <typename T> |
183 | RT find(const Fortran::evaluate::Constant<T> &x) { |
184 | return {}; |
185 | } |
186 | RT find(const Fortran::evaluate::StaticDataObject &) { return {}; } |
187 | RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; } |
188 | RT find(const Fortran::evaluate::BaseObject &x) { |
189 | (void)find(x.u); |
190 | return {}; |
191 | } |
192 | RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; } |
193 | RT find(const Fortran::evaluate::ComplexPart &x) { return {}; } |
194 | template <typename T> |
195 | RT find(const Fortran::evaluate::Designator<T> &x) { |
196 | return find(x.u); |
197 | } |
198 | template <typename T> |
199 | RT find(const Fortran::evaluate::Variable<T> &x) { |
200 | return find(x.u); |
201 | } |
202 | RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; } |
203 | RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; } |
204 | RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; } |
205 | RT find(const Fortran::evaluate::ProcedureRef &x) { |
206 | (void)find(x.proc()); |
207 | if (x.IsElemental()) |
208 | (void)find(x.arguments()); |
209 | return {}; |
210 | } |
211 | RT find(const Fortran::evaluate::ActualArgument &x) { |
212 | if (const auto *sym = x.GetAssumedTypeDummy()) |
213 | (void)find(*sym); |
214 | else |
215 | (void)find(x.UnwrapExpr()); |
216 | return {}; |
217 | } |
218 | template <typename T> |
219 | RT find(const Fortran::evaluate::FunctionRef<T> &x) { |
220 | (void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x)); |
221 | return {}; |
222 | } |
223 | template <typename T> |
224 | RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) { |
225 | return {}; |
226 | } |
227 | template <typename T> |
228 | RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) { |
229 | return {}; |
230 | } |
231 | template <typename T> |
232 | RT find(const Fortran::evaluate::ImpliedDo<T> &) { |
233 | return {}; |
234 | } |
235 | RT find(const Fortran::semantics::ParamValue &) { return {}; } |
236 | RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; } |
237 | RT find(const Fortran::evaluate::StructureConstructor &) { return {}; } |
238 | template <typename D, typename R, typename O> |
239 | RT find(const Fortran::evaluate::Operation<D, R, O> &op) { |
240 | (void)find(op.left()); |
241 | return false; |
242 | } |
243 | template <typename D, typename R, typename LO, typename RO> |
244 | RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) { |
245 | (void)find(op.left()); |
246 | (void)find(op.right()); |
247 | return false; |
248 | } |
249 | RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) { |
250 | (void)find(x.u); |
251 | return {}; |
252 | } |
253 | template <typename T> |
254 | RT find(const Fortran::evaluate::Expr<T> &x) { |
255 | (void)find(x.u); |
256 | return {}; |
257 | } |
258 | |
259 | llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases; |
260 | llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars; |
261 | }; |
262 | |
263 | } // namespace |
264 | |
265 | void Fortran::lower::ExplicitIterSpace::leave() { |
266 | ccLoopNest.pop_back(); |
267 | --forallContextOpen; |
268 | conditionalCleanup(); |
269 | } |
270 | |
271 | void Fortran::lower::ExplicitIterSpace::addSymbol( |
272 | Fortran::lower::FrontEndSymbol sym) { |
273 | assert(!symbolStack.empty()); |
274 | symbolStack.back().push_back(sym); |
275 | } |
276 | |
277 | void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x, |
278 | bool lhs) { |
279 | ArrayBaseFinder finder(collectAllSymbols()); |
280 | finder(*x); |
281 | llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases = |
282 | finder.getBases(); |
283 | if (rhsBases.empty()) |
284 | endAssign(); |
285 | if (lhs) { |
286 | if (bases.empty()) { |
287 | lhsBases.push_back(std::nullopt); |
288 | return; |
289 | } |
290 | assert(bases.size() >= 1 && "must detect an array reference on lhs" ); |
291 | if (bases.size() > 1) |
292 | rhsBases.back().append(bases.begin(), bases.end() - 1); |
293 | lhsBases.push_back(bases.back()); |
294 | return; |
295 | } |
296 | rhsBases.back().append(bases.begin(), bases.end()); |
297 | } |
298 | |
299 | void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); } |
300 | |
301 | void Fortran::lower::ExplicitIterSpace::pushLevel() { |
302 | symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{}); |
303 | } |
304 | |
305 | void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); } |
306 | |
307 | void Fortran::lower::ExplicitIterSpace::conditionalCleanup() { |
308 | if (forallContextOpen == 0) { |
309 | // Exiting the outermost FORALL context. |
310 | // Cleanup any residual mask buffers. |
311 | outermostContext().finalizeAndReset(); |
312 | // Clear and reset all the cached information. |
313 | symbolStack.clear(); |
314 | lhsBases.clear(); |
315 | rhsBases.clear(); |
316 | loadBindings.clear(); |
317 | ccLoopNest.clear(); |
318 | innerArgs.clear(); |
319 | outerLoop = std::nullopt; |
320 | clearLoops(); |
321 | counter = 0; |
322 | } |
323 | } |
324 | |
325 | std::optional<size_t> |
326 | Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) { |
327 | if (lhsBases[counter]) { |
328 | auto ld = loadBindings.find(*lhsBases[counter]); |
329 | std::optional<size_t> optPos; |
330 | if (ld != loadBindings.end() && ld->second == load) |
331 | optPos = static_cast<size_t>(0u); |
332 | assert(optPos.has_value() && "load does not correspond to lhs" ); |
333 | return optPos; |
334 | } |
335 | return std::nullopt; |
336 | } |
337 | |
338 | llvm::SmallVector<Fortran::lower::FrontEndSymbol> |
339 | Fortran::lower::ExplicitIterSpace::collectAllSymbols() { |
340 | llvm::SmallVector<Fortran::lower::FrontEndSymbol> result; |
341 | for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack) |
342 | result.append(vec.begin(), vec.end()); |
343 | return result; |
344 | } |
345 | |
346 | llvm::raw_ostream & |
347 | Fortran::lower::operator<<(llvm::raw_ostream &s, |
348 | const Fortran::lower::ImplicitIterSpace &e) { |
349 | for (const llvm::SmallVector< |
350 | Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs : |
351 | e.getMasks()) { |
352 | s << "{ " ; |
353 | for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs) |
354 | x->AsFortran(s << '(') << "), " ; |
355 | s << "}\n" ; |
356 | } |
357 | return s; |
358 | } |
359 | |
360 | llvm::raw_ostream & |
361 | Fortran::lower::operator<<(llvm::raw_ostream &s, |
362 | const Fortran::lower::ExplicitIterSpace &e) { |
363 | auto dump = [&](const auto &u) { |
364 | std::visit(Fortran::common::visitors{ |
365 | [&](const Fortran::semantics::Symbol *y) { |
366 | s << " " << *y << '\n'; |
367 | }, |
368 | [&](const Fortran::evaluate::ArrayRef *y) { |
369 | s << " " ; |
370 | if (y->base().IsSymbol()) |
371 | s << y->base().GetFirstSymbol(); |
372 | else |
373 | s << y->base().GetComponent().GetLastSymbol(); |
374 | s << '\n'; |
375 | }, |
376 | [&](const Fortran::evaluate::Component *y) { |
377 | s << " " << y->GetLastSymbol() << '\n'; |
378 | }}, |
379 | u); |
380 | }; |
381 | s << "LHS bases:\n" ; |
382 | for (const std::optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u : |
383 | e.lhsBases) |
384 | if (u) |
385 | dump(*u); |
386 | s << "RHS bases:\n" ; |
387 | for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> |
388 | &bases : e.rhsBases) { |
389 | for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases) |
390 | dump(u); |
391 | s << '\n'; |
392 | } |
393 | return s; |
394 | } |
395 | |
396 | void Fortran::lower::ImplicitIterSpace::dump() const { |
397 | llvm::errs() << *this << '\n'; |
398 | } |
399 | |
400 | void Fortran::lower::ExplicitIterSpace::dump() const { |
401 | llvm::errs() << *this << '\n'; |
402 | } |
403 | |