1 | //===-- runtime/matmul-transpose.cpp --------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | // Implements a fused matmul-transpose operation |
10 | // |
11 | // There are two main entry points; one establishes a descriptor for the |
12 | // result and allocates it, and the other expects a result descriptor that |
13 | // points to existing storage. |
14 | // |
15 | // This implementation must handle all combinations of numeric types and |
16 | // kinds (100 - 165 cases depending on the target), plus all combinations |
17 | // of logical kinds (16). A single template undergoes many instantiations |
18 | // to cover all of the valid possibilities. |
19 | // |
20 | // The usefulness of this optimization should be reviewed once Matmul is swapped |
21 | // to use the faster BLAS routines. |
22 | |
23 | #include "flang/Runtime/matmul-transpose.h" |
24 | #include "terminator.h" |
25 | #include "tools.h" |
26 | #include "flang/Common/optional.h" |
27 | #include "flang/Runtime/c-or-cpp.h" |
28 | #include "flang/Runtime/cpp-type.h" |
29 | #include "flang/Runtime/descriptor.h" |
30 | #include <cstring> |
31 | |
32 | namespace { |
33 | using namespace Fortran::runtime; |
34 | |
35 | // Suppress the warnings about calling __host__-only std::complex operators, |
36 | // defined in C++ STD header files, from __device__ code. |
37 | RT_DIAG_PUSH |
38 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
39 | |
40 | // Contiguous numeric TRANSPOSE(matrix)*matrix multiplication |
41 | // TRANSPOSE(matrix(n, rows)) * matrix(n,cols) -> |
42 | // matrix(rows, n) * matrix(n,cols) -> matrix(rows,cols) |
43 | // The transpose is implemented by swapping the indices of accesses into the LHS |
44 | // |
45 | // Straightforward algorithm: |
46 | // DO 1 I = 1, NROWS |
47 | // DO 1 J = 1, NCOLS |
48 | // RES(I,J) = 0 |
49 | // DO 1 K = 1, N |
50 | // 1 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) |
51 | // |
52 | // With loop distribution and transposition to avoid the inner sum |
53 | // reduction and to avoid non-unit strides: |
54 | // DO 1 I = 1, NROWS |
55 | // DO 1 J = 1, NCOLS |
56 | // 1 RES(I,J) = 0 |
57 | // DO 2 J = 1, NCOLS |
58 | // DO 2 I = 1, NROWS |
59 | // DO 2 K = 1, N |
60 | // 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term |
61 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
62 | bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS> |
63 | inline static RT_API_ATTRS void MatrixTransposedTimesMatrix( |
64 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
65 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
66 | SubscriptValue n, std::size_t xColumnByteStride = 0, |
67 | std::size_t yColumnByteStride = 0) { |
68 | using ResultType = CppTypeFor<RCAT, RKIND>; |
69 | |
70 | std::memset(product, 0, rows * cols * sizeof *product); |
71 | for (SubscriptValue j{0}; j < cols; ++j) { |
72 | for (SubscriptValue i{0}; i < rows; ++i) { |
73 | for (SubscriptValue k{0}; k < n; ++k) { |
74 | ResultType x_ki; |
75 | if constexpr (!X_HAS_STRIDED_COLUMNS) { |
76 | x_ki = static_cast<ResultType>(x[i * n + k]); |
77 | } else { |
78 | x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>( |
79 | reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]); |
80 | } |
81 | ResultType y_kj; |
82 | if constexpr (!Y_HAS_STRIDED_COLUMNS) { |
83 | y_kj = static_cast<ResultType>(y[j * n + k]); |
84 | } else { |
85 | y_kj = static_cast<ResultType>(reinterpret_cast<const YT *>( |
86 | reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]); |
87 | } |
88 | product[j * rows + i] += x_ki * y_kj; |
89 | } |
90 | } |
91 | } |
92 | } |
93 | |
94 | RT_DIAG_POP |
95 | |
96 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
97 | inline static RT_API_ATTRS void MatrixTransposedTimesMatrixHelper( |
98 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
99 | SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, |
100 | SubscriptValue n, Fortran::common::optional<std::size_t> xColumnByteStride, |
101 | Fortran::common::optional<std::size_t> yColumnByteStride) { |
102 | if (!xColumnByteStride) { |
103 | if (!yColumnByteStride) { |
104 | MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, false>( |
105 | product, rows, cols, x, y, n); |
106 | } else { |
107 | MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, true>( |
108 | product, rows, cols, x, y, n, 0, *yColumnByteStride); |
109 | } |
110 | } else { |
111 | if (!yColumnByteStride) { |
112 | MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, false>( |
113 | product, rows, cols, x, y, n, *xColumnByteStride); |
114 | } else { |
115 | MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, true>( |
116 | product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride); |
117 | } |
118 | } |
119 | } |
120 | |
121 | RT_DIAG_PUSH |
122 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
123 | |
124 | // Contiguous numeric matrix*vector multiplication |
125 | // matrix(rows,n) * column vector(n) -> column vector(rows) |
126 | // Straightforward algorithm: |
127 | // DO 1 I = 1, NROWS |
128 | // RES(I) = 0 |
129 | // DO 1 K = 1, N |
130 | // 1 RES(I) = RES(I) + X(K,I)*Y(K) |
131 | // With loop distribution and transposition to avoid the inner |
132 | // sum reduction and to avoid non-unit strides: |
133 | // DO 1 I = 1, NROWS |
134 | // 1 RES(I) = 0 |
135 | // DO 2 I = 1, NROWS |
136 | // DO 2 K = 1, N |
137 | // 2 RES(I) = RES(I) + X(K,I)*Y(K) |
138 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT, |
139 | bool X_HAS_STRIDED_COLUMNS> |
140 | inline static RT_API_ATTRS void MatrixTransposedTimesVector( |
141 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
142 | SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, |
143 | std::size_t xColumnByteStride = 0) { |
144 | using ResultType = CppTypeFor<RCAT, RKIND>; |
145 | std::memset(product, 0, rows * sizeof *product); |
146 | for (SubscriptValue i{0}; i < rows; ++i) { |
147 | for (SubscriptValue k{0}; k < n; ++k) { |
148 | ResultType x_ki; |
149 | if constexpr (!X_HAS_STRIDED_COLUMNS) { |
150 | x_ki = static_cast<ResultType>(x[i * n + k]); |
151 | } else { |
152 | x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>( |
153 | reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]); |
154 | } |
155 | ResultType y_k = static_cast<ResultType>(y[k]); |
156 | product[i] += x_ki * y_k; |
157 | } |
158 | } |
159 | } |
160 | |
161 | RT_DIAG_POP |
162 | |
163 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
164 | inline static RT_API_ATTRS void MatrixTransposedTimesVectorHelper( |
165 | CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows, |
166 | SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y, |
167 | Fortran::common::optional<std::size_t> xColumnByteStride) { |
168 | if (!xColumnByteStride) { |
169 | MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, false>( |
170 | product, rows, n, x, y); |
171 | } else { |
172 | MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, true>( |
173 | product, rows, n, x, y, *xColumnByteStride); |
174 | } |
175 | } |
176 | |
177 | RT_DIAG_PUSH |
178 | RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN |
179 | |
180 | // Implements an instance of MATMUL for given argument types. |
181 | template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, |
182 | typename YT> |
183 | inline static RT_API_ATTRS void DoMatmulTranspose( |
184 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, |
185 | const Descriptor &x, const Descriptor &y, Terminator &terminator) { |
186 | int xRank{x.rank()}; |
187 | int yRank{y.rank()}; |
188 | int resRank{xRank + yRank - 2}; |
189 | if (xRank * yRank != 2 * resRank) { |
190 | terminator.Crash( |
191 | "MATMUL-TRANSPOSE: bad argument ranks (%d * %d)" , xRank, yRank); |
192 | } |
193 | SubscriptValue extent[2]{x.GetDimension(1).Extent(), |
194 | resRank == 2 ? y.GetDimension(1).Extent() : 0}; |
195 | if constexpr (IS_ALLOCATING) { |
196 | result.Establish( |
197 | RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); |
198 | for (int j{0}; j < resRank; ++j) { |
199 | result.GetDimension(j).SetBounds(1, extent[j]); |
200 | } |
201 | if (int stat{result.Allocate()}) { |
202 | terminator.Crash( |
203 | "MATMUL-TRANSPOSE: could not allocate memory for result; STAT=%d" , |
204 | stat); |
205 | } |
206 | } else { |
207 | RUNTIME_CHECK(terminator, resRank == result.rank()); |
208 | RUNTIME_CHECK( |
209 | terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND)); |
210 | RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); |
211 | RUNTIME_CHECK(terminator, |
212 | resRank == 1 || result.GetDimension(1).Extent() == extent[1]); |
213 | } |
214 | SubscriptValue n{x.GetDimension(0).Extent()}; |
215 | if (n != y.GetDimension(0).Extent()) { |
216 | terminator.Crash( |
217 | "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)" , |
218 | static_cast<std::intmax_t>(x.GetDimension(0).Extent()), |
219 | static_cast<std::intmax_t>(x.GetDimension(1).Extent()), |
220 | static_cast<std::intmax_t>(y.GetDimension(0).Extent()), |
221 | static_cast<std::intmax_t>(y.GetDimension(1).Extent())); |
222 | } |
223 | using WriteResult = |
224 | CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, |
225 | RKIND>; |
226 | const SubscriptValue rows{extent[0]}; |
227 | const SubscriptValue cols{extent[1]}; |
228 | if constexpr (RCAT != TypeCategory::Logical) { |
229 | if (x.IsContiguous(1) && y.IsContiguous(1) && |
230 | (IS_ALLOCATING || result.IsContiguous())) { |
231 | // Contiguous numeric matrices (maybe with columns |
232 | // separated by a stride). |
233 | Fortran::common::optional<std::size_t> xColumnByteStride; |
234 | if (!x.IsContiguous()) { |
235 | // X's columns are strided. |
236 | SubscriptValue xAt[2]{}; |
237 | x.GetLowerBounds(xAt); |
238 | xAt[1]++; |
239 | xColumnByteStride = x.SubscriptsToByteOffset(xAt); |
240 | } |
241 | Fortran::common::optional<std::size_t> yColumnByteStride; |
242 | if (!y.IsContiguous()) { |
243 | // Y's columns are strided. |
244 | SubscriptValue yAt[2]{}; |
245 | y.GetLowerBounds(yAt); |
246 | yAt[1]++; |
247 | yColumnByteStride = y.SubscriptsToByteOffset(yAt); |
248 | } |
249 | if (resRank == 2) { // M*M -> M |
250 | // TODO: use BLAS-3 GEMM for supported types. |
251 | MatrixTransposedTimesMatrixHelper<RCAT, RKIND, XT, YT>( |
252 | result.template OffsetElement<WriteResult>(), rows, cols, |
253 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride, |
254 | yColumnByteStride); |
255 | return; |
256 | } |
257 | if (xRank == 2) { // M*V -> V |
258 | // TODO: use BLAS-2 GEMM for supported types. |
259 | MatrixTransposedTimesVectorHelper<RCAT, RKIND, XT, YT>( |
260 | result.template OffsetElement<WriteResult>(), rows, n, |
261 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride); |
262 | return; |
263 | } |
264 | // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank |
265 | // 1 matrices |
266 | terminator.Crash( |
267 | "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)" , |
268 | static_cast<std::intmax_t>(x.GetDimension(0).Extent()), |
269 | static_cast<std::intmax_t>(n), |
270 | static_cast<std::intmax_t>(y.GetDimension(0).Extent()), |
271 | static_cast<std::intmax_t>(y.GetDimension(1).Extent())); |
272 | return; |
273 | } |
274 | } |
275 | // General algorithms for LOGICAL and noncontiguity |
276 | SubscriptValue xLB[2], yLB[2], resLB[2]; |
277 | x.GetLowerBounds(xLB); |
278 | y.GetLowerBounds(yLB); |
279 | result.GetLowerBounds(resLB); |
280 | using ResultType = CppTypeFor<RCAT, RKIND>; |
281 | if (resRank == 2) { // M*M -> M |
282 | for (SubscriptValue i{0}; i < rows; ++i) { |
283 | for (SubscriptValue j{0}; j < cols; ++j) { |
284 | ResultType res_ij; |
285 | if constexpr (RCAT == TypeCategory::Logical) { |
286 | res_ij = false; |
287 | } else { |
288 | res_ij = 0; |
289 | } |
290 | |
291 | for (SubscriptValue k{0}; k < n; ++k) { |
292 | SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]}; |
293 | SubscriptValue yAt[2]{k + yLB[0], j + yLB[1]}; |
294 | if constexpr (RCAT == TypeCategory::Logical) { |
295 | ResultType x_ki = IsLogicalElementTrue(x, xAt); |
296 | ResultType y_kj = IsLogicalElementTrue(y, yAt); |
297 | res_ij = res_ij || (x_ki && y_kj); |
298 | } else { |
299 | ResultType x_ki = static_cast<ResultType>(*x.Element<XT>(xAt)); |
300 | ResultType y_kj = static_cast<ResultType>(*y.Element<YT>(yAt)); |
301 | res_ij += x_ki * y_kj; |
302 | } |
303 | } |
304 | SubscriptValue resAt[2]{i + resLB[0], j + resLB[1]}; |
305 | *result.template Element<WriteResult>(resAt) = res_ij; |
306 | } |
307 | } |
308 | } else if (xRank == 2) { // M*V -> V |
309 | for (SubscriptValue i{0}; i < rows; ++i) { |
310 | ResultType res_i; |
311 | if constexpr (RCAT == TypeCategory::Logical) { |
312 | res_i = false; |
313 | } else { |
314 | res_i = 0; |
315 | } |
316 | |
317 | for (SubscriptValue k{0}; k < n; ++k) { |
318 | SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]}; |
319 | SubscriptValue yAt[1]{k + yLB[0]}; |
320 | if constexpr (RCAT == TypeCategory::Logical) { |
321 | ResultType x_ki = IsLogicalElementTrue(x, xAt); |
322 | ResultType y_k = IsLogicalElementTrue(y, yAt); |
323 | res_i = res_i || (x_ki && y_k); |
324 | } else { |
325 | ResultType x_ki = static_cast<ResultType>(*x.Element<XT>(xAt)); |
326 | ResultType y_k = static_cast<ResultType>(*y.Element<YT>(yAt)); |
327 | res_i += x_ki * y_k; |
328 | } |
329 | } |
330 | SubscriptValue resAt[1]{i + resLB[0]}; |
331 | *result.template Element<WriteResult>(resAt) = res_i; |
332 | } |
333 | } else { // V*M -> V |
334 | // TRANSPOSE(V) not allowed by fortran standard |
335 | terminator.Crash( |
336 | "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)" , |
337 | static_cast<std::intmax_t>(x.GetDimension(0).Extent()), |
338 | static_cast<std::intmax_t>(n), |
339 | static_cast<std::intmax_t>(y.GetDimension(0).Extent()), |
340 | static_cast<std::intmax_t>(y.GetDimension(1).Extent())); |
341 | } |
342 | } |
343 | |
344 | RT_DIAG_POP |
345 | |
346 | // Maps the dynamic type information from the arguments' descriptors |
347 | // to the right instantiation of DoMatmul() for valid combinations of |
348 | // types. |
349 | template <bool IS_ALLOCATING> struct MatmulTranspose { |
350 | using ResultDescriptor = |
351 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; |
352 | template <TypeCategory XCAT, int XKIND> struct MM1 { |
353 | template <TypeCategory YCAT, int YKIND> struct MM2 { |
354 | RT_API_ATTRS void operator()(ResultDescriptor &result, |
355 | const Descriptor &x, const Descriptor &y, |
356 | Terminator &terminator) const { |
357 | if constexpr (constexpr auto resultType{ |
358 | GetResultType(XCAT, XKIND, YCAT, YKIND)}) { |
359 | if constexpr (Fortran::common::IsNumericTypeCategory( |
360 | resultType->first) || |
361 | resultType->first == TypeCategory::Logical) { |
362 | return DoMatmulTranspose<IS_ALLOCATING, resultType->first, |
363 | resultType->second, CppTypeFor<XCAT, XKIND>, |
364 | CppTypeFor<YCAT, YKIND>>(result, x, y, terminator); |
365 | } |
366 | } |
367 | terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))" , |
368 | static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); |
369 | } |
370 | }; |
371 | RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, |
372 | const Descriptor &y, Terminator &terminator, TypeCategory yCat, |
373 | int yKind) const { |
374 | ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator); |
375 | } |
376 | }; |
377 | RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x, |
378 | const Descriptor &y, const char *sourceFile, int line) const { |
379 | Terminator terminator{sourceFile, line}; |
380 | auto xCatKind{x.type().GetCategoryAndKind()}; |
381 | auto yCatKind{y.type().GetCategoryAndKind()}; |
382 | RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); |
383 | ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result, |
384 | x, y, terminator, yCatKind->first, yCatKind->second); |
385 | } |
386 | }; |
387 | } // namespace |
388 | |
389 | namespace Fortran::runtime { |
390 | extern "C" { |
391 | RT_EXT_API_GROUP_BEGIN |
392 | |
393 | void RTDEF(MatmulTranspose)(Descriptor &result, const Descriptor &x, |
394 | const Descriptor &y, const char *sourceFile, int line) { |
395 | MatmulTranspose<true>{}(result, x, y, sourceFile, line); |
396 | } |
397 | void RTDEF(MatmulTransposeDirect)(const Descriptor &result, const Descriptor &x, |
398 | const Descriptor &y, const char *sourceFile, int line) { |
399 | MatmulTranspose<false>{}(result, x, y, sourceFile, line); |
400 | } |
401 | |
402 | RT_EXT_API_GROUP_END |
403 | } // extern "C" |
404 | } // namespace Fortran::runtime |
405 | |