1//===-- flang/unittests/Runtime/MatmulTranspose.cpp -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "gtest/gtest.h"
10#include "tools.h"
11#include "flang/Runtime/allocatable.h"
12#include "flang/Runtime/cpp-type.h"
13#include "flang/Runtime/descriptor.h"
14#include "flang/Runtime/matmul-transpose.h"
15#include "flang/Runtime/type-code.h"
16
17using namespace Fortran::runtime;
18using Fortran::common::TypeCategory;
19
20TEST(MatmulTranspose, Basic) {
21 // X 0 1 Y 6 9 Z 6 7 8 M 0 0 1 1 V -1 -2
22 // 2 3 7 10 9 10 11 0 1 0 1
23 // 4 5 8 11
24
25 auto x{MakeArray<TypeCategory::Integer, 4>(
26 std::vector<int>{3, 2}, std::vector<std::int32_t>{0, 2, 4, 1, 3, 5})};
27 auto y{MakeArray<TypeCategory::Integer, 2>(
28 std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})};
29 auto z{MakeArray<TypeCategory::Integer, 2>(
30 std::vector<int>{2, 3}, std::vector<std::int16_t>{6, 9, 7, 10, 8, 11})};
31 auto m{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{2, 4},
32 std::vector<std::int16_t>{0, 0, 0, 1, 1, 0, 1, 1})};
33 auto v{MakeArray<TypeCategory::Integer, 8>(
34 std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
35 // X2 0 1 Y2 -1 -1 Z2 6 7 8
36 // 2 3 6 9 9 10 11
37 // 4 5 7 10 -1 -1 -1
38 // -1 -1 8 11
39 auto x2{MakeArray<TypeCategory::Integer, 4>(std::vector<int>{4, 2},
40 std::vector<std::int32_t>{0, 2, 4, -1, 1, 3, 5, -1})};
41 auto y2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{4, 2},
42 std::vector<std::int16_t>{-1, 6, 7, 8, -1, 9, 10, 11})};
43 auto z2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{3, 3},
44 std::vector<std::int16_t>{6, 9, -1, 7, 10, -1, 8, 11, -1})};
45
46 StaticDescriptor<2, true> statDesc;
47 Descriptor &result{statDesc.descriptor()};
48
49 RTNAME(MatmulTranspose)(result, *x, *y, __FILE__, __LINE__);
50 ASSERT_EQ(result.rank(), 2);
51 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
52 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
53 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
54 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
55 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
56 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
57 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
58 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
59 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
60
61 std::memset(
62 s: result.raw().base_addr, c: 0, n: result.Elements() * result.ElementBytes());
63 result.GetDimension(0).SetLowerBound(0);
64 result.GetDimension(1).SetLowerBound(2);
65 RTNAME(MatmulTransposeDirect)(result, *x, *y, __FILE__, __LINE__);
66 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
67 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
68 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
69 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
70 result.Destroy();
71
72 RTNAME(MatmulTranspose)(result, *z, *v, __FILE__, __LINE__);
73 ASSERT_EQ(result.rank(), 1);
74 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
75 EXPECT_EQ(result.GetDimension(0).Extent(), 3);
76 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
77 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
78 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
79 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
80 result.Destroy();
81
82 RTNAME(MatmulTranspose)(result, *m, *z, __FILE__, __LINE__);
83 ASSERT_EQ(result.rank(), 2);
84 ASSERT_EQ(result.GetDimension(0).LowerBound(), 1);
85 ASSERT_EQ(result.GetDimension(0).UpperBound(), 4);
86 ASSERT_EQ(result.GetDimension(1).LowerBound(), 1);
87 ASSERT_EQ(result.GetDimension(1).UpperBound(), 3);
88 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 2}));
89 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(0), 0);
90 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(1), 9);
91 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(2), 6);
92 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(3), 15);
93 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(4), 0);
94 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(5), 10);
95 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(6), 7);
96 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(7), 17);
97 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(8), 0);
98 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(9), 11);
99 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(10), 8);
100 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(11), 19);
101 result.Destroy();
102
103 // Test non-contiguous sections.
104 static constexpr int sectionRank{2};
105 StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
106 Descriptor &sectionX2{sectionStaticDescriptorX2.descriptor()};
107 sectionX2.Establish(x2->type(), x2->ElementBytes(),
108 /*p=*/nullptr, /*rank=*/sectionRank);
109 static const SubscriptValue lowersX2[]{1, 1}, uppersX2[]{3, 2};
110 // Section of X2:
111 // +-----+
112 // | 0 1|
113 // | 2 3|
114 // | 4 5|
115 // +-----+
116 // -1 -1
117 const auto errorX2{CFI_section(
118 &sectionX2.raw(), &x2->raw(), lowersX2, uppersX2, /*strides=*/nullptr)};
119 ASSERT_EQ(errorX2, 0) << "CFI_section failed for X2: " << errorX2;
120
121 StaticDescriptor<sectionRank> sectionStaticDescriptorY2;
122 Descriptor &sectionY2{sectionStaticDescriptorY2.descriptor()};
123 sectionY2.Establish(y2->type(), y2->ElementBytes(),
124 /*p=*/nullptr, /*rank=*/sectionRank);
125 static const SubscriptValue lowersY2[]{2, 1};
126 // Section of Y2:
127 // -1 -1
128 // +-----+
129 // | 6 0|
130 // | 7 10|
131 // | 8 11|
132 // +-----+
133 const auto errorY2{CFI_section(&sectionY2.raw(), &y2->raw(), lowersY2,
134 /*uppers=*/nullptr, /*strides=*/nullptr)};
135 ASSERT_EQ(errorY2, 0) << "CFI_section failed for Y2: " << errorY2;
136
137 StaticDescriptor<sectionRank> sectionStaticDescriptorZ2;
138 Descriptor &sectionZ2{sectionStaticDescriptorZ2.descriptor()};
139 sectionZ2.Establish(z2->type(), z2->ElementBytes(),
140 /*p=*/nullptr, /*rank=*/sectionRank);
141 static const SubscriptValue lowersZ2[]{1, 1}, uppersZ2[]{2, 3};
142 // Section of Z2:
143 // +--------+
144 // | 6 7 8|
145 // | 9 10 11|
146 // +--------+
147 // -1 -1 -1
148 const auto errorZ2{CFI_section(
149 &sectionZ2.raw(), &z2->raw(), lowersZ2, uppersZ2, /*strides=*/nullptr)};
150 ASSERT_EQ(errorZ2, 0) << "CFI_section failed for Z2: " << errorZ2;
151
152 RTNAME(MatmulTranspose)(result, sectionX2, *y, __FILE__, __LINE__);
153 ASSERT_EQ(result.rank(), 2);
154 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
155 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
156 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
157 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
158 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
159 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
160 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
161 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
162 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
163 result.Destroy();
164
165 RTNAME(MatmulTranspose)(result, *x, sectionY2, __FILE__, __LINE__);
166 ASSERT_EQ(result.rank(), 2);
167 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
168 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
169 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
170 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
171 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
172 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
173 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
174 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
175 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
176 result.Destroy();
177
178 RTNAME(MatmulTranspose)(result, sectionX2, sectionY2, __FILE__, __LINE__);
179 ASSERT_EQ(result.rank(), 2);
180 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
181 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
182 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
183 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
184 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
185 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
186 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
187 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
188 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
189 result.Destroy();
190
191 RTNAME(MatmulTranspose)(result, sectionZ2, *v, __FILE__, __LINE__);
192 ASSERT_EQ(result.rank(), 1);
193 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
194 EXPECT_EQ(result.GetDimension(0).Extent(), 3);
195 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
196 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
197 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
198 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
199 result.Destroy();
200
201 // X F F Y F T V T F T
202 // T F F T
203 // T T F F
204 auto xLog{MakeArray<TypeCategory::Logical, 1>(std::vector<int>{3, 2},
205 std::vector<std::uint8_t>{false, true, true, false, false, true})};
206 auto yLog{MakeArray<TypeCategory::Logical, 2>(std::vector<int>{3, 2},
207 std::vector<std::uint16_t>{false, false, false, true, true, false})};
208 auto vLog{MakeArray<TypeCategory::Logical, 1>(
209 std::vector<int>{3}, std::vector<std::uint8_t>{true, false, true})};
210 RTNAME(MatmulTranspose)(result, *xLog, *yLog, __FILE__, __LINE__);
211 ASSERT_EQ(result.rank(), 2);
212 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
213 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
214 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
215 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
216 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
217 EXPECT_FALSE(
218 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
219 EXPECT_FALSE(
220 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
221 EXPECT_TRUE(
222 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
223 EXPECT_FALSE(
224 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
225
226 RTNAME(MatmulTranspose)(result, *yLog, *vLog, __FILE__, __LINE__);
227 ASSERT_EQ(result.rank(), 1);
228 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
229 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
230 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
231 EXPECT_FALSE(
232 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
233 EXPECT_TRUE(
234 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
235}
236

source code of flang/unittests/Runtime/MatmulTranspose.cpp