1 | //===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
11 | #include "mlir/Dialect/SCF/IR/SCF.h" |
12 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
13 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
14 | #include "mlir/IR/Diagnostics.h" |
15 | #include "mlir/IR/MLIRContext.h" |
16 | #include "mlir/IR/OwningOpRef.h" |
17 | #include "mlir/IR/PatternMatch.h" |
18 | #include "mlir/Interfaces/LoopLikeInterface.h" |
19 | #include "gtest/gtest.h" |
20 | |
21 | using namespace mlir; |
22 | using namespace mlir::scf; |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | // Test Fixture |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | class SCFLoopLikeTest : public ::testing::Test { |
29 | protected: |
30 | SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) { |
31 | context.loadDialect<affine::AffineDialect, arith::ArithDialect, |
32 | scf::SCFDialect>(); |
33 | } |
34 | |
35 | void checkUnidimensional(LoopLikeOpInterface loopLikeOp) { |
36 | std::optional<OpFoldResult> maybeSingleLb = |
37 | loopLikeOp.getSingleLowerBound(); |
38 | EXPECT_TRUE(maybeSingleLb.has_value()); |
39 | std::optional<OpFoldResult> maybeSingleUb = |
40 | loopLikeOp.getSingleUpperBound(); |
41 | EXPECT_TRUE(maybeSingleUb.has_value()); |
42 | std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep(); |
43 | EXPECT_TRUE(maybeSingleStep.has_value()); |
44 | std::optional<OpFoldResult> maybeSingleIndVar = |
45 | loopLikeOp.getSingleInductionVar(); |
46 | EXPECT_TRUE(maybeSingleIndVar.has_value()); |
47 | |
48 | std::optional<SmallVector<OpFoldResult>> maybeLb = |
49 | loopLikeOp.getLoopLowerBounds(); |
50 | ASSERT_TRUE(maybeLb.has_value()); |
51 | EXPECT_EQ((*maybeLb).size(), 1u); |
52 | std::optional<SmallVector<OpFoldResult>> maybeUb = |
53 | loopLikeOp.getLoopUpperBounds(); |
54 | ASSERT_TRUE(maybeUb.has_value()); |
55 | EXPECT_EQ((*maybeUb).size(), 1u); |
56 | std::optional<SmallVector<OpFoldResult>> maybeStep = |
57 | loopLikeOp.getLoopSteps(); |
58 | ASSERT_TRUE(maybeStep.has_value()); |
59 | EXPECT_EQ((*maybeStep).size(), 1u); |
60 | std::optional<SmallVector<Value>> maybeInductionVars = |
61 | loopLikeOp.getLoopInductionVars(); |
62 | ASSERT_TRUE(maybeInductionVars.has_value()); |
63 | EXPECT_EQ((*maybeInductionVars).size(), 1u); |
64 | } |
65 | |
66 | void checkMultidimensional(LoopLikeOpInterface loopLikeOp) { |
67 | std::optional<OpFoldResult> maybeSingleLb = |
68 | loopLikeOp.getSingleLowerBound(); |
69 | EXPECT_FALSE(maybeSingleLb.has_value()); |
70 | std::optional<OpFoldResult> maybeSingleUb = |
71 | loopLikeOp.getSingleUpperBound(); |
72 | EXPECT_FALSE(maybeSingleUb.has_value()); |
73 | std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep(); |
74 | EXPECT_FALSE(maybeSingleStep.has_value()); |
75 | std::optional<OpFoldResult> maybeSingleIndVar = |
76 | loopLikeOp.getSingleInductionVar(); |
77 | EXPECT_FALSE(maybeSingleIndVar.has_value()); |
78 | |
79 | std::optional<SmallVector<OpFoldResult>> maybeLb = |
80 | loopLikeOp.getLoopLowerBounds(); |
81 | ASSERT_TRUE(maybeLb.has_value()); |
82 | EXPECT_EQ((*maybeLb).size(), 2u); |
83 | std::optional<SmallVector<OpFoldResult>> maybeUb = |
84 | loopLikeOp.getLoopUpperBounds(); |
85 | ASSERT_TRUE(maybeUb.has_value()); |
86 | EXPECT_EQ((*maybeUb).size(), 2u); |
87 | std::optional<SmallVector<OpFoldResult>> maybeStep = |
88 | loopLikeOp.getLoopSteps(); |
89 | ASSERT_TRUE(maybeStep.has_value()); |
90 | EXPECT_EQ((*maybeStep).size(), 2u); |
91 | std::optional<SmallVector<Value>> maybeInductionVars = |
92 | loopLikeOp.getLoopInductionVars(); |
93 | ASSERT_TRUE(maybeInductionVars.has_value()); |
94 | EXPECT_EQ((*maybeInductionVars).size(), 2u); |
95 | } |
96 | |
97 | void checkNormalized(LoopLikeOpInterface loopLikeOp) { |
98 | std::optional<SmallVector<OpFoldResult>> maybeLb = |
99 | loopLikeOp.getLoopLowerBounds(); |
100 | ASSERT_TRUE(maybeLb.has_value()); |
101 | std::optional<SmallVector<OpFoldResult>> maybeStep = |
102 | loopLikeOp.getLoopSteps(); |
103 | ASSERT_TRUE(maybeStep.has_value()); |
104 | |
105 | auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) { |
106 | return llvm::all_of(Range&: results, P: [&](OpFoldResult ofr) { |
107 | auto intValue = getConstantIntValue(ofr); |
108 | return intValue.has_value() && intValue == val; |
109 | }); |
110 | }; |
111 | EXPECT_TRUE(allEqual(*maybeLb, 0)); |
112 | EXPECT_TRUE(allEqual(*maybeStep, 1)); |
113 | } |
114 | |
115 | MLIRContext context; |
116 | OpBuilder b; |
117 | Location loc; |
118 | }; |
119 | |
120 | TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) { |
121 | OwningOpRef<arith::ConstantIndexOp> lb = |
122 | b.create<arith::ConstantIndexOp>(location: loc, args: 0); |
123 | OwningOpRef<arith::ConstantIndexOp> ub = |
124 | b.create<arith::ConstantIndexOp>(location: loc, args: 10); |
125 | OwningOpRef<arith::ConstantIndexOp> step = |
126 | b.create<arith::ConstantIndexOp>(location: loc, args: 2); |
127 | |
128 | OwningOpRef<scf::ForOp> forOp = |
129 | b.create<scf::ForOp>(loc, lb.get(), ub.get(), step.get()); |
130 | checkUnidimensional(forOp.get()); |
131 | |
132 | OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>( |
133 | loc, ArrayRef<OpFoldResult>(lb->getResult()), |
134 | ArrayRef<OpFoldResult>(ub->getResult()), |
135 | ArrayRef<OpFoldResult>(step->getResult()), ValueRange(), std::nullopt); |
136 | checkUnidimensional(forallOp.get()); |
137 | |
138 | OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>( |
139 | loc, ValueRange(lb->getResult()), ValueRange(ub->getResult()), |
140 | ValueRange(step->getResult()), ValueRange()); |
141 | checkUnidimensional(parallelOp.get()); |
142 | } |
143 | |
144 | TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) { |
145 | OwningOpRef<arith::ConstantIndexOp> lb = |
146 | b.create<arith::ConstantIndexOp>(location: loc, args: 0); |
147 | OwningOpRef<arith::ConstantIndexOp> ub = |
148 | b.create<arith::ConstantIndexOp>(location: loc, args: 10); |
149 | OwningOpRef<arith::ConstantIndexOp> step = |
150 | b.create<arith::ConstantIndexOp>(location: loc, args: 2); |
151 | |
152 | OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>( |
153 | loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}), |
154 | ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}), |
155 | ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}), |
156 | ValueRange(), std::nullopt); |
157 | checkMultidimensional(forallOp.get()); |
158 | |
159 | OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>( |
160 | loc, ValueRange({lb->getResult(), lb->getResult()}), |
161 | ValueRange({ub->getResult(), ub->getResult()}), |
162 | ValueRange({step->getResult(), step->getResult()}), ValueRange()); |
163 | checkMultidimensional(parallelOp.get()); |
164 | } |
165 | |
166 | TEST_F(SCFLoopLikeTest, testForallNormalize) { |
167 | OwningOpRef<arith::ConstantIndexOp> lb = |
168 | b.create<arith::ConstantIndexOp>(location: loc, args: 1); |
169 | OwningOpRef<arith::ConstantIndexOp> ub = |
170 | b.create<arith::ConstantIndexOp>(location: loc, args: 10); |
171 | OwningOpRef<arith::ConstantIndexOp> step = |
172 | b.create<arith::ConstantIndexOp>(location: loc, args: 3); |
173 | |
174 | scf::ForallOp forallOp = b.create<scf::ForallOp>( |
175 | loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}), |
176 | ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}), |
177 | ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}), |
178 | ValueRange(), std::nullopt); |
179 | // Create a user of the induction variable. Bitcast is chosen for simplicity |
180 | // since it is unary. |
181 | b.setInsertionPointToStart(forallOp.getBody()); |
182 | b.create<arith::BitcastOp>(UnknownLoc::get(&context), b.getF64Type(), |
183 | forallOp.getInductionVar(0)); |
184 | IRRewriter rewriter(b); |
185 | FailureOr<scf::ForallOp> maybeNormalizedForallOp = |
186 | normalizeForallOp(rewriter, forallOp); |
187 | EXPECT_TRUE(succeeded(maybeNormalizedForallOp)); |
188 | OwningOpRef<scf::ForallOp> normalizedForallOp(*maybeNormalizedForallOp); |
189 | checkNormalized(normalizedForallOp.get()); |
190 | |
191 | // Check that the IV user has been updated to use the denormalized variable. |
192 | Block *body = normalizedForallOp->getBody(); |
193 | auto bitcastOps = body->getOps<arith::BitcastOp>(); |
194 | ASSERT_EQ(std::distance(bitcastOps.begin(), bitcastOps.end()), 1); |
195 | arith::BitcastOp ivUser = *bitcastOps.begin(); |
196 | ASSERT_NE(ivUser.getIn(), normalizedForallOp->getInductionVar(0)); |
197 | } |
198 | |