1 | //===-- lib/runtime/dot-product.cpp -----------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "float.h" |
10 | #include "flang-rt/runtime/descriptor.h" |
11 | #include "flang-rt/runtime/terminator.h" |
12 | #include "flang-rt/runtime/tools.h" |
13 | #include "flang/Common/float128.h" |
14 | #include "flang/Runtime/cpp-type.h" |
15 | #include "flang/Runtime/reduction.h" |
16 | #include <cfloat> |
17 | #include <cinttypes> |
18 | |
19 | namespace Fortran::runtime { |
20 | |
21 | // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first |
22 | // argument; MATMUL does not. |
23 | |
24 | // General accumulator for any type and stride; this is not used for |
25 | // contiguous numeric vectors. |
26 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
27 | class Accumulator { |
28 | public: |
29 | using Result = AccumulationType<RCAT, RKIND>; |
30 | RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y) |
31 | : x_{x}, y_{y} {} |
32 | RT_API_ATTRS void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) { |
33 | if constexpr (RCAT == TypeCategory::Logical) { |
34 | sum_ = sum_ || |
35 | (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); |
36 | } else { |
37 | const XT &xElement{*x_.Element<XT>(&xAt)}; |
38 | const YT &yElement{*y_.Element<YT>(&yAt)}; |
39 | if constexpr (RCAT == TypeCategory::Complex) { |
40 | sum_ += rtcmplx::conj(static_cast<Result>(xElement)) * |
41 | static_cast<Result>(yElement); |
42 | } else { |
43 | sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement); |
44 | } |
45 | } |
46 | } |
47 | RT_API_ATTRS Result GetResult() const { return sum_; } |
48 | |
49 | private: |
50 | const Descriptor &x_, &y_; |
51 | Result sum_{}; |
52 | }; |
53 | |
54 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
55 | static inline RT_API_ATTRS CppTypeFor<RCAT, RKIND> DoDotProduct( |
56 | const Descriptor &x, const Descriptor &y, Terminator &terminator) { |
57 | using Result = CppTypeFor<RCAT, RKIND>; |
58 | RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); |
59 | SubscriptValue n{x.GetDimension(0).Extent()}; |
60 | if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { |
61 | terminator.Crash( |
62 | "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd" , |
63 | static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN)); |
64 | } |
65 | if constexpr (RCAT != TypeCategory::Logical) { |
66 | if (x.GetDimension(0).ByteStride() == sizeof(XT) && |
67 | y.GetDimension(0).ByteStride() == sizeof(YT)) { |
68 | // Contiguous numeric vectors |
69 | if constexpr (std::is_same_v<XT, YT>) { |
70 | // Contiguous homogeneous numeric vectors |
71 | if constexpr (std::is_same_v<XT, float>) { |
72 | // TODO: call BLAS-1 SDOT or SDSDOT |
73 | } else if constexpr (std::is_same_v<XT, double>) { |
74 | // TODO: call BLAS-1 DDOT |
75 | } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { |
76 | // TODO: call BLAS-1 CDOTC |
77 | } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { |
78 | // TODO: call BLAS-1 ZDOTC |
79 | } |
80 | } |
81 | XT *xp{x.OffsetElement<XT>(0)}; |
82 | YT *yp{y.OffsetElement<YT>(0)}; |
83 | using AccumType = AccumulationType<RCAT, RKIND>; |
84 | AccumType accum{}; |
85 | if constexpr (RCAT == TypeCategory::Complex) { |
86 | for (SubscriptValue j{0}; j < n; ++j) { |
87 | // conj() may instantiate its argument twice, |
88 | // so xp has to be incremented separately. |
89 | // This is a workaround for an alleged bug in clang, |
90 | // that shows up as: |
91 | // warning: multiple unsequenced modifications to 'xp' |
92 | accum += rtcmplx::conj(static_cast<AccumType>(*xp)) * |
93 | static_cast<AccumType>(*yp++); |
94 | xp++; |
95 | } |
96 | } else { |
97 | for (SubscriptValue j{0}; j < n; ++j) { |
98 | accum += |
99 | static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++); |
100 | } |
101 | } |
102 | return static_cast<Result>(accum); |
103 | } |
104 | } |
105 | // Non-contiguous, heterogeneous, & LOGICAL cases |
106 | SubscriptValue xAt{x.GetDimension(0).LowerBound()}; |
107 | SubscriptValue yAt{y.GetDimension(0).LowerBound()}; |
108 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
109 | for (SubscriptValue j{0}; j < n; ++j) { |
110 | accumulator.AccumulateIndexed(xAt++, yAt++); |
111 | } |
112 | return static_cast<Result>(accumulator.GetResult()); |
113 | } |
114 | |
115 | template <TypeCategory RCAT, int RKIND> struct DotProduct { |
116 | using Result = CppTypeFor<RCAT, RKIND>; |
117 | template <TypeCategory XCAT, int XKIND> struct DP1 { |
118 | template <TypeCategory YCAT, int YKIND> struct DP2 { |
119 | RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, |
120 | Terminator &terminator) const { |
121 | if constexpr (constexpr auto resultType{ |
122 | GetResultType(XCAT, XKIND, YCAT, YKIND)}) { |
123 | if constexpr (resultType->first == RCAT && |
124 | (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) { |
125 | return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>, |
126 | CppTypeFor<YCAT, YKIND>>(x, y, terminator); |
127 | } |
128 | } |
129 | terminator.Crash( |
130 | "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))" , |
131 | static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND, |
132 | static_cast<int>(YCAT), YKIND); |
133 | } |
134 | }; |
135 | RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, |
136 | Terminator &terminator, TypeCategory yCat, int yKind) const { |
137 | return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator); |
138 | } |
139 | }; |
140 | RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, |
141 | const char *source, int line) const { |
142 | Terminator terminator{source, line}; |
143 | if (RCAT != TypeCategory::Logical && x.type() == y.type()) { |
144 | // No conversions needed, operands and result have same known type |
145 | return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}( |
146 | x, y, terminator); |
147 | } else { |
148 | auto xCatKind{x.type().GetCategoryAndKind()}; |
149 | auto yCatKind{y.type().GetCategoryAndKind()}; |
150 | RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); |
151 | return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, |
152 | terminator, x, y, terminator, yCatKind->first, yCatKind->second); |
153 | } |
154 | } |
155 | }; |
156 | |
157 | extern "C" { |
158 | RT_EXT_API_GROUP_BEGIN |
159 | |
160 | CppTypeFor<TypeCategory::Integer, 1> RTDEF(DotProductInteger1)( |
161 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
162 | return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line); |
163 | } |
164 | CppTypeFor<TypeCategory::Integer, 2> RTDEF(DotProductInteger2)( |
165 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
166 | return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line); |
167 | } |
168 | CppTypeFor<TypeCategory::Integer, 4> RTDEF(DotProductInteger4)( |
169 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
170 | return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line); |
171 | } |
172 | CppTypeFor<TypeCategory::Integer, 8> RTDEF(DotProductInteger8)( |
173 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
174 | return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); |
175 | } |
176 | #ifdef __SIZEOF_INT128__ |
177 | CppTypeFor<TypeCategory::Integer, 16> RTDEF(DotProductInteger16)( |
178 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
179 | return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line); |
180 | } |
181 | #endif |
182 | |
183 | CppTypeFor<TypeCategory::Unsigned, 1> RTDEF(DotProductUnsigned1)( |
184 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
185 | return DotProduct<TypeCategory::Unsigned, 1>{}(x, y, source, line); |
186 | } |
187 | CppTypeFor<TypeCategory::Unsigned, 2> RTDEF(DotProductUnsigned2)( |
188 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
189 | return DotProduct<TypeCategory::Unsigned, 2>{}(x, y, source, line); |
190 | } |
191 | CppTypeFor<TypeCategory::Unsigned, 4> RTDEF(DotProductUnsigned4)( |
192 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
193 | return DotProduct<TypeCategory::Unsigned, 4>{}(x, y, source, line); |
194 | } |
195 | CppTypeFor<TypeCategory::Unsigned, 8> RTDEF(DotProductUnsigned8)( |
196 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
197 | return DotProduct<TypeCategory::Unsigned, 8>{}(x, y, source, line); |
198 | } |
199 | #ifdef __SIZEOF_INT128__ |
200 | CppTypeFor<TypeCategory::Unsigned, 16> RTDEF(DotProductUnsigned16)( |
201 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
202 | return DotProduct<TypeCategory::Unsigned, 16>{}(x, y, source, line); |
203 | } |
204 | #endif |
205 | |
206 | // TODO: REAL/COMPLEX(2 & 3) |
207 | // Intermediate results and operations are at least 64 bits |
208 | CppTypeFor<TypeCategory::Real, 4> RTDEF(DotProductReal4)( |
209 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
210 | return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line); |
211 | } |
212 | CppTypeFor<TypeCategory::Real, 8> RTDEF(DotProductReal8)( |
213 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
214 | return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line); |
215 | } |
216 | #if HAS_FLOAT80 |
217 | CppTypeFor<TypeCategory::Real, 10> RTDEF(DotProductReal10)( |
218 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
219 | return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line); |
220 | } |
221 | #endif |
222 | #if HAS_LDBL128 || HAS_FLOAT128 |
223 | CppTypeFor<TypeCategory::Real, 16> RTDEF(DotProductReal16)( |
224 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
225 | return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line); |
226 | } |
227 | #endif |
228 | |
229 | void RTDEF(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result, |
230 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
231 | result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line); |
232 | } |
233 | void RTDEF(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result, |
234 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
235 | result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line); |
236 | } |
237 | #if HAS_FLOAT80 |
238 | void RTDEF(CppDotProductComplex10)( |
239 | CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x, |
240 | const Descriptor &y, const char *source, int line) { |
241 | result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line); |
242 | } |
243 | #endif |
244 | #if HAS_LDBL128 || HAS_FLOAT128 |
245 | void RTDEF(CppDotProductComplex16)( |
246 | CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x, |
247 | const Descriptor &y, const char *source, int line) { |
248 | result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line); |
249 | } |
250 | #endif |
251 | |
252 | bool RTDEF(DotProductLogical)( |
253 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
254 | return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line); |
255 | } |
256 | |
257 | RT_EXT_API_GROUP_END |
258 | } // extern "C" |
259 | } // namespace Fortran::runtime |
260 | |