1//===- DoLoopHelper.cpp -- DoLoopHelper unit tests ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Builder/DoLoopHelper.h"
10#include "gtest/gtest.h"
11#include "flang/Optimizer/Dialect/Support/KindMapping.h"
12#include "flang/Optimizer/Support/InitFIR.h"
13#include <string>
14
15using namespace mlir;
16
17struct DoLoopHelperTest : public testing::Test {
18public:
19 void SetUp() {
20 kindMap = std::make_unique<fir::KindMapping>(&context);
21 mlir::OpBuilder builder(&context);
22 firBuilder = new fir::FirOpBuilder(builder, *kindMap);
23 fir::support::loadDialects(context);
24 }
25 void TearDown() { delete firBuilder; }
26
27 fir::FirOpBuilder &getBuilder() { return *firBuilder; }
28
29 mlir::MLIRContext context;
30 std::unique_ptr<fir::KindMapping> kindMap;
31 fir::FirOpBuilder *firBuilder;
32};
33
34void checkConstantValue(const mlir::Value &value, int64_t v) {
35 EXPECT_TRUE(mlir::isa<mlir::arith::ConstantOp>(value.getDefiningOp()));
36 auto cstOp = dyn_cast<mlir::arith::ConstantOp>(value.getDefiningOp());
37 auto valueAttr = cstOp.getValue().dyn_cast_or_null<IntegerAttr>();
38 EXPECT_EQ(v, valueAttr.getInt());
39}
40
41TEST_F(DoLoopHelperTest, createLoopWithCountTest) {
42 auto firBuilder = getBuilder();
43 fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
44
45 auto c10 = firBuilder.createIntegerConstant(
46 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 10);
47 auto loop =
48 helper.createLoop(c10, [&](fir::FirOpBuilder &, mlir::Value index) {});
49 checkConstantValue(loop.getLowerBound(), 0);
50 EXPECT_TRUE(mlir::isa<arith::SubIOp>(loop.getUpperBound().getDefiningOp()));
51 auto subOp = dyn_cast<arith::SubIOp>(loop.getUpperBound().getDefiningOp());
52 EXPECT_EQ(c10, subOp.getLhs());
53 checkConstantValue(subOp.getRhs(), 1);
54 checkConstantValue(loop.getStep(), 1);
55}
56
57TEST_F(DoLoopHelperTest, createLoopWithLowerAndUpperBound) {
58 auto firBuilder = getBuilder();
59 fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
60
61 auto lb = firBuilder.createIntegerConstant(
62 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1);
63 auto ub = firBuilder.createIntegerConstant(
64 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20);
65 auto loop =
66 helper.createLoop(lb, ub, [&](fir::FirOpBuilder &, mlir::Value index) {});
67 checkConstantValue(loop.getLowerBound(), 1);
68 checkConstantValue(loop.getUpperBound(), 20);
69 checkConstantValue(loop.getStep(), 1);
70}
71
72TEST_F(DoLoopHelperTest, createLoopWithStep) {
73 auto firBuilder = getBuilder();
74 fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
75
76 auto lb = firBuilder.createIntegerConstant(
77 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1);
78 auto ub = firBuilder.createIntegerConstant(
79 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20);
80 auto step = firBuilder.createIntegerConstant(
81 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 2);
82 auto loop = helper.createLoop(
83 lb, ub, step, [&](fir::FirOpBuilder &, mlir::Value index) {});
84 checkConstantValue(loop.getLowerBound(), 1);
85 checkConstantValue(loop.getUpperBound(), 20);
86 checkConstantValue(loop.getStep(), 2);
87}
88

source code of flang/unittests/Optimizer/Builder/DoLoopHelperTest.cpp