| 1 | //===- ComplexExprTest.cpp -- ComplexExpr unit tests ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "flang/Optimizer/Builder/Complex.h" |
| 10 | #include "gtest/gtest.h" |
| 11 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
| 12 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
| 13 | #include "flang/Optimizer/Support/InitFIR.h" |
| 14 | |
| 15 | struct ComplexTest : public testing::Test { |
| 16 | public: |
| 17 | void SetUp() override { |
| 18 | fir::support::loadDialects(context); |
| 19 | |
| 20 | mlir::OpBuilder builder(&context); |
| 21 | auto loc = builder.getUnknownLoc(); |
| 22 | |
| 23 | // Set up a Module with a dummy function operation inside. |
| 24 | // Set the insertion point in the function entry block. |
| 25 | moduleOp = builder.create<mlir::ModuleOp>(loc); |
| 26 | builder.setInsertionPointToStart(moduleOp->getBody()); |
| 27 | mlir::func::FuncOp func = builder.create<mlir::func::FuncOp>( |
| 28 | loc, "func1" , builder.getFunctionType(std::nullopt, std::nullopt)); |
| 29 | auto *entryBlock = func.addEntryBlock(); |
| 30 | builder.setInsertionPointToStart(entryBlock); |
| 31 | |
| 32 | kindMap = std::make_unique<fir::KindMapping>(&context); |
| 33 | firBuilder = std::make_unique<fir::FirOpBuilder>(builder, *kindMap); |
| 34 | helper = std::make_unique<fir::factory::Complex>(*firBuilder, loc); |
| 35 | |
| 36 | // Init commonly used types |
| 37 | realTy1 = mlir::Float32Type::get(&context); |
| 38 | complexTy1 = mlir::ComplexType::get(realTy1); |
| 39 | integerTy1 = mlir::IntegerType::get(&context, 32); |
| 40 | |
| 41 | // Create commonly used reals |
| 42 | rOne = firBuilder->createRealConstant(loc, realTy1, 1u); |
| 43 | rTwo = firBuilder->createRealConstant(loc, realTy1, 2u); |
| 44 | rThree = firBuilder->createRealConstant(loc, realTy1, 3u); |
| 45 | rFour = firBuilder->createRealConstant(loc, realTy1, 4u); |
| 46 | } |
| 47 | |
| 48 | mlir::MLIRContext context; |
| 49 | mlir::OwningOpRef<mlir::ModuleOp> moduleOp; |
| 50 | std::unique_ptr<fir::KindMapping> kindMap; |
| 51 | std::unique_ptr<fir::FirOpBuilder> firBuilder; |
| 52 | std::unique_ptr<fir::factory::Complex> helper; |
| 53 | |
| 54 | // Commonly used real/complex/integer types |
| 55 | mlir::FloatType realTy1; |
| 56 | mlir::ComplexType complexTy1; |
| 57 | mlir::IntegerType integerTy1; |
| 58 | |
| 59 | // Commonly used real numbers |
| 60 | mlir::Value rOne; |
| 61 | mlir::Value rTwo; |
| 62 | mlir::Value rThree; |
| 63 | mlir::Value rFour; |
| 64 | }; |
| 65 | |
| 66 | TEST_F(ComplexTest, verifyTypes) { |
| 67 | mlir::Value cVal1 = helper->createComplex(complexTy1, rOne, rTwo); |
| 68 | EXPECT_TRUE(fir::isa_complex(cVal1.getType())); |
| 69 | EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal1))); |
| 70 | |
| 71 | mlir::Value real1 = helper->extractComplexPart(cVal1, /*isImagPart=*/false); |
| 72 | mlir::Value imag1 = helper->extractComplexPart(cVal1, /*isImagPart=*/true); |
| 73 | EXPECT_EQ(realTy1, real1.getType()); |
| 74 | EXPECT_EQ(realTy1, imag1.getType()); |
| 75 | |
| 76 | mlir::Value cVal3 = |
| 77 | helper->insertComplexPart(cVal1, rThree, /*isImagPart=*/false); |
| 78 | mlir::Value cVal4 = |
| 79 | helper->insertComplexPart(cVal3, rFour, /*isImagPart=*/true); |
| 80 | EXPECT_TRUE(fir::isa_complex(cVal4.getType())); |
| 81 | EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal4))); |
| 82 | } |
| 83 | |
| 84 | TEST_F(ComplexTest, verifyConvertWithSemantics) { |
| 85 | auto loc = firBuilder->getUnknownLoc(); |
| 86 | rOne = firBuilder->createRealConstant(loc, realTy1, 1u); |
| 87 | // Convert real to complex |
| 88 | mlir::Value v1 = firBuilder->convertWithSemantics(loc, complexTy1, rOne); |
| 89 | EXPECT_TRUE(fir::isa_complex(v1.getType())); |
| 90 | |
| 91 | // Convert complex to integer |
| 92 | mlir::Value v2 = firBuilder->convertWithSemantics(loc, integerTy1, v1); |
| 93 | EXPECT_TRUE(mlir::isa<mlir::IntegerType>(v2.getType())); |
| 94 | EXPECT_TRUE(mlir::dyn_cast<fir::ConvertOp>(v2.getDefiningOp())); |
| 95 | } |
| 96 | |