1 | //===- ComplexExprTest.cpp -- ComplexExpr unit tests ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Builder/Complex.h" |
10 | #include "gtest/gtest.h" |
11 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
12 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
13 | #include "flang/Optimizer/Support/InitFIR.h" |
14 | |
15 | struct ComplexTest : public testing::Test { |
16 | public: |
17 | void SetUp() override { |
18 | fir::support::loadDialects(context); |
19 | |
20 | mlir::OpBuilder builder(&context); |
21 | auto loc = builder.getUnknownLoc(); |
22 | |
23 | // Set up a Module with a dummy function operation inside. |
24 | // Set the insertion point in the function entry block. |
25 | mlir::ModuleOp mod = builder.create<mlir::ModuleOp>(loc); |
26 | mlir::func::FuncOp func = mlir::func::FuncOp::create( |
27 | loc, "func1" , builder.getFunctionType(std::nullopt, std::nullopt)); |
28 | auto *entryBlock = func.addEntryBlock(); |
29 | mod.push_back(mod); |
30 | builder.setInsertionPointToStart(entryBlock); |
31 | |
32 | kindMap = std::make_unique<fir::KindMapping>(&context); |
33 | firBuilder = std::make_unique<fir::FirOpBuilder>(mod, *kindMap); |
34 | helper = std::make_unique<fir::factory::Complex>(*firBuilder, loc); |
35 | |
36 | // Init commonly used types |
37 | realTy1 = mlir::FloatType::getF32(&context); |
38 | complexTy1 = fir::ComplexType::get(&context, 4); |
39 | integerTy1 = mlir::IntegerType::get(&context, 32); |
40 | |
41 | // Create commonly used reals |
42 | rOne = firBuilder->createRealConstant(loc, realTy1, 1u); |
43 | rTwo = firBuilder->createRealConstant(loc, realTy1, 2u); |
44 | rThree = firBuilder->createRealConstant(loc, realTy1, 3u); |
45 | rFour = firBuilder->createRealConstant(loc, realTy1, 4u); |
46 | } |
47 | |
48 | mlir::MLIRContext context; |
49 | std::unique_ptr<fir::KindMapping> kindMap; |
50 | std::unique_ptr<fir::FirOpBuilder> firBuilder; |
51 | std::unique_ptr<fir::factory::Complex> helper; |
52 | |
53 | // Commonly used real/complex/integer types |
54 | mlir::FloatType realTy1; |
55 | fir::ComplexType complexTy1; |
56 | mlir::IntegerType integerTy1; |
57 | |
58 | // Commonly used real numbers |
59 | mlir::Value rOne; |
60 | mlir::Value rTwo; |
61 | mlir::Value rThree; |
62 | mlir::Value rFour; |
63 | }; |
64 | |
65 | TEST_F(ComplexTest, verifyTypes) { |
66 | mlir::Value cVal1 = helper->createComplex(complexTy1, rOne, rTwo); |
67 | mlir::Value cVal2 = helper->createComplex(4, rOne, rTwo); |
68 | EXPECT_TRUE(fir::isa_complex(cVal1.getType())); |
69 | EXPECT_TRUE(fir::isa_complex(cVal2.getType())); |
70 | EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal1))); |
71 | EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal2))); |
72 | |
73 | mlir::Value real1 = helper->extractComplexPart(cVal1, /*isImagPart=*/false); |
74 | mlir::Value imag1 = helper->extractComplexPart(cVal1, /*isImagPart=*/true); |
75 | mlir::Value real2 = helper->extractComplexPart(cVal2, /*isImagPart=*/false); |
76 | mlir::Value imag2 = helper->extractComplexPart(cVal2, /*isImagPart=*/true); |
77 | EXPECT_EQ(realTy1, real1.getType()); |
78 | EXPECT_EQ(realTy1, imag1.getType()); |
79 | EXPECT_EQ(realTy1, real2.getType()); |
80 | EXPECT_EQ(realTy1, imag2.getType()); |
81 | |
82 | mlir::Value cVal3 = |
83 | helper->insertComplexPart(cVal1, rThree, /*isImagPart=*/false); |
84 | mlir::Value cVal4 = |
85 | helper->insertComplexPart(cVal3, rFour, /*isImagPart=*/true); |
86 | EXPECT_TRUE(fir::isa_complex(cVal4.getType())); |
87 | EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal4))); |
88 | } |
89 | |
90 | TEST_F(ComplexTest, verifyConvertWithSemantics) { |
91 | auto loc = firBuilder->getUnknownLoc(); |
92 | rOne = firBuilder->createRealConstant(loc, realTy1, 1u); |
93 | // Convert real to complex |
94 | mlir::Value v1 = firBuilder->convertWithSemantics(loc, complexTy1, rOne); |
95 | EXPECT_TRUE(fir::isa_complex(v1.getType())); |
96 | |
97 | // Convert complex to integer |
98 | mlir::Value v2 = firBuilder->convertWithSemantics(loc, integerTy1, v1); |
99 | EXPECT_TRUE(v2.getType().isa<mlir::IntegerType>()); |
100 | EXPECT_TRUE(mlir::dyn_cast<fir::ConvertOp>(v2.getDefiningOp())); |
101 | } |
102 | |