1//===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/SCF/IR/SCF.h"
11#include "mlir/IR/Diagnostics.h"
12#include "mlir/IR/MLIRContext.h"
13#include "mlir/IR/OwningOpRef.h"
14#include "gtest/gtest.h"
15
16using namespace mlir;
17using namespace mlir::scf;
18
19//===----------------------------------------------------------------------===//
20// Test Fixture
21//===----------------------------------------------------------------------===//
22
23class SCFLoopLikeTest : public ::testing::Test {
24protected:
25 SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
26 context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
27 }
28
29 void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
30 std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
31 EXPECT_TRUE(maybeLb.has_value());
32 std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
33 EXPECT_TRUE(maybeUb.has_value());
34 std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
35 EXPECT_TRUE(maybeStep.has_value());
36 std::optional<OpFoldResult> maybeIndVar =
37 loopLikeOp.getSingleInductionVar();
38 EXPECT_TRUE(maybeIndVar.has_value());
39 }
40
41 void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
42 std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
43 EXPECT_FALSE(maybeLb.has_value());
44 std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
45 EXPECT_FALSE(maybeUb.has_value());
46 std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
47 EXPECT_FALSE(maybeStep.has_value());
48 std::optional<OpFoldResult> maybeIndVar =
49 loopLikeOp.getSingleInductionVar();
50 EXPECT_FALSE(maybeIndVar.has_value());
51 }
52
53 MLIRContext context;
54 OpBuilder b;
55 Location loc;
56};
57
58TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) {
59 OwningOpRef<arith::ConstantIndexOp> lb =
60 b.create<arith::ConstantIndexOp>(location: loc, args: 0);
61 OwningOpRef<arith::ConstantIndexOp> ub =
62 b.create<arith::ConstantIndexOp>(location: loc, args: 10);
63 OwningOpRef<arith::ConstantIndexOp> step =
64 b.create<arith::ConstantIndexOp>(location: loc, args: 2);
65
66 OwningOpRef<scf::ForOp> forOp =
67 b.create<scf::ForOp>(loc, lb.get(), ub.get(), step.get());
68 checkUnidimensional(forOp.get());
69
70 OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
71 loc, ArrayRef<OpFoldResult>(lb->getResult()),
72 ArrayRef<OpFoldResult>(ub->getResult()),
73 ArrayRef<OpFoldResult>(step->getResult()), ValueRange(), std::nullopt);
74 checkUnidimensional(forallOp.get());
75
76 OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
77 loc, ValueRange(lb->getResult()), ValueRange(ub->getResult()),
78 ValueRange(step->getResult()), ValueRange());
79 checkUnidimensional(parallelOp.get());
80}
81
82TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
83 OwningOpRef<arith::ConstantIndexOp> lb =
84 b.create<arith::ConstantIndexOp>(location: loc, args: 0);
85 OwningOpRef<arith::ConstantIndexOp> ub =
86 b.create<arith::ConstantIndexOp>(location: loc, args: 10);
87 OwningOpRef<arith::ConstantIndexOp> step =
88 b.create<arith::ConstantIndexOp>(location: loc, args: 2);
89
90 OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
91 loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
92 ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
93 ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
94 ValueRange(), std::nullopt);
95 checkMultidimensional(forallOp.get());
96
97 OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
98 loc, ValueRange({lb->getResult(), lb->getResult()}),
99 ValueRange({ub->getResult(), ub->getResult()}),
100 ValueRange({step->getResult(), step->getResult()}), ValueRange());
101 checkMultidimensional(parallelOp.get());
102}
103

source code of mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp