| 1 | //===-- lib/runtime/dot-product.cpp -----------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "float.h" |
| 10 | #include "flang-rt/runtime/descriptor.h" |
| 11 | #include "flang-rt/runtime/terminator.h" |
| 12 | #include "flang-rt/runtime/tools.h" |
| 13 | #include "flang/Common/float128.h" |
| 14 | #include "flang/Runtime/cpp-type.h" |
| 15 | #include "flang/Runtime/reduction.h" |
| 16 | #include <cfloat> |
| 17 | #include <cinttypes> |
| 18 | |
| 19 | namespace Fortran::runtime { |
| 20 | |
| 21 | // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first |
| 22 | // argument; MATMUL does not. |
| 23 | |
| 24 | // General accumulator for any type and stride; this is not used for |
| 25 | // contiguous numeric vectors. |
| 26 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
| 27 | class Accumulator { |
| 28 | public: |
| 29 | using Result = AccumulationType<RCAT, RKIND>; |
| 30 | RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y) |
| 31 | : x_{x}, y_{y} {} |
| 32 | RT_API_ATTRS void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) { |
| 33 | if constexpr (RCAT == TypeCategory::Logical) { |
| 34 | sum_ = sum_ || |
| 35 | (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); |
| 36 | } else { |
| 37 | const XT &xElement{*x_.Element<XT>(&xAt)}; |
| 38 | const YT &yElement{*y_.Element<YT>(&yAt)}; |
| 39 | if constexpr (RCAT == TypeCategory::Complex) { |
| 40 | sum_ += rtcmplx::conj(static_cast<Result>(xElement)) * |
| 41 | static_cast<Result>(yElement); |
| 42 | } else { |
| 43 | sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement); |
| 44 | } |
| 45 | } |
| 46 | } |
| 47 | RT_API_ATTRS Result GetResult() const { return sum_; } |
| 48 | |
| 49 | private: |
| 50 | const Descriptor &x_, &y_; |
| 51 | Result sum_{}; |
| 52 | }; |
| 53 | |
| 54 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
| 55 | static inline RT_API_ATTRS CppTypeFor<RCAT, RKIND> DoDotProduct( |
| 56 | const Descriptor &x, const Descriptor &y, Terminator &terminator) { |
| 57 | using Result = CppTypeFor<RCAT, RKIND>; |
| 58 | RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); |
| 59 | SubscriptValue n{x.GetDimension(0).Extent()}; |
| 60 | if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { |
| 61 | terminator.Crash( |
| 62 | "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd" , |
| 63 | static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN)); |
| 64 | } |
| 65 | if constexpr (RCAT != TypeCategory::Logical) { |
| 66 | if (x.GetDimension(0).ByteStride() == sizeof(XT) && |
| 67 | y.GetDimension(0).ByteStride() == sizeof(YT)) { |
| 68 | // Contiguous numeric vectors |
| 69 | if constexpr (std::is_same_v<XT, YT>) { |
| 70 | // Contiguous homogeneous numeric vectors |
| 71 | if constexpr (std::is_same_v<XT, float>) { |
| 72 | // TODO: call BLAS-1 SDOT or SDSDOT |
| 73 | } else if constexpr (std::is_same_v<XT, double>) { |
| 74 | // TODO: call BLAS-1 DDOT |
| 75 | } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { |
| 76 | // TODO: call BLAS-1 CDOTC |
| 77 | } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { |
| 78 | // TODO: call BLAS-1 ZDOTC |
| 79 | } |
| 80 | } |
| 81 | XT *xp{x.OffsetElement<XT>(0)}; |
| 82 | YT *yp{y.OffsetElement<YT>(0)}; |
| 83 | using AccumType = AccumulationType<RCAT, RKIND>; |
| 84 | AccumType accum{}; |
| 85 | if constexpr (RCAT == TypeCategory::Complex) { |
| 86 | for (SubscriptValue j{0}; j < n; ++j) { |
| 87 | // conj() may instantiate its argument twice, |
| 88 | // so xp has to be incremented separately. |
| 89 | // This is a workaround for an alleged bug in clang, |
| 90 | // that shows up as: |
| 91 | // warning: multiple unsequenced modifications to 'xp' |
| 92 | accum += rtcmplx::conj(static_cast<AccumType>(*xp)) * |
| 93 | static_cast<AccumType>(*yp++); |
| 94 | xp++; |
| 95 | } |
| 96 | } else { |
| 97 | for (SubscriptValue j{0}; j < n; ++j) { |
| 98 | accum += |
| 99 | static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++); |
| 100 | } |
| 101 | } |
| 102 | return static_cast<Result>(accum); |
| 103 | } |
| 104 | } |
| 105 | // Non-contiguous, heterogeneous, & LOGICAL cases |
| 106 | SubscriptValue xAt{x.GetDimension(0).LowerBound()}; |
| 107 | SubscriptValue yAt{y.GetDimension(0).LowerBound()}; |
| 108 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
| 109 | for (SubscriptValue j{0}; j < n; ++j) { |
| 110 | accumulator.AccumulateIndexed(xAt++, yAt++); |
| 111 | } |
| 112 | return static_cast<Result>(accumulator.GetResult()); |
| 113 | } |
| 114 | |
| 115 | template <TypeCategory RCAT, int RKIND> struct DotProduct { |
| 116 | using Result = CppTypeFor<RCAT, RKIND>; |
| 117 | template <TypeCategory XCAT, int XKIND> struct DP1 { |
| 118 | template <TypeCategory YCAT, int YKIND> struct DP2 { |
| 119 | RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, |
| 120 | Terminator &terminator) const { |
| 121 | if constexpr (constexpr auto resultType{ |
| 122 | GetResultType(XCAT, XKIND, YCAT, YKIND)}) { |
| 123 | if constexpr (resultType->first == RCAT && |
| 124 | (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) { |
| 125 | return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>, |
| 126 | CppTypeFor<YCAT, YKIND>>(x, y, terminator); |
| 127 | } |
| 128 | } |
| 129 | terminator.Crash( |
| 130 | "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))" , |
| 131 | static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND, |
| 132 | static_cast<int>(YCAT), YKIND); |
| 133 | } |
| 134 | }; |
| 135 | RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, |
| 136 | Terminator &terminator, TypeCategory yCat, int yKind) const { |
| 137 | return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator); |
| 138 | } |
| 139 | }; |
| 140 | RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y, |
| 141 | const char *source, int line) const { |
| 142 | Terminator terminator{source, line}; |
| 143 | if (RCAT != TypeCategory::Logical && x.type() == y.type()) { |
| 144 | // No conversions needed, operands and result have same known type |
| 145 | return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}( |
| 146 | x, y, terminator); |
| 147 | } else { |
| 148 | auto xCatKind{x.type().GetCategoryAndKind()}; |
| 149 | auto yCatKind{y.type().GetCategoryAndKind()}; |
| 150 | RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); |
| 151 | return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, |
| 152 | terminator, x, y, terminator, yCatKind->first, yCatKind->second); |
| 153 | } |
| 154 | } |
| 155 | }; |
| 156 | |
| 157 | extern "C" { |
| 158 | RT_EXT_API_GROUP_BEGIN |
| 159 | |
| 160 | CppTypeFor<TypeCategory::Integer, 1> RTDEF(DotProductInteger1)( |
| 161 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 162 | return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line); |
| 163 | } |
| 164 | CppTypeFor<TypeCategory::Integer, 2> RTDEF(DotProductInteger2)( |
| 165 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 166 | return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line); |
| 167 | } |
| 168 | CppTypeFor<TypeCategory::Integer, 4> RTDEF(DotProductInteger4)( |
| 169 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 170 | return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line); |
| 171 | } |
| 172 | CppTypeFor<TypeCategory::Integer, 8> RTDEF(DotProductInteger8)( |
| 173 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 174 | return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); |
| 175 | } |
| 176 | #ifdef __SIZEOF_INT128__ |
| 177 | CppTypeFor<TypeCategory::Integer, 16> RTDEF(DotProductInteger16)( |
| 178 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 179 | return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line); |
| 180 | } |
| 181 | #endif |
| 182 | |
| 183 | CppTypeFor<TypeCategory::Unsigned, 1> RTDEF(DotProductUnsigned1)( |
| 184 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 185 | return DotProduct<TypeCategory::Unsigned, 1>{}(x, y, source, line); |
| 186 | } |
| 187 | CppTypeFor<TypeCategory::Unsigned, 2> RTDEF(DotProductUnsigned2)( |
| 188 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 189 | return DotProduct<TypeCategory::Unsigned, 2>{}(x, y, source, line); |
| 190 | } |
| 191 | CppTypeFor<TypeCategory::Unsigned, 4> RTDEF(DotProductUnsigned4)( |
| 192 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 193 | return DotProduct<TypeCategory::Unsigned, 4>{}(x, y, source, line); |
| 194 | } |
| 195 | CppTypeFor<TypeCategory::Unsigned, 8> RTDEF(DotProductUnsigned8)( |
| 196 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 197 | return DotProduct<TypeCategory::Unsigned, 8>{}(x, y, source, line); |
| 198 | } |
| 199 | #ifdef __SIZEOF_INT128__ |
| 200 | CppTypeFor<TypeCategory::Unsigned, 16> RTDEF(DotProductUnsigned16)( |
| 201 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 202 | return DotProduct<TypeCategory::Unsigned, 16>{}(x, y, source, line); |
| 203 | } |
| 204 | #endif |
| 205 | |
| 206 | // TODO: REAL/COMPLEX(2 & 3) |
| 207 | // Intermediate results and operations are at least 64 bits |
| 208 | CppTypeFor<TypeCategory::Real, 4> RTDEF(DotProductReal4)( |
| 209 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 210 | return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line); |
| 211 | } |
| 212 | CppTypeFor<TypeCategory::Real, 8> RTDEF(DotProductReal8)( |
| 213 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 214 | return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line); |
| 215 | } |
| 216 | #if HAS_FLOAT80 |
| 217 | CppTypeFor<TypeCategory::Real, 10> RTDEF(DotProductReal10)( |
| 218 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 219 | return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line); |
| 220 | } |
| 221 | #endif |
| 222 | #if HAS_LDBL128 || HAS_FLOAT128 |
| 223 | CppTypeFor<TypeCategory::Real, 16> RTDEF(DotProductReal16)( |
| 224 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 225 | return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line); |
| 226 | } |
| 227 | #endif |
| 228 | |
| 229 | void RTDEF(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result, |
| 230 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 231 | result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line); |
| 232 | } |
| 233 | void RTDEF(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result, |
| 234 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 235 | result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line); |
| 236 | } |
| 237 | #if HAS_FLOAT80 |
| 238 | void RTDEF(CppDotProductComplex10)( |
| 239 | CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x, |
| 240 | const Descriptor &y, const char *source, int line) { |
| 241 | result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line); |
| 242 | } |
| 243 | #endif |
| 244 | #if HAS_LDBL128 || HAS_FLOAT128 |
| 245 | void RTDEF(CppDotProductComplex16)( |
| 246 | CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x, |
| 247 | const Descriptor &y, const char *source, int line) { |
| 248 | result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line); |
| 249 | } |
| 250 | #endif |
| 251 | |
| 252 | bool RTDEF(DotProductLogical)( |
| 253 | const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| 254 | return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line); |
| 255 | } |
| 256 | |
| 257 | RT_EXT_API_GROUP_END |
| 258 | } // extern "C" |
| 259 | } // namespace Fortran::runtime |
| 260 | |