1 | //===-- runtime/matmul.cpp ------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | // Implements all forms of MATMUL (Fortran 2018 16.9.124) |
10 | // |
11 | // There are two main entry points; one establishes a descriptor for the |
12 | // result and allocates it, and the other expects a result descriptor that |
13 | // points to existing storage. |
14 | // |
15 | // This implementation must handle all combinations of numeric types and |
16 | // kinds (100 - 165 cases depending on the target), plus all combinations |
17 | // of logical kinds (16). A single template undergoes many instantiations |
18 | // to cover all of the valid possibilities. |
19 | // |
20 | // Places where BLAS routines could be called are marked as TODO items. |
21 | |
22 | #include "flang/Runtime/matmul.h" |
23 | #include "terminator.h" |
24 | #include "tools.h" |
25 | #include "flang/Runtime/c-or-cpp.h" |
26 | #include "flang/Runtime/cpp-type.h" |
27 | #include "flang/Runtime/descriptor.h" |
28 | #include <cstring> |
29 | |
30 | namespace Fortran::runtime { |
31 | |
32 | // Suppress the warnings about calling __host__-only std::complex operators, |
33 | // defined in C++ STD header files, from __device__ code. |
34 | RT_DIAG_PUSH |
35 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
36 | |
37 | // General accumulator for any type and stride; this is not used for |
38 | // contiguous numeric cases. |
39 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
40 | class Accumulator { |
41 | public: |
42 | using Result = AccumulationType<RCAT, RKIND>; |
43 | RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y) |
44 | : x_{x}, y_{y} {} |
45 | RT_API_ATTRS void Accumulate( |
46 | const SubscriptValue xAt[], const SubscriptValue yAt[]) { |
47 | if constexpr (RCAT == TypeCategory::Logical) { |
48 | sum_ = sum_ || |
49 | (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); |
50 | } else { |
51 | sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) * |
52 | static_cast<Result>(*y_.Element<YT>(yAt)); |
53 | } |
54 | } |
55 | RT_API_ATTRS Result GetResult() const { return sum_; } |
56 | |
57 | private: |
58 | const Descriptor &x_, &y_; |
59 | Result sum_{}; |
60 | }; |
61 | |
62 | // Contiguous numeric matrix*matrix multiplication |
63 | // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols) |
64 | // Straightforward algorithm: |
65 | // DO 1 I = 1, NROWS |
66 | // DO 1 J = 1, NCOLS |
67 | // RES(I,J) = 0 |
68 | // DO 1 K = 1, N |
69 | // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) |
70 | // With loop distribution and transposition to avoid the inner sum |
71 | // reduction and to avoid non-unit strides: |
72 | // DO 1 I = 1, NROWS |
73 | // DO 1 J = 1, NCOLS |
74 | // 1 RES(I,J) = 0 |
75 | // DO 2 K = 1, N |
76 | // DO 2 J = 1, NCOLS |
77 | // DO 2 I = 1, NROWS |
78 | // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term |
79 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
80 | bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS> |
81 | inline RT_API_ATTRS void MatrixTimesMatrix( |
82 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
83 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
84 | SubscriptValue n, std::size_t xColumnByteStride = 0, |
85 | std::size_t yColumnByteStride = 0) { |
86 | using ResultType = CppTypeFor<RCAT, RKIND>; |
87 | std::memset(product, 0, rows * cols * sizeof *product); |
88 | const XT *RESTRICT xp0{x}; |
89 | for (SubscriptValue k{0}; k < n; ++k) { |
90 | ResultType *RESTRICT p{product}; |
91 | for (SubscriptValue j{0}; j < cols; ++j) { |
92 | const XT *RESTRICT xp{xp0}; |
93 | ResultType yv; |
94 | if constexpr (!Y_HAS_STRIDED_COLUMNS) { |
95 | yv = static_cast<ResultType>(y[k + j * n]); |
96 | } else { |
97 | yv = static_cast<ResultType>(reinterpret_cast<const YT *>( |
98 | reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]); |
99 | } |
100 | for (SubscriptValue i{0}; i < rows; ++i) { |
101 | *p++ += static_cast<ResultType>(*xp++) * yv; |
102 | } |
103 | } |
104 | if constexpr (!X_HAS_STRIDED_COLUMNS) { |
105 | xp0 += rows; |
106 | } else { |
107 | xp0 = reinterpret_cast<const XT *>( |
108 | reinterpret_cast<const char *>(xp0) + xColumnByteStride); |
109 | } |
110 | } |
111 | } |
112 | |
113 | RT_DIAG_POP |
114 | |
115 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
116 | inline RT_API_ATTRS void MatrixTimesMatrixHelper( |
117 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
118 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
119 | SubscriptValue n, std::optional<std::size_t> xColumnByteStride, |
120 | std::optional<std::size_t> yColumnByteStride) { |
121 | if (!xColumnByteStride) { |
122 | if (!yColumnByteStride) { |
123 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>( |
124 | product, rows, cols, x, y, n); |
125 | } else { |
126 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>( |
127 | product, rows, cols, x, y, n, 0, *yColumnByteStride); |
128 | } |
129 | } else { |
130 | if (!yColumnByteStride) { |
131 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>( |
132 | product, rows, cols, x, y, n, *xColumnByteStride); |
133 | } else { |
134 | MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>( |
135 | product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride); |
136 | } |
137 | } |
138 | } |
139 | |
140 | RT_DIAG_PUSH |
141 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
142 | |
143 | // Contiguous numeric matrix*vector multiplication |
144 | // matrix(rows,n) * column vector(n) -> column vector(rows) |
145 | // Straightforward algorithm: |
146 | // DO 1 J = 1, NROWS |
147 | // RES(J) = 0 |
148 | // DO 1 K = 1, N |
149 | // 1 RES(J) = RES(J) + X(J,K)*Y(K) |
150 | // With loop distribution and transposition to avoid the inner |
151 | // sum reduction and to avoid non-unit strides: |
152 | // DO 1 J = 1, NROWS |
153 | // 1 RES(J) = 0 |
154 | // DO 2 K = 1, N |
155 | // DO 2 J = 1, NROWS |
156 | // 2 RES(J) = RES(J) + X(J,K)*Y(K) |
157 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
158 | bool X_HAS_STRIDED_COLUMNS> |
159 | inline RT_API_ATTRS void MatrixTimesVector( |
160 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
161 | SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, |
162 | std::size_t xColumnByteStride = 0) { |
163 | using ResultType = CppTypeFor<RCAT, RKIND>; |
164 | std::memset(product, 0, rows * sizeof *product); |
165 | [[maybe_unused]] const XT *RESTRICT xp0{x}; |
166 | for (SubscriptValue k{0}; k < n; ++k) { |
167 | ResultType *RESTRICT p{product}; |
168 | auto yv{static_cast<ResultType>(*y++)}; |
169 | for (SubscriptValue j{0}; j < rows; ++j) { |
170 | *p++ += static_cast<ResultType>(*x++) * yv; |
171 | } |
172 | if constexpr (X_HAS_STRIDED_COLUMNS) { |
173 | xp0 = reinterpret_cast<const XT *>( |
174 | reinterpret_cast<const char *>(xp0) + xColumnByteStride); |
175 | x = xp0; |
176 | } |
177 | } |
178 | } |
179 | |
180 | RT_DIAG_POP |
181 | |
182 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
183 | inline RT_API_ATTRS void MatrixTimesVectorHelper( |
184 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
185 | SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, |
186 | std::optional<std::size_t> xColumnByteStride) { |
187 | if (!xColumnByteStride) { |
188 | MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y); |
189 | } else { |
190 | MatrixTimesVector<RCAT, RKIND, XT, YT, true>( |
191 | product, rows, n, x, y, *xColumnByteStride); |
192 | } |
193 | } |
194 | |
195 | RT_DIAG_PUSH |
196 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
197 | |
198 | // Contiguous numeric vector*matrix multiplication |
199 | // row vector(n) * matrix(n,cols) -> row vector(cols) |
200 | // Straightforward algorithm: |
201 | // DO 1 J = 1, NCOLS |
202 | // RES(J) = 0 |
203 | // DO 1 K = 1, N |
204 | // 1 RES(J) = RES(J) + X(K)*Y(K,J) |
205 | // With loop distribution and transposition to avoid the inner |
206 | // sum reduction and one non-unit stride (the other remains): |
207 | // DO 1 J = 1, NCOLS |
208 | // 1 RES(J) = 0 |
209 | // DO 2 K = 1, N |
210 | // DO 2 J = 1, NCOLS |
211 | // 2 RES(J) = RES(J) + X(K)*Y(K,J) |
212 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
213 | bool Y_HAS_STRIDED_COLUMNS> |
214 | inline RT_API_ATTRS void VectorTimesMatrix( |
215 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, |
216 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
217 | std::size_t yColumnByteStride = 0) { |
218 | using ResultType = CppTypeFor<RCAT, RKIND>; |
219 | std::memset(product, 0, cols * sizeof *product); |
220 | for (SubscriptValue k{0}; k < n; ++k) { |
221 | ResultType *RESTRICT p{product}; |
222 | auto xv{static_cast<ResultType>(*x++)}; |
223 | const YT *RESTRICT yp{&y[k]}; |
224 | for (SubscriptValue j{0}; j < cols; ++j) { |
225 | *p++ += xv * static_cast<ResultType>(*yp); |
226 | if constexpr (!Y_HAS_STRIDED_COLUMNS) { |
227 | yp += n; |
228 | } else { |
229 | yp = reinterpret_cast<const YT *>( |
230 | reinterpret_cast<const char *>(yp) + yColumnByteStride); |
231 | } |
232 | } |
233 | } |
234 | } |
235 | |
236 | RT_DIAG_POP |
237 | |
238 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
239 | bool SPARSE_COLUMNS = false> |
240 | inline RT_API_ATTRS void VectorTimesMatrixHelper( |
241 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n, |
242 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
243 | std::optional<std::size_t> yColumnByteStride) { |
244 | if (!yColumnByteStride) { |
245 | VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y); |
246 | } else { |
247 | VectorTimesMatrix<RCAT, RKIND, XT, YT, true>( |
248 | product, n, cols, x, y, *yColumnByteStride); |
249 | } |
250 | } |
251 | |
252 | RT_DIAG_PUSH |
253 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
254 | |
255 | // Implements an instance of MATMUL for given argument types. |
256 | template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, |
257 | typename YT> |
258 | static inline RT_API_ATTRS void DoMatmul( |
259 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, |
260 | const Descriptor &x, const Descriptor &y, Terminator &terminator) { |
261 | int xRank{x.rank()}; |
262 | int yRank{y.rank()}; |
263 | int resRank{xRank + yRank - 2}; |
264 | if (xRank * yRank != 2 * resRank) { |
265 | terminator.Crash("MATMUL: bad argument ranks (%d * %d)" , xRank, yRank); |
266 | } |
267 | SubscriptValue extent[2]{ |
268 | xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), |
269 | resRank == 2 ? y.GetDimension(1).Extent() : 0}; |
270 | if constexpr (IS_ALLOCATING) { |
271 | result.Establish( |
272 | RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); |
273 | for (int j{0}; j < resRank; ++j) { |
274 | result.GetDimension(j).SetBounds(1, extent[j]); |
275 | } |
276 | if (int stat{result.Allocate()}) { |
277 | terminator.Crash( |
278 | "MATMUL: could not allocate memory for result; STAT=%d" , stat); |
279 | } |
280 | } else { |
281 | RUNTIME_CHECK(terminator, resRank == result.rank()); |
282 | RUNTIME_CHECK( |
283 | terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND)); |
284 | RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); |
285 | RUNTIME_CHECK(terminator, |
286 | resRank == 1 || result.GetDimension(1).Extent() == extent[1]); |
287 | } |
288 | SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; |
289 | if (n != y.GetDimension(0).Extent()) { |
290 | terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)" , |
291 | static_cast<std::intmax_t>(x.GetDimension(0).Extent()), |
292 | static_cast<std::intmax_t>(n), |
293 | static_cast<std::intmax_t>(y.GetDimension(0).Extent()), |
294 | static_cast<std::intmax_t>(y.GetDimension(1).Extent())); |
295 | } |
296 | using WriteResult = |
297 | CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, |
298 | RKIND>; |
299 | if constexpr (RCAT != TypeCategory::Logical) { |
300 | if (x.IsContiguous(1) && y.IsContiguous(1) && |
301 | (IS_ALLOCATING || result.IsContiguous())) { |
302 | // Contiguous numeric matrices (maybe with columns |
303 | // separated by a stride). |
304 | std::optional<std::size_t> xColumnByteStride; |
305 | if (!x.IsContiguous()) { |
306 | // X's columns are strided. |
307 | SubscriptValue xAt[2]{}; |
308 | x.GetLowerBounds(xAt); |
309 | xAt[1]++; |
310 | xColumnByteStride = x.SubscriptsToByteOffset(xAt); |
311 | } |
312 | std::optional<std::size_t> yColumnByteStride; |
313 | if (!y.IsContiguous()) { |
314 | // Y's columns are strided. |
315 | SubscriptValue yAt[2]{}; |
316 | y.GetLowerBounds(yAt); |
317 | yAt[1]++; |
318 | yColumnByteStride = y.SubscriptsToByteOffset(yAt); |
319 | } |
320 | // Note that BLAS GEMM can be used for the strided |
321 | // columns by setting proper leading dimension size. |
322 | // This implies that the column stride is divisible |
323 | // by the element size, which is usually true. |
324 | if (resRank == 2) { // M*M -> M |
325 | if (std::is_same_v<XT, YT>) { |
326 | if constexpr (std::is_same_v<XT, float>) { |
327 | // TODO: call BLAS-3 SGEMM |
328 | // TODO: try using CUTLASS for device. |
329 | } else if constexpr (std::is_same_v<XT, double>) { |
330 | // TODO: call BLAS-3 DGEMM |
331 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { |
332 | // TODO: call BLAS-3 CGEMM |
333 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { |
334 | // TODO: call BLAS-3 ZGEMM |
335 | } |
336 | } |
337 | MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>( |
338 | result.template OffsetElement<WriteResult>(), extent[0], extent[1], |
339 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride, |
340 | yColumnByteStride); |
341 | return; |
342 | } else if (xRank == 2) { // M*V -> V |
343 | if (std::is_same_v<XT, YT>) { |
344 | if constexpr (std::is_same_v<XT, float>) { |
345 | // TODO: call BLAS-2 SGEMV(x,y) |
346 | } else if constexpr (std::is_same_v<XT, double>) { |
347 | // TODO: call BLAS-2 DGEMV(x,y) |
348 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { |
349 | // TODO: call BLAS-2 CGEMV(x,y) |
350 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { |
351 | // TODO: call BLAS-2 ZGEMV(x,y) |
352 | } |
353 | } |
354 | MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>( |
355 | result.template OffsetElement<WriteResult>(), extent[0], n, |
356 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride); |
357 | return; |
358 | } else { // V*M -> V |
359 | if (std::is_same_v<XT, YT>) { |
360 | if constexpr (std::is_same_v<XT, float>) { |
361 | // TODO: call BLAS-2 SGEMV(y,x) |
362 | } else if constexpr (std::is_same_v<XT, double>) { |
363 | // TODO: call BLAS-2 DGEMV(y,x) |
364 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { |
365 | // TODO: call BLAS-2 CGEMV(y,x) |
366 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { |
367 | // TODO: call BLAS-2 ZGEMV(y,x) |
368 | } |
369 | } |
370 | VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>( |
371 | result.template OffsetElement<WriteResult>(), n, extent[0], |
372 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride); |
373 | return; |
374 | } |
375 | } |
376 | } |
377 | // General algorithms for LOGICAL and noncontiguity |
378 | SubscriptValue xAt[2], yAt[2], resAt[2]; |
379 | x.GetLowerBounds(xAt); |
380 | y.GetLowerBounds(yAt); |
381 | result.GetLowerBounds(resAt); |
382 | if (resRank == 2) { // M*M -> M |
383 | SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; |
384 | for (SubscriptValue i{0}; i < extent[0]; ++i) { |
385 | for (SubscriptValue j{0}; j < extent[1]; ++j) { |
386 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
387 | yAt[1] = y1 + j; |
388 | for (SubscriptValue k{0}; k < n; ++k) { |
389 | xAt[1] = x1 + k; |
390 | yAt[0] = y0 + k; |
391 | accumulator.Accumulate(xAt, yAt); |
392 | } |
393 | resAt[1] = res1 + j; |
394 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); |
395 | } |
396 | ++resAt[0]; |
397 | ++xAt[0]; |
398 | } |
399 | } else if (xRank == 2) { // M*V -> V |
400 | SubscriptValue x1{xAt[1]}, y0{yAt[0]}; |
401 | for (SubscriptValue j{0}; j < extent[0]; ++j) { |
402 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
403 | for (SubscriptValue k{0}; k < n; ++k) { |
404 | xAt[1] = x1 + k; |
405 | yAt[0] = y0 + k; |
406 | accumulator.Accumulate(xAt, yAt); |
407 | } |
408 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); |
409 | ++resAt[0]; |
410 | ++xAt[0]; |
411 | } |
412 | } else { // V*M -> V |
413 | SubscriptValue x0{xAt[0]}, y0{yAt[0]}; |
414 | for (SubscriptValue j{0}; j < extent[0]; ++j) { |
415 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
416 | for (SubscriptValue k{0}; k < n; ++k) { |
417 | xAt[0] = x0 + k; |
418 | yAt[0] = y0 + k; |
419 | accumulator.Accumulate(xAt, yAt); |
420 | } |
421 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); |
422 | ++resAt[0]; |
423 | ++yAt[1]; |
424 | } |
425 | } |
426 | } |
427 | |
428 | RT_DIAG_POP |
429 | |
430 | // Maps the dynamic type information from the arguments' descriptors |
431 | // to the right instantiation of DoMatmul() for valid combinations of |
432 | // types. |
433 | template <bool IS_ALLOCATING> struct Matmul { |
434 | using ResultDescriptor = |
435 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; |
436 | template <TypeCategory XCAT, int XKIND> struct MM1 { |
437 | template <TypeCategory YCAT, int YKIND> struct MM2 { |
438 | RT_API_ATTRS void operator()(ResultDescriptor &result, |
439 | const Descriptor &x, const Descriptor &y, |
440 | Terminator &terminator) const { |
441 | if constexpr (constexpr auto resultType{ |
442 | GetResultType(XCAT, XKIND, YCAT, YKIND)}) { |
443 | if constexpr (common::IsNumericTypeCategory(resultType->first) || |
444 | resultType->first == TypeCategory::Logical) { |
445 | return DoMatmul<IS_ALLOCATING, resultType->first, |
446 | resultType->second, CppTypeFor<XCAT, XKIND>, |
447 | CppTypeFor<YCAT, YKIND>>(result, x, y, terminator); |
448 | } |
449 | } |
450 | terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))" , |
451 | static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); |
452 | } |
453 | }; |
454 | RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, |
455 | const Descriptor &y, Terminator &terminator, TypeCategory yCat, |
456 | int yKind) const { |
457 | ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator); |
458 | } |
459 | }; |
460 | RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, |
461 | const Descriptor &y, const char *sourceFile, int line) const { |
462 | Terminator terminator{sourceFile, line}; |
463 | auto xCatKind{x.type().GetCategoryAndKind()}; |
464 | auto yCatKind{y.type().GetCategoryAndKind()}; |
465 | RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); |
466 | ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result, |
467 | x, y, terminator, yCatKind->first, yCatKind->second); |
468 | } |
469 | }; |
470 | |
471 | extern "C" { |
472 | RT_EXT_API_GROUP_BEGIN |
473 | |
474 | void RTDEF(Matmul)(Descriptor &result, const Descriptor &x, const Descriptor &y, |
475 | const char *sourceFile, int line) { |
476 | Matmul<true>{}(result, x, y, sourceFile, line); |
477 | } |
478 | void RTDEF(MatmulDirect)(const Descriptor &result, const Descriptor &x, |
479 | const Descriptor &y, const char *sourceFile, int line) { |
480 | Matmul<false>{}(result, x, y, sourceFile, line); |
481 | } |
482 | |
483 | RT_EXT_API_GROUP_END |
484 | } // extern "C" |
485 | } // namespace Fortran::runtime |
486 | |