1 | //===-- lib/Evaluate/fold-matmul.h ----------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_ |
10 | #define FORTRAN_EVALUATE_FOLD_MATMUL_H_ |
11 | |
12 | #include "fold-implementation.h" |
13 | |
14 | namespace Fortran::evaluate { |
15 | |
16 | template <typename T> |
17 | static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) { |
18 | using Element = typename Constant<T>::Element; |
19 | auto args{funcRef.arguments()}; |
20 | CHECK(args.size() == 2); |
21 | Folder<T> folder{context}; |
22 | Constant<T> *ma{folder.Folding(args[0])}; |
23 | Constant<T> *mb{folder.Folding(args[1])}; |
24 | if (!ma || !mb) { |
25 | return Expr<T>{std::move(funcRef)}; |
26 | } |
27 | CHECK(ma->Rank() >= 1 && ma->Rank() <= 2 && mb->Rank() >= 1 && |
28 | mb->Rank() <= 2 && (ma->Rank() == 2 || mb->Rank() == 2)); |
29 | ConstantSubscript commonExtent{ma->shape().back()}; |
30 | if (mb->shape().front() != commonExtent) { |
31 | context.messages().Say( |
32 | "Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US , |
33 | commonExtent, mb->shape().front()); |
34 | return MakeInvalidIntrinsic(std::move(funcRef)); |
35 | } |
36 | ConstantSubscript rows{ma->Rank() == 1 ? 1 : ma->shape()[0]}; |
37 | ConstantSubscript columns{mb->Rank() == 1 ? 1 : mb->shape()[1]}; |
38 | std::vector<Element> elements; |
39 | elements.reserve(rows * columns); |
40 | bool overflow{false}; |
41 | [[maybe_unused]] const auto &rounding{ |
42 | context.targetCharacteristics().roundingMode()}; |
43 | // result(j,k) = SUM(A(j,:) * B(:,k)) |
44 | for (ConstantSubscript ci{0}; ci < columns; ++ci) { |
45 | for (ConstantSubscript ri{0}; ri < rows; ++ri) { |
46 | ConstantSubscripts aAt{ma->lbounds()}; |
47 | if (ma->Rank() == 2) { |
48 | aAt[0] += ri; |
49 | } |
50 | ConstantSubscripts bAt{mb->lbounds()}; |
51 | if (mb->Rank() == 2) { |
52 | bAt[1] += ci; |
53 | } |
54 | Element sum{}; |
55 | [[maybe_unused]] Element correction{}; |
56 | for (ConstantSubscript j{0}; j < commonExtent; ++j) { |
57 | Element aElt{ma->At(aAt)}; |
58 | Element bElt{mb->At(bAt)}; |
59 | if constexpr (T::category == TypeCategory::Real || |
60 | T::category == TypeCategory::Complex) { |
61 | // Kahan summation |
62 | auto product{aElt.Multiply(bElt, rounding)}; |
63 | overflow |= product.flags.test(RealFlag::Overflow); |
64 | auto next{correction.Add(product.value, rounding)}; |
65 | overflow |= next.flags.test(RealFlag::Overflow); |
66 | auto added{sum.Add(next.value, rounding)}; |
67 | overflow |= added.flags.test(RealFlag::Overflow); |
68 | correction = added.value.Subtract(sum, rounding) |
69 | .value.Subtract(next.value, rounding) |
70 | .value; |
71 | sum = std::move(added.value); |
72 | } else if constexpr (T::category == TypeCategory::Integer) { |
73 | auto product{aElt.MultiplySigned(bElt)}; |
74 | overflow |= product.SignedMultiplicationOverflowed(); |
75 | auto added{sum.AddSigned(product.lower)}; |
76 | overflow |= added.overflow; |
77 | sum = std::move(added.value); |
78 | } else { |
79 | static_assert(T::category == TypeCategory::Logical); |
80 | sum = sum.OR(aElt.AND(bElt)); |
81 | } |
82 | ++aAt.back(); |
83 | ++bAt.front(); |
84 | } |
85 | elements.push_back(sum); |
86 | } |
87 | } |
88 | if (overflow) { |
89 | context.messages().Say( |
90 | "MATMUL of %s data overflowed during computation"_warn_en_US , |
91 | T::AsFortran()); |
92 | } |
93 | ConstantSubscripts shape; |
94 | if (ma->Rank() == 2) { |
95 | shape.push_back(rows); |
96 | } |
97 | if (mb->Rank() == 2) { |
98 | shape.push_back(columns); |
99 | } |
100 | return Expr<T>{Constant<T>{std::move(elements), std::move(shape)}}; |
101 | } |
102 | } // namespace Fortran::evaluate |
103 | #endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_ |
104 | |