1//===-- lib/runtime/dot-product.cpp -----------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "float.h"
10#include "flang-rt/runtime/descriptor.h"
11#include "flang-rt/runtime/terminator.h"
12#include "flang-rt/runtime/tools.h"
13#include "flang/Common/float128.h"
14#include "flang/Runtime/cpp-type.h"
15#include "flang/Runtime/reduction.h"
16#include <cfloat>
17#include <cinttypes>
18
19namespace Fortran::runtime {
20
21// Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
22// argument; MATMUL does not.
23
24// General accumulator for any type and stride; this is not used for
25// contiguous numeric vectors.
26template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
27class Accumulator {
28public:
29 using Result = AccumulationType<RCAT, RKIND>;
30 RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y)
31 : x_{x}, y_{y} {}
32 RT_API_ATTRS void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
33 if constexpr (RCAT == TypeCategory::Logical) {
34 sum_ = sum_ ||
35 (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
36 } else {
37 const XT &xElement{*x_.Element<XT>(&xAt)};
38 const YT &yElement{*y_.Element<YT>(&yAt)};
39 if constexpr (RCAT == TypeCategory::Complex) {
40 sum_ += rtcmplx::conj(static_cast<Result>(xElement)) *
41 static_cast<Result>(yElement);
42 } else {
43 sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
44 }
45 }
46 }
47 RT_API_ATTRS Result GetResult() const { return sum_; }
48
49private:
50 const Descriptor &x_, &y_;
51 Result sum_{};
52};
53
54template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
55static inline RT_API_ATTRS CppTypeFor<RCAT, RKIND> DoDotProduct(
56 const Descriptor &x, const Descriptor &y, Terminator &terminator) {
57 using Result = CppTypeFor<RCAT, RKIND>;
58 RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
59 SubscriptValue n{x.GetDimension(0).Extent()};
60 if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
61 terminator.Crash(
62 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
63 static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
64 }
65 if constexpr (RCAT != TypeCategory::Logical) {
66 if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
67 y.GetDimension(0).ByteStride() == sizeof(YT)) {
68 // Contiguous numeric vectors
69 if constexpr (std::is_same_v<XT, YT>) {
70 // Contiguous homogeneous numeric vectors
71 if constexpr (std::is_same_v<XT, float>) {
72 // TODO: call BLAS-1 SDOT or SDSDOT
73 } else if constexpr (std::is_same_v<XT, double>) {
74 // TODO: call BLAS-1 DDOT
75 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) {
76 // TODO: call BLAS-1 CDOTC
77 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) {
78 // TODO: call BLAS-1 ZDOTC
79 }
80 }
81 XT *xp{x.OffsetElement<XT>(0)};
82 YT *yp{y.OffsetElement<YT>(0)};
83 using AccumType = AccumulationType<RCAT, RKIND>;
84 AccumType accum{};
85 if constexpr (RCAT == TypeCategory::Complex) {
86 for (SubscriptValue j{0}; j < n; ++j) {
87 // conj() may instantiate its argument twice,
88 // so xp has to be incremented separately.
89 // This is a workaround for an alleged bug in clang,
90 // that shows up as:
91 // warning: multiple unsequenced modifications to 'xp'
92 accum += rtcmplx::conj(static_cast<AccumType>(*xp)) *
93 static_cast<AccumType>(*yp++);
94 xp++;
95 }
96 } else {
97 for (SubscriptValue j{0}; j < n; ++j) {
98 accum +=
99 static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
100 }
101 }
102 return static_cast<Result>(accum);
103 }
104 }
105 // Non-contiguous, heterogeneous, & LOGICAL cases
106 SubscriptValue xAt{x.GetDimension(0).LowerBound()};
107 SubscriptValue yAt{y.GetDimension(0).LowerBound()};
108 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
109 for (SubscriptValue j{0}; j < n; ++j) {
110 accumulator.AccumulateIndexed(xAt++, yAt++);
111 }
112 return static_cast<Result>(accumulator.GetResult());
113}
114
115template <TypeCategory RCAT, int RKIND> struct DotProduct {
116 using Result = CppTypeFor<RCAT, RKIND>;
117 template <TypeCategory XCAT, int XKIND> struct DP1 {
118 template <TypeCategory YCAT, int YKIND> struct DP2 {
119 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
120 Terminator &terminator) const {
121 if constexpr (constexpr auto resultType{
122 GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
123 if constexpr (resultType->first == RCAT &&
124 (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) {
125 return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
126 CppTypeFor<YCAT, YKIND>>(x, y, terminator);
127 }
128 }
129 terminator.Crash(
130 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
131 static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
132 static_cast<int>(YCAT), YKIND);
133 }
134 };
135 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
136 Terminator &terminator, TypeCategory yCat, int yKind) const {
137 return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
138 }
139 };
140 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
141 const char *source, int line) const {
142 Terminator terminator{source, line};
143 if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
144 // No conversions needed, operands and result have same known type
145 return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
146 x, y, terminator);
147 } else {
148 auto xCatKind{x.type().GetCategoryAndKind()};
149 auto yCatKind{y.type().GetCategoryAndKind()};
150 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
151 return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
152 terminator, x, y, terminator, yCatKind->first, yCatKind->second);
153 }
154 }
155};
156
157extern "C" {
158RT_EXT_API_GROUP_BEGIN
159
160CppTypeFor<TypeCategory::Integer, 1> RTDEF(DotProductInteger1)(
161 const Descriptor &x, const Descriptor &y, const char *source, int line) {
162 return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
163}
164CppTypeFor<TypeCategory::Integer, 2> RTDEF(DotProductInteger2)(
165 const Descriptor &x, const Descriptor &y, const char *source, int line) {
166 return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
167}
168CppTypeFor<TypeCategory::Integer, 4> RTDEF(DotProductInteger4)(
169 const Descriptor &x, const Descriptor &y, const char *source, int line) {
170 return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
171}
172CppTypeFor<TypeCategory::Integer, 8> RTDEF(DotProductInteger8)(
173 const Descriptor &x, const Descriptor &y, const char *source, int line) {
174 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
175}
176#ifdef __SIZEOF_INT128__
177CppTypeFor<TypeCategory::Integer, 16> RTDEF(DotProductInteger16)(
178 const Descriptor &x, const Descriptor &y, const char *source, int line) {
179 return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
180}
181#endif
182
183CppTypeFor<TypeCategory::Unsigned, 1> RTDEF(DotProductUnsigned1)(
184 const Descriptor &x, const Descriptor &y, const char *source, int line) {
185 return DotProduct<TypeCategory::Unsigned, 1>{}(x, y, source, line);
186}
187CppTypeFor<TypeCategory::Unsigned, 2> RTDEF(DotProductUnsigned2)(
188 const Descriptor &x, const Descriptor &y, const char *source, int line) {
189 return DotProduct<TypeCategory::Unsigned, 2>{}(x, y, source, line);
190}
191CppTypeFor<TypeCategory::Unsigned, 4> RTDEF(DotProductUnsigned4)(
192 const Descriptor &x, const Descriptor &y, const char *source, int line) {
193 return DotProduct<TypeCategory::Unsigned, 4>{}(x, y, source, line);
194}
195CppTypeFor<TypeCategory::Unsigned, 8> RTDEF(DotProductUnsigned8)(
196 const Descriptor &x, const Descriptor &y, const char *source, int line) {
197 return DotProduct<TypeCategory::Unsigned, 8>{}(x, y, source, line);
198}
199#ifdef __SIZEOF_INT128__
200CppTypeFor<TypeCategory::Unsigned, 16> RTDEF(DotProductUnsigned16)(
201 const Descriptor &x, const Descriptor &y, const char *source, int line) {
202 return DotProduct<TypeCategory::Unsigned, 16>{}(x, y, source, line);
203}
204#endif
205
206// TODO: REAL/COMPLEX(2 & 3)
207// Intermediate results and operations are at least 64 bits
208CppTypeFor<TypeCategory::Real, 4> RTDEF(DotProductReal4)(
209 const Descriptor &x, const Descriptor &y, const char *source, int line) {
210 return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
211}
212CppTypeFor<TypeCategory::Real, 8> RTDEF(DotProductReal8)(
213 const Descriptor &x, const Descriptor &y, const char *source, int line) {
214 return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
215}
216#if HAS_FLOAT80
217CppTypeFor<TypeCategory::Real, 10> RTDEF(DotProductReal10)(
218 const Descriptor &x, const Descriptor &y, const char *source, int line) {
219 return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
220}
221#endif
222#if HAS_LDBL128 || HAS_FLOAT128
223CppTypeFor<TypeCategory::Real, 16> RTDEF(DotProductReal16)(
224 const Descriptor &x, const Descriptor &y, const char *source, int line) {
225 return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
226}
227#endif
228
229void RTDEF(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
230 const Descriptor &x, const Descriptor &y, const char *source, int line) {
231 result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
232}
233void RTDEF(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
234 const Descriptor &x, const Descriptor &y, const char *source, int line) {
235 result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
236}
237#if HAS_FLOAT80
238void RTDEF(CppDotProductComplex10)(
239 CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
240 const Descriptor &y, const char *source, int line) {
241 result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
242}
243#endif
244#if HAS_LDBL128 || HAS_FLOAT128
245void RTDEF(CppDotProductComplex16)(
246 CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
247 const Descriptor &y, const char *source, int line) {
248 result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
249}
250#endif
251
252bool RTDEF(DotProductLogical)(
253 const Descriptor &x, const Descriptor &y, const char *source, int line) {
254 return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
255}
256
257RT_EXT_API_GROUP_END
258} // extern "C"
259} // namespace Fortran::runtime
260

source code of flang-rt/lib/runtime/dot-product.cpp