1 | //===-- runtime/matmul.cpp ------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | // Implements all forms of MATMUL (Fortran 2018 16.9.124) |
10 | // |
11 | // There are two main entry points; one establishes a descriptor for the |
12 | // result and allocates it, and the other expects a result descriptor that |
13 | // points to existing storage. |
14 | // |
15 | // This implementation must handle all combinations of numeric types and |
16 | // kinds (100 - 165 cases depending on the target), plus all combinations |
17 | // of logical kinds (16). A single template undergoes many instantiations |
18 | // to cover all of the valid possibilities. |
19 | // |
20 | // Places where BLAS routines could be called are marked as TODO items. |
21 | |
22 | #include "flang/Runtime/matmul.h" |
23 | #include "terminator.h" |
24 | #include "tools.h" |
25 | #include "flang/Common/optional.h" |
26 | #include "flang/Runtime/c-or-cpp.h" |
27 | #include "flang/Runtime/cpp-type.h" |
28 | #include "flang/Runtime/descriptor.h" |
29 | #include <cstring> |
30 | |
31 | namespace Fortran::runtime { |
32 | |
33 | // Suppress the warnings about calling __host__-only std::complex operators, |
34 | // defined in C++ STD header files, from __device__ code. |
35 | RT_DIAG_PUSH |
36 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
37 | |
38 | // General accumulator for any type and stride; this is not used for |
39 | // contiguous numeric cases. |
40 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
41 | class Accumulator { |
42 | public: |
43 | using Result = AccumulationType<RCAT, RKIND>; |
44 | RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y) |
45 | : x_{x}, y_{y} {} |
46 | RT_API_ATTRS void Accumulate( |
47 | const SubscriptValue xAt[], const SubscriptValue yAt[]) { |
48 | if constexpr (RCAT == TypeCategory::Logical) { |
49 | sum_ = sum_ || |
50 | (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); |
51 | } else { |
52 | sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) * |
53 | static_cast<Result>(*y_.Element<YT>(yAt)); |
54 | } |
55 | } |
56 | RT_API_ATTRS Result GetResult() const { return sum_; } |
57 | |
58 | private: |
59 | const Descriptor &x_, &y_; |
60 | Result sum_{}; |
61 | }; |
62 | |
63 | // Contiguous numeric matrix*matrix multiplication |
64 | // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols) |
65 | // Straightforward algorithm: |
66 | // DO 1 I = 1, NROWS |
67 | // DO 1 J = 1, NCOLS |
68 | // RES(I,J) = 0 |
69 | // DO 1 K = 1, N |
70 | // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) |
71 | // With loop distribution and transposition to avoid the inner sum |
72 | // reduction and to avoid non-unit strides: |
73 | // DO 1 I = 1, NROWS |
74 | // DO 1 J = 1, NCOLS |
75 | // 1 RES(I,J) = 0 |
76 | // DO 2 K = 1, N |
77 | // DO 2 J = 1, NCOLS |
78 | // DO 2 I = 1, NROWS |
79 | // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term |
80 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
81 | bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS> |
82 | inline RT_API_ATTRS void MatrixTimesMatrix( |
83 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
84 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
85 | SubscriptValue n, std::size_t xColumnByteStride = 0, |
86 | std::size_t yColumnByteStride = 0) { |
87 | using ResultType = CppTypeFor<RCAT, RKIND>; |
88 | std::memset(product, 0, rows * cols * sizeof *product); |
89 | const XT *RESTRICT xp0{x}; |
90 | for (SubscriptValue k{0}; k < n; ++k) { |
91 | ResultType *RESTRICT p{product}; |
92 | for (SubscriptValue j{0}; j < cols; ++j) { |
93 | const XT *RESTRICT xp{xp0}; |
94 | ResultType yv; |
95 | if constexpr (!Y_HAS_STRIDED_COLUMNS) { |
96 | yv = static_cast<ResultType>(y[k + j * n]); |
97 | } else { |
98 | yv = static_cast<ResultType>(reinterpret_cast<const YT *>( |
99 | reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]); |
100 | } |
101 | for (SubscriptValue i{0}; i < rows; ++i) { |
102 | *p++ += static_cast<ResultType>(*xp++) * yv; |
103 | } |
104 | } |
105 | if constexpr (!X_HAS_STRIDED_COLUMNS) { |
106 | xp0 += rows; |
107 | } else { |
108 | xp0 = reinterpret_cast<const XT *>( |
109 | reinterpret_cast<const char *>(xp0) + xColumnByteStride); |
110 | } |
111 | } |
112 | } |
113 | |
114 | RT_DIAG_POP |
115 | |
116 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
117 | inline RT_API_ATTRS void MatrixTimesMatrixHelper( |
118 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
119 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
120 | SubscriptValue n, Fortran::common::optional<std::size_t> xColumnByteStride, |
121 | Fortran::common::optional<std::size_t> yColumnByteStride) { |
122 | if (!xColumnByteStride) { |
123 | if (!yColumnByteStride) { |
124 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>( |
125 | product, rows, cols, x, y, n); |
126 | } else { |
127 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>( |
128 | product, rows, cols, x, y, n, 0, *yColumnByteStride); |
129 | } |
130 | } else { |
131 | if (!yColumnByteStride) { |
132 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>( |
133 | product, rows, cols, x, y, n, *xColumnByteStride); |
134 | } else { |
135 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>( |
136 | product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride); |
137 | } |
138 | } |
139 | } |
140 | |
141 | RT_DIAG_PUSH |
142 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
143 | |
144 | // Contiguous numeric matrix*vector multiplication |
145 | // matrix(rows,n) * column vector(n) -> column vector(rows) |
146 | // Straightforward algorithm: |
147 | // DO 1 J = 1, NROWS |
148 | // RES(J) = 0 |
149 | // DO 1 K = 1, N |
150 | // 1 RES(J) = RES(J) + X(J,K)*Y(K) |
151 | // With loop distribution and transposition to avoid the inner |
152 | // sum reduction and to avoid non-unit strides: |
153 | // DO 1 J = 1, NROWS |
154 | // 1 RES(J) = 0 |
155 | // DO 2 K = 1, N |
156 | // DO 2 J = 1, NROWS |
157 | // 2 RES(J) = RES(J) + X(J,K)*Y(K) |
158 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
159 | bool X_HAS_STRIDED_COLUMNS> |
160 | inline RT_API_ATTRS void MatrixTimesVector( |
161 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
162 | SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, |
163 | std::size_t xColumnByteStride = 0) { |
164 | using ResultType = CppTypeFor<RCAT, RKIND>; |
165 | std::memset(product, 0, rows * sizeof *product); |
166 | [[maybe_unused]] const XT *RESTRICT xp0{x}; |
167 | for (SubscriptValue k{0}; k < n; ++k) { |
168 | ResultType *RESTRICT p{product}; |
169 | auto yv{static_cast<ResultType>(*y++)}; |
170 | for (SubscriptValue j{0}; j < rows; ++j) { |
171 | *p++ += static_cast<ResultType>(*x++) * yv; |
172 | } |
173 | if constexpr (X_HAS_STRIDED_COLUMNS) { |
174 | xp0 = reinterpret_cast<const XT *>( |
175 | reinterpret_cast<const char *>(xp0) + xColumnByteStride); |
176 | x = xp0; |
177 | } |
178 | } |
179 | } |
180 | |
181 | RT_DIAG_POP |
182 | |
183 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
184 | inline RT_API_ATTRS void MatrixTimesVectorHelper( |
185 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
186 | SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, |
187 | Fortran::common::optional<std::size_t> xColumnByteStride) { |
188 | if (!xColumnByteStride) { |
189 | MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y); |
190 | } else { |
191 | MatrixTimesVector<RCAT, RKIND, XT, YT, true>( |
192 | product, rows, n, x, y, *xColumnByteStride); |
193 | } |
194 | } |
195 | |
196 | RT_DIAG_PUSH |
197 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
198 | |
199 | // Contiguous numeric vector*matrix multiplication |
200 | // row vector(n) * matrix(n,cols) -> row vector(cols) |
201 | // Straightforward algorithm: |
202 | // DO 1 J = 1, NCOLS |
203 | // RES(J) = 0 |
204 | // DO 1 K = 1, N |
205 | // 1 RES(J) = RES(J) + X(K)*Y(K,J) |
206 | // With loop distribution and transposition to avoid the inner |
207 | // sum reduction and one non-unit stride (the other remains): |
208 | // DO 1 J = 1, NCOLS |
209 | // 1 RES(J) = 0 |
210 | // DO 2 K = 1, N |
211 | // DO 2 J = 1, NCOLS |
212 | // 2 RES(J) = RES(J) + X(K)*Y(K,J) |
213 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
214 | bool Y_HAS_STRIDED_COLUMNS> |
215 | inline RT_API_ATTRS void VectorTimesMatrix( |
216 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, |
217 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
218 | std::size_t yColumnByteStride = 0) { |
219 | using ResultType = CppTypeFor<RCAT, RKIND>; |
220 | std::memset(product, 0, cols * sizeof *product); |
221 | for (SubscriptValue k{0}; k < n; ++k) { |
222 | ResultType *RESTRICT p{product}; |
223 | auto xv{static_cast<ResultType>(*x++)}; |
224 | const YT *RESTRICT yp{&y[k]}; |
225 | for (SubscriptValue j{0}; j < cols; ++j) { |
226 | *p++ += xv * static_cast<ResultType>(*yp); |
227 | if constexpr (!Y_HAS_STRIDED_COLUMNS) { |
228 | yp += n; |
229 | } else { |
230 | yp = reinterpret_cast<const YT *>( |
231 | reinterpret_cast<const char *>(yp) + yColumnByteStride); |
232 | } |
233 | } |
234 | } |
235 | } |
236 | |
237 | RT_DIAG_POP |
238 | |
239 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
240 | bool SPARSE_COLUMNS = false> |
241 | inline RT_API_ATTRS void VectorTimesMatrixHelper( |
242 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, |
243 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
244 | Fortran::common::optional<std::size_t> yColumnByteStride) { |
245 | if (!yColumnByteStride) { |
246 | VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y); |
247 | } else { |
248 | VectorTimesMatrix<RCAT, RKIND, XT, YT, true>( |
249 | product, n, cols, x, y, *yColumnByteStride); |
250 | } |
251 | } |
252 | |
253 | RT_DIAG_PUSH |
254 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
255 | |
256 | // Implements an instance of MATMUL for given argument types. |
257 | template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, |
258 | typename YT> |
259 | static inline RT_API_ATTRS void DoMatmul( |
260 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, |
261 | const Descriptor &x, const Descriptor &y, Terminator &terminator) { |
262 | int xRank{x.rank()}; |
263 | int yRank{y.rank()}; |
264 | int resRank{xRank + yRank - 2}; |
265 | if (xRank * yRank != 2 * resRank) { |
266 | terminator.Crash("MATMUL: bad argument ranks (%d * %d)" , xRank, yRank); |
267 | } |
268 | SubscriptValue extent[2]{ |
269 | xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), |
270 | resRank == 2 ? y.GetDimension(1).Extent() : 0}; |
271 | if constexpr (IS_ALLOCATING) { |
272 | result.Establish( |
273 | RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); |
274 | for (int j{0}; j < resRank; ++j) { |
275 | result.GetDimension(j).SetBounds(1, extent[j]); |
276 | } |
277 | if (int stat{result.Allocate()}) { |
278 | terminator.Crash( |
279 | "MATMUL: could not allocate memory for result; STAT=%d" , stat); |
280 | } |
281 | } else { |
282 | RUNTIME_CHECK(terminator, resRank == result.rank()); |
283 | RUNTIME_CHECK( |
284 | terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND)); |
285 | RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); |
286 | RUNTIME_CHECK(terminator, |
287 | resRank == 1 || result.GetDimension(1).Extent() == extent[1]); |
288 | } |
289 | SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; |
290 | if (n != y.GetDimension(0).Extent()) { |
291 | terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)" , |
292 | static_cast<std::intmax_t>(x.GetDimension(0).Extent()), |
293 | static_cast<std::intmax_t>(n), |
294 | static_cast<std::intmax_t>(y.GetDimension(0).Extent()), |
295 | static_cast<std::intmax_t>(y.GetDimension(1).Extent())); |
296 | } |
297 | using WriteResult = |
298 | CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, |
299 | RKIND>; |
300 | if constexpr (RCAT != TypeCategory::Logical) { |
301 | if (x.IsContiguous(1) && y.IsContiguous(1) && |
302 | (IS_ALLOCATING || result.IsContiguous())) { |
303 | // Contiguous numeric matrices (maybe with columns |
304 | // separated by a stride). |
305 | Fortran::common::optional<std::size_t> xColumnByteStride; |
306 | if (!x.IsContiguous()) { |
307 | // X's columns are strided. |
308 | SubscriptValue xAt[2]{}; |
309 | x.GetLowerBounds(xAt); |
310 | xAt[1]++; |
311 | xColumnByteStride = x.SubscriptsToByteOffset(xAt); |
312 | } |
313 | Fortran::common::optional<std::size_t> yColumnByteStride; |
314 | if (!y.IsContiguous()) { |
315 | // Y's columns are strided. |
316 | SubscriptValue yAt[2]{}; |
317 | y.GetLowerBounds(yAt); |
318 | yAt[1]++; |
319 | yColumnByteStride = y.SubscriptsToByteOffset(yAt); |
320 | } |
321 | // Note that BLAS GEMM can be used for the strided |
322 | // columns by setting proper leading dimension size. |
323 | // This implies that the column stride is divisible |
324 | // by the element size, which is usually true. |
325 | if (resRank == 2) { // M*M -> M |
326 | if (std::is_same_v<XT, YT>) { |
327 | if constexpr (std::is_same_v<XT, float>) { |
328 | // TODO: call BLAS-3 SGEMM |
329 | // TODO: try using CUTLASS for device. |
330 | } else if constexpr (std::is_same_v<XT, double>) { |
331 | // TODO: call BLAS-3 DGEMM |
332 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { |
333 | // TODO: call BLAS-3 CGEMM |
334 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { |
335 | // TODO: call BLAS-3 ZGEMM |
336 | } |
337 | } |
338 | MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>( |
339 | result.template OffsetElement<WriteResult>(), extent[0], extent[1], |
340 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride, |
341 | yColumnByteStride); |
342 | return; |
343 | } else if (xRank == 2) { // M*V -> V |
344 | if (std::is_same_v<XT, YT>) { |
345 | if constexpr (std::is_same_v<XT, float>) { |
346 | // TODO: call BLAS-2 SGEMV(x,y) |
347 | } else if constexpr (std::is_same_v<XT, double>) { |
348 | // TODO: call BLAS-2 DGEMV(x,y) |
349 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { |
350 | // TODO: call BLAS-2 CGEMV(x,y) |
351 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { |
352 | // TODO: call BLAS-2 ZGEMV(x,y) |
353 | } |
354 | } |
355 | MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>( |
356 | result.template OffsetElement<WriteResult>(), extent[0], n, |
357 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride); |
358 | return; |
359 | } else { // V*M -> V |
360 | if (std::is_same_v<XT, YT>) { |
361 | if constexpr (std::is_same_v<XT, float>) { |
362 | // TODO: call BLAS-2 SGEMV(y,x) |
363 | } else if constexpr (std::is_same_v<XT, double>) { |
364 | // TODO: call BLAS-2 DGEMV(y,x) |
365 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { |
366 | // TODO: call BLAS-2 CGEMV(y,x) |
367 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { |
368 | // TODO: call BLAS-2 ZGEMV(y,x) |
369 | } |
370 | } |
371 | VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>( |
372 | result.template OffsetElement<WriteResult>(), n, extent[0], |
373 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride); |
374 | return; |
375 | } |
376 | } |
377 | } |
378 | // General algorithms for LOGICAL and noncontiguity |
379 | SubscriptValue xAt[2], yAt[2], resAt[2]; |
380 | x.GetLowerBounds(xAt); |
381 | y.GetLowerBounds(yAt); |
382 | result.GetLowerBounds(resAt); |
383 | if (resRank == 2) { // M*M -> M |
384 | SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; |
385 | for (SubscriptValue i{0}; i < extent[0]; ++i) { |
386 | for (SubscriptValue j{0}; j < extent[1]; ++j) { |
387 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
388 | yAt[1] = y1 + j; |
389 | for (SubscriptValue k{0}; k < n; ++k) { |
390 | xAt[1] = x1 + k; |
391 | yAt[0] = y0 + k; |
392 | accumulator.Accumulate(xAt, yAt); |
393 | } |
394 | resAt[1] = res1 + j; |
395 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); |
396 | } |
397 | ++resAt[0]; |
398 | ++xAt[0]; |
399 | } |
400 | } else if (xRank == 2) { // M*V -> V |
401 | SubscriptValue x1{xAt[1]}, y0{yAt[0]}; |
402 | for (SubscriptValue j{0}; j < extent[0]; ++j) { |
403 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
404 | for (SubscriptValue k{0}; k < n; ++k) { |
405 | xAt[1] = x1 + k; |
406 | yAt[0] = y0 + k; |
407 | accumulator.Accumulate(xAt, yAt); |
408 | } |
409 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); |
410 | ++resAt[0]; |
411 | ++xAt[0]; |
412 | } |
413 | } else { // V*M -> V |
414 | SubscriptValue x0{xAt[0]}, y0{yAt[0]}; |
415 | for (SubscriptValue j{0}; j < extent[0]; ++j) { |
416 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
417 | for (SubscriptValue k{0}; k < n; ++k) { |
418 | xAt[0] = x0 + k; |
419 | yAt[0] = y0 + k; |
420 | accumulator.Accumulate(xAt, yAt); |
421 | } |
422 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); |
423 | ++resAt[0]; |
424 | ++yAt[1]; |
425 | } |
426 | } |
427 | } |
428 | |
429 | RT_DIAG_POP |
430 | |
431 | // Maps the dynamic type information from the arguments' descriptors |
432 | // to the right instantiation of DoMatmul() for valid combinations of |
433 | // types. |
434 | template <bool IS_ALLOCATING> struct Matmul { |
435 | using ResultDescriptor = |
436 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; |
437 | template <TypeCategory XCAT, int XKIND> struct MM1 { |
438 | template <TypeCategory YCAT, int YKIND> struct MM2 { |
439 | RT_API_ATTRS void operator()(ResultDescriptor &result, |
440 | const Descriptor &x, const Descriptor &y, |
441 | Terminator &terminator) const { |
442 | if constexpr (constexpr auto resultType{ |
443 | GetResultType(XCAT, XKIND, YCAT, YKIND)}) { |
444 | if constexpr (common::IsNumericTypeCategory(resultType->first) || |
445 | resultType->first == TypeCategory::Logical) { |
446 | return DoMatmul<IS_ALLOCATING, resultType->first, |
447 | resultType->second, CppTypeFor<XCAT, XKIND>, |
448 | CppTypeFor<YCAT, YKIND>>(result, x, y, terminator); |
449 | } |
450 | } |
451 | terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))" , |
452 | static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); |
453 | } |
454 | }; |
455 | RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, |
456 | const Descriptor &y, Terminator &terminator, TypeCategory yCat, |
457 | int yKind) const { |
458 | ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator); |
459 | } |
460 | }; |
461 | RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, |
462 | const Descriptor &y, const char *sourceFile, int line) const { |
463 | Terminator terminator{sourceFile, line}; |
464 | auto xCatKind{x.type().GetCategoryAndKind()}; |
465 | auto yCatKind{y.type().GetCategoryAndKind()}; |
466 | RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); |
467 | ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result, |
468 | x, y, terminator, yCatKind->first, yCatKind->second); |
469 | } |
470 | }; |
471 | |
472 | extern "C" { |
473 | RT_EXT_API_GROUP_BEGIN |
474 | |
475 | void RTDEF(Matmul)(Descriptor &result, const Descriptor &x, const Descriptor &y, |
476 | const char *sourceFile, int line) { |
477 | Matmul<true>{}(result, x, y, sourceFile, line); |
478 | } |
479 | void RTDEF(MatmulDirect)(const Descriptor &result, const Descriptor &x, |
480 | const Descriptor &y, const char *sourceFile, int line) { |
481 | Matmul<false>{}(result, x, y, sourceFile, line); |
482 | } |
483 | |
484 | RT_EXT_API_GROUP_END |
485 | } // extern "C" |
486 | } // namespace Fortran::runtime |
487 | |