1//===- ReshapeOpsUtilsTest.cpp - ReshapeOpsUtils unit tests ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
10#include "mlir/IR/BuiltinTypeInterfaces.h"
11#include "llvm/ADT/STLExtras.h"
12#include "gtest/gtest.h"
13#include <optional>
14
15using namespace mlir;
16
17/// Helper to make constructing
18/// `std::optional<SmallVector<ReassociationIndices>>` more readable.
19static std::optional<SmallVector<ReassociationIndices>>
20makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
21 return std::optional<SmallVector<ReassociationIndices>>(list);
22}
23
24TEST(ReassociationIndicesForCollapse, ScalarTest) {
25 EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
26 makeOptionalIndices({}));
27 EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
28 makeOptionalIndices({}));
29 EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
30 makeOptionalIndices({}));
31 EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
32 ShapedType::kDynamic, 1,
33 ShapedType::kDynamic},
34 {}),
35 makeOptionalIndices({}));
36}
37
38TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {
39 EXPECT_EQ(getReassociationIndicesForCollapse({}, {}), std::nullopt);
40 EXPECT_EQ(getReassociationIndicesForCollapse({}, {1}), std::nullopt);
41 EXPECT_EQ(getReassociationIndicesForCollapse({2}, {}), std::nullopt);
42 EXPECT_EQ(
43 getReassociationIndicesForCollapse({1, 2, ShapedType::kDynamic, 1}, {}),
44 std::nullopt);
45}
46
47TEST(ReassociationIndicesForCollapse, StaticTest) {
48 EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}),
49 makeOptionalIndices({{0, 1}}));
50 EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {10, 600}),
51 makeOptionalIndices({{0}, {1, 2}}));
52 EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 30}),
53 makeOptionalIndices({{0, 1}, {2}}));
54}
55
56TEST(ReassociationIndicesForCollapse, StaticTestFailure) {
57 // No-op reassociation
58 EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}),
59 std::nullopt);
60 // Invalid static reassociations
61 EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt);
62 EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}),
63 std::nullopt);
64 // Non-collapsing (expanding) reassociation
65 EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}),
66 std::nullopt);
67}
68
69TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) {
70 EXPECT_EQ(getReassociationIndicesForCollapse({10, 1}, {10}),
71 makeOptionalIndices({{0, 1}}));
72 EXPECT_EQ(getReassociationIndicesForCollapse({1, 20, 30}, {600}),
73 makeOptionalIndices({{0, 1, 2}}));
74 EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}),
75 makeOptionalIndices({{0, 1, 2}}));
76 EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1, 1}, {1, 1, 1}),
77 makeOptionalIndices({{0}, {1}, {2, 3}}));
78}
79
80TEST(ReassociationIndicesForCollapse, DynamicTest) {
81 EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1},
82 {ShapedType::kDynamic}),
83 makeOptionalIndices({{0, 1}}));
84 EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1},
85 {ShapedType::kDynamic}),
86 makeOptionalIndices({{0, 1, 2}}));
87 EXPECT_EQ(getReassociationIndicesForCollapse(
88 {1, ShapedType::kDynamic, 1, ShapedType::kDynamic, 1},
89 {ShapedType::kDynamic, ShapedType::kDynamic}),
90 makeOptionalIndices({{0, 1}, {2, 3, 4}}));
91 EXPECT_EQ(
92 getReassociationIndicesForCollapse(
93 {ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}),
94 makeOptionalIndices({{0, 1}}));
95 EXPECT_EQ(getReassociationIndicesForCollapse(
96 {1, ShapedType::kDynamic, ShapedType::kDynamic},
97 {1, ShapedType::kDynamic}),
98 makeOptionalIndices({{0}, {1, 2}}));
99
100 EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10},
101 {ShapedType::kDynamic}),
102 makeOptionalIndices({{0, 1}}));
103 EXPECT_EQ(getReassociationIndicesForCollapse(
104 {1, ShapedType::kDynamic, ShapedType::kDynamic},
105 {ShapedType::kDynamic}),
106 makeOptionalIndices({{0, 1, 2}}));
107 EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic},
108 {ShapedType::kDynamic}),
109 makeOptionalIndices({{0, 1}}));
110 EXPECT_EQ(getReassociationIndicesForCollapse(
111 {ShapedType::kDynamic, 1, 2, ShapedType::kDynamic, 10},
112 {ShapedType::kDynamic, 10}),
113 makeOptionalIndices({{0, 1, 2, 3}, {4}}));
114 EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
115 {ShapedType::kDynamic, 20}),
116 makeOptionalIndices({{0, 1}, {2}}));
117 EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic, 20},
118 {ShapedType::kDynamic, 20}),
119 makeOptionalIndices({{0, 1}, {2}}));
120 EXPECT_EQ(getReassociationIndicesForCollapse(
121 {ShapedType::kDynamic, 3, 2, 5, 2}, {ShapedType::kDynamic, 20}),
122 makeOptionalIndices({{0, 1}, {2, 3, 4}}));
123 EXPECT_EQ(getReassociationIndicesForCollapse(
124 {10, ShapedType::kDynamic, 20, ShapedType::kDynamic, 1},
125 {ShapedType::kDynamic, 20, ShapedType::kDynamic}),
126 makeOptionalIndices({{0, 1}, {2}, {3, 4}}));
127 EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1},
128 {ShapedType::kDynamic}),
129 makeOptionalIndices({{0, 1, 2}}));
130 EXPECT_EQ(getReassociationIndicesForCollapse(
131 {ShapedType::kDynamic, ShapedType::kDynamic, 1},
132 {ShapedType::kDynamic, ShapedType::kDynamic}),
133 makeOptionalIndices({{0}, {1, 2}}));
134 EXPECT_EQ(getReassociationIndicesForCollapse(
135 {1, ShapedType::kDynamic, ShapedType::kDynamic},
136 {ShapedType::kDynamic, ShapedType::kDynamic}),
137 makeOptionalIndices({{0, 1}, {2}}));
138 EXPECT_EQ(getReassociationIndicesForCollapse(
139 {ShapedType::kDynamic, 1, ShapedType::kDynamic},
140 {ShapedType::kDynamic, ShapedType::kDynamic}),
141 makeOptionalIndices({{0}, {1, 2}}));
142}
143
144TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
145 EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
146 {ShapedType::kDynamic, 10}),
147 std::nullopt);
148 EXPECT_EQ(getReassociationIndicesForCollapse(
149 {ShapedType::kDynamic, 10, ShapedType::kDynamic},
150 {ShapedType::kDynamic, ShapedType::kDynamic}),
151 std::nullopt);
152 EXPECT_EQ(getReassociationIndicesForCollapse(
153 {20, ShapedType::kDynamic, 10, ShapedType::kDynamic},
154 {ShapedType::kDynamic, ShapedType::kDynamic}),
155 std::nullopt);
156 EXPECT_EQ(getReassociationIndicesForCollapse(
157 {ShapedType::kDynamic, 5, 3, 2, 2}, {ShapedType::kDynamic, 20}),
158 std::nullopt);
159 EXPECT_EQ(
160 getReassociationIndicesForCollapse(
161 {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
162 {ShapedType::kDynamic, ShapedType::kDynamic}),
163 std::nullopt);
164 EXPECT_EQ(getReassociationIndicesForCollapse(
165 {ShapedType::kDynamic, ShapedType::kDynamic, 10, 1,
166 ShapedType::kDynamic},
167 {ShapedType::kDynamic, ShapedType::kDynamic}),
168 std::nullopt);
169 EXPECT_EQ(getReassociationIndicesForCollapse(
170 {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
171 {ShapedType::kDynamic, 10, ShapedType::kDynamic}),
172 std::nullopt);
173 EXPECT_EQ(getReassociationIndicesForCollapse(
174 {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
175 {ShapedType::kDynamic, 2, 2, ShapedType::kDynamic}),
176 std::nullopt);
177 EXPECT_EQ(getReassociationIndicesForCollapse(
178 {ShapedType::kDynamic, 3, 4, 3, ShapedType::kDynamic},
179 {ShapedType::kDynamic, 12, ShapedType::kDynamic}),
180 std::nullopt);
181 EXPECT_EQ(getReassociationIndicesForCollapse(
182 {ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic},
183 {ShapedType::kDynamic, 32, ShapedType::kDynamic}),
184 std::nullopt);
185
186 //===----------------------------------------------------------------------===//
187 // TODO: Reassociation for the following examples can be computed, but isn't
188 // supported by `getReassociationIndicesForCollapse`.
189 //===----------------------------------------------------------------------===//
190
191 // TODO: Fails because there's no backtracking when some source dimensions
192 // remain unmatched at either edge.
193 EXPECT_EQ(getReassociationIndicesForCollapse(
194 {ShapedType::kDynamic, 10, ShapedType::kDynamic, 10},
195 {ShapedType::kDynamic, 10}),
196 std::nullopt);
197 EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 2, 2},
198 {1, ShapedType::kDynamic, 2}),
199 std::nullopt);
200 EXPECT_EQ(getReassociationIndicesForCollapse({2, 2, ShapedType::kDynamic, 1},
201 {2, ShapedType::kDynamic}),
202 std::nullopt);
203}
204

source code of mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp