1//===-- lib/Evaluate/fold-reduction.h -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
10#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
11
12#include "fold-implementation.h"
13
14namespace Fortran::evaluate {
15
16// DOT_PRODUCT
17template <typename T>
18static Expr<T> FoldDotProduct(
19 FoldingContext &context, FunctionRef<T> &&funcRef) {
20 using Element = typename Constant<T>::Element;
21 auto args{funcRef.arguments()};
22 CHECK(args.size() == 2);
23 Folder<T> folder{context};
24 Constant<T> *va{folder.Folding(args[0])};
25 Constant<T> *vb{folder.Folding(args[1])};
26 if (va && vb) {
27 CHECK(va->Rank() == 1 && vb->Rank() == 1);
28 if (va->size() != vb->size()) {
29 context.messages().Say(
30 "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US,
31 va->size(), vb->size());
32 return MakeInvalidIntrinsic(std::move(funcRef));
33 }
34 Element sum{};
35 bool overflow{false};
36 if constexpr (T::category == TypeCategory::Complex) {
37 std::vector<Element> conjugates;
38 for (const Element &x : va->values()) {
39 conjugates.emplace_back(x.CONJG());
40 }
41 Constant<T> conjgA{
42 std::move(conjugates), ConstantSubscripts{va->shape()}};
43 Expr<T> products{Fold(
44 context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
45 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
46 Element correction{}; // Use Kahan summation for greater precision.
47 const auto &rounding{context.targetCharacteristics().roundingMode()};
48 for (const Element &x : cProducts.values()) {
49 auto next{correction.Add(x, rounding)};
50 overflow |= next.flags.test(RealFlag::Overflow);
51 auto added{sum.Add(next.value, rounding)};
52 overflow |= added.flags.test(RealFlag::Overflow);
53 correction = added.value.Subtract(sum, rounding)
54 .value.Subtract(next.value, rounding)
55 .value;
56 sum = std::move(added.value);
57 }
58 } else if constexpr (T::category == TypeCategory::Logical) {
59 Expr<T> conjunctions{Fold(context,
60 Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And,
61 Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})};
62 Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))};
63 for (const Element &x : cConjunctions.values()) {
64 if (x.IsTrue()) {
65 sum = Element{true};
66 break;
67 }
68 }
69 } else if constexpr (T::category == TypeCategory::Integer) {
70 Expr<T> products{
71 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
72 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
73 for (const Element &x : cProducts.values()) {
74 auto next{sum.AddSigned(x)};
75 overflow |= next.overflow;
76 sum = std::move(next.value);
77 }
78 } else {
79 static_assert(T::category == TypeCategory::Real);
80 Expr<T> products{
81 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
82 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
83 Element correction{}; // Use Kahan summation for greater precision.
84 const auto &rounding{context.targetCharacteristics().roundingMode()};
85 for (const Element &x : cProducts.values()) {
86 auto next{correction.Add(x, rounding)};
87 overflow |= next.flags.test(RealFlag::Overflow);
88 auto added{sum.Add(next.value, rounding)};
89 overflow |= added.flags.test(RealFlag::Overflow);
90 correction = added.value.Subtract(sum, rounding)
91 .value.Subtract(next.value, rounding)
92 .value;
93 sum = std::move(added.value);
94 }
95 }
96 if (overflow) {
97 context.messages().Say(
98 "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US,
99 T::AsFortran());
100 }
101 return Expr<T>{Constant<T>{std::move(sum)}};
102 }
103 return Expr<T>{std::move(funcRef)};
104}
105
106// Fold and validate a DIM= argument. Returns false on error.
107bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &,
108 ActualArguments &, std::optional<int> dimIndex, int rank);
109
110// Fold and validate a MASK= argument. Return null on error, absent MASK=, or
111// non-constant MASK=.
112Constant<LogicalResult> *GetReductionMASK(
113 std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape,
114 FoldingContext &);
115
116// Common preprocessing for reduction transformational intrinsic function
117// folding. If the intrinsic can have DIM= &/or MASK= arguments, extract
118// and check them. If a MASK= is present, apply it to the array data and
119// substitute replacement values for elements corresponding to .FALSE. in
120// the mask. If the result is present, the intrinsic call can be folded.
121template <typename T> struct ArrayAndMask {
122 Constant<T> array;
123 Constant<LogicalResult> mask;
124};
125template <typename T>
126static std::optional<ArrayAndMask<T>> ProcessReductionArgs(
127 FoldingContext &context, ActualArguments &arg, std::optional<int> &dim,
128 int arrayIndex, std::optional<int> dimIndex = std::nullopt,
129 std::optional<int> maskIndex = std::nullopt) {
130 if (arg.empty()) {
131 return std::nullopt;
132 }
133 Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])};
134 if (!folded || folded->Rank() < 1) {
135 return std::nullopt;
136 }
137 if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) {
138 return std::nullopt;
139 }
140 std::size_t n{folded->size()};
141 std::vector<Scalar<LogicalResult>> maskElement;
142 if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() &&
143 arg[*maskIndex]) {
144 if (const Constant<LogicalResult> *origMask{
145 GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) {
146 if (auto scalarMask{origMask->GetScalarValue()}) {
147 maskElement =
148 std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue());
149 } else {
150 maskElement = origMask->values();
151 }
152 } else {
153 return std::nullopt;
154 }
155 } else {
156 maskElement = std::vector<Scalar<LogicalResult>>(n, true);
157 }
158 return ArrayAndMask<T>{Constant<T>(*folded),
159 Constant<LogicalResult>{
160 std::move(maskElement), ConstantSubscripts{folded->shape()}}};
161}
162
163// Generalized reduction to an array of one dimension fewer (w/ DIM=)
164// or to a scalar (w/o DIM=). The ACCUMULATOR type must define
165// operator()(Scalar<T> &, const ConstantSubscripts &, bool first)
166// and Done(Scalar<T> &).
167template <typename T, typename ACCUMULATOR, typename ARRAY>
168static Constant<T> DoReduction(const Constant<ARRAY> &array,
169 const Constant<LogicalResult> &mask, std::optional<int> &dim,
170 const Scalar<T> &identity, ACCUMULATOR &accumulator) {
171 ConstantSubscripts at{array.lbounds()};
172 ConstantSubscripts maskAt{mask.lbounds()};
173 std::vector<typename Constant<T>::Element> elements;
174 ConstantSubscripts resultShape; // empty -> scalar
175 if (dim) { // DIM= is present, so result is an array
176 resultShape = array.shape();
177 resultShape.erase(resultShape.begin() + (*dim - 1));
178 ConstantSubscript dimExtent{array.shape().at(*dim - 1)};
179 CHECK(dimExtent == mask.shape().at(*dim - 1));
180 ConstantSubscript &dimAt{at[*dim - 1]};
181 ConstantSubscript dimLbound{dimAt};
182 ConstantSubscript &maskDimAt{maskAt[*dim - 1]};
183 ConstantSubscript maskDimLbound{maskDimAt};
184 for (auto n{GetSize(resultShape)}; n-- > 0;
185 array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
186 elements.push_back(identity);
187 if (dimExtent > 0) {
188 dimAt = dimLbound;
189 maskDimAt = maskDimLbound;
190 bool firstUnmasked{true};
191 for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) {
192 if (mask.At(maskAt).IsTrue()) {
193 accumulator(elements.back(), at, firstUnmasked);
194 firstUnmasked = false;
195 }
196 }
197 --dimAt, --maskDimAt;
198 }
199 accumulator.Done(elements.back());
200 }
201 } else { // no DIM=, result is scalar
202 elements.push_back(identity);
203 bool firstUnmasked{true};
204 for (auto n{array.size()}; n-- > 0;
205 array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
206 if (mask.At(maskAt).IsTrue()) {
207 accumulator(elements.back(), at, firstUnmasked);
208 firstUnmasked = false;
209 }
210 }
211 accumulator.Done(elements.back());
212 }
213 if constexpr (T::category == TypeCategory::Character) {
214 return {static_cast<ConstantSubscript>(identity.size()),
215 std::move(elements), std::move(resultShape)};
216 } else {
217 return {std::move(elements), std::move(resultShape)};
218 }
219}
220
221// MAXVAL & MINVAL
222template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
223public:
224 MaxvalMinvalAccumulator(
225 RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
226 : opr_{opr}, context_{context}, array_{array} {};
227 void operator()(Scalar<T> &element, const ConstantSubscripts &at,
228 [[maybe_unused]] bool firstUnmasked) const {
229 auto aAt{array_.At(at)};
230 if constexpr (ABS) {
231 aAt = aAt.ABS();
232 }
233 if constexpr (T::category == TypeCategory::Real) {
234 if (firstUnmasked || element.IsNotANumber()) {
235 // Return NaN if and only if all unmasked elements are NaNs and
236 // at least one unmasked element is visible.
237 element = aAt;
238 return;
239 }
240 }
241 Expr<LogicalResult> test{PackageRelation(
242 opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
243 auto folded{GetScalarConstantValue<LogicalResult>(
244 test.Rewrite(context_, std::move(test)))};
245 CHECK(folded.has_value());
246 if (folded->IsTrue()) {
247 element = aAt;
248 }
249 }
250 void Done(Scalar<T> &) const {}
251
252private:
253 RelationalOperator opr_;
254 FoldingContext &context_;
255 const Constant<T> &array_;
256};
257
258template <typename T>
259static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
260 RelationalOperator opr, const Scalar<T> &identity) {
261 static_assert(T::category == TypeCategory::Integer ||
262 T::category == TypeCategory::Real ||
263 T::category == TypeCategory::Character);
264 std::optional<int> dim;
265 if (std::optional<ArrayAndMask<T>> arrayAndMask{
266 ProcessReductionArgs<T>(context, ref.arguments(), dim,
267 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
268 MaxvalMinvalAccumulator accumulator{opr, context, arrayAndMask->array};
269 return Expr<T>{DoReduction<T>(
270 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
271 }
272 return Expr<T>{std::move(ref)};
273}
274
275// PRODUCT
276template <typename T> class ProductAccumulator {
277public:
278 ProductAccumulator(const Constant<T> &array) : array_{array} {}
279 void operator()(
280 Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
281 if constexpr (T::category == TypeCategory::Integer) {
282 auto prod{element.MultiplySigned(array_.At(at))};
283 overflow_ |= prod.SignedMultiplicationOverflowed();
284 element = prod.lower;
285 } else { // Real & Complex
286 auto prod{element.Multiply(array_.At(at))};
287 overflow_ |= prod.flags.test(RealFlag::Overflow);
288 element = prod.value;
289 }
290 }
291 bool overflow() const { return overflow_; }
292 void Done(Scalar<T> &) const {}
293
294private:
295 const Constant<T> &array_;
296 bool overflow_{false};
297};
298
299template <typename T>
300static Expr<T> FoldProduct(
301 FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
302 static_assert(T::category == TypeCategory::Integer ||
303 T::category == TypeCategory::Real ||
304 T::category == TypeCategory::Complex);
305 std::optional<int> dim;
306 if (std::optional<ArrayAndMask<T>> arrayAndMask{
307 ProcessReductionArgs<T>(context, ref.arguments(), dim,
308 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
309 ProductAccumulator accumulator{arrayAndMask->array};
310 auto result{Expr<T>{DoReduction<T>(
311 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
312 if (accumulator.overflow()) {
313 context.messages().Say(
314 "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
315 }
316 return result;
317 }
318 return Expr<T>{std::move(ref)};
319}
320
321// SUM
322template <typename T> class SumAccumulator {
323 using Element = typename Constant<T>::Element;
324
325public:
326 SumAccumulator(const Constant<T> &array, Rounding rounding)
327 : array_{array}, rounding_{rounding} {}
328 void operator()(
329 Element &element, const ConstantSubscripts &at, bool /*first*/) {
330 if constexpr (T::category == TypeCategory::Integer) {
331 auto sum{element.AddSigned(array_.At(at))};
332 overflow_ |= sum.overflow;
333 element = sum.value;
334 } else { // Real & Complex: use Kahan summation
335 auto next{array_.At(at).Add(correction_, rounding_)};
336 overflow_ |= next.flags.test(RealFlag::Overflow);
337 auto sum{element.Add(next.value, rounding_)};
338 overflow_ |= sum.flags.test(RealFlag::Overflow);
339 // correction = (sum - element) - next; algebraically zero
340 correction_ = sum.value.Subtract(element, rounding_)
341 .value.Subtract(next.value, rounding_)
342 .value;
343 element = sum.value;
344 }
345 }
346 bool overflow() const { return overflow_; }
347 void Done([[maybe_unused]] Element &element) {
348 if constexpr (T::category != TypeCategory::Integer) {
349 auto corrected{element.Add(correction_, rounding_)};
350 overflow_ |= corrected.flags.test(RealFlag::Overflow);
351 correction_ = Scalar<T>{};
352 element = corrected.value;
353 }
354 }
355
356private:
357 const Constant<T> &array_;
358 Rounding rounding_;
359 bool overflow_{false};
360 Element correction_{};
361};
362
363template <typename T>
364static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
365 static_assert(T::category == TypeCategory::Integer ||
366 T::category == TypeCategory::Real ||
367 T::category == TypeCategory::Complex);
368 using Element = typename Constant<T>::Element;
369 std::optional<int> dim;
370 Element identity{};
371 if (std::optional<ArrayAndMask<T>> arrayAndMask{
372 ProcessReductionArgs<T>(context, ref.arguments(), dim,
373 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
374 SumAccumulator accumulator{
375 arrayAndMask->array, context.targetCharacteristics().roundingMode()};
376 auto result{Expr<T>{DoReduction<T>(
377 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
378 if (accumulator.overflow()) {
379 context.messages().Say(
380 "SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
381 }
382 return result;
383 }
384 return Expr<T>{std::move(ref)};
385}
386
387// Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY
388template <typename T> class OperationAccumulator {
389public:
390 OperationAccumulator(const Constant<T> &array,
391 Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
392 : array_{array}, operation_{operation} {}
393 void operator()(
394 Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
395 element = (element.*operation_)(array_.At(at));
396 }
397 void Done(Scalar<T> &) const {}
398
399private:
400 const Constant<T> &array_;
401 Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const;
402};
403
404} // namespace Fortran::evaluate
405#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_
406

source code of flang/lib/Evaluate/fold-reduction.h