1 | //===-- Lower/Support/Utils.cpp -- utilities --------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Lower/Support/Utils.h" |
14 | |
15 | #include "flang/Common/indirection.h" |
16 | #include "flang/Lower/AbstractConverter.h" |
17 | #include "flang/Lower/ConvertVariable.h" |
18 | #include "flang/Lower/IterationSpace.h" |
19 | #include "flang/Lower/Support/PrivateReductionUtils.h" |
20 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
21 | #include "flang/Optimizer/Builder/Todo.h" |
22 | #include "flang/Optimizer/HLFIR/HLFIRDialect.h" |
23 | #include "flang/Semantics/tools.h" |
24 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
25 | #include <cstdint> |
26 | #include <optional> |
27 | #include <type_traits> |
28 | |
29 | namespace Fortran::lower { |
30 | // Fortran::evaluate::Expr are functional values organized like an AST. A |
31 | // Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end |
32 | // tools can often cause copies and extra wrapper classes to be added to any |
33 | // Fortran::evaluate::Expr. These values should not be assumed or relied upon to |
34 | // have an *object* identity. They are deeply recursive, irregular structures |
35 | // built from a large number of classes which do not use inheritance and |
36 | // necessitate a large volume of boilerplate code as a result. |
37 | // |
38 | // Contrastingly, LLVM data structures make ubiquitous assumptions about an |
39 | // object's identity via pointers to the object. An object's location in memory |
40 | // is thus very often an identifying relation. |
41 | |
42 | // This class defines a hash computation of a Fortran::evaluate::Expr tree value |
43 | // so it can be used with llvm::DenseMap. The Fortran::evaluate::Expr need not |
44 | // have the same address. |
45 | class HashEvaluateExpr { |
46 | public: |
47 | // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an |
48 | // identity property. |
49 | static unsigned getHashValue(const Fortran::semantics::Symbol &x) { |
50 | return static_cast<unsigned>(reinterpret_cast<std::intptr_t>(&x)); |
51 | } |
52 | template <typename A, bool COPY> |
53 | static unsigned getHashValue(const Fortran::common::Indirection<A, COPY> &x) { |
54 | return getHashValue(x.value()); |
55 | } |
56 | template <typename A> |
57 | static unsigned getHashValue(const std::optional<A> &x) { |
58 | if (x.has_value()) |
59 | return getHashValue(x.value()); |
60 | return 0u; |
61 | } |
62 | static unsigned getHashValue(const Fortran::evaluate::Subscript &x) { |
63 | return Fortran::common::visit( |
64 | [&](const auto &v) { return getHashValue(v); }, x.u); |
65 | } |
66 | static unsigned getHashValue(const Fortran::evaluate::Triplet &x) { |
67 | return getHashValue(x.lower()) - getHashValue(x.upper()) * 5u - |
68 | getHashValue(x.stride()) * 11u; |
69 | } |
70 | static unsigned getHashValue(const Fortran::evaluate::Component &x) { |
71 | return getHashValue(x.base()) * 83u - getHashValue(x.GetLastSymbol()); |
72 | } |
73 | static unsigned getHashValue(const Fortran::evaluate::ArrayRef &x) { |
74 | unsigned subs = 1u; |
75 | for (const Fortran::evaluate::Subscript &v : x.subscript()) |
76 | subs -= getHashValue(v); |
77 | return getHashValue(x.base()) * 89u - subs; |
78 | } |
79 | static unsigned getHashValue(const Fortran::evaluate::CoarrayRef &x) { |
80 | unsigned cosubs = 3u; |
81 | for (const Fortran::evaluate::Expr<Fortran::evaluate::SubscriptInteger> &v : |
82 | x.cosubscript()) |
83 | cosubs -= getHashValue(v); |
84 | return getHashValue(x.base()) * 97u - cosubs + getHashValue(x.stat()) + |
85 | 257u + getHashValue(x.team()); |
86 | } |
87 | static unsigned getHashValue(const Fortran::evaluate::NamedEntity &x) { |
88 | if (x.IsSymbol()) |
89 | return getHashValue(x.GetFirstSymbol()) * 11u; |
90 | return getHashValue(x.GetComponent()) * 13u; |
91 | } |
92 | static unsigned getHashValue(const Fortran::evaluate::DataRef &x) { |
93 | return Fortran::common::visit( |
94 | [&](const auto &v) { return getHashValue(v); }, x.u); |
95 | } |
96 | static unsigned getHashValue(const Fortran::evaluate::ComplexPart &x) { |
97 | return getHashValue(x.complex()) - static_cast<unsigned>(x.part()); |
98 | } |
99 | template <Fortran::common::TypeCategory TC1, int KIND, |
100 | Fortran::common::TypeCategory TC2> |
101 | static unsigned getHashValue( |
102 | const Fortran::evaluate::Convert<Fortran::evaluate::Type<TC1, KIND>, TC2> |
103 | &x) { |
104 | return getHashValue(x.left()) - (static_cast<unsigned>(TC1) + 2u) - |
105 | (static_cast<unsigned>(KIND) + 5u); |
106 | } |
107 | template <int KIND> |
108 | static unsigned |
109 | getHashValue(const Fortran::evaluate::ComplexComponent<KIND> &x) { |
110 | return getHashValue(x.left()) - |
111 | (static_cast<unsigned>(x.isImaginaryPart) + 1u) * 3u; |
112 | } |
113 | template <typename T> |
114 | static unsigned getHashValue(const Fortran::evaluate::Parentheses<T> &x) { |
115 | return getHashValue(x.left()) * 17u; |
116 | } |
117 | template <Fortran::common::TypeCategory TC, int KIND> |
118 | static unsigned getHashValue( |
119 | const Fortran::evaluate::Negate<Fortran::evaluate::Type<TC, KIND>> &x) { |
120 | return getHashValue(x.left()) - (static_cast<unsigned>(TC) + 5u) - |
121 | (static_cast<unsigned>(KIND) + 7u); |
122 | } |
123 | template <Fortran::common::TypeCategory TC, int KIND> |
124 | static unsigned getHashValue( |
125 | const Fortran::evaluate::Add<Fortran::evaluate::Type<TC, KIND>> &x) { |
126 | return (getHashValue(x.left()) + getHashValue(x.right())) * 23u + |
127 | static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); |
128 | } |
129 | template <Fortran::common::TypeCategory TC, int KIND> |
130 | static unsigned getHashValue( |
131 | const Fortran::evaluate::Subtract<Fortran::evaluate::Type<TC, KIND>> &x) { |
132 | return (getHashValue(x.left()) - getHashValue(x.right())) * 19u + |
133 | static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); |
134 | } |
135 | template <Fortran::common::TypeCategory TC, int KIND> |
136 | static unsigned getHashValue( |
137 | const Fortran::evaluate::Multiply<Fortran::evaluate::Type<TC, KIND>> &x) { |
138 | return (getHashValue(x.left()) + getHashValue(x.right())) * 29u + |
139 | static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); |
140 | } |
141 | template <Fortran::common::TypeCategory TC, int KIND> |
142 | static unsigned getHashValue( |
143 | const Fortran::evaluate::Divide<Fortran::evaluate::Type<TC, KIND>> &x) { |
144 | return (getHashValue(x.left()) - getHashValue(x.right())) * 31u + |
145 | static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); |
146 | } |
147 | template <Fortran::common::TypeCategory TC, int KIND> |
148 | static unsigned getHashValue( |
149 | const Fortran::evaluate::Power<Fortran::evaluate::Type<TC, KIND>> &x) { |
150 | return (getHashValue(x.left()) - getHashValue(x.right())) * 37u + |
151 | static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); |
152 | } |
153 | template <Fortran::common::TypeCategory TC, int KIND> |
154 | static unsigned getHashValue( |
155 | const Fortran::evaluate::Extremum<Fortran::evaluate::Type<TC, KIND>> &x) { |
156 | return (getHashValue(x.left()) + getHashValue(x.right())) * 41u + |
157 | static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) + |
158 | static_cast<unsigned>(x.ordering) * 7u; |
159 | } |
160 | template <Fortran::common::TypeCategory TC, int KIND> |
161 | static unsigned getHashValue( |
162 | const Fortran::evaluate::RealToIntPower<Fortran::evaluate::Type<TC, KIND>> |
163 | &x) { |
164 | return (getHashValue(x.left()) - getHashValue(x.right())) * 43u + |
165 | static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); |
166 | } |
167 | template <int KIND> |
168 | static unsigned |
169 | getHashValue(const Fortran::evaluate::ComplexConstructor<KIND> &x) { |
170 | return (getHashValue(x.left()) - getHashValue(x.right())) * 47u + |
171 | static_cast<unsigned>(KIND); |
172 | } |
173 | template <int KIND> |
174 | static unsigned getHashValue(const Fortran::evaluate::Concat<KIND> &x) { |
175 | return (getHashValue(x.left()) - getHashValue(x.right())) * 53u + |
176 | static_cast<unsigned>(KIND); |
177 | } |
178 | template <int KIND> |
179 | static unsigned getHashValue(const Fortran::evaluate::SetLength<KIND> &x) { |
180 | return (getHashValue(x.left()) - getHashValue(x.right())) * 59u + |
181 | static_cast<unsigned>(KIND); |
182 | } |
183 | static unsigned getHashValue(const Fortran::semantics::SymbolRef &sym) { |
184 | return getHashValue(sym.get()); |
185 | } |
186 | static unsigned getHashValue(const Fortran::evaluate::Substring &x) { |
187 | return 61u * |
188 | Fortran::common::visit( |
189 | [&](const auto &p) { return getHashValue(p); }, x.parent()) - |
190 | getHashValue(x.lower()) - (getHashValue(x.lower()) + 1u); |
191 | } |
192 | static unsigned |
193 | getHashValue(const Fortran::evaluate::StaticDataObject::Pointer &x) { |
194 | return llvm::hash_value(x->name()); |
195 | } |
196 | static unsigned getHashValue(const Fortran::evaluate::SpecificIntrinsic &x) { |
197 | return llvm::hash_value(x.name); |
198 | } |
199 | template <typename A> |
200 | static unsigned getHashValue(const Fortran::evaluate::Constant<A> &x) { |
201 | // FIXME: Should hash the content. |
202 | return 103u; |
203 | } |
204 | static unsigned getHashValue(const Fortran::evaluate::ActualArgument &x) { |
205 | if (const Fortran::evaluate::Symbol *sym = x.GetAssumedTypeDummy()) |
206 | return getHashValue(*sym); |
207 | return getHashValue(*x.UnwrapExpr()); |
208 | } |
209 | static unsigned |
210 | getHashValue(const Fortran::evaluate::ProcedureDesignator &x) { |
211 | return Fortran::common::visit( |
212 | [&](const auto &v) { return getHashValue(v); }, x.u); |
213 | } |
214 | static unsigned getHashValue(const Fortran::evaluate::ProcedureRef &x) { |
215 | unsigned args = 13u; |
216 | for (const std::optional<Fortran::evaluate::ActualArgument> &v : |
217 | x.arguments()) |
218 | args -= getHashValue(v); |
219 | return getHashValue(x.proc()) * 101u - args; |
220 | } |
221 | template <typename A> |
222 | static unsigned |
223 | getHashValue(const Fortran::evaluate::ArrayConstructor<A> &x) { |
224 | // FIXME: hash the contents. |
225 | return 127u; |
226 | } |
227 | static unsigned getHashValue(const Fortran::evaluate::ImpliedDoIndex &x) { |
228 | return llvm::hash_value(toStringRef(x.name).str()) * 131u; |
229 | } |
230 | static unsigned getHashValue(const Fortran::evaluate::TypeParamInquiry &x) { |
231 | return getHashValue(x.base()) * 137u - getHashValue(x.parameter()) * 3u; |
232 | } |
233 | static unsigned getHashValue(const Fortran::evaluate::DescriptorInquiry &x) { |
234 | return getHashValue(x.base()) * 139u - |
235 | static_cast<unsigned>(x.field()) * 13u + |
236 | static_cast<unsigned>(x.dimension()); |
237 | } |
238 | static unsigned |
239 | getHashValue(const Fortran::evaluate::StructureConstructor &x) { |
240 | // FIXME: hash the contents. |
241 | return 149u; |
242 | } |
243 | template <int KIND> |
244 | static unsigned getHashValue(const Fortran::evaluate::Not<KIND> &x) { |
245 | return getHashValue(x.left()) * 61u + static_cast<unsigned>(KIND); |
246 | } |
247 | template <int KIND> |
248 | static unsigned |
249 | getHashValue(const Fortran::evaluate::LogicalOperation<KIND> &x) { |
250 | unsigned result = getHashValue(x.left()) + getHashValue(x.right()); |
251 | return result * 67u + static_cast<unsigned>(x.logicalOperator) * 5u; |
252 | } |
253 | template <Fortran::common::TypeCategory TC, int KIND> |
254 | static unsigned getHashValue( |
255 | const Fortran::evaluate::Relational<Fortran::evaluate::Type<TC, KIND>> |
256 | &x) { |
257 | return (getHashValue(x.left()) + getHashValue(x.right())) * 71u + |
258 | static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) + |
259 | static_cast<unsigned>(x.opr) * 11u; |
260 | } |
261 | template <typename A> |
262 | static unsigned getHashValue(const Fortran::evaluate::Expr<A> &x) { |
263 | return Fortran::common::visit( |
264 | [&](const auto &v) { return getHashValue(v); }, x.u); |
265 | } |
266 | static unsigned getHashValue( |
267 | const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) { |
268 | return Fortran::common::visit( |
269 | [&](const auto &v) { return getHashValue(v); }, x.u); |
270 | } |
271 | template <typename A> |
272 | static unsigned getHashValue(const Fortran::evaluate::Designator<A> &x) { |
273 | return Fortran::common::visit( |
274 | [&](const auto &v) { return getHashValue(v); }, x.u); |
275 | } |
276 | template <int BITS> |
277 | static unsigned |
278 | getHashValue(const Fortran::evaluate::value::Integer<BITS> &x) { |
279 | return static_cast<unsigned>(x.ToSInt()); |
280 | } |
281 | static unsigned getHashValue(const Fortran::evaluate::NullPointer &x) { |
282 | return ~179u; |
283 | } |
284 | }; |
285 | |
286 | // Define the is equals test for using Fortran::evaluate::Expr values with |
287 | // llvm::DenseMap. |
288 | class IsEqualEvaluateExpr { |
289 | public: |
290 | // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an |
291 | // identity property. |
292 | static bool isEqual(const Fortran::semantics::Symbol &x, |
293 | const Fortran::semantics::Symbol &y) { |
294 | return isEqual(&x, &y); |
295 | } |
296 | static bool isEqual(const Fortran::semantics::Symbol *x, |
297 | const Fortran::semantics::Symbol *y) { |
298 | return x == y; |
299 | } |
300 | template <typename A, bool COPY> |
301 | static bool isEqual(const Fortran::common::Indirection<A, COPY> &x, |
302 | const Fortran::common::Indirection<A, COPY> &y) { |
303 | return isEqual(x.value(), y.value()); |
304 | } |
305 | template <typename A> |
306 | static bool isEqual(const std::optional<A> &x, const std::optional<A> &y) { |
307 | if (x.has_value() && y.has_value()) |
308 | return isEqual(x.value(), y.value()); |
309 | return !x.has_value() && !y.has_value(); |
310 | } |
311 | template <typename A> |
312 | static bool isEqual(const std::vector<A> &x, const std::vector<A> &y) { |
313 | if (x.size() != y.size()) |
314 | return false; |
315 | const std::size_t size = x.size(); |
316 | for (std::remove_const_t<decltype(size)> i = 0; i < size; ++i) |
317 | if (!isEqual(x[i], y[i])) |
318 | return false; |
319 | return true; |
320 | } |
321 | static bool isEqual(const Fortran::evaluate::Subscript &x, |
322 | const Fortran::evaluate::Subscript &y) { |
323 | return Fortran::common::visit( |
324 | [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); |
325 | } |
326 | static bool isEqual(const Fortran::evaluate::Triplet &x, |
327 | const Fortran::evaluate::Triplet &y) { |
328 | return isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()) && |
329 | isEqual(x.stride(), y.stride()); |
330 | } |
331 | static bool isEqual(const Fortran::evaluate::Component &x, |
332 | const Fortran::evaluate::Component &y) { |
333 | return isEqual(x.base(), y.base()) && |
334 | isEqual(x.GetLastSymbol(), y.GetLastSymbol()); |
335 | } |
336 | static bool isEqual(const Fortran::evaluate::ArrayRef &x, |
337 | const Fortran::evaluate::ArrayRef &y) { |
338 | return isEqual(x.base(), y.base()) && isEqual(x.subscript(), y.subscript()); |
339 | } |
340 | static bool isEqual(const Fortran::evaluate::CoarrayRef &x, |
341 | const Fortran::evaluate::CoarrayRef &y) { |
342 | return isEqual(x.base(), y.base()) && |
343 | isEqual(x.cosubscript(), y.cosubscript()) && |
344 | isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team()); |
345 | } |
346 | static bool isEqual(const Fortran::evaluate::NamedEntity &x, |
347 | const Fortran::evaluate::NamedEntity &y) { |
348 | if (x.IsSymbol() && y.IsSymbol()) |
349 | return isEqual(x.GetFirstSymbol(), y.GetFirstSymbol()); |
350 | return !x.IsSymbol() && !y.IsSymbol() && |
351 | isEqual(x.GetComponent(), y.GetComponent()); |
352 | } |
353 | static bool isEqual(const Fortran::evaluate::DataRef &x, |
354 | const Fortran::evaluate::DataRef &y) { |
355 | return Fortran::common::visit( |
356 | [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); |
357 | } |
358 | static bool isEqual(const Fortran::evaluate::ComplexPart &x, |
359 | const Fortran::evaluate::ComplexPart &y) { |
360 | return isEqual(x.complex(), y.complex()) && x.part() == y.part(); |
361 | } |
362 | template <typename A, Fortran::common::TypeCategory TC2> |
363 | static bool isEqual(const Fortran::evaluate::Convert<A, TC2> &x, |
364 | const Fortran::evaluate::Convert<A, TC2> &y) { |
365 | return isEqual(x.left(), y.left()); |
366 | } |
367 | template <int KIND> |
368 | static bool isEqual(const Fortran::evaluate::ComplexComponent<KIND> &x, |
369 | const Fortran::evaluate::ComplexComponent<KIND> &y) { |
370 | return isEqual(x.left(), y.left()) && |
371 | x.isImaginaryPart == y.isImaginaryPart; |
372 | } |
373 | template <typename T> |
374 | static bool isEqual(const Fortran::evaluate::Parentheses<T> &x, |
375 | const Fortran::evaluate::Parentheses<T> &y) { |
376 | return isEqual(x.left(), y.left()); |
377 | } |
378 | template <typename A> |
379 | static bool isEqual(const Fortran::evaluate::Negate<A> &x, |
380 | const Fortran::evaluate::Negate<A> &y) { |
381 | return isEqual(x.left(), y.left()); |
382 | } |
383 | template <typename A> |
384 | static bool isBinaryEqual(const A &x, const A &y) { |
385 | return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); |
386 | } |
387 | template <typename A> |
388 | static bool isEqual(const Fortran::evaluate::Add<A> &x, |
389 | const Fortran::evaluate::Add<A> &y) { |
390 | return isBinaryEqual(x, y); |
391 | } |
392 | template <typename A> |
393 | static bool isEqual(const Fortran::evaluate::Subtract<A> &x, |
394 | const Fortran::evaluate::Subtract<A> &y) { |
395 | return isBinaryEqual(x, y); |
396 | } |
397 | template <typename A> |
398 | static bool isEqual(const Fortran::evaluate::Multiply<A> &x, |
399 | const Fortran::evaluate::Multiply<A> &y) { |
400 | return isBinaryEqual(x, y); |
401 | } |
402 | template <typename A> |
403 | static bool isEqual(const Fortran::evaluate::Divide<A> &x, |
404 | const Fortran::evaluate::Divide<A> &y) { |
405 | return isBinaryEqual(x, y); |
406 | } |
407 | template <typename A> |
408 | static bool isEqual(const Fortran::evaluate::Power<A> &x, |
409 | const Fortran::evaluate::Power<A> &y) { |
410 | return isBinaryEqual(x, y); |
411 | } |
412 | template <typename A> |
413 | static bool isEqual(const Fortran::evaluate::Extremum<A> &x, |
414 | const Fortran::evaluate::Extremum<A> &y) { |
415 | return isBinaryEqual(x, y); |
416 | } |
417 | template <typename A> |
418 | static bool isEqual(const Fortran::evaluate::RealToIntPower<A> &x, |
419 | const Fortran::evaluate::RealToIntPower<A> &y) { |
420 | return isBinaryEqual(x, y); |
421 | } |
422 | template <int KIND> |
423 | static bool isEqual(const Fortran::evaluate::ComplexConstructor<KIND> &x, |
424 | const Fortran::evaluate::ComplexConstructor<KIND> &y) { |
425 | return isBinaryEqual(x, y); |
426 | } |
427 | template <int KIND> |
428 | static bool isEqual(const Fortran::evaluate::Concat<KIND> &x, |
429 | const Fortran::evaluate::Concat<KIND> &y) { |
430 | return isBinaryEqual(x, y); |
431 | } |
432 | template <int KIND> |
433 | static bool isEqual(const Fortran::evaluate::SetLength<KIND> &x, |
434 | const Fortran::evaluate::SetLength<KIND> &y) { |
435 | return isBinaryEqual(x, y); |
436 | } |
437 | static bool isEqual(const Fortran::semantics::SymbolRef &x, |
438 | const Fortran::semantics::SymbolRef &y) { |
439 | return isEqual(x.get(), y.get()); |
440 | } |
441 | static bool isEqual(const Fortran::evaluate::Substring &x, |
442 | const Fortran::evaluate::Substring &y) { |
443 | return Fortran::common::visit( |
444 | [&](const auto &p, const auto &q) { return isEqual(p, q); }, |
445 | x.parent(), y.parent()) && |
446 | isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()); |
447 | } |
448 | static bool isEqual(const Fortran::evaluate::StaticDataObject::Pointer &x, |
449 | const Fortran::evaluate::StaticDataObject::Pointer &y) { |
450 | return x->name() == y->name(); |
451 | } |
452 | static bool isEqual(const Fortran::evaluate::SpecificIntrinsic &x, |
453 | const Fortran::evaluate::SpecificIntrinsic &y) { |
454 | return x.name == y.name; |
455 | } |
456 | template <typename A> |
457 | static bool isEqual(const Fortran::evaluate::Constant<A> &x, |
458 | const Fortran::evaluate::Constant<A> &y) { |
459 | return x == y; |
460 | } |
461 | static bool isEqual(const Fortran::evaluate::ActualArgument &x, |
462 | const Fortran::evaluate::ActualArgument &y) { |
463 | if (const Fortran::evaluate::Symbol *xs = x.GetAssumedTypeDummy()) { |
464 | if (const Fortran::evaluate::Symbol *ys = y.GetAssumedTypeDummy()) |
465 | return isEqual(*xs, *ys); |
466 | return false; |
467 | } |
468 | return !y.GetAssumedTypeDummy() && |
469 | isEqual(*x.UnwrapExpr(), *y.UnwrapExpr()); |
470 | } |
471 | static bool isEqual(const Fortran::evaluate::ProcedureDesignator &x, |
472 | const Fortran::evaluate::ProcedureDesignator &y) { |
473 | return Fortran::common::visit( |
474 | [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); |
475 | } |
476 | static bool isEqual(const Fortran::evaluate::ProcedureRef &x, |
477 | const Fortran::evaluate::ProcedureRef &y) { |
478 | return isEqual(x.proc(), y.proc()) && isEqual(x.arguments(), y.arguments()); |
479 | } |
480 | template <typename A> |
481 | static bool isEqual(const Fortran::evaluate::ImpliedDo<A> &x, |
482 | const Fortran::evaluate::ImpliedDo<A> &y) { |
483 | return isEqual(x.values(), y.values()) && isEqual(x.lower(), y.lower()) && |
484 | isEqual(x.upper(), y.upper()) && isEqual(x.stride(), y.stride()); |
485 | } |
486 | template <typename A> |
487 | static bool isEqual(const Fortran::evaluate::ArrayConstructorValues<A> &x, |
488 | const Fortran::evaluate::ArrayConstructorValues<A> &y) { |
489 | using Expr = Fortran::evaluate::Expr<A>; |
490 | using ImpliedDo = Fortran::evaluate::ImpliedDo<A>; |
491 | for (const auto &[xValue, yValue] : llvm::zip(x, y)) { |
492 | bool checkElement = Fortran::common::visit( |
493 | common::visitors{ |
494 | [&](const Expr &v, const Expr &w) { return isEqual(v, w); }, |
495 | [&](const ImpliedDo &v, const ImpliedDo &w) { |
496 | return isEqual(v, w); |
497 | }, |
498 | [&](const Expr &, const ImpliedDo &) { return false; }, |
499 | [&](const ImpliedDo &, const Expr &) { return false; }, |
500 | }, |
501 | xValue.u, yValue.u); |
502 | if (!checkElement) { |
503 | return false; |
504 | } |
505 | } |
506 | return true; |
507 | } |
508 | static bool isEqual(const Fortran::evaluate::SubscriptInteger &x, |
509 | const Fortran::evaluate::SubscriptInteger &y) { |
510 | return x == y; |
511 | } |
512 | template <typename A> |
513 | static bool isEqual(const Fortran::evaluate::ArrayConstructor<A> &x, |
514 | const Fortran::evaluate::ArrayConstructor<A> &y) { |
515 | bool checkCharacterType = true; |
516 | if constexpr (A::category == Fortran::common::TypeCategory::Character) { |
517 | checkCharacterType = isEqual(*x.LEN(), *y.LEN()); |
518 | } |
519 | using Base = Fortran::evaluate::ArrayConstructorValues<A>; |
520 | return isEqual((Base)x, (Base)y) && |
521 | (x.GetType() == y.GetType() && checkCharacterType); |
522 | } |
523 | static bool isEqual(const Fortran::evaluate::ImpliedDoIndex &x, |
524 | const Fortran::evaluate::ImpliedDoIndex &y) { |
525 | return toStringRef(x.name) == toStringRef(y.name); |
526 | } |
527 | static bool isEqual(const Fortran::evaluate::TypeParamInquiry &x, |
528 | const Fortran::evaluate::TypeParamInquiry &y) { |
529 | return isEqual(x.base(), y.base()) && isEqual(x.parameter(), y.parameter()); |
530 | } |
531 | static bool isEqual(const Fortran::evaluate::DescriptorInquiry &x, |
532 | const Fortran::evaluate::DescriptorInquiry &y) { |
533 | return isEqual(x.base(), y.base()) && x.field() == y.field() && |
534 | x.dimension() == y.dimension(); |
535 | } |
536 | static bool isEqual(const Fortran::evaluate::StructureConstructor &x, |
537 | const Fortran::evaluate::StructureConstructor &y) { |
538 | const auto &xValues = x.values(); |
539 | const auto &yValues = y.values(); |
540 | if (xValues.size() != yValues.size()) |
541 | return false; |
542 | if (x.derivedTypeSpec() != y.derivedTypeSpec()) |
543 | return false; |
544 | for (const auto &[xSymbol, xValue] : xValues) { |
545 | auto yIt = yValues.find(xSymbol); |
546 | // This should probably never happen, since the derived type |
547 | // should be the same. |
548 | if (yIt == yValues.end()) |
549 | return false; |
550 | if (!isEqual(xValue, yIt->second)) |
551 | return false; |
552 | } |
553 | return true; |
554 | } |
555 | template <int KIND> |
556 | static bool isEqual(const Fortran::evaluate::Not<KIND> &x, |
557 | const Fortran::evaluate::Not<KIND> &y) { |
558 | return isEqual(x.left(), y.left()); |
559 | } |
560 | template <int KIND> |
561 | static bool isEqual(const Fortran::evaluate::LogicalOperation<KIND> &x, |
562 | const Fortran::evaluate::LogicalOperation<KIND> &y) { |
563 | return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); |
564 | } |
565 | template <typename A> |
566 | static bool isEqual(const Fortran::evaluate::Relational<A> &x, |
567 | const Fortran::evaluate::Relational<A> &y) { |
568 | return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); |
569 | } |
570 | template <typename A> |
571 | static bool isEqual(const Fortran::evaluate::Expr<A> &x, |
572 | const Fortran::evaluate::Expr<A> &y) { |
573 | return Fortran::common::visit( |
574 | [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); |
575 | } |
576 | static bool |
577 | isEqual(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x, |
578 | const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &y) { |
579 | return Fortran::common::visit( |
580 | [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); |
581 | } |
582 | template <typename A> |
583 | static bool isEqual(const Fortran::evaluate::Designator<A> &x, |
584 | const Fortran::evaluate::Designator<A> &y) { |
585 | return Fortran::common::visit( |
586 | [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); |
587 | } |
588 | template <int BITS> |
589 | static bool isEqual(const Fortran::evaluate::value::Integer<BITS> &x, |
590 | const Fortran::evaluate::value::Integer<BITS> &y) { |
591 | return x == y; |
592 | } |
593 | static bool isEqual(const Fortran::evaluate::NullPointer &x, |
594 | const Fortran::evaluate::NullPointer &y) { |
595 | return true; |
596 | } |
597 | template <typename A, typename B, |
598 | std::enable_if_t<!std::is_same_v<A, B>, bool> = true> |
599 | static bool isEqual(const A &, const B &) { |
600 | return false; |
601 | } |
602 | }; |
603 | |
604 | unsigned getHashValue(const Fortran::lower::SomeExpr *x) { |
605 | return HashEvaluateExpr::getHashValue(*x); |
606 | } |
607 | |
608 | unsigned getHashValue(const Fortran::lower::ExplicitIterSpace::ArrayBases &x) { |
609 | return Fortran::common::visit( |
610 | [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x); |
611 | } |
612 | |
613 | bool isEqual(const Fortran::lower::SomeExpr *x, |
614 | const Fortran::lower::SomeExpr *y) { |
615 | const auto *empty = |
616 | llvm::DenseMapInfo<const Fortran::lower::SomeExpr *>::getEmptyKey(); |
617 | const auto *tombstone = |
618 | llvm::DenseMapInfo<const Fortran::lower::SomeExpr *>::getTombstoneKey(); |
619 | if (x == empty || y == empty || x == tombstone || y == tombstone) |
620 | return x == y; |
621 | return x == y || IsEqualEvaluateExpr::isEqual(*x, *y); |
622 | } |
623 | |
624 | bool isEqual(const Fortran::lower::ExplicitIterSpace::ArrayBases &x, |
625 | const Fortran::lower::ExplicitIterSpace::ArrayBases &y) { |
626 | return Fortran::common::visit( |
627 | Fortran::common::visitors{ |
628 | // Fortran::semantics::Symbol * are the exception here. These pointers |
629 | // have identity; if two Symbol * values are the same (different) then |
630 | // they are the same (different) logical symbol. |
631 | [&](Fortran::lower::FrontEndSymbol p, |
632 | Fortran::lower::FrontEndSymbol q) { return p == q; }, |
633 | [&](const auto *p, const auto *q) { |
634 | if constexpr (std::is_same_v<decltype(p), decltype(q)>) { |
635 | return IsEqualEvaluateExpr::isEqual(*p, *q); |
636 | } else { |
637 | // Different subtree types are never equal. |
638 | return false; |
639 | } |
640 | }}, |
641 | x, y); |
642 | } |
643 | |
644 | void copyFirstPrivateSymbol(lower::AbstractConverter &converter, |
645 | const semantics::Symbol *sym, |
646 | mlir::OpBuilder::InsertPoint *copyAssignIP) { |
647 | if (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) || |
648 | sym->test(semantics::Symbol::Flag::LocalityLocalInit)) |
649 | converter.copyHostAssociateVar(*sym, copyAssignIP); |
650 | } |
651 | |
652 | template <typename OpType, typename OperandsStructType> |
653 | void privatizeSymbol( |
654 | lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, |
655 | lower::SymMap &symTable, |
656 | llvm::SetVector<const semantics::Symbol *> &allPrivatizedSymbols, |
657 | llvm::SmallSet<const semantics::Symbol *, 16> &mightHaveReadHostSym, |
658 | const semantics::Symbol *symToPrivatize, OperandsStructType *clauseOps) { |
659 | constexpr bool isDoConcurrent = |
660 | std::is_same_v<OpType, fir::LocalitySpecifierOp>; |
661 | mlir::OpBuilder::InsertPoint dcIP; |
662 | |
663 | if (isDoConcurrent) { |
664 | dcIP = firOpBuilder.saveInsertionPoint(); |
665 | firOpBuilder.setInsertionPoint( |
666 | firOpBuilder.getRegion().getParentOfType<fir::DoConcurrentOp>()); |
667 | } |
668 | |
669 | const semantics::Symbol *sym = |
670 | isDoConcurrent ? &symToPrivatize->GetUltimate() : symToPrivatize; |
671 | const lower::SymbolBox hsb = isDoConcurrent |
672 | ? converter.shallowLookupSymbol(*sym) |
673 | : converter.lookupOneLevelUpSymbol(*sym); |
674 | assert(hsb && "Host symbol box not found" ); |
675 | |
676 | mlir::Location symLoc = hsb.getAddr().getLoc(); |
677 | std::string privatizerName = sym->name().ToString() + ".privatizer" ; |
678 | bool emitCopyRegion = |
679 | symToPrivatize->test(semantics::Symbol::Flag::OmpFirstPrivate) || |
680 | symToPrivatize->test(semantics::Symbol::Flag::LocalityLocalInit); |
681 | |
682 | mlir::Value privVal = hsb.getAddr(); |
683 | mlir::Type allocType = privVal.getType(); |
684 | if (!mlir::isa<fir::PointerType>(privVal.getType())) |
685 | allocType = fir::unwrapRefType(privVal.getType()); |
686 | |
687 | if (auto poly = mlir::dyn_cast<fir::ClassType>(allocType)) { |
688 | if (!mlir::isa<fir::PointerType>(poly.getEleTy()) && emitCopyRegion) |
689 | TODO(symLoc, "create polymorphic host associated copy" ); |
690 | } |
691 | |
692 | // fir.array<> cannot be converted to any single llvm type and fir helpers |
693 | // are not available in openmp to llvmir translation so we cannot generate |
694 | // an alloca for a fir.array type there. Get around this by boxing all |
695 | // arrays. |
696 | if (mlir::isa<fir::SequenceType>(allocType)) { |
697 | hlfir::Entity entity{hsb.getAddr()}; |
698 | entity = genVariableBox(symLoc, firOpBuilder, entity); |
699 | privVal = entity.getBase(); |
700 | allocType = privVal.getType(); |
701 | } |
702 | |
703 | if (mlir::isa<fir::BaseBoxType>(privVal.getType())) { |
704 | // Boxes should be passed by reference into nested regions: |
705 | auto oldIP = firOpBuilder.saveInsertionPoint(); |
706 | firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); |
707 | auto alloca = firOpBuilder.create<fir::AllocaOp>(symLoc, privVal.getType()); |
708 | firOpBuilder.restoreInsertionPoint(oldIP); |
709 | firOpBuilder.create<fir::StoreOp>(symLoc, privVal, alloca); |
710 | privVal = alloca; |
711 | } |
712 | |
713 | mlir::Type argType = privVal.getType(); |
714 | |
715 | OpType privatizerOp = [&]() { |
716 | auto moduleOp = firOpBuilder.getModule(); |
717 | auto uniquePrivatizerName = fir::getTypeAsString( |
718 | allocType, converter.getKindMap(), |
719 | converter.mangleName(*sym) + |
720 | (emitCopyRegion ? "_firstprivate" : "_private" )); |
721 | |
722 | if (auto existingPrivatizer = |
723 | moduleOp.lookupSymbol<OpType>(uniquePrivatizerName)) |
724 | return existingPrivatizer; |
725 | |
726 | mlir::OpBuilder::InsertionGuard guard(firOpBuilder); |
727 | firOpBuilder.setInsertionPointToStart(moduleOp.getBody()); |
728 | OpType result; |
729 | |
730 | if constexpr (std::is_same_v<OpType, mlir::omp::PrivateClauseOp>) { |
731 | result = firOpBuilder.create<OpType>( |
732 | symLoc, uniquePrivatizerName, allocType, |
733 | emitCopyRegion ? mlir::omp::DataSharingClauseType::FirstPrivate |
734 | : mlir::omp::DataSharingClauseType::Private); |
735 | } else { |
736 | result = firOpBuilder.create<OpType>( |
737 | symLoc, uniquePrivatizerName, allocType, |
738 | emitCopyRegion ? fir::LocalitySpecifierType::LocalInit |
739 | : fir::LocalitySpecifierType::Local); |
740 | } |
741 | |
742 | fir::ExtendedValue symExV = converter.getSymbolExtendedValue(*sym); |
743 | lower::SymMapScope outerScope(symTable); |
744 | |
745 | // Populate the `init` region. |
746 | // We need to initialize in the following cases: |
747 | // 1. The allocation was for a derived type which requires initialization |
748 | // (this can be skipped if it will be initialized anyway by the copy |
749 | // region, unless the derived type has allocatable components) |
750 | // 2. The allocation was for any kind of box |
751 | // 3. The allocation was for a boxed character |
752 | const bool needsInitialization = |
753 | (Fortran::lower::hasDefaultInitialization(sym->GetUltimate()) && |
754 | (!emitCopyRegion || hlfir::mayHaveAllocatableComponent(allocType))) || |
755 | mlir::isa<fir::BaseBoxType>(allocType) || |
756 | mlir::isa<fir::BoxCharType>(allocType); |
757 | if (needsInitialization) { |
758 | lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol( |
759 | isDoConcurrent ? symToPrivatize->GetUltimate() : *symToPrivatize); |
760 | |
761 | assert(hsb && "Host symbol box not found" ); |
762 | hlfir::Entity entity{hsb.getAddr()}; |
763 | bool cannotHaveNonDefaultLowerBounds = |
764 | !entity.mayHaveNonDefaultLowerBounds(); |
765 | |
766 | mlir::Region &initRegion = result.getInitRegion(); |
767 | mlir::Location symLoc = hsb.getAddr().getLoc(); |
768 | mlir::Block *initBlock = firOpBuilder.createBlock( |
769 | &initRegion, /*insertPt=*/{}, {argType, argType}, {symLoc, symLoc}); |
770 | |
771 | bool emitCopyRegion = |
772 | symToPrivatize->test(semantics::Symbol::Flag::OmpFirstPrivate) || |
773 | symToPrivatize->test( |
774 | Fortran::semantics::Symbol::Flag::LocalityLocalInit); |
775 | |
776 | populateByRefInitAndCleanupRegions( |
777 | converter, symLoc, argType, /*scalarInitValue=*/nullptr, initBlock, |
778 | result.getInitPrivateArg(), result.getInitMoldArg(), |
779 | result.getDeallocRegion(), |
780 | emitCopyRegion ? DeclOperationKind::FirstPrivateOrLocalInit |
781 | : DeclOperationKind::PrivateOrLocal, |
782 | symToPrivatize, cannotHaveNonDefaultLowerBounds, isDoConcurrent); |
783 | // TODO: currently there are false positives from dead uses of the mold |
784 | // arg |
785 | if (result.initReadsFromMold()) |
786 | mightHaveReadHostSym.insert(symToPrivatize); |
787 | } |
788 | |
789 | // Populate the `copy` region if this is a `firstprivate`. |
790 | if (emitCopyRegion) { |
791 | mlir::Region ©Region = result.getCopyRegion(); |
792 | // First block argument corresponding to the original/host value while |
793 | // second block argument corresponding to the privatized value. |
794 | mlir::Block *copyEntryBlock = firOpBuilder.createBlock( |
795 | ©Region, /*insertPt=*/{}, {argType, argType}, {symLoc, symLoc}); |
796 | firOpBuilder.setInsertionPointToEnd(copyEntryBlock); |
797 | |
798 | auto addSymbol = [&](unsigned argIdx, const semantics::Symbol *symToMap, |
799 | bool force = false) { |
800 | symExV.match( |
801 | [&](const fir::MutableBoxValue &box) { |
802 | symTable.addSymbol( |
803 | *symToMap, |
804 | fir::substBase(box, copyRegion.getArgument(argIdx)), force); |
805 | }, |
806 | [&](const auto &box) { |
807 | symTable.addSymbol(*symToMap, copyRegion.getArgument(argIdx), |
808 | force); |
809 | }); |
810 | }; |
811 | |
812 | addSymbol(0, sym, true); |
813 | lower::SymMapScope innerScope(symTable); |
814 | addSymbol(1, symToPrivatize); |
815 | |
816 | auto ip = firOpBuilder.saveInsertionPoint(); |
817 | copyFirstPrivateSymbol(converter, symToPrivatize, &ip); |
818 | |
819 | if constexpr (std::is_same_v<OpType, mlir::omp::PrivateClauseOp>) { |
820 | firOpBuilder.create<mlir::omp::YieldOp>( |
821 | hsb.getAddr().getLoc(), |
822 | symTable.shallowLookupSymbol(*symToPrivatize).getAddr()); |
823 | } else { |
824 | firOpBuilder.create<fir::YieldOp>( |
825 | hsb.getAddr().getLoc(), |
826 | symTable.shallowLookupSymbol(*symToPrivatize).getAddr()); |
827 | } |
828 | } |
829 | |
830 | return result; |
831 | }(); |
832 | |
833 | if (clauseOps) { |
834 | clauseOps->privateSyms.push_back(mlir::SymbolRefAttr::get(privatizerOp)); |
835 | clauseOps->privateVars.push_back(privVal); |
836 | } |
837 | |
838 | if (isDoConcurrent) |
839 | allPrivatizedSymbols.insert(symToPrivatize); |
840 | |
841 | if (isDoConcurrent) |
842 | firOpBuilder.restoreInsertionPoint(dcIP); |
843 | } |
844 | |
845 | template void |
846 | privatizeSymbol<mlir::omp::PrivateClauseOp, mlir::omp::PrivateClauseOps>( |
847 | lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, |
848 | lower::SymMap &symTable, |
849 | llvm::SetVector<const semantics::Symbol *> &allPrivatizedSymbols, |
850 | llvm::SmallSet<const semantics::Symbol *, 16> &mightHaveReadHostSym, |
851 | const semantics::Symbol *symToPrivatize, |
852 | mlir::omp::PrivateClauseOps *clauseOps); |
853 | |
854 | template void |
855 | privatizeSymbol<fir::LocalitySpecifierOp, fir::LocalitySpecifierOperands>( |
856 | lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, |
857 | lower::SymMap &symTable, |
858 | llvm::SetVector<const semantics::Symbol *> &allPrivatizedSymbols, |
859 | llvm::SmallSet<const semantics::Symbol *, 16> &mightHaveReadHostSym, |
860 | const semantics::Symbol *symToPrivatize, |
861 | fir::LocalitySpecifierOperands *clauseOps); |
862 | |
863 | } // end namespace Fortran::lower |
864 | |