1 | //===-- runtime/dot-product.cpp -------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "float.h" |
10 | #include "terminator.h" |
11 | #include "tools.h" |
12 | #include "flang/Common/float128.h" |
13 | #include "flang/Runtime/cpp-type.h" |
14 | #include "flang/Runtime/descriptor.h" |
15 | #include "flang/Runtime/reduction.h" |
16 | #include <cfloat> |
17 | #include <cinttypes> |
18 | |
19 | namespace Fortran::runtime { |
20 | |
21 | // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first |
22 | // argument; MATMUL does not. |
23 | |
24 | // Suppress the warnings about calling __host__-only std::complex operators, |
25 | // defined in C++ STD header files, from __device__ code. |
26 | RT_DIAG_PUSH |
27 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
28 | |
29 | // General accumulator for any type and stride; this is not used for |
30 | // contiguous numeric vectors. |
31 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
32 | class Accumulator { |
33 | public: |
34 | using Result = AccumulationType<RCAT, RKIND>; |
35 | RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y) |
36 | : x_{x}, y_{y} {} |
37 | RT_API_ATTRS void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) { |
38 | if constexpr (RCAT == TypeCategory::Logical) { |
39 | sum_ = sum_ || |
40 | (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); |
41 | } else { |
42 | const XT &xElement{*x_.Element<XT>(&xAt)}; |
43 | const YT &yElement{*y_.Element<YT>(&yAt)}; |
44 | if constexpr (RCAT == TypeCategory::Complex) { |
45 | sum_ += std::conj(static_cast<Result>(xElement)) * |
46 | static_cast<Result>(yElement); |
47 | } else { |
48 | sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement); |
49 | } |
50 | } |
51 | } |
52 | RT_API_ATTRS Result GetResult() const { return sum_; } |
53 | |
54 | private: |
55 | const Descriptor &x_, &y_; |
56 | Result sum_{}; |
57 | }; |
58 | |
59 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
60 | static inline RT_API_ATTRS CppTypeFor<RCAT, RKIND> DoDotProduct( |
61 | const Descriptor &x, const Descriptor &y, Terminator &terminator) { |
62 | using Result = CppTypeFor<RCAT, RKIND>; |
63 | RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); |
64 | SubscriptValue n{x.GetDimension(0).Extent()}; |
65 | if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { |
66 | terminator.Crash( |
67 | "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd" , |
68 | static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN)); |
69 | } |
70 | if constexpr (RCAT != TypeCategory::Logical) { |
71 | if (x.GetDimension(0).ByteStride() == sizeof(XT) && |
72 | y.GetDimension(0).ByteStride() == sizeof(YT)) { |
73 | // Contiguous numeric vectors |
74 | if constexpr (std::is_same_v<XT, YT>) { |
75 | // Contiguous homogeneous numeric vectors |
76 | if constexpr (std::is_same_v<XT, float>) { |
77 | // TODO: call BLAS-1 SDOT or SDSDOT |
78 | } else if constexpr (std::is_same_v<XT, double>) { |
79 | // TODO: call BLAS-1 DDOT |
80 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { |
81 | // TODO: call BLAS-1 CDOTC |
82 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { |
83 | // TODO: call BLAS-1 ZDOTC |
84 | } |
85 | } |
86 | XT *xp{x.OffsetElement<XT>(0)}; |
87 | YT *yp{y.OffsetElement<YT>(0)}; |
88 | using AccumType = AccumulationType<RCAT, RKIND>; |
89 | AccumType accum{}; |
90 | if constexpr (RCAT == TypeCategory::Complex) { |
91 | for (SubscriptValue j{0}; j < n; ++j) { |
92 | // std::conj() may instantiate its argument twice, |
93 | // so xp has to be incremented separately. |
94 | // This is a workaround for an alleged bug in clang, |
95 | // that shows up as: |
96 | // warning: multiple unsequenced modifications to 'xp' |
97 | accum += std::conj(static_cast<AccumType>(*xp)) * |
98 | static_cast<AccumType>(*yp++); |
99 | xp++; |
100 | } |
101 | } else { |
102 | for (SubscriptValue j{0}; j < n; ++j) { |
103 | accum += |
104 | static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++); |
105 | } |
106 | } |
107 | return static_cast<Result>(accum); |
108 | } |
109 | } |
110 | // Non-contiguous, heterogeneous, & LOGICAL cases |
111 | SubscriptValue xAt{x.GetDimension(0).LowerBound()}; |
112 | SubscriptValue yAt{y.GetDimension(0).LowerBound()}; |
113 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
114 | for (SubscriptValue j{0}; j < n; ++j) { |
115 | accumulator.AccumulateIndexed(xAt++, yAt++); |
116 | } |
117 | return static_cast<Result>(accumulator.GetResult()); |
118 | } |
119 | |
120 | RT_DIAG_POP |
121 | |
122 | template <TypeCategory RCAT, int RKIND> struct DotProduct { |
123 | using Result = CppTypeFor<RCAT, RKIND>; |
124 | template <TypeCategory XCAT, int XKIND> struct DP1 { |
125 | template <TypeCategory YCAT, int YKIND> struct DP2 { |
126 | RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, |
127 | Terminator &terminator) const { |
128 | if constexpr (constexpr auto resultType{ |
129 | GetResultType(XCAT, XKIND, YCAT, YKIND)}) { |
130 | if constexpr (resultType->first == RCAT && |
131 | (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) { |
132 | return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>, |
133 | CppTypeFor<YCAT, YKIND>>(x, y, terminator); |
134 | } |
135 | } |
136 | terminator.Crash( |
137 | "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))" , |
138 | static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND, |
139 | static_cast<int>(YCAT), YKIND); |
140 | } |
141 | }; |
142 | RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, |
143 | Terminator &terminator, TypeCategory yCat, int yKind) const { |
144 | return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator); |
145 | } |
146 | }; |
147 | RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, |
148 | const char *source, int line) const { |
149 | Terminator terminator{source, line}; |
150 | if (RCAT != TypeCategory::Logical && x.type() == y.type()) { |
151 | // No conversions needed, operands and result have same known type |
152 | return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}( |
153 | x, y, terminator); |
154 | } else { |
155 | auto xCatKind{x.type().GetCategoryAndKind()}; |
156 | auto yCatKind{y.type().GetCategoryAndKind()}; |
157 | RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); |
158 | return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, |
159 | terminator, x, y, terminator, yCatKind->first, yCatKind->second); |
160 | } |
161 | } |
162 | }; |
163 | |
164 | extern "C" { |
165 | RT_EXT_API_GROUP_BEGIN |
166 | |
167 | CppTypeFor<TypeCategory::Integer, 1> RTDEF(DotProductInteger1)( |
168 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
169 | return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line); |
170 | } |
171 | CppTypeFor<TypeCategory::Integer, 2> RTDEF(DotProductInteger2)( |
172 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
173 | return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line); |
174 | } |
175 | CppTypeFor<TypeCategory::Integer, 4> RTDEF(DotProductInteger4)( |
176 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
177 | return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line); |
178 | } |
179 | CppTypeFor<TypeCategory::Integer, 8> RTDEF(DotProductInteger8)( |
180 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
181 | return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); |
182 | } |
183 | #ifdef __SIZEOF_INT128__ |
184 | CppTypeFor<TypeCategory::Integer, 16> RTDEF(DotProductInteger16)( |
185 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
186 | return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line); |
187 | } |
188 | #endif |
189 | |
190 | // TODO: REAL/COMPLEX(2 & 3) |
191 | // Intermediate results and operations are at least 64 bits |
192 | CppTypeFor<TypeCategory::Real, 4> RTDEF(DotProductReal4)( |
193 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
194 | return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line); |
195 | } |
196 | CppTypeFor<TypeCategory::Real, 8> RTDEF(DotProductReal8)( |
197 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
198 | return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line); |
199 | } |
200 | #if LDBL_MANT_DIG == 64 |
201 | CppTypeFor<TypeCategory::Real, 10> RTDEF(DotProductReal10)( |
202 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
203 | return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line); |
204 | } |
205 | #endif |
206 | #if LDBL_MANT_DIG == 113 || HAS_FLOAT128 |
207 | CppTypeFor<TypeCategory::Real, 16> RTDEF(DotProductReal16)( |
208 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
209 | return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line); |
210 | } |
211 | #endif |
212 | |
213 | void RTDEF(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result, |
214 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
215 | result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line); |
216 | } |
217 | void RTDEF(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result, |
218 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
219 | result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line); |
220 | } |
221 | #if LDBL_MANT_DIG == 64 |
222 | void RTDEF(CppDotProductComplex10)( |
223 | CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x, |
224 | const Descriptor &y, const char *source, int line) { |
225 | result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line); |
226 | } |
227 | #endif |
228 | #if LDBL_MANT_DIG == 113 || HAS_FLOAT128 |
229 | void RTDEF(CppDotProductComplex16)( |
230 | CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x, |
231 | const Descriptor &y, const char *source, int line) { |
232 | result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line); |
233 | } |
234 | #endif |
235 | |
236 | bool RTDEF(DotProductLogical)( |
237 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
238 | return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line); |
239 | } |
240 | |
241 | RT_EXT_API_GROUP_END |
242 | } // extern "C" |
243 | } // namespace Fortran::runtime |
244 | |