1 | //===-- runtime/product.cpp -----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | // Implements PRODUCT for all required operand types and shapes. |
10 | |
11 | #include "reduction-templates.h" |
12 | #include "flang/Common/float128.h" |
13 | #include "flang/Runtime/reduction.h" |
14 | #include <cfloat> |
15 | #include <cinttypes> |
16 | #include <complex> |
17 | |
18 | namespace Fortran::runtime { |
19 | template <typename INTERMEDIATE> class NonComplexProductAccumulator { |
20 | public: |
21 | explicit RT_API_ATTRS NonComplexProductAccumulator(const Descriptor &array) |
22 | : array_{array} {} |
23 | RT_API_ATTRS void Reinitialize() { product_ = 1; } |
24 | template <typename A> |
25 | RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const { |
26 | *p = static_cast<A>(product_); |
27 | } |
28 | template <typename A> |
29 | RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) { |
30 | product_ *= *array_.Element<A>(at); |
31 | return product_ != 0; |
32 | } |
33 | |
34 | private: |
35 | const Descriptor &array_; |
36 | INTERMEDIATE product_{1}; |
37 | }; |
38 | |
39 | // Suppress the warnings about calling __host__-only std::complex operators, |
40 | // defined in C++ STD header files, from __device__ code. |
41 | RT_DIAG_PUSH |
42 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
43 | |
44 | template <typename PART> class ComplexProductAccumulator { |
45 | public: |
46 | explicit RT_API_ATTRS ComplexProductAccumulator(const Descriptor &array) |
47 | : array_{array} {} |
48 | RT_API_ATTRS void Reinitialize() { product_ = std::complex<PART>{1, 0}; } |
49 | template <typename A> |
50 | RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const { |
51 | using ResultPart = typename A::value_type; |
52 | *p = {static_cast<ResultPart>(product_.real()), |
53 | static_cast<ResultPart>(product_.imag())}; |
54 | } |
55 | template <typename A> |
56 | RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) { |
57 | product_ *= *array_.Element<A>(at); |
58 | return true; |
59 | } |
60 | |
61 | private: |
62 | const Descriptor &array_; |
63 | std::complex<PART> product_{1, 0}; |
64 | }; |
65 | |
66 | RT_DIAG_POP |
67 | |
68 | extern "C" { |
69 | RT_EXT_API_GROUP_BEGIN |
70 | |
71 | CppTypeFor<TypeCategory::Integer, 1> RTDEF(ProductInteger1)(const Descriptor &x, |
72 | const char *source, int line, int dim, const Descriptor *mask) { |
73 | return GetTotalReduction<TypeCategory::Integer, 1>(x, source, line, dim, mask, |
74 | NonComplexProductAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, |
75 | "PRODUCT" ); |
76 | } |
77 | CppTypeFor<TypeCategory::Integer, 2> RTDEF(ProductInteger2)(const Descriptor &x, |
78 | const char *source, int line, int dim, const Descriptor *mask) { |
79 | return GetTotalReduction<TypeCategory::Integer, 2>(x, source, line, dim, mask, |
80 | NonComplexProductAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, |
81 | "PRODUCT" ); |
82 | } |
83 | CppTypeFor<TypeCategory::Integer, 4> RTDEF(ProductInteger4)(const Descriptor &x, |
84 | const char *source, int line, int dim, const Descriptor *mask) { |
85 | return GetTotalReduction<TypeCategory::Integer, 4>(x, source, line, dim, mask, |
86 | NonComplexProductAccumulator<CppTypeFor<TypeCategory::Integer, 4>>{x}, |
87 | "PRODUCT" ); |
88 | } |
89 | CppTypeFor<TypeCategory::Integer, 8> RTDEF(ProductInteger8)(const Descriptor &x, |
90 | const char *source, int line, int dim, const Descriptor *mask) { |
91 | return GetTotalReduction<TypeCategory::Integer, 8>(x, source, line, dim, mask, |
92 | NonComplexProductAccumulator<CppTypeFor<TypeCategory::Integer, 8>>{x}, |
93 | "PRODUCT" ); |
94 | } |
95 | #ifdef __SIZEOF_INT128__ |
96 | CppTypeFor<TypeCategory::Integer, 16> RTDEF(ProductInteger16)( |
97 | const Descriptor &x, const char *source, int line, int dim, |
98 | const Descriptor *mask) { |
99 | return GetTotalReduction<TypeCategory::Integer, 16>(x, source, line, dim, |
100 | mask, |
101 | NonComplexProductAccumulator<CppTypeFor<TypeCategory::Integer, 16>>{x}, |
102 | "PRODUCT" ); |
103 | } |
104 | #endif |
105 | |
106 | // TODO: real/complex(2 & 3) |
107 | CppTypeFor<TypeCategory::Real, 4> RTDEF(ProductReal4)(const Descriptor &x, |
108 | const char *source, int line, int dim, const Descriptor *mask) { |
109 | return GetTotalReduction<TypeCategory::Real, 4>(x, source, line, dim, mask, |
110 | NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x}, |
111 | "PRODUCT" ); |
112 | } |
113 | CppTypeFor<TypeCategory::Real, 8> RTDEF(ProductReal8)(const Descriptor &x, |
114 | const char *source, int line, int dim, const Descriptor *mask) { |
115 | return GetTotalReduction<TypeCategory::Real, 8>(x, source, line, dim, mask, |
116 | NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x}, |
117 | "PRODUCT" ); |
118 | } |
119 | #if LDBL_MANT_DIG == 64 |
120 | CppTypeFor<TypeCategory::Real, 10> RTDEF(ProductReal10)(const Descriptor &x, |
121 | const char *source, int line, int dim, const Descriptor *mask) { |
122 | return GetTotalReduction<TypeCategory::Real, 10>(x, source, line, dim, mask, |
123 | NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x}, |
124 | "PRODUCT" ); |
125 | } |
126 | #endif |
127 | #if LDBL_MANT_DIG == 113 || HAS_FLOAT128 |
128 | CppTypeFor<TypeCategory::Real, 16> RTDEF(ProductReal16)(const Descriptor &x, |
129 | const char *source, int line, int dim, const Descriptor *mask) { |
130 | return GetTotalReduction<TypeCategory::Real, 16>(x, source, line, dim, mask, |
131 | NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 16>>{x}, |
132 | "PRODUCT" ); |
133 | } |
134 | #endif |
135 | |
136 | void RTDEF(CppProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result, |
137 | const Descriptor &x, const char *source, int line, int dim, |
138 | const Descriptor *mask) { |
139 | result = GetTotalReduction<TypeCategory::Complex, 4>(x, source, line, dim, |
140 | mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x}, |
141 | "PRODUCT" ); |
142 | } |
143 | void RTDEF(CppProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result, |
144 | const Descriptor &x, const char *source, int line, int dim, |
145 | const Descriptor *mask) { |
146 | result = GetTotalReduction<TypeCategory::Complex, 8>(x, source, line, dim, |
147 | mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x}, |
148 | "PRODUCT" ); |
149 | } |
150 | #if LDBL_MANT_DIG == 64 |
151 | void RTDEF(CppProductComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result, |
152 | const Descriptor &x, const char *source, int line, int dim, |
153 | const Descriptor *mask) { |
154 | result = GetTotalReduction<TypeCategory::Complex, 10>(x, source, line, dim, |
155 | mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x}, |
156 | "PRODUCT" ); |
157 | } |
158 | #endif |
159 | #if LDBL_MANT_DIG == 113 || HAS_FLOAT128 |
160 | void RTDEF(CppProductComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result, |
161 | const Descriptor &x, const char *source, int line, int dim, |
162 | const Descriptor *mask) { |
163 | result = GetTotalReduction<TypeCategory::Complex, 16>(x, source, line, dim, |
164 | mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 16>>{x}, |
165 | "PRODUCT" ); |
166 | } |
167 | #endif |
168 | |
169 | void RTDEF(ProductDim)(Descriptor &result, const Descriptor &x, int dim, |
170 | const char *source, int line, const Descriptor *mask) { |
171 | TypedPartialNumericReduction<NonComplexProductAccumulator, |
172 | NonComplexProductAccumulator, ComplexProductAccumulator>( |
173 | result, x, dim, source, line, mask, "PRODUCT" ); |
174 | } |
175 | |
176 | RT_EXT_API_GROUP_END |
177 | } // extern "C" |
178 | } // namespace Fortran::runtime |
179 | |