1//===- StructuredOpsUtilsTest.cpp - StructuredOpsUtils unit tests ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10#include "mlir/IR/AffineExpr.h"
11#include "mlir/IR/AffineMap.h"
12#include "gmock/gmock.h"
13#include "gtest/gtest.h"
14
15using namespace mlir;
16using testing::Not;
17using testing::Truly;
18
19namespace {
20
21TEST(isRowMajorMatmul, Simple) {
22 MLIRContext context;
23
24 AffineExpr m, n, k;
25 bindDims(ctx: &context, exprs&: m, exprs&: n, exprs&: k);
26 auto mapA = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, k}, context: &context));
27 auto mapB = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {k, n}, context: &context));
28 auto mapC = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, n}, context: &context));
29 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
30
31 EXPECT_THAT(maps, Truly(isRowMajorMatmul));
32}
33
34TEST(isRowMajorMatmul, BindingShifted) {
35 MLIRContext context;
36
37 AffineExpr m, n, k;
38 bindDims(ctx: &context, exprs&: k, exprs&: m, exprs&: n); // bind in different order
39 auto mapA = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, k}, context: &context));
40 auto mapB = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {k, n}, context: &context));
41 auto mapC = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, n}, context: &context));
42 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
43
44 EXPECT_THAT(maps, Truly(isRowMajorMatmul));
45}
46
47TEST(isRowMajorMatmul, BindingSwapped) {
48 MLIRContext context;
49
50 AffineExpr m, n, k;
51 bindDims(ctx: &context, exprs&: k, exprs&: n, exprs&: m); // bind in different order
52 auto mapA = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, k}, context: &context));
53 auto mapB = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {k, n}, context: &context));
54 auto mapC = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, n}, context: &context));
55 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
56
57 EXPECT_THAT(maps, Truly(isRowMajorMatmul));
58}
59
60TEST(isRowMajorMatmul, ColumnMajor) {
61 MLIRContext context;
62
63 AffineExpr m, n, k;
64 bindDims(ctx: &context, exprs&: m, exprs&: n, exprs&: k);
65 auto mapA = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {k, n}, context: &context));
66 auto mapB = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, k}, context: &context));
67 auto mapC = AffineMapAttr::get(AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, n}, context: &context));
68 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
69
70 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
71}
72
73TEST(isRowMajorMatmul, FirstInputSwapped) {
74 MLIRContext context;
75
76 AffineExpr m, n, k;
77 bindDims(ctx: &context, exprs&: m, exprs&: n, exprs&: k);
78 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context));
79 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
80 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
81 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
82
83 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
84}
85
86TEST(isRowMajorMatmul, TooFewMaps) {
87 MLIRContext context;
88
89 AffineExpr m, n, k;
90 bindDims(ctx: &context, exprs&: m, exprs&: n, exprs&: k);
91 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
92 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
93 auto maps = ArrayAttr::get(&context, {mapA, mapB});
94
95 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
96}
97
98TEST(isRowMajorMatmul, TooManyMaps) {
99 MLIRContext context;
100
101 AffineExpr m, n, k;
102 bindDims(ctx: &context, exprs&: m, exprs&: n, exprs&: k);
103 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
104 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
105 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
106 auto mapD = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context));
107
108 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC, mapD});
109
110 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
111}
112
113TEST(isRowMajorMatmul, TooFewOutputs) {
114 MLIRContext context;
115
116 AffineExpr m, n, k;
117 bindDims(ctx: &context, exprs&: m, exprs&: n, exprs&: k);
118 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m}, &context));
119 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
120 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
121 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
122
123 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
124}
125
126TEST(isColumnMajorMatmul, Simple) {
127 MLIRContext context;
128
129 AffineExpr m, n, k;
130 bindDims(ctx: &context, exprs&: m, exprs&: n, exprs&: k);
131 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
132 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
133 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
134 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
135
136 EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
137}
138
139TEST(isColumnMajorMatmul, BindingShifted) {
140 MLIRContext context;
141
142 AffineExpr m, n, k;
143 bindDims(ctx: &context, exprs&: k, exprs&: m, exprs&: n); // bind in different order
144 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
145 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
146 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
147 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
148
149 EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
150}
151
152TEST(isColumnMajorMatmul, BindingSwapped) {
153 MLIRContext context;
154
155 AffineExpr m, n, k;
156 bindDims(ctx: &context, exprs&: k, exprs&: n, exprs&: m); // bind in different order
157 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
158 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
159 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
160 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
161
162 EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
163}
164
165TEST(isColumnMajorMatmul, RowMajor) {
166 MLIRContext context;
167
168 AffineExpr m, n, k;
169 bindDims(ctx: &context, exprs&: m, exprs&: n, exprs&: k);
170 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
171 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
172 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
173 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
174
175 EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul)));
176}
177
178TEST(isColumnMajorMatmul, FirstInputSwapped) {
179 MLIRContext context;
180
181 AffineExpr m, n, k;
182 bindDims(ctx: &context, exprs&: m, exprs&: n, exprs&: k);
183 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, &context));
184 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
185 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
186 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
187
188 EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul)));
189}
190
191TEST(isRowMajorBatchMatmul, Simple) {
192 MLIRContext context;
193
194 AffineExpr batch, m, n, k;
195 bindDims(ctx: &context, exprs&: batch, exprs&: m, exprs&: n, exprs&: k);
196 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
197 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
198 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
199 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
200
201 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
202}
203
204TEST(isRowMajorBatchMatmul, BindingShifted) {
205 MLIRContext context;
206
207 AffineExpr batch, m, n, k;
208 bindDims(ctx: &context, exprs&: k, exprs&: batch, exprs&: m, exprs&: n); // bind in different order
209 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
210 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
211 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
212 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
213
214 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
215}
216
217TEST(isRowMajorBatchMatmul, BindingSwapped) {
218 MLIRContext context;
219
220 AffineExpr batch, m, n, k;
221 bindDims(ctx: &context, exprs&: batch, exprs&: k, exprs&: n, exprs&: m); // bind in different order
222 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
223 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
224 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
225 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
226
227 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
228}
229
230TEST(isRowMajorBatchMatmul, FirstInputSwapped) {
231 MLIRContext context;
232
233 AffineExpr batch, m, n, k;
234 bindDims(ctx: &context, exprs&: batch, exprs&: m, exprs&: n, exprs&: k);
235 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, m}, &context));
236 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
237 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
238 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
239
240 EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul)));
241}
242
243TEST(isVecmat, Simple) {
244 MLIRContext context;
245
246 AffineExpr k, n;
247 bindDims(ctx: &context, exprs&: k, exprs&: n);
248 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
249 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
250 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
251 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
252
253 EXPECT_THAT(maps, Truly(isVecmat));
254}
255
256TEST(isVecmat, BindingSwapped) {
257 MLIRContext context;
258
259 AffineExpr k, n;
260 bindDims(ctx: &context, exprs&: n, exprs&: k); // bind in different order
261 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
262 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
263 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
264 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
265
266 EXPECT_THAT(maps, Truly(isVecmat));
267}
268
269TEST(isVecmat, WrongDimOrderMatrix) {
270 MLIRContext context;
271
272 AffineExpr k, n;
273 bindDims(ctx: &context, exprs&: k, exprs&: n);
274 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
275 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
276 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
277 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
278
279 EXPECT_THAT(maps, Not(Truly(isVecmat)));
280}
281
282TEST(isMatvec, Simple) {
283 MLIRContext context;
284
285 AffineExpr k, n;
286 bindDims(ctx: &context, exprs&: k, exprs&: n);
287 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
288 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
289 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
290 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
291
292 EXPECT_THAT(maps, Truly(isMatvec));
293}
294
295TEST(isMatvec, BindingSwapped) {
296 MLIRContext context;
297
298 AffineExpr k, n;
299 bindDims(ctx: &context, exprs&: n, exprs&: k); // bind in different order
300 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
301 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
302 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
303 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
304
305 EXPECT_THAT(maps, Truly(isMatvec));
306}
307
308TEST(isMatvec, WrongDimOrderMatrix) {
309 MLIRContext context;
310
311 AffineExpr k, n;
312 bindDims(ctx: &context, exprs&: k, exprs&: n);
313 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
314 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
315 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
316 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
317
318 EXPECT_THAT(maps, Not(Truly(isMatvec)));
319}
320
321TEST(isBatchMatvec, Simple) {
322 MLIRContext context;
323
324 AffineExpr batch, k, n;
325 bindDims(ctx: &context, exprs&: batch, exprs&: k, exprs&: n);
326 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
327 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
328 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
329 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
330
331 EXPECT_THAT(maps, Truly(isBatchMatvec));
332}
333
334TEST(isBatchMatvec, BindingSwapped) {
335 MLIRContext context;
336
337 AffineExpr batch, k, n;
338 bindDims(ctx: &context, exprs&: batch, exprs&: n, exprs&: k); // bind in different order
339 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
340 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
341 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
342 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
343
344 EXPECT_THAT(maps, Truly(isBatchMatvec));
345}
346
347TEST(isBatchMatvec, Matmul) {
348 MLIRContext context;
349
350 AffineExpr m, n, k;
351 bindDims(ctx: &context, exprs&: m, exprs&: n, exprs&: k);
352 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
353 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
354 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
355 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
356
357 EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
358}
359
360TEST(isBatchMatvec, WrongDimOrderMatrix) {
361 MLIRContext context;
362
363 AffineExpr batch, k, n;
364 bindDims(ctx: &context, exprs&: batch, exprs&: k, exprs&: n);
365 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
366 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
367 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
368 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
369
370 EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
371}
372
373TEST(isBatchVecmat, Simple) {
374 MLIRContext context;
375
376 AffineExpr batch, k, n;
377 bindDims(ctx: &context, exprs&: batch, exprs&: k, exprs&: n);
378 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
379 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
380 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
381 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
382
383 EXPECT_THAT(maps, Truly(isBatchVecmat));
384}
385
386TEST(isBatchVecmat, BindingSwapped) {
387 MLIRContext context;
388
389 AffineExpr batch, k, n;
390 bindDims(ctx: &context, exprs&: batch, exprs&: n, exprs&: k); // bind in different order
391 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
392 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
393 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
394 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
395
396 EXPECT_THAT(maps, Truly(isBatchVecmat));
397}
398
399TEST(isBatchVecmat, Matmul) {
400 MLIRContext context;
401
402 AffineExpr m, n, k;
403 bindDims(ctx: &context, exprs&: m, exprs&: n, exprs&: k);
404 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
405 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
406 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
407 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
408
409 EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
410}
411
412TEST(isBatchVecmat, WrongDimOrderMatrix) {
413 MLIRContext context;
414
415 AffineExpr batch, k, n;
416 bindDims(ctx: &context, exprs&: batch, exprs&: k, exprs&: n);
417 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
418 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
419 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
420 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
421
422 EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
423}
424
425} // namespace
426

source code of mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp