| 1 | //===- RuntimeCallTestBase.cpp -- Base for runtime call generation tests --===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H |
| 10 | #define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H |
| 11 | |
| 12 | #include "gtest/gtest.h" |
| 13 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
| 14 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
| 15 | #include "flang/Optimizer/Support/InitFIR.h" |
| 16 | |
| 17 | struct RuntimeCallTest : public testing::Test { |
| 18 | public: |
| 19 | void SetUp() override { |
| 20 | fir::support::loadDialects(context); |
| 21 | |
| 22 | mlir::OpBuilder builder(&context); |
| 23 | auto loc = builder.getUnknownLoc(); |
| 24 | |
| 25 | // Set up a Module with a dummy function operation inside. |
| 26 | // Set the insertion point in the function entry block. |
| 27 | moduleOp = builder.create<mlir::ModuleOp>(loc); |
| 28 | builder.setInsertionPointToStart(moduleOp->getBody()); |
| 29 | mlir::func::FuncOp func = |
| 30 | builder.create<mlir::func::FuncOp>(loc, "runtime_unit_tests_func" , |
| 31 | builder.getFunctionType(std::nullopt, std::nullopt)); |
| 32 | auto *entryBlock = func.addEntryBlock(); |
| 33 | builder.setInsertionPointToStart(entryBlock); |
| 34 | |
| 35 | kindMap = std::make_unique<fir::KindMapping>(&context); |
| 36 | firBuilder = std::make_unique<fir::FirOpBuilder>(builder, *kindMap); |
| 37 | |
| 38 | i1Ty = firBuilder->getI1Type(); |
| 39 | i8Ty = firBuilder->getI8Type(); |
| 40 | i16Ty = firBuilder->getIntegerType(16); |
| 41 | i32Ty = firBuilder->getI32Type(); |
| 42 | i64Ty = firBuilder->getI64Type(); |
| 43 | i128Ty = firBuilder->getIntegerType(128); |
| 44 | |
| 45 | f32Ty = firBuilder->getF32Type(); |
| 46 | f64Ty = firBuilder->getF64Type(); |
| 47 | f80Ty = firBuilder->getF80Type(); |
| 48 | f128Ty = firBuilder->getF128Type(); |
| 49 | |
| 50 | c4Ty = mlir::ComplexType::get(f32Ty); |
| 51 | c8Ty = mlir::ComplexType::get(f64Ty); |
| 52 | c10Ty = mlir::ComplexType::get(f80Ty); |
| 53 | c16Ty = mlir::ComplexType::get(f128Ty); |
| 54 | |
| 55 | seqTy10 = fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); |
| 56 | boxTy = fir::BoxType::get(mlir::NoneType::get(firBuilder->getContext())); |
| 57 | |
| 58 | char1Ty = fir::CharacterType::getSingleton(builder.getContext(), 1); |
| 59 | char2Ty = fir::CharacterType::getSingleton(builder.getContext(), 2); |
| 60 | char4Ty = fir::CharacterType::getSingleton(builder.getContext(), 4); |
| 61 | |
| 62 | logical1Ty = fir::LogicalType::get(builder.getContext(), 1); |
| 63 | logical2Ty = fir::LogicalType::get(builder.getContext(), 2); |
| 64 | logical4Ty = fir::LogicalType::get(builder.getContext(), 4); |
| 65 | logical8Ty = fir::LogicalType::get(builder.getContext(), 8); |
| 66 | } |
| 67 | |
| 68 | mlir::MLIRContext context; |
| 69 | mlir::OwningOpRef<mlir::ModuleOp> moduleOp; |
| 70 | std::unique_ptr<fir::KindMapping> kindMap; |
| 71 | std::unique_ptr<fir::FirOpBuilder> firBuilder; |
| 72 | |
| 73 | // Commonly used types |
| 74 | mlir::Type i1Ty; |
| 75 | mlir::Type i8Ty; |
| 76 | mlir::Type i16Ty; |
| 77 | mlir::Type i32Ty; |
| 78 | mlir::Type i64Ty; |
| 79 | mlir::Type i128Ty; |
| 80 | mlir::Type f32Ty; |
| 81 | mlir::Type f64Ty; |
| 82 | mlir::Type f80Ty; |
| 83 | mlir::Type f128Ty; |
| 84 | mlir::Type c4Ty; |
| 85 | mlir::Type c8Ty; |
| 86 | mlir::Type c10Ty; |
| 87 | mlir::Type c16Ty; |
| 88 | mlir::Type seqTy10; |
| 89 | mlir::Type boxTy; |
| 90 | mlir::Type char1Ty; |
| 91 | mlir::Type char2Ty; |
| 92 | mlir::Type char4Ty; |
| 93 | mlir::Type logical1Ty; |
| 94 | mlir::Type logical2Ty; |
| 95 | mlir::Type logical4Ty; |
| 96 | mlir::Type logical8Ty; |
| 97 | }; |
| 98 | |
| 99 | /// Check that the \p op is a `fir::CallOp` operation and its name matches |
| 100 | /// \p fctName and the number of arguments is equal to \p nbArgs. |
| 101 | /// Most runtime calls have two additional location arguments added. These are |
| 102 | /// added in this check when \p addLocArgs is true. |
| 103 | static inline void checkCallOp(mlir::Operation *op, llvm::StringRef fctName, |
| 104 | unsigned nbArgs, bool addLocArgs = true) { |
| 105 | EXPECT_TRUE(mlir::isa<fir::CallOp>(*op)); |
| 106 | auto callOp = mlir::dyn_cast<fir::CallOp>(*op); |
| 107 | EXPECT_TRUE(callOp.getCallee().has_value()); |
| 108 | mlir::SymbolRefAttr callee = *callOp.getCallee(); |
| 109 | EXPECT_EQ(fctName, callee.getRootReference().getValue()); |
| 110 | // sourceFile and sourceLine are added arguments. |
| 111 | if (addLocArgs) |
| 112 | nbArgs += 2; |
| 113 | EXPECT_EQ(nbArgs, callOp.getArgs().size()); |
| 114 | } |
| 115 | |
| 116 | /// Check the call operation from the \p result value. In some cases the |
| 117 | /// value is directly used in the call and sometimes there is an indirection |
| 118 | /// through a `fir.convert` operation. Once the `fir.call` operation is |
| 119 | /// retrieved the check is made by `checkCallOp`. |
| 120 | /// |
| 121 | /// Directly used in `fir.call`. |
| 122 | /// ``` |
| 123 | /// %result = arith.constant 1 : i32 |
| 124 | /// %0 = fir.call @foo(%result) : (i32) -> i1 |
| 125 | /// ``` |
| 126 | /// |
| 127 | /// Value used in `fir.call` through `fir.convert` indirection. |
| 128 | /// ``` |
| 129 | /// %result = arith.constant 1 : i32 |
| 130 | /// %arg = fir.convert %result : (i32) -> i16 |
| 131 | /// %0 = fir.call @foo(%arg) : (i16) -> i1 |
| 132 | /// ``` |
| 133 | static inline void checkCallOpFromResultBox(mlir::Value result, |
| 134 | llvm::StringRef fctName, unsigned nbArgs, bool addLocArgs = true) { |
| 135 | EXPECT_TRUE(result.hasOneUse()); |
| 136 | const auto &u = result.user_begin(); |
| 137 | if (mlir::isa<fir::CallOp>(*u)) |
| 138 | return checkCallOp(*u, fctName, nbArgs, addLocArgs); |
| 139 | auto convOp = mlir::dyn_cast<fir::ConvertOp>(*u); |
| 140 | EXPECT_NE(nullptr, convOp); |
| 141 | checkCallOpFromResultBox(convOp.getResult(), fctName, nbArgs, addLocArgs); |
| 142 | } |
| 143 | |
| 144 | /// Check the operations in \p block for a `fir::CallOp` operation where the |
| 145 | /// function being called shares its function name with \p fctName and the |
| 146 | /// number of arguments is equal to \p nbArgs. Note that this check only cares |
| 147 | /// if the operation exists, and not the order in when the operation is called. |
| 148 | /// This results in exiting the test as soon as the first correct instance of |
| 149 | /// `fir::CallOp` is found). |
| 150 | static inline void checkBlockForCallOp( |
| 151 | mlir::Block *block, llvm::StringRef fctName, unsigned nbArgs) { |
| 152 | assert(block && "mlir::Block given is a nullptr" ); |
| 153 | for (auto &op : block->getOperations()) { |
| 154 | if (auto callOp = mlir::dyn_cast<fir::CallOp>(op)) { |
| 155 | if (fctName == callOp.getCallee()->getRootReference().getValue()) { |
| 156 | EXPECT_EQ(nbArgs, callOp.getArgs().size()); |
| 157 | return; |
| 158 | } |
| 159 | } |
| 160 | } |
| 161 | FAIL() << "No calls to " << fctName << " were found!" ; |
| 162 | } |
| 163 | |
| 164 | #endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H |
| 165 | |