1 | //===-- flang/unittests/Runtime/Matmul.cpp--------- -------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Runtime/matmul.h" |
10 | #include "gtest/gtest.h" |
11 | #include "tools.h" |
12 | #include "flang/Runtime/allocatable.h" |
13 | #include "flang/Runtime/cpp-type.h" |
14 | #include "flang/Runtime/descriptor.h" |
15 | #include "flang/Runtime/type-code.h" |
16 | |
17 | using namespace Fortran::runtime; |
18 | using Fortran::common::TypeCategory; |
19 | |
20 | TEST(Matmul, Basic) { |
21 | // X 0 2 4 Y 6 9 V -1 -2 |
22 | // 1 3 5 7 10 |
23 | // 8 11 |
24 | auto x{MakeArray<TypeCategory::Integer, 4>( |
25 | std::vector<int>{2, 3}, std::vector<std::int32_t>{0, 1, 2, 3, 4, 5})}; |
26 | auto y{MakeArray<TypeCategory::Integer, 2>( |
27 | std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})}; |
28 | auto v{MakeArray<TypeCategory::Integer, 8>( |
29 | std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})}; |
30 | |
31 | // X2 0 2 4 Y2 -1 -1 |
32 | // 1 3 5 6 9 |
33 | // -1 -1 -1 7 10 |
34 | // 8 11 |
35 | auto x2{MakeArray<TypeCategory::Integer, 4>(std::vector<int>{3, 3}, |
36 | std::vector<std::int32_t>{0, 1, -1, 2, 3, -1, 4, 5})}; |
37 | auto y2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{4, 2}, |
38 | std::vector<std::int16_t>{-1, 6, 7, 8, -1, 9, 10, 11})}; |
39 | |
40 | StaticDescriptor<2, true> statDesc; |
41 | Descriptor &result{statDesc.descriptor()}; |
42 | |
43 | RTNAME(Matmul)(result, *x, *y, __FILE__, __LINE__); |
44 | ASSERT_EQ(result.rank(), 2); |
45 | EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); |
46 | EXPECT_EQ(result.GetDimension(0).Extent(), 2); |
47 | EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); |
48 | EXPECT_EQ(result.GetDimension(1).Extent(), 2); |
49 | ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); |
50 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46); |
51 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67); |
52 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64); |
53 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94); |
54 | |
55 | std::memset( |
56 | s: result.raw().base_addr, c: 0, n: result.Elements() * result.ElementBytes()); |
57 | result.GetDimension(0).SetLowerBound(0); |
58 | result.GetDimension(1).SetLowerBound(2); |
59 | RTNAME(MatmulDirect)(result, *x, *y, __FILE__, __LINE__); |
60 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46); |
61 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67); |
62 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64); |
63 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94); |
64 | result.Destroy(); |
65 | |
66 | RTNAME(Matmul)(result, *v, *x, __FILE__, __LINE__); |
67 | ASSERT_EQ(result.rank(), 1); |
68 | EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); |
69 | EXPECT_EQ(result.GetDimension(0).Extent(), 3); |
70 | ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); |
71 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2); |
72 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8); |
73 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14); |
74 | result.Destroy(); |
75 | |
76 | RTNAME(Matmul)(result, *y, *v, __FILE__, __LINE__); |
77 | ASSERT_EQ(result.rank(), 1); |
78 | EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); |
79 | EXPECT_EQ(result.GetDimension(0).Extent(), 3); |
80 | ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); |
81 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24); |
82 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27); |
83 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30); |
84 | result.Destroy(); |
85 | |
86 | // Test non-contiguous sections. |
87 | static constexpr int sectionRank{2}; |
88 | StaticDescriptor<sectionRank> sectionStaticDescriptorX2; |
89 | Descriptor §ionX2{sectionStaticDescriptorX2.descriptor()}; |
90 | sectionX2.Establish(x2->type(), x2->ElementBytes(), |
91 | /*p=*/nullptr, /*rank=*/sectionRank); |
92 | static const SubscriptValue lowersX2[]{1, 1}, uppersX2[]{2, 3}; |
93 | // Section of X2: |
94 | // +--------+ |
95 | // | 0 2 4| |
96 | // | 1 3 5| |
97 | // +--------+ |
98 | // -1 -1 -1 |
99 | const auto errorX2{CFI_section( |
100 | §ionX2.raw(), &x2->raw(), lowersX2, uppersX2, /*strides=*/nullptr)}; |
101 | ASSERT_EQ(errorX2, 0) << "CFI_section failed for X2: " << errorX2; |
102 | |
103 | StaticDescriptor<sectionRank> sectionStaticDescriptorY2; |
104 | Descriptor §ionY2{sectionStaticDescriptorY2.descriptor()}; |
105 | sectionY2.Establish(y2->type(), y2->ElementBytes(), |
106 | /*p=*/nullptr, /*rank=*/sectionRank); |
107 | static const SubscriptValue lowersY2[]{2, 1}; |
108 | // Section of Y2: |
109 | // -1 -1 |
110 | // +-----+ |
111 | // | 6 9| |
112 | // | 7 10| |
113 | // | 8 11| |
114 | // +-----+ |
115 | const auto errorY2{CFI_section(§ionY2.raw(), &y2->raw(), lowersY2, |
116 | /*uppers=*/nullptr, /*strides=*/nullptr)}; |
117 | ASSERT_EQ(errorY2, 0) << "CFI_section failed for Y2: " << errorY2; |
118 | |
119 | RTNAME(Matmul)(result, sectionX2, *y, __FILE__, __LINE__); |
120 | ASSERT_EQ(result.rank(), 2); |
121 | EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); |
122 | EXPECT_EQ(result.GetDimension(0).Extent(), 2); |
123 | EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); |
124 | EXPECT_EQ(result.GetDimension(1).Extent(), 2); |
125 | ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); |
126 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46); |
127 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67); |
128 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64); |
129 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94); |
130 | result.Destroy(); |
131 | |
132 | RTNAME(Matmul)(result, *x, sectionY2, __FILE__, __LINE__); |
133 | ASSERT_EQ(result.rank(), 2); |
134 | EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); |
135 | EXPECT_EQ(result.GetDimension(0).Extent(), 2); |
136 | EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); |
137 | EXPECT_EQ(result.GetDimension(1).Extent(), 2); |
138 | ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); |
139 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46); |
140 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67); |
141 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64); |
142 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94); |
143 | result.Destroy(); |
144 | |
145 | RTNAME(Matmul)(result, sectionX2, sectionY2, __FILE__, __LINE__); |
146 | ASSERT_EQ(result.rank(), 2); |
147 | EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); |
148 | EXPECT_EQ(result.GetDimension(0).Extent(), 2); |
149 | EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); |
150 | EXPECT_EQ(result.GetDimension(1).Extent(), 2); |
151 | ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); |
152 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46); |
153 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67); |
154 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64); |
155 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94); |
156 | result.Destroy(); |
157 | |
158 | RTNAME(Matmul)(result, *v, sectionX2, __FILE__, __LINE__); |
159 | ASSERT_EQ(result.rank(), 1); |
160 | EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); |
161 | EXPECT_EQ(result.GetDimension(0).Extent(), 3); |
162 | ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); |
163 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2); |
164 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8); |
165 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14); |
166 | result.Destroy(); |
167 | |
168 | RTNAME(Matmul)(result, sectionY2, *v, __FILE__, __LINE__); |
169 | ASSERT_EQ(result.rank(), 1); |
170 | EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); |
171 | EXPECT_EQ(result.GetDimension(0).Extent(), 3); |
172 | ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); |
173 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24); |
174 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27); |
175 | EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30); |
176 | result.Destroy(); |
177 | |
178 | // X F F T Y F T |
179 | // F T T F T |
180 | // F F |
181 | auto xLog{MakeArray<TypeCategory::Logical, 1>(std::vector<int>{2, 3}, |
182 | std::vector<std::uint8_t>{false, false, false, true, true, false})}; |
183 | auto yLog{MakeArray<TypeCategory::Logical, 2>(std::vector<int>{3, 2}, |
184 | std::vector<std::uint16_t>{false, false, false, true, true, false})}; |
185 | RTNAME(Matmul)(result, *xLog, *yLog, __FILE__, __LINE__); |
186 | ASSERT_EQ(result.rank(), 2); |
187 | EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); |
188 | EXPECT_EQ(result.GetDimension(0).Extent(), 2); |
189 | EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); |
190 | EXPECT_EQ(result.GetDimension(1).Extent(), 2); |
191 | ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2})); |
192 | EXPECT_FALSE( |
193 | static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0))); |
194 | EXPECT_FALSE( |
195 | static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1))); |
196 | EXPECT_FALSE( |
197 | static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2))); |
198 | EXPECT_TRUE( |
199 | static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3))); |
200 | } |
201 | |