1 | //===- MergerTest.cpp - Tests for the sparsifier's merger -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
10 | #include "llvm/Support/Compiler.h" |
11 | #include "gmock/gmock.h" |
12 | #include "gtest/gtest.h" |
13 | |
14 | #include <memory> |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::sparse_tensor; |
18 | |
19 | namespace { |
20 | |
21 | /// |
22 | /// Defines macros to iterate binary and the combination of binary operations. |
23 | /// |
24 | |
25 | #define FOREVERY_BINOP(DO) \ |
26 | DO(mulf, TensorExp::Kind::kMulF) \ |
27 | DO(mulc, TensorExp::Kind::kMulC) \ |
28 | DO(muli, TensorExp::Kind::kMulI) \ |
29 | DO(addf, TensorExp::Kind::kAddF) \ |
30 | DO(addc, TensorExp::Kind::kAddC) \ |
31 | DO(addi, TensorExp::Kind::kAddI) \ |
32 | DO(subf, TensorExp::Kind::kSubF) \ |
33 | DO(subc, TensorExp::Kind::kSubC) \ |
34 | DO(subi, TensorExp::Kind::kSubI) \ |
35 | DO(andi, TensorExp::Kind::kAndI) \ |
36 | DO(xori, TensorExp::Kind::kXorI) \ |
37 | DO(ori, TensorExp::Kind::kOrI) \ |
38 | DO(cmpf, TensorExp::Kind::kCmpF) \ |
39 | DO(cmpi, TensorExp::Kind::kCmpI) |
40 | |
41 | #define (TEST, EXTRA) \ |
42 | TEST(addf, EXTRA) \ |
43 | TEST(addc, EXTRA) \ |
44 | TEST(addi, EXTRA) \ |
45 | TEST(xori, EXTRA) \ |
46 | TEST(ori, EXTRA) |
47 | |
48 | #define (TEST, EXTRA) \ |
49 | TEST(mulf, EXTRA) \ |
50 | TEST(mulc, EXTRA) \ |
51 | TEST(muli, EXTRA) \ |
52 | TEST(andi, EXTRA) |
53 | |
54 | #define FOREVERY_COMMON_DISJ_BINOP(TEST) \ |
55 | FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, "") |
56 | |
57 | #define FOREVERY_COMMON_CONJ_BINOP(TEST) \ |
58 | FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, "") |
59 | |
60 | #define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST) \ |
61 | FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addf) \ |
62 | FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addc) \ |
63 | FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addi) \ |
64 | FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, xori) \ |
65 | FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, ori) |
66 | |
67 | #define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST) \ |
68 | FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulf) \ |
69 | FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulc) \ |
70 | FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, muli) \ |
71 | FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, andi) |
72 | |
73 | #define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST) \ |
74 | FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addf) \ |
75 | FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addc) \ |
76 | FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addi) \ |
77 | FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, ori) \ |
78 | FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, xori) |
79 | |
80 | /// |
81 | /// Helper classes/functions for testing Merger. |
82 | /// |
83 | |
84 | /// Simple recursive data structure used to match expressions in `Merger`, |
85 | /// which uses const references into the short-lived data strucutures. |
86 | struct Match { |
87 | struct Children { |
88 | Children(const Match &e0, const Match &e1) : e0(e0), e1(e1) {} |
89 | const Match &e0; |
90 | const Match &e1; |
91 | }; |
92 | |
93 | Match() : kind(TensorExp::Kind::kSynZero) {} |
94 | Match(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {} |
95 | Match(TensorExp::Kind kind, const Match &e0, const Match &e1) |
96 | : kind(kind), children(e0, e1) { |
97 | assert(kind >= TensorExp::Kind::kMulF); |
98 | } |
99 | |
100 | TensorExp::Kind kind; |
101 | union { |
102 | TensorId tid; |
103 | Children children; |
104 | }; |
105 | }; |
106 | |
107 | /// |
108 | /// Readable Match builder functions. |
109 | /// These should be preferred over the actual constructors. |
110 | /// |
111 | |
112 | static Match tensorMatch(TensorId tid) { return Match(tid); } |
113 | static Match synZeroMatch() { return Match(); } |
114 | |
115 | #define IMPL_BINOP_PATTERN(OP, KIND) \ |
116 | LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0, \ |
117 | const Match &e1) { \ |
118 | return Match(KIND, e0, e1); \ |
119 | } |
120 | FOREVERY_BINOP(IMPL_BINOP_PATTERN) |
121 | #undef IMPL_BINOP_PATTERN |
122 | |
123 | // Parameterize LevelFormat to test both Dense and Batch LevelFormat. |
124 | class MergerTestBase : public ::testing::TestWithParam<LevelFormat> { |
125 | protected: |
126 | MergerTestBase(unsigned numTensors, unsigned numLoops) |
127 | : merger(numTensors, numLoops, /*maxRank=*/numLoops) { |
128 | tensors.reserve(N: numTensors); |
129 | for (unsigned t = 0; t < numTensors; t++) |
130 | tensors.push_back(Elt: merger.addTensorExp(t: tid(t))); |
131 | } |
132 | |
133 | /// |
134 | /// Expression construction helpers. |
135 | /// |
136 | |
137 | TensorId tid(unsigned t) const { return merger.makeTensorId(t); } |
138 | LoopId lid(unsigned i) const { return merger.makeLoopId(i); } |
139 | ExprId tensor(unsigned t) const { |
140 | assert(t < tensors.size()); |
141 | return tensors[t]; |
142 | } |
143 | |
144 | #define IMPL_BINOP_EXPR(OP, KIND) \ |
145 | LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) { \ |
146 | return merger.addExp(KIND, e0, e1); \ |
147 | } |
148 | FOREVERY_BINOP(IMPL_BINOP_EXPR) |
149 | #undef IMPL_BINOP_EXPR |
150 | |
151 | /// |
152 | /// Comparison helpers. |
153 | /// |
154 | |
155 | /// Returns true if any lattice point with an expression matching |
156 | /// the given `pattern` and bits matching the given `bits` is present |
157 | /// in the `[lo, lo+n)` slice of the lattice set `s`. This is useful |
158 | /// for testing partial ordering constraints between lattice points. |
159 | /// We generally know how contiguous groups of lattice points should |
160 | /// be ordered with respect to other groups, but there is no required |
161 | /// ordering within groups. If `simple` is true, then compare the |
162 | /// `lat.simple` field instead to test the result after optimization. |
163 | bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n, |
164 | const Match &pattern, const BitVector &bits, |
165 | bool simple) { |
166 | for (unsigned k = lo, hi = lo + n; k < hi; ++k) { |
167 | if (compareExpression(e: merger.lat(p: merger.set(s)[k]).exp, pattern) && |
168 | compareBits(s, k, bits, simple)) |
169 | return true; |
170 | } |
171 | return false; |
172 | } |
173 | |
174 | /// Wrapper over latPointWithinRange for readability of tests. |
175 | void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n, |
176 | const Match &pattern, const BitVector &bits, |
177 | bool simple = false) { |
178 | EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple)); |
179 | } |
180 | |
181 | /// Wrapper over expectLatPointWithinRange for a single lat point. |
182 | void expectLatPoint(LatSetId s, unsigned lo, const Match &pattern, |
183 | const BitVector &bits, bool simple = false) { |
184 | EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple)); |
185 | } |
186 | |
187 | /// Converts a vector of (loop, tensor) pairs to a bitvector with the |
188 | /// corresponding bits set. |
189 | BitVector loopsToBits(const std::vector<std::pair<LoopId, TensorId>> &loops) { |
190 | BitVector testBits = BitVector(merger.getNumTensors(), false); |
191 | for (auto [loop, tensor] : loops) |
192 | testBits.set(merger.makeTensorLoopId(t: tensor, i: loop)); |
193 | return testBits; |
194 | } |
195 | |
196 | /// Returns true if the bits of the `k`th point in set `s` matches |
197 | /// the given `bits`. If `simple` is true, then compares the `lat.simple` |
198 | /// field instead, to test the result after optimization |
199 | bool compareBits(LatSetId s, unsigned k, const BitVector &bits, bool simple) { |
200 | const auto &point = merger.lat(p: merger.set(s)[k]); |
201 | return (simple ? point.simple : point.bits) == bits; |
202 | } |
203 | |
204 | /// Check that there are n lattice points in set s. |
205 | void expectNumLatPoints(LatSetId s, unsigned n) { |
206 | EXPECT_THAT(merger.set(s).size(), n); |
207 | } |
208 | |
209 | /// Compares expressions for equality. Equality is defined recursively as: |
210 | /// - Operations are equal if they have the same kind and children. |
211 | /// - Leaf tensors are equal if they refer to the same tensor. |
212 | bool compareExpression(ExprId e, const Match &pattern) { |
213 | const auto &tensorExp = merger.exp(e); |
214 | if (tensorExp.kind != pattern.kind) |
215 | return false; |
216 | switch (tensorExp.kind) { |
217 | // Leaf. |
218 | case TensorExp::Kind::kTensor: |
219 | return tensorExp.tensor == pattern.tid; |
220 | case TensorExp::Kind::kSynZero: |
221 | // Already checked kind equivalence @L233 |
222 | return true; |
223 | case TensorExp::Kind::kInvariant: |
224 | llvm_unreachable("invariant not handled yet" ); |
225 | case TensorExp::Kind::kLoopVar: |
226 | llvm_unreachable("loop-variables not handled yet" ); |
227 | // Unary operations. |
228 | case TensorExp::Kind::kAbsF: |
229 | case TensorExp::Kind::kAbsC: |
230 | case TensorExp::Kind::kAbsI: |
231 | case TensorExp::Kind::kCeilF: |
232 | case TensorExp::Kind::kFloorF: |
233 | case TensorExp::Kind::kSqrtF: |
234 | case TensorExp::Kind::kSqrtC: |
235 | case TensorExp::Kind::kExpm1F: |
236 | case TensorExp::Kind::kExpm1C: |
237 | case TensorExp::Kind::kLog1pF: |
238 | case TensorExp::Kind::kLog1pC: |
239 | case TensorExp::Kind::kSinF: |
240 | case TensorExp::Kind::kSinC: |
241 | case TensorExp::Kind::kTanhF: |
242 | case TensorExp::Kind::kTanhC: |
243 | case TensorExp::Kind::kNegF: |
244 | case TensorExp::Kind::kNegC: |
245 | case TensorExp::Kind::kNegI: |
246 | case TensorExp::Kind::kTruncF: |
247 | case TensorExp::Kind::kExtF: |
248 | case TensorExp::Kind::kCastFS: |
249 | case TensorExp::Kind::kCastFU: |
250 | case TensorExp::Kind::kCastSF: |
251 | case TensorExp::Kind::kCastUF: |
252 | case TensorExp::Kind::kCastS: |
253 | case TensorExp::Kind::kCastU: |
254 | case TensorExp::Kind::kCastIdx: |
255 | case TensorExp::Kind::kTruncI: |
256 | case TensorExp::Kind::kCIm: |
257 | case TensorExp::Kind::kCRe: |
258 | case TensorExp::Kind::kBitCast: |
259 | case TensorExp::Kind::kSelect: |
260 | case TensorExp::Kind::kBinaryBranch: |
261 | case TensorExp::Kind::kUnary: |
262 | return compareExpression(e: tensorExp.children.e0, pattern: pattern.children.e0); |
263 | // Binary operations. |
264 | case TensorExp::Kind::kMulF: |
265 | case TensorExp::Kind::kMulC: |
266 | case TensorExp::Kind::kMulI: |
267 | case TensorExp::Kind::kDivF: |
268 | case TensorExp::Kind::kDivC: |
269 | case TensorExp::Kind::kDivS: |
270 | case TensorExp::Kind::kDivU: |
271 | case TensorExp::Kind::kAddF: |
272 | case TensorExp::Kind::kAddC: |
273 | case TensorExp::Kind::kAddI: |
274 | case TensorExp::Kind::kSubF: |
275 | case TensorExp::Kind::kSubC: |
276 | case TensorExp::Kind::kSubI: |
277 | case TensorExp::Kind::kAndI: |
278 | case TensorExp::Kind::kOrI: |
279 | case TensorExp::Kind::kXorI: |
280 | case TensorExp::Kind::kCmpF: |
281 | case TensorExp::Kind::kCmpI: |
282 | case TensorExp::Kind::kShrS: |
283 | case TensorExp::Kind::kShrU: |
284 | case TensorExp::Kind::kShlI: |
285 | case TensorExp::Kind::kBinary: |
286 | case TensorExp::Kind::kReduce: |
287 | return compareExpression(e: tensorExp.children.e0, pattern: pattern.children.e0) && |
288 | compareExpression(e: tensorExp.children.e1, pattern: pattern.children.e1); |
289 | case TensorExp::Kind::kDenseOp: { |
290 | bool eq = compareExpression(e: tensorExp.children.e0, pattern: pattern.children.e0); |
291 | if (eq && tensorExp.children.e1 != sparse_tensor::detail::kInvalidId) |
292 | return compareExpression(e: tensorExp.children.e1, pattern: pattern.children.e1); |
293 | return eq; |
294 | } |
295 | } |
296 | llvm_unreachable("unexpected kind" ); |
297 | } |
298 | |
299 | // This field is public for convenience. |
300 | Merger merger; |
301 | |
302 | private: |
303 | // This field is private to prevent mutation after the ctor. |
304 | SmallVector<ExprId> tensors; |
305 | }; |
306 | |
307 | /// |
308 | /// Tests with all sparse inputs. |
309 | /// |
310 | |
311 | /// Three tensors (two inputs, one output); and a single loop. |
312 | class MergerTest3T1L : public MergerTestBase { |
313 | protected: |
314 | MergerTest3T1L() : MergerTestBase(3, 1) { |
315 | EXPECT_TRUE(merger.getOutTensorID() == tid(2)); |
316 | // Tensor 0: sparse input vector. |
317 | merger.setLevelAndType(t: tid(t: 0), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed); |
318 | // Tensor 1: sparse input vector. |
319 | merger.setLevelAndType(t: tid(t: 1), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed); |
320 | // Tensor 2: dense output vector. |
321 | merger.setLevelAndType(t: tid(t: 2), i: lid(i: 0), lvl: 0, lt: GetParam()); |
322 | } |
323 | }; |
324 | |
325 | INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L, |
326 | ::testing::Values(LevelFormat::Dense, |
327 | LevelFormat::Batch)); |
328 | |
329 | /// Four tensors (three inputs, one output); and a single loop. |
330 | class MergerTest4T1L : public MergerTestBase { |
331 | protected: |
332 | MergerTest4T1L() : MergerTestBase(4, 1) { |
333 | EXPECT_TRUE(merger.getOutTensorID() == tid(3)); |
334 | // Tensor 0: sparse input vector. |
335 | merger.setLevelAndType(t: tid(t: 0), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed); |
336 | // Tensor 1: sparse input vector. |
337 | merger.setLevelAndType(t: tid(t: 1), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed); |
338 | // Tensor 2: sparse input vector |
339 | merger.setLevelAndType(t: tid(t: 2), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed); |
340 | // Tensor 3: dense output vector |
341 | merger.setLevelAndType(t: tid(t: 3), i: lid(i: 0), lvl: 0, lt: GetParam()); |
342 | } |
343 | }; |
344 | |
345 | INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L, |
346 | ::testing::Values(LevelFormat::Dense, |
347 | LevelFormat::Batch)); |
348 | |
349 | /// |
350 | /// Tests with both sparse and dense input. |
351 | /// |
352 | |
353 | /// Three tensors (two inputs, one output); and a single loop. |
354 | class MergerTest3T1LD : public MergerTestBase { |
355 | protected: |
356 | MergerTest3T1LD() : MergerTestBase(3, 1) { |
357 | EXPECT_TRUE(merger.getOutTensorID() == tid(2)); |
358 | // Tensor 0: sparse input vector. |
359 | merger.setLevelAndType(t: tid(t: 0), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed); |
360 | // Tensor 1: dense input vector. |
361 | merger.setLevelAndType(t: tid(t: 1), i: lid(i: 0), lvl: 0, lt: GetParam()); |
362 | // Tensor 2: dense output vector. |
363 | merger.setLevelAndType(t: tid(t: 2), i: lid(i: 0), lvl: 0, lt: GetParam()); |
364 | } |
365 | }; |
366 | |
367 | INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD, |
368 | ::testing::Values(LevelFormat::Dense, |
369 | LevelFormat::Batch)); |
370 | |
371 | /// |
372 | /// Tests with both undef and dense input. |
373 | /// |
374 | |
375 | /// Three tensors (three inputs, one output); and a single loop. |
376 | class MergerTest4T1LU : public MergerTestBase { |
377 | protected: |
378 | MergerTest4T1LU() : MergerTestBase(4, 1) { |
379 | EXPECT_TRUE(merger.getOutTensorID() == tid(3)); |
380 | // Tensor 0: undef input vector. |
381 | merger.setLevelAndType(t: tid(t: 0), i: lid(i: 0), lvl: 0, lt: LevelFormat::Undef); |
382 | // Tensor 1: dense input vector. |
383 | merger.setLevelAndType(t: tid(t: 1), i: lid(i: 0), lvl: 0, lt: GetParam()); |
384 | // Tensor 2: undef input vector. |
385 | merger.setLevelAndType(t: tid(t: 2), i: lid(i: 0), lvl: 0, lt: LevelFormat::Undef); |
386 | // Tensor 3: dense output vector. |
387 | merger.setLevelAndType(t: tid(t: 3), i: lid(i: 0), lvl: 0, lt: GetParam()); |
388 | } |
389 | }; |
390 | |
391 | INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU, |
392 | ::testing::Values(LevelFormat::Dense, |
393 | LevelFormat::Batch)); |
394 | |
395 | /// |
396 | /// Tests with operation on sparse output. |
397 | /// |
398 | |
399 | /// Three tensors (two inputs, one output, one synthetic); and a single loop. |
400 | class MergerTest3T1LSo : public MergerTestBase { |
401 | protected: |
402 | MergerTest3T1LSo() : MergerTestBase(3, 1) { |
403 | EXPECT_TRUE(merger.getOutTensorID() == tid(2)); |
404 | EXPECT_TRUE(merger.getSynTensorID() == tid(3)); |
405 | merger.setHasSparseOut(true); |
406 | // Tensor 0: undef input vector. |
407 | merger.setLevelAndType(t: tid(t: 0), i: lid(i: 0), lvl: 0, lt: LevelFormat::Undef); |
408 | // Tensor 1: undef input vector. |
409 | merger.setLevelAndType(t: tid(t: 1), i: lid(i: 0), lvl: 0, lt: LevelFormat::Undef); |
410 | // Tensor 2: sparse output vector. |
411 | merger.setLevelAndType(t: tid(t: 2), i: lid(i: 0), lvl: 0, lt: LevelFormat::Compressed); |
412 | } |
413 | }; |
414 | |
415 | // This testsuite does not use any dense-like format, just one of {Dense, Batch} |
416 | // is enough. |
417 | INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo, |
418 | ::testing::Values(LevelFormat::Dense)); |
419 | |
420 | } // namespace |
421 | |
422 | /// Vector multiplication (conjunction) of 3 vectors, i.e.; |
423 | /// a(i) = b(i) * c(i) * d(i) |
424 | /// which should form the single lattice point |
425 | /// { |
426 | /// lat( i_00_U i_01_D i_02_U / (tensor_0 * tensor_1 * tensor2) ) |
427 | /// } |
428 | /// after optimization, the dense dimesion should be kept, despite it appears |
429 | /// in the middle |
430 | /// { |
431 | /// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) ) |
432 | /// } |
433 | #define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \ |
434 | TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \ |
435 | const auto em = CONJ1##Expr(tensor(0), tensor(1)); \ |
436 | const auto e = CONJ2##Expr(em, tensor(2)); \ |
437 | const auto l0 = lid(0); \ |
438 | const auto t0 = tid(0); \ |
439 | const auto t1 = tid(1); \ |
440 | const auto t2 = tid(2); \ |
441 | const Match &p0 = tensorMatch(t0); \ |
442 | const Match &p1 = tensorMatch(t1); \ |
443 | const Match &p2 = tensorMatch(t2); \ |
444 | auto s = merger.buildLattices(e, l0); \ |
445 | expectNumLatPoints(s, 1); \ |
446 | expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \ |
447 | loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ |
448 | s = merger.optimizeSet(s); \ |
449 | expectNumLatPoints(s, 1); \ |
450 | expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \ |
451 | loopsToBits({{l0, t1}}), true); \ |
452 | } |
453 | FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF) |
454 | #undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF |
455 | |
456 | /// Vector multiplication (conjunction) of 2 vectors, i.e.; |
457 | /// o(i) = b(i) * c(i) * o(i) |
458 | /// which should form the single lattice point (note how a synthetic tensor |
459 | /// i_03_U is created for the sparse output) |
460 | /// { |
461 | /// lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) ) |
462 | /// } |
463 | /// after optimization, the synthetic tensor should be preserved. |
464 | /// { |
465 | /// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) ) |
466 | /// } |
467 | #define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \ |
468 | TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \ |
469 | const auto em = CONJ1##Expr(tensor(0), tensor(1)); \ |
470 | const auto e = CONJ2##Expr(em, tensor(2)); \ |
471 | const auto l0 = lid(0); \ |
472 | const auto t0 = tid(0); \ |
473 | const auto t1 = tid(1); \ |
474 | const auto t2 = tid(2); \ |
475 | const auto t3 = tid(3); \ |
476 | const Match &p0 = tensorMatch(t0); \ |
477 | const Match &p1 = tensorMatch(t1); \ |
478 | const Match &p2 = tensorMatch(t2); \ |
479 | auto s = merger.buildLattices(e, l0); \ |
480 | expectNumLatPoints(s, 1); \ |
481 | expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \ |
482 | loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}})); \ |
483 | s = merger.optimizeSet(s); \ |
484 | expectNumLatPoints(s, 1); \ |
485 | expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \ |
486 | loopsToBits({{l0, t3}}), true); \ |
487 | } |
488 | FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT) |
489 | #undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT |
490 | |
491 | /// Vector addition (disjunction) of 2 vectors. i.e.; |
492 | /// a(i) = b(i) + c(i) |
493 | /// which should form the 3 lattice points |
494 | /// { |
495 | /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) |
496 | /// lat( i_00 / tensor_0 ) |
497 | /// lat( i_01 / tensor_1 ) |
498 | /// } |
499 | /// and after optimization, the lattice points do not change (as there is no |
500 | /// duplicated point and all input vectors are sparse vector). |
501 | /// { |
502 | /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) |
503 | /// lat( i_00 / tensor_0 ) |
504 | /// lat( i_01 / tensor_1 ) |
505 | /// } |
506 | #define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \ |
507 | TEST_P(MergerTest3T1L, vector_##OP) { \ |
508 | const auto e = OP##Expr(tensor(0), tensor(1)); \ |
509 | const auto l0 = lid(0); \ |
510 | const auto t0 = tid(0); \ |
511 | const auto t1 = tid(1); \ |
512 | const Match &p0 = tensorMatch(t0); \ |
513 | const Match &p1 = tensorMatch(t1); \ |
514 | auto s = merger.buildLattices(e, l0); \ |
515 | \ |
516 | expectNumLatPoints(s, 3); \ |
517 | expectLatPoint(s, 0, OP##Match(p0, p1), \ |
518 | loopsToBits({{l0, t0}, {l0, t1}})); \ |
519 | expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \ |
520 | expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \ |
521 | \ |
522 | s = merger.optimizeSet(s); \ |
523 | expectNumLatPoints(s, 3); \ |
524 | expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \ |
525 | true); \ |
526 | expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true); \ |
527 | expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true); \ |
528 | } |
529 | FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ) |
530 | #undef IMPL_MERGER_TEST_DISJ |
531 | |
532 | /// Vector multiplication (conjunction) of 2 vectors, i.e.; |
533 | /// a(i) = b(i) * c(i) |
534 | /// which should form the single lattice point |
535 | /// { |
536 | /// lat( i_00 i_01 / (tensor_0 * tensor_1) ) |
537 | /// } |
538 | #define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \ |
539 | TEST_P(MergerTest3T1L, vector_##OP) { \ |
540 | const auto e = OP##Expr(tensor(0), tensor(1)); \ |
541 | const auto l0 = lid(0); \ |
542 | const auto t0 = tid(0); \ |
543 | const auto t1 = tid(1); \ |
544 | const Match &p0 = tensorMatch(t0); \ |
545 | const Match &p1 = tensorMatch(t1); \ |
546 | auto s = merger.buildLattices(e, l0); \ |
547 | \ |
548 | expectNumLatPoints(s, 1); \ |
549 | expectLatPoint(s, 0, OP##Match(p0, p1), \ |
550 | loopsToBits({{l0, t0}, {l0, t1}})); \ |
551 | \ |
552 | s = merger.optimizeSet(s); \ |
553 | expectNumLatPoints(s, 1); \ |
554 | expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \ |
555 | true); \ |
556 | } |
557 | FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ) |
558 | #undef IMPL_MERGER_TEST_CONJ |
559 | |
560 | /// Vector multiplication (conjunction) then addition (disjunction), i.e.; |
561 | /// a(i) = b(i) * c(i) + d(i); |
562 | /// which should form |
563 | /// { |
564 | /// lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 ) |
565 | /// lat( i_00 i_01 / tensor_0 * tensor_1 |
566 | /// lat( i_02 / tensor_2 ) |
567 | /// } |
568 | #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \ |
569 | TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) { \ |
570 | const auto em = CONJ##Expr(tensor(0), tensor(1)); \ |
571 | const auto e = DISJ##Expr(em, tensor(2)); \ |
572 | const auto l0 = lid(0); \ |
573 | const auto t0 = tid(0); \ |
574 | const auto t1 = tid(1); \ |
575 | const auto t2 = tid(2); \ |
576 | const Match &p0 = tensorMatch(t0); \ |
577 | const Match &p1 = tensorMatch(t1); \ |
578 | const Match &p2 = tensorMatch(t2); \ |
579 | auto s = merger.buildLattices(e, l0); \ |
580 | \ |
581 | expectNumLatPoints(s, 3); \ |
582 | expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \ |
583 | loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ |
584 | expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \ |
585 | loopsToBits({{l0, t0}, {l0, t1}})); \ |
586 | expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \ |
587 | \ |
588 | s = merger.optimizeSet(s); \ |
589 | expectNumLatPoints(s, 3); \ |
590 | expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \ |
591 | loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ |
592 | expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \ |
593 | loopsToBits({{l0, t0}, {l0, t1}})); \ |
594 | expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \ |
595 | } |
596 | FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ) |
597 | #undef IMPL_MERGER_TEST_CONJ_DISJ |
598 | |
599 | /// Vector addition (disjunction) then addition (disjunction), i.e.; |
600 | /// a(i) = b(i) + c(i) + d(i) |
601 | /// which should form |
602 | /// { |
603 | /// lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 ) |
604 | /// lat( i_02 i_01 / tensor_2 + tensor_1 ) |
605 | /// lat( i_02 i_00 / tensor_2 + tensor_0 ) |
606 | /// lat( i_01 i_00 / tensor_1 + tensor_0 ) |
607 | /// lat( i_02 / tensor_2 ) |
608 | /// lat( i_01 / tensor_1 ) |
609 | /// lat( i_00 / tensor_0 ) |
610 | /// } |
611 | #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \ |
612 | TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \ |
613 | const auto em = DISJ1##Expr(tensor(0), tensor(1)); \ |
614 | const auto e = DISJ2##Expr(em, tensor(2)); \ |
615 | const auto l0 = lid(0); \ |
616 | const auto t0 = tid(0); \ |
617 | const auto t1 = tid(1); \ |
618 | const auto t2 = tid(2); \ |
619 | const Match &p0 = tensorMatch(t0); \ |
620 | const Match &p1 = tensorMatch(t1); \ |
621 | const Match &p2 = tensorMatch(t2); \ |
622 | auto s = merger.buildLattices(e, l0); \ |
623 | \ |
624 | expectNumLatPoints(s, 7); \ |
625 | expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \ |
626 | loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ |
627 | expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \ |
628 | loopsToBits({{l0, t1}, {l0, t2}})); \ |
629 | expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \ |
630 | loopsToBits({{l0, t0}, {l0, t2}})); \ |
631 | expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \ |
632 | loopsToBits({{l0, t0}, {l0, t1}})); \ |
633 | expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \ |
634 | expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \ |
635 | expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \ |
636 | \ |
637 | s = merger.optimizeSet(s); \ |
638 | expectNumLatPoints(s, 7); \ |
639 | expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \ |
640 | loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ |
641 | expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \ |
642 | loopsToBits({{l0, t1}, {l0, t2}})); \ |
643 | expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \ |
644 | loopsToBits({{l0, t0}, {l0, t2}})); \ |
645 | expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \ |
646 | loopsToBits({{l0, t0}, {l0, t1}})); \ |
647 | expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \ |
648 | expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \ |
649 | expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \ |
650 | } |
651 | FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ) |
652 | #undef IMPL_MERGER_TEST_DISJ_DISJ |
653 | |
654 | /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.; |
655 | /// a(i) = b(i) * c(i) * d(i); |
656 | /// which should form |
657 | /// { |
658 | /// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 ) |
659 | /// } |
660 | #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \ |
661 | TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \ |
662 | const auto em = CONJ1##Expr(tensor(0), tensor(1)); \ |
663 | const auto e = CONJ2##Expr(em, tensor(2)); \ |
664 | const auto l0 = lid(0); \ |
665 | const auto t0 = tid(0); \ |
666 | const auto t1 = tid(1); \ |
667 | const auto t2 = tid(2); \ |
668 | const Match &p0 = tensorMatch(t0); \ |
669 | const Match &p1 = tensorMatch(t1); \ |
670 | const Match &p2 = tensorMatch(t2); \ |
671 | auto s = merger.buildLattices(e, l0); \ |
672 | expectNumLatPoints(s, 1); \ |
673 | expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \ |
674 | loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ |
675 | s = merger.optimizeSet(s); \ |
676 | expectNumLatPoints(s, 1); \ |
677 | expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \ |
678 | loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \ |
679 | } |
680 | FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ) |
681 | #undef IMPL_MERGER_TEST_CONJ_CONJ |
682 | |
683 | /// Vector addition (disjunction) of 2 vectors, i.e.; |
684 | /// a(i) = b(i) + c(i) |
685 | /// which should form the 3 lattice points |
686 | /// { |
687 | /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) |
688 | /// lat( i_00 / sparse_tensor_0 ) |
689 | /// lat( i_01 / dense_tensor_1 ) |
690 | /// } |
691 | /// which should be optimized to |
692 | /// { |
693 | /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton) |
694 | /// lat( i_01 / dense_tensor_0 ) (no sparse dimension) |
695 | /// } |
696 | /// |
697 | /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff |
698 | /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ). |
699 | #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \ |
700 | TEST_P(MergerTest3T1LD, vector_opted_##OP) { \ |
701 | const auto e = OP##Expr(tensor(0), tensor(1)); \ |
702 | const auto l0 = lid(0); \ |
703 | const auto t0 = tid(0); \ |
704 | const auto t1 = tid(1); \ |
705 | const Match &p0 = tensorMatch(t0); \ |
706 | const Match &p1 = tensorMatch(t1); \ |
707 | auto s = merger.buildLattices(e, l0); \ |
708 | \ |
709 | expectNumLatPoints(s, 3); \ |
710 | expectLatPoint(s, 0, OP##Match(p0, p1), \ |
711 | loopsToBits({{l0, t0}, {l0, t1}})); \ |
712 | expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \ |
713 | expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \ |
714 | \ |
715 | s = merger.optimizeSet(s); \ |
716 | expectNumLatPoints(s, 2); \ |
717 | expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \ |
718 | true); \ |
719 | expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true); \ |
720 | } |
721 | FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ) |
722 | #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ |
723 | |
724 | /// Vector multiplication (conjunction) of 2 vectors, i.e.: |
725 | /// a(i) = b(i) * c(i) |
726 | /// which should form the single lattice point |
727 | /// { |
728 | /// lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) ) |
729 | /// } |
730 | /// it should be optimized to |
731 | /// { |
732 | /// lat( i_00 / (sparse_tensor_0 * dense_tensor_1) ) |
733 | /// } |
734 | /// since i_01 is a dense dimension. |
735 | #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \ |
736 | TEST_P(MergerTest3T1LD, vector_opted_##OP) { \ |
737 | const auto e = OP##Expr(tensor(0), tensor(1)); \ |
738 | const auto l0 = lid(0); \ |
739 | const auto t0 = tid(0); \ |
740 | const auto t1 = tid(1); \ |
741 | const Match &p0 = tensorMatch(t0); \ |
742 | const Match &p1 = tensorMatch(t1); \ |
743 | auto s = merger.buildLattices(e, l0); \ |
744 | \ |
745 | expectNumLatPoints(s, 1); \ |
746 | expectLatPoint(s, 0, OP##Match(p0, p1), \ |
747 | loopsToBits({{l0, t0}, {l0, t1}})); \ |
748 | \ |
749 | s = merger.optimizeSet(s); \ |
750 | expectNumLatPoints(s, 1); \ |
751 | expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}}), true); \ |
752 | } |
753 | FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ) |
754 | #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ |
755 | |
756 | /// Vector element-wise comparison (disjunction) of 2 vectors. i.e.; |
757 | /// a(i) = b(i) + c(i) |
758 | /// which should form the 3 lattice points |
759 | /// { |
760 | /// lat( i_00 i_01 / (tensor_0 cmp tensor_1) ) |
761 | /// lat( i_00 / tensor_0 cmp 0 ) |
762 | /// lat( i_01 / 0 cmp tensor_1 ) |
763 | /// } |
764 | /// and after optimization, the lattice points do not change (as there is no |
765 | /// duplicated point and all input vectors are sparse vector). |
766 | /// { |
767 | /// lat( i_00 i_01 / (tensor_0 cmp tensor_1) ) |
768 | /// lat( i_00 / tensor_0 cmp 0 ) |
769 | /// lat( i_01 / 0 cmp tensor_1 ) |
770 | /// } |
771 | TEST_P(MergerTest3T1L, vector_cmp) { |
772 | const auto e = cmpiExpr(e0: tensor(t: 0), e1: tensor(t: 1)); |
773 | const auto l0 = lid(i: 0); |
774 | const auto t0 = tid(t: 0); |
775 | const auto t1 = tid(t: 1); |
776 | const Match &zero = synZeroMatch(); |
777 | const Match &p0 = tensorMatch(tid: t0); |
778 | const Match &p1 = tensorMatch(tid: t1); |
779 | auto s = merger.buildLattices(e, i: l0); |
780 | expectLatPoint(s, lo: 0, pattern: cmpiMatch(e0: p0, e1: p1), bits: loopsToBits(loops: {{l0, t0}, {l0, t1}})); |
781 | expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: p0, e1: zero), |
782 | bits: loopsToBits(loops: {{l0, t0}})); |
783 | expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: zero, e1: p1), |
784 | bits: loopsToBits(loops: {{l0, t1}})); |
785 | s = merger.optimizeSet(s); |
786 | expectLatPoint(s, lo: 0, pattern: cmpiMatch(e0: p0, e1: p1), bits: loopsToBits(loops: {{l0, t0}, {l0, t1}})); |
787 | expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: p0, e1: zero), |
788 | bits: loopsToBits(loops: {{l0, t0}})); |
789 | expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: zero, e1: p1), |
790 | bits: loopsToBits(loops: {{l0, t1}})); |
791 | } |
792 | |
793 | /// Vector element-wise comparsion (disjunction) of 2 vectors, i.e.; |
794 | /// a(i) = b(i) cmp c(i) |
795 | /// which should form the 3 lattice points |
796 | /// { |
797 | /// lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ) |
798 | /// lat( i_00 / sparse_tensor_0 cmp 0) |
799 | /// lat( i_01 / 0 cmp dense_tensor_1 ) |
800 | /// } |
801 | /// which should be optimized to |
802 | /// { |
803 | /// lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ) (not singleton) |
804 | /// lat( i_01 / 0 cmp dense_tensor_0 ) () |
805 | /// } |
806 | /// |
807 | /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff |
808 | /// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ). |
809 | TEST_P(MergerTest3T1LD, vector_cmp) { |
810 | const auto e = cmpiExpr(e0: tensor(t: 0), e1: tensor(t: 1)); |
811 | const auto l0 = lid(i: 0); |
812 | const auto t0 = tid(t: 0); |
813 | const auto t1 = tid(t: 1); |
814 | const Match &zero = synZeroMatch(); |
815 | const Match &p0 = tensorMatch(tid: t0); |
816 | const Match &p1 = tensorMatch(tid: t1); |
817 | auto s = merger.buildLattices(e, i: l0); |
818 | expectLatPoint(s, lo: 0, pattern: cmpiMatch(e0: p0, e1: p1), bits: loopsToBits(loops: {{l0, t0}, {l0, t1}})); |
819 | expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: p0, e1: zero), |
820 | bits: loopsToBits(loops: {{l0, t0}})); |
821 | expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: zero, e1: p1), |
822 | bits: loopsToBits(loops: {{l0, t1}})); |
823 | s = merger.optimizeSet(s); |
824 | expectLatPoint(s, lo: 0, pattern: cmpiMatch(e0: p0, e1: p1), bits: loopsToBits(loops: {{l0, t0}, {l0, t1}})); |
825 | expectLatPointWithinRange(s, lo: 1, n: 2, pattern: cmpiMatch(e0: zero, e1: p1), |
826 | bits: loopsToBits(loops: {{l0, t1}})); |
827 | } |
828 | |