1 | //===- AffineExprTest.cpp - unit tests for affine expression API ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <cstdint> |
10 | #include <limits> |
11 | |
12 | #include "mlir/IR/AffineExpr.h" |
13 | #include "mlir/IR/Builders.h" |
14 | #include "gtest/gtest.h" |
15 | |
16 | using namespace mlir; |
17 | |
18 | static std::string toString(AffineExpr expr) { |
19 | std::string s; |
20 | llvm::raw_string_ostream ss(s); |
21 | ss << expr; |
22 | return s; |
23 | } |
24 | |
25 | // Test creating AffineExprs using the overloaded binary operators. |
26 | TEST(AffineExprTest, constructFromBinaryOperators) { |
27 | MLIRContext ctx; |
28 | OpBuilder b(&ctx); |
29 | |
30 | auto d0 = b.getAffineDimExpr(position: 0); |
31 | auto d1 = b.getAffineDimExpr(position: 1); |
32 | |
33 | auto sum = d0 + d1; |
34 | auto difference = d0 - d1; |
35 | auto product = d0 * d1; |
36 | auto remainder = d0 % d1; |
37 | |
38 | ASSERT_EQ(sum.getKind(), AffineExprKind::Add); |
39 | ASSERT_EQ(difference.getKind(), AffineExprKind::Add); |
40 | ASSERT_EQ(product.getKind(), AffineExprKind::Mul); |
41 | ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod); |
42 | } |
43 | |
44 | TEST(AffineExprTest, constantFolding) { |
45 | MLIRContext ctx; |
46 | OpBuilder b(&ctx); |
47 | auto cn1 = b.getAffineConstantExpr(constant: -1); |
48 | auto c0 = b.getAffineConstantExpr(constant: 0); |
49 | auto c1 = b.getAffineConstantExpr(constant: 1); |
50 | auto c2 = b.getAffineConstantExpr(constant: 2); |
51 | auto c3 = b.getAffineConstantExpr(constant: 3); |
52 | auto c6 = b.getAffineConstantExpr(constant: 6); |
53 | auto cmax = b.getAffineConstantExpr(constant: std::numeric_limits<int64_t>::max()); |
54 | auto cmin = b.getAffineConstantExpr(constant: std::numeric_limits<int64_t>::min()); |
55 | |
56 | ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Add, c1, c2), c3); |
57 | ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Mul, c2, c3), c6); |
58 | ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c2), c1); |
59 | ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c2), c2); |
60 | |
61 | // Test division by zero: |
62 | auto c3ceildivc0 = getAffineBinaryOpExpr(kind: AffineExprKind::CeilDiv, lhs: c3, rhs: c0); |
63 | ASSERT_EQ(c3ceildivc0.getKind(), AffineExprKind::CeilDiv); |
64 | |
65 | auto c3floordivc0 = getAffineBinaryOpExpr(kind: AffineExprKind::FloorDiv, lhs: c3, rhs: c0); |
66 | ASSERT_EQ(c3floordivc0.getKind(), AffineExprKind::FloorDiv); |
67 | |
68 | auto c3modc0 = getAffineBinaryOpExpr(kind: AffineExprKind::Mod, lhs: c3, rhs: c0); |
69 | ASSERT_EQ(c3modc0.getKind(), AffineExprKind::Mod); |
70 | |
71 | // Test overflow: |
72 | auto cmaxplusc1 = getAffineBinaryOpExpr(kind: AffineExprKind::Add, lhs: cmax, rhs: c1); |
73 | ASSERT_EQ(cmaxplusc1.getKind(), AffineExprKind::Add); |
74 | |
75 | auto cmaxtimesc2 = getAffineBinaryOpExpr(kind: AffineExprKind::Mul, lhs: cmax, rhs: c2); |
76 | ASSERT_EQ(cmaxtimesc2.getKind(), AffineExprKind::Mul); |
77 | |
78 | auto cminceildivcn1 = |
79 | getAffineBinaryOpExpr(kind: AffineExprKind::CeilDiv, lhs: cmin, rhs: cn1); |
80 | ASSERT_EQ(cminceildivcn1.getKind(), AffineExprKind::CeilDiv); |
81 | |
82 | auto cminfloordivcn1 = |
83 | getAffineBinaryOpExpr(kind: AffineExprKind::FloorDiv, lhs: cmin, rhs: cn1); |
84 | ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv); |
85 | } |
86 | |
87 | TEST(AffineExprTest, divisionSimplification) { |
88 | MLIRContext ctx; |
89 | OpBuilder b(&ctx); |
90 | auto cn6 = b.getAffineConstantExpr(constant: -6); |
91 | auto c6 = b.getAffineConstantExpr(constant: 6); |
92 | auto d0 = b.getAffineDimExpr(position: 0); |
93 | auto d1 = b.getAffineDimExpr(position: 1); |
94 | |
95 | ASSERT_EQ(c6.floorDiv(-1), cn6); |
96 | ASSERT_EQ((d0 * 6).floorDiv(2), d0 * 3); |
97 | ASSERT_EQ((d0 * 6).floorDiv(4).getKind(), AffineExprKind::FloorDiv); |
98 | ASSERT_EQ((d0 * 6).floorDiv(-2), d0 * -3); |
99 | ASSERT_EQ((d0 * 6 + d1).floorDiv(2), d0 * 3 + d1.floorDiv(2)); |
100 | ASSERT_EQ((d0 * 6 + d1).floorDiv(-2), d0 * -3 + d1.floorDiv(-2)); |
101 | ASSERT_EQ((d0 * 6 + d1).floorDiv(4).getKind(), AffineExprKind::FloorDiv); |
102 | |
103 | ASSERT_EQ(c6.ceilDiv(-1), cn6); |
104 | ASSERT_EQ((d0 * 6).ceilDiv(2), d0 * 3); |
105 | ASSERT_EQ((d0 * 6).ceilDiv(4).getKind(), AffineExprKind::CeilDiv); |
106 | ASSERT_EQ((d0 * 6).ceilDiv(-2), d0 * -3); |
107 | } |
108 | |
109 | TEST(AffineExprTest, modSimplificationRegression) { |
110 | MLIRContext ctx; |
111 | OpBuilder b(&ctx); |
112 | auto d0 = b.getAffineDimExpr(position: 0); |
113 | auto sum = d0 + d0.floorDiv(v: 3).floorDiv(v: -3); |
114 | ASSERT_EQ(sum.getKind(), AffineExprKind::Add); |
115 | } |
116 | |
117 | TEST(AffineExprTest, divisorOfNegativeFloorDiv) { |
118 | MLIRContext ctx; |
119 | OpBuilder b(&ctx); |
120 | ASSERT_EQ(b.getAffineDimExpr(0).floorDiv(-1).getLargestKnownDivisor(), 1); |
121 | } |
122 | |
123 | TEST(AffineExprTest, d0PlusD0FloorDivNeg2) { |
124 | // Regression test for a bug where this was rewritten to d0 mod -2. We do not |
125 | // support a negative RHS for mod in LowerAffinePass. |
126 | MLIRContext ctx; |
127 | OpBuilder b(&ctx); |
128 | auto d0 = b.getAffineDimExpr(position: 0); |
129 | auto sum = d0 + d0.floorDiv(v: -2) * 2; |
130 | ASSERT_EQ(toString(sum), "d0 + (d0 floordiv -2) * 2" ); |
131 | } |
132 | |
133 | TEST(AffineExprTest, simpleAffineExprFlattenerRegression) { |
134 | |
135 | // Regression test for a bug where mod simplification was not handled |
136 | // properly when `lhs % rhs` was happened to have the property that `lhs |
137 | // floordiv rhs = lhs`. |
138 | MLIRContext ctx; |
139 | OpBuilder b(&ctx); |
140 | |
141 | auto d0 = b.getAffineDimExpr(position: 0); |
142 | |
143 | // Manually replace variables by constants to avoid constant folding. |
144 | AffineExpr expr = (d0 - (d0 + 2)).floorDiv(v: 8) % 8; |
145 | AffineExpr result = mlir::simplifyAffineExpr(expr, numDims: 1, numSymbols: 0); |
146 | |
147 | ASSERT_TRUE(isa<AffineConstantExpr>(result)); |
148 | ASSERT_EQ(cast<AffineConstantExpr>(result).getValue(), 7); |
149 | } |
150 | |