1//===- MergerTest.cpp - Tests for the sparsifier's merger -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10#include "llvm/Support/Compiler.h"
11#include "gmock/gmock.h"
12#include "gtest/gtest.h"
13
14#include <memory>
15
16using namespace mlir;
17using namespace mlir::sparse_tensor;
18
19namespace {
20
21///
22/// Defines macros to iterate binary and the combination of binary operations.
23///
24
25#define FOREVERY_BINOP(DO) \
26 DO(mulf, TensorExp::Kind::kMulF) \
27 DO(mulc, TensorExp::Kind::kMulC) \
28 DO(muli, TensorExp::Kind::kMulI) \
29 DO(addf, TensorExp::Kind::kAddF) \
30 DO(addc, TensorExp::Kind::kAddC) \
31 DO(addi, TensorExp::Kind::kAddI) \
32 DO(subf, TensorExp::Kind::kSubF) \
33 DO(subc, TensorExp::Kind::kSubC) \
34 DO(subi, TensorExp::Kind::kSubI) \
35 DO(andi, TensorExp::Kind::kAndI) \
36 DO(xori, TensorExp::Kind::kXorI) \
37 DO(ori, TensorExp::Kind::kOrI) \
38 DO(cmpf, TensorExp::Kind::kCmpF) \
39 DO(cmpi, TensorExp::Kind::kCmpI)
40
41#define FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, EXTRA) \
42 TEST(addf, EXTRA) \
43 TEST(addc, EXTRA) \
44 TEST(addi, EXTRA) \
45 TEST(xori, EXTRA) \
46 TEST(ori, EXTRA)
47
48#define FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, EXTRA) \
49 TEST(mulf, EXTRA) \
50 TEST(mulc, EXTRA) \
51 TEST(muli, EXTRA) \
52 TEST(andi, EXTRA)
53
54#define FOREVERY_COMMON_DISJ_BINOP(TEST) \
55 FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, "")
56
57#define FOREVERY_COMMON_CONJ_BINOP(TEST) \
58 FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, "")
59
60#define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST) \
61 FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addf) \
62 FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addc) \
63 FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addi) \
64 FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, xori) \
65 FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, ori)
66
67#define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST) \
68 FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulf) \
69 FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulc) \
70 FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, muli) \
71 FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, andi)
72
73#define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST) \
74 FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addf) \
75 FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addc) \
76 FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addi) \
77 FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, ori) \
78 FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, xori)
79
80///
81/// Helper classes/functions for testing Merger.
82///
83
84/// Simple recursive data structure used to match expressions in `Merger`,
85/// which uses const references into the short-lived data strucutures.
86struct Match {
87 struct Children {
88 Children(const Match &e0, const Match &e1) : e0(e0), e1(e1) {}
89 const Match &e0;
90 const Match &e1;
91 };
92
93 Match() : kind(TensorExp::Kind::kSynZero) {}
94 Match(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
95 Match(TensorExp::Kind kind, const Match &e0, const Match &e1)
96 : kind(kind), children(e0, e1) {
97 assert(kind >= TensorExp::Kind::kMulF);
98 }
99
100 TensorExp::Kind kind;
101 union {
102 TensorId tid;
103 Children children;
104 };
105};
106
107///
108/// Readable Match builder functions.
109/// These should be preferred over the actual constructors.
110///
111
112static Match tensorMatch(TensorId tid) { return Match(tid); }
113static Match synZeroMatch() { return Match(); }
114
115#define IMPL_BINOP_PATTERN(OP, KIND) \
116 LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0, \
117 const Match &e1) { \
118 return Match(KIND, e0, e1); \
119 }
120FOREVERY_BINOP(IMPL_BINOP_PATTERN)
121#undef IMPL_BINOP_PATTERN
122
123// Parameterize LevelFormat to test both Dense and Batch LevelFormat.
124class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
125protected:
126 MergerTestBase(unsigned numTensors, unsigned numLoops)
127 : merger(numTensors, numLoops, /*maxRank=*/numLoops) {
128 tensors.reserve(N: numTensors);
129 for (unsigned t = 0; t < numTensors; t++)
130 tensors.push_back(Elt: merger.addTensorExp(t: tid(t)));
131 }
132
133 ///
134 /// Expression construction helpers.
135 ///
136
137 TensorId tid(unsigned t) const { return merger.makeTensorId(t); }
138 LoopId lid(unsigned i) const { return merger.makeLoopId(i); }
139 ExprId tensor(unsigned t) const {
140 assert(t < tensors.size());
141 return tensors[t];
142 }
143
144#define IMPL_BINOP_EXPR(OP, KIND) \
145 LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) { \
146 return merger.addExp(KIND, e0, e1); \
147 }
148 FOREVERY_BINOP(IMPL_BINOP_EXPR)
149#undef IMPL_BINOP_EXPR
150
151 ///
152 /// Comparison helpers.
153 ///
154
155 /// Returns true if any lattice point with an expression matching
156 /// the given `pattern` and bits matching the given `bits` is present
157 /// in the `[lo, lo+n)` slice of the lattice set `s`. This is useful
158 /// for testing partial ordering constraints between lattice points.
159 /// We generally know how contiguous groups of lattice points should
160 /// be ordered with respect to other groups, but there is no required
161 /// ordering within groups. If `simple` is true, then compare the
162 /// `lat.simple` field instead to test the result after optimization.
163 bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n,
164 const Match &pattern, const BitVector &bits,
165 bool simple) {
166 for (unsigned k = lo, hi = lo + n; k < hi; ++k) {
167 if (compareExpression(e: merger.lat(p: merger.set(s)[k]).exp, pattern) &&
168 compareBits(s, k, bits, simple))
169 return true;
170 }
171 return false;
172 }
173
174 /// Wrapper over latPointWithinRange for readability of tests.
175 void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n,
176 const Match &pattern, const BitVector &bits,
177 bool simple = false) {
178 EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple));
179 }
180
181 /// Wrapper over expectLatPointWithinRange for a single lat point.
182 void expectLatPoint(LatSetId s, unsigned lo, const Match &pattern,
183 const BitVector &bits, bool simple = false) {
184 EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple));
185 }
186
187 /// Converts a vector of (loop, tensor) pairs to a bitvector with the
188 /// corresponding bits set.
189 BitVector loopsToBits(const std::vector<std::pair<LoopId, TensorId>> &loops) {
190 BitVector testBits = BitVector(merger.getNumTensors(), false);
191 for (auto [loop, tensor] : loops)
192 testBits.set(merger.makeTensorLoopId(t: tensor, i: loop));
193 return testBits;
194 }
195
196 /// Returns true if the bits of the `k`th point in set `s` matches
197 /// the given `bits`. If `simple` is true, then compares the `lat.simple`
198 /// field instead, to test the result after optimization
199 bool compareBits(LatSetId s, unsigned k, const BitVector &bits, bool simple) {
200 const auto &point = merger.lat(p: merger.set(s)[k]);
201 return (simple ? point.simple : point.bits) == bits;
202 }
203
204 /// Check that there are n lattice points in set s.
205 void expectNumLatPoints(LatSetId s, unsigned n) {
206 EXPECT_THAT(merger.set(s).size(), n);
207 }
208
209 /// Compares expressions for equality. Equality is defined recursively as:
210 /// - Operations are equal if they have the same kind and children.
211 /// - Leaf tensors are equal if they refer to the same tensor.
212 bool compareExpression(ExprId e, const Match &pattern) {
213 const auto &tensorExp = merger.exp(e);
214 if (tensorExp.kind != pattern.kind)
215 return false;
216 switch (tensorExp.kind) {
217 // Leaf.
218 case TensorExp::Kind::kTensor:
219 return tensorExp.tensor == pattern.tid;
220 case TensorExp::Kind::kSynZero:
221 // Already checked kind equivalence @L233
222 return true;
223 case TensorExp::Kind::kInvariant:
224 llvm_unreachable("invariant not handled yet");
225 case TensorExp::Kind::kLoopVar:
226 llvm_unreachable("loop-variables not handled yet");
227 // Unary operations.
228 case TensorExp::Kind::kAbsF:
229 case TensorExp::Kind::kAbsC:
230 case TensorExp::Kind::kAbsI:
231 case TensorExp::Kind::kCeilF:
232 case TensorExp::Kind::kFloorF:
233 case TensorExp::Kind::kSqrtF:
234 case TensorExp::Kind::kSqrtC:
235 case TensorExp::Kind::kExpm1F:
236 case TensorExp::Kind::kExpm1C:
237 case TensorExp::Kind::kLog1pF:
238 case TensorExp::Kind::kLog1pC:
239 case TensorExp::Kind::kRelu:
240 case TensorExp::Kind::kSinF:
241 case TensorExp::Kind::kSinC:
242 case TensorExp::Kind::kTanhF:
243 case TensorExp::Kind::kTanhC:
244 case TensorExp::Kind::kNegF:
245 case TensorExp::Kind::kNegC:
246 case TensorExp::Kind::kNegI:
247 case TensorExp::Kind::kTruncF:
248 case TensorExp::Kind::kExtF:
249 case TensorExp::Kind::kCastFS:
250 case TensorExp::Kind::kCastFU:
251 case TensorExp::Kind::kCastSF:
252 case TensorExp::Kind::kCastUF:
253 case TensorExp::Kind::kCastS:
254 case TensorExp::Kind::kCastU:
255 case TensorExp::Kind::kCastIdx:
256 case TensorExp::Kind::kTruncI:
257 case TensorExp::Kind::kCIm:
258 case TensorExp::Kind::kCRe:
259 case TensorExp::Kind::kBitCast:
260 case TensorExp::Kind::kSelect:
261 case TensorExp::Kind::kBinaryBranch:
262 case TensorExp::Kind::kUnary:
263 return compareExpression(e: tensorExp.children.e0, pattern: pattern.children.e0);
264 // Binary operations.
265 case TensorExp::Kind::kMulF:
266 case TensorExp::Kind::kMulC:
267 case TensorExp::Kind::kMulI:
268 case TensorExp::Kind::kDivF:
269 case TensorExp::Kind::kDivC:
270 case TensorExp::Kind::kDivS:
271 case TensorExp::Kind::kDivU:
272 case TensorExp::Kind::kAddF:
273 case TensorExp::Kind::kAddC:
274 case TensorExp::Kind::kAddI:
275 case TensorExp::Kind::kSubF:
276 case TensorExp::Kind::kSubC:
277 case TensorExp::Kind::kSubI:
278 case TensorExp::Kind::kAndI:
279 case TensorExp::Kind::kOrI:
280 case TensorExp::Kind::kXorI:
281 case TensorExp::Kind::kCmpF:
282 case TensorExp::Kind::kCmpI:
283 case TensorExp::Kind::kShrS:
284 case TensorExp::Kind::kShrU:
285 case TensorExp::Kind::kShlI:
286 case TensorExp::Kind::kBinary:
287 case TensorExp::Kind::kReduce:
288 return compareExpression(e: tensorExp.children.e0, pattern: pattern.children.e0) &&
289 compareExpression(e: tensorExp.children.e1, pattern: pattern.children.e1);
290 case TensorExp::Kind::kDenseOp: {
291 bool eq = compareExpression(e: tensorExp.children.e0, pattern: pattern.children.e0);
292 if (eq && tensorExp.children.e1 != sparse_tensor::detail::kInvalidId)
293 return compareExpression(e: tensorExp.children.e1, pattern: pattern.children.e1);
294 return eq;
295 }
296 }
297 llvm_unreachable("unexpected kind");
298 }
299
300 // This field is public for convenience.
301 Merger merger;
302
303private:
304 // This field is private to prevent mutation after the ctor.
305 SmallVector<ExprId> tensors;
306};
307
308///
309/// Tests with all sparse inputs.
310///
311
312/// Three tensors (two inputs, one output); and a single loop.
313class MergerTest3T1L : public MergerTestBase {
314protected:
315 MergerTest3T1L() : MergerTestBase(3, 1) {
316 EXPECT_TRUE(merger.getOutTensorID() == tid(2));
317 // Tensor 0: sparse input vector.
318 merger.setLevelAndType(t: tid(t: 0), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed);
319 // Tensor 1: sparse input vector.
320 merger.setLevelAndType(t: tid(t: 1), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed);
321 // Tensor 2: dense output vector.
322 merger.setLevelAndType(t: tid(t: 2), i: lid(i: 0), lvl: 0, lt: GetParam());
323 }
324};
325
326INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
327 ::testing::Values(LevelFormat::Dense,
328 LevelFormat::Batch));
329
330/// Four tensors (three inputs, one output); and a single loop.
331class MergerTest4T1L : public MergerTestBase {
332protected:
333 MergerTest4T1L() : MergerTestBase(4, 1) {
334 EXPECT_TRUE(merger.getOutTensorID() == tid(3));
335 // Tensor 0: sparse input vector.
336 merger.setLevelAndType(t: tid(t: 0), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed);
337 // Tensor 1: sparse input vector.
338 merger.setLevelAndType(t: tid(t: 1), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed);
339 // Tensor 2: sparse input vector
340 merger.setLevelAndType(t: tid(t: 2), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed);
341 // Tensor 3: dense output vector
342 merger.setLevelAndType(t: tid(t: 3), i: lid(i: 0), lvl: 0, lt: GetParam());
343 }
344};
345
346INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
347 ::testing::Values(LevelFormat::Dense,
348 LevelFormat::Batch));
349
350///
351/// Tests with both sparse and dense input.
352///
353
354/// Three tensors (two inputs, one output); and a single loop.
355class MergerTest3T1LD : public MergerTestBase {
356protected:
357 MergerTest3T1LD() : MergerTestBase(3, 1) {
358 EXPECT_TRUE(merger.getOutTensorID() == tid(2));
359 // Tensor 0: sparse input vector.
360 merger.setLevelAndType(t: tid(t: 0), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed);
361 // Tensor 1: dense input vector.
362 merger.setLevelAndType(t: tid(t: 1), i: lid(i: 0), lvl: 0, lt: GetParam());
363 // Tensor 2: dense output vector.
364 merger.setLevelAndType(t: tid(t: 2), i: lid(i: 0), lvl: 0, lt: GetParam());
365 }
366};
367
368INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
369 ::testing::Values(LevelFormat::Dense,
370 LevelFormat::Batch));
371
372///
373/// Tests with both undef and dense input.
374///
375
376/// Three tensors (three inputs, one output); and a single loop.
377class MergerTest4T1LU : public MergerTestBase {
378protected:
379 MergerTest4T1LU() : MergerTestBase(4, 1) {
380 EXPECT_TRUE(merger.getOutTensorID() == tid(3));
381 // Tensor 0: undef input vector.
382 merger.setLevelAndType(t: tid(t: 0), i: lid(i: 0), lvl: 0, lt: LevelFormat::Undef);
383 // Tensor 1: dense input vector.
384 merger.setLevelAndType(t: tid(t: 1), i: lid(i: 0), lvl: 0, lt: GetParam());
385 // Tensor 2: undef input vector.
386 merger.setLevelAndType(t: tid(t: 2), i: lid(i: 0), lvl: 0, lt: LevelFormat::Undef);
387 // Tensor 3: dense output vector.
388 merger.setLevelAndType(t: tid(t: 3), i: lid(i: 0), lvl: 0, lt: GetParam());
389 }
390};
391
392INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
393 ::testing::Values(LevelFormat::Dense,
394 LevelFormat::Batch));
395
396///
397/// Tests with operation on sparse output.
398///
399
400/// Three tensors (two inputs, one output, one synthetic); and a single loop.
401class MergerTest3T1LSo : public MergerTestBase {
402protected:
403 MergerTest3T1LSo() : MergerTestBase(3, 1) {
404 EXPECT_TRUE(merger.getOutTensorID() == tid(2));
405 EXPECT_TRUE(merger.getSynTensorID() == tid(3));
406 merger.setHasSparseOut(true);
407 // Tensor 0: undef input vector.
408 merger.setLevelAndType(t: tid(t: 0), i: lid(i: 0), lvl: 0, lt: LevelFormat::Undef);
409 // Tensor 1: undef input vector.
410 merger.setLevelAndType(t: tid(t: 1), i: lid(i: 0), lvl: 0, lt: LevelFormat::Undef);
411 // Tensor 2: sparse output vector.
412 merger.setLevelAndType(t: tid(t: 2), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed);
413 }
414};
415
416// This testsuite does not use any dense-like format, just one of {Dense, Batch}
417// is enough.
418INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
419 ::testing::Values(LevelFormat::Dense));
420
421} // namespace
422
423/// Vector multiplication (conjunction) of 3 vectors, i.e.;
424/// a(i) = b(i) * c(i) * d(i)
425/// which should form the single lattice point
426/// {
427/// lat( i_00_U i_01_D i_02_U / (tensor_0 * tensor_1 * tensor2) )
428/// }
429/// after optimization, the dense dimesion should be kept, despite it appears
430/// in the middle
431/// {
432/// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
433/// }
434#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
435 TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
436 const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
437 const auto e = CONJ2##Expr(em, tensor(2)); \
438 const auto l0 = lid(0); \
439 const auto t0 = tid(0); \
440 const auto t1 = tid(1); \
441 const auto t2 = tid(2); \
442 const Match &p0 = tensorMatch(t0); \
443 const Match &p1 = tensorMatch(t1); \
444 const Match &p2 = tensorMatch(t2); \
445 auto s = merger.buildLattices(e, l0); \
446 expectNumLatPoints(s, 1); \
447 expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
448 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
449 s = merger.optimizeSet(s); \
450 expectNumLatPoints(s, 1); \
451 expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
452 loopsToBits({{l0, t1}}), true); \
453 }
454FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
455#undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
456
457/// Vector multiplication (conjunction) of 2 vectors, i.e.;
458/// o(i) = b(i) * c(i) * o(i)
459/// which should form the single lattice point (note how a synthetic tensor
460/// i_03_U is created for the sparse output)
461/// {
462/// lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) )
463/// }
464/// after optimization, the synthetic tensor should be preserved.
465/// {
466/// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
467/// }
468#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
469 TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
470 const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
471 const auto e = CONJ2##Expr(em, tensor(2)); \
472 const auto l0 = lid(0); \
473 const auto t0 = tid(0); \
474 const auto t1 = tid(1); \
475 const auto t2 = tid(2); \
476 const auto t3 = tid(3); \
477 const Match &p0 = tensorMatch(t0); \
478 const Match &p1 = tensorMatch(t1); \
479 const Match &p2 = tensorMatch(t2); \
480 auto s = merger.buildLattices(e, l0); \
481 expectNumLatPoints(s, 1); \
482 expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
483 loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}})); \
484 s = merger.optimizeSet(s); \
485 expectNumLatPoints(s, 1); \
486 expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
487 loopsToBits({{l0, t3}}), true); \
488 }
489FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
490#undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
491
492/// Vector addition (disjunction) of 2 vectors. i.e.;
493/// a(i) = b(i) + c(i)
494/// which should form the 3 lattice points
495/// {
496/// lat( i_00 i_01 / (tensor_0 + tensor_1) )
497/// lat( i_00 / tensor_0 )
498/// lat( i_01 / tensor_1 )
499/// }
500/// and after optimization, the lattice points do not change (as there is no
501/// duplicated point and all input vectors are sparse vector).
502/// {
503/// lat( i_00 i_01 / (tensor_0 + tensor_1) )
504/// lat( i_00 / tensor_0 )
505/// lat( i_01 / tensor_1 )
506/// }
507#define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \
508 TEST_P(MergerTest3T1L, vector_##OP) { \
509 const auto e = OP##Expr(tensor(0), tensor(1)); \
510 const auto l0 = lid(0); \
511 const auto t0 = tid(0); \
512 const auto t1 = tid(1); \
513 const Match &p0 = tensorMatch(t0); \
514 const Match &p1 = tensorMatch(t1); \
515 auto s = merger.buildLattices(e, l0); \
516 \
517 expectNumLatPoints(s, 3); \
518 expectLatPoint(s, 0, OP##Match(p0, p1), \
519 loopsToBits({{l0, t0}, {l0, t1}})); \
520 expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
521 expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
522 \
523 s = merger.optimizeSet(s); \
524 expectNumLatPoints(s, 3); \
525 expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
526 true); \
527 expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true); \
528 expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true); \
529 }
530FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
531#undef IMPL_MERGER_TEST_DISJ
532
533/// Vector multiplication (conjunction) of 2 vectors, i.e.;
534/// a(i) = b(i) * c(i)
535/// which should form the single lattice point
536/// {
537/// lat( i_00 i_01 / (tensor_0 * tensor_1) )
538/// }
539#define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \
540 TEST_P(MergerTest3T1L, vector_##OP) { \
541 const auto e = OP##Expr(tensor(0), tensor(1)); \
542 const auto l0 = lid(0); \
543 const auto t0 = tid(0); \
544 const auto t1 = tid(1); \
545 const Match &p0 = tensorMatch(t0); \
546 const Match &p1 = tensorMatch(t1); \
547 auto s = merger.buildLattices(e, l0); \
548 \
549 expectNumLatPoints(s, 1); \
550 expectLatPoint(s, 0, OP##Match(p0, p1), \
551 loopsToBits({{l0, t0}, {l0, t1}})); \
552 \
553 s = merger.optimizeSet(s); \
554 expectNumLatPoints(s, 1); \
555 expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
556 true); \
557 }
558FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
559#undef IMPL_MERGER_TEST_CONJ
560
561/// Vector multiplication (conjunction) then addition (disjunction), i.e.;
562/// a(i) = b(i) * c(i) + d(i);
563/// which should form
564/// {
565/// lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
566/// lat( i_00 i_01 / tensor_0 * tensor_1
567/// lat( i_02 / tensor_2 )
568/// }
569#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
570 TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
571 const auto em = CONJ##Expr(tensor(0), tensor(1)); \
572 const auto e = DISJ##Expr(em, tensor(2)); \
573 const auto l0 = lid(0); \
574 const auto t0 = tid(0); \
575 const auto t1 = tid(1); \
576 const auto t2 = tid(2); \
577 const Match &p0 = tensorMatch(t0); \
578 const Match &p1 = tensorMatch(t1); \
579 const Match &p2 = tensorMatch(t2); \
580 auto s = merger.buildLattices(e, l0); \
581 \
582 expectNumLatPoints(s, 3); \
583 expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \
584 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
585 expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \
586 loopsToBits({{l0, t0}, {l0, t1}})); \
587 expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
588 \
589 s = merger.optimizeSet(s); \
590 expectNumLatPoints(s, 3); \
591 expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \
592 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
593 expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \
594 loopsToBits({{l0, t0}, {l0, t1}})); \
595 expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
596 }
597FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
598#undef IMPL_MERGER_TEST_CONJ_DISJ
599
600/// Vector addition (disjunction) then addition (disjunction), i.e.;
601/// a(i) = b(i) + c(i) + d(i)
602/// which should form
603/// {
604/// lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
605/// lat( i_02 i_01 / tensor_2 + tensor_1 )
606/// lat( i_02 i_00 / tensor_2 + tensor_0 )
607/// lat( i_01 i_00 / tensor_1 + tensor_0 )
608/// lat( i_02 / tensor_2 )
609/// lat( i_01 / tensor_1 )
610/// lat( i_00 / tensor_0 )
611/// }
612#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
613 TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
614 const auto em = DISJ1##Expr(tensor(0), tensor(1)); \
615 const auto e = DISJ2##Expr(em, tensor(2)); \
616 const auto l0 = lid(0); \
617 const auto t0 = tid(0); \
618 const auto t1 = tid(1); \
619 const auto t2 = tid(2); \
620 const Match &p0 = tensorMatch(t0); \
621 const Match &p1 = tensorMatch(t1); \
622 const Match &p2 = tensorMatch(t2); \
623 auto s = merger.buildLattices(e, l0); \
624 \
625 expectNumLatPoints(s, 7); \
626 expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \
627 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
628 expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \
629 loopsToBits({{l0, t1}, {l0, t2}})); \
630 expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \
631 loopsToBits({{l0, t0}, {l0, t2}})); \
632 expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \
633 loopsToBits({{l0, t0}, {l0, t1}})); \
634 expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
635 expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
636 expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \
637 \
638 s = merger.optimizeSet(s); \
639 expectNumLatPoints(s, 7); \
640 expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \
641 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
642 expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \
643 loopsToBits({{l0, t1}, {l0, t2}})); \
644 expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \
645 loopsToBits({{l0, t0}, {l0, t2}})); \
646 expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \
647 loopsToBits({{l0, t0}, {l0, t1}})); \
648 expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
649 expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
650 expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \
651 }
652FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
653#undef IMPL_MERGER_TEST_DISJ_DISJ
654
655/// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
656/// a(i) = b(i) * c(i) * d(i);
657/// which should form
658/// {
659/// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
660/// }
661#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
662 TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
663 const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
664 const auto e = CONJ2##Expr(em, tensor(2)); \
665 const auto l0 = lid(0); \
666 const auto t0 = tid(0); \
667 const auto t1 = tid(1); \
668 const auto t2 = tid(2); \
669 const Match &p0 = tensorMatch(t0); \
670 const Match &p1 = tensorMatch(t1); \
671 const Match &p2 = tensorMatch(t2); \
672 auto s = merger.buildLattices(e, l0); \
673 expectNumLatPoints(s, 1); \
674 expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
675 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
676 s = merger.optimizeSet(s); \
677 expectNumLatPoints(s, 1); \
678 expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
679 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \
680 }
681FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
682#undef IMPL_MERGER_TEST_CONJ_CONJ
683
684/// Vector addition (disjunction) of 2 vectors, i.e.;
685/// a(i) = b(i) + c(i)
686/// which should form the 3 lattice points
687/// {
688/// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
689/// lat( i_00 / sparse_tensor_0 )
690/// lat( i_01 / dense_tensor_1 )
691/// }
692/// which should be optimized to
693/// {
694/// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
695/// lat( i_01 / dense_tensor_0 ) (no sparse dimension)
696/// }
697///
698/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
699/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
700#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \
701 TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
702 const auto e = OP##Expr(tensor(0), tensor(1)); \
703 const auto l0 = lid(0); \
704 const auto t0 = tid(0); \
705 const auto t1 = tid(1); \
706 const Match &p0 = tensorMatch(t0); \
707 const Match &p1 = tensorMatch(t1); \
708 auto s = merger.buildLattices(e, l0); \
709 \
710 expectNumLatPoints(s, 3); \
711 expectLatPoint(s, 0, OP##Match(p0, p1), \
712 loopsToBits({{l0, t0}, {l0, t1}})); \
713 expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
714 expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
715 \
716 s = merger.optimizeSet(s); \
717 expectNumLatPoints(s, 2); \
718 expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
719 true); \
720 expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true); \
721 }
722FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
723#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
724
725/// Vector multiplication (conjunction) of 2 vectors, i.e.:
726/// a(i) = b(i) * c(i)
727/// which should form the single lattice point
728/// {
729/// lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
730/// }
731/// it should be optimized to
732/// {
733/// lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
734/// }
735/// since i_01 is a dense dimension.
736#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \
737 TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
738 const auto e = OP##Expr(tensor(0), tensor(1)); \
739 const auto l0 = lid(0); \
740 const auto t0 = tid(0); \
741 const auto t1 = tid(1); \
742 const Match &p0 = tensorMatch(t0); \
743 const Match &p1 = tensorMatch(t1); \
744 auto s = merger.buildLattices(e, l0); \
745 \
746 expectNumLatPoints(s, 1); \
747 expectLatPoint(s, 0, OP##Match(p0, p1), \
748 loopsToBits({{l0, t0}, {l0, t1}})); \
749 \
750 s = merger.optimizeSet(s); \
751 expectNumLatPoints(s, 1); \
752 expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}}), true); \
753 }
754FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
755#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
756
757/// Vector element-wise comparison (disjunction) of 2 vectors. i.e.;
758/// a(i) = b(i) + c(i)
759/// which should form the 3 lattice points
760/// {
761/// lat( i_00 i_01 / (tensor_0 cmp tensor_1) )
762/// lat( i_00 / tensor_0 cmp 0 )
763/// lat( i_01 / 0 cmp tensor_1 )
764/// }
765/// and after optimization, the lattice points do not change (as there is no
766/// duplicated point and all input vectors are sparse vector).
767/// {
768/// lat( i_00 i_01 / (tensor_0 cmp tensor_1) )
769/// lat( i_00 / tensor_0 cmp 0 )
770/// lat( i_01 / 0 cmp tensor_1 )
771/// }
772TEST_P(MergerTest3T1L, vector_cmp) {
773 const auto e = cmpiExpr(e0: tensor(t: 0), e1: tensor(t: 1));
774 const auto l0 = lid(i: 0);
775 const auto t0 = tid(t: 0);
776 const auto t1 = tid(t: 1);
777 const Match &zero = synZeroMatch();
778 const Match &p0 = tensorMatch(tid: t0);
779 const Match &p1 = tensorMatch(tid: t1);
780 auto s = merger.buildLattices(e, i: l0);
781 expectLatPoint(s, lo: 0, pattern: cmpiMatch(e0: p0, e1: p1), bits: loopsToBits(loops: {{l0, t0}, {l0, t1}}));
782 expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: p0, e1: zero),
783 bits: loopsToBits(loops: {{l0, t0}}));
784 expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: zero, e1: p1),
785 bits: loopsToBits(loops: {{l0, t1}}));
786 s = merger.optimizeSet(s);
787 expectLatPoint(s, lo: 0, pattern: cmpiMatch(e0: p0, e1: p1), bits: loopsToBits(loops: {{l0, t0}, {l0, t1}}));
788 expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: p0, e1: zero),
789 bits: loopsToBits(loops: {{l0, t0}}));
790 expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: zero, e1: p1),
791 bits: loopsToBits(loops: {{l0, t1}}));
792}
793
794/// Vector element-wise comparsion (disjunction) of 2 vectors, i.e.;
795/// a(i) = b(i) cmp c(i)
796/// which should form the 3 lattice points
797/// {
798/// lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) )
799/// lat( i_00 / sparse_tensor_0 cmp 0)
800/// lat( i_01 / 0 cmp dense_tensor_1 )
801/// }
802/// which should be optimized to
803/// {
804/// lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ) (not singleton)
805/// lat( i_01 / 0 cmp dense_tensor_0 ) ()
806/// }
807///
808/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
809/// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
810TEST_P(MergerTest3T1LD, vector_cmp) {
811 const auto e = cmpiExpr(e0: tensor(t: 0), e1: tensor(t: 1));
812 const auto l0 = lid(i: 0);
813 const auto t0 = tid(t: 0);
814 const auto t1 = tid(t: 1);
815 const Match &zero = synZeroMatch();
816 const Match &p0 = tensorMatch(tid: t0);
817 const Match &p1 = tensorMatch(tid: t1);
818 auto s = merger.buildLattices(e, i: l0);
819 expectLatPoint(s, lo: 0, pattern: cmpiMatch(e0: p0, e1: p1), bits: loopsToBits(loops: {{l0, t0}, {l0, t1}}));
820 expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: p0, e1: zero),
821 bits: loopsToBits(loops: {{l0, t0}}));
822 expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: zero, e1: p1),
823 bits: loopsToBits(loops: {{l0, t1}}));
824 s = merger.optimizeSet(s);
825 expectLatPoint(s, lo: 0, pattern: cmpiMatch(e0: p0, e1: p1), bits: loopsToBits(loops: {{l0, t0}, {l0, t1}}));
826 expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: zero, e1: p1),
827 bits: loopsToBits(loops: {{l0, t1}}));
828}
829

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/unittests/Dialect/SparseTensor/MergerTest.cpp