1//===-- runtime/product.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// Implements PRODUCT for all required operand types and shapes.
10
11#include "reduction-templates.h"
12#include "flang/Common/float128.h"
13#include "flang/Runtime/reduction.h"
14#include <cfloat>
15#include <cinttypes>
16#include <complex>
17
18namespace Fortran::runtime {
19template <typename INTERMEDIATE> class NonComplexProductAccumulator {
20public:
21 explicit RT_API_ATTRS NonComplexProductAccumulator(const Descriptor &array)
22 : array_{array} {}
23 RT_API_ATTRS void Reinitialize() { product_ = 1; }
24 template <typename A>
25 RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
26 *p = static_cast<A>(product_);
27 }
28 template <typename A>
29 RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
30 product_ *= *array_.Element<A>(at);
31 return product_ != 0;
32 }
33
34private:
35 const Descriptor &array_;
36 INTERMEDIATE product_{1};
37};
38
39// Suppress the warnings about calling __host__-only std::complex operators,
40// defined in C++ STD header files, from __device__ code.
41RT_DIAG_PUSH
42RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
43
44template <typename PART> class ComplexProductAccumulator {
45public:
46 explicit RT_API_ATTRS ComplexProductAccumulator(const Descriptor &array)
47 : array_{array} {}
48 RT_API_ATTRS void Reinitialize() { product_ = std::complex<PART>{1, 0}; }
49 template <typename A>
50 RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
51 using ResultPart = typename A::value_type;
52 *p = {static_cast<ResultPart>(product_.real()),
53 static_cast<ResultPart>(product_.imag())};
54 }
55 template <typename A>
56 RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
57 product_ *= *array_.Element<A>(at);
58 return true;
59 }
60
61private:
62 const Descriptor &array_;
63 std::complex<PART> product_{1, 0};
64};
65
66RT_DIAG_POP
67
68extern "C" {
69RT_EXT_API_GROUP_BEGIN
70
71CppTypeFor<TypeCategory::Integer, 1> RTDEF(ProductInteger1)(const Descriptor &x,
72 const char *source, int line, int dim, const Descriptor *mask) {
73 return GetTotalReduction<TypeCategory::Integer, 1>(x, source, line, dim, mask,
74 NonComplexProductAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x},
75 "PRODUCT");
76}
77CppTypeFor<TypeCategory::Integer, 2> RTDEF(ProductInteger2)(const Descriptor &x,
78 const char *source, int line, int dim, const Descriptor *mask) {
79 return GetTotalReduction<TypeCategory::Integer, 2>(x, source, line, dim, mask,
80 NonComplexProductAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x},
81 "PRODUCT");
82}
83CppTypeFor<TypeCategory::Integer, 4> RTDEF(ProductInteger4)(const Descriptor &x,
84 const char *source, int line, int dim, const Descriptor *mask) {
85 return GetTotalReduction<TypeCategory::Integer, 4>(x, source, line, dim, mask,
86 NonComplexProductAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x},
87 "PRODUCT");
88}
89CppTypeFor<TypeCategory::Integer, 8> RTDEF(ProductInteger8)(const Descriptor &x,
90 const char *source, int line, int dim, const Descriptor *mask) {
91 return GetTotalReduction<TypeCategory::Integer, 8>(x, source, line, dim, mask,
92 NonComplexProductAccumulator<CppTypeFor<TypeCategory::Integer, 8>>{x},
93 "PRODUCT");
94}
95#ifdef __SIZEOF_INT128__
96CppTypeFor<TypeCategory::Integer, 16> RTDEF(ProductInteger16)(
97 const Descriptor &x, const char *source, int line, int dim,
98 const Descriptor *mask) {
99 return GetTotalReduction<TypeCategory::Integer, 16>(x, source, line, dim,
100 mask,
101 NonComplexProductAccumulator<CppTypeFor<TypeCategory::Integer, 16>>{x},
102 "PRODUCT");
103}
104#endif
105
106// TODO: real/complex(2 & 3)
107CppTypeFor<TypeCategory::Real, 4> RTDEF(ProductReal4)(const Descriptor &x,
108 const char *source, int line, int dim, const Descriptor *mask) {
109 return GetTotalReduction<TypeCategory::Real, 4>(x, source, line, dim, mask,
110 NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
111 "PRODUCT");
112}
113CppTypeFor<TypeCategory::Real, 8> RTDEF(ProductReal8)(const Descriptor &x,
114 const char *source, int line, int dim, const Descriptor *mask) {
115 return GetTotalReduction<TypeCategory::Real, 8>(x, source, line, dim, mask,
116 NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
117 "PRODUCT");
118}
119#if LDBL_MANT_DIG == 64
120CppTypeFor<TypeCategory::Real, 10> RTDEF(ProductReal10)(const Descriptor &x,
121 const char *source, int line, int dim, const Descriptor *mask) {
122 return GetTotalReduction<TypeCategory::Real, 10>(x, source, line, dim, mask,
123 NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x},
124 "PRODUCT");
125}
126#endif
127#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
128CppTypeFor<TypeCategory::Real, 16> RTDEF(ProductReal16)(const Descriptor &x,
129 const char *source, int line, int dim, const Descriptor *mask) {
130 return GetTotalReduction<TypeCategory::Real, 16>(x, source, line, dim, mask,
131 NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 16>>{x},
132 "PRODUCT");
133}
134#endif
135
136void RTDEF(CppProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
137 const Descriptor &x, const char *source, int line, int dim,
138 const Descriptor *mask) {
139 result = GetTotalReduction<TypeCategory::Complex, 4>(x, source, line, dim,
140 mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
141 "PRODUCT");
142}
143void RTDEF(CppProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
144 const Descriptor &x, const char *source, int line, int dim,
145 const Descriptor *mask) {
146 result = GetTotalReduction<TypeCategory::Complex, 8>(x, source, line, dim,
147 mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
148 "PRODUCT");
149}
150#if LDBL_MANT_DIG == 64
151void RTDEF(CppProductComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
152 const Descriptor &x, const char *source, int line, int dim,
153 const Descriptor *mask) {
154 result = GetTotalReduction<TypeCategory::Complex, 10>(x, source, line, dim,
155 mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x},
156 "PRODUCT");
157}
158#endif
159#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
160void RTDEF(CppProductComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
161 const Descriptor &x, const char *source, int line, int dim,
162 const Descriptor *mask) {
163 result = GetTotalReduction<TypeCategory::Complex, 16>(x, source, line, dim,
164 mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 16>>{x},
165 "PRODUCT");
166}
167#endif
168
169void RTDEF(ProductDim)(Descriptor &result, const Descriptor &x, int dim,
170 const char *source, int line, const Descriptor *mask) {
171 TypedPartialNumericReduction<NonComplexProductAccumulator,
172 NonComplexProductAccumulator, ComplexProductAccumulator>(
173 result, x, dim, source, line, mask, "PRODUCT");
174}
175
176RT_EXT_API_GROUP_END
177} // extern "C"
178} // namespace Fortran::runtime
179

source code of flang/runtime/product.cpp