1 | //===- ComplexExprTest.cpp -- ComplexExpr unit tests ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Builder/Complex.h" |
10 | #include "gtest/gtest.h" |
11 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
12 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
13 | #include "flang/Optimizer/Support/InitFIR.h" |
14 | |
15 | struct ComplexTest : public testing::Test { |
16 | public: |
17 | void SetUp() override { |
18 | fir::support::loadDialects(context); |
19 | |
20 | mlir::OpBuilder builder(&context); |
21 | auto loc = builder.getUnknownLoc(); |
22 | |
23 | // Set up a Module with a dummy function operation inside. |
24 | // Set the insertion point in the function entry block. |
25 | moduleOp = builder.create<mlir::ModuleOp>(loc); |
26 | builder.setInsertionPointToStart(moduleOp->getBody()); |
27 | mlir::func::FuncOp func = builder.create<mlir::func::FuncOp>( |
28 | loc, "func1" , builder.getFunctionType(std::nullopt, std::nullopt)); |
29 | auto *entryBlock = func.addEntryBlock(); |
30 | builder.setInsertionPointToStart(entryBlock); |
31 | |
32 | kindMap = std::make_unique<fir::KindMapping>(&context); |
33 | firBuilder = std::make_unique<fir::FirOpBuilder>(builder, *kindMap); |
34 | helper = std::make_unique<fir::factory::Complex>(*firBuilder, loc); |
35 | |
36 | // Init commonly used types |
37 | realTy1 = mlir::Float32Type::get(&context); |
38 | complexTy1 = mlir::ComplexType::get(realTy1); |
39 | integerTy1 = mlir::IntegerType::get(&context, 32); |
40 | |
41 | // Create commonly used reals |
42 | rOne = firBuilder->createRealConstant(loc, realTy1, 1u); |
43 | rTwo = firBuilder->createRealConstant(loc, realTy1, 2u); |
44 | rThree = firBuilder->createRealConstant(loc, realTy1, 3u); |
45 | rFour = firBuilder->createRealConstant(loc, realTy1, 4u); |
46 | } |
47 | |
48 | mlir::MLIRContext context; |
49 | mlir::OwningOpRef<mlir::ModuleOp> moduleOp; |
50 | std::unique_ptr<fir::KindMapping> kindMap; |
51 | std::unique_ptr<fir::FirOpBuilder> firBuilder; |
52 | std::unique_ptr<fir::factory::Complex> helper; |
53 | |
54 | // Commonly used real/complex/integer types |
55 | mlir::FloatType realTy1; |
56 | mlir::ComplexType complexTy1; |
57 | mlir::IntegerType integerTy1; |
58 | |
59 | // Commonly used real numbers |
60 | mlir::Value rOne; |
61 | mlir::Value rTwo; |
62 | mlir::Value rThree; |
63 | mlir::Value rFour; |
64 | }; |
65 | |
66 | TEST_F(ComplexTest, verifyTypes) { |
67 | mlir::Value cVal1 = helper->createComplex(complexTy1, rOne, rTwo); |
68 | EXPECT_TRUE(fir::isa_complex(cVal1.getType())); |
69 | EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal1))); |
70 | |
71 | mlir::Value real1 = helper->extractComplexPart(cVal1, /*isImagPart=*/false); |
72 | mlir::Value imag1 = helper->extractComplexPart(cVal1, /*isImagPart=*/true); |
73 | EXPECT_EQ(realTy1, real1.getType()); |
74 | EXPECT_EQ(realTy1, imag1.getType()); |
75 | |
76 | mlir::Value cVal3 = |
77 | helper->insertComplexPart(cVal1, rThree, /*isImagPart=*/false); |
78 | mlir::Value cVal4 = |
79 | helper->insertComplexPart(cVal3, rFour, /*isImagPart=*/true); |
80 | EXPECT_TRUE(fir::isa_complex(cVal4.getType())); |
81 | EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal4))); |
82 | } |
83 | |
84 | TEST_F(ComplexTest, verifyConvertWithSemantics) { |
85 | auto loc = firBuilder->getUnknownLoc(); |
86 | rOne = firBuilder->createRealConstant(loc, realTy1, 1u); |
87 | // Convert real to complex |
88 | mlir::Value v1 = firBuilder->convertWithSemantics(loc, complexTy1, rOne); |
89 | EXPECT_TRUE(fir::isa_complex(v1.getType())); |
90 | |
91 | // Convert complex to integer |
92 | mlir::Value v2 = firBuilder->convertWithSemantics(loc, integerTy1, v1); |
93 | EXPECT_TRUE(mlir::isa<mlir::IntegerType>(v2.getType())); |
94 | EXPECT_TRUE(mlir::dyn_cast<fir::ConvertOp>(v2.getDefiningOp())); |
95 | } |
96 | |