1//===-- runtime/dot-product.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "float.h"
10#include "terminator.h"
11#include "tools.h"
12#include "flang/Common/float128.h"
13#include "flang/Runtime/cpp-type.h"
14#include "flang/Runtime/descriptor.h"
15#include "flang/Runtime/reduction.h"
16#include <cfloat>
17#include <cinttypes>
18
19namespace Fortran::runtime {
20
21// Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
22// argument; MATMUL does not.
23
24// Suppress the warnings about calling __host__-only std::complex operators,
25// defined in C++ STD header files, from __device__ code.
26RT_DIAG_PUSH
27RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
28
29// General accumulator for any type and stride; this is not used for
30// contiguous numeric vectors.
31template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
32class Accumulator {
33public:
34 using Result = AccumulationType<RCAT, RKIND>;
35 RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y)
36 : x_{x}, y_{y} {}
37 RT_API_ATTRS void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
38 if constexpr (RCAT == TypeCategory::Logical) {
39 sum_ = sum_ ||
40 (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
41 } else {
42 const XT &xElement{*x_.Element<XT>(&xAt)};
43 const YT &yElement{*y_.Element<YT>(&yAt)};
44 if constexpr (RCAT == TypeCategory::Complex) {
45 sum_ += std::conj(static_cast<Result>(xElement)) *
46 static_cast<Result>(yElement);
47 } else {
48 sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
49 }
50 }
51 }
52 RT_API_ATTRS Result GetResult() const { return sum_; }
53
54private:
55 const Descriptor &x_, &y_;
56 Result sum_{};
57};
58
59template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
60static inline RT_API_ATTRS CppTypeFor<RCAT, RKIND> DoDotProduct(
61 const Descriptor &x, const Descriptor &y, Terminator &terminator) {
62 using Result = CppTypeFor<RCAT, RKIND>;
63 RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
64 SubscriptValue n{x.GetDimension(0).Extent()};
65 if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
66 terminator.Crash(
67 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
68 static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
69 }
70 if constexpr (RCAT != TypeCategory::Logical) {
71 if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
72 y.GetDimension(0).ByteStride() == sizeof(YT)) {
73 // Contiguous numeric vectors
74 if constexpr (std::is_same_v<XT, YT>) {
75 // Contiguous homogeneous numeric vectors
76 if constexpr (std::is_same_v<XT, float>) {
77 // TODO: call BLAS-1 SDOT or SDSDOT
78 } else if constexpr (std::is_same_v<XT, double>) {
79 // TODO: call BLAS-1 DDOT
80 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
81 // TODO: call BLAS-1 CDOTC
82 } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
83 // TODO: call BLAS-1 ZDOTC
84 }
85 }
86 XT *xp{x.OffsetElement<XT>(0)};
87 YT *yp{y.OffsetElement<YT>(0)};
88 using AccumType = AccumulationType<RCAT, RKIND>;
89 AccumType accum{};
90 if constexpr (RCAT == TypeCategory::Complex) {
91 for (SubscriptValue j{0}; j < n; ++j) {
92 // std::conj() may instantiate its argument twice,
93 // so xp has to be incremented separately.
94 // This is a workaround for an alleged bug in clang,
95 // that shows up as:
96 // warning: multiple unsequenced modifications to 'xp'
97 accum += std::conj(static_cast<AccumType>(*xp)) *
98 static_cast<AccumType>(*yp++);
99 xp++;
100 }
101 } else {
102 for (SubscriptValue j{0}; j < n; ++j) {
103 accum +=
104 static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
105 }
106 }
107 return static_cast<Result>(accum);
108 }
109 }
110 // Non-contiguous, heterogeneous, & LOGICAL cases
111 SubscriptValue xAt{x.GetDimension(0).LowerBound()};
112 SubscriptValue yAt{y.GetDimension(0).LowerBound()};
113 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
114 for (SubscriptValue j{0}; j < n; ++j) {
115 accumulator.AccumulateIndexed(xAt++, yAt++);
116 }
117 return static_cast<Result>(accumulator.GetResult());
118}
119
120RT_DIAG_POP
121
122template <TypeCategory RCAT, int RKIND> struct DotProduct {
123 using Result = CppTypeFor<RCAT, RKIND>;
124 template <TypeCategory XCAT, int XKIND> struct DP1 {
125 template <TypeCategory YCAT, int YKIND> struct DP2 {
126 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
127 Terminator &terminator) const {
128 if constexpr (constexpr auto resultType{
129 GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
130 if constexpr (resultType->first == RCAT &&
131 (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) {
132 return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
133 CppTypeFor<YCAT, YKIND>>(x, y, terminator);
134 }
135 }
136 terminator.Crash(
137 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
138 static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
139 static_cast<int>(YCAT), YKIND);
140 }
141 };
142 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
143 Terminator &terminator, TypeCategory yCat, int yKind) const {
144 return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
145 }
146 };
147 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
148 const char *source, int line) const {
149 Terminator terminator{source, line};
150 if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
151 // No conversions needed, operands and result have same known type
152 return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
153 x, y, terminator);
154 } else {
155 auto xCatKind{x.type().GetCategoryAndKind()};
156 auto yCatKind{y.type().GetCategoryAndKind()};
157 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
158 return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
159 terminator, x, y, terminator, yCatKind->first, yCatKind->second);
160 }
161 }
162};
163
164extern "C" {
165RT_EXT_API_GROUP_BEGIN
166
167CppTypeFor<TypeCategory::Integer, 1> RTDEF(DotProductInteger1)(
168 const Descriptor &x, const Descriptor &y, const char *source, int line) {
169 return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
170}
171CppTypeFor<TypeCategory::Integer, 2> RTDEF(DotProductInteger2)(
172 const Descriptor &x, const Descriptor &y, const char *source, int line) {
173 return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
174}
175CppTypeFor<TypeCategory::Integer, 4> RTDEF(DotProductInteger4)(
176 const Descriptor &x, const Descriptor &y, const char *source, int line) {
177 return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
178}
179CppTypeFor<TypeCategory::Integer, 8> RTDEF(DotProductInteger8)(
180 const Descriptor &x, const Descriptor &y, const char *source, int line) {
181 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
182}
183#ifdef __SIZEOF_INT128__
184CppTypeFor<TypeCategory::Integer, 16> RTDEF(DotProductInteger16)(
185 const Descriptor &x, const Descriptor &y, const char *source, int line) {
186 return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
187}
188#endif
189
190// TODO: REAL/COMPLEX(2 & 3)
191// Intermediate results and operations are at least 64 bits
192CppTypeFor<TypeCategory::Real, 4> RTDEF(DotProductReal4)(
193 const Descriptor &x, const Descriptor &y, const char *source, int line) {
194 return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
195}
196CppTypeFor<TypeCategory::Real, 8> RTDEF(DotProductReal8)(
197 const Descriptor &x, const Descriptor &y, const char *source, int line) {
198 return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
199}
200#if LDBL_MANT_DIG == 64
201CppTypeFor<TypeCategory::Real, 10> RTDEF(DotProductReal10)(
202 const Descriptor &x, const Descriptor &y, const char *source, int line) {
203 return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
204}
205#endif
206#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
207CppTypeFor<TypeCategory::Real, 16> RTDEF(DotProductReal16)(
208 const Descriptor &x, const Descriptor &y, const char *source, int line) {
209 return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
210}
211#endif
212
213void RTDEF(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
214 const Descriptor &x, const Descriptor &y, const char *source, int line) {
215 result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
216}
217void RTDEF(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
218 const Descriptor &x, const Descriptor &y, const char *source, int line) {
219 result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
220}
221#if LDBL_MANT_DIG == 64
222void RTDEF(CppDotProductComplex10)(
223 CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
224 const Descriptor &y, const char *source, int line) {
225 result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
226}
227#endif
228#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
229void RTDEF(CppDotProductComplex16)(
230 CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
231 const Descriptor &y, const char *source, int line) {
232 result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
233}
234#endif
235
236bool RTDEF(DotProductLogical)(
237 const Descriptor &x, const Descriptor &y, const char *source, int line) {
238 return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
239}
240
241RT_EXT_API_GROUP_END
242} // extern "C"
243} // namespace Fortran::runtime
244

source code of flang/runtime/dot-product.cpp