1//===-- lib/Evaluate/fold-implementation.h --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_
10#define FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_
11
12#include "character.h"
13#include "host.h"
14#include "int-power.h"
15#include "flang/Common/indirection.h"
16#include "flang/Common/template.h"
17#include "flang/Common/unwrap.h"
18#include "flang/Evaluate/characteristics.h"
19#include "flang/Evaluate/common.h"
20#include "flang/Evaluate/constant.h"
21#include "flang/Evaluate/expression.h"
22#include "flang/Evaluate/fold.h"
23#include "flang/Evaluate/formatting.h"
24#include "flang/Evaluate/intrinsics-library.h"
25#include "flang/Evaluate/intrinsics.h"
26#include "flang/Evaluate/shape.h"
27#include "flang/Evaluate/tools.h"
28#include "flang/Evaluate/traverse.h"
29#include "flang/Evaluate/type.h"
30#include "flang/Parser/message.h"
31#include "flang/Semantics/scope.h"
32#include "flang/Semantics/symbol.h"
33#include "flang/Semantics/tools.h"
34#include <algorithm>
35#include <cmath>
36#include <complex>
37#include <cstdio>
38#include <optional>
39#include <type_traits>
40#include <variant>
41
42// Some environments, viz. glibc 2.17 and *BSD, allow the macro HUGE
43// to leak out of <math.h>.
44#undef HUGE
45
46namespace Fortran::evaluate {
47
48// Utilities
49template <typename T> class Folder {
50public:
51 explicit Folder(FoldingContext &c) : context_{c} {}
52 std::optional<Constant<T>> GetNamedConstant(const Symbol &);
53 std::optional<Constant<T>> ApplySubscripts(const Constant<T> &array,
54 const std::vector<Constant<SubscriptInteger>> &subscripts);
55 std::optional<Constant<T>> ApplyComponent(Constant<SomeDerived> &&,
56 const Symbol &component,
57 const std::vector<Constant<SubscriptInteger>> * = nullptr);
58 std::optional<Constant<T>> GetConstantComponent(
59 Component &, const std::vector<Constant<SubscriptInteger>> * = nullptr);
60 std::optional<Constant<T>> Folding(ArrayRef &);
61 std::optional<Constant<T>> Folding(DataRef &);
62 Expr<T> Folding(Designator<T> &&);
63 Constant<T> *Folding(std::optional<ActualArgument> &);
64
65 Expr<T> CSHIFT(FunctionRef<T> &&);
66 Expr<T> EOSHIFT(FunctionRef<T> &&);
67 Expr<T> MERGE(FunctionRef<T> &&);
68 Expr<T> PACK(FunctionRef<T> &&);
69 Expr<T> RESHAPE(FunctionRef<T> &&);
70 Expr<T> SPREAD(FunctionRef<T> &&);
71 Expr<T> TRANSPOSE(FunctionRef<T> &&);
72 Expr<T> UNPACK(FunctionRef<T> &&);
73
74 Expr<T> TRANSFER(FunctionRef<T> &&);
75
76private:
77 FoldingContext &context_;
78};
79
80std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
81 FoldingContext &, Subscript &, const NamedEntity &, int dim);
82
83// Helper to use host runtime on scalars for folding.
84template <typename TR, typename... TA>
85std::optional<std::function<Scalar<TR>(FoldingContext &, Scalar<TA>...)>>
86GetHostRuntimeWrapper(const std::string &name) {
87 std::vector<DynamicType> argTypes{TA{}.GetType()...};
88 if (auto hostWrapper{GetHostRuntimeWrapper(name, TR{}.GetType(), argTypes)}) {
89 return [hostWrapper](
90 FoldingContext &context, Scalar<TA>... args) -> Scalar<TR> {
91 std::vector<Expr<SomeType>> genericArgs{
92 AsGenericExpr(Constant<TA>{args})...};
93 return GetScalarConstantValue<TR>(
94 (*hostWrapper)(context, std::move(genericArgs)))
95 .value();
96 };
97 }
98 return std::nullopt;
99}
100
101// FoldOperation() rewrites expression tree nodes.
102// If there is any possibility that the rewritten node will
103// not have the same representation type, the result of
104// FoldOperation() will be packaged in an Expr<> of the same
105// specific type.
106
107// no-op base case
108template <typename A>
109common::IfNoLvalue<Expr<ResultType<A>>, A> FoldOperation(
110 FoldingContext &, A &&x) {
111 static_assert(!std::is_same_v<A, Expr<ResultType<A>>>,
112 "call Fold() instead for Expr<>");
113 return Expr<ResultType<A>>{std::move(x)};
114}
115
116Component FoldOperation(FoldingContext &, Component &&);
117NamedEntity FoldOperation(FoldingContext &, NamedEntity &&);
118Triplet FoldOperation(FoldingContext &, Triplet &&);
119Subscript FoldOperation(FoldingContext &, Subscript &&);
120ArrayRef FoldOperation(FoldingContext &, ArrayRef &&);
121CoarrayRef FoldOperation(FoldingContext &, CoarrayRef &&);
122DataRef FoldOperation(FoldingContext &, DataRef &&);
123Substring FoldOperation(FoldingContext &, Substring &&);
124ComplexPart FoldOperation(FoldingContext &, ComplexPart &&);
125template <typename T>
126Expr<T> FoldOperation(FoldingContext &, FunctionRef<T> &&);
127template <typename T>
128Expr<T> FoldOperation(FoldingContext &context, Designator<T> &&designator) {
129 return Folder<T>{context}.Folding(std::move(designator));
130}
131Expr<TypeParamInquiry::Result> FoldOperation(
132 FoldingContext &, TypeParamInquiry &&);
133Expr<ImpliedDoIndex::Result> FoldOperation(
134 FoldingContext &context, ImpliedDoIndex &&);
135template <typename T>
136Expr<T> FoldOperation(FoldingContext &, ArrayConstructor<T> &&);
137Expr<SomeDerived> FoldOperation(FoldingContext &, StructureConstructor &&);
138
139template <typename T>
140std::optional<Constant<T>> Folder<T>::GetNamedConstant(const Symbol &symbol0) {
141 const Symbol &symbol{ResolveAssociations(symbol0)};
142 if (IsNamedConstant(symbol)) {
143 if (const auto *object{
144 symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
145 if (const auto *constant{UnwrapConstantValue<T>(object->init())}) {
146 return *constant;
147 }
148 }
149 }
150 return std::nullopt;
151}
152
153template <typename T>
154std::optional<Constant<T>> Folder<T>::Folding(ArrayRef &aRef) {
155 std::vector<Constant<SubscriptInteger>> subscripts;
156 int dim{0};
157 for (Subscript &ss : aRef.subscript()) {
158 if (auto constant{GetConstantSubscript(context_, ss, aRef.base(), dim++)}) {
159 subscripts.emplace_back(std::move(*constant));
160 } else {
161 return std::nullopt;
162 }
163 }
164 if (Component * component{aRef.base().UnwrapComponent()}) {
165 return GetConstantComponent(*component, &subscripts);
166 } else if (std::optional<Constant<T>> array{
167 GetNamedConstant(aRef.base().GetLastSymbol())}) {
168 return ApplySubscripts(*array, subscripts);
169 } else {
170 return std::nullopt;
171 }
172}
173
174template <typename T>
175std::optional<Constant<T>> Folder<T>::Folding(DataRef &ref) {
176 return common::visit(
177 common::visitors{
178 [this](SymbolRef &sym) { return GetNamedConstant(*sym); },
179 [this](Component &comp) {
180 comp = FoldOperation(context_, std::move(comp));
181 return GetConstantComponent(comp);
182 },
183 [this](ArrayRef &aRef) {
184 aRef = FoldOperation(context_, std::move(aRef));
185 return Folding(aRef);
186 },
187 [](CoarrayRef &) { return std::optional<Constant<T>>{}; },
188 },
189 ref.u);
190}
191
192// TODO: This would be more natural as a member function of Constant<T>.
193template <typename T>
194std::optional<Constant<T>> Folder<T>::ApplySubscripts(const Constant<T> &array,
195 const std::vector<Constant<SubscriptInteger>> &subscripts) {
196 const auto &shape{array.shape()};
197 const auto &lbounds{array.lbounds()};
198 int rank{GetRank(shape)};
199 CHECK(rank == static_cast<int>(subscripts.size()));
200 std::size_t elements{1};
201 ConstantSubscripts resultShape;
202 ConstantSubscripts ssLB;
203 for (const auto &ss : subscripts) {
204 if (ss.Rank() == 1) {
205 resultShape.push_back(static_cast<ConstantSubscript>(ss.size()));
206 elements *= ss.size();
207 ssLB.push_back(ss.lbounds().front());
208 } else if (ss.Rank() > 1) {
209 return std::nullopt; // error recovery
210 }
211 }
212 ConstantSubscripts ssAt(rank, 0), at(rank, 0), tmp(1, 0);
213 std::vector<Scalar<T>> values;
214 while (elements-- > 0) {
215 bool increment{true};
216 int k{0};
217 for (int j{0}; j < rank; ++j) {
218 if (subscripts[j].Rank() == 0) {
219 at[j] = subscripts[j].GetScalarValue().value().ToInt64();
220 } else {
221 CHECK(k < GetRank(resultShape));
222 tmp[0] = ssLB.at(k) + ssAt.at(k);
223 at[j] = subscripts[j].At(tmp).ToInt64();
224 if (increment) {
225 if (++ssAt[k] == resultShape[k]) {
226 ssAt[k] = 0;
227 } else {
228 increment = false;
229 }
230 }
231 ++k;
232 }
233 if (at[j] < lbounds[j] || at[j] >= lbounds[j] + shape[j]) {
234 context_.messages().Say(
235 "Subscript value (%jd) is out of range on dimension %d in reference to a constant array value"_err_en_US,
236 at[j], j + 1);
237 return std::nullopt;
238 }
239 }
240 values.emplace_back(array.At(at));
241 CHECK(!increment || elements == 0);
242 CHECK(k == GetRank(resultShape));
243 }
244 if constexpr (T::category == TypeCategory::Character) {
245 return Constant<T>{array.LEN(), std::move(values), std::move(resultShape)};
246 } else if constexpr (std::is_same_v<T, SomeDerived>) {
247 return Constant<T>{array.result().derivedTypeSpec(), std::move(values),
248 std::move(resultShape)};
249 } else {
250 return Constant<T>{std::move(values), std::move(resultShape)};
251 }
252}
253
254template <typename T>
255std::optional<Constant<T>> Folder<T>::ApplyComponent(
256 Constant<SomeDerived> &&structures, const Symbol &component,
257 const std::vector<Constant<SubscriptInteger>> *subscripts) {
258 if (auto scalar{structures.GetScalarValue()}) {
259 if (std::optional<Expr<SomeType>> expr{scalar->Find(component)}) {
260 if (const Constant<T> *value{UnwrapConstantValue<T>(*expr)}) {
261 if (subscripts) {
262 return ApplySubscripts(*value, *subscripts);
263 } else {
264 return *value;
265 }
266 }
267 }
268 } else {
269 // A(:)%scalar_component & A(:)%array_component(subscripts)
270 std::unique_ptr<ArrayConstructor<T>> array;
271 if (structures.empty()) {
272 return std::nullopt;
273 }
274 ConstantSubscripts at{structures.lbounds()};
275 do {
276 StructureConstructor scalar{structures.At(at)};
277 if (std::optional<Expr<SomeType>> expr{scalar.Find(component)}) {
278 if (const Constant<T> *value{UnwrapConstantValue<T>(expr.value())}) {
279 if (!array.get()) {
280 // This technique ensures that character length or derived type
281 // information is propagated to the array constructor.
282 auto *typedExpr{UnwrapExpr<Expr<T>>(expr.value())};
283 CHECK(typedExpr);
284 array = std::make_unique<ArrayConstructor<T>>(*typedExpr);
285 }
286 if (subscripts) {
287 if (auto element{ApplySubscripts(*value, *subscripts)}) {
288 CHECK(element->Rank() == 0);
289 array->Push(Expr<T>{std::move(*element)});
290 } else {
291 return std::nullopt;
292 }
293 } else {
294 CHECK(value->Rank() == 0);
295 array->Push(Expr<T>{*value});
296 }
297 } else {
298 return std::nullopt;
299 }
300 }
301 } while (structures.IncrementSubscripts(at));
302 // Fold the ArrayConstructor<> into a Constant<>.
303 CHECK(array);
304 Expr<T> result{Fold(context_, Expr<T>{std::move(*array)})};
305 if (auto *constant{UnwrapConstantValue<T>(result)}) {
306 return constant->Reshape(common::Clone(structures.shape()));
307 }
308 }
309 return std::nullopt;
310}
311
312template <typename T>
313std::optional<Constant<T>> Folder<T>::GetConstantComponent(Component &component,
314 const std::vector<Constant<SubscriptInteger>> *subscripts) {
315 if (std::optional<Constant<SomeDerived>> structures{common::visit(
316 common::visitors{
317 [&](const Symbol &symbol) {
318 return Folder<SomeDerived>{context_}.GetNamedConstant(symbol);
319 },
320 [&](ArrayRef &aRef) {
321 return Folder<SomeDerived>{context_}.Folding(aRef);
322 },
323 [&](Component &base) {
324 return Folder<SomeDerived>{context_}.GetConstantComponent(base);
325 },
326 [&](CoarrayRef &) {
327 return std::optional<Constant<SomeDerived>>{};
328 },
329 },
330 component.base().u)}) {
331 return ApplyComponent(
332 std::move(*structures), component.GetLastSymbol(), subscripts);
333 } else {
334 return std::nullopt;
335 }
336}
337
338template <typename T> Expr<T> Folder<T>::Folding(Designator<T> &&designator) {
339 if constexpr (T::category == TypeCategory::Character) {
340 if (auto *substring{common::Unwrap<Substring>(designator.u)}) {
341 if (std::optional<Expr<SomeCharacter>> folded{
342 substring->Fold(context_)}) {
343 if (const auto *specific{std::get_if<Expr<T>>(&folded->u)}) {
344 return std::move(*specific);
345 }
346 }
347 // We used to fold zero-length substrings into zero-length
348 // constants here, but that led to problems in variable
349 // definition contexts.
350 }
351 } else if constexpr (T::category == TypeCategory::Real) {
352 if (auto *zPart{std::get_if<ComplexPart>(&designator.u)}) {
353 *zPart = FoldOperation(context_, std::move(*zPart));
354 using ComplexT = Type<TypeCategory::Complex, T::kind>;
355 if (auto zConst{Folder<ComplexT>{context_}.Folding(zPart->complex())}) {
356 return Fold(context_,
357 Expr<T>{ComplexComponent<T::kind>{
358 zPart->part() == ComplexPart::Part::IM,
359 Expr<ComplexT>{std::move(*zConst)}}});
360 } else {
361 return Expr<T>{Designator<T>{std::move(*zPart)}};
362 }
363 }
364 }
365 return common::visit(
366 common::visitors{
367 [&](SymbolRef &&symbol) {
368 if (auto constant{GetNamedConstant(*symbol)}) {
369 return Expr<T>{std::move(*constant)};
370 }
371 return Expr<T>{std::move(designator)};
372 },
373 [&](ArrayRef &&aRef) {
374 aRef = FoldOperation(context_, std::move(aRef));
375 if (auto c{Folding(aRef)}) {
376 return Expr<T>{std::move(*c)};
377 } else {
378 return Expr<T>{Designator<T>{std::move(aRef)}};
379 }
380 },
381 [&](Component &&component) {
382 component = FoldOperation(context_, std::move(component));
383 if (auto c{GetConstantComponent(component)}) {
384 return Expr<T>{std::move(*c)};
385 } else {
386 return Expr<T>{Designator<T>{std::move(component)}};
387 }
388 },
389 [&](auto &&x) {
390 return Expr<T>{
391 Designator<T>{FoldOperation(context_, std::move(x))}};
392 },
393 },
394 std::move(designator.u));
395}
396
397// Apply type conversion and re-folding if necessary.
398// This is where BOZ arguments are converted.
399template <typename T>
400Constant<T> *Folder<T>::Folding(std::optional<ActualArgument> &arg) {
401 if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
402 if constexpr (T::category != TypeCategory::Derived) {
403 if (!UnwrapExpr<Expr<T>>(*expr)) {
404 if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
405 *expr = Fold(context_, std::move(*converted));
406 }
407 }
408 }
409 return UnwrapConstantValue<T>(*expr);
410 }
411 return nullptr;
412}
413
414template <typename... A, std::size_t... I>
415std::optional<std::tuple<const Constant<A> *...>> GetConstantArgumentsHelper(
416 FoldingContext &context, ActualArguments &arguments,
417 std::index_sequence<I...>) {
418 static_assert(sizeof...(A) > 0);
419 std::tuple<const Constant<A> *...> args{
420 Folder<A>{context}.Folding(arguments.at(I))...};
421 if ((... && (std::get<I>(args)))) {
422 return args;
423 } else {
424 return std::nullopt;
425 }
426}
427
428template <typename... A>
429std::optional<std::tuple<const Constant<A> *...>> GetConstantArguments(
430 FoldingContext &context, ActualArguments &args) {
431 return GetConstantArgumentsHelper<A...>(
432 context, args, std::index_sequence_for<A...>{});
433}
434
435template <typename... A, std::size_t... I>
436std::optional<std::tuple<Scalar<A>...>> GetScalarConstantArgumentsHelper(
437 FoldingContext &context, ActualArguments &args, std::index_sequence<I...>) {
438 if (auto constArgs{GetConstantArguments<A...>(context, args)}) {
439 return std::tuple<Scalar<A>...>{
440 std::get<I>(*constArgs)->GetScalarValue().value()...};
441 } else {
442 return std::nullopt;
443 }
444}
445
446template <typename... A>
447std::optional<std::tuple<Scalar<A>...>> GetScalarConstantArguments(
448 FoldingContext &context, ActualArguments &args) {
449 return GetScalarConstantArgumentsHelper<A...>(
450 context, args, std::index_sequence_for<A...>{});
451}
452
453// helpers to fold intrinsic function references
454// Define callable types used in a common utility that
455// takes care of array and cast/conversion aspects for elemental intrinsics
456
457template <typename TR, typename... TArgs>
458using ScalarFunc = std::function<Scalar<TR>(const Scalar<TArgs> &...)>;
459template <typename TR, typename... TArgs>
460using ScalarFuncWithContext =
461 std::function<Scalar<TR>(FoldingContext &, const Scalar<TArgs> &...)>;
462
463template <template <typename, typename...> typename WrapperType, typename TR,
464 typename... TA, std::size_t... I>
465Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
466 FunctionRef<TR> &&funcRef, WrapperType<TR, TA...> func,
467 std::index_sequence<I...>) {
468 if (std::optional<std::tuple<const Constant<TA> *...>> args{
469 GetConstantArguments<TA...>(context, funcRef.arguments())}) {
470 // Compute the shape of the result based on shapes of arguments
471 ConstantSubscripts shape;
472 int rank{0};
473 const ConstantSubscripts *shapes[]{&std::get<I>(*args)->shape()...};
474 const int ranks[]{std::get<I>(*args)->Rank()...};
475 for (unsigned int i{0}; i < sizeof...(TA); ++i) {
476 if (ranks[i] > 0) {
477 if (rank == 0) {
478 rank = ranks[i];
479 shape = *shapes[i];
480 } else {
481 if (shape != *shapes[i]) {
482 // TODO: Rank compatibility was already checked but it seems to be
483 // the first place where the actual shapes are checked to be the
484 // same. Shouldn't this be checked elsewhere so that this is also
485 // checked for non constexpr call to elemental intrinsics function?
486 context.messages().Say(
487 "Arguments in elemental intrinsic function are not conformable"_err_en_US);
488 return Expr<TR>{std::move(funcRef)};
489 }
490 }
491 }
492 }
493 CHECK(rank == GetRank(shape));
494 // Compute all the scalar values of the results
495 std::vector<Scalar<TR>> results;
496 std::optional<uint64_t> n{TotalElementCount(shape)};
497 if (!n) {
498 context.messages().Say(
499 "Too many elements in elemental intrinsic function result"_err_en_US);
500 return Expr<TR>{std::move(funcRef)};
501 }
502 if (*n > 0) {
503 ConstantBounds bounds{shape};
504 ConstantSubscripts resultIndex(rank, 1);
505 ConstantSubscripts argIndex[]{std::get<I>(*args)->lbounds()...};
506 do {
507 if constexpr (std::is_same_v<WrapperType<TR, TA...>,
508 ScalarFuncWithContext<TR, TA...>>) {
509 results.emplace_back(
510 func(context, std::get<I>(*args)->At(argIndex[I])...));
511 } else if constexpr (std::is_same_v<WrapperType<TR, TA...>,
512 ScalarFunc<TR, TA...>>) {
513 results.emplace_back(func(std::get<I>(*args)->At(argIndex[I])...));
514 }
515 (std::get<I>(*args)->IncrementSubscripts(argIndex[I]), ...);
516 } while (bounds.IncrementSubscripts(resultIndex));
517 }
518 // Build and return constant result
519 if constexpr (TR::category == TypeCategory::Character) {
520 auto len{static_cast<ConstantSubscript>(
521 results.empty() ? 0 : results[0].length())};
522 return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
523 } else if constexpr (TR::category == TypeCategory::Derived) {
524 if (!results.empty()) {
525 return Expr<TR>{rank == 0
526 ? Constant<TR>{results.front()}
527 : Constant<TR>{results.front().derivedTypeSpec(),
528 std::move(results), std::move(shape)}};
529 }
530 } else {
531 return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
532 }
533 }
534 return Expr<TR>{std::move(funcRef)};
535}
536
537template <typename TR, typename... TA>
538Expr<TR> FoldElementalIntrinsic(FoldingContext &context,
539 FunctionRef<TR> &&funcRef, ScalarFunc<TR, TA...> func) {
540 return FoldElementalIntrinsicHelper<ScalarFunc, TR, TA...>(
541 context, std::move(funcRef), func, std::index_sequence_for<TA...>{});
542}
543template <typename TR, typename... TA>
544Expr<TR> FoldElementalIntrinsic(FoldingContext &context,
545 FunctionRef<TR> &&funcRef, ScalarFuncWithContext<TR, TA...> func) {
546 return FoldElementalIntrinsicHelper<ScalarFuncWithContext, TR, TA...>(
547 context, std::move(funcRef), func, std::index_sequence_for<TA...>{});
548}
549
550std::optional<std::int64_t> GetInt64ArgOr(
551 const std::optional<ActualArgument> &, std::int64_t defaultValue);
552
553template <typename A, typename B>
554std::optional<std::vector<A>> GetIntegerVector(const B &x) {
555 static_assert(std::is_integral_v<A>);
556 if (const auto *someInteger{UnwrapExpr<Expr<SomeInteger>>(x)}) {
557 return common::visit(
558 [](const auto &typedExpr) -> std::optional<std::vector<A>> {
559 using T = ResultType<decltype(typedExpr)>;
560 if (const auto *constant{UnwrapConstantValue<T>(typedExpr)}) {
561 if (constant->Rank() == 1) {
562 std::vector<A> result;
563 for (const auto &value : constant->values()) {
564 result.push_back(static_cast<A>(value.ToInt64()));
565 }
566 return result;
567 }
568 }
569 return std::nullopt;
570 },
571 someInteger->u);
572 }
573 return std::nullopt;
574}
575
576// Transform an intrinsic function reference that contains user errors
577// into an intrinsic with the same characteristic but the "invalid" name.
578// This to prevent generating warnings over and over if the expression
579// gets re-folded.
580template <typename T> Expr<T> MakeInvalidIntrinsic(FunctionRef<T> &&funcRef) {
581 SpecificIntrinsic invalid{std::get<SpecificIntrinsic>(funcRef.proc().u)};
582 invalid.name = IntrinsicProcTable::InvalidName;
583 return Expr<T>{FunctionRef<T>{ProcedureDesignator{std::move(invalid)},
584 ActualArguments{std::move(funcRef.arguments())}}};
585}
586
587template <typename T> Expr<T> Folder<T>::CSHIFT(FunctionRef<T> &&funcRef) {
588 auto args{funcRef.arguments()};
589 CHECK(args.size() == 3);
590 const auto *array{UnwrapConstantValue<T>(args[0])};
591 const auto *shiftExpr{UnwrapExpr<Expr<SomeInteger>>(args[1])};
592 auto dim{GetInt64ArgOr(args[2], 1)};
593 if (!array || !shiftExpr || !dim) {
594 return Expr<T>{std::move(funcRef)};
595 }
596 auto convertedShift{Fold(context_,
597 ConvertToType<SubscriptInteger>(Expr<SomeInteger>{*shiftExpr}))};
598 const auto *shift{UnwrapConstantValue<SubscriptInteger>(convertedShift)};
599 if (!shift) {
600 return Expr<T>{std::move(funcRef)};
601 }
602 // Arguments are constant
603 if (*dim < 1 || *dim > array->Rank()) {
604 context_.messages().Say("Invalid 'dim=' argument (%jd) in CSHIFT"_err_en_US,
605 static_cast<std::intmax_t>(*dim));
606 } else if (shift->Rank() > 0 && shift->Rank() != array->Rank() - 1) {
607 // message already emitted from intrinsic look-up
608 } else {
609 int rank{array->Rank()};
610 int zbDim{static_cast<int>(*dim) - 1};
611 bool ok{true};
612 if (shift->Rank() > 0) {
613 int k{0};
614 for (int j{0}; j < rank; ++j) {
615 if (j != zbDim) {
616 if (array->shape()[j] != shift->shape()[k]) {
617 context_.messages().Say(
618 "Invalid 'shift=' argument in CSHIFT: extent on dimension %d is %jd but must be %jd"_err_en_US,
619 k + 1, static_cast<std::intmax_t>(shift->shape()[k]),
620 static_cast<std::intmax_t>(array->shape()[j]));
621 ok = false;
622 }
623 ++k;
624 }
625 }
626 }
627 if (ok) {
628 std::vector<Scalar<T>> resultElements;
629 ConstantSubscripts arrayLB{array->lbounds()};
630 ConstantSubscripts arrayAt{arrayLB};
631 ConstantSubscript &dimIndex{arrayAt[zbDim]};
632 ConstantSubscript dimLB{dimIndex}; // initial value
633 ConstantSubscript dimExtent{array->shape()[zbDim]};
634 ConstantSubscripts shiftLB{shift->lbounds()};
635 for (auto n{GetSize(array->shape())}; n > 0; --n) {
636 ConstantSubscript origDimIndex{dimIndex};
637 ConstantSubscripts shiftAt;
638 if (shift->Rank() > 0) {
639 int k{0};
640 for (int j{0}; j < rank; ++j) {
641 if (j != zbDim) {
642 shiftAt.emplace_back(shiftLB[k++] + arrayAt[j] - arrayLB[j]);
643 }
644 }
645 }
646 ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
647 dimIndex = dimLB + ((dimIndex - dimLB + shiftCount) % dimExtent);
648 if (dimIndex < dimLB) {
649 dimIndex += dimExtent;
650 } else if (dimIndex >= dimLB + dimExtent) {
651 dimIndex -= dimExtent;
652 }
653 resultElements.push_back(array->At(arrayAt));
654 dimIndex = origDimIndex;
655 array->IncrementSubscripts(arrayAt);
656 }
657 return Expr<T>{PackageConstant<T>(
658 std::move(resultElements), *array, array->shape())};
659 }
660 }
661 // Invalid, prevent re-folding
662 return MakeInvalidIntrinsic(std::move(funcRef));
663}
664
665template <typename T> Expr<T> Folder<T>::EOSHIFT(FunctionRef<T> &&funcRef) {
666 auto args{funcRef.arguments()};
667 CHECK(args.size() == 4);
668 const auto *array{UnwrapConstantValue<T>(args[0])};
669 const auto *shiftExpr{UnwrapExpr<Expr<SomeInteger>>(args[1])};
670 auto dim{GetInt64ArgOr(args[3], 1)};
671 if (!array || !shiftExpr || !dim) {
672 return Expr<T>{std::move(funcRef)};
673 }
674 // Apply type conversions to the shift= and boundary= arguments.
675 auto convertedShift{Fold(context_,
676 ConvertToType<SubscriptInteger>(Expr<SomeInteger>{*shiftExpr}))};
677 const auto *shift{UnwrapConstantValue<SubscriptInteger>(convertedShift)};
678 if (!shift) {
679 return Expr<T>{std::move(funcRef)};
680 }
681 const Constant<T> *boundary{nullptr};
682 std::optional<Expr<SomeType>> convertedBoundary;
683 if (const auto *boundaryExpr{UnwrapExpr<Expr<SomeType>>(args[2])}) {
684 convertedBoundary = Fold(context_,
685 ConvertToType(array->GetType(), Expr<SomeType>{*boundaryExpr}));
686 boundary = UnwrapExpr<Constant<T>>(convertedBoundary);
687 if (!boundary) {
688 return Expr<T>{std::move(funcRef)};
689 }
690 }
691 // Arguments are constant
692 if (*dim < 1 || *dim > array->Rank()) {
693 context_.messages().Say(
694 "Invalid 'dim=' argument (%jd) in EOSHIFT"_err_en_US,
695 static_cast<std::intmax_t>(*dim));
696 } else if (shift->Rank() > 0 && shift->Rank() != array->Rank() - 1) {
697 // message already emitted from intrinsic look-up
698 } else if (boundary && boundary->Rank() > 0 &&
699 boundary->Rank() != array->Rank() - 1) {
700 // ditto
701 } else {
702 int rank{array->Rank()};
703 int zbDim{static_cast<int>(*dim) - 1};
704 bool ok{true};
705 if (shift->Rank() > 0) {
706 int k{0};
707 for (int j{0}; j < rank; ++j) {
708 if (j != zbDim) {
709 if (array->shape()[j] != shift->shape()[k]) {
710 context_.messages().Say(
711 "Invalid 'shift=' argument in EOSHIFT: extent on dimension %d is %jd but must be %jd"_err_en_US,
712 k + 1, static_cast<std::intmax_t>(shift->shape()[k]),
713 static_cast<std::intmax_t>(array->shape()[j]));
714 ok = false;
715 }
716 ++k;
717 }
718 }
719 }
720 if (boundary && boundary->Rank() > 0) {
721 int k{0};
722 for (int j{0}; j < rank; ++j) {
723 if (j != zbDim) {
724 if (array->shape()[j] != boundary->shape()[k]) {
725 context_.messages().Say(
726 "Invalid 'boundary=' argument in EOSHIFT: extent on dimension %d is %jd but must be %jd"_err_en_US,
727 k + 1, static_cast<std::intmax_t>(boundary->shape()[k]),
728 static_cast<std::intmax_t>(array->shape()[j]));
729 ok = false;
730 }
731 ++k;
732 }
733 }
734 }
735 if (ok) {
736 std::vector<Scalar<T>> resultElements;
737 ConstantSubscripts arrayLB{array->lbounds()};
738 ConstantSubscripts arrayAt{arrayLB};
739 ConstantSubscript &dimIndex{arrayAt[zbDim]};
740 ConstantSubscript dimLB{dimIndex}; // initial value
741 ConstantSubscript dimExtent{array->shape()[zbDim]};
742 ConstantSubscripts shiftLB{shift->lbounds()};
743 ConstantSubscripts boundaryLB;
744 if (boundary) {
745 boundaryLB = boundary->lbounds();
746 }
747 for (auto n{GetSize(array->shape())}; n > 0; --n) {
748 ConstantSubscript origDimIndex{dimIndex};
749 ConstantSubscripts shiftAt;
750 if (shift->Rank() > 0) {
751 int k{0};
752 for (int j{0}; j < rank; ++j) {
753 if (j != zbDim) {
754 shiftAt.emplace_back(shiftLB[k++] + arrayAt[j] - arrayLB[j]);
755 }
756 }
757 }
758 ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
759 dimIndex += shiftCount;
760 if (dimIndex >= dimLB && dimIndex < dimLB + dimExtent) {
761 resultElements.push_back(array->At(arrayAt));
762 } else if (boundary) {
763 ConstantSubscripts boundaryAt;
764 if (boundary->Rank() > 0) {
765 for (int j{0}; j < rank; ++j) {
766 int k{0};
767 if (j != zbDim) {
768 boundaryAt.emplace_back(
769 boundaryLB[k++] + arrayAt[j] - arrayLB[j]);
770 }
771 }
772 }
773 resultElements.push_back(boundary->At(boundaryAt));
774 } else if constexpr (T::category == TypeCategory::Integer ||
775 T::category == TypeCategory::Real ||
776 T::category == TypeCategory::Complex ||
777 T::category == TypeCategory::Logical) {
778 resultElements.emplace_back();
779 } else if constexpr (T::category == TypeCategory::Character) {
780 auto len{static_cast<std::size_t>(array->LEN())};
781 typename Scalar<T>::value_type space{' '};
782 resultElements.emplace_back(len, space);
783 } else {
784 DIE("no derived type boundary");
785 }
786 dimIndex = origDimIndex;
787 array->IncrementSubscripts(arrayAt);
788 }
789 return Expr<T>{PackageConstant<T>(
790 std::move(resultElements), *array, array->shape())};
791 }
792 }
793 // Invalid, prevent re-folding
794 return MakeInvalidIntrinsic(std::move(funcRef));
795}
796
797template <typename T> Expr<T> Folder<T>::MERGE(FunctionRef<T> &&funcRef) {
798 return FoldElementalIntrinsic<T, T, T, LogicalResult>(context_,
799 std::move(funcRef),
800 ScalarFunc<T, T, T, LogicalResult>(
801 [](const Scalar<T> &ifTrue, const Scalar<T> &ifFalse,
802 const Scalar<LogicalResult> &predicate) -> Scalar<T> {
803 return predicate.IsTrue() ? ifTrue : ifFalse;
804 }));
805}
806
807template <typename T> Expr<T> Folder<T>::PACK(FunctionRef<T> &&funcRef) {
808 auto args{funcRef.arguments()};
809 CHECK(args.size() == 3);
810 const auto *array{UnwrapConstantValue<T>(args[0])};
811 const auto *vector{UnwrapConstantValue<T>(args[2])};
812 auto convertedMask{Fold(context_,
813 ConvertToType<LogicalResult>(
814 Expr<SomeLogical>{DEREF(UnwrapExpr<Expr<SomeLogical>>(args[1]))}))};
815 const auto *mask{UnwrapConstantValue<LogicalResult>(convertedMask)};
816 if (!array || !mask || (args[2] && !vector)) {
817 return Expr<T>{std::move(funcRef)};
818 }
819 // Arguments are constant.
820 ConstantSubscript arrayElements{GetSize(array->shape())};
821 ConstantSubscript truths{0};
822 ConstantSubscripts maskAt{mask->lbounds()};
823 if (mask->Rank() == 0) {
824 if (mask->At(maskAt).IsTrue()) {
825 truths = arrayElements;
826 }
827 } else if (array->shape() != mask->shape()) {
828 // Error already emitted from intrinsic processing
829 return MakeInvalidIntrinsic(std::move(funcRef));
830 } else {
831 for (ConstantSubscript j{0}; j < arrayElements;
832 ++j, mask->IncrementSubscripts(maskAt)) {
833 if (mask->At(maskAt).IsTrue()) {
834 ++truths;
835 }
836 }
837 }
838 std::vector<Scalar<T>> resultElements;
839 ConstantSubscripts arrayAt{array->lbounds()};
840 ConstantSubscript resultSize{truths};
841 if (vector) {
842 resultSize = vector->shape().at(0);
843 if (resultSize < truths) {
844 context_.messages().Say(
845 "Invalid 'vector=' argument in PACK: the 'mask=' argument has %jd true elements, but the vector has only %jd elements"_err_en_US,
846 static_cast<std::intmax_t>(truths),
847 static_cast<std::intmax_t>(resultSize));
848 return MakeInvalidIntrinsic(std::move(funcRef));
849 }
850 }
851 for (ConstantSubscript j{0}; j < truths;) {
852 if (mask->At(maskAt).IsTrue()) {
853 resultElements.push_back(array->At(arrayAt));
854 ++j;
855 }
856 array->IncrementSubscripts(arrayAt);
857 mask->IncrementSubscripts(maskAt);
858 }
859 if (vector) {
860 ConstantSubscripts vectorAt{vector->lbounds()};
861 vectorAt.at(0) += truths;
862 for (ConstantSubscript j{truths}; j < resultSize; ++j) {
863 resultElements.push_back(vector->At(vectorAt));
864 ++vectorAt[0];
865 }
866 }
867 return Expr<T>{PackageConstant<T>(std::move(resultElements), *array,
868 ConstantSubscripts{static_cast<ConstantSubscript>(resultSize)})};
869}
870
871template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
872 auto args{funcRef.arguments()};
873 CHECK(args.size() == 4);
874 const auto *source{UnwrapConstantValue<T>(args[0])};
875 const auto *pad{UnwrapConstantValue<T>(args[2])};
876 std::optional<std::vector<ConstantSubscript>> shape{
877 GetIntegerVector<ConstantSubscript>(args[1])};
878 std::optional<std::vector<int>> order{GetIntegerVector<int>(args[3])};
879 if (!source || !shape || (args[2] && !pad) || (args[3] && !order)) {
880 return Expr<T>{std::move(funcRef)}; // Non-constant arguments
881 } else if (shape.value().size() > common::maxRank) {
882 context_.messages().Say(
883 "Size of 'shape=' argument must not be greater than %d"_err_en_US,
884 common::maxRank);
885 } else if (HasNegativeExtent(shape.value())) {
886 context_.messages().Say(
887 "'shape=' argument must not have a negative extent"_err_en_US);
888 } else {
889 std::optional<uint64_t> optResultElement{TotalElementCount(shape.value())};
890 if (!optResultElement) {
891 context_.messages().Say(
892 "'shape=' argument has too many elements"_err_en_US);
893 } else {
894 int rank{GetRank(shape.value())};
895 uint64_t resultElements{*optResultElement};
896 std::optional<std::vector<int>> dimOrder;
897 if (order) {
898 dimOrder = ValidateDimensionOrder(rank, *order);
899 }
900 std::vector<int> *dimOrderPtr{dimOrder ? &dimOrder.value() : nullptr};
901 if (order && !dimOrder) {
902 context_.messages().Say(
903 "Invalid 'order=' argument in RESHAPE"_err_en_US);
904 } else if (resultElements > source->size() && (!pad || pad->empty())) {
905 context_.messages().Say(
906 "Too few elements in 'source=' argument and 'pad=' "
907 "argument is not present or has null size"_err_en_US);
908 } else {
909 Constant<T> result{!source->empty() || !pad
910 ? source->Reshape(std::move(shape.value()))
911 : pad->Reshape(std::move(shape.value()))};
912 ConstantSubscripts subscripts{result.lbounds()};
913 auto copied{result.CopyFrom(*source,
914 std::min(a: static_cast<uint64_t>(source->size()), b: resultElements),
915 subscripts, dimOrderPtr)};
916 if (copied < resultElements) {
917 CHECK(pad);
918 copied += result.CopyFrom(
919 *pad, resultElements - copied, subscripts, dimOrderPtr);
920 }
921 CHECK(copied == resultElements);
922 return Expr<T>{std::move(result)};
923 }
924 }
925 }
926 // Invalid, prevent re-folding
927 return MakeInvalidIntrinsic(std::move(funcRef));
928}
929
930template <typename T> Expr<T> Folder<T>::SPREAD(FunctionRef<T> &&funcRef) {
931 auto args{funcRef.arguments()};
932 CHECK(args.size() == 3);
933 const Constant<T> *source{UnwrapConstantValue<T>(args[0])};
934 auto dim{ToInt64(args[1])};
935 auto ncopies{ToInt64(args[2])};
936 if (!source || !dim) {
937 return Expr<T>{std::move(funcRef)};
938 }
939 int sourceRank{source->Rank()};
940 if (sourceRank >= common::maxRank) {
941 context_.messages().Say(
942 "SOURCE= argument to SPREAD has rank %d but must have rank less than %d"_err_en_US,
943 sourceRank, common::maxRank);
944 } else if (*dim < 1 || *dim > sourceRank + 1) {
945 context_.messages().Say(
946 "DIM=%d argument to SPREAD must be between 1 and %d"_err_en_US, *dim,
947 sourceRank + 1);
948 } else if (!ncopies) {
949 return Expr<T>{std::move(funcRef)};
950 } else {
951 if (*ncopies < 0) {
952 ncopies = 0;
953 }
954 // TODO: Consider moving this implementation (after the user error
955 // checks), along with other transformational intrinsics, into
956 // constant.h (or a new header) so that the transformationals
957 // are available for all Constant<>s without needing to be packaged
958 // as references to intrinsic functions for folding.
959 ConstantSubscripts shape{source->shape()};
960 shape.insert(shape.begin() + *dim - 1, *ncopies);
961 Constant<T> spread{source->Reshape(std::move(shape))};
962 std::optional<uint64_t> n{TotalElementCount(spread.shape())};
963 if (!n) {
964 context_.messages().Say("Too many elements in SPREAD result"_err_en_US);
965 } else {
966 std::vector<int> dimOrder;
967 for (int j{0}; j < sourceRank; ++j) {
968 dimOrder.push_back(j < *dim - 1 ? j : j + 1);
969 }
970 dimOrder.push_back(*dim - 1);
971 ConstantSubscripts at{spread.lbounds()}; // all 1
972 spread.CopyFrom(*source, *n, at, &dimOrder);
973 return Expr<T>{std::move(spread)};
974 }
975 }
976 // Invalid, prevent re-folding
977 return MakeInvalidIntrinsic(std::move(funcRef));
978}
979
980template <typename T> Expr<T> Folder<T>::TRANSPOSE(FunctionRef<T> &&funcRef) {
981 auto args{funcRef.arguments()};
982 CHECK(args.size() == 1);
983 const auto *matrix{UnwrapConstantValue<T>(args[0])};
984 if (!matrix) {
985 return Expr<T>{std::move(funcRef)};
986 }
987 // Argument is constant. Traverse its elements in transposed order.
988 std::vector<Scalar<T>> resultElements;
989 ConstantSubscripts at(2);
990 for (ConstantSubscript j{0}; j < matrix->shape()[0]; ++j) {
991 at[0] = matrix->lbounds()[0] + j;
992 for (ConstantSubscript k{0}; k < matrix->shape()[1]; ++k) {
993 at[1] = matrix->lbounds()[1] + k;
994 resultElements.push_back(matrix->At(at));
995 }
996 }
997 at = matrix->shape();
998 std::swap(at[0], at[1]);
999 return Expr<T>{PackageConstant<T>(std::move(resultElements), *matrix, at)};
1000}
1001
1002template <typename T> Expr<T> Folder<T>::UNPACK(FunctionRef<T> &&funcRef) {
1003 auto args{funcRef.arguments()};
1004 CHECK(args.size() == 3);
1005 const auto *vector{UnwrapConstantValue<T>(args[0])};
1006 auto convertedMask{Fold(context_,
1007 ConvertToType<LogicalResult>(
1008 Expr<SomeLogical>{DEREF(UnwrapExpr<Expr<SomeLogical>>(args[1]))}))};
1009 const auto *mask{UnwrapConstantValue<LogicalResult>(convertedMask)};
1010 const auto *field{UnwrapConstantValue<T>(args[2])};
1011 if (!vector || !mask || !field) {
1012 return Expr<T>{std::move(funcRef)};
1013 }
1014 // Arguments are constant.
1015 if (field->Rank() > 0 && field->shape() != mask->shape()) {
1016 // Error already emitted from intrinsic processing
1017 return MakeInvalidIntrinsic(std::move(funcRef));
1018 }
1019 ConstantSubscript maskElements{GetSize(mask->shape())};
1020 ConstantSubscript truths{0};
1021 ConstantSubscripts maskAt{mask->lbounds()};
1022 for (ConstantSubscript j{0}; j < maskElements;
1023 ++j, mask->IncrementSubscripts(maskAt)) {
1024 if (mask->At(maskAt).IsTrue()) {
1025 ++truths;
1026 }
1027 }
1028 if (truths > GetSize(vector->shape())) {
1029 context_.messages().Say(
1030 "Invalid 'vector=' argument in UNPACK: the 'mask=' argument has %jd true elements, but the vector has only %jd elements"_err_en_US,
1031 static_cast<std::intmax_t>(truths),
1032 static_cast<std::intmax_t>(GetSize(vector->shape())));
1033 return MakeInvalidIntrinsic(std::move(funcRef));
1034 }
1035 std::vector<Scalar<T>> resultElements;
1036 ConstantSubscripts vectorAt{vector->lbounds()};
1037 ConstantSubscripts fieldAt{field->lbounds()};
1038 for (ConstantSubscript j{0}; j < maskElements; ++j) {
1039 if (mask->At(maskAt).IsTrue()) {
1040 resultElements.push_back(vector->At(vectorAt));
1041 vector->IncrementSubscripts(vectorAt);
1042 } else {
1043 resultElements.push_back(field->At(fieldAt));
1044 }
1045 mask->IncrementSubscripts(maskAt);
1046 field->IncrementSubscripts(fieldAt);
1047 }
1048 return Expr<T>{
1049 PackageConstant<T>(std::move(resultElements), *vector, mask->shape())};
1050}
1051
1052std::optional<Expr<SomeType>> FoldTransfer(
1053 FoldingContext &, const ActualArguments &);
1054
1055template <typename T> Expr<T> Folder<T>::TRANSFER(FunctionRef<T> &&funcRef) {
1056 if (auto folded{FoldTransfer(context_, funcRef.arguments())}) {
1057 return DEREF(UnwrapExpr<Expr<T>>(*folded));
1058 } else {
1059 return Expr<T>{std::move(funcRef)};
1060 }
1061}
1062
1063template <typename T>
1064Expr<T> FoldMINorMAX(
1065 FoldingContext &context, FunctionRef<T> &&funcRef, Ordering order) {
1066 static_assert(T::category == TypeCategory::Integer ||
1067 T::category == TypeCategory::Real ||
1068 T::category == TypeCategory::Character);
1069 std::vector<Constant<T> *> constantArgs;
1070 // Call Folding on all arguments, even if some are not constant,
1071 // to make operand promotion explicit.
1072 for (auto &arg : funcRef.arguments()) {
1073 if (auto *cst{Folder<T>{context}.Folding(arg)}) {
1074 constantArgs.push_back(cst);
1075 }
1076 }
1077 if (constantArgs.size() != funcRef.arguments().size()) {
1078 return Expr<T>(std::move(funcRef));
1079 }
1080 CHECK(!constantArgs.empty());
1081 Expr<T> result{std::move(*constantArgs[0])};
1082 for (std::size_t i{1}; i < constantArgs.size(); ++i) {
1083 Extremum<T> extremum{order, result, Expr<T>{std::move(*constantArgs[i])}};
1084 result = FoldOperation(context, std::move(extremum));
1085 }
1086 return result;
1087}
1088
1089// For AMAX0, AMIN0, AMAX1, AMIN1, DMAX1, DMIN1, MAX0, MIN0, MAX1, and MIN1
1090// a special care has to be taken to insert the conversion on the result
1091// of the MIN/MAX. This is made slightly more complex by the extension
1092// supported by f18 that arguments may have different kinds. This implies
1093// that the created MIN/MAX result type cannot be deduced from the standard but
1094// has to be deduced from the arguments.
1095// e.g. AMAX0(int8, int4) is rewritten to REAL(MAX(int8, INT(int4, 8)))).
1096template <typename T>
1097Expr<T> RewriteSpecificMINorMAX(
1098 FoldingContext &context, FunctionRef<T> &&funcRef) {
1099 ActualArguments &args{funcRef.arguments()};
1100 auto &intrinsic{DEREF(std::get_if<SpecificIntrinsic>(&funcRef.proc().u))};
1101 // Rewrite MAX1(args) to INT(MAX(args)) and fold. Same logic for MIN1.
1102 // Find result type for max/min based on the arguments.
1103 std::optional<DynamicType> resultType;
1104 ActualArgument *resultTypeArg{nullptr};
1105 for (auto j{args.size()}; j-- > 0;) {
1106 if (args[j]) {
1107 DynamicType type{args[j]->GetType().value()};
1108 // Handle mixed real/integer arguments: all the previous arguments were
1109 // integers and this one is real. The type of the MAX/MIN result will
1110 // be the one of the real argument.
1111 if (!resultType ||
1112 (type.category() == resultType->category() &&
1113 type.kind() > resultType->kind()) ||
1114 resultType->category() == TypeCategory::Integer) {
1115 resultType = type;
1116 resultTypeArg = &*args[j];
1117 }
1118 }
1119 }
1120 if (!resultType) { // error recovery
1121 return Expr<T>{std::move(funcRef)};
1122 }
1123 intrinsic.name =
1124 intrinsic.name.find("max") != std::string::npos ? "max"s : "min"s;
1125 intrinsic.characteristics.value().functionResult.value().SetType(*resultType);
1126 auto insertConversion{[&](const auto &x) -> Expr<T> {
1127 using TR = ResultType<decltype(x)>;
1128 FunctionRef<TR> maxRef{
1129 ProcedureDesignator{funcRef.proc()}, ActualArguments{args}};
1130 return Fold(context, ConvertToType<T>(AsCategoryExpr(std::move(maxRef))));
1131 }};
1132 if (auto *sx{UnwrapExpr<Expr<SomeReal>>(*resultTypeArg)}) {
1133 return common::visit(insertConversion, sx->u);
1134 } else if (auto *sx{UnwrapExpr<Expr<SomeInteger>>(*resultTypeArg)}) {
1135 return common::visit(insertConversion, sx->u);
1136 } else {
1137 return Expr<T>{std::move(funcRef)}; // error recovery
1138 }
1139}
1140
1141// FoldIntrinsicFunction()
1142template <int KIND>
1143Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
1144 FoldingContext &context, FunctionRef<Type<TypeCategory::Integer, KIND>> &&);
1145template <int KIND>
1146Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
1147 FoldingContext &context, FunctionRef<Type<TypeCategory::Real, KIND>> &&);
1148template <int KIND>
1149Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
1150 FoldingContext &context, FunctionRef<Type<TypeCategory::Complex, KIND>> &&);
1151template <int KIND>
1152Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
1153 FoldingContext &context, FunctionRef<Type<TypeCategory::Logical, KIND>> &&);
1154
1155template <typename T>
1156Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
1157 ActualArguments &args{funcRef.arguments()};
1158 const auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
1159 if (!intrinsic || intrinsic->name != "kind") {
1160 // Don't fold the argument to KIND(); it might be a TypeParamInquiry
1161 // with a forced result type that doesn't match the parameter.
1162 for (std::optional<ActualArgument> &arg : args) {
1163 if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
1164 *expr = Fold(context, std::move(*expr));
1165 }
1166 }
1167 }
1168 if (intrinsic) {
1169 const std::string name{intrinsic->name};
1170 if (name == "cshift") {
1171 return Folder<T>{context}.CSHIFT(std::move(funcRef));
1172 } else if (name == "eoshift") {
1173 return Folder<T>{context}.EOSHIFT(std::move(funcRef));
1174 } else if (name == "merge") {
1175 return Folder<T>{context}.MERGE(std::move(funcRef));
1176 } else if (name == "pack") {
1177 return Folder<T>{context}.PACK(std::move(funcRef));
1178 } else if (name == "reshape") {
1179 return Folder<T>{context}.RESHAPE(std::move(funcRef));
1180 } else if (name == "spread") {
1181 return Folder<T>{context}.SPREAD(std::move(funcRef));
1182 } else if (name == "transfer") {
1183 return Folder<T>{context}.TRANSFER(std::move(funcRef));
1184 } else if (name == "transpose") {
1185 return Folder<T>{context}.TRANSPOSE(std::move(funcRef));
1186 } else if (name == "unpack") {
1187 return Folder<T>{context}.UNPACK(std::move(funcRef));
1188 }
1189 // TODO: extends_type_of, same_type_as
1190 if constexpr (!std::is_same_v<T, SomeDerived>) {
1191 return FoldIntrinsicFunction(context, std::move(funcRef));
1192 }
1193 }
1194 return Expr<T>{std::move(funcRef)};
1195}
1196
1197Expr<ImpliedDoIndex::Result> FoldOperation(FoldingContext &, ImpliedDoIndex &&);
1198
1199// Array constructor folding
1200template <typename T> class ArrayConstructorFolder {
1201public:
1202 explicit ArrayConstructorFolder(FoldingContext &c) : context_{c} {}
1203
1204 Expr<T> FoldArray(ArrayConstructor<T> &&array) {
1205 // Calls FoldArray(const ArrayConstructorValues<T> &) below
1206 if (FoldArray(array)) {
1207 auto n{static_cast<ConstantSubscript>(elements_.size())};
1208 if constexpr (std::is_same_v<T, SomeDerived>) {
1209 return Expr<T>{Constant<T>{array.GetType().GetDerivedTypeSpec(),
1210 std::move(elements_), ConstantSubscripts{n}}};
1211 } else if constexpr (T::category == TypeCategory::Character) {
1212 if (const auto *len{array.LEN()}) {
1213 auto length{Fold(context_, common::Clone(*len))};
1214 if (std::optional<ConstantSubscript> lengthValue{ToInt64(length)}) {
1215 return Expr<T>{Constant<T>{
1216 *lengthValue, std::move(elements_), ConstantSubscripts{n}}};
1217 }
1218 }
1219 } else {
1220 return Expr<T>{
1221 Constant<T>{std::move(elements_), ConstantSubscripts{n}}};
1222 }
1223 }
1224 return Expr<T>{std::move(array)};
1225 }
1226
1227private:
1228 bool FoldArray(const Expr<T> &expr) {
1229 Expr<T> folded{Fold(context_, common::Clone(expr))};
1230 if (const auto *c{UnwrapConstantValue<T>(folded)}) {
1231 // Copy elements in Fortran array element order
1232 if (!c->empty()) {
1233 ConstantSubscripts index{c->lbounds()};
1234 do {
1235 elements_.emplace_back(c->At(index));
1236 } while (c->IncrementSubscripts(index));
1237 }
1238 return true;
1239 } else {
1240 return false;
1241 }
1242 }
1243 bool FoldArray(const common::CopyableIndirection<Expr<T>> &expr) {
1244 return FoldArray(expr.value());
1245 }
1246 bool FoldArray(const ImpliedDo<T> &iDo) {
1247 Expr<SubscriptInteger> lower{
1248 Fold(context_, Expr<SubscriptInteger>{iDo.lower()})};
1249 Expr<SubscriptInteger> upper{
1250 Fold(context_, Expr<SubscriptInteger>{iDo.upper()})};
1251 Expr<SubscriptInteger> stride{
1252 Fold(context_, Expr<SubscriptInteger>{iDo.stride()})};
1253 std::optional<ConstantSubscript> start{ToInt64(lower)}, end{ToInt64(upper)},
1254 step{ToInt64(stride)};
1255 if (start && end && step && *step != 0) {
1256 bool result{true};
1257 ConstantSubscript &j{context_.StartImpliedDo(iDo.name(), *start)};
1258 if (*step > 0) {
1259 for (; j <= *end; j += *step) {
1260 result &= FoldArray(iDo.values());
1261 }
1262 } else {
1263 for (; j >= *end; j += *step) {
1264 result &= FoldArray(iDo.values());
1265 }
1266 }
1267 context_.EndImpliedDo(iDo.name());
1268 return result;
1269 } else {
1270 return false;
1271 }
1272 }
1273 bool FoldArray(const ArrayConstructorValue<T> &x) {
1274 return common::visit([&](const auto &y) { return FoldArray(y); }, x.u);
1275 }
1276 bool FoldArray(const ArrayConstructorValues<T> &xs) {
1277 for (const auto &x : xs) {
1278 if (!FoldArray(x)) {
1279 return false;
1280 }
1281 }
1282 return true;
1283 }
1284
1285 FoldingContext &context_;
1286 std::vector<Scalar<T>> elements_;
1287};
1288
1289template <typename T>
1290Expr<T> FoldOperation(FoldingContext &context, ArrayConstructor<T> &&array) {
1291 return ArrayConstructorFolder<T>{context}.FoldArray(std::move(array));
1292}
1293
1294// Array operation elemental application: When all operands to an operation
1295// are constant arrays, array constructors without any implied DO loops,
1296// &/or expanded scalars, pull the operation "into" the array result by
1297// applying it in an elementwise fashion. For example, [A,1]+[B,2]
1298// is rewritten into [A+B,1+2] and then partially folded to [A+B,3].
1299
1300// If possible, restructures an array expression into an array constructor
1301// that comprises a "flat" ArrayConstructorValues with no implied DO loops.
1302template <typename T>
1303bool ArrayConstructorIsFlat(const ArrayConstructorValues<T> &values) {
1304 for (const ArrayConstructorValue<T> &x : values) {
1305 if (!std::holds_alternative<Expr<T>>(x.u)) {
1306 return false;
1307 }
1308 }
1309 return true;
1310}
1311
1312template <typename T>
1313std::optional<Expr<T>> AsFlatArrayConstructor(const Expr<T> &expr) {
1314 if (const auto *c{UnwrapConstantValue<T>(expr)}) {
1315 ArrayConstructor<T> result{expr};
1316 if (!c->empty()) {
1317 ConstantSubscripts at{c->lbounds()};
1318 do {
1319 result.Push(Expr<T>{Constant<T>{c->At(at)}});
1320 } while (c->IncrementSubscripts(at));
1321 }
1322 return std::make_optional<Expr<T>>(std::move(result));
1323 } else if (const auto *a{UnwrapExpr<ArrayConstructor<T>>(expr)}) {
1324 if (ArrayConstructorIsFlat(*a)) {
1325 return std::make_optional<Expr<T>>(expr);
1326 }
1327 } else if (const auto *p{UnwrapExpr<Parentheses<T>>(expr)}) {
1328 return AsFlatArrayConstructor(Expr<T>{p->left()});
1329 }
1330 return std::nullopt;
1331}
1332
1333template <TypeCategory CAT>
1334std::enable_if_t<CAT != TypeCategory::Derived,
1335 std::optional<Expr<SomeKind<CAT>>>>
1336AsFlatArrayConstructor(const Expr<SomeKind<CAT>> &expr) {
1337 return common::visit(
1338 [&](const auto &kindExpr) -> std::optional<Expr<SomeKind<CAT>>> {
1339 if (auto flattened{AsFlatArrayConstructor(kindExpr)}) {
1340 return Expr<SomeKind<CAT>>{std::move(*flattened)};
1341 } else {
1342 return std::nullopt;
1343 }
1344 },
1345 expr.u);
1346}
1347
1348// FromArrayConstructor is a subroutine for MapOperation() below.
1349// Given a flat ArrayConstructor<T> and a shape, it wraps the array
1350// into an Expr<T>, folds it, and returns the resulting wrapped
1351// array constructor or constant array value.
1352template <typename T>
1353std::optional<Expr<T>> FromArrayConstructor(
1354 FoldingContext &context, ArrayConstructor<T> &&values, const Shape &shape) {
1355 if (auto constShape{AsConstantExtents(context, shape)}) {
1356 Expr<T> result{Fold(context, Expr<T>{std::move(values)})};
1357 if (auto *constant{UnwrapConstantValue<T>(result)}) {
1358 // Elements and shape are both constant.
1359 return Expr<T>{constant->Reshape(std::move(*constShape))};
1360 }
1361 if (constShape->size() == 1) {
1362 if (auto elements{GetShape(context, result)}) {
1363 if (auto constElements{AsConstantExtents(context, *elements)}) {
1364 if (constElements->size() == 1 &&
1365 constElements->at(0) == constShape->at(0)) {
1366 // Elements are not constant, but array constructor has
1367 // the right known shape and can be simply returned as is.
1368 return std::move(result);
1369 }
1370 }
1371 }
1372 }
1373 }
1374 return std::nullopt;
1375}
1376
1377// MapOperation is a utility for various specializations of ApplyElementwise()
1378// that follow. Given one or two flat ArrayConstructor<OPERAND> (wrapped in an
1379// Expr<OPERAND>) for some specific operand type(s), apply a given function f
1380// to each of their corresponding elements to produce a flat
1381// ArrayConstructor<RESULT> (wrapped in an Expr<RESULT>).
1382// Preserves shape.
1383
1384// Unary case
1385template <typename RESULT, typename OPERAND>
1386std::optional<Expr<RESULT>> MapOperation(FoldingContext &context,
1387 std::function<Expr<RESULT>(Expr<OPERAND> &&)> &&f, const Shape &shape,
1388 [[maybe_unused]] std::optional<Expr<SubscriptInteger>> &&length,
1389 Expr<OPERAND> &&values) {
1390 ArrayConstructor<RESULT> result{values};
1391 if constexpr (common::HasMember<OPERAND, AllIntrinsicCategoryTypes>) {
1392 common::visit(
1393 [&](auto &&kindExpr) {
1394 using kindType = ResultType<decltype(kindExpr)>;
1395 auto &aConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
1396 for (auto &acValue : aConst) {
1397 auto &scalar{std::get<Expr<kindType>>(acValue.u)};
1398 result.Push(Fold(context, f(Expr<OPERAND>{std::move(scalar)})));
1399 }
1400 },
1401 std::move(values.u));
1402 } else {
1403 auto &aConst{std::get<ArrayConstructor<OPERAND>>(values.u)};
1404 for (auto &acValue : aConst) {
1405 auto &scalar{std::get<Expr<OPERAND>>(acValue.u)};
1406 result.Push(Fold(context, f(std::move(scalar))));
1407 }
1408 }
1409 if constexpr (RESULT::category == TypeCategory::Character) {
1410 if (length) {
1411 result.set_LEN(std::move(*length));
1412 }
1413 }
1414 return FromArrayConstructor(context, std::move(result), shape);
1415}
1416
1417template <typename RESULT, typename A>
1418ArrayConstructor<RESULT> ArrayConstructorFromMold(
1419 const A &prototype, std::optional<Expr<SubscriptInteger>> &&length) {
1420 ArrayConstructor<RESULT> result{prototype};
1421 if constexpr (RESULT::category == TypeCategory::Character) {
1422 if (length) {
1423 result.set_LEN(std::move(*length));
1424 }
1425 }
1426 return result;
1427}
1428
1429template <typename LEFT, typename RIGHT>
1430bool ShapesMatch(FoldingContext &context,
1431 const ArrayConstructor<LEFT> &leftArrConst,
1432 const ArrayConstructor<RIGHT> &rightArrConst) {
1433 auto rightIter{rightArrConst.begin()};
1434 for (auto &leftValue : leftArrConst) {
1435 CHECK(rightIter != rightArrConst.end());
1436 auto &leftExpr{std::get<Expr<LEFT>>(leftValue.u)};
1437 auto &rightExpr{std::get<Expr<RIGHT>>(rightIter->u)};
1438 if (leftExpr.Rank() != rightExpr.Rank()) {
1439 return false;
1440 }
1441 std::optional<Shape> leftShape{GetShape(context, leftExpr)};
1442 std::optional<Shape> rightShape{GetShape(context, rightExpr)};
1443 if (!leftShape || !rightShape || *leftShape != *rightShape) {
1444 return false;
1445 }
1446 ++rightIter;
1447 }
1448 return true;
1449}
1450
1451// array * array case
1452template <typename RESULT, typename LEFT, typename RIGHT>
1453auto MapOperation(FoldingContext &context,
1454 std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
1455 const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
1456 Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues)
1457 -> std::optional<Expr<RESULT>> {
1458 auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
1459 auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
1460 if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
1461 bool mapped{common::visit(
1462 [&](auto &&kindExpr) -> bool {
1463 using kindType = ResultType<decltype(kindExpr)>;
1464
1465 auto &rightArrConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
1466 if (!ShapesMatch(context, leftArrConst, rightArrConst)) {
1467 return false;
1468 }
1469 auto rightIter{rightArrConst.begin()};
1470 for (auto &leftValue : leftArrConst) {
1471 CHECK(rightIter != rightArrConst.end());
1472 auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
1473 auto &rightScalar{std::get<Expr<kindType>>(rightIter->u)};
1474 result.Push(Fold(context,
1475 f(std::move(leftScalar), Expr<RIGHT>{std::move(rightScalar)})));
1476 ++rightIter;
1477 }
1478 return true;
1479 },
1480 std::move(rightValues.u))};
1481 if (!mapped) {
1482 return std::nullopt;
1483 }
1484 } else {
1485 auto &rightArrConst{std::get<ArrayConstructor<RIGHT>>(rightValues.u)};
1486 if (!ShapesMatch(context, leftArrConst, rightArrConst)) {
1487 return std::nullopt;
1488 }
1489 auto rightIter{rightArrConst.begin()};
1490 for (auto &leftValue : leftArrConst) {
1491 CHECK(rightIter != rightArrConst.end());
1492 auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
1493 auto &rightScalar{std::get<Expr<RIGHT>>(rightIter->u)};
1494 result.Push(
1495 Fold(context, f(std::move(leftScalar), std::move(rightScalar))));
1496 ++rightIter;
1497 }
1498 }
1499 return FromArrayConstructor(context, std::move(result), shape);
1500}
1501
1502// array * scalar case
1503template <typename RESULT, typename LEFT, typename RIGHT>
1504auto MapOperation(FoldingContext &context,
1505 std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
1506 const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
1507 Expr<LEFT> &&leftValues, const Expr<RIGHT> &rightScalar)
1508 -> std::optional<Expr<RESULT>> {
1509 auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
1510 auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
1511 for (auto &leftValue : leftArrConst) {
1512 auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
1513 result.Push(
1514 Fold(context, f(std::move(leftScalar), Expr<RIGHT>{rightScalar})));
1515 }
1516 return FromArrayConstructor(context, std::move(result), shape);
1517}
1518
1519// scalar * array case
1520template <typename RESULT, typename LEFT, typename RIGHT>
1521auto MapOperation(FoldingContext &context,
1522 std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
1523 const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
1524 const Expr<LEFT> &leftScalar, Expr<RIGHT> &&rightValues)
1525 -> std::optional<Expr<RESULT>> {
1526 auto result{ArrayConstructorFromMold<RESULT>(leftScalar, std::move(length))};
1527 if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
1528 common::visit(
1529 [&](auto &&kindExpr) {
1530 using kindType = ResultType<decltype(kindExpr)>;
1531 auto &rightArrConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
1532 for (auto &rightValue : rightArrConst) {
1533 auto &rightScalar{std::get<Expr<kindType>>(rightValue.u)};
1534 result.Push(Fold(context,
1535 f(Expr<LEFT>{leftScalar},
1536 Expr<RIGHT>{std::move(rightScalar)})));
1537 }
1538 },
1539 std::move(rightValues.u));
1540 } else {
1541 auto &rightArrConst{std::get<ArrayConstructor<RIGHT>>(rightValues.u)};
1542 for (auto &rightValue : rightArrConst) {
1543 auto &rightScalar{std::get<Expr<RIGHT>>(rightValue.u)};
1544 result.Push(
1545 Fold(context, f(Expr<LEFT>{leftScalar}, std::move(rightScalar))));
1546 }
1547 }
1548 return FromArrayConstructor(context, std::move(result), shape);
1549}
1550
1551template <typename DERIVED, typename RESULT, typename... OPD>
1552std::optional<Expr<SubscriptInteger>> ComputeResultLength(
1553 Operation<DERIVED, RESULT, OPD...> &operation) {
1554 if constexpr (RESULT::category == TypeCategory::Character) {
1555 return Expr<RESULT>{operation.derived()}.LEN();
1556 }
1557 return std::nullopt;
1558}
1559
1560// ApplyElementwise() recursively folds the operand expression(s) of an
1561// operation, then attempts to apply the operation to the (corresponding)
1562// scalar element(s) of those operands. Returns std::nullopt for scalars
1563// or unlinearizable operands.
1564template <typename DERIVED, typename RESULT, typename OPERAND>
1565auto ApplyElementwise(FoldingContext &context,
1566 Operation<DERIVED, RESULT, OPERAND> &operation,
1567 std::function<Expr<RESULT>(Expr<OPERAND> &&)> &&f)
1568 -> std::optional<Expr<RESULT>> {
1569 auto &expr{operation.left()};
1570 expr = Fold(context, std::move(expr));
1571 if (expr.Rank() > 0) {
1572 if (std::optional<Shape> shape{GetShape(context, expr)}) {
1573 if (auto values{AsFlatArrayConstructor(expr)}) {
1574 return MapOperation(context, std::move(f), *shape,
1575 ComputeResultLength(operation), std::move(*values));
1576 }
1577 }
1578 }
1579 return std::nullopt;
1580}
1581
1582template <typename DERIVED, typename RESULT, typename OPERAND>
1583auto ApplyElementwise(
1584 FoldingContext &context, Operation<DERIVED, RESULT, OPERAND> &operation)
1585 -> std::optional<Expr<RESULT>> {
1586 return ApplyElementwise(context, operation,
1587 std::function<Expr<RESULT>(Expr<OPERAND> &&)>{
1588 [](Expr<OPERAND> &&operand) {
1589 return Expr<RESULT>{DERIVED{std::move(operand)}};
1590 }});
1591}
1592
1593template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
1594auto ApplyElementwise(FoldingContext &context,
1595 Operation<DERIVED, RESULT, LEFT, RIGHT> &operation,
1596 std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f)
1597 -> std::optional<Expr<RESULT>> {
1598 auto resultLength{ComputeResultLength(operation)};
1599 auto &leftExpr{operation.left()};
1600 leftExpr = Fold(context, std::move(leftExpr));
1601 auto &rightExpr{operation.right()};
1602 rightExpr = Fold(context, std::move(rightExpr));
1603 if (leftExpr.Rank() > 0) {
1604 if (std::optional<Shape> leftShape{GetShape(context, leftExpr)}) {
1605 if (auto left{AsFlatArrayConstructor(leftExpr)}) {
1606 if (rightExpr.Rank() > 0) {
1607 if (std::optional<Shape> rightShape{GetShape(context, rightExpr)}) {
1608 if (auto right{AsFlatArrayConstructor(rightExpr)}) {
1609 if (CheckConformance(context.messages(), *leftShape, *rightShape,
1610 CheckConformanceFlags::EitherScalarExpandable)
1611 .value_or(false /*fail if not known now to conform*/)) {
1612 return MapOperation(context, std::move(f), *leftShape,
1613 std::move(resultLength), std::move(*left),
1614 std::move(*right));
1615 } else {
1616 return std::nullopt;
1617 }
1618 return MapOperation(context, std::move(f), *leftShape,
1619 std::move(resultLength), std::move(*left), std::move(*right));
1620 }
1621 }
1622 } else if (IsExpandableScalar(rightExpr, context, *leftShape)) {
1623 return MapOperation(context, std::move(f), *leftShape,
1624 std::move(resultLength), std::move(*left), rightExpr);
1625 }
1626 }
1627 }
1628 } else if (rightExpr.Rank() > 0) {
1629 if (std::optional<Shape> rightShape{GetShape(context, rightExpr)}) {
1630 if (IsExpandableScalar(leftExpr, context, *rightShape)) {
1631 if (auto right{AsFlatArrayConstructor(rightExpr)}) {
1632 return MapOperation(context, std::move(f), *rightShape,
1633 std::move(resultLength), leftExpr, std::move(*right));
1634 }
1635 }
1636 }
1637 }
1638 return std::nullopt;
1639}
1640
1641template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
1642auto ApplyElementwise(
1643 FoldingContext &context, Operation<DERIVED, RESULT, LEFT, RIGHT> &operation)
1644 -> std::optional<Expr<RESULT>> {
1645 return ApplyElementwise(context, operation,
1646 std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)>{
1647 [](Expr<LEFT> &&left, Expr<RIGHT> &&right) {
1648 return Expr<RESULT>{DERIVED{std::move(left), std::move(right)}};
1649 }});
1650}
1651
1652// Unary operations
1653
1654template <typename TO, typename FROM>
1655common::IfNoLvalue<std::optional<TO>, FROM> ConvertString(FROM &&s) {
1656 if constexpr (std::is_same_v<TO, FROM>) {
1657 return std::make_optional<TO>(std::move(s));
1658 } else {
1659 // Fortran character conversion is well defined between distinct kinds
1660 // only when the actual characters are valid 7-bit ASCII.
1661 TO str;
1662 for (auto iter{s.cbegin()}; iter != s.cend(); ++iter) {
1663 if (static_cast<std::uint64_t>(*iter) > 127) {
1664 return std::nullopt;
1665 }
1666 str.push_back(*iter);
1667 }
1668 return std::make_optional<TO>(std::move(str));
1669 }
1670}
1671
1672template <typename TO, TypeCategory FROMCAT>
1673Expr<TO> FoldOperation(
1674 FoldingContext &context, Convert<TO, FROMCAT> &&convert) {
1675 if (auto array{ApplyElementwise(context, convert)}) {
1676 return *array;
1677 }
1678 struct {
1679 FoldingContext &context;
1680 Convert<TO, FROMCAT> &convert;
1681 } msvcWorkaround{context, convert};
1682 return common::visit(
1683 [&msvcWorkaround](auto &kindExpr) -> Expr<TO> {
1684 using Operand = ResultType<decltype(kindExpr)>;
1685 // This variable is a workaround for msvc which emits an error when
1686 // using the FROMCAT template parameter below.
1687 TypeCategory constexpr FromCat{FROMCAT};
1688 static_assert(FromCat == Operand::category);
1689 auto &convert{msvcWorkaround.convert};
1690 if (auto value{GetScalarConstantValue<Operand>(kindExpr)}) {
1691 FoldingContext &ctx{msvcWorkaround.context};
1692 if constexpr (TO::category == TypeCategory::Integer) {
1693 if constexpr (FromCat == TypeCategory::Integer) {
1694 auto converted{Scalar<TO>::ConvertSigned(*value)};
1695 if (converted.overflow) {
1696 ctx.messages().Say(
1697 "INTEGER(%d) to INTEGER(%d) conversion overflowed"_warn_en_US,
1698 Operand::kind, TO::kind);
1699 }
1700 return ScalarConstantToExpr(std::move(converted.value));
1701 } else if constexpr (FromCat == TypeCategory::Real) {
1702 auto converted{value->template ToInteger<Scalar<TO>>()};
1703 if (converted.flags.test(RealFlag::InvalidArgument)) {
1704 ctx.messages().Say(
1705 "REAL(%d) to INTEGER(%d) conversion: invalid argument"_warn_en_US,
1706 Operand::kind, TO::kind);
1707 } else if (converted.flags.test(RealFlag::Overflow)) {
1708 ctx.messages().Say(
1709 "REAL(%d) to INTEGER(%d) conversion overflowed"_warn_en_US,
1710 Operand::kind, TO::kind);
1711 }
1712 return ScalarConstantToExpr(std::move(converted.value));
1713 }
1714 } else if constexpr (TO::category == TypeCategory::Real) {
1715 if constexpr (FromCat == TypeCategory::Integer) {
1716 auto converted{Scalar<TO>::FromInteger(*value)};
1717 if (!converted.flags.empty()) {
1718 char buffer[64];
1719 std::snprintf(buffer, sizeof buffer,
1720 "INTEGER(%d) to REAL(%d) conversion", Operand::kind,
1721 TO::kind);
1722 RealFlagWarnings(ctx, converted.flags, buffer);
1723 }
1724 return ScalarConstantToExpr(std::move(converted.value));
1725 } else if constexpr (FromCat == TypeCategory::Real) {
1726 auto converted{Scalar<TO>::Convert(*value)};
1727 char buffer[64];
1728 if (!converted.flags.empty()) {
1729 std::snprintf(buffer, sizeof buffer,
1730 "REAL(%d) to REAL(%d) conversion", Operand::kind, TO::kind);
1731 RealFlagWarnings(ctx, converted.flags, buffer);
1732 }
1733 if (ctx.targetCharacteristics().areSubnormalsFlushedToZero()) {
1734 converted.value = converted.value.FlushSubnormalToZero();
1735 }
1736 return ScalarConstantToExpr(std::move(converted.value));
1737 }
1738 } else if constexpr (TO::category == TypeCategory::Complex) {
1739 if constexpr (FromCat == TypeCategory::Complex) {
1740 return FoldOperation(ctx,
1741 ComplexConstructor<TO::kind>{
1742 AsExpr(Convert<typename TO::Part>{AsCategoryExpr(
1743 Constant<typename Operand::Part>{value->REAL()})}),
1744 AsExpr(Convert<typename TO::Part>{AsCategoryExpr(
1745 Constant<typename Operand::Part>{value->AIMAG()})})});
1746 }
1747 } else if constexpr (TO::category == TypeCategory::Character &&
1748 FromCat == TypeCategory::Character) {
1749 if (auto converted{ConvertString<Scalar<TO>>(std::move(*value))}) {
1750 return ScalarConstantToExpr(std::move(*converted));
1751 }
1752 } else if constexpr (TO::category == TypeCategory::Logical &&
1753 FromCat == TypeCategory::Logical) {
1754 return Expr<TO>{value->IsTrue()};
1755 }
1756 } else if constexpr (TO::category == FromCat &&
1757 FromCat != TypeCategory::Character) {
1758 // Conversion of non-constant in same type category
1759 if constexpr (std::is_same_v<Operand, TO>) {
1760 return std::move(kindExpr); // remove needless conversion
1761 } else if constexpr (TO::category == TypeCategory::Logical ||
1762 TO::category == TypeCategory::Integer) {
1763 if (auto *innerConv{
1764 std::get_if<Convert<Operand, TO::category>>(&kindExpr.u)}) {
1765 // Conversion of conversion of same category & kind
1766 if (auto *x{std::get_if<Expr<TO>>(&innerConv->left().u)}) {
1767 if constexpr (TO::category == TypeCategory::Logical ||
1768 TO::kind <= Operand::kind) {
1769 return std::move(*x); // no-op Logical or Integer
1770 // widening/narrowing conversion pair
1771 } else if constexpr (std::is_same_v<TO,
1772 DescriptorInquiry::Result>) {
1773 if (std::holds_alternative<DescriptorInquiry>(x->u) ||
1774 std::holds_alternative<TypeParamInquiry>(x->u)) {
1775 // int(int(size(...),kind=k),kind=8) -> size(...)
1776 return std::move(*x);
1777 }
1778 }
1779 }
1780 }
1781 }
1782 }
1783 return Expr<TO>{std::move(convert)};
1784 },
1785 convert.left().u);
1786}
1787
1788template <typename T>
1789Expr<T> FoldOperation(FoldingContext &context, Parentheses<T> &&x) {
1790 auto &operand{x.left()};
1791 operand = Fold(context, std::move(operand));
1792 if (auto value{GetScalarConstantValue<T>(operand)}) {
1793 // Preserve parentheses, even around constants.
1794 return Expr<T>{Parentheses<T>{Expr<T>{Constant<T>{*value}}}};
1795 } else if (std::holds_alternative<Parentheses<T>>(operand.u)) {
1796 // ((x)) -> (x)
1797 return std::move(operand);
1798 } else {
1799 return Expr<T>{Parentheses<T>{std::move(operand)}};
1800 }
1801}
1802
1803template <typename T>
1804Expr<T> FoldOperation(FoldingContext &context, Negate<T> &&x) {
1805 if (auto array{ApplyElementwise(context, x)}) {
1806 return *array;
1807 }
1808 auto &operand{x.left()};
1809 if (auto *nn{std::get_if<Negate<T>>(&x.left().u)}) {
1810 // -(-x) -> (x)
1811 if (IsVariable(nn->left())) {
1812 return FoldOperation(context, Parentheses<T>{std::move(nn->left())});
1813 } else {
1814 return std::move(nn->left());
1815 }
1816 } else if (auto value{GetScalarConstantValue<T>(operand)}) {
1817 if constexpr (T::category == TypeCategory::Integer) {
1818 auto negated{value->Negate()};
1819 if (negated.overflow) {
1820 context.messages().Say(
1821 "INTEGER(%d) negation overflowed"_warn_en_US, T::kind);
1822 }
1823 return Expr<T>{Constant<T>{std::move(negated.value)}};
1824 } else {
1825 // REAL & COMPLEX negation: no exceptions possible
1826 return Expr<T>{Constant<T>{value->Negate()}};
1827 }
1828 }
1829 return Expr<T>{std::move(x)};
1830}
1831
1832// Binary (dyadic) operations
1833
1834template <typename LEFT, typename RIGHT>
1835std::optional<std::pair<Scalar<LEFT>, Scalar<RIGHT>>> OperandsAreConstants(
1836 const Expr<LEFT> &x, const Expr<RIGHT> &y) {
1837 if (auto xvalue{GetScalarConstantValue<LEFT>(x)}) {
1838 if (auto yvalue{GetScalarConstantValue<RIGHT>(y)}) {
1839 return {std::make_pair(*xvalue, *yvalue)};
1840 }
1841 }
1842 return std::nullopt;
1843}
1844
1845template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
1846std::optional<std::pair<Scalar<LEFT>, Scalar<RIGHT>>> OperandsAreConstants(
1847 const Operation<DERIVED, RESULT, LEFT, RIGHT> &operation) {
1848 return OperandsAreConstants(operation.left(), operation.right());
1849}
1850
1851template <typename T>
1852Expr<T> FoldOperation(FoldingContext &context, Add<T> &&x) {
1853 if (auto array{ApplyElementwise(context, x)}) {
1854 return *array;
1855 }
1856 if (auto folded{OperandsAreConstants(x)}) {
1857 if constexpr (T::category == TypeCategory::Integer) {
1858 auto sum{folded->first.AddSigned(folded->second)};
1859 if (sum.overflow) {
1860 context.messages().Say(
1861 "INTEGER(%d) addition overflowed"_warn_en_US, T::kind);
1862 }
1863 return Expr<T>{Constant<T>{sum.value}};
1864 } else {
1865 auto sum{folded->first.Add(
1866 folded->second, context.targetCharacteristics().roundingMode())};
1867 RealFlagWarnings(context, sum.flags, "addition");
1868 if (context.targetCharacteristics().areSubnormalsFlushedToZero()) {
1869 sum.value = sum.value.FlushSubnormalToZero();
1870 }
1871 return Expr<T>{Constant<T>{sum.value}};
1872 }
1873 }
1874 return Expr<T>{std::move(x)};
1875}
1876
1877template <typename T>
1878Expr<T> FoldOperation(FoldingContext &context, Subtract<T> &&x) {
1879 if (auto array{ApplyElementwise(context, x)}) {
1880 return *array;
1881 }
1882 if (auto folded{OperandsAreConstants(x)}) {
1883 if constexpr (T::category == TypeCategory::Integer) {
1884 auto difference{folded->first.SubtractSigned(folded->second)};
1885 if (difference.overflow) {
1886 context.messages().Say(
1887 "INTEGER(%d) subtraction overflowed"_warn_en_US, T::kind);
1888 }
1889 return Expr<T>{Constant<T>{difference.value}};
1890 } else {
1891 auto difference{folded->first.Subtract(
1892 folded->second, context.targetCharacteristics().roundingMode())};
1893 RealFlagWarnings(context, difference.flags, "subtraction");
1894 if (context.targetCharacteristics().areSubnormalsFlushedToZero()) {
1895 difference.value = difference.value.FlushSubnormalToZero();
1896 }
1897 return Expr<T>{Constant<T>{difference.value}};
1898 }
1899 }
1900 return Expr<T>{std::move(x)};
1901}
1902
1903template <typename T>
1904Expr<T> FoldOperation(FoldingContext &context, Multiply<T> &&x) {
1905 if (auto array{ApplyElementwise(context, x)}) {
1906 return *array;
1907 }
1908 if (auto folded{OperandsAreConstants(x)}) {
1909 if constexpr (T::category == TypeCategory::Integer) {
1910 auto product{folded->first.MultiplySigned(folded->second)};
1911 if (product.SignedMultiplicationOverflowed()) {
1912 context.messages().Say(
1913 "INTEGER(%d) multiplication overflowed"_warn_en_US, T::kind);
1914 }
1915 return Expr<T>{Constant<T>{product.lower}};
1916 } else {
1917 auto product{folded->first.Multiply(
1918 folded->second, context.targetCharacteristics().roundingMode())};
1919 RealFlagWarnings(context, product.flags, "multiplication");
1920 if (context.targetCharacteristics().areSubnormalsFlushedToZero()) {
1921 product.value = product.value.FlushSubnormalToZero();
1922 }
1923 return Expr<T>{Constant<T>{product.value}};
1924 }
1925 } else if constexpr (T::category == TypeCategory::Integer) {
1926 if (auto c{GetScalarConstantValue<T>(x.right())}) {
1927 x.right() = std::move(x.left());
1928 x.left() = Expr<T>{std::move(*c)};
1929 }
1930 if (auto c{GetScalarConstantValue<T>(x.left())}) {
1931 if (c->IsZero() && x.right().Rank() == 0) {
1932 return std::move(x.left());
1933 } else if (c->CompareSigned(Scalar<T>{1}) == Ordering::Equal) {
1934 if (IsVariable(x.right())) {
1935 return FoldOperation(context, Parentheses<T>{std::move(x.right())});
1936 } else {
1937 return std::move(x.right());
1938 }
1939 } else if (c->CompareSigned(Scalar<T>{-1}) == Ordering::Equal) {
1940 return FoldOperation(context, Negate<T>{std::move(x.right())});
1941 }
1942 }
1943 }
1944 return Expr<T>{std::move(x)};
1945}
1946
1947template <typename T>
1948Expr<T> FoldOperation(FoldingContext &context, Divide<T> &&x) {
1949 if (auto array{ApplyElementwise(context, x)}) {
1950 return *array;
1951 }
1952 if (auto folded{OperandsAreConstants(x)}) {
1953 if constexpr (T::category == TypeCategory::Integer) {
1954 auto quotAndRem{folded->first.DivideSigned(folded->second)};
1955 if (quotAndRem.divisionByZero) {
1956 context.messages().Say(
1957 "INTEGER(%d) division by zero"_warn_en_US, T::kind);
1958 return Expr<T>{std::move(x)};
1959 }
1960 if (quotAndRem.overflow) {
1961 context.messages().Say(
1962 "INTEGER(%d) division overflowed"_warn_en_US, T::kind);
1963 }
1964 return Expr<T>{Constant<T>{quotAndRem.quotient}};
1965 } else {
1966 auto quotient{folded->first.Divide(
1967 folded->second, context.targetCharacteristics().roundingMode())};
1968 // Don't warn about -1./0., 0./0., or 1./0. from a module file
1969 // they are interpreted as canonical Fortran representations of -Inf,
1970 // NaN, and Inf respectively.
1971 bool isCanonicalNaNOrInf{false};
1972 if constexpr (T::category == TypeCategory::Real) {
1973 if (folded->second.IsZero() && context.moduleFileName().has_value()) {
1974 using IntType = typename T::Scalar::Word;
1975 auto intNumerator{folded->first.template ToInteger<IntType>()};
1976 isCanonicalNaNOrInf = intNumerator.flags == RealFlags{} &&
1977 intNumerator.value >= IntType{-1} &&
1978 intNumerator.value <= IntType{1};
1979 }
1980 }
1981 if (!isCanonicalNaNOrInf) {
1982 RealFlagWarnings(context, quotient.flags, "division");
1983 }
1984 if (context.targetCharacteristics().areSubnormalsFlushedToZero()) {
1985 quotient.value = quotient.value.FlushSubnormalToZero();
1986 }
1987 return Expr<T>{Constant<T>{quotient.value}};
1988 }
1989 }
1990 return Expr<T>{std::move(x)};
1991}
1992
1993template <typename T>
1994Expr<T> FoldOperation(FoldingContext &context, Power<T> &&x) {
1995 if (auto array{ApplyElementwise(context, x)}) {
1996 return *array;
1997 }
1998 if (auto folded{OperandsAreConstants(x)}) {
1999 if constexpr (T::category == TypeCategory::Integer) {
2000 auto power{folded->first.Power(folded->second)};
2001 if (power.divisionByZero) {
2002 context.messages().Say(
2003 "INTEGER(%d) zero to negative power"_warn_en_US, T::kind);
2004 } else if (power.overflow) {
2005 context.messages().Say(
2006 "INTEGER(%d) power overflowed"_warn_en_US, T::kind);
2007 } else if (power.zeroToZero) {
2008 context.messages().Say(
2009 "INTEGER(%d) 0**0 is not defined"_warn_en_US, T::kind);
2010 }
2011 return Expr<T>{Constant<T>{power.power}};
2012 } else {
2013 if (auto callable{GetHostRuntimeWrapper<T, T, T>("pow")}) {
2014 return Expr<T>{
2015 Constant<T>{(*callable)(context, folded->first, folded->second)}};
2016 } else {
2017 context.messages().Say(
2018 "Power for %s cannot be folded on host"_warn_en_US,
2019 T{}.AsFortran());
2020 }
2021 }
2022 }
2023 return Expr<T>{std::move(x)};
2024}
2025
2026template <typename T>
2027Expr<T> FoldOperation(FoldingContext &context, RealToIntPower<T> &&x) {
2028 if (auto array{ApplyElementwise(context, x)}) {
2029 return *array;
2030 }
2031 return common::visit(
2032 [&](auto &y) -> Expr<T> {
2033 if (auto folded{OperandsAreConstants(x.left(), y)}) {
2034 auto power{evaluate::IntPower(folded->first, folded->second)};
2035 RealFlagWarnings(context, power.flags, "power with INTEGER exponent");
2036 if (context.targetCharacteristics().areSubnormalsFlushedToZero()) {
2037 power.value = power.value.FlushSubnormalToZero();
2038 }
2039 return Expr<T>{Constant<T>{power.value}};
2040 } else {
2041 return Expr<T>{std::move(x)};
2042 }
2043 },
2044 x.right().u);
2045}
2046
2047template <typename T>
2048Expr<T> FoldOperation(FoldingContext &context, Extremum<T> &&x) {
2049 if (auto array{ApplyElementwise(context, x,
2050 std::function<Expr<T>(Expr<T> &&, Expr<T> &&)>{[=](Expr<T> &&l,
2051 Expr<T> &&r) {
2052 return Expr<T>{Extremum<T>{x.ordering, std::move(l), std::move(r)}};
2053 }})}) {
2054 return *array;
2055 }
2056 if (auto folded{OperandsAreConstants(x)}) {
2057 if constexpr (T::category == TypeCategory::Integer) {
2058 if (folded->first.CompareSigned(folded->second) == x.ordering) {
2059 return Expr<T>{Constant<T>{folded->first}};
2060 }
2061 } else if constexpr (T::category == TypeCategory::Real) {
2062 if (folded->first.IsNotANumber() ||
2063 (folded->first.Compare(folded->second) == Relation::Less) ==
2064 (x.ordering == Ordering::Less)) {
2065 return Expr<T>{Constant<T>{folded->first}};
2066 }
2067 } else {
2068 static_assert(T::category == TypeCategory::Character);
2069 // Result of MIN and MAX on character has the length of
2070 // the longest argument.
2071 auto maxLen{std::max(folded->first.length(), folded->second.length())};
2072 bool isFirst{x.ordering == Compare(folded->first, folded->second)};
2073 auto res{isFirst ? std::move(folded->first) : std::move(folded->second)};
2074 res = res.length() == maxLen
2075 ? std::move(res)
2076 : CharacterUtils<T::kind>::Resize(res, maxLen);
2077 return Expr<T>{Constant<T>{std::move(res)}};
2078 }
2079 return Expr<T>{Constant<T>{folded->second}};
2080 }
2081 return Expr<T>{std::move(x)};
2082}
2083
2084template <int KIND>
2085Expr<Type<TypeCategory::Real, KIND>> ToReal(
2086 FoldingContext &context, Expr<SomeType> &&expr) {
2087 using Result = Type<TypeCategory::Real, KIND>;
2088 std::optional<Expr<Result>> result;
2089 common::visit(
2090 [&](auto &&x) {
2091 using From = std::decay_t<decltype(x)>;
2092 if constexpr (std::is_same_v<From, BOZLiteralConstant>) {
2093 // Move the bits without any integer->real conversion
2094 From original{x};
2095 result = ConvertToType<Result>(std::move(x));
2096 const auto *constant{UnwrapExpr<Constant<Result>>(*result)};
2097 CHECK(constant);
2098 Scalar<Result> real{constant->GetScalarValue().value()};
2099 From converted{From::ConvertUnsigned(real.RawBits()).value};
2100 if (original != converted) { // C1601
2101 context.messages().Say(
2102 "Nonzero bits truncated from BOZ literal constant in REAL intrinsic"_warn_en_US);
2103 }
2104 } else if constexpr (IsNumericCategoryExpr<From>()) {
2105 result = Fold(context, ConvertToType<Result>(std::move(x)));
2106 } else {
2107 common::die("ToReal: bad argument expression");
2108 }
2109 },
2110 std::move(expr.u));
2111 return result.value();
2112}
2113
2114// REAL(z) and AIMAG(z)
2115template <int KIND>
2116Expr<Type<TypeCategory::Real, KIND>> FoldOperation(
2117 FoldingContext &context, ComplexComponent<KIND> &&x) {
2118 using Operand = Type<TypeCategory::Complex, KIND>;
2119 using Result = Type<TypeCategory::Real, KIND>;
2120 if (auto array{ApplyElementwise(context, x,
2121 std::function<Expr<Result>(Expr<Operand> &&)>{
2122 [=](Expr<Operand> &&operand) {
2123 return Expr<Result>{ComplexComponent<KIND>{
2124 x.isImaginaryPart, std::move(operand)}};
2125 }})}) {
2126 return *array;
2127 }
2128 auto &operand{x.left()};
2129 if (auto value{GetScalarConstantValue<Operand>(operand)}) {
2130 if (x.isImaginaryPart) {
2131 return Expr<Result>{Constant<Result>{value->AIMAG()}};
2132 } else {
2133 return Expr<Result>{Constant<Result>{value->REAL()}};
2134 }
2135 }
2136 return Expr<Result>{std::move(x)};
2137}
2138
2139template <typename T>
2140Expr<T> ExpressionBase<T>::Rewrite(FoldingContext &context, Expr<T> &&expr) {
2141 return common::visit(
2142 [&](auto &&x) -> Expr<T> {
2143 if constexpr (IsSpecificIntrinsicType<T>) {
2144 return FoldOperation(context, std::move(x));
2145 } else if constexpr (std::is_same_v<T, SomeDerived>) {
2146 return FoldOperation(context, std::move(x));
2147 } else if constexpr (common::HasMember<decltype(x),
2148 TypelessExpression>) {
2149 return std::move(expr);
2150 } else {
2151 return Expr<T>{Fold(context, std::move(x))};
2152 }
2153 },
2154 std::move(expr.u));
2155}
2156
2157FOR_EACH_TYPE_AND_KIND(extern template class ExpressionBase, )
2158} // namespace Fortran::evaluate
2159#endif // FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_
2160

source code of flang/lib/Evaluate/fold-implementation.h