1//===- PWMAFunctionTest.cpp - Tests for PWMAFunction ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains tests for PWMAFunction.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Parser.h"
14
15#include "mlir/Analysis/Presburger/PWMAFunction.h"
16#include "mlir/Analysis/Presburger/PresburgerRelation.h"
17#include "mlir/IR/MLIRContext.h"
18
19#include <gmock/gmock.h>
20#include <gtest/gtest.h>
21
22using namespace mlir;
23using namespace presburger;
24
25using testing::ElementsAre;
26
27TEST(PWAFunctionTest, isEqual) {
28 // The output expressions are different but it doesn't matter because they are
29 // equal in this domain.
30 PWMAFunction idAtZeros =
31 parsePWMAF(pieces: {{"(x, y) : (y == 0)", "(x, y) -> (x, y)"},
32 {"(x, y) : (y - 1 >= 0, x == 0)", "(x, y) -> (x, y)"},
33 {"(x, y) : (-y - 1 >= 0, x == 0)", "(x, y) -> (x, y)"}});
34 PWMAFunction idAtZeros2 =
35 parsePWMAF(pieces: {{"(x, y) : (y == 0)", "(x, y) -> (x, 20*y)"},
36 {"(x, y) : (y - 1 >= 0, x == 0)", "(x, y) -> (30*x, y)"},
37 {"(x, y) : (-y - 1 > =0, x == 0)", "(x, y) -> (30*x, y)"}});
38 EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2));
39
40 PWMAFunction notIdAtZeros = parsePWMAF(pieces: {
41 {"(x, y) : (y == 0)", "(x, y) -> (x, y)"},
42 {"(x, y) : (y - 1 >= 0, x == 0)", "(x, y) -> (x, 2*y)"},
43 {"(x, y) : (-y - 1 >= 0, x == 0)", "(x, y) -> (x, 2*y)"},
44 });
45 EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros));
46
47 // These match at their intersection but one has a bigger domain.
48 PWMAFunction idNoNegNegQuadrant =
49 parsePWMAF(pieces: {{"(x, y) : (x >= 0)", "(x, y) -> (x, y)"},
50 {"(x, y) : (-x - 1 >= 0, y >= 0)", "(x, y) -> (x, y)"}});
51 PWMAFunction idOnlyPosX = parsePWMAF(pieces: {
52 {"(x, y) : (x >= 0)", "(x, y) -> (x, y)"},
53 });
54 EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX));
55
56 // Different representations of the same domain.
57 PWMAFunction sumPlusOne = parsePWMAF(pieces: {
58 {"(x, y) : (x >= 0)", "(x, y) -> (x + y + 1)"},
59 {"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", "(x, y) -> (x + y + 1)"},
60 {"(x, y) : (-x - 1 >= 0, y >= 0)", "(x, y) -> (x + y + 1)"},
61 });
62 PWMAFunction sumPlusOne2 = parsePWMAF(pieces: {
63 {"(x, y) : ()", "(x, y) -> (x + y + 1)"},
64 });
65 EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2));
66
67 // Functions with zero input dimensions.
68 PWMAFunction noInputs1 = parsePWMAF(pieces: {
69 {"() : ()", "() -> (1)"},
70 });
71 PWMAFunction noInputs2 = parsePWMAF(pieces: {
72 {"() : ()", "() -> (2)"},
73 });
74 EXPECT_TRUE(noInputs1.isEqual(noInputs1));
75 EXPECT_FALSE(noInputs1.isEqual(noInputs2));
76
77 // Mismatched dimensionalities.
78 EXPECT_FALSE(noInputs1.isEqual(sumPlusOne));
79 EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne));
80
81 // Divisions.
82 // Domain is only multiples of 6; x = 6k for some k.
83 // x + 4(x/2) + 4(x/3) == 26k.
84 PWMAFunction mul2AndMul3 = parsePWMAF(pieces: {
85 {"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)",
86 "(x) -> (x + 4 * (x floordiv 2) + 4 * (x floordiv 3))"},
87 });
88 PWMAFunction mul6 = parsePWMAF(pieces: {
89 {"(x) : (x - 6*(x floordiv 6) == 0)", "(x) -> (26 * (x floordiv 6))"},
90 });
91 EXPECT_TRUE(mul2AndMul3.isEqual(mul6));
92
93 PWMAFunction mul6diff = parsePWMAF(pieces: {
94 {"(x) : (x - 5*(x floordiv 5) == 0)", "(x) -> (52 * (x floordiv 6))"},
95 });
96 EXPECT_FALSE(mul2AndMul3.isEqual(mul6diff));
97
98 PWMAFunction mul5 = parsePWMAF(pieces: {
99 {"(x) : (x - 5*(x floordiv 5) == 0)", "(x) -> (26 * (x floordiv 5))"},
100 });
101 EXPECT_FALSE(mul2AndMul3.isEqual(mul5));
102}
103
104TEST(PWMAFunction, valueAt) {
105 PWMAFunction nonNegPWMAF = parsePWMAF(
106 pieces: {{"(x, y) : (x >= 0)", "(x, y) -> (x + 2*y + 3, 3*x + 4*y + 5)"},
107 {"(x, y) : (y >= 0, -x - 1 >= 0)",
108 "(x, y) -> (-x + 2*y + 3, -3*x + 4*y + 5)"}});
109 EXPECT_THAT(*nonNegPWMAF.valueAt({2, 3}), ElementsAre(11, 23));
110 EXPECT_THAT(*nonNegPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
111 EXPECT_THAT(*nonNegPWMAF.valueAt({2, -3}), ElementsAre(-1, -1));
112 EXPECT_FALSE(nonNegPWMAF.valueAt({-2, -3}).has_value());
113
114 PWMAFunction divPWMAF = parsePWMAF(
115 pieces: {{"(x, y) : (x >= 0, x - 2*(x floordiv 2) == 0)",
116 "(x, y) -> (2*y + (x floordiv 2) + 3, 4*y + 3*(x floordiv 2) + 5)"},
117 {"(x, y) : (y >= 0, -x - 1 >= 0)",
118 "(x, y) -> (-x + 2*y + 3, -3*x + 4*y + 5)"}});
119 EXPECT_THAT(*divPWMAF.valueAt({4, 3}), ElementsAre(11, 23));
120 EXPECT_THAT(*divPWMAF.valueAt({4, -3}), ElementsAre(-1, -1));
121 EXPECT_FALSE(divPWMAF.valueAt({3, 3}).has_value());
122 EXPECT_FALSE(divPWMAF.valueAt({3, -3}).has_value());
123
124 EXPECT_THAT(*divPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
125 EXPECT_FALSE(divPWMAF.valueAt({-2, -3}).has_value());
126}
127
128TEST(PWMAFunction, removeIdRangeRegressionTest) {
129 PWMAFunction pwmafA = parsePWMAF(pieces: {
130 {"(x, y) : (x == 0, y == 0, x - 2*(x floordiv 2) == 0, y - 2*(y floordiv "
131 "2) == 0)",
132 "(x, y) -> (0, 0)"},
133 });
134 PWMAFunction pwmafB = parsePWMAF(pieces: {
135 {"(x, y) : (x - 11*y == 0, 11*x - y == 0, x - 2*(x floordiv 2) == 0, "
136 "y - 2*(y floordiv 2) == 0)",
137 "(x, y) -> (0, 0)"},
138 });
139 EXPECT_TRUE(pwmafA.isEqual(pwmafB));
140}
141
142TEST(PWMAFunction, eliminateRedundantLocalIdRegressionTest) {
143 PWMAFunction pwmafA = parsePWMAF(pieces: {
144 {"(x, y) : (x - 2*(x floordiv 2) == 0, x - 2*y == 0)", "(x, y) -> (y)"},
145 });
146 PWMAFunction pwmafB = parsePWMAF(pieces: {
147 {"(x, y) : (x - 2*(x floordiv 2) == 0, x - 2*y == 0)",
148 "(x, y) -> (x - y)"},
149 });
150 EXPECT_TRUE(pwmafA.isEqual(pwmafB));
151}
152
153TEST(PWMAFunction, unionLexMaxSimple) {
154 // func2 is better than func1, but func2's domain is empty.
155 {
156 PWMAFunction func1 = parsePWMAF(pieces: {
157 {"(x) : ()", "(x) -> (1)"},
158 });
159
160 PWMAFunction func2 = parsePWMAF(pieces: {
161 {"(x) : (1 == 0)", "(x) -> (2)"},
162 });
163
164 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func1));
165 EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func1));
166 }
167
168 // func2 is better than func1 on a subset of func1.
169 {
170 PWMAFunction func1 = parsePWMAF(pieces: {
171 {"(x) : ()", "(x) -> (1)"},
172 });
173
174 PWMAFunction func2 = parsePWMAF(pieces: {
175 {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (2)"},
176 });
177
178 PWMAFunction result = parsePWMAF(pieces: {
179 {"(x) : (-1 - x >= 0)", "(x) -> (1)"},
180 {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (2)"},
181 {"(x) : (x - 11 >= 0)", "(x) -> (1)"},
182 });
183
184 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
185 EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
186 }
187
188 // func1 and func2 are defined over the whole domain with different outputs.
189 {
190 PWMAFunction func1 = parsePWMAF(pieces: {
191 {"(x) : ()", "(x) -> (x)"},
192 });
193
194 PWMAFunction func2 = parsePWMAF(pieces: {
195 {"(x) : ()", "(x) -> (-x)"},
196 });
197
198 PWMAFunction result = parsePWMAF(pieces: {
199 {"(x) : (x >= 0)", "(x) -> (x)"},
200 {"(x) : (-1 - x >= 0)", "(x) -> (-x)"},
201 });
202
203 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
204 EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
205 }
206
207 // func1 and func2 have disjoint domains.
208 {
209 PWMAFunction func1 = parsePWMAF(pieces: {
210 {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (1)"},
211 {"(x) : (x - 71 >= 0, 80 - x >= 0)", "(x) -> (1)"},
212 });
213
214 PWMAFunction func2 = parsePWMAF(pieces: {
215 {"(x) : (x - 20 >= 0, 41 - x >= 0)", "(x) -> (2)"},
216 {"(x) : (x - 101 >= 0, 120 - x >= 0)", "(x) -> (2)"},
217 });
218
219 PWMAFunction result = parsePWMAF(pieces: {
220 {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (1)"},
221 {"(x) : (x - 71 >= 0, 80 - x >= 0)", "(x) -> (1)"},
222 {"(x) : (x - 20 >= 0, 41 - x >= 0)", "(x) -> (2)"},
223 {"(x) : (x - 101 >= 0, 120 - x >= 0)", "(x) -> (2)"},
224 });
225
226 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
227 EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
228 }
229}
230
231TEST(PWMAFunction, unionLexMinSimple) {
232 // func2 is better than func1, but func2's domain is empty.
233 {
234 PWMAFunction func1 = parsePWMAF(pieces: {
235 {"(x) : ()", "(x) -> (-1)"},
236 });
237
238 PWMAFunction func2 = parsePWMAF(pieces: {
239 {"(x) : (1 == 0)", "(x) -> (-2)"},
240 });
241
242 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func1));
243 EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func1));
244 }
245
246 // func2 is better than func1 on a subset of func1.
247 {
248 PWMAFunction func1 = parsePWMAF(pieces: {
249 {"(x) : ()", "(x) -> (-1)"},
250 });
251
252 PWMAFunction func2 = parsePWMAF(pieces: {
253 {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (-2)"},
254 });
255
256 PWMAFunction result = parsePWMAF(pieces: {
257 {"(x) : (-1 - x >= 0)", "(x) -> (-1)"},
258 {"(x) : (x >= 0, 10 - x >= 0)", "(x) -> (-2)"},
259 {"(x) : (x - 11 >= 0)", "(x) -> (-1)"},
260 });
261
262 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
263 EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
264 }
265
266 // func1 and func2 are defined over the whole domain with different outputs.
267 {
268 PWMAFunction func1 = parsePWMAF(pieces: {
269 {"(x) : ()", "(x) -> (-x)"},
270 });
271
272 PWMAFunction func2 = parsePWMAF(pieces: {
273 {"(x) : ()", "(x) -> (x)"},
274 });
275
276 PWMAFunction result = parsePWMAF(pieces: {
277 {"(x) : (x >= 0)", "(x) -> (-x)"},
278 {"(x) : (-1 - x >= 0)", "(x) -> (x)"},
279 });
280
281 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
282 EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
283 }
284}
285
286TEST(PWMAFunction, unionLexMaxComplex) {
287 // Union of function containing 4 different pieces of output.
288 //
289 // x >= 21 --> func1 (func2 not defined)
290 // x <= 0 --> func2 (func1 not defined)
291 // 10 <= x <= 20, y > 0 --> func1 (x + y > x - y for y > 0)
292 // 10 <= x <= 20, y <= 0 --> func2 (x + y <= x - y for y <= 0)
293 {
294 PWMAFunction func1 = parsePWMAF(pieces: {
295 {"(x, y) : (x >= 10)", "(x, y) -> (x + y)"},
296 });
297
298 PWMAFunction func2 = parsePWMAF(pieces: {
299 {"(x, y) : (x <= 20)", "(x, y) -> (x - y)"},
300 });
301
302 PWMAFunction result = parsePWMAF(pieces: {
303 {"(x, y) : (x >= 10, x <= 20, y >= 1)", "(x, y) -> (x + y)"},
304 {"(x, y) : (x >= 21)", "(x, y) -> (x + y)"},
305 {"(x, y) : (x <= 9)", "(x, y) -> (x - y)"},
306 {"(x, y) : (x >= 10, x <= 20, y <= 0)", "(x, y) -> (x - y)"},
307 });
308
309 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
310 }
311
312 // Functions with more than one output, with contribution from both functions.
313 //
314 // If y >= 1, func1 is better because in the first output,
315 // x + y (func1) > x (func2), when y >= 1
316 //
317 // If y == 0, the first output is same for both functions, so we look at the
318 // second output. -2x + 4 (func1) > 2x - 2 (func2) when 0 <= x <= 1, so we
319 // take func1 for this domain and func2 for the remaining.
320 {
321 PWMAFunction func1 = parsePWMAF(pieces: {
322 {"(x, y) : (x >= 0, y >= 0)", "(x, y) -> (x + y, -2*x + 4)"},
323 });
324
325 PWMAFunction func2 = parsePWMAF(pieces: {
326 {"(x, y) : (x >= 0, y >= 0)", "(x, y) -> (x, 2*x - 2)"},
327 });
328
329 PWMAFunction result = parsePWMAF(pieces: {
330 {"(x, y) : (x >= 0, y >= 1)", "(x, y) -> (x + y, -2*x + 4)"},
331 {"(x, y) : (x >= 0, x <= 1, y == 0)", "(x, y) -> (x + y, -2*x + 4)"},
332 {"(x, y) : (x >= 2, y == 0)", "(x, y) -> (x, 2*x - 2)"},
333 });
334
335 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
336 EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
337 }
338
339 // Function with three boolean variables `a, b, c` used to control which
340 // output will be taken lexicographically.
341 //
342 // a == 1 --> Take func2
343 // a == 0, b == 1 --> Take func1
344 // a == 0, b == 0, c == 1 --> Take func2
345 {
346 PWMAFunction func1 = parsePWMAF(pieces: {
347 {"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c "
348 ">= 0, 1 - c >= 0)",
349 "(a, b, c) -> (0, b, 0)"},
350 });
351
352 PWMAFunction func2 = parsePWMAF(pieces: {
353 {"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c >= 0, 1 - "
354 "c >= 0)",
355 "(a, b, c) -> (a, 0, c)"},
356 });
357
358 PWMAFunction result = parsePWMAF(pieces: {
359 {"(a, b, c) : (a - 1 == 0, b >= 0, 1 - b >= 0, c >= 0, 1 - c >= 0)",
360 "(a, b, c) -> (a, 0, c)"},
361 {"(a, b, c) : (a == 0, b - 1 == 0, c >= 0, 1 - c >= 0)",
362 "(a, b, c) -> (0, b, 0)"},
363 {"(a, b, c) : (a == 0, b == 0, c >= 0, 1 - c >= 0)",
364 "(a, b, c) -> (a, 0, c)"},
365 });
366
367 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
368 EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
369 }
370}
371
372TEST(PWMAFunction, unionLexMinComplex) {
373 // Regression test checking if lexicographic tiebreak produces disjoint
374 // domains.
375 //
376 // If x == 1, func1 is better since in the first output,
377 // -x (func1) is < 0 (func2) when x == 1.
378 //
379 // If x == 0, func1 and func2 both have the same first output. So we take a
380 // look at the second output. func2 is better since in the second output,
381 // y - 1 (func2) is < y (func1).
382 PWMAFunction func1 = parsePWMAF(pieces: {
383 {"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)", "(x, y) -> (-x, y)"},
384 });
385
386 PWMAFunction func2 = parsePWMAF(pieces: {
387 {"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)", "(x, y) -> (0, y - 1)"},
388 });
389
390 PWMAFunction result = parsePWMAF(pieces: {
391 {"(x, y) : (x == 1, y >= 0, y <= 1)", "(x, y) -> (-x, y)"},
392 {"(x, y) : (x == 0, y >= 0, y <= 1)", "(x, y) -> (0, y - 1)"},
393 });
394
395 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
396 EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
397}
398
399TEST(PWMAFunction, unionLexMinWithDivs) {
400 {
401 PWMAFunction func1 = parsePWMAF(pieces: {
402 {"(x, y) : (x mod 5 == 0)", "(x, y) -> (x, 1)"},
403 });
404
405 PWMAFunction func2 = parsePWMAF(pieces: {
406 {"(x, y) : (x mod 7 == 0)", "(x, y) -> (x + y, 2)"},
407 });
408
409 PWMAFunction result = parsePWMAF(pieces: {
410 {"(x, y) : (x mod 5 == 0, x mod 7 >= 1)", "(x, y) -> (x, 1)"},
411 {"(x, y) : (x mod 7 == 0, x mod 5 >= 1)", "(x, y) -> (x + y, 2)"},
412 {"(x, y) : (x mod 5 == 0, x mod 7 == 0, y >= 0)", "(x, y) -> (x, 1)"},
413 {"(x, y) : (x mod 7 == 0, x mod 5 == 0, y <= -1)",
414 "(x, y) -> (x + y, 2)"},
415 });
416
417 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
418 }
419
420 {
421 PWMAFunction func1 = parsePWMAF(pieces: {
422 {"(x) : (x >= 0, x <= 1000)", "(x) -> (x floordiv 16)"},
423 });
424
425 PWMAFunction func2 = parsePWMAF(pieces: {
426 {"(x) : (x >= 0, x <= 1000)", "(x) -> ((x + 10) floordiv 17)"},
427 });
428
429 PWMAFunction result = parsePWMAF(pieces: {
430 {"(x) : (x >= 0, x <= 1000, x floordiv 16 <= (x + 10) floordiv 17)",
431 "(x) -> (x floordiv 16)"},
432 {"(x) : (x >= 0, x <= 1000, x floordiv 16 >= (x + 10) floordiv 17 + 1)",
433 "(x) -> ((x + 10) floordiv 17)"},
434 });
435
436 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
437 }
438}
439

source code of mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp