| 1 | //===- AffineExprTest.cpp - unit tests for affine expression API ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include <cstdint> |
| 10 | #include <limits> |
| 11 | |
| 12 | #include "mlir/IR/AffineExpr.h" |
| 13 | #include "mlir/IR/Builders.h" |
| 14 | #include "gtest/gtest.h" |
| 15 | |
| 16 | using namespace mlir; |
| 17 | |
| 18 | static std::string toString(AffineExpr expr) { |
| 19 | std::string s; |
| 20 | llvm::raw_string_ostream ss(s); |
| 21 | ss << expr; |
| 22 | return s; |
| 23 | } |
| 24 | |
| 25 | // Test creating AffineExprs using the overloaded binary operators. |
| 26 | TEST(AffineExprTest, constructFromBinaryOperators) { |
| 27 | MLIRContext ctx; |
| 28 | OpBuilder b(&ctx); |
| 29 | |
| 30 | auto d0 = b.getAffineDimExpr(position: 0); |
| 31 | auto d1 = b.getAffineDimExpr(position: 1); |
| 32 | |
| 33 | auto sum = d0 + d1; |
| 34 | auto difference = d0 - d1; |
| 35 | auto product = d0 * d1; |
| 36 | auto remainder = d0 % d1; |
| 37 | |
| 38 | ASSERT_EQ(sum.getKind(), AffineExprKind::Add); |
| 39 | ASSERT_EQ(difference.getKind(), AffineExprKind::Add); |
| 40 | ASSERT_EQ(product.getKind(), AffineExprKind::Mul); |
| 41 | ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod); |
| 42 | } |
| 43 | |
| 44 | TEST(AffineExprTest, constantFolding) { |
| 45 | MLIRContext ctx; |
| 46 | OpBuilder b(&ctx); |
| 47 | auto cn1 = b.getAffineConstantExpr(constant: -1); |
| 48 | auto c0 = b.getAffineConstantExpr(constant: 0); |
| 49 | auto c1 = b.getAffineConstantExpr(constant: 1); |
| 50 | auto c2 = b.getAffineConstantExpr(constant: 2); |
| 51 | auto c3 = b.getAffineConstantExpr(constant: 3); |
| 52 | auto c6 = b.getAffineConstantExpr(constant: 6); |
| 53 | auto cmax = b.getAffineConstantExpr(constant: std::numeric_limits<int64_t>::max()); |
| 54 | auto cmin = b.getAffineConstantExpr(constant: std::numeric_limits<int64_t>::min()); |
| 55 | |
| 56 | ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Add, c1, c2), c3); |
| 57 | ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Mul, c2, c3), c6); |
| 58 | ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c2), c1); |
| 59 | ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c2), c2); |
| 60 | |
| 61 | // Test division by zero: |
| 62 | auto c3ceildivc0 = getAffineBinaryOpExpr(kind: AffineExprKind::CeilDiv, lhs: c3, rhs: c0); |
| 63 | ASSERT_EQ(c3ceildivc0.getKind(), AffineExprKind::CeilDiv); |
| 64 | |
| 65 | auto c3floordivc0 = getAffineBinaryOpExpr(kind: AffineExprKind::FloorDiv, lhs: c3, rhs: c0); |
| 66 | ASSERT_EQ(c3floordivc0.getKind(), AffineExprKind::FloorDiv); |
| 67 | |
| 68 | auto c3modc0 = getAffineBinaryOpExpr(kind: AffineExprKind::Mod, lhs: c3, rhs: c0); |
| 69 | ASSERT_EQ(c3modc0.getKind(), AffineExprKind::Mod); |
| 70 | |
| 71 | // Test overflow: |
| 72 | auto cmaxplusc1 = getAffineBinaryOpExpr(kind: AffineExprKind::Add, lhs: cmax, rhs: c1); |
| 73 | ASSERT_EQ(cmaxplusc1.getKind(), AffineExprKind::Add); |
| 74 | |
| 75 | auto cmaxtimesc2 = getAffineBinaryOpExpr(kind: AffineExprKind::Mul, lhs: cmax, rhs: c2); |
| 76 | ASSERT_EQ(cmaxtimesc2.getKind(), AffineExprKind::Mul); |
| 77 | |
| 78 | auto cminceildivcn1 = |
| 79 | getAffineBinaryOpExpr(kind: AffineExprKind::CeilDiv, lhs: cmin, rhs: cn1); |
| 80 | ASSERT_EQ(cminceildivcn1.getKind(), AffineExprKind::CeilDiv); |
| 81 | |
| 82 | auto cminfloordivcn1 = |
| 83 | getAffineBinaryOpExpr(kind: AffineExprKind::FloorDiv, lhs: cmin, rhs: cn1); |
| 84 | ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv); |
| 85 | } |
| 86 | |
| 87 | TEST(AffineExprTest, divisionSimplification) { |
| 88 | MLIRContext ctx; |
| 89 | OpBuilder b(&ctx); |
| 90 | auto cn6 = b.getAffineConstantExpr(constant: -6); |
| 91 | auto c6 = b.getAffineConstantExpr(constant: 6); |
| 92 | auto d0 = b.getAffineDimExpr(position: 0); |
| 93 | auto d1 = b.getAffineDimExpr(position: 1); |
| 94 | |
| 95 | ASSERT_EQ(c6.floorDiv(-1), cn6); |
| 96 | ASSERT_EQ((d0 * 6).floorDiv(2), d0 * 3); |
| 97 | ASSERT_EQ((d0 * 6).floorDiv(4).getKind(), AffineExprKind::FloorDiv); |
| 98 | ASSERT_EQ((d0 * 6).floorDiv(-2), d0 * -3); |
| 99 | ASSERT_EQ((d0 * 6 + d1).floorDiv(2), d0 * 3 + d1.floorDiv(2)); |
| 100 | ASSERT_EQ((d0 * 6 + d1).floorDiv(-2), d0 * -3 + d1.floorDiv(-2)); |
| 101 | ASSERT_EQ((d0 * 6 + d1).floorDiv(4).getKind(), AffineExprKind::FloorDiv); |
| 102 | |
| 103 | ASSERT_EQ(c6.ceilDiv(-1), cn6); |
| 104 | ASSERT_EQ((d0 * 6).ceilDiv(2), d0 * 3); |
| 105 | ASSERT_EQ((d0 * 6).ceilDiv(4).getKind(), AffineExprKind::CeilDiv); |
| 106 | ASSERT_EQ((d0 * 6).ceilDiv(-2), d0 * -3); |
| 107 | } |
| 108 | |
| 109 | TEST(AffineExprTest, modSimplificationRegression) { |
| 110 | MLIRContext ctx; |
| 111 | OpBuilder b(&ctx); |
| 112 | auto d0 = b.getAffineDimExpr(position: 0); |
| 113 | auto sum = d0 + d0.floorDiv(v: 3).floorDiv(v: -3); |
| 114 | ASSERT_EQ(sum.getKind(), AffineExprKind::Add); |
| 115 | } |
| 116 | |
| 117 | TEST(AffineExprTest, divisorOfNegativeFloorDiv) { |
| 118 | MLIRContext ctx; |
| 119 | OpBuilder b(&ctx); |
| 120 | ASSERT_EQ(b.getAffineDimExpr(0).floorDiv(-1).getLargestKnownDivisor(), 1); |
| 121 | } |
| 122 | |
| 123 | TEST(AffineExprTest, d0PlusD0FloorDivNeg2) { |
| 124 | // Regression test for a bug where this was rewritten to d0 mod -2. We do not |
| 125 | // support a negative RHS for mod in LowerAffinePass. |
| 126 | MLIRContext ctx; |
| 127 | OpBuilder b(&ctx); |
| 128 | auto d0 = b.getAffineDimExpr(position: 0); |
| 129 | auto sum = d0 + d0.floorDiv(v: -2) * 2; |
| 130 | ASSERT_EQ(toString(sum), "d0 + (d0 floordiv -2) * 2" ); |
| 131 | } |
| 132 | |
| 133 | TEST(AffineExprTest, simpleAffineExprFlattenerRegression) { |
| 134 | |
| 135 | // Regression test for a bug where mod simplification was not handled |
| 136 | // properly when `lhs % rhs` was happened to have the property that `lhs |
| 137 | // floordiv rhs = lhs`. |
| 138 | MLIRContext ctx; |
| 139 | OpBuilder b(&ctx); |
| 140 | |
| 141 | auto d0 = b.getAffineDimExpr(position: 0); |
| 142 | |
| 143 | // Manually replace variables by constants to avoid constant folding. |
| 144 | AffineExpr expr = (d0 - (d0 + 2)).floorDiv(v: 8) % 8; |
| 145 | AffineExpr result = mlir::simplifyAffineExpr(expr, numDims: 1, numSymbols: 0); |
| 146 | |
| 147 | ASSERT_TRUE(isa<AffineConstantExpr>(result)); |
| 148 | ASSERT_EQ(cast<AffineConstantExpr>(result).getValue(), 7); |
| 149 | } |
| 150 | |