1 | //===-- IterationSpace.cpp ------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Lower/IterationSpace.h" |
14 | #include "flang/Evaluate/expression.h" |
15 | #include "flang/Lower/AbstractConverter.h" |
16 | #include "flang/Lower/Support/Utils.h" |
17 | #include "llvm/Support/Debug.h" |
18 | #include <optional> |
19 | |
20 | #define DEBUG_TYPE "flang-lower-iteration-space" |
21 | |
22 | namespace { |
23 | |
24 | /// This class can recover the base array in an expression that contains |
25 | /// explicit iteration space symbols. Most of the class can be ignored as it is |
26 | /// boilerplate Fortran::evaluate::Expr traversal. |
27 | class ArrayBaseFinder { |
28 | public: |
29 | using RT = bool; |
30 | |
31 | ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms) |
32 | : controlVars(syms) {} |
33 | |
34 | template <typename T> |
35 | void operator()(const T &x) { |
36 | (void)find(x); |
37 | } |
38 | |
39 | /// Get the list of bases. |
40 | llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> |
41 | getBases() const { |
42 | LLVM_DEBUG(llvm::dbgs() |
43 | << "number of array bases found: " << bases.size() << '\n'); |
44 | return bases; |
45 | } |
46 | |
47 | private: |
48 | // First, the cases that are of interest. |
49 | RT find(const Fortran::semantics::Symbol &symbol) { |
50 | if (symbol.Rank() > 0) { |
51 | bases.push_back(&symbol); |
52 | return true; |
53 | } |
54 | return {}; |
55 | } |
56 | RT find(const Fortran::evaluate::Component &x) { |
57 | auto found = find(x.base()); |
58 | if (!found && x.base().Rank() == 0 && x.Rank() > 0) { |
59 | bases.push_back(&x); |
60 | return true; |
61 | } |
62 | return found; |
63 | } |
64 | RT find(const Fortran::evaluate::ArrayRef &x) { |
65 | for (const auto &sub : x.subscript()) |
66 | (void)find(sub); |
67 | if (x.base().IsSymbol()) { |
68 | if (x.Rank() > 0 || intersection(x.subscript())) { |
69 | bases.push_back(&x); |
70 | return true; |
71 | } |
72 | return {}; |
73 | } |
74 | auto found = find(x.base()); |
75 | if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) || |
76 | intersection(x.subscript()))) { |
77 | bases.push_back(&x); |
78 | return true; |
79 | } |
80 | return found; |
81 | } |
82 | RT find(const Fortran::evaluate::Triplet &x) { |
83 | if (const auto *lower = x.GetLower()) |
84 | (void)find(*lower); |
85 | if (const auto *upper = x.GetUpper()) |
86 | (void)find(*upper); |
87 | return find(x.GetStride()); |
88 | } |
89 | RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) { |
90 | return find(x.value()); |
91 | } |
92 | RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); } |
93 | RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); } |
94 | RT find(const Fortran::evaluate::CoarrayRef &x) { |
95 | assert(false && "coarray reference" ); |
96 | return {}; |
97 | } |
98 | |
99 | template <typename A> |
100 | bool intersection(const A &subscripts) { |
101 | return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts); |
102 | } |
103 | |
104 | // The rest is traversal boilerplate and can be ignored. |
105 | RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); } |
106 | template <typename A> |
107 | RT find(const Fortran::semantics::SymbolRef x) { |
108 | return find(*x); |
109 | } |
110 | RT find(const Fortran::evaluate::NamedEntity &x) { |
111 | if (x.IsSymbol()) |
112 | return find(x.GetFirstSymbol()); |
113 | return find(x.GetComponent()); |
114 | } |
115 | |
116 | template <typename A, bool C> |
117 | RT find(const Fortran::common::Indirection<A, C> &x) { |
118 | return find(x.value()); |
119 | } |
120 | template <typename A> |
121 | RT find(const std::unique_ptr<A> &x) { |
122 | return find(x.get()); |
123 | } |
124 | template <typename A> |
125 | RT find(const std::shared_ptr<A> &x) { |
126 | return find(x.get()); |
127 | } |
128 | template <typename A> |
129 | RT find(const A *x) { |
130 | if (x) |
131 | return find(*x); |
132 | return {}; |
133 | } |
134 | template <typename A> |
135 | RT find(const std::optional<A> &x) { |
136 | if (x) |
137 | return find(*x); |
138 | return {}; |
139 | } |
140 | template <typename... A> |
141 | RT find(const std::variant<A...> &u) { |
142 | return Fortran::common::visit([&](const auto &v) { return find(v); }, u); |
143 | } |
144 | template <typename A> |
145 | RT find(const std::vector<A> &x) { |
146 | for (auto &v : x) |
147 | (void)find(v); |
148 | return {}; |
149 | } |
150 | RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; } |
151 | RT find(const Fortran::evaluate::NullPointer &) { return {}; } |
152 | template <typename T> |
153 | RT find(const Fortran::evaluate::Constant<T> &x) { |
154 | return {}; |
155 | } |
156 | RT find(const Fortran::evaluate::StaticDataObject &) { return {}; } |
157 | RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; } |
158 | RT find(const Fortran::evaluate::BaseObject &x) { |
159 | (void)find(x.u); |
160 | return {}; |
161 | } |
162 | RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; } |
163 | RT find(const Fortran::evaluate::ComplexPart &x) { return {}; } |
164 | template <typename T> |
165 | RT find(const Fortran::evaluate::Designator<T> &x) { |
166 | return find(x.u); |
167 | } |
168 | RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; } |
169 | RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; } |
170 | RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; } |
171 | RT find(const Fortran::evaluate::ProcedureRef &x) { |
172 | (void)find(x.proc()); |
173 | if (x.IsElemental()) |
174 | (void)find(x.arguments()); |
175 | return {}; |
176 | } |
177 | RT find(const Fortran::evaluate::ActualArgument &x) { |
178 | if (const auto *sym = x.GetAssumedTypeDummy()) |
179 | (void)find(*sym); |
180 | else |
181 | (void)find(x.UnwrapExpr()); |
182 | return {}; |
183 | } |
184 | template <typename T> |
185 | RT find(const Fortran::evaluate::FunctionRef<T> &x) { |
186 | (void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x)); |
187 | return {}; |
188 | } |
189 | template <typename T> |
190 | RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) { |
191 | return {}; |
192 | } |
193 | template <typename T> |
194 | RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) { |
195 | return {}; |
196 | } |
197 | template <typename T> |
198 | RT find(const Fortran::evaluate::ImpliedDo<T> &) { |
199 | return {}; |
200 | } |
201 | RT find(const Fortran::semantics::ParamValue &) { return {}; } |
202 | RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; } |
203 | RT find(const Fortran::evaluate::StructureConstructor &) { return {}; } |
204 | template <typename D, typename R, typename O> |
205 | RT find(const Fortran::evaluate::Operation<D, R, O> &op) { |
206 | (void)find(op.left()); |
207 | return false; |
208 | } |
209 | template <typename D, typename R, typename LO, typename RO> |
210 | RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) { |
211 | (void)find(op.left()); |
212 | (void)find(op.right()); |
213 | return false; |
214 | } |
215 | RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) { |
216 | (void)find(x.u); |
217 | return {}; |
218 | } |
219 | template <typename T> |
220 | RT find(const Fortran::evaluate::Expr<T> &x) { |
221 | (void)find(x.u); |
222 | return {}; |
223 | } |
224 | |
225 | llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases; |
226 | llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars; |
227 | }; |
228 | |
229 | } // namespace |
230 | |
231 | void Fortran::lower::ExplicitIterSpace::leave() { |
232 | ccLoopNest.pop_back(); |
233 | --forallContextOpen; |
234 | conditionalCleanup(); |
235 | } |
236 | |
237 | void Fortran::lower::ExplicitIterSpace::addSymbol( |
238 | Fortran::lower::FrontEndSymbol sym) { |
239 | assert(!symbolStack.empty()); |
240 | symbolStack.back().push_back(sym); |
241 | } |
242 | |
243 | void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x, |
244 | bool lhs) { |
245 | ArrayBaseFinder finder(collectAllSymbols()); |
246 | finder(*x); |
247 | llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases = |
248 | finder.getBases(); |
249 | if (rhsBases.empty()) |
250 | endAssign(); |
251 | if (lhs) { |
252 | if (bases.empty()) { |
253 | lhsBases.push_back(std::nullopt); |
254 | return; |
255 | } |
256 | assert(bases.size() >= 1 && "must detect an array reference on lhs" ); |
257 | if (bases.size() > 1) |
258 | rhsBases.back().append(bases.begin(), bases.end() - 1); |
259 | lhsBases.push_back(bases.back()); |
260 | return; |
261 | } |
262 | rhsBases.back().append(bases.begin(), bases.end()); |
263 | } |
264 | |
265 | void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); } |
266 | |
267 | void Fortran::lower::ExplicitIterSpace::pushLevel() { |
268 | symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{}); |
269 | } |
270 | |
271 | void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); } |
272 | |
273 | void Fortran::lower::ExplicitIterSpace::conditionalCleanup() { |
274 | if (forallContextOpen == 0) { |
275 | // Exiting the outermost FORALL context. |
276 | // Cleanup any residual mask buffers. |
277 | outermostContext().finalizeAndReset(); |
278 | // Clear and reset all the cached information. |
279 | symbolStack.clear(); |
280 | lhsBases.clear(); |
281 | rhsBases.clear(); |
282 | loadBindings.clear(); |
283 | ccLoopNest.clear(); |
284 | innerArgs.clear(); |
285 | outerLoop = std::nullopt; |
286 | clearLoops(); |
287 | counter = 0; |
288 | } |
289 | } |
290 | |
291 | std::optional<size_t> |
292 | Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) { |
293 | if (lhsBases[counter]) { |
294 | auto ld = loadBindings.find(*lhsBases[counter]); |
295 | std::optional<size_t> optPos; |
296 | if (ld != loadBindings.end() && ld->second == load) |
297 | optPos = static_cast<size_t>(0u); |
298 | assert(optPos.has_value() && "load does not correspond to lhs" ); |
299 | return optPos; |
300 | } |
301 | return std::nullopt; |
302 | } |
303 | |
304 | llvm::SmallVector<Fortran::lower::FrontEndSymbol> |
305 | Fortran::lower::ExplicitIterSpace::collectAllSymbols() { |
306 | llvm::SmallVector<Fortran::lower::FrontEndSymbol> result; |
307 | for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack) |
308 | result.append(vec.begin(), vec.end()); |
309 | return result; |
310 | } |
311 | |
312 | llvm::raw_ostream & |
313 | Fortran::lower::operator<<(llvm::raw_ostream &s, |
314 | const Fortran::lower::ImplicitIterSpace &e) { |
315 | for (const llvm::SmallVector< |
316 | Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs : |
317 | e.getMasks()) { |
318 | s << "{ " ; |
319 | for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs) |
320 | x->AsFortran(s << '(') << "), " ; |
321 | s << "}\n" ; |
322 | } |
323 | return s; |
324 | } |
325 | |
326 | llvm::raw_ostream & |
327 | Fortran::lower::operator<<(llvm::raw_ostream &s, |
328 | const Fortran::lower::ExplicitIterSpace &e) { |
329 | auto dump = [&](const auto &u) { |
330 | Fortran::common::visit( |
331 | Fortran::common::visitors{ |
332 | [&](const Fortran::semantics::Symbol *y) { |
333 | s << " " << *y << '\n'; |
334 | }, |
335 | [&](const Fortran::evaluate::ArrayRef *y) { |
336 | s << " " ; |
337 | if (y->base().IsSymbol()) |
338 | s << y->base().GetFirstSymbol(); |
339 | else |
340 | s << y->base().GetComponent().GetLastSymbol(); |
341 | s << '\n'; |
342 | }, |
343 | [&](const Fortran::evaluate::Component *y) { |
344 | s << " " << y->GetLastSymbol() << '\n'; |
345 | }}, |
346 | u); |
347 | }; |
348 | s << "LHS bases:\n" ; |
349 | for (const std::optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u : |
350 | e.lhsBases) |
351 | if (u) |
352 | dump(*u); |
353 | s << "RHS bases:\n" ; |
354 | for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> |
355 | &bases : e.rhsBases) { |
356 | for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases) |
357 | dump(u); |
358 | s << '\n'; |
359 | } |
360 | return s; |
361 | } |
362 | |
363 | void Fortran::lower::ImplicitIterSpace::dump() const { |
364 | llvm::errs() << *this << '\n'; |
365 | } |
366 | |
367 | void Fortran::lower::ExplicitIterSpace::dump() const { |
368 | llvm::errs() << *this << '\n'; |
369 | } |
370 | |