| 1 | //===-- lib/runtime/matmul.cpp ----------------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | // Implements all forms of MATMUL (Fortran 2018 16.9.124) |
| 10 | // |
| 11 | // There are two main entry points; one establishes a descriptor for the |
| 12 | // result and allocates it, and the other expects a result descriptor that |
| 13 | // points to existing storage. |
| 14 | // |
| 15 | // This implementation must handle all combinations of numeric types and |
| 16 | // kinds (100 - 165 cases depending on the target), plus all combinations |
| 17 | // of logical kinds (16). A single template undergoes many instantiations |
| 18 | // to cover all of the valid possibilities. |
| 19 | // |
| 20 | // Places where BLAS routines could be called are marked as TODO items. |
| 21 | |
| 22 | #include "flang/Runtime/matmul.h" |
| 23 | #include "flang-rt/runtime/descriptor.h" |
| 24 | #include "flang-rt/runtime/terminator.h" |
| 25 | #include "flang-rt/runtime/tools.h" |
| 26 | #include "flang/Common/optional.h" |
| 27 | #include "flang/Runtime/c-or-cpp.h" |
| 28 | #include "flang/Runtime/cpp-type.h" |
| 29 | #include <cstring> |
| 30 | |
| 31 | namespace { |
| 32 | using namespace Fortran::runtime; |
| 33 | |
| 34 | // General accumulator for any type and stride; this is not used for |
| 35 | // contiguous numeric cases. |
| 36 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
| 37 | class Accumulator { |
| 38 | public: |
| 39 | using Result = AccumulationType<RCAT, RKIND>; |
| 40 | RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y) |
| 41 | : x_{x}, y_{y} {} |
| 42 | RT_API_ATTRS void Accumulate( |
| 43 | const SubscriptValue xAt[], const SubscriptValue yAt[]) { |
| 44 | if constexpr (RCAT == TypeCategory::Logical) { |
| 45 | sum_ = sum_ || |
| 46 | (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); |
| 47 | } else { |
| 48 | sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) * |
| 49 | static_cast<Result>(*y_.Element<YT>(yAt)); |
| 50 | } |
| 51 | } |
| 52 | RT_API_ATTRS Result GetResult() const { return sum_; } |
| 53 | |
| 54 | private: |
| 55 | const Descriptor &x_, &y_; |
| 56 | Result sum_{}; |
| 57 | }; |
| 58 | |
| 59 | // Contiguous numeric matrix*matrix multiplication |
| 60 | // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols) |
| 61 | // Straightforward algorithm: |
| 62 | // DO 1 I = 1, NROWS |
| 63 | // DO 1 J = 1, NCOLS |
| 64 | // RES(I,J) = 0 |
| 65 | // DO 1 K = 1, N |
| 66 | // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) |
| 67 | // With loop distribution and transposition to avoid the inner sum |
| 68 | // reduction and to avoid non-unit strides: |
| 69 | // DO 1 I = 1, NROWS |
| 70 | // DO 1 J = 1, NCOLS |
| 71 | // 1 RES(I,J) = 0 |
| 72 | // DO 2 K = 1, N |
| 73 | // DO 2 J = 1, NCOLS |
| 74 | // DO 2 I = 1, NROWS |
| 75 | // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term |
| 76 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
| 77 | bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS> |
| 78 | inline RT_API_ATTRS void MatrixTimesMatrix( |
| 79 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
| 80 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
| 81 | SubscriptValue n, std::size_t xColumnByteStride = 0, |
| 82 | std::size_t yColumnByteStride = 0) { |
| 83 | using ResultType = CppTypeFor<RCAT, RKIND>; |
| 84 | std::memset(product, 0, rows * cols * sizeof *product); |
| 85 | const XT *RESTRICT xp0{x}; |
| 86 | for (SubscriptValue k{0}; k < n; ++k) { |
| 87 | ResultType *RESTRICT p{product}; |
| 88 | for (SubscriptValue j{0}; j < cols; ++j) { |
| 89 | const XT *RESTRICT xp{xp0}; |
| 90 | ResultType yv; |
| 91 | if constexpr (!Y_HAS_STRIDED_COLUMNS) { |
| 92 | yv = static_cast<ResultType>(y[k + j * n]); |
| 93 | } else { |
| 94 | yv = static_cast<ResultType>(reinterpret_cast<const YT *>( |
| 95 | reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]); |
| 96 | } |
| 97 | for (SubscriptValue i{0}; i < rows; ++i) { |
| 98 | *p++ += static_cast<ResultType>(*xp++) * yv; |
| 99 | } |
| 100 | } |
| 101 | if constexpr (!X_HAS_STRIDED_COLUMNS) { |
| 102 | xp0 += rows; |
| 103 | } else { |
| 104 | xp0 = reinterpret_cast<const XT *>( |
| 105 | reinterpret_cast<const char *>(xp0) + xColumnByteStride); |
| 106 | } |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
| 111 | inline RT_API_ATTRS void MatrixTimesMatrixHelper( |
| 112 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
| 113 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
| 114 | SubscriptValue n, Fortran::common::optional<std::size_t> xColumnByteStride, |
| 115 | Fortran::common::optional<std::size_t> yColumnByteStride) { |
| 116 | if (!xColumnByteStride) { |
| 117 | if (!yColumnByteStride) { |
| 118 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>( |
| 119 | product, rows, cols, x, y, n); |
| 120 | } else { |
| 121 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>( |
| 122 | product, rows, cols, x, y, n, 0, *yColumnByteStride); |
| 123 | } |
| 124 | } else { |
| 125 | if (!yColumnByteStride) { |
| 126 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>( |
| 127 | product, rows, cols, x, y, n, *xColumnByteStride); |
| 128 | } else { |
| 129 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>( |
| 130 | product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride); |
| 131 | } |
| 132 | } |
| 133 | } |
| 134 | |
| 135 | // Contiguous numeric matrix*vector multiplication |
| 136 | // matrix(rows,n) * column vector(n) -> column vector(rows) |
| 137 | // Straightforward algorithm: |
| 138 | // DO 1 J = 1, NROWS |
| 139 | // RES(J) = 0 |
| 140 | // DO 1 K = 1, N |
| 141 | // 1 RES(J) = RES(J) + X(J,K)*Y(K) |
| 142 | // With loop distribution and transposition to avoid the inner |
| 143 | // sum reduction and to avoid non-unit strides: |
| 144 | // DO 1 J = 1, NROWS |
| 145 | // 1 RES(J) = 0 |
| 146 | // DO 2 K = 1, N |
| 147 | // DO 2 J = 1, NROWS |
| 148 | // 2 RES(J) = RES(J) + X(J,K)*Y(K) |
| 149 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
| 150 | bool X_HAS_STRIDED_COLUMNS> |
| 151 | inline RT_API_ATTRS void MatrixTimesVector( |
| 152 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
| 153 | SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, |
| 154 | std::size_t xColumnByteStride = 0) { |
| 155 | using ResultType = CppTypeFor<RCAT, RKIND>; |
| 156 | std::memset(product, 0, rows * sizeof *product); |
| 157 | [[maybe_unused]] const XT *RESTRICT xp0{x}; |
| 158 | for (SubscriptValue k{0}; k < n; ++k) { |
| 159 | ResultType *RESTRICT p{product}; |
| 160 | auto yv{static_cast<ResultType>(*y++)}; |
| 161 | for (SubscriptValue j{0}; j < rows; ++j) { |
| 162 | *p++ += static_cast<ResultType>(*x++) * yv; |
| 163 | } |
| 164 | if constexpr (X_HAS_STRIDED_COLUMNS) { |
| 165 | xp0 = reinterpret_cast<const XT *>( |
| 166 | reinterpret_cast<const char *>(xp0) + xColumnByteStride); |
| 167 | x = xp0; |
| 168 | } |
| 169 | } |
| 170 | } |
| 171 | |
| 172 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
| 173 | inline RT_API_ATTRS void MatrixTimesVectorHelper( |
| 174 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
| 175 | SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, |
| 176 | Fortran::common::optional<std::size_t> xColumnByteStride) { |
| 177 | if (!xColumnByteStride) { |
| 178 | MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y); |
| 179 | } else { |
| 180 | MatrixTimesVector<RCAT, RKIND, XT, YT, true>( |
| 181 | product, rows, n, x, y, *xColumnByteStride); |
| 182 | } |
| 183 | } |
| 184 | |
| 185 | // Contiguous numeric vector*matrix multiplication |
| 186 | // row vector(n) * matrix(n,cols) -> row vector(cols) |
| 187 | // Straightforward algorithm: |
| 188 | // DO 1 J = 1, NCOLS |
| 189 | // RES(J) = 0 |
| 190 | // DO 1 K = 1, N |
| 191 | // 1 RES(J) = RES(J) + X(K)*Y(K,J) |
| 192 | // With loop distribution and transposition to avoid the inner |
| 193 | // sum reduction and one non-unit stride (the other remains): |
| 194 | // DO 1 J = 1, NCOLS |
| 195 | // 1 RES(J) = 0 |
| 196 | // DO 2 K = 1, N |
| 197 | // DO 2 J = 1, NCOLS |
| 198 | // 2 RES(J) = RES(J) + X(K)*Y(K,J) |
| 199 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
| 200 | bool Y_HAS_STRIDED_COLUMNS> |
| 201 | inline RT_API_ATTRS void VectorTimesMatrix( |
| 202 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, |
| 203 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
| 204 | std::size_t yColumnByteStride = 0) { |
| 205 | using ResultType = CppTypeFor<RCAT, RKIND>; |
| 206 | std::memset(product, 0, cols * sizeof *product); |
| 207 | for (SubscriptValue k{0}; k < n; ++k) { |
| 208 | ResultType *RESTRICT p{product}; |
| 209 | auto xv{static_cast<ResultType>(*x++)}; |
| 210 | const YT *RESTRICT yp{&y[k]}; |
| 211 | for (SubscriptValue j{0}; j < cols; ++j) { |
| 212 | *p++ += xv * static_cast<ResultType>(*yp); |
| 213 | if constexpr (!Y_HAS_STRIDED_COLUMNS) { |
| 214 | yp += n; |
| 215 | } else { |
| 216 | yp = reinterpret_cast<const YT *>( |
| 217 | reinterpret_cast<const char *>(yp) + yColumnByteStride); |
| 218 | } |
| 219 | } |
| 220 | } |
| 221 | } |
| 222 | |
| 223 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
| 224 | bool SPARSE_COLUMNS = false> |
| 225 | inline RT_API_ATTRS void VectorTimesMatrixHelper( |
| 226 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, |
| 227 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
| 228 | Fortran::common::optional<std::size_t> yColumnByteStride) { |
| 229 | if (!yColumnByteStride) { |
| 230 | VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y); |
| 231 | } else { |
| 232 | VectorTimesMatrix<RCAT, RKIND, XT, YT, true>( |
| 233 | product, n, cols, x, y, *yColumnByteStride); |
| 234 | } |
| 235 | } |
| 236 | |
| 237 | // Implements an instance of MATMUL for given argument types. |
| 238 | template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, |
| 239 | typename YT> |
| 240 | static inline RT_API_ATTRS void DoMatmul( |
| 241 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, |
| 242 | const Descriptor &x, const Descriptor &y, Terminator &terminator) { |
| 243 | int xRank{x.rank()}; |
| 244 | int yRank{y.rank()}; |
| 245 | int resRank{xRank + yRank - 2}; |
| 246 | if (xRank * yRank != 2 * resRank) { |
| 247 | terminator.Crash("MATMUL: bad argument ranks (%d * %d)" , xRank, yRank); |
| 248 | } |
| 249 | SubscriptValue extent[2]{ |
| 250 | xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), |
| 251 | resRank == 2 ? y.GetDimension(1).Extent() : 0}; |
| 252 | if constexpr (IS_ALLOCATING) { |
| 253 | result.Establish( |
| 254 | RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); |
| 255 | for (int j{0}; j < resRank; ++j) { |
| 256 | result.GetDimension(j).SetBounds(1, extent[j]); |
| 257 | } |
| 258 | if (int stat{result.Allocate(kNoAsyncObject)}) { |
| 259 | terminator.Crash( |
| 260 | "MATMUL: could not allocate memory for result; STAT=%d" , stat); |
| 261 | } |
| 262 | } else { |
| 263 | RUNTIME_CHECK(terminator, resRank == result.rank()); |
| 264 | RUNTIME_CHECK( |
| 265 | terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND)); |
| 266 | RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); |
| 267 | RUNTIME_CHECK(terminator, |
| 268 | resRank == 1 || result.GetDimension(1).Extent() == extent[1]); |
| 269 | } |
| 270 | SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; |
| 271 | if (n != y.GetDimension(0).Extent()) { |
| 272 | // At this point, we know that there's a shape error. There are three |
| 273 | // possibilities, x is rank 1, y is rank 1, or both are rank 2. |
| 274 | if (xRank == 1) { |
| 275 | terminator.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)" , |
| 276 | static_cast<std::intmax_t>(n), |
| 277 | static_cast<std::intmax_t>(y.GetDimension(0).Extent()), |
| 278 | static_cast<std::intmax_t>(y.GetDimension(1).Extent())); |
| 279 | } else if (yRank == 1) { |
| 280 | terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)" , |
| 281 | static_cast<std::intmax_t>(x.GetDimension(0).Extent()), |
| 282 | static_cast<std::intmax_t>(n), |
| 283 | static_cast<std::intmax_t>(y.GetDimension(0).Extent())); |
| 284 | } else { |
| 285 | terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)" , |
| 286 | static_cast<std::intmax_t>(x.GetDimension(0).Extent()), |
| 287 | static_cast<std::intmax_t>(n), |
| 288 | static_cast<std::intmax_t>(y.GetDimension(0).Extent()), |
| 289 | static_cast<std::intmax_t>(y.GetDimension(1).Extent())); |
| 290 | } |
| 291 | } |
| 292 | using WriteResult = |
| 293 | CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, |
| 294 | RKIND>; |
| 295 | if constexpr (RCAT != TypeCategory::Logical) { |
| 296 | if (x.IsContiguous(1) && y.IsContiguous(1) && |
| 297 | (IS_ALLOCATING || result.IsContiguous())) { |
| 298 | // Contiguous numeric matrices (maybe with columns |
| 299 | // separated by a stride). |
| 300 | Fortran::common::optional<std::size_t> xColumnByteStride; |
| 301 | if (!x.IsContiguous()) { |
| 302 | // X's columns are strided. |
| 303 | SubscriptValue xAt[2]{}; |
| 304 | x.GetLowerBounds(xAt); |
| 305 | xAt[1]++; |
| 306 | xColumnByteStride = x.SubscriptsToByteOffset(xAt); |
| 307 | } |
| 308 | Fortran::common::optional<std::size_t> yColumnByteStride; |
| 309 | if (!y.IsContiguous()) { |
| 310 | // Y's columns are strided. |
| 311 | SubscriptValue yAt[2]{}; |
| 312 | y.GetLowerBounds(yAt); |
| 313 | yAt[1]++; |
| 314 | yColumnByteStride = y.SubscriptsToByteOffset(yAt); |
| 315 | } |
| 316 | // Note that BLAS GEMM can be used for the strided |
| 317 | // columns by setting proper leading dimension size. |
| 318 | // This implies that the column stride is divisible |
| 319 | // by the element size, which is usually true. |
| 320 | if (resRank == 2) { // M*M -> M |
| 321 | if (std::is_same_v<XT, YT>) { |
| 322 | if constexpr (std::is_same_v<XT, float>) { |
| 323 | // TODO: call BLAS-3 SGEMM |
| 324 | // TODO: try using CUTLASS for device. |
| 325 | } else if constexpr (std::is_same_v<XT, double>) { |
| 326 | // TODO: call BLAS-3 DGEMM |
| 327 | } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { |
| 328 | // TODO: call BLAS-3 CGEMM |
| 329 | } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { |
| 330 | // TODO: call BLAS-3 ZGEMM |
| 331 | } |
| 332 | } |
| 333 | MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>( |
| 334 | result.template OffsetElement<WriteResult>(), extent[0], extent[1], |
| 335 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride, |
| 336 | yColumnByteStride); |
| 337 | return; |
| 338 | } else if (xRank == 2) { // M*V -> V |
| 339 | if (std::is_same_v<XT, YT>) { |
| 340 | if constexpr (std::is_same_v<XT, float>) { |
| 341 | // TODO: call BLAS-2 SGEMV(x,y) |
| 342 | } else if constexpr (std::is_same_v<XT, double>) { |
| 343 | // TODO: call BLAS-2 DGEMV(x,y) |
| 344 | } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { |
| 345 | // TODO: call BLAS-2 CGEMV(x,y) |
| 346 | } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { |
| 347 | // TODO: call BLAS-2 ZGEMV(x,y) |
| 348 | } |
| 349 | } |
| 350 | MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>( |
| 351 | result.template OffsetElement<WriteResult>(), extent[0], n, |
| 352 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride); |
| 353 | return; |
| 354 | } else { // V*M -> V |
| 355 | if (std::is_same_v<XT, YT>) { |
| 356 | if constexpr (std::is_same_v<XT, float>) { |
| 357 | // TODO: call BLAS-2 SGEMV(y,x) |
| 358 | } else if constexpr (std::is_same_v<XT, double>) { |
| 359 | // TODO: call BLAS-2 DGEMV(y,x) |
| 360 | } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) { |
| 361 | // TODO: call BLAS-2 CGEMV(y,x) |
| 362 | } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) { |
| 363 | // TODO: call BLAS-2 ZGEMV(y,x) |
| 364 | } |
| 365 | } |
| 366 | VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>( |
| 367 | result.template OffsetElement<WriteResult>(), n, extent[0], |
| 368 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride); |
| 369 | return; |
| 370 | } |
| 371 | } |
| 372 | } |
| 373 | // General algorithms for LOGICAL and noncontiguity |
| 374 | SubscriptValue xAt[2], yAt[2], resAt[2]; |
| 375 | x.GetLowerBounds(xAt); |
| 376 | y.GetLowerBounds(yAt); |
| 377 | result.GetLowerBounds(resAt); |
| 378 | if (resRank == 2) { // M*M -> M |
| 379 | SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; |
| 380 | for (SubscriptValue i{0}; i < extent[0]; ++i) { |
| 381 | for (SubscriptValue j{0}; j < extent[1]; ++j) { |
| 382 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
| 383 | yAt[1] = y1 + j; |
| 384 | for (SubscriptValue k{0}; k < n; ++k) { |
| 385 | xAt[1] = x1 + k; |
| 386 | yAt[0] = y0 + k; |
| 387 | accumulator.Accumulate(xAt, yAt); |
| 388 | } |
| 389 | resAt[1] = res1 + j; |
| 390 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); |
| 391 | } |
| 392 | ++resAt[0]; |
| 393 | ++xAt[0]; |
| 394 | } |
| 395 | } else if (xRank == 2) { // M*V -> V |
| 396 | SubscriptValue x1{xAt[1]}, y0{yAt[0]}; |
| 397 | for (SubscriptValue j{0}; j < extent[0]; ++j) { |
| 398 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
| 399 | for (SubscriptValue k{0}; k < n; ++k) { |
| 400 | xAt[1] = x1 + k; |
| 401 | yAt[0] = y0 + k; |
| 402 | accumulator.Accumulate(xAt, yAt); |
| 403 | } |
| 404 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); |
| 405 | ++resAt[0]; |
| 406 | ++xAt[0]; |
| 407 | } |
| 408 | } else { // V*M -> V |
| 409 | SubscriptValue x0{xAt[0]}, y0{yAt[0]}; |
| 410 | for (SubscriptValue j{0}; j < extent[0]; ++j) { |
| 411 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
| 412 | for (SubscriptValue k{0}; k < n; ++k) { |
| 413 | xAt[0] = x0 + k; |
| 414 | yAt[0] = y0 + k; |
| 415 | accumulator.Accumulate(xAt, yAt); |
| 416 | } |
| 417 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); |
| 418 | ++resAt[0]; |
| 419 | ++yAt[1]; |
| 420 | } |
| 421 | } |
| 422 | } |
| 423 | |
| 424 | template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT, |
| 425 | int YKIND> |
| 426 | struct MatmulHelper { |
| 427 | using ResultTy = Fortran::common::optional<std::pair<TypeCategory, int>>; |
| 428 | using ResultDescriptor = |
| 429 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; |
| 430 | RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, |
| 431 | const Descriptor &y, const char *sourceFile, int line) const { |
| 432 | Terminator terminator{sourceFile, line}; |
| 433 | auto xCatKind{x.type().GetCategoryAndKind()}; |
| 434 | auto yCatKind{y.type().GetCategoryAndKind()}; |
| 435 | RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); |
| 436 | RUNTIME_CHECK(terminator, |
| 437 | (xCatKind->first == XCAT && yCatKind->first == YCAT) || |
| 438 | (XCAT == TypeCategory::Integer && YCAT == TypeCategory::Integer && |
| 439 | ((xCatKind->first == TypeCategory::Integer || |
| 440 | xCatKind->first == TypeCategory::Unsigned) && |
| 441 | (yCatKind->first == TypeCategory::Integer || |
| 442 | yCatKind->first == TypeCategory::Unsigned)))); |
| 443 | if constexpr (constexpr ResultTy resultType{ |
| 444 | GetResultType(XCAT, XKIND, YCAT, YKIND)}) { |
| 445 | return DoMatmul<IS_ALLOCATING, resultType->first, resultType->second, |
| 446 | CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>( |
| 447 | result, x, y, terminator); |
| 448 | } |
| 449 | terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))" , |
| 450 | static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); |
| 451 | } |
| 452 | }; |
| 453 | } // namespace |
| 454 | |
| 455 | namespace Fortran::runtime { |
| 456 | extern "C" { |
| 457 | RT_EXT_API_GROUP_BEGIN |
| 458 | |
| 459 | #define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ |
| 460 | void RTDEF(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \ |
| 461 | const Descriptor &x, const Descriptor &y, const char *sourceFile, \ |
| 462 | int line) { \ |
| 463 | MatmulHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \ |
| 464 | YKIND>{}(result, x, y, sourceFile, line); \ |
| 465 | } |
| 466 | |
| 467 | #define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \ |
| 468 | void RTDEF(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \ |
| 469 | const Descriptor &x, const Descriptor &y, const char *sourceFile, \ |
| 470 | int line) { \ |
| 471 | MatmulHelper<false, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \ |
| 472 | YKIND>{}(result, x, y, sourceFile, line); \ |
| 473 | } |
| 474 | |
| 475 | #define MATMUL_FORCE_ALL_TYPES 0 |
| 476 | |
| 477 | #include "flang/Runtime/matmul-instances.inc" |
| 478 | |
| 479 | RT_EXT_API_GROUP_END |
| 480 | } // extern "C" |
| 481 | } // namespace Fortran::runtime |
| 482 | |