1 | //===-- lib/Evaluate/fold-reduction.h -------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_ |
10 | #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_ |
11 | |
12 | #include "fold-implementation.h" |
13 | |
14 | namespace Fortran::evaluate { |
15 | |
16 | // DOT_PRODUCT |
17 | template <typename T> |
18 | static Expr<T> FoldDotProduct( |
19 | FoldingContext &context, FunctionRef<T> &&funcRef) { |
20 | using Element = typename Constant<T>::Element; |
21 | auto args{funcRef.arguments()}; |
22 | CHECK(args.size() == 2); |
23 | Folder<T> folder{context}; |
24 | Constant<T> *va{folder.Folding(args[0])}; |
25 | Constant<T> *vb{folder.Folding(args[1])}; |
26 | if (va && vb) { |
27 | CHECK(va->Rank() == 1 && vb->Rank() == 1); |
28 | if (va->size() != vb->size()) { |
29 | context.messages().Say( |
30 | "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US , |
31 | va->size(), vb->size()); |
32 | return MakeInvalidIntrinsic(std::move(funcRef)); |
33 | } |
34 | Element sum{}; |
35 | bool overflow{false}; |
36 | if constexpr (T::category == TypeCategory::Complex) { |
37 | std::vector<Element> conjugates; |
38 | for (const Element &x : va->values()) { |
39 | conjugates.emplace_back(x.CONJG()); |
40 | } |
41 | Constant<T> conjgA{ |
42 | std::move(conjugates), ConstantSubscripts{va->shape()}}; |
43 | Expr<T> products{Fold( |
44 | context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})}; |
45 | Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; |
46 | Element correction{}; // Use Kahan summation for greater precision. |
47 | const auto &rounding{context.targetCharacteristics().roundingMode()}; |
48 | for (const Element &x : cProducts.values()) { |
49 | auto next{correction.Add(x, rounding)}; |
50 | overflow |= next.flags.test(RealFlag::Overflow); |
51 | auto added{sum.Add(next.value, rounding)}; |
52 | overflow |= added.flags.test(RealFlag::Overflow); |
53 | correction = added.value.Subtract(sum, rounding) |
54 | .value.Subtract(next.value, rounding) |
55 | .value; |
56 | sum = std::move(added.value); |
57 | } |
58 | } else if constexpr (T::category == TypeCategory::Logical) { |
59 | Expr<T> conjunctions{Fold(context, |
60 | Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And, |
61 | Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})}; |
62 | Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))}; |
63 | for (const Element &x : cConjunctions.values()) { |
64 | if (x.IsTrue()) { |
65 | sum = Element{true}; |
66 | break; |
67 | } |
68 | } |
69 | } else if constexpr (T::category == TypeCategory::Integer) { |
70 | Expr<T> products{ |
71 | Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; |
72 | Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; |
73 | for (const Element &x : cProducts.values()) { |
74 | auto next{sum.AddSigned(x)}; |
75 | overflow |= next.overflow; |
76 | sum = std::move(next.value); |
77 | } |
78 | } else { |
79 | static_assert(T::category == TypeCategory::Real); |
80 | Expr<T> products{ |
81 | Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; |
82 | Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; |
83 | Element correction{}; // Use Kahan summation for greater precision. |
84 | const auto &rounding{context.targetCharacteristics().roundingMode()}; |
85 | for (const Element &x : cProducts.values()) { |
86 | auto next{correction.Add(x, rounding)}; |
87 | overflow |= next.flags.test(RealFlag::Overflow); |
88 | auto added{sum.Add(next.value, rounding)}; |
89 | overflow |= added.flags.test(RealFlag::Overflow); |
90 | correction = added.value.Subtract(sum, rounding) |
91 | .value.Subtract(next.value, rounding) |
92 | .value; |
93 | sum = std::move(added.value); |
94 | } |
95 | } |
96 | if (overflow) { |
97 | context.messages().Say( |
98 | "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US , |
99 | T::AsFortran()); |
100 | } |
101 | return Expr<T>{Constant<T>{std::move(sum)}}; |
102 | } |
103 | return Expr<T>{std::move(funcRef)}; |
104 | } |
105 | |
106 | // Fold and validate a DIM= argument. Returns false on error. |
107 | bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &, |
108 | ActualArguments &, std::optional<int> dimIndex, int rank); |
109 | |
110 | // Fold and validate a MASK= argument. Return null on error, absent MASK=, or |
111 | // non-constant MASK=. |
112 | Constant<LogicalResult> *GetReductionMASK( |
113 | std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape, |
114 | FoldingContext &); |
115 | |
116 | // Common preprocessing for reduction transformational intrinsic function |
117 | // folding. If the intrinsic can have DIM= &/or MASK= arguments, extract |
118 | // and check them. If a MASK= is present, apply it to the array data and |
119 | // substitute replacement values for elements corresponding to .FALSE. in |
120 | // the mask. If the result is present, the intrinsic call can be folded. |
121 | template <typename T> struct ArrayAndMask { |
122 | Constant<T> array; |
123 | Constant<LogicalResult> mask; |
124 | }; |
125 | template <typename T> |
126 | static std::optional<ArrayAndMask<T>> ProcessReductionArgs( |
127 | FoldingContext &context, ActualArguments &arg, std::optional<int> &dim, |
128 | int arrayIndex, std::optional<int> dimIndex = std::nullopt, |
129 | std::optional<int> maskIndex = std::nullopt) { |
130 | if (arg.empty()) { |
131 | return std::nullopt; |
132 | } |
133 | Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])}; |
134 | if (!folded || folded->Rank() < 1) { |
135 | return std::nullopt; |
136 | } |
137 | if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) { |
138 | return std::nullopt; |
139 | } |
140 | std::size_t n{folded->size()}; |
141 | std::vector<Scalar<LogicalResult>> maskElement; |
142 | if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() && |
143 | arg[*maskIndex]) { |
144 | if (const Constant<LogicalResult> *origMask{ |
145 | GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) { |
146 | if (auto scalarMask{origMask->GetScalarValue()}) { |
147 | maskElement = |
148 | std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue()); |
149 | } else { |
150 | maskElement = origMask->values(); |
151 | } |
152 | } else { |
153 | return std::nullopt; |
154 | } |
155 | } else { |
156 | maskElement = std::vector<Scalar<LogicalResult>>(n, true); |
157 | } |
158 | return ArrayAndMask<T>{Constant<T>(*folded), |
159 | Constant<LogicalResult>{ |
160 | std::move(maskElement), ConstantSubscripts{folded->shape()}}}; |
161 | } |
162 | |
163 | // Generalized reduction to an array of one dimension fewer (w/ DIM=) |
164 | // or to a scalar (w/o DIM=). The ACCUMULATOR type must define |
165 | // operator()(Scalar<T> &, const ConstantSubscripts &, bool first) |
166 | // and Done(Scalar<T> &). |
167 | template <typename T, typename ACCUMULATOR, typename ARRAY> |
168 | static Constant<T> DoReduction(const Constant<ARRAY> &array, |
169 | const Constant<LogicalResult> &mask, std::optional<int> &dim, |
170 | const Scalar<T> &identity, ACCUMULATOR &accumulator) { |
171 | ConstantSubscripts at{array.lbounds()}; |
172 | ConstantSubscripts maskAt{mask.lbounds()}; |
173 | std::vector<typename Constant<T>::Element> elements; |
174 | ConstantSubscripts resultShape; // empty -> scalar |
175 | if (dim) { // DIM= is present, so result is an array |
176 | resultShape = array.shape(); |
177 | resultShape.erase(resultShape.begin() + (*dim - 1)); |
178 | ConstantSubscript dimExtent{array.shape().at(*dim - 1)}; |
179 | CHECK(dimExtent == mask.shape().at(*dim - 1)); |
180 | ConstantSubscript &dimAt{at[*dim - 1]}; |
181 | ConstantSubscript dimLbound{dimAt}; |
182 | ConstantSubscript &maskDimAt{maskAt[*dim - 1]}; |
183 | ConstantSubscript maskDimLbound{maskDimAt}; |
184 | for (auto n{GetSize(resultShape)}; n-- > 0; |
185 | array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) { |
186 | elements.push_back(identity); |
187 | if (dimExtent > 0) { |
188 | dimAt = dimLbound; |
189 | maskDimAt = maskDimLbound; |
190 | bool firstUnmasked{true}; |
191 | for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) { |
192 | if (mask.At(maskAt).IsTrue()) { |
193 | accumulator(elements.back(), at, firstUnmasked); |
194 | firstUnmasked = false; |
195 | } |
196 | } |
197 | --dimAt, --maskDimAt; |
198 | } |
199 | accumulator.Done(elements.back()); |
200 | } |
201 | } else { // no DIM=, result is scalar |
202 | elements.push_back(identity); |
203 | bool firstUnmasked{true}; |
204 | for (auto n{array.size()}; n-- > 0; |
205 | array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) { |
206 | if (mask.At(maskAt).IsTrue()) { |
207 | accumulator(elements.back(), at, firstUnmasked); |
208 | firstUnmasked = false; |
209 | } |
210 | } |
211 | accumulator.Done(elements.back()); |
212 | } |
213 | if constexpr (T::category == TypeCategory::Character) { |
214 | return {static_cast<ConstantSubscript>(identity.size()), |
215 | std::move(elements), std::move(resultShape)}; |
216 | } else { |
217 | return {std::move(elements), std::move(resultShape)}; |
218 | } |
219 | } |
220 | |
221 | // MAXVAL & MINVAL |
222 | template <typename T, bool ABS = false> class MaxvalMinvalAccumulator { |
223 | public: |
224 | MaxvalMinvalAccumulator( |
225 | RelationalOperator opr, FoldingContext &context, const Constant<T> &array) |
226 | : opr_{opr}, context_{context}, array_{array} {}; |
227 | void operator()(Scalar<T> &element, const ConstantSubscripts &at, |
228 | [[maybe_unused]] bool firstUnmasked) const { |
229 | auto aAt{array_.At(at)}; |
230 | if constexpr (ABS) { |
231 | aAt = aAt.ABS(); |
232 | } |
233 | if constexpr (T::category == TypeCategory::Real) { |
234 | if (firstUnmasked || element.IsNotANumber()) { |
235 | // Return NaN if and only if all unmasked elements are NaNs and |
236 | // at least one unmasked element is visible. |
237 | element = aAt; |
238 | return; |
239 | } |
240 | } |
241 | Expr<LogicalResult> test{PackageRelation( |
242 | opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})}; |
243 | auto folded{GetScalarConstantValue<LogicalResult>( |
244 | test.Rewrite(context_, std::move(test)))}; |
245 | CHECK(folded.has_value()); |
246 | if (folded->IsTrue()) { |
247 | element = aAt; |
248 | } |
249 | } |
250 | void Done(Scalar<T> &) const {} |
251 | |
252 | private: |
253 | RelationalOperator opr_; |
254 | FoldingContext &context_; |
255 | const Constant<T> &array_; |
256 | }; |
257 | |
258 | template <typename T> |
259 | static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref, |
260 | RelationalOperator opr, const Scalar<T> &identity) { |
261 | static_assert(T::category == TypeCategory::Integer || |
262 | T::category == TypeCategory::Real || |
263 | T::category == TypeCategory::Character); |
264 | std::optional<int> dim; |
265 | if (std::optional<ArrayAndMask<T>> arrayAndMask{ |
266 | ProcessReductionArgs<T>(context, ref.arguments(), dim, |
267 | /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { |
268 | MaxvalMinvalAccumulator accumulator{opr, context, arrayAndMask->array}; |
269 | return Expr<T>{DoReduction<T>( |
270 | arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}; |
271 | } |
272 | return Expr<T>{std::move(ref)}; |
273 | } |
274 | |
275 | // PRODUCT |
276 | template <typename T> class ProductAccumulator { |
277 | public: |
278 | ProductAccumulator(const Constant<T> &array) : array_{array} {} |
279 | void operator()( |
280 | Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) { |
281 | if constexpr (T::category == TypeCategory::Integer) { |
282 | auto prod{element.MultiplySigned(array_.At(at))}; |
283 | overflow_ |= prod.SignedMultiplicationOverflowed(); |
284 | element = prod.lower; |
285 | } else { // Real & Complex |
286 | auto prod{element.Multiply(array_.At(at))}; |
287 | overflow_ |= prod.flags.test(RealFlag::Overflow); |
288 | element = prod.value; |
289 | } |
290 | } |
291 | bool overflow() const { return overflow_; } |
292 | void Done(Scalar<T> &) const {} |
293 | |
294 | private: |
295 | const Constant<T> &array_; |
296 | bool overflow_{false}; |
297 | }; |
298 | |
299 | template <typename T> |
300 | static Expr<T> FoldProduct( |
301 | FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) { |
302 | static_assert(T::category == TypeCategory::Integer || |
303 | T::category == TypeCategory::Real || |
304 | T::category == TypeCategory::Complex); |
305 | std::optional<int> dim; |
306 | if (std::optional<ArrayAndMask<T>> arrayAndMask{ |
307 | ProcessReductionArgs<T>(context, ref.arguments(), dim, |
308 | /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { |
309 | ProductAccumulator accumulator{arrayAndMask->array}; |
310 | auto result{Expr<T>{DoReduction<T>( |
311 | arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}}; |
312 | if (accumulator.overflow()) { |
313 | context.messages().Say( |
314 | "PRODUCT() of %s data overflowed"_warn_en_US , T::AsFortran()); |
315 | } |
316 | return result; |
317 | } |
318 | return Expr<T>{std::move(ref)}; |
319 | } |
320 | |
321 | // SUM |
322 | template <typename T> class SumAccumulator { |
323 | using Element = typename Constant<T>::Element; |
324 | |
325 | public: |
326 | SumAccumulator(const Constant<T> &array, Rounding rounding) |
327 | : array_{array}, rounding_{rounding} {} |
328 | void operator()( |
329 | Element &element, const ConstantSubscripts &at, bool /*first*/) { |
330 | if constexpr (T::category == TypeCategory::Integer) { |
331 | auto sum{element.AddSigned(array_.At(at))}; |
332 | overflow_ |= sum.overflow; |
333 | element = sum.value; |
334 | } else { // Real & Complex: use Kahan summation |
335 | auto next{array_.At(at).Add(correction_, rounding_)}; |
336 | overflow_ |= next.flags.test(RealFlag::Overflow); |
337 | auto sum{element.Add(next.value, rounding_)}; |
338 | overflow_ |= sum.flags.test(RealFlag::Overflow); |
339 | // correction = (sum - element) - next; algebraically zero |
340 | correction_ = sum.value.Subtract(element, rounding_) |
341 | .value.Subtract(next.value, rounding_) |
342 | .value; |
343 | element = sum.value; |
344 | } |
345 | } |
346 | bool overflow() const { return overflow_; } |
347 | void Done([[maybe_unused]] Element &element) { |
348 | if constexpr (T::category != TypeCategory::Integer) { |
349 | auto corrected{element.Add(correction_, rounding_)}; |
350 | overflow_ |= corrected.flags.test(RealFlag::Overflow); |
351 | correction_ = Scalar<T>{}; |
352 | element = corrected.value; |
353 | } |
354 | } |
355 | |
356 | private: |
357 | const Constant<T> &array_; |
358 | Rounding rounding_; |
359 | bool overflow_{false}; |
360 | Element correction_{}; |
361 | }; |
362 | |
363 | template <typename T> |
364 | static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) { |
365 | static_assert(T::category == TypeCategory::Integer || |
366 | T::category == TypeCategory::Real || |
367 | T::category == TypeCategory::Complex); |
368 | using Element = typename Constant<T>::Element; |
369 | std::optional<int> dim; |
370 | Element identity{}; |
371 | if (std::optional<ArrayAndMask<T>> arrayAndMask{ |
372 | ProcessReductionArgs<T>(context, ref.arguments(), dim, |
373 | /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) { |
374 | SumAccumulator accumulator{ |
375 | arrayAndMask->array, context.targetCharacteristics().roundingMode()}; |
376 | auto result{Expr<T>{DoReduction<T>( |
377 | arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}}; |
378 | if (accumulator.overflow()) { |
379 | context.messages().Say( |
380 | "SUM() of %s data overflowed"_warn_en_US , T::AsFortran()); |
381 | } |
382 | return result; |
383 | } |
384 | return Expr<T>{std::move(ref)}; |
385 | } |
386 | |
387 | // Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY |
388 | template <typename T> class OperationAccumulator { |
389 | public: |
390 | OperationAccumulator(const Constant<T> &array, |
391 | Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const) |
392 | : array_{array}, operation_{operation} {} |
393 | void operator()( |
394 | Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) { |
395 | element = (element.*operation_)(array_.At(at)); |
396 | } |
397 | void Done(Scalar<T> &) const {} |
398 | |
399 | private: |
400 | const Constant<T> &array_; |
401 | Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const; |
402 | }; |
403 | |
404 | } // namespace Fortran::evaluate |
405 | #endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_ |
406 | |