1//===- RuntimeCallTestBase.cpp -- Base for runtime call generation tests --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H
10#define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H
11
12#include "gtest/gtest.h"
13#include "flang/Optimizer/Builder/FIRBuilder.h"
14#include "flang/Optimizer/Dialect/Support/KindMapping.h"
15#include "flang/Optimizer/Support/InitFIR.h"
16
17struct RuntimeCallTest : public testing::Test {
18public:
19 void SetUp() override {
20 fir::support::loadDialects(context);
21
22 mlir::OpBuilder builder(&context);
23 auto loc = builder.getUnknownLoc();
24
25 // Set up a Module with a dummy function operation inside.
26 // Set the insertion point in the function entry block.
27 moduleOp = builder.create<mlir::ModuleOp>(loc);
28 builder.setInsertionPointToStart(moduleOp->getBody());
29 mlir::func::FuncOp func =
30 builder.create<mlir::func::FuncOp>(loc, "runtime_unit_tests_func",
31 builder.getFunctionType(std::nullopt, std::nullopt));
32 auto *entryBlock = func.addEntryBlock();
33 builder.setInsertionPointToStart(entryBlock);
34
35 kindMap = std::make_unique<fir::KindMapping>(&context);
36 firBuilder = std::make_unique<fir::FirOpBuilder>(builder, *kindMap);
37
38 i1Ty = firBuilder->getI1Type();
39 i8Ty = firBuilder->getI8Type();
40 i16Ty = firBuilder->getIntegerType(16);
41 i32Ty = firBuilder->getI32Type();
42 i64Ty = firBuilder->getI64Type();
43 i128Ty = firBuilder->getIntegerType(128);
44
45 f32Ty = firBuilder->getF32Type();
46 f64Ty = firBuilder->getF64Type();
47 f80Ty = firBuilder->getF80Type();
48 f128Ty = firBuilder->getF128Type();
49
50 c4Ty = mlir::ComplexType::get(f32Ty);
51 c8Ty = mlir::ComplexType::get(f64Ty);
52 c10Ty = mlir::ComplexType::get(f80Ty);
53 c16Ty = mlir::ComplexType::get(f128Ty);
54
55 seqTy10 = fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty);
56 boxTy = fir::BoxType::get(mlir::NoneType::get(firBuilder->getContext()));
57
58 char1Ty = fir::CharacterType::getSingleton(builder.getContext(), 1);
59 char2Ty = fir::CharacterType::getSingleton(builder.getContext(), 2);
60 char4Ty = fir::CharacterType::getSingleton(builder.getContext(), 4);
61
62 logical1Ty = fir::LogicalType::get(builder.getContext(), 1);
63 logical2Ty = fir::LogicalType::get(builder.getContext(), 2);
64 logical4Ty = fir::LogicalType::get(builder.getContext(), 4);
65 logical8Ty = fir::LogicalType::get(builder.getContext(), 8);
66 }
67
68 mlir::MLIRContext context;
69 mlir::OwningOpRef<mlir::ModuleOp> moduleOp;
70 std::unique_ptr<fir::KindMapping> kindMap;
71 std::unique_ptr<fir::FirOpBuilder> firBuilder;
72
73 // Commonly used types
74 mlir::Type i1Ty;
75 mlir::Type i8Ty;
76 mlir::Type i16Ty;
77 mlir::Type i32Ty;
78 mlir::Type i64Ty;
79 mlir::Type i128Ty;
80 mlir::Type f32Ty;
81 mlir::Type f64Ty;
82 mlir::Type f80Ty;
83 mlir::Type f128Ty;
84 mlir::Type c4Ty;
85 mlir::Type c8Ty;
86 mlir::Type c10Ty;
87 mlir::Type c16Ty;
88 mlir::Type seqTy10;
89 mlir::Type boxTy;
90 mlir::Type char1Ty;
91 mlir::Type char2Ty;
92 mlir::Type char4Ty;
93 mlir::Type logical1Ty;
94 mlir::Type logical2Ty;
95 mlir::Type logical4Ty;
96 mlir::Type logical8Ty;
97};
98
99/// Check that the \p op is a `fir::CallOp` operation and its name matches
100/// \p fctName and the number of arguments is equal to \p nbArgs.
101/// Most runtime calls have two additional location arguments added. These are
102/// added in this check when \p addLocArgs is true.
103static inline void checkCallOp(mlir::Operation *op, llvm::StringRef fctName,
104 unsigned nbArgs, bool addLocArgs = true) {
105 EXPECT_TRUE(mlir::isa<fir::CallOp>(*op));
106 auto callOp = mlir::dyn_cast<fir::CallOp>(*op);
107 EXPECT_TRUE(callOp.getCallee().has_value());
108 mlir::SymbolRefAttr callee = *callOp.getCallee();
109 EXPECT_EQ(fctName, callee.getRootReference().getValue());
110 // sourceFile and sourceLine are added arguments.
111 if (addLocArgs)
112 nbArgs += 2;
113 EXPECT_EQ(nbArgs, callOp.getArgs().size());
114}
115
116/// Check the call operation from the \p result value. In some cases the
117/// value is directly used in the call and sometimes there is an indirection
118/// through a `fir.convert` operation. Once the `fir.call` operation is
119/// retrieved the check is made by `checkCallOp`.
120///
121/// Directly used in `fir.call`.
122/// ```
123/// %result = arith.constant 1 : i32
124/// %0 = fir.call @foo(%result) : (i32) -> i1
125/// ```
126///
127/// Value used in `fir.call` through `fir.convert` indirection.
128/// ```
129/// %result = arith.constant 1 : i32
130/// %arg = fir.convert %result : (i32) -> i16
131/// %0 = fir.call @foo(%arg) : (i16) -> i1
132/// ```
133static inline void checkCallOpFromResultBox(mlir::Value result,
134 llvm::StringRef fctName, unsigned nbArgs, bool addLocArgs = true) {
135 EXPECT_TRUE(result.hasOneUse());
136 const auto &u = result.user_begin();
137 if (mlir::isa<fir::CallOp>(*u))
138 return checkCallOp(*u, fctName, nbArgs, addLocArgs);
139 auto convOp = mlir::dyn_cast<fir::ConvertOp>(*u);
140 EXPECT_NE(nullptr, convOp);
141 checkCallOpFromResultBox(convOp.getResult(), fctName, nbArgs, addLocArgs);
142}
143
144/// Check the operations in \p block for a `fir::CallOp` operation where the
145/// function being called shares its function name with \p fctName and the
146/// number of arguments is equal to \p nbArgs. Note that this check only cares
147/// if the operation exists, and not the order in when the operation is called.
148/// This results in exiting the test as soon as the first correct instance of
149/// `fir::CallOp` is found).
150static inline void checkBlockForCallOp(
151 mlir::Block *block, llvm::StringRef fctName, unsigned nbArgs) {
152 assert(block && "mlir::Block given is a nullptr");
153 for (auto &op : block->getOperations()) {
154 if (auto callOp = mlir::dyn_cast<fir::CallOp>(op)) {
155 if (fctName == callOp.getCallee()->getRootReference().getValue()) {
156 EXPECT_EQ(nbArgs, callOp.getArgs().size());
157 return;
158 }
159 }
160 }
161 FAIL() << "No calls to " << fctName << " were found!";
162}
163
164#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RUNTIMECALLTESTBASE_H
165

source code of flang/unittests/Optimizer/Builder/Runtime/RuntimeCallTestBase.h