1 | //===-- lib/Evaluate/fold-reduction.h -------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_ |
10 | #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_ |
11 | |
12 | #include "fold-implementation.h" |
13 | |
14 | namespace Fortran::evaluate { |
15 | |
16 | // DOT_PRODUCT |
17 | template <typename T> |
18 | static Expr<T> FoldDotProduct( |
19 | FoldingContext &context, FunctionRef<T> &&funcRef) { |
20 | using Element = typename Constant<T>::Element; |
21 | auto args{funcRef.arguments()}; |
22 | CHECK(args.size() == 2); |
23 | Folder<T> folder{context}; |
24 | Constant<T> *va{folder.Folding(args[0])}; |
25 | Constant<T> *vb{folder.Folding(args[1])}; |
26 | if (va && vb) { |
27 | CHECK(va->Rank() == 1 && vb->Rank() == 1); |
28 | if (va->size() != vb->size()) { |
29 | context.messages().Say( |
30 | "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US , |
31 | va->size(), vb->size()); |
32 | return MakeInvalidIntrinsic(std::move(funcRef)); |
33 | } |
34 | Element sum{}; |
35 | bool overflow{false}; |
36 | if constexpr (T::category == TypeCategory::Complex) { |
37 | std::vector<Element> conjugates; |
38 | for (const Element &x : va->values()) { |
39 | conjugates.emplace_back(x.CONJG()); |
40 | } |
41 | Constant<T> conjgA{ |
42 | std::move(conjugates), ConstantSubscripts{va->shape()}}; |
43 | Expr<T> products{Fold( |
44 | context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})}; |
45 | Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; |
46 | [[maybe_unused]] Element correction{}; |
47 | const auto &rounding{context.targetCharacteristics().roundingMode()}; |
48 | for (const Element &x : cProducts.values()) { |
49 | if constexpr (useKahanSummation) { |
50 | auto next{x.Subtract(correction, rounding)}; |
51 | overflow |= next.flags.test(RealFlag::Overflow); |
52 | auto added{sum.Add(next.value, rounding)}; |
53 | overflow |= added.flags.test(RealFlag::Overflow); |
54 | correction = added.value.Subtract(sum, rounding) |
55 | .value.Subtract(next.value, rounding) |
56 | .value; |
57 | sum = std::move(added.value); |
58 | } else { |
59 | auto added{sum.Add(x, rounding)}; |
60 | overflow |= added.flags.test(RealFlag::Overflow); |
61 | sum = std::move(added.value); |
62 | } |
63 | } |
64 | } else if constexpr (T::category == TypeCategory::Logical) { |
65 | Expr<T> conjunctions{Fold(context, |
66 | Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And, |
67 | Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})}; |
68 | Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))}; |
69 | for (const Element &x : cConjunctions.values()) { |
70 | if (x.IsTrue()) { |
71 | sum = Element{true}; |
72 | break; |
73 | } |
74 | } |
75 | } else if constexpr (T::category == TypeCategory::Integer) { |
76 | Expr<T> products{ |
77 | Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; |
78 | Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; |
79 | for (const Element &x : cProducts.values()) { |
80 | auto next{sum.AddSigned(x)}; |
81 | overflow |= next.overflow; |
82 | sum = std::move(next.value); |
83 | } |
84 | } else if constexpr (T::category == TypeCategory::Unsigned) { |
85 | Expr<T> products{ |
86 | Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; |
87 | Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; |
88 | for (const Element &x : cProducts.values()) { |
89 | sum = sum.AddUnsigned(x).value; |
90 | } |
91 | } else { |
92 | static_assert(T::category == TypeCategory::Real); |
93 | Expr<T> products{ |
94 | Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; |
95 | Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; |
96 | [[maybe_unused]] Element correction{}; |
97 | const auto &rounding{context.targetCharacteristics().roundingMode()}; |
98 | for (const Element &x : cProducts.values()) { |
99 | if constexpr (useKahanSummation) { |
100 | auto next{x.Subtract(correction, rounding)}; |
101 | overflow |= next.flags.test(RealFlag::Overflow); |
102 | auto added{sum.Add(next.value, rounding)}; |
103 | overflow |= added.flags.test(RealFlag::Overflow); |
104 | correction = added.value.Subtract(sum, rounding) |
105 | .value.Subtract(next.value, rounding) |
106 | .value; |
107 | sum = std::move(added.value); |
108 | } else { |
109 | auto added{sum.Add(x, rounding)}; |
110 | overflow |= added.flags.test(RealFlag::Overflow); |
111 | sum = std::move(added.value); |
112 | } |
113 | } |
114 | } |
115 | if (overflow && |
116 | context.languageFeatures().ShouldWarn( |
117 | common::UsageWarning::FoldingException)) { |
118 | context.messages().Say(common::UsageWarning::FoldingException, |
119 | "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US , |
120 | T::AsFortran()); |
121 | } |
122 | return Expr<T>{Constant<T>{std::move(sum)}}; |
123 | } |
124 | return Expr<T>{std::move(funcRef)}; |
125 | } |
126 | |
127 | // Fold and validate a DIM= argument. Returns false on error. |
128 | bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &, |
129 | ActualArguments &, std::optional<int> dimIndex, int rank); |
130 | |
131 | // Fold and validate a MASK= argument. Return null on error, absent MASK=, or |
132 | // non-constant MASK=. |
133 | Constant<LogicalResult> *GetReductionMASK( |
134 | std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape, |
135 | FoldingContext &); |
136 | |
137 | // Common preprocessing for reduction transformational intrinsic function |
138 | // folding. If the intrinsic can have DIM= &/or MASK= arguments, extract |
139 | // and check them. If a MASK= is present, apply it to the array data and |
140 | // substitute replacement values for elements corresponding to .FALSE. in |
141 | // the mask. If the result is present, the intrinsic call can be folded. |
142 | template <typename T> struct ArrayAndMask { |
143 | Constant<T> array; |
144 | Constant<LogicalResult> mask; |
145 | }; |
146 | template <typename T> |
147 | static std::optional<ArrayAndMask<T>> ProcessReductionArgs( |
148 | FoldingContext &context, ActualArguments &arg, std::optional<int> &dim, |
149 | int arrayIndex, std::optional<int> dimIndex = std::nullopt, |
150 | std::optional<int> maskIndex = std::nullopt) { |
151 | if (arg.empty()) { |
152 | return std::nullopt; |
153 | } |
154 | Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])}; |
155 | if (!folded || folded->Rank() < 1) { |
156 | return std::nullopt; |
157 | } |
158 | if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) { |
159 | return std::nullopt; |
160 | } |
161 | std::size_t n{folded->size()}; |
162 | std::vector<Scalar<LogicalResult>> maskElement; |
163 | if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() && |
164 | arg[*maskIndex]) { |
165 | if (const Constant<LogicalResult> *origMask{ |
166 | GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) { |
167 | if (auto scalarMask{origMask->GetScalarValue()}) { |
168 | maskElement = |
169 | std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue()); |
170 | } else { |
171 | maskElement = origMask->values(); |
172 | } |
173 | } else { |
174 | return std::nullopt; |
175 | } |
176 | } else { |
177 | maskElement = std::vector<Scalar<LogicalResult>>(n, true); |
178 | } |
179 | return ArrayAndMask<T>{Constant<T>(*folded), |
180 | Constant<LogicalResult>{ |
181 | std::move(maskElement), ConstantSubscripts{folded->shape()}}}; |
182 | } |
183 | |
184 | // Generalized reduction to an array of one dimension fewer (w/ DIM=) |
185 | // or to a scalar (w/o DIM=). The ACCUMULATOR type must define |
186 | // operator()(Scalar<T> &, const ConstantSubscripts &, bool first) |
187 | // and Done(Scalar<T> &). |
188 | template <typename T, typename ACCUMULATOR, typename ARRAY> |
189 | static Constant<T> DoReduction(const Constant<ARRAY> &array, |
190 | const Constant<LogicalResult> &mask, std::optional<int> &dim, |
191 | const Scalar<T> &identity, ACCUMULATOR &accumulator) { |
192 | ConstantSubscripts at{array.lbounds()}; |
193 | ConstantSubscripts maskAt{mask.lbounds()}; |
194 | std::vector<typename Constant<T>::Element> elements; |
195 | ConstantSubscripts resultShape; // empty -> scalar |
196 | if (dim) { // DIM= is present, so result is an array |
197 | resultShape = array.shape(); |
198 | resultShape.erase(resultShape.begin() + (*dim - 1)); |
199 | ConstantSubscript dimExtent{array.shape().at(*dim - 1)}; |
200 | CHECK(dimExtent == mask.shape().at(*dim - 1)); |
201 | ConstantSubscript &dimAt{at[*dim - 1]}; |
202 | ConstantSubscript dimLbound{dimAt}; |
203 | ConstantSubscript &maskDimAt{maskAt[*dim - 1]}; |
204 | ConstantSubscript maskDimLbound{maskDimAt}; |
205 | for (auto n{GetSize(resultShape)}; n-- > 0; |
206 | array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) { |
207 | elements.push_back(identity); |
208 | if (dimExtent > 0) { |
209 | dimAt = dimLbound; |
210 | maskDimAt = maskDimLbound; |
211 | bool firstUnmasked{true}; |
212 | for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) { |
213 | if (mask.At(maskAt).IsTrue()) { |
214 | accumulator(elements.back(), at, firstUnmasked); |
215 | firstUnmasked = false; |
216 | } |
217 | } |
218 | --dimAt, --maskDimAt; |
219 | } |
220 | accumulator.Done(elements.back()); |
221 | } |
222 | } else { // no DIM=, result is scalar |
223 | elements.push_back(identity); |
224 | bool firstUnmasked{true}; |
225 | for (auto n{array.size()}; n-- > 0; |
226 | array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) { |
227 | if (mask.At(maskAt).IsTrue()) { |
228 | accumulator(elements.back(), at, firstUnmasked); |
229 | firstUnmasked = false; |
230 | } |
231 | } |
232 | accumulator.Done(elements.back()); |
233 | } |
234 | if constexpr (T::category == TypeCategory::Character) { |
235 | return {static_cast<ConstantSubscript>(identity.size()), |
236 | std::move(elements), std::move(resultShape)}; |
237 | } else { |
238 | return {std::move(elements), std::move(resultShape)}; |
239 | } |
240 | } |
241 | |
242 | // MAXVAL & MINVAL |
243 | template <typename T, bool ABS = false> class MaxvalMinvalAccumulator { |
244 | public: |
245 | MaxvalMinvalAccumulator( |
246 | RelationalOperator opr, FoldingContext &context, const Constant<T> &array) |
247 | : opr_{opr}, context_{context}, array_{array} {}; |
248 | void operator()(Scalar<T> &element, const ConstantSubscripts &at, |
249 | [[maybe_unused]] bool firstUnmasked) const { |
250 | auto aAt{array_.At(at)}; |
251 | if constexpr (ABS) { |
252 | aAt = aAt.ABS(); |
253 | } |
254 | if constexpr (T::category == TypeCategory::Real) { |
255 | if (firstUnmasked || element.IsNotANumber()) { |
256 | // Return NaN if and only if all unmasked elements are NaNs and |
257 | // at least one unmasked element is visible. |
258 | element = aAt; |
259 | return; |
260 | } |
261 | } |
262 | Expr<LogicalResult> test{PackageRelation( |
263 | opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})}; |
264 | auto folded{GetScalarConstantValue<LogicalResult>( |
265 | test.Rewrite(context_, std::move(test)))}; |
266 | CHECK(folded.has_value()); |
267 | if (folded->IsTrue()) { |
268 | element = aAt; |
269 | } |
270 | } |
271 | void Done(Scalar<T> &) const {} |
272 | |
273 | private: |
274 | RelationalOperator opr_; |
275 | FoldingContext &context_; |
276 | const Constant<T> &array_; |
277 | }; |
278 | |
279 | template <typename T> |
280 | static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref, |
281 | RelationalOperator opr, const Scalar<T> &identity) { |
282 | static_assert(T::category == TypeCategory::Integer || |
283 | T::category == TypeCategory::Unsigned || |
284 | T::category == TypeCategory::Real || |
285 | T::category == TypeCategory::Character); |
286 | std::optional<int> dim; |
287 | if (std::optional<ArrayAndMask<T>> arrayAndMask{ |
288 | ProcessReductionArgs<T>(context, ref.arguments(), dim, |
289 | /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { |
290 | MaxvalMinvalAccumulator<T> accumulator{opr, context, arrayAndMask->array}; |
291 | return Expr<T>{DoReduction<T>( |
292 | arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}; |
293 | } |
294 | return Expr<T>{std::move(ref)}; |
295 | } |
296 | |
297 | // PRODUCT |
298 | template <typename T> class ProductAccumulator { |
299 | public: |
300 | ProductAccumulator(const Constant<T> &array) : array_{array} {} |
301 | void operator()( |
302 | Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) { |
303 | if constexpr (T::category == TypeCategory::Integer) { |
304 | auto prod{element.MultiplySigned(array_.At(at))}; |
305 | overflow_ |= prod.SignedMultiplicationOverflowed(); |
306 | element = prod.lower; |
307 | } else if constexpr (T::category == TypeCategory::Unsigned) { |
308 | element = element.MultiplyUnsigned(array_.At(at)).lower; |
309 | } else { // Real & Complex |
310 | auto prod{element.Multiply(array_.At(at))}; |
311 | overflow_ |= prod.flags.test(RealFlag::Overflow); |
312 | element = prod.value; |
313 | } |
314 | } |
315 | bool overflow() const { return overflow_; } |
316 | void Done(Scalar<T> &) const {} |
317 | |
318 | private: |
319 | const Constant<T> &array_; |
320 | bool overflow_{false}; |
321 | }; |
322 | |
323 | template <typename T> |
324 | static Expr<T> FoldProduct( |
325 | FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) { |
326 | static_assert(T::category == TypeCategory::Integer || |
327 | T::category == TypeCategory::Unsigned || |
328 | T::category == TypeCategory::Real || |
329 | T::category == TypeCategory::Complex); |
330 | std::optional<int> dim; |
331 | if (std::optional<ArrayAndMask<T>> arrayAndMask{ |
332 | ProcessReductionArgs<T>(context, ref.arguments(), dim, |
333 | /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { |
334 | ProductAccumulator accumulator{arrayAndMask->array}; |
335 | auto result{Expr<T>{DoReduction<T>( |
336 | arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}}; |
337 | if (accumulator.overflow() && |
338 | context.languageFeatures().ShouldWarn( |
339 | common::UsageWarning::FoldingException)) { |
340 | context.messages().Say(common::UsageWarning::FoldingException, |
341 | "PRODUCT() of %s data overflowed"_warn_en_US , T::AsFortran()); |
342 | } |
343 | return result; |
344 | } |
345 | return Expr<T>{std::move(ref)}; |
346 | } |
347 | |
348 | // SUM |
349 | template <typename T> class SumAccumulator { |
350 | using Element = typename Constant<T>::Element; |
351 | |
352 | public: |
353 | SumAccumulator(const Constant<T> &array, Rounding rounding) |
354 | : array_{array}, rounding_{rounding} {} |
355 | void operator()( |
356 | Element &element, const ConstantSubscripts &at, bool /*first*/) { |
357 | if constexpr (T::category == TypeCategory::Integer) { |
358 | auto sum{element.AddSigned(array_.At(at))}; |
359 | overflow_ |= sum.overflow; |
360 | element = sum.value; |
361 | } else if constexpr (T::category == TypeCategory::Unsigned) { |
362 | element = element.AddUnsigned(array_.At(at)).value; |
363 | } else { // Real & Complex: use Kahan summation |
364 | auto next{array_.At(at).Subtract(correction_, rounding_)}; |
365 | overflow_ |= next.flags.test(RealFlag::Overflow); |
366 | auto sum{element.Add(next.value, rounding_)}; |
367 | overflow_ |= sum.flags.test(RealFlag::Overflow); |
368 | // correction = (sum - element) - next; algebraically zero |
369 | correction_ = sum.value.Subtract(element, rounding_) |
370 | .value.Subtract(next.value, rounding_) |
371 | .value; |
372 | element = sum.value; |
373 | } |
374 | } |
375 | bool overflow() const { return overflow_; } |
376 | void Done([[maybe_unused]] Element &element) { |
377 | if constexpr (T::category != TypeCategory::Integer && |
378 | T::category != TypeCategory::Unsigned) { |
379 | auto corrected{element.Add(correction_, rounding_)}; |
380 | overflow_ |= corrected.flags.test(RealFlag::Overflow); |
381 | correction_ = Scalar<T>{}; |
382 | element = corrected.value; |
383 | } |
384 | } |
385 | |
386 | private: |
387 | const Constant<T> &array_; |
388 | Rounding rounding_; |
389 | bool overflow_{false}; |
390 | Element correction_{}; |
391 | }; |
392 | |
393 | template <typename T> |
394 | static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) { |
395 | static_assert(T::category == TypeCategory::Integer || |
396 | T::category == TypeCategory::Unsigned || |
397 | T::category == TypeCategory::Real || |
398 | T::category == TypeCategory::Complex); |
399 | using Element = typename Constant<T>::Element; |
400 | std::optional<int> dim; |
401 | Element identity{}; |
402 | if (std::optional<ArrayAndMask<T>> arrayAndMask{ |
403 | ProcessReductionArgs<T>(context, ref.arguments(), dim, |
404 | /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { |
405 | SumAccumulator accumulator{ |
406 | arrayAndMask->array, context.targetCharacteristics().roundingMode()}; |
407 | auto result{Expr<T>{DoReduction<T>( |
408 | arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}}; |
409 | if (accumulator.overflow() && |
410 | context.languageFeatures().ShouldWarn( |
411 | common::UsageWarning::FoldingException)) { |
412 | context.messages().Say(common::UsageWarning::FoldingException, |
413 | "SUM() of %s data overflowed"_warn_en_US , T::AsFortran()); |
414 | } |
415 | return result; |
416 | } |
417 | return Expr<T>{std::move(ref)}; |
418 | } |
419 | |
420 | // Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY |
421 | template <typename T> class OperationAccumulator { |
422 | public: |
423 | OperationAccumulator(const Constant<T> &array, |
424 | Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const) |
425 | : array_{array}, operation_{operation} {} |
426 | void operator()( |
427 | Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) { |
428 | element = (element.*operation_)(array_.At(at)); |
429 | } |
430 | void Done(Scalar<T> &) const {} |
431 | |
432 | private: |
433 | const Constant<T> &array_; |
434 | Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const; |
435 | }; |
436 | |
437 | } // namespace Fortran::evaluate |
438 | #endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_ |
439 | |