1//===- AffineExprTest.cpp - unit tests for affine expression API ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cstdint>
10#include <limits>
11
12#include "mlir/IR/AffineExpr.h"
13#include "mlir/IR/Builders.h"
14#include "gtest/gtest.h"
15
16using namespace mlir;
17
18static std::string toString(AffineExpr expr) {
19 std::string s;
20 llvm::raw_string_ostream ss(s);
21 ss << expr;
22 return s;
23}
24
25// Test creating AffineExprs using the overloaded binary operators.
26TEST(AffineExprTest, constructFromBinaryOperators) {
27 MLIRContext ctx;
28 OpBuilder b(&ctx);
29
30 auto d0 = b.getAffineDimExpr(position: 0);
31 auto d1 = b.getAffineDimExpr(position: 1);
32
33 auto sum = d0 + d1;
34 auto difference = d0 - d1;
35 auto product = d0 * d1;
36 auto remainder = d0 % d1;
37
38 ASSERT_EQ(sum.getKind(), AffineExprKind::Add);
39 ASSERT_EQ(difference.getKind(), AffineExprKind::Add);
40 ASSERT_EQ(product.getKind(), AffineExprKind::Mul);
41 ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod);
42}
43
44TEST(AffineExprTest, constantFolding) {
45 MLIRContext ctx;
46 OpBuilder b(&ctx);
47 auto cn1 = b.getAffineConstantExpr(constant: -1);
48 auto c0 = b.getAffineConstantExpr(constant: 0);
49 auto c1 = b.getAffineConstantExpr(constant: 1);
50 auto c2 = b.getAffineConstantExpr(constant: 2);
51 auto c3 = b.getAffineConstantExpr(constant: 3);
52 auto c6 = b.getAffineConstantExpr(constant: 6);
53 auto cmax = b.getAffineConstantExpr(constant: std::numeric_limits<int64_t>::max());
54 auto cmin = b.getAffineConstantExpr(constant: std::numeric_limits<int64_t>::min());
55
56 ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Add, c1, c2), c3);
57 ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Mul, c2, c3), c6);
58 ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c2), c1);
59 ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c2), c2);
60
61 // Test division by zero:
62 auto c3ceildivc0 = getAffineBinaryOpExpr(kind: AffineExprKind::CeilDiv, lhs: c3, rhs: c0);
63 ASSERT_EQ(c3ceildivc0.getKind(), AffineExprKind::CeilDiv);
64
65 auto c3floordivc0 = getAffineBinaryOpExpr(kind: AffineExprKind::FloorDiv, lhs: c3, rhs: c0);
66 ASSERT_EQ(c3floordivc0.getKind(), AffineExprKind::FloorDiv);
67
68 auto c3modc0 = getAffineBinaryOpExpr(kind: AffineExprKind::Mod, lhs: c3, rhs: c0);
69 ASSERT_EQ(c3modc0.getKind(), AffineExprKind::Mod);
70
71 // Test overflow:
72 auto cmaxplusc1 = getAffineBinaryOpExpr(kind: AffineExprKind::Add, lhs: cmax, rhs: c1);
73 ASSERT_EQ(cmaxplusc1.getKind(), AffineExprKind::Add);
74
75 auto cmaxtimesc2 = getAffineBinaryOpExpr(kind: AffineExprKind::Mul, lhs: cmax, rhs: c2);
76 ASSERT_EQ(cmaxtimesc2.getKind(), AffineExprKind::Mul);
77
78 auto cminceildivcn1 =
79 getAffineBinaryOpExpr(kind: AffineExprKind::CeilDiv, lhs: cmin, rhs: cn1);
80 ASSERT_EQ(cminceildivcn1.getKind(), AffineExprKind::CeilDiv);
81
82 auto cminfloordivcn1 =
83 getAffineBinaryOpExpr(kind: AffineExprKind::FloorDiv, lhs: cmin, rhs: cn1);
84 ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
85}
86
87TEST(AffineExprTest, commutative) {
88 MLIRContext ctx;
89 OpBuilder b(&ctx);
90 auto c2 = b.getAffineConstantExpr(constant: 1);
91 auto d0 = b.getAffineDimExpr(position: 0);
92 auto d1 = b.getAffineDimExpr(position: 1);
93 auto s0 = b.getAffineSymbolExpr(position: 0);
94 auto s1 = b.getAffineSymbolExpr(position: 1);
95
96 ASSERT_EQ(d0 * d1, d1 * d0);
97 ASSERT_EQ(s0 + s1, s1 + s0);
98 ASSERT_EQ(s0 * c2, c2 * s0);
99}
100
101TEST(AffineExprTest, divisionSimplification) {
102 MLIRContext ctx;
103 OpBuilder b(&ctx);
104 auto cn6 = b.getAffineConstantExpr(constant: -6);
105 auto c6 = b.getAffineConstantExpr(constant: 6);
106 auto d0 = b.getAffineDimExpr(position: 0);
107 auto d1 = b.getAffineDimExpr(position: 1);
108
109 ASSERT_EQ(c6.floorDiv(-1), cn6);
110 ASSERT_EQ((d0 * 6).floorDiv(2), d0 * 3);
111 ASSERT_EQ((d0 * 6).floorDiv(4).getKind(), AffineExprKind::FloorDiv);
112 ASSERT_EQ((d0 * 6).floorDiv(-2), d0 * -3);
113 ASSERT_EQ((d0 * 6 + d1).floorDiv(2), d0 * 3 + d1.floorDiv(2));
114 ASSERT_EQ((d0 * 6 + d1).floorDiv(-2), d0 * -3 + d1.floorDiv(-2));
115 ASSERT_EQ((d0 * 6 + d1).floorDiv(4).getKind(), AffineExprKind::FloorDiv);
116
117 ASSERT_EQ(c6.ceilDiv(-1), cn6);
118 ASSERT_EQ((d0 * 6).ceilDiv(2), d0 * 3);
119 ASSERT_EQ((d0 * 6).ceilDiv(4).getKind(), AffineExprKind::CeilDiv);
120 ASSERT_EQ((d0 * 6).ceilDiv(-2), d0 * -3);
121}
122
123TEST(AffineExprTest, modSimplificationRegression) {
124 MLIRContext ctx;
125 OpBuilder b(&ctx);
126 auto d0 = b.getAffineDimExpr(position: 0);
127 auto sum = d0 + d0.floorDiv(v: 3).floorDiv(v: -3);
128 ASSERT_EQ(sum.getKind(), AffineExprKind::Add);
129}
130
131TEST(AffineExprTest, divisorOfNegativeFloorDiv) {
132 MLIRContext ctx;
133 OpBuilder b(&ctx);
134 ASSERT_EQ(b.getAffineDimExpr(0).floorDiv(-1).getLargestKnownDivisor(), 1);
135}
136
137TEST(AffineExprTest, d0PlusD0FloorDivNeg2) {
138 // Regression test for a bug where this was rewritten to d0 mod -2. We do not
139 // support a negative RHS for mod in LowerAffinePass.
140 MLIRContext ctx;
141 OpBuilder b(&ctx);
142 auto d0 = b.getAffineDimExpr(position: 0);
143 auto sum = d0 + d0.floorDiv(v: -2) * 2;
144 ASSERT_EQ(toString(sum), "d0 + (d0 floordiv -2) * 2");
145}
146
147TEST(AffineExprTest, simpleAffineExprFlattenerRegression) {
148
149 // Regression test for a bug where mod simplification was not handled
150 // properly when `lhs % rhs` was happened to have the property that `lhs
151 // floordiv rhs = lhs`.
152 MLIRContext ctx;
153 OpBuilder b(&ctx);
154
155 auto d0 = b.getAffineDimExpr(position: 0);
156
157 // Manually replace variables by constants to avoid constant folding.
158 AffineExpr expr = (d0 - (d0 + 2)).floorDiv(v: 8) % 8;
159 AffineExpr result = mlir::simplifyAffineExpr(expr, numDims: 1, numSymbols: 0);
160
161 ASSERT_TRUE(isa<AffineConstantExpr>(result));
162 ASSERT_EQ(cast<AffineConstantExpr>(result).getValue(), 7);
163}
164
165TEST(AffineExprTest, simplifyCommutative) {
166 MLIRContext ctx;
167 OpBuilder b(&ctx);
168 auto s0 = b.getAffineSymbolExpr(position: 0);
169 auto s1 = b.getAffineSymbolExpr(position: 1);
170
171 ASSERT_EQ(s0 * s1 - s1 * s0 + 1, 1);
172}
173

source code of mlir/unittests/IR/AffineExprTest.cpp