1 | //===-- lib/Evaluate/fold-integer.cpp -------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "fold-implementation.h" |
10 | #include "fold-matmul.h" |
11 | #include "fold-reduction.h" |
12 | #include "flang/Evaluate/check-expression.h" |
13 | |
14 | namespace Fortran::evaluate { |
15 | |
16 | // Given a collection of ConstantSubscripts values, package them as a Constant. |
17 | // Return scalar value if asScalar == true and shape-dim array otherwise. |
18 | template <typename T> |
19 | Expr<T> PackageConstantBounds( |
20 | const ConstantSubscripts &&bounds, bool asScalar = false) { |
21 | if (asScalar) { |
22 | return Expr<T>{Constant<T>{bounds.at(0)}}; |
23 | } else { |
24 | // As rank-dim array |
25 | const int rank{GetRank(bounds)}; |
26 | std::vector<Scalar<T>> packed(rank); |
27 | std::transform(bounds.begin(), bounds.end(), packed.begin(), |
28 | [](ConstantSubscript x) { return Scalar<T>(x); }); |
29 | return Expr<T>{Constant<T>{std::move(packed), ConstantSubscripts{rank}}}; |
30 | } |
31 | } |
32 | |
33 | // If a DIM= argument to LBOUND(), UBOUND(), or SIZE() exists and has a valid |
34 | // constant value, return in "dimVal" that value, less 1 (to make it suitable |
35 | // for use as a C++ vector<> index). Also check for erroneous constant values |
36 | // and returns false on error. |
37 | static bool CheckDimArg(const std::optional<ActualArgument> &dimArg, |
38 | const Expr<SomeType> &array, parser::ContextualMessages &messages, |
39 | bool isLBound, std::optional<int> &dimVal) { |
40 | dimVal.reset(); |
41 | if (int rank{array.Rank()}; rank > 0 || IsAssumedRank(array)) { |
42 | auto named{ExtractNamedEntity(array)}; |
43 | if (auto dim64{ToInt64(dimArg)}) { |
44 | if (*dim64 < 1) { |
45 | messages.Say("DIM=%jd dimension must be positive"_err_en_US , *dim64); |
46 | return false; |
47 | } else if (!IsAssumedRank(array) && *dim64 > rank) { |
48 | messages.Say( |
49 | "DIM=%jd dimension is out of range for rank-%d array"_err_en_US , |
50 | *dim64, rank); |
51 | return false; |
52 | } else if (!isLBound && named && |
53 | semantics::IsAssumedSizeArray(named->GetLastSymbol()) && |
54 | *dim64 == rank) { |
55 | messages.Say( |
56 | "DIM=%jd dimension is out of range for rank-%d assumed-size array"_err_en_US , |
57 | *dim64, rank); |
58 | return false; |
59 | } else if (IsAssumedRank(array)) { |
60 | if (*dim64 > common::maxRank) { |
61 | messages.Say( |
62 | "DIM=%jd dimension is too large for any array (maximum rank %d)"_err_en_US , |
63 | *dim64, common::maxRank); |
64 | return false; |
65 | } |
66 | } else { |
67 | dimVal = static_cast<int>(*dim64 - 1); // 1-based to 0-based |
68 | } |
69 | } |
70 | } |
71 | return true; |
72 | } |
73 | |
74 | // Class to retrieve the constant bound of an expression which is an |
75 | // array that devolves to a type of Constant<T> |
76 | class GetConstantArrayBoundHelper { |
77 | public: |
78 | template <typename T> |
79 | static Expr<T> GetLbound( |
80 | const Expr<SomeType> &array, std::optional<int> dim) { |
81 | return PackageConstantBounds<T>( |
82 | GetConstantArrayBoundHelper(dim, /*getLbound=*/true).Get(array), |
83 | dim.has_value()); |
84 | } |
85 | |
86 | template <typename T> |
87 | static Expr<T> GetUbound( |
88 | const Expr<SomeType> &array, std::optional<int> dim) { |
89 | return PackageConstantBounds<T>( |
90 | GetConstantArrayBoundHelper(dim, /*getLbound=*/false).Get(array), |
91 | dim.has_value()); |
92 | } |
93 | |
94 | private: |
95 | GetConstantArrayBoundHelper( |
96 | std::optional<ConstantSubscript> dim, bool getLbound) |
97 | : dim_{dim}, getLbound_{getLbound} {} |
98 | |
99 | template <typename T> ConstantSubscripts Get(const T &) { |
100 | // The method is needed for template expansion, but we should never get |
101 | // here in practice. |
102 | CHECK(false); |
103 | return {0}; |
104 | } |
105 | |
106 | template <typename T> ConstantSubscripts Get(const Constant<T> &x) { |
107 | if (getLbound_) { |
108 | // Return the lower bound |
109 | if (dim_) { |
110 | return {x.lbounds().at(*dim_)}; |
111 | } else { |
112 | return x.lbounds(); |
113 | } |
114 | } else { |
115 | // Return the upper bound |
116 | if (arrayFromParenthesesExpr) { |
117 | // Underlying array comes from (x) expression - return shapes |
118 | if (dim_) { |
119 | return {x.shape().at(*dim_)}; |
120 | } else { |
121 | return x.shape(); |
122 | } |
123 | } else { |
124 | return x.ComputeUbounds(dim_); |
125 | } |
126 | } |
127 | } |
128 | |
129 | template <typename T> ConstantSubscripts Get(const Parentheses<T> &x) { |
130 | // Case of temp variable inside parentheses - return [1, ... 1] for lower |
131 | // bounds and shape for upper bounds |
132 | if (getLbound_) { |
133 | return ConstantSubscripts(x.Rank(), ConstantSubscript{1}); |
134 | } else { |
135 | // Indicate that underlying array comes from parentheses expression. |
136 | // Continue to unwrap expression until we hit a constant |
137 | arrayFromParenthesesExpr = true; |
138 | return Get(x.left()); |
139 | } |
140 | } |
141 | |
142 | template <typename T> ConstantSubscripts Get(const Expr<T> &x) { |
143 | // recurse through Expr<T>'a until we hit a constant |
144 | return common::visit([&](const auto &inner) { return Get(inner); }, |
145 | // [&](const auto &) { return 0; }, |
146 | x.u); |
147 | } |
148 | |
149 | const std::optional<ConstantSubscript> dim_; |
150 | const bool getLbound_; |
151 | bool arrayFromParenthesesExpr{false}; |
152 | }; |
153 | |
154 | template <int KIND> |
155 | Expr<Type<TypeCategory::Integer, KIND>> LBOUND(FoldingContext &context, |
156 | FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) { |
157 | using T = Type<TypeCategory::Integer, KIND>; |
158 | ActualArguments &args{funcRef.arguments()}; |
159 | if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) { |
160 | std::optional<int> dim; |
161 | if (funcRef.Rank() == 0) { |
162 | // Optional DIM= argument is present: result is scalar. |
163 | if (!CheckDimArg(args[1], *array, context.messages(), true, dim)) { |
164 | return MakeInvalidIntrinsic<T>(std::move(funcRef)); |
165 | } else if (!dim) { |
166 | // DIM= is present but not constant, or error |
167 | return Expr<T>{std::move(funcRef)}; |
168 | } |
169 | } |
170 | if (IsAssumedRank(*array)) { |
171 | // Would like to return 1 if DIM=.. is present, but that would be |
172 | // hiding a runtime error if the DIM= were too large (including |
173 | // the case of an assumed-rank argument that's scalar). |
174 | } else if (int rank{array->Rank()}; rank > 0) { |
175 | bool lowerBoundsAreOne{true}; |
176 | if (auto named{ExtractNamedEntity(*array)}) { |
177 | const Symbol &symbol{named->GetLastSymbol()}; |
178 | if (symbol.Rank() == rank) { |
179 | lowerBoundsAreOne = false; |
180 | if (dim) { |
181 | if (auto lb{GetLBOUND(context, *named, *dim)}) { |
182 | return Fold(context, ConvertToType<T>(std::move(*lb))); |
183 | } |
184 | } else if (auto extents{ |
185 | AsExtentArrayExpr(GetLBOUNDs(context, *named))}) { |
186 | return Fold(context, |
187 | ConvertToType<T>(Expr<ExtentType>{std::move(*extents)})); |
188 | } |
189 | } else { |
190 | lowerBoundsAreOne = symbol.Rank() == 0; // LBOUND(array%component) |
191 | } |
192 | } |
193 | if (IsActuallyConstant(*array)) { |
194 | return GetConstantArrayBoundHelper::GetLbound<T>(*array, dim); |
195 | } |
196 | if (lowerBoundsAreOne) { |
197 | ConstantSubscripts ones(rank, ConstantSubscript{1}); |
198 | return PackageConstantBounds<T>(std::move(ones), dim.has_value()); |
199 | } |
200 | } |
201 | } |
202 | return Expr<T>{std::move(funcRef)}; |
203 | } |
204 | |
205 | template <int KIND> |
206 | Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context, |
207 | FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) { |
208 | using T = Type<TypeCategory::Integer, KIND>; |
209 | ActualArguments &args{funcRef.arguments()}; |
210 | if (auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) { |
211 | std::optional<int> dim; |
212 | if (funcRef.Rank() == 0) { |
213 | // Optional DIM= argument is present: result is scalar. |
214 | if (!CheckDimArg(args[1], *array, context.messages(), false, dim)) { |
215 | return MakeInvalidIntrinsic<T>(std::move(funcRef)); |
216 | } else if (!dim) { |
217 | // DIM= is present but not constant, or error |
218 | return Expr<T>{std::move(funcRef)}; |
219 | } |
220 | } |
221 | if (IsAssumedRank(*array)) { |
222 | } else if (int rank{array->Rank()}; rank > 0) { |
223 | bool takeBoundsFromShape{true}; |
224 | if (auto named{ExtractNamedEntity(*array)}) { |
225 | const Symbol &symbol{named->GetLastSymbol()}; |
226 | if (symbol.Rank() == rank) { |
227 | takeBoundsFromShape = false; |
228 | if (dim) { |
229 | if (auto ub{GetUBOUND(context, *named, *dim)}) { |
230 | return Fold(context, ConvertToType<T>(std::move(*ub))); |
231 | } |
232 | } else { |
233 | Shape ubounds{GetUBOUNDs(context, *named)}; |
234 | if (semantics::IsAssumedSizeArray(symbol)) { |
235 | CHECK(!ubounds.back()); |
236 | ubounds.back() = ExtentExpr{-1}; |
237 | } |
238 | if (auto extents{AsExtentArrayExpr(ubounds)}) { |
239 | return Fold(context, |
240 | ConvertToType<T>(Expr<ExtentType>{std::move(*extents)})); |
241 | } |
242 | } |
243 | } else { |
244 | takeBoundsFromShape = symbol.Rank() == 0; // UBOUND(array%component) |
245 | } |
246 | } |
247 | if (IsActuallyConstant(*array)) { |
248 | return GetConstantArrayBoundHelper::GetUbound<T>(*array, dim); |
249 | } |
250 | if (takeBoundsFromShape) { |
251 | if (auto shape{GetContextFreeShape(context, *array)}) { |
252 | if (dim) { |
253 | if (auto &dimSize{shape->at(*dim)}) { |
254 | return Fold(context, |
255 | ConvertToType<T>(Expr<ExtentType>{std::move(*dimSize)})); |
256 | } |
257 | } else if (auto shapeExpr{AsExtentArrayExpr(*shape)}) { |
258 | return Fold(context, ConvertToType<T>(std::move(*shapeExpr))); |
259 | } |
260 | } |
261 | } |
262 | } |
263 | } |
264 | return Expr<T>{std::move(funcRef)}; |
265 | } |
266 | |
267 | // COUNT() |
268 | template <typename T, int MASK_KIND> class CountAccumulator { |
269 | using MaskT = Type<TypeCategory::Logical, MASK_KIND>; |
270 | |
271 | public: |
272 | CountAccumulator(const Constant<MaskT> &mask) : mask_{mask} {} |
273 | void operator()( |
274 | Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) { |
275 | if (mask_.At(at).IsTrue()) { |
276 | auto incremented{element.AddSigned(Scalar<T>{1})}; |
277 | overflow_ |= incremented.overflow; |
278 | element = incremented.value; |
279 | } |
280 | } |
281 | bool overflow() const { return overflow_; } |
282 | void Done(Scalar<T> &) const {} |
283 | |
284 | private: |
285 | const Constant<MaskT> &mask_; |
286 | bool overflow_{false}; |
287 | }; |
288 | |
289 | template <typename T, int maskKind> |
290 | static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) { |
291 | using KindLogical = Type<TypeCategory::Logical, maskKind>; |
292 | static_assert(T::category == TypeCategory::Integer); |
293 | std::optional<int> dim; |
294 | if (std::optional<ArrayAndMask<KindLogical>> arrayAndMask{ |
295 | ProcessReductionArgs<KindLogical>( |
296 | context, ref.arguments(), dim, /*ARRAY=*/0, /*DIM=*/1)}) { |
297 | CountAccumulator<T, maskKind> accumulator{arrayAndMask->array}; |
298 | Constant<T> result{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask, |
299 | dim, Scalar<T>{}, accumulator)}; |
300 | if (accumulator.overflow()) { |
301 | context.messages().Say( |
302 | "Result of intrinsic function COUNT overflows its result type"_warn_en_US ); |
303 | } |
304 | return Expr<T>{std::move(result)}; |
305 | } |
306 | return Expr<T>{std::move(ref)}; |
307 | } |
308 | |
309 | // FINDLOC(), MAXLOC(), & MINLOC() |
310 | enum class WhichLocation { Findloc, Maxloc, Minloc }; |
311 | template <WhichLocation WHICH> class LocationHelper { |
312 | public: |
313 | LocationHelper( |
314 | DynamicType &&type, ActualArguments &arg, FoldingContext &context) |
315 | : type_{type}, arg_{arg}, context_{context} {} |
316 | using Result = std::optional<Constant<SubscriptInteger>>; |
317 | using Types = std::conditional_t<WHICH == WhichLocation::Findloc, |
318 | AllIntrinsicTypes, RelationalTypes>; |
319 | |
320 | template <typename T> Result Test() const { |
321 | if (T::category != type_.category() || T::kind != type_.kind()) { |
322 | return std::nullopt; |
323 | } |
324 | CHECK(arg_.size() == (WHICH == WhichLocation::Findloc ? 6 : 5)); |
325 | Folder<T> folder{context_}; |
326 | Constant<T> *array{folder.Folding(arg_[0])}; |
327 | if (!array) { |
328 | return std::nullopt; |
329 | } |
330 | std::optional<Constant<T>> value; |
331 | if constexpr (WHICH == WhichLocation::Findloc) { |
332 | if (const Constant<T> *p{folder.Folding(arg_[1])}) { |
333 | value.emplace(*p); |
334 | } else { |
335 | return std::nullopt; |
336 | } |
337 | } |
338 | std::optional<int> dim; |
339 | Constant<LogicalResult> *mask{ |
340 | GetReductionMASK(arg_[maskArg], array->shape(), context_)}; |
341 | if ((!mask && arg_[maskArg]) || |
342 | !CheckReductionDIM(dim, context_, arg_, dimArg, array->Rank())) { |
343 | return std::nullopt; |
344 | } |
345 | bool back{false}; |
346 | if (arg_[backArg]) { |
347 | const auto *backConst{ |
348 | Folder<LogicalResult>{context_}.Folding(arg_[backArg])}; |
349 | if (backConst) { |
350 | back = backConst->GetScalarValue().value().IsTrue(); |
351 | } else { |
352 | return std::nullopt; |
353 | } |
354 | } |
355 | const RelationalOperator relation{WHICH == WhichLocation::Findloc |
356 | ? RelationalOperator::EQ |
357 | : WHICH == WhichLocation::Maxloc |
358 | ? (back ? RelationalOperator::GE : RelationalOperator::GT) |
359 | : back ? RelationalOperator::LE |
360 | : RelationalOperator::LT}; |
361 | // Use lower bounds of 1 exclusively. |
362 | array->SetLowerBoundsToOne(); |
363 | ConstantSubscripts at{array->lbounds()}, maskAt, resultIndices, resultShape; |
364 | if (mask) { |
365 | if (auto scalarMask{mask->GetScalarValue()}) { |
366 | // Convert into array in case of scalar MASK= (for |
367 | // MAXLOC/MINLOC/FINDLOC mask should be conformable) |
368 | ConstantSubscript n{GetSize(array->shape())}; |
369 | std::vector<Scalar<LogicalResult>> mask_elements( |
370 | n, Scalar<LogicalResult>{scalarMask.value()}); |
371 | *mask = Constant<LogicalResult>{ |
372 | std::move(mask_elements), ConstantSubscripts{array->shape()}}; |
373 | } |
374 | mask->SetLowerBoundsToOne(); |
375 | maskAt = mask->lbounds(); |
376 | } |
377 | if (dim) { // DIM= |
378 | if (*dim < 1 || *dim > array->Rank()) { |
379 | context_.messages().Say("DIM=%d is out of range"_err_en_US , *dim); |
380 | return std::nullopt; |
381 | } |
382 | int zbDim{*dim - 1}; |
383 | resultShape = array->shape(); |
384 | resultShape.erase( |
385 | resultShape.begin() + zbDim); // scalar if array is vector |
386 | ConstantSubscript dimLength{array->shape()[zbDim]}; |
387 | ConstantSubscript n{GetSize(resultShape)}; |
388 | for (ConstantSubscript j{0}; j < n; ++j) { |
389 | ConstantSubscript hit{0}; |
390 | if constexpr (WHICH == WhichLocation::Maxloc || |
391 | WHICH == WhichLocation::Minloc) { |
392 | value.reset(); |
393 | } |
394 | for (ConstantSubscript k{0}; k < dimLength; |
395 | ++k, ++at[zbDim], mask && ++maskAt[zbDim]) { |
396 | if ((!mask || mask->At(maskAt).IsTrue()) && |
397 | IsHit(array->At(at), value, relation, back)) { |
398 | hit = at[zbDim]; |
399 | if constexpr (WHICH == WhichLocation::Findloc) { |
400 | if (!back) { |
401 | break; |
402 | } |
403 | } |
404 | } |
405 | } |
406 | resultIndices.emplace_back(hit); |
407 | at[zbDim] = std::max<ConstantSubscript>(dimLength, 1); |
408 | array->IncrementSubscripts(at); |
409 | at[zbDim] = 1; |
410 | if (mask) { |
411 | maskAt[zbDim] = mask->lbounds()[zbDim] + |
412 | std::max<ConstantSubscript>(dimLength, 1) - 1; |
413 | mask->IncrementSubscripts(maskAt); |
414 | maskAt[zbDim] = mask->lbounds()[zbDim]; |
415 | } |
416 | } |
417 | } else { // no DIM= |
418 | resultShape = ConstantSubscripts{array->Rank()}; // always a vector |
419 | ConstantSubscript n{GetSize(array->shape())}; |
420 | resultIndices = ConstantSubscripts(array->Rank(), 0); |
421 | for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at), |
422 | mask && mask->IncrementSubscripts(maskAt)) { |
423 | if ((!mask || mask->At(maskAt).IsTrue()) && |
424 | IsHit(array->At(at), value, relation, back)) { |
425 | resultIndices = at; |
426 | if constexpr (WHICH == WhichLocation::Findloc) { |
427 | if (!back) { |
428 | break; |
429 | } |
430 | } |
431 | } |
432 | } |
433 | } |
434 | std::vector<Scalar<SubscriptInteger>> resultElements; |
435 | for (ConstantSubscript j : resultIndices) { |
436 | resultElements.emplace_back(j); |
437 | } |
438 | return Constant<SubscriptInteger>{ |
439 | std::move(resultElements), std::move(resultShape)}; |
440 | } |
441 | |
442 | private: |
443 | template <typename T> |
444 | bool IsHit(typename Constant<T>::Element element, |
445 | std::optional<Constant<T>> &value, |
446 | [[maybe_unused]] RelationalOperator relation, |
447 | [[maybe_unused]] bool back) const { |
448 | std::optional<Expr<LogicalResult>> cmp; |
449 | bool result{true}; |
450 | if (value) { |
451 | if constexpr (T::category == TypeCategory::Logical) { |
452 | // array(at) .EQV. value? |
453 | static_assert(WHICH == WhichLocation::Findloc); |
454 | cmp.emplace(ConvertToType<LogicalResult>( |
455 | Expr<T>{LogicalOperation<T::kind>{LogicalOperator::Eqv, |
456 | Expr<T>{Constant<T>{element}}, Expr<T>{Constant<T>{*value}}}})); |
457 | } else { // compare array(at) to value |
458 | if constexpr (T::category == TypeCategory::Real && |
459 | (WHICH == WhichLocation::Maxloc || |
460 | WHICH == WhichLocation::Minloc)) { |
461 | if (value && value->GetScalarValue().value().IsNotANumber() && |
462 | (back || !element.IsNotANumber())) { |
463 | // Replace NaN |
464 | cmp.emplace(Constant<LogicalResult>{Scalar<LogicalResult>{true}}); |
465 | } |
466 | } |
467 | if (!cmp) { |
468 | cmp.emplace(PackageRelation(relation, Expr<T>{Constant<T>{element}}, |
469 | Expr<T>{Constant<T>{*value}})); |
470 | } |
471 | } |
472 | Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))}; |
473 | result = GetScalarConstantValue<LogicalResult>(folded).value().IsTrue(); |
474 | } else { |
475 | // first unmasked element for MAXLOC/MINLOC - always take it |
476 | } |
477 | if constexpr (WHICH == WhichLocation::Maxloc || |
478 | WHICH == WhichLocation::Minloc) { |
479 | if (result) { |
480 | value.emplace(std::move(element)); |
481 | } |
482 | } |
483 | return result; |
484 | } |
485 | |
486 | static constexpr int dimArg{WHICH == WhichLocation::Findloc ? 2 : 1}; |
487 | static constexpr int maskArg{dimArg + 1}; |
488 | static constexpr int backArg{maskArg + 2}; |
489 | |
490 | DynamicType type_; |
491 | ActualArguments &arg_; |
492 | FoldingContext &context_; |
493 | }; |
494 | |
495 | template <WhichLocation which> |
496 | static std::optional<Constant<SubscriptInteger>> FoldLocationCall( |
497 | ActualArguments &arg, FoldingContext &context) { |
498 | if (arg[0]) { |
499 | if (auto type{arg[0]->GetType()}) { |
500 | if constexpr (which == WhichLocation::Findloc) { |
501 | // Both ARRAY and VALUE are susceptible to conversion to a common |
502 | // comparison type. |
503 | if (arg[1]) { |
504 | if (auto valType{arg[1]->GetType()}) { |
505 | if (auto compareType{ComparisonType(*type, *valType)}) { |
506 | type = compareType; |
507 | } |
508 | } |
509 | } |
510 | } |
511 | return common::SearchTypes( |
512 | LocationHelper<which>{std::move(*type), arg, context}); |
513 | } |
514 | } |
515 | return std::nullopt; |
516 | } |
517 | |
518 | template <WhichLocation which, typename T> |
519 | static Expr<T> FoldLocation(FoldingContext &context, FunctionRef<T> &&ref) { |
520 | static_assert(T::category == TypeCategory::Integer); |
521 | if (std::optional<Constant<SubscriptInteger>> found{ |
522 | FoldLocationCall<which>(ref.arguments(), context)}) { |
523 | return Expr<T>{Fold( |
524 | context, ConvertToType<T>(Expr<SubscriptInteger>{std::move(*found)}))}; |
525 | } else { |
526 | return Expr<T>{std::move(ref)}; |
527 | } |
528 | } |
529 | |
530 | // for IALL, IANY, & IPARITY |
531 | template <typename T> |
532 | static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref, |
533 | Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const, |
534 | Scalar<T> identity) { |
535 | static_assert(T::category == TypeCategory::Integer); |
536 | std::optional<int> dim; |
537 | if (std::optional<ArrayAndMask<T>> arrayAndMask{ |
538 | ProcessReductionArgs<T>(context, ref.arguments(), dim, |
539 | /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { |
540 | OperationAccumulator<T> accumulator{arrayAndMask->array, operation}; |
541 | return Expr<T>{DoReduction<T>( |
542 | arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}; |
543 | } |
544 | return Expr<T>{std::move(ref)}; |
545 | } |
546 | |
547 | template <int KIND> |
548 | Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( |
549 | FoldingContext &context, |
550 | FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) { |
551 | using T = Type<TypeCategory::Integer, KIND>; |
552 | using Int4 = Type<TypeCategory::Integer, 4>; |
553 | ActualArguments &args{funcRef.arguments()}; |
554 | auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}; |
555 | CHECK(intrinsic); |
556 | std::string name{intrinsic->name}; |
557 | auto FromInt64{[&name, &context](std::int64_t n) { |
558 | Scalar<T> result{n}; |
559 | if (result.ToInt64() != n) { |
560 | context.messages().Say( |
561 | "Result of intrinsic function '%s' (%jd) overflows its result type"_warn_en_US , |
562 | name, std::intmax_t{n}); |
563 | } |
564 | return result; |
565 | }}; |
566 | if (name == "abs" ) { // incl. babs, iiabs, jiaabs, & kiabs |
567 | return FoldElementalIntrinsic<T, T>(context, std::move(funcRef), |
568 | ScalarFunc<T, T>([&context](const Scalar<T> &i) -> Scalar<T> { |
569 | typename Scalar<T>::ValueWithOverflow j{i.ABS()}; |
570 | if (j.overflow) { |
571 | context.messages().Say( |
572 | "abs(integer(kind=%d)) folding overflowed"_warn_en_US , KIND); |
573 | } |
574 | return j.value; |
575 | })); |
576 | } else if (name == "bit_size" ) { |
577 | return Expr<T>{Scalar<T>::bits}; |
578 | } else if (name == "ceiling" || name == "floor" || name == "nint" ) { |
579 | if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) { |
580 | // NINT rounds ties away from zero, not to even |
581 | common::RoundingMode mode{name == "ceiling" ? common::RoundingMode::Up |
582 | : name == "floor" ? common::RoundingMode::Down |
583 | : common::RoundingMode::TiesAwayFromZero}; |
584 | return common::visit( |
585 | [&](const auto &kx) { |
586 | using TR = ResultType<decltype(kx)>; |
587 | return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef), |
588 | ScalarFunc<T, TR>([&](const Scalar<TR> &x) { |
589 | auto y{x.template ToInteger<Scalar<T>>(mode)}; |
590 | if (y.flags.test(RealFlag::Overflow)) { |
591 | context.messages().Say( |
592 | "%s intrinsic folding overflow"_warn_en_US , name); |
593 | } |
594 | return y.value; |
595 | })); |
596 | }, |
597 | cx->u); |
598 | } |
599 | } else if (name == "count" ) { |
600 | int maskKind = args[0]->GetType()->kind(); |
601 | switch (maskKind) { |
602 | SWITCH_COVERS_ALL_CASES |
603 | case 1: |
604 | return FoldCount<T, 1>(context, std::move(funcRef)); |
605 | case 2: |
606 | return FoldCount<T, 2>(context, std::move(funcRef)); |
607 | case 4: |
608 | return FoldCount<T, 4>(context, std::move(funcRef)); |
609 | case 8: |
610 | return FoldCount<T, 8>(context, std::move(funcRef)); |
611 | } |
612 | } else if (name == "digits" ) { |
613 | if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) { |
614 | return Expr<T>{common::visit( |
615 | [](const auto &kx) { |
616 | return Scalar<ResultType<decltype(kx)>>::DIGITS; |
617 | }, |
618 | cx->u)}; |
619 | } else if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) { |
620 | return Expr<T>{common::visit( |
621 | [](const auto &kx) { |
622 | return Scalar<ResultType<decltype(kx)>>::DIGITS; |
623 | }, |
624 | cx->u)}; |
625 | } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) { |
626 | return Expr<T>{common::visit( |
627 | [](const auto &kx) { |
628 | return Scalar<typename ResultType<decltype(kx)>::Part>::DIGITS; |
629 | }, |
630 | cx->u)}; |
631 | } |
632 | } else if (name == "dim" ) { |
633 | return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef), |
634 | ScalarFunc<T, T, T>([&context](const Scalar<T> &x, |
635 | const Scalar<T> &y) -> Scalar<T> { |
636 | auto result{x.DIM(y)}; |
637 | if (result.overflow) { |
638 | context.messages().Say("DIM intrinsic folding overflow"_warn_en_US ); |
639 | } |
640 | return result.value; |
641 | })); |
642 | } else if (name == "dot_product" ) { |
643 | return FoldDotProduct<T>(context, std::move(funcRef)); |
644 | } else if (name == "dshiftl" || name == "dshiftr" ) { |
645 | const auto fptr{ |
646 | name == "dshiftl" ? &Scalar<T>::DSHIFTL : &Scalar<T>::DSHIFTR}; |
647 | // Third argument can be of any kind. However, it must be smaller or equal |
648 | // than BIT_SIZE. It can be converted to Int4 to simplify. |
649 | if (const auto *argCon{Folder<T>(context).Folding(args[0])}; |
650 | argCon && argCon->empty()) { |
651 | } else if (const auto *shiftCon{Folder<Int4>(context).Folding(args[2])}) { |
652 | for (const auto &scalar : shiftCon->values()) { |
653 | std::int64_t shiftVal{scalar.ToInt64()}; |
654 | if (shiftVal < 0) { |
655 | context.messages().Say("SHIFT=%jd count for %s is negative"_err_en_US , |
656 | std::intmax_t{shiftVal}, name); |
657 | break; |
658 | } else if (shiftVal > T::Scalar::bits) { |
659 | context.messages().Say( |
660 | "SHIFT=%jd count for %s is greater than %d"_err_en_US , |
661 | std::intmax_t{shiftVal}, name, T::Scalar::bits); |
662 | break; |
663 | } |
664 | } |
665 | } |
666 | return FoldElementalIntrinsic<T, T, T, Int4>(context, std::move(funcRef), |
667 | ScalarFunc<T, T, T, Int4>( |
668 | [&fptr](const Scalar<T> &i, const Scalar<T> &j, |
669 | const Scalar<Int4> &shift) -> Scalar<T> { |
670 | return std::invoke(fptr, i, j, static_cast<int>(shift.ToInt64())); |
671 | })); |
672 | } else if (name == "exponent" ) { |
673 | if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) { |
674 | return common::visit( |
675 | [&funcRef, &context](const auto &x) -> Expr<T> { |
676 | using TR = typename std::decay_t<decltype(x)>::Result; |
677 | return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef), |
678 | &Scalar<TR>::template EXPONENT<Scalar<T>>); |
679 | }, |
680 | sx->u); |
681 | } else { |
682 | DIE("exponent argument must be real" ); |
683 | } |
684 | } else if (name == "findloc" ) { |
685 | return FoldLocation<WhichLocation::Findloc, T>(context, std::move(funcRef)); |
686 | } else if (name == "huge" ) { |
687 | return Expr<T>{Scalar<T>::HUGE()}; |
688 | } else if (name == "iachar" || name == "ichar" ) { |
689 | auto *someChar{UnwrapExpr<Expr<SomeCharacter>>(args[0])}; |
690 | CHECK(someChar); |
691 | if (auto len{ToInt64(someChar->LEN())}) { |
692 | if (len.value() < 1) { |
693 | context.messages().Say( |
694 | "Character in intrinsic function %s must have length one"_err_en_US , |
695 | name); |
696 | } else if (len.value() > 1 && |
697 | context.languageFeatures().ShouldWarn( |
698 | common::UsageWarning::Portability)) { |
699 | // Do not die, this was not checked before |
700 | context.messages().Say( |
701 | "Character in intrinsic function %s should have length one"_port_en_US , |
702 | name); |
703 | } else { |
704 | return common::visit( |
705 | [&funcRef, &context, &FromInt64](const auto &str) -> Expr<T> { |
706 | using Char = typename std::decay_t<decltype(str)>::Result; |
707 | (void)FromInt64; |
708 | return FoldElementalIntrinsic<T, Char>(context, |
709 | std::move(funcRef), |
710 | ScalarFunc<T, Char>( |
711 | #ifndef _MSC_VER |
712 | [&FromInt64](const Scalar<Char> &c) { |
713 | return FromInt64(CharacterUtils<Char::kind>::ICHAR( |
714 | CharacterUtils<Char::kind>::Resize(c, 1))); |
715 | })); |
716 | #else // _MSC_VER |
717 | // MSVC 14 get confused by the original code above and |
718 | // ends up emitting an error about passing a std::string |
719 | // to the std::u16string instantiation of |
720 | // CharacterUtils<2>::ICHAR(). Can't find a work-around, |
721 | // so remove the FromInt64 error checking lambda that |
722 | // seems to have caused the proble. |
723 | [](const Scalar<Char> &c) { |
724 | return CharacterUtils<Char::kind>::ICHAR( |
725 | CharacterUtils<Char::kind>::Resize(c, 1)); |
726 | })); |
727 | #endif // _MSC_VER |
728 | }, |
729 | someChar->u); |
730 | } |
731 | } |
732 | } else if (name == "iand" || name == "ior" || name == "ieor" ) { |
733 | auto fptr{&Scalar<T>::IAND}; |
734 | if (name == "iand" ) { // done in fptr declaration |
735 | } else if (name == "ior" ) { |
736 | fptr = &Scalar<T>::IOR; |
737 | } else if (name == "ieor" ) { |
738 | fptr = &Scalar<T>::IEOR; |
739 | } else { |
740 | common::die("missing case to fold intrinsic function %s" , name.c_str()); |
741 | } |
742 | return FoldElementalIntrinsic<T, T, T>( |
743 | context, std::move(funcRef), ScalarFunc<T, T, T>(fptr)); |
744 | } else if (name == "iall" ) { |
745 | return FoldBitReduction( |
746 | context, std::move(funcRef), &Scalar<T>::IAND, Scalar<T>{}.NOT()); |
747 | } else if (name == "iany" ) { |
748 | return FoldBitReduction( |
749 | context, std::move(funcRef), &Scalar<T>::IOR, Scalar<T>{}); |
750 | } else if (name == "ibclr" || name == "ibset" ) { |
751 | // Second argument can be of any kind. However, it must be smaller |
752 | // than BIT_SIZE. It can be converted to Int4 to simplify. |
753 | auto fptr{&Scalar<T>::IBCLR}; |
754 | if (name == "ibclr" ) { // done in fptr definition |
755 | } else if (name == "ibset" ) { |
756 | fptr = &Scalar<T>::IBSET; |
757 | } else { |
758 | common::die("missing case to fold intrinsic function %s" , name.c_str()); |
759 | } |
760 | if (const auto *argCon{Folder<T>(context).Folding(args[0])}; |
761 | argCon && argCon->empty()) { |
762 | } else if (const auto *posCon{Folder<Int4>(context).Folding(args[1])}) { |
763 | for (const auto &scalar : posCon->values()) { |
764 | std::int64_t posVal{scalar.ToInt64()}; |
765 | if (posVal < 0) { |
766 | context.messages().Say( |
767 | "bit position for %s (%jd) is negative"_err_en_US , name, |
768 | std::intmax_t{posVal}); |
769 | break; |
770 | } else if (posVal >= T::Scalar::bits) { |
771 | context.messages().Say( |
772 | "bit position for %s (%jd) is not less than %d"_err_en_US , name, |
773 | std::intmax_t{posVal}, T::Scalar::bits); |
774 | break; |
775 | } |
776 | } |
777 | } |
778 | return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef), |
779 | ScalarFunc<T, T, Int4>( |
780 | [&](const Scalar<T> &i, const Scalar<Int4> &pos) -> Scalar<T> { |
781 | return std::invoke(fptr, i, static_cast<int>(pos.ToInt64())); |
782 | })); |
783 | } else if (name == "ibits" ) { |
784 | const auto *posCon{Folder<Int4>(context).Folding(args[1])}; |
785 | const auto *lenCon{Folder<Int4>(context).Folding(args[2])}; |
786 | if (const auto *argCon{Folder<T>(context).Folding(args[0])}; |
787 | argCon && argCon->empty()) { |
788 | } else { |
789 | std::size_t posCt{posCon ? posCon->size() : 0}; |
790 | std::size_t lenCt{lenCon ? lenCon->size() : 0}; |
791 | std::size_t n{std::max(posCt, lenCt)}; |
792 | for (std::size_t j{0}; j < n; ++j) { |
793 | int posVal{j < posCt || posCt == 1 |
794 | ? static_cast<int>(posCon->values()[j % posCt].ToInt64()) |
795 | : 0}; |
796 | int lenVal{j < lenCt || lenCt == 1 |
797 | ? static_cast<int>(lenCon->values()[j % lenCt].ToInt64()) |
798 | : 0}; |
799 | if (posVal < 0) { |
800 | context.messages().Say( |
801 | "bit position for IBITS(POS=%jd) is negative"_err_en_US , |
802 | std::intmax_t{posVal}); |
803 | break; |
804 | } else if (lenVal < 0) { |
805 | context.messages().Say( |
806 | "bit length for IBITS(LEN=%jd) is negative"_err_en_US , |
807 | std::intmax_t{lenVal}); |
808 | break; |
809 | } else if (posVal + lenVal > T::Scalar::bits) { |
810 | context.messages().Say( |
811 | "IBITS() must have POS+LEN (>=%jd) no greater than %d"_err_en_US , |
812 | std::intmax_t{posVal + lenVal}, T::Scalar::bits); |
813 | break; |
814 | } |
815 | } |
816 | } |
817 | return FoldElementalIntrinsic<T, T, Int4, Int4>(context, std::move(funcRef), |
818 | ScalarFunc<T, T, Int4, Int4>( |
819 | [&](const Scalar<T> &i, const Scalar<Int4> &pos, |
820 | const Scalar<Int4> &len) -> Scalar<T> { |
821 | return i.IBITS(static_cast<int>(pos.ToInt64()), |
822 | static_cast<int>(len.ToInt64())); |
823 | })); |
824 | } else if (name == "index" || name == "scan" || name == "verify" ) { |
825 | if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) { |
826 | return common::visit( |
827 | [&](const auto &kch) -> Expr<T> { |
828 | using TC = typename std::decay_t<decltype(kch)>::Result; |
829 | if (UnwrapExpr<Expr<SomeLogical>>(args[2])) { // BACK= |
830 | return FoldElementalIntrinsic<T, TC, TC, LogicalResult>(context, |
831 | std::move(funcRef), |
832 | ScalarFunc<T, TC, TC, LogicalResult>{ |
833 | [&name, &FromInt64](const Scalar<TC> &str, |
834 | const Scalar<TC> &other, |
835 | const Scalar<LogicalResult> &back) { |
836 | return FromInt64(name == "index" |
837 | ? CharacterUtils<TC::kind>::INDEX( |
838 | str, other, back.IsTrue()) |
839 | : name == "scan" |
840 | ? CharacterUtils<TC::kind>::SCAN( |
841 | str, other, back.IsTrue()) |
842 | : CharacterUtils<TC::kind>::VERIFY( |
843 | str, other, back.IsTrue())); |
844 | }}); |
845 | } else { |
846 | return FoldElementalIntrinsic<T, TC, TC>(context, |
847 | std::move(funcRef), |
848 | ScalarFunc<T, TC, TC>{ |
849 | [&name, &FromInt64]( |
850 | const Scalar<TC> &str, const Scalar<TC> &other) { |
851 | return FromInt64(name == "index" |
852 | ? CharacterUtils<TC::kind>::INDEX(str, other) |
853 | : name == "scan" |
854 | ? CharacterUtils<TC::kind>::SCAN(str, other) |
855 | : CharacterUtils<TC::kind>::VERIFY(str, other)); |
856 | }}); |
857 | } |
858 | }, |
859 | charExpr->u); |
860 | } else { |
861 | DIE("first argument must be CHARACTER" ); |
862 | } |
863 | } else if (name == "int" ) { |
864 | if (auto *expr{UnwrapExpr<Expr<SomeType>>(args[0])}) { |
865 | return common::visit( |
866 | [&](auto &&x) -> Expr<T> { |
867 | using From = std::decay_t<decltype(x)>; |
868 | if constexpr (std::is_same_v<From, BOZLiteralConstant> || |
869 | IsNumericCategoryExpr<From>()) { |
870 | return Fold(context, ConvertToType<T>(std::move(x))); |
871 | } |
872 | DIE("int() argument type not valid" ); |
873 | }, |
874 | std::move(expr->u)); |
875 | } |
876 | } else if (name == "int_ptr_kind" ) { |
877 | return Expr<T>{8}; |
878 | } else if (name == "kind" ) { |
879 | // FoldOperation(FunctionRef &&) in fold-implementation.h will not |
880 | // have folded the argument; in the case of TypeParamInquiry, |
881 | // try to get the type of the parameter itself. |
882 | if (const auto *expr{args[0] ? args[0]->UnwrapExpr() : nullptr}) { |
883 | if (const auto *inquiry{UnwrapExpr<TypeParamInquiry>(*expr)}) { |
884 | if (const auto *typeSpec{inquiry->parameter().GetType()}) { |
885 | if (const auto *intrinType{typeSpec->AsIntrinsic()}) { |
886 | if (auto k{ToInt64(Fold( |
887 | context, Expr<SubscriptInteger>{intrinType->kind()}))}) { |
888 | return Expr<T>{*k}; |
889 | } |
890 | } |
891 | } |
892 | } else if (auto dyType{expr->GetType()}) { |
893 | return Expr<T>{dyType->kind()}; |
894 | } |
895 | } |
896 | } else if (name == "iparity" ) { |
897 | return FoldBitReduction( |
898 | context, std::move(funcRef), &Scalar<T>::IEOR, Scalar<T>{}); |
899 | } else if (name == "ishft" || name == "ishftc" ) { |
900 | const auto *argCon{Folder<T>(context).Folding(args[0])}; |
901 | const auto *shiftCon{Folder<Int4>(context).Folding(args[1])}; |
902 | const auto *shiftVals{shiftCon ? &shiftCon->values() : nullptr}; |
903 | const auto *sizeCon{ |
904 | args.size() == 3 ? Folder<Int4>(context).Folding(args[2]) : nullptr}; |
905 | const auto *sizeVals{sizeCon ? &sizeCon->values() : nullptr}; |
906 | if ((argCon && argCon->empty()) || !shiftVals || shiftVals->empty() || |
907 | (sizeVals && sizeVals->empty())) { |
908 | // size= and shift= values don't need to be checked |
909 | } else { |
910 | for (const auto &scalar : *shiftVals) { |
911 | std::int64_t shiftVal{scalar.ToInt64()}; |
912 | if (shiftVal < -T::Scalar::bits) { |
913 | context.messages().Say( |
914 | "SHIFT=%jd count for %s is less than %d"_err_en_US , |
915 | std::intmax_t{shiftVal}, name, -T::Scalar::bits); |
916 | break; |
917 | } else if (shiftVal > T::Scalar::bits) { |
918 | context.messages().Say( |
919 | "SHIFT=%jd count for %s is greater than %d"_err_en_US , |
920 | std::intmax_t{shiftVal}, name, T::Scalar::bits); |
921 | break; |
922 | } |
923 | } |
924 | if (sizeVals) { |
925 | for (const auto &scalar : *sizeVals) { |
926 | std::int64_t sizeVal{scalar.ToInt64()}; |
927 | if (sizeVal <= 0) { |
928 | context.messages().Say( |
929 | "SIZE=%jd count for ishftc is not positive"_err_en_US , |
930 | std::intmax_t{sizeVal}, name); |
931 | break; |
932 | } else if (sizeVal > T::Scalar::bits) { |
933 | context.messages().Say( |
934 | "SIZE=%jd count for ishftc is greater than %d"_err_en_US , |
935 | std::intmax_t{sizeVal}, T::Scalar::bits); |
936 | break; |
937 | } |
938 | } |
939 | if (shiftVals->size() == 1 || sizeVals->size() == 1 || |
940 | shiftVals->size() == sizeVals->size()) { |
941 | auto iters{std::max(shiftVals->size(), sizeVals->size())}; |
942 | for (std::size_t j{0}; j < iters; ++j) { |
943 | auto shiftVal{static_cast<int>( |
944 | (*shiftVals)[j % shiftVals->size()].ToInt64())}; |
945 | auto sizeVal{ |
946 | static_cast<int>((*sizeVals)[j % sizeVals->size()].ToInt64())}; |
947 | if (sizeVal > 0 && std::abs(shiftVal) > sizeVal) { |
948 | context.messages().Say( |
949 | "SHIFT=%jd count for ishftc is greater in magnitude than SIZE=%jd"_err_en_US , |
950 | std::intmax_t{shiftVal}, std::intmax_t{sizeVal}); |
951 | break; |
952 | } |
953 | } |
954 | } |
955 | } |
956 | } |
957 | if (name == "ishft" ) { |
958 | return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef), |
959 | ScalarFunc<T, T, Int4>( |
960 | [&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> { |
961 | return i.ISHFT(static_cast<int>(shift.ToInt64())); |
962 | })); |
963 | } else if (!args.at(2)) { // ISHFTC(no SIZE=) |
964 | return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef), |
965 | ScalarFunc<T, T, Int4>( |
966 | [&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> { |
967 | return i.ISHFTC(static_cast<int>(shift.ToInt64())); |
968 | })); |
969 | } else { // ISHFTC(with SIZE=) |
970 | return FoldElementalIntrinsic<T, T, Int4, Int4>(context, |
971 | std::move(funcRef), |
972 | ScalarFunc<T, T, Int4, Int4>( |
973 | [&](const Scalar<T> &i, const Scalar<Int4> &shift, |
974 | const Scalar<Int4> &size) -> Scalar<T> { |
975 | auto shiftVal{static_cast<int>(shift.ToInt64())}; |
976 | auto sizeVal{static_cast<int>(size.ToInt64())}; |
977 | return i.ISHFTC(shiftVal, sizeVal); |
978 | })); |
979 | } |
980 | } else if (name == "izext" || name == "jzext" ) { |
981 | if (args.size() == 1) { |
982 | if (auto *expr{UnwrapExpr<Expr<SomeInteger>>(args[0])}) { |
983 | // Rewrite to IAND(INT(n,k),255_k) for k=KIND(T) |
984 | intrinsic->name = "iand" ; |
985 | auto converted{ConvertToType<T>(std::move(*expr))}; |
986 | *expr = Fold(context, Expr<SomeInteger>{std::move(converted)}); |
987 | args.emplace_back(AsGenericExpr(Expr<T>{Scalar<T>{255}})); |
988 | return FoldIntrinsicFunction(context, std::move(funcRef)); |
989 | } |
990 | } |
991 | } else if (name == "lbound" ) { |
992 | return LBOUND(context, std::move(funcRef)); |
993 | } else if (name == "leadz" || name == "trailz" || name == "poppar" || |
994 | name == "popcnt" ) { |
995 | if (auto *sn{UnwrapExpr<Expr<SomeInteger>>(args[0])}) { |
996 | return common::visit( |
997 | [&funcRef, &context, &name](const auto &n) -> Expr<T> { |
998 | using TI = typename std::decay_t<decltype(n)>::Result; |
999 | if (name == "poppar" ) { |
1000 | return FoldElementalIntrinsic<T, TI>(context, std::move(funcRef), |
1001 | ScalarFunc<T, TI>([](const Scalar<TI> &i) -> Scalar<T> { |
1002 | return Scalar<T>{i.POPPAR() ? 1 : 0}; |
1003 | })); |
1004 | } |
1005 | auto fptr{&Scalar<TI>::LEADZ}; |
1006 | if (name == "leadz" ) { // done in fptr definition |
1007 | } else if (name == "trailz" ) { |
1008 | fptr = &Scalar<TI>::TRAILZ; |
1009 | } else if (name == "popcnt" ) { |
1010 | fptr = &Scalar<TI>::POPCNT; |
1011 | } else { |
1012 | common::die( |
1013 | "missing case to fold intrinsic function %s" , name.c_str()); |
1014 | } |
1015 | return FoldElementalIntrinsic<T, TI>(context, std::move(funcRef), |
1016 | // `i` should be declared as `const Scalar<TI>&`. |
1017 | // We declare it as `auto` to workaround an msvc bug: |
1018 | // https://developercommunity.visualstudio.com/t/Regression:-nested-closure-assumes-wrong/10130223 |
1019 | ScalarFunc<T, TI>([&fptr](const auto &i) -> Scalar<T> { |
1020 | return Scalar<T>{std::invoke(fptr, i)}; |
1021 | })); |
1022 | }, |
1023 | sn->u); |
1024 | } else { |
1025 | DIE("leadz argument must be integer" ); |
1026 | } |
1027 | } else if (name == "len" ) { |
1028 | if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) { |
1029 | return common::visit( |
1030 | [&](auto &kx) { |
1031 | if (auto len{kx.LEN()}) { |
1032 | if (IsScopeInvariantExpr(*len)) { |
1033 | return Fold(context, ConvertToType<T>(*std::move(len))); |
1034 | } else { |
1035 | return Expr<T>{std::move(funcRef)}; |
1036 | } |
1037 | } else { |
1038 | return Expr<T>{std::move(funcRef)}; |
1039 | } |
1040 | }, |
1041 | charExpr->u); |
1042 | } else { |
1043 | DIE("len() argument must be of character type" ); |
1044 | } |
1045 | } else if (name == "len_trim" ) { |
1046 | if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) { |
1047 | return common::visit( |
1048 | [&](const auto &kch) -> Expr<T> { |
1049 | using TC = typename std::decay_t<decltype(kch)>::Result; |
1050 | return FoldElementalIntrinsic<T, TC>(context, std::move(funcRef), |
1051 | ScalarFunc<T, TC>{[&FromInt64](const Scalar<TC> &str) { |
1052 | return FromInt64(CharacterUtils<TC::kind>::LEN_TRIM(str)); |
1053 | }}); |
1054 | }, |
1055 | charExpr->u); |
1056 | } else { |
1057 | DIE("len_trim() argument must be of character type" ); |
1058 | } |
1059 | } else if (name == "maskl" || name == "maskr" ) { |
1060 | // Argument can be of any kind but value has to be smaller than BIT_SIZE. |
1061 | // It can be safely converted to Int4 to simplify. |
1062 | const auto fptr{name == "maskl" ? &Scalar<T>::MASKL : &Scalar<T>::MASKR}; |
1063 | return FoldElementalIntrinsic<T, Int4>(context, std::move(funcRef), |
1064 | ScalarFunc<T, Int4>([&fptr](const Scalar<Int4> &places) -> Scalar<T> { |
1065 | return fptr(static_cast<int>(places.ToInt64())); |
1066 | })); |
1067 | } else if (name == "matmul" ) { |
1068 | return FoldMatmul(context, std::move(funcRef)); |
1069 | } else if (name == "max" ) { |
1070 | return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater); |
1071 | } else if (name == "max0" || name == "max1" ) { |
1072 | return RewriteSpecificMINorMAX(context, std::move(funcRef)); |
1073 | } else if (name == "maxexponent" ) { |
1074 | if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) { |
1075 | return common::visit( |
1076 | [](const auto &x) { |
1077 | using TR = typename std::decay_t<decltype(x)>::Result; |
1078 | return Expr<T>{Scalar<TR>::MAXEXPONENT}; |
1079 | }, |
1080 | sx->u); |
1081 | } |
1082 | } else if (name == "maxloc" ) { |
1083 | return FoldLocation<WhichLocation::Maxloc, T>(context, std::move(funcRef)); |
1084 | } else if (name == "maxval" ) { |
1085 | return FoldMaxvalMinval<T>(context, std::move(funcRef), |
1086 | RelationalOperator::GT, T::Scalar::Least()); |
1087 | } else if (name == "merge_bits" ) { |
1088 | return FoldElementalIntrinsic<T, T, T, T>( |
1089 | context, std::move(funcRef), &Scalar<T>::MERGE_BITS); |
1090 | } else if (name == "min" ) { |
1091 | return FoldMINorMAX(context, std::move(funcRef), Ordering::Less); |
1092 | } else if (name == "min0" || name == "min1" ) { |
1093 | return RewriteSpecificMINorMAX(context, std::move(funcRef)); |
1094 | } else if (name == "minexponent" ) { |
1095 | if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) { |
1096 | return common::visit( |
1097 | [](const auto &x) { |
1098 | using TR = typename std::decay_t<decltype(x)>::Result; |
1099 | return Expr<T>{Scalar<TR>::MINEXPONENT}; |
1100 | }, |
1101 | sx->u); |
1102 | } |
1103 | } else if (name == "minloc" ) { |
1104 | return FoldLocation<WhichLocation::Minloc, T>(context, std::move(funcRef)); |
1105 | } else if (name == "minval" ) { |
1106 | return FoldMaxvalMinval<T>( |
1107 | context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE()); |
1108 | } else if (name == "mod" ) { |
1109 | return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef), |
1110 | ScalarFuncWithContext<T, T, T>( |
1111 | [](FoldingContext &context, const Scalar<T> &x, |
1112 | const Scalar<T> &y) -> Scalar<T> { |
1113 | auto quotRem{x.DivideSigned(y)}; |
1114 | if (quotRem.divisionByZero) { |
1115 | context.messages().Say("mod() by zero"_warn_en_US ); |
1116 | } else if (quotRem.overflow) { |
1117 | context.messages().Say("mod() folding overflowed"_warn_en_US ); |
1118 | } |
1119 | return quotRem.remainder; |
1120 | })); |
1121 | } else if (name == "modulo" ) { |
1122 | return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef), |
1123 | ScalarFuncWithContext<T, T, T>([](FoldingContext &context, |
1124 | const Scalar<T> &x, |
1125 | const Scalar<T> &y) -> Scalar<T> { |
1126 | auto result{x.MODULO(y)}; |
1127 | if (result.overflow) { |
1128 | context.messages().Say("modulo() folding overflowed"_warn_en_US ); |
1129 | } |
1130 | return result.value; |
1131 | })); |
1132 | } else if (name == "not" ) { |
1133 | return FoldElementalIntrinsic<T, T>( |
1134 | context, std::move(funcRef), &Scalar<T>::NOT); |
1135 | } else if (name == "precision" ) { |
1136 | if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) { |
1137 | return Expr<T>{common::visit( |
1138 | [](const auto &kx) { |
1139 | return Scalar<ResultType<decltype(kx)>>::PRECISION; |
1140 | }, |
1141 | cx->u)}; |
1142 | } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) { |
1143 | return Expr<T>{common::visit( |
1144 | [](const auto &kx) { |
1145 | return Scalar<typename ResultType<decltype(kx)>::Part>::PRECISION; |
1146 | }, |
1147 | cx->u)}; |
1148 | } |
1149 | } else if (name == "product" ) { |
1150 | return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{1}); |
1151 | } else if (name == "radix" ) { |
1152 | return Expr<T>{2}; |
1153 | } else if (name == "range" ) { |
1154 | if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) { |
1155 | return Expr<T>{common::visit( |
1156 | [](const auto &kx) { |
1157 | return Scalar<ResultType<decltype(kx)>>::RANGE; |
1158 | }, |
1159 | cx->u)}; |
1160 | } else if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) { |
1161 | return Expr<T>{common::visit( |
1162 | [](const auto &kx) { |
1163 | return Scalar<ResultType<decltype(kx)>>::RANGE; |
1164 | }, |
1165 | cx->u)}; |
1166 | } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) { |
1167 | return Expr<T>{common::visit( |
1168 | [](const auto &kx) { |
1169 | return Scalar<typename ResultType<decltype(kx)>::Part>::RANGE; |
1170 | }, |
1171 | cx->u)}; |
1172 | } |
1173 | } else if (name == "rank" ) { |
1174 | if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) { |
1175 | if (auto named{ExtractNamedEntity(*array)}) { |
1176 | const Symbol &symbol{named->GetLastSymbol()}; |
1177 | if (IsAssumedRank(symbol)) { |
1178 | // DescriptorInquiry can only be placed in expression of kind |
1179 | // DescriptorInquiry::Result::kind. |
1180 | return ConvertToType<T>(Expr< |
1181 | Type<TypeCategory::Integer, DescriptorInquiry::Result::kind>>{ |
1182 | DescriptorInquiry{*named, DescriptorInquiry::Field::Rank}}); |
1183 | } |
1184 | } |
1185 | return Expr<T>{args[0].value().Rank()}; |
1186 | } |
1187 | return Expr<T>{args[0].value().Rank()}; |
1188 | } else if (name == "selected_char_kind" ) { |
1189 | if (const auto *chCon{UnwrapExpr<Constant<TypeOf<std::string>>>(args[0])}) { |
1190 | if (std::optional<std::string> value{chCon->GetScalarValue()}) { |
1191 | int defaultKind{ |
1192 | context.defaults().GetDefaultKind(TypeCategory::Character)}; |
1193 | return Expr<T>{SelectedCharKind(*value, defaultKind)}; |
1194 | } |
1195 | } |
1196 | } else if (name == "selected_int_kind" ) { |
1197 | if (auto p{ToInt64(args[0])}) { |
1198 | return Expr<T>{context.targetCharacteristics().SelectedIntKind(*p)}; |
1199 | } |
1200 | } else if (name == "selected_logical_kind" ) { |
1201 | if (auto p{ToInt64(args[0])}) { |
1202 | return Expr<T>{context.targetCharacteristics().SelectedLogicalKind(*p)}; |
1203 | } |
1204 | } else if (name == "selected_real_kind" || |
1205 | name == "__builtin_ieee_selected_real_kind" ) { |
1206 | if (auto p{GetInt64ArgOr(args[0], 0)}) { |
1207 | if (auto r{GetInt64ArgOr(args[1], 0)}) { |
1208 | if (auto radix{GetInt64ArgOr(args[2], 2)}) { |
1209 | return Expr<T>{ |
1210 | context.targetCharacteristics().SelectedRealKind(*p, *r, *radix)}; |
1211 | } |
1212 | } |
1213 | } |
1214 | } else if (name == "shape" ) { |
1215 | if (auto shape{GetContextFreeShape(context, args[0])}) { |
1216 | if (auto shapeExpr{AsExtentArrayExpr(*shape)}) { |
1217 | return Fold(context, ConvertToType<T>(std::move(*shapeExpr))); |
1218 | } |
1219 | } |
1220 | } else if (name == "shifta" || name == "shiftr" || name == "shiftl" ) { |
1221 | // Second argument can be of any kind. However, it must be smaller or |
1222 | // equal than BIT_SIZE. It can be converted to Int4 to simplify. |
1223 | auto fptr{&Scalar<T>::SHIFTA}; |
1224 | if (name == "shifta" ) { // done in fptr definition |
1225 | } else if (name == "shiftr" ) { |
1226 | fptr = &Scalar<T>::SHIFTR; |
1227 | } else if (name == "shiftl" ) { |
1228 | fptr = &Scalar<T>::SHIFTL; |
1229 | } else { |
1230 | common::die("missing case to fold intrinsic function %s" , name.c_str()); |
1231 | } |
1232 | if (const auto *argCon{Folder<T>(context).Folding(args[0])}; |
1233 | argCon && argCon->empty()) { |
1234 | } else if (const auto *shiftCon{Folder<Int4>(context).Folding(args[1])}) { |
1235 | for (const auto &scalar : shiftCon->values()) { |
1236 | std::int64_t shiftVal{scalar.ToInt64()}; |
1237 | if (shiftVal < 0) { |
1238 | context.messages().Say("SHIFT=%jd count for %s is negative"_err_en_US , |
1239 | std::intmax_t{shiftVal}, name, -T::Scalar::bits); |
1240 | break; |
1241 | } else if (shiftVal > T::Scalar::bits) { |
1242 | context.messages().Say( |
1243 | "SHIFT=%jd count for %s is greater than %d"_err_en_US , |
1244 | std::intmax_t{shiftVal}, name, T::Scalar::bits); |
1245 | break; |
1246 | } |
1247 | } |
1248 | } |
1249 | return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef), |
1250 | ScalarFunc<T, T, Int4>( |
1251 | [&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> { |
1252 | return std::invoke(fptr, i, static_cast<int>(shift.ToInt64())); |
1253 | })); |
1254 | } else if (name == "sign" ) { |
1255 | return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef), |
1256 | ScalarFunc<T, T, T>([&context](const Scalar<T> &j, |
1257 | const Scalar<T> &k) -> Scalar<T> { |
1258 | typename Scalar<T>::ValueWithOverflow result{j.SIGN(k)}; |
1259 | if (result.overflow) { |
1260 | context.messages().Say( |
1261 | "sign(integer(kind=%d)) folding overflowed"_warn_en_US , KIND); |
1262 | } |
1263 | return result.value; |
1264 | })); |
1265 | } else if (name == "size" ) { |
1266 | if (auto shape{GetContextFreeShape(context, args[0])}) { |
1267 | if (args[1]) { // DIM= is present, get one extent |
1268 | std::optional<int> dim; |
1269 | if (const auto *array{args[0].value().UnwrapExpr()}; array && |
1270 | !CheckDimArg(args[1], *array, context.messages(), false, dim)) { |
1271 | return MakeInvalidIntrinsic<T>(std::move(funcRef)); |
1272 | } else if (dim) { |
1273 | if (auto &extent{shape->at(*dim)}) { |
1274 | return Fold(context, ConvertToType<T>(std::move(*extent))); |
1275 | } |
1276 | } |
1277 | } else if (auto extents{common::AllElementsPresent(std::move(*shape))}) { |
1278 | // DIM= is absent; compute PRODUCT(SHAPE()) |
1279 | ExtentExpr product{1}; |
1280 | for (auto &&extent : std::move(*extents)) { |
1281 | product = std::move(product) * std::move(extent); |
1282 | } |
1283 | return Expr<T>{ConvertToType<T>(Fold(context, std::move(product)))}; |
1284 | } |
1285 | } |
1286 | } else if (name == "sizeof" ) { // in bytes; extension |
1287 | if (auto info{ |
1288 | characteristics::TypeAndShape::Characterize(args[0], context)}) { |
1289 | if (auto bytes{info->MeasureSizeInBytes(context)}) { |
1290 | return Expr<T>{Fold(context, ConvertToType<T>(std::move(*bytes)))}; |
1291 | } |
1292 | } |
1293 | } else if (name == "storage_size" ) { // in bits |
1294 | if (auto info{ |
1295 | characteristics::TypeAndShape::Characterize(args[0], context)}) { |
1296 | if (auto bytes{info->MeasureElementSizeInBytes(context, true)}) { |
1297 | return Expr<T>{ |
1298 | Fold(context, Expr<T>{8} * ConvertToType<T>(std::move(*bytes)))}; |
1299 | } |
1300 | } |
1301 | } else if (name == "sum" ) { |
1302 | return FoldSum<T>(context, std::move(funcRef)); |
1303 | } else if (name == "ubound" ) { |
1304 | return UBOUND(context, std::move(funcRef)); |
1305 | } else if (name == "__builtin_numeric_storage_size" ) { |
1306 | if (!context.moduleFileName()) { |
1307 | // Don't fold this reference until it appears in the module file |
1308 | // for ISO_FORTRAN_ENV -- the value depends on the compiler options |
1309 | // that might be in force. |
1310 | } else { |
1311 | auto intBytes{ |
1312 | context.targetCharacteristics().GetByteSize(TypeCategory::Integer, |
1313 | context.defaults().GetDefaultKind(TypeCategory::Integer))}; |
1314 | auto realBytes{ |
1315 | context.targetCharacteristics().GetByteSize(TypeCategory::Real, |
1316 | context.defaults().GetDefaultKind(TypeCategory::Real))}; |
1317 | if (intBytes != realBytes) { |
1318 | context.messages().Say(*context.moduleFileName(), |
1319 | "NUMERIC_STORAGE_SIZE from ISO_FORTRAN_ENV is not well-defined when default INTEGER and REAL are not consistent due to compiler options"_warn_en_US ); |
1320 | } |
1321 | return Expr<T>{8 * std::min(intBytes, realBytes)}; |
1322 | } |
1323 | } |
1324 | return Expr<T>{std::move(funcRef)}; |
1325 | } |
1326 | |
1327 | // Substitutes a bare type parameter reference with its value if it has one now |
1328 | // in an instantiation. Bare LEN type parameters are substituted only when |
1329 | // the known value is constant. |
1330 | Expr<TypeParamInquiry::Result> FoldOperation( |
1331 | FoldingContext &context, TypeParamInquiry &&inquiry) { |
1332 | std::optional<NamedEntity> base{inquiry.base()}; |
1333 | parser::CharBlock parameterName{inquiry.parameter().name()}; |
1334 | if (base) { |
1335 | // Handling "designator%typeParam". Get the value of the type parameter |
1336 | // from the instantiation of the base |
1337 | if (const semantics::DeclTypeSpec * |
1338 | declType{base->GetLastSymbol().GetType()}) { |
1339 | if (const semantics::ParamValue * |
1340 | paramValue{ |
1341 | declType->derivedTypeSpec().FindParameter(parameterName)}) { |
1342 | const semantics::MaybeIntExpr ¶mExpr{paramValue->GetExplicit()}; |
1343 | if (paramExpr && IsConstantExpr(*paramExpr)) { |
1344 | Expr<SomeInteger> intExpr{*paramExpr}; |
1345 | return Fold(context, |
1346 | ConvertToType<TypeParamInquiry::Result>(std::move(intExpr))); |
1347 | } |
1348 | } |
1349 | } |
1350 | } else { |
1351 | // A "bare" type parameter: replace with its value, if that's now known |
1352 | // in a current derived type instantiation. |
1353 | if (const auto *pdt{context.pdtInstance()}) { |
1354 | auto restorer{context.WithoutPDTInstance()}; // don't loop |
1355 | bool isLen{false}; |
1356 | if (const semantics::Scope * scope{pdt->scope()}) { |
1357 | auto iter{scope->find(parameterName)}; |
1358 | if (iter != scope->end()) { |
1359 | const Symbol &symbol{*iter->second}; |
1360 | const auto *details{symbol.detailsIf<semantics::TypeParamDetails>()}; |
1361 | if (details) { |
1362 | isLen = details->attr() == common::TypeParamAttr::Len; |
1363 | const semantics::MaybeIntExpr &initExpr{details->init()}; |
1364 | if (initExpr && IsConstantExpr(*initExpr) && |
1365 | (!isLen || ToInt64(*initExpr))) { |
1366 | Expr<SomeInteger> expr{*initExpr}; |
1367 | return Fold(context, |
1368 | ConvertToType<TypeParamInquiry::Result>(std::move(expr))); |
1369 | } |
1370 | } |
1371 | } |
1372 | } |
1373 | if (const auto *value{pdt->FindParameter(parameterName)}) { |
1374 | if (value->isExplicit()) { |
1375 | auto folded{Fold(context, |
1376 | AsExpr(ConvertToType<TypeParamInquiry::Result>( |
1377 | Expr<SomeInteger>{value->GetExplicit().value()})))}; |
1378 | if (!isLen || ToInt64(folded)) { |
1379 | return folded; |
1380 | } |
1381 | } |
1382 | } |
1383 | } |
1384 | } |
1385 | return AsExpr(std::move(inquiry)); |
1386 | } |
1387 | |
1388 | std::optional<std::int64_t> ToInt64(const Expr<SomeInteger> &expr) { |
1389 | return common::visit( |
1390 | [](const auto &kindExpr) { return ToInt64(kindExpr); }, expr.u); |
1391 | } |
1392 | |
1393 | std::optional<std::int64_t> ToInt64(const Expr<SomeType> &expr) { |
1394 | return ToInt64(UnwrapExpr<Expr<SomeInteger>>(expr)); |
1395 | } |
1396 | |
1397 | std::optional<std::int64_t> ToInt64(const ActualArgument &arg) { |
1398 | return ToInt64(arg.UnwrapExpr()); |
1399 | } |
1400 | |
1401 | #ifdef _MSC_VER // disable bogus warning about missing definitions |
1402 | #pragma warning(disable : 4661) |
1403 | #endif |
1404 | FOR_EACH_INTEGER_KIND(template class ExpressionBase, ) |
1405 | template class ExpressionBase<SomeInteger>; |
1406 | } // namespace Fortran::evaluate |
1407 | |