1//===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/SCF/IR/SCF.h"
12#include "mlir/Dialect/SCF/Utils/Utils.h"
13#include "mlir/Dialect/Utils/StaticValueUtils.h"
14#include "mlir/IR/Diagnostics.h"
15#include "mlir/IR/MLIRContext.h"
16#include "mlir/IR/OwningOpRef.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Interfaces/LoopLikeInterface.h"
19#include "gtest/gtest.h"
20
21using namespace mlir;
22using namespace mlir::scf;
23
24//===----------------------------------------------------------------------===//
25// Test Fixture
26//===----------------------------------------------------------------------===//
27
28class SCFLoopLikeTest : public ::testing::Test {
29protected:
30 SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
31 context.loadDialect<affine::AffineDialect, arith::ArithDialect,
32 scf::SCFDialect>();
33 }
34
35 void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
36 std::optional<OpFoldResult> maybeSingleLb =
37 loopLikeOp.getSingleLowerBound();
38 EXPECT_TRUE(maybeSingleLb.has_value());
39 std::optional<OpFoldResult> maybeSingleUb =
40 loopLikeOp.getSingleUpperBound();
41 EXPECT_TRUE(maybeSingleUb.has_value());
42 std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
43 EXPECT_TRUE(maybeSingleStep.has_value());
44 std::optional<OpFoldResult> maybeSingleIndVar =
45 loopLikeOp.getSingleInductionVar();
46 EXPECT_TRUE(maybeSingleIndVar.has_value());
47
48 std::optional<SmallVector<OpFoldResult>> maybeLb =
49 loopLikeOp.getLoopLowerBounds();
50 ASSERT_TRUE(maybeLb.has_value());
51 EXPECT_EQ((*maybeLb).size(), 1u);
52 std::optional<SmallVector<OpFoldResult>> maybeUb =
53 loopLikeOp.getLoopUpperBounds();
54 ASSERT_TRUE(maybeUb.has_value());
55 EXPECT_EQ((*maybeUb).size(), 1u);
56 std::optional<SmallVector<OpFoldResult>> maybeStep =
57 loopLikeOp.getLoopSteps();
58 ASSERT_TRUE(maybeStep.has_value());
59 EXPECT_EQ((*maybeStep).size(), 1u);
60 std::optional<SmallVector<Value>> maybeInductionVars =
61 loopLikeOp.getLoopInductionVars();
62 ASSERT_TRUE(maybeInductionVars.has_value());
63 EXPECT_EQ((*maybeInductionVars).size(), 1u);
64 }
65
66 void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
67 std::optional<OpFoldResult> maybeSingleLb =
68 loopLikeOp.getSingleLowerBound();
69 EXPECT_FALSE(maybeSingleLb.has_value());
70 std::optional<OpFoldResult> maybeSingleUb =
71 loopLikeOp.getSingleUpperBound();
72 EXPECT_FALSE(maybeSingleUb.has_value());
73 std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
74 EXPECT_FALSE(maybeSingleStep.has_value());
75 std::optional<OpFoldResult> maybeSingleIndVar =
76 loopLikeOp.getSingleInductionVar();
77 EXPECT_FALSE(maybeSingleIndVar.has_value());
78
79 std::optional<SmallVector<OpFoldResult>> maybeLb =
80 loopLikeOp.getLoopLowerBounds();
81 ASSERT_TRUE(maybeLb.has_value());
82 EXPECT_EQ((*maybeLb).size(), 2u);
83 std::optional<SmallVector<OpFoldResult>> maybeUb =
84 loopLikeOp.getLoopUpperBounds();
85 ASSERT_TRUE(maybeUb.has_value());
86 EXPECT_EQ((*maybeUb).size(), 2u);
87 std::optional<SmallVector<OpFoldResult>> maybeStep =
88 loopLikeOp.getLoopSteps();
89 ASSERT_TRUE(maybeStep.has_value());
90 EXPECT_EQ((*maybeStep).size(), 2u);
91 std::optional<SmallVector<Value>> maybeInductionVars =
92 loopLikeOp.getLoopInductionVars();
93 ASSERT_TRUE(maybeInductionVars.has_value());
94 EXPECT_EQ((*maybeInductionVars).size(), 2u);
95 }
96
97 void checkNormalized(LoopLikeOpInterface loopLikeOp) {
98 std::optional<SmallVector<OpFoldResult>> maybeLb =
99 loopLikeOp.getLoopLowerBounds();
100 ASSERT_TRUE(maybeLb.has_value());
101 std::optional<SmallVector<OpFoldResult>> maybeStep =
102 loopLikeOp.getLoopSteps();
103 ASSERT_TRUE(maybeStep.has_value());
104
105 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
106 return llvm::all_of(Range&: results, P: [&](OpFoldResult ofr) {
107 auto intValue = getConstantIntValue(ofr);
108 return intValue.has_value() && intValue == val;
109 });
110 };
111 EXPECT_TRUE(allEqual(*maybeLb, 0));
112 EXPECT_TRUE(allEqual(*maybeStep, 1));
113 }
114
115 MLIRContext context;
116 OpBuilder b;
117 Location loc;
118};
119
120TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) {
121 OwningOpRef<arith::ConstantIndexOp> lb =
122 b.create<arith::ConstantIndexOp>(location: loc, args: 0);
123 OwningOpRef<arith::ConstantIndexOp> ub =
124 b.create<arith::ConstantIndexOp>(location: loc, args: 10);
125 OwningOpRef<arith::ConstantIndexOp> step =
126 b.create<arith::ConstantIndexOp>(location: loc, args: 2);
127
128 OwningOpRef<scf::ForOp> forOp =
129 b.create<scf::ForOp>(loc, lb.get(), ub.get(), step.get());
130 checkUnidimensional(forOp.get());
131
132 OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
133 loc, ArrayRef<OpFoldResult>(lb->getResult()),
134 ArrayRef<OpFoldResult>(ub->getResult()),
135 ArrayRef<OpFoldResult>(step->getResult()), ValueRange(), std::nullopt);
136 checkUnidimensional(forallOp.get());
137
138 OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
139 loc, ValueRange(lb->getResult()), ValueRange(ub->getResult()),
140 ValueRange(step->getResult()), ValueRange());
141 checkUnidimensional(parallelOp.get());
142}
143
144TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
145 OwningOpRef<arith::ConstantIndexOp> lb =
146 b.create<arith::ConstantIndexOp>(location: loc, args: 0);
147 OwningOpRef<arith::ConstantIndexOp> ub =
148 b.create<arith::ConstantIndexOp>(location: loc, args: 10);
149 OwningOpRef<arith::ConstantIndexOp> step =
150 b.create<arith::ConstantIndexOp>(location: loc, args: 2);
151
152 OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
153 loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
154 ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
155 ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
156 ValueRange(), std::nullopt);
157 checkMultidimensional(forallOp.get());
158
159 OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
160 loc, ValueRange({lb->getResult(), lb->getResult()}),
161 ValueRange({ub->getResult(), ub->getResult()}),
162 ValueRange({step->getResult(), step->getResult()}), ValueRange());
163 checkMultidimensional(parallelOp.get());
164}
165
166TEST_F(SCFLoopLikeTest, testForallNormalize) {
167 OwningOpRef<arith::ConstantIndexOp> lb =
168 b.create<arith::ConstantIndexOp>(location: loc, args: 1);
169 OwningOpRef<arith::ConstantIndexOp> ub =
170 b.create<arith::ConstantIndexOp>(location: loc, args: 10);
171 OwningOpRef<arith::ConstantIndexOp> step =
172 b.create<arith::ConstantIndexOp>(location: loc, args: 3);
173
174 scf::ForallOp forallOp = b.create<scf::ForallOp>(
175 loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
176 ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
177 ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
178 ValueRange(), std::nullopt);
179 // Create a user of the induction variable. Bitcast is chosen for simplicity
180 // since it is unary.
181 b.setInsertionPointToStart(forallOp.getBody());
182 b.create<arith::BitcastOp>(UnknownLoc::get(&context), b.getF64Type(),
183 forallOp.getInductionVar(0));
184 IRRewriter rewriter(b);
185 FailureOr<scf::ForallOp> maybeNormalizedForallOp =
186 normalizeForallOp(rewriter, forallOp);
187 EXPECT_TRUE(succeeded(maybeNormalizedForallOp));
188 OwningOpRef<scf::ForallOp> normalizedForallOp(*maybeNormalizedForallOp);
189 checkNormalized(normalizedForallOp.get());
190
191 // Check that the IV user has been updated to use the denormalized variable.
192 Block *body = normalizedForallOp->getBody();
193 auto bitcastOps = body->getOps<arith::BitcastOp>();
194 ASSERT_EQ(std::distance(bitcastOps.begin(), bitcastOps.end()), 1);
195 arith::BitcastOp ivUser = *bitcastOps.begin();
196 ASSERT_NE(ivUser.getIn(), normalizedForallOp->getInductionVar(0));
197}
198

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp