1//===- InlineCostTest.cpp - test for InlineCost ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Analysis/InlineCost.h"
10#include "llvm/Analysis/AssumptionCache.h"
11#include "llvm/Analysis/InlineModelFeatureMaps.h"
12#include "llvm/Analysis/TargetTransformInfo.h"
13#include "llvm/AsmParser/Parser.h"
14#include "llvm/IR/InstIterator.h"
15#include "llvm/IR/Instructions.h"
16#include "llvm/IR/LLVMContext.h"
17#include "llvm/IR/Module.h"
18#include "llvm/Support/SourceMgr.h"
19#include "gtest/gtest.h"
20
21namespace {
22
23using namespace llvm;
24
25CallBase *getCallInFunction(Function *F) {
26 for (auto &I : instructions(F)) {
27 if (auto *CB = dyn_cast<llvm::CallBase>(Val: &I))
28 return CB;
29 }
30 return nullptr;
31}
32
33std::optional<InlineCostFeatures> getInliningCostFeaturesForCall(CallBase &CB) {
34 ModuleAnalysisManager MAM;
35 FunctionAnalysisManager FAM;
36 FAM.registerPass(PassBuilder: [&] { return TargetIRAnalysis(); });
37 FAM.registerPass(PassBuilder: [&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
38 FAM.registerPass(PassBuilder: [&] { return AssumptionAnalysis(); });
39 MAM.registerPass(PassBuilder: [&] { return FunctionAnalysisManagerModuleProxy(FAM); });
40
41 MAM.registerPass(PassBuilder: [&] { return PassInstrumentationAnalysis(); });
42 FAM.registerPass(PassBuilder: [&] { return PassInstrumentationAnalysis(); });
43
44 ModulePassManager MPM;
45 MPM.run(IR&: *CB.getModule(), AM&: MAM);
46
47 auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
48 return FAM.getResult<AssumptionAnalysis>(IR&: F);
49 };
50 auto &TIR = FAM.getResult<TargetIRAnalysis>(IR&: *CB.getFunction());
51
52 return getInliningCostFeatures(Call&: CB, CalleeTTI&: TIR, GetAssumptionCache);
53}
54
55// Tests that we can retrieve the CostFeatures without an error
56TEST(InlineCostTest, CostFeatures) {
57 const auto *const IR = R"IR(
58define i32 @f(i32) {
59 ret i32 4
60}
61
62define i32 @g(i32) {
63 %2 = call i32 @f(i32 0)
64 ret i32 %2
65}
66)IR";
67
68 LLVMContext C;
69 SMDiagnostic Err;
70 std::unique_ptr<Module> M = parseAssemblyString(AsmString: IR, Err, Context&: C);
71 ASSERT_TRUE(M);
72
73 auto *G = M->getFunction(Name: "g");
74 ASSERT_TRUE(G);
75
76 // find the call to f in g
77 CallBase *CB = getCallInFunction(F: G);
78 ASSERT_TRUE(CB);
79
80 const auto Features = getInliningCostFeaturesForCall(CB&: *CB);
81
82 // Check that the optional is not empty
83 ASSERT_TRUE(Features);
84}
85
86// Tests the calculated SROA cost
87TEST(InlineCostTest, SROACost) {
88 using namespace llvm;
89
90 const auto *const IR = R"IR(
91define void @f_savings(ptr %var) {
92 %load = load i32, ptr %var
93 %inc = add i32 %load, 1
94 store i32 %inc, ptr %var
95 ret void
96}
97
98define void @g_savings(i32) {
99 %var = alloca i32
100 call void @f_savings(ptr %var)
101 ret void
102}
103
104define void @f_losses(ptr %var) {
105 %load = load i32, ptr %var
106 %inc = add i32 %load, 1
107 store i32 %inc, ptr %var
108 call void @prevent_sroa(ptr %var)
109 ret void
110}
111
112define void @g_losses(i32) {
113 %var = alloca i32
114 call void @f_losses(ptr %var)
115 ret void
116}
117
118declare void @prevent_sroa(ptr)
119)IR";
120
121 LLVMContext C;
122 SMDiagnostic Err;
123 std::unique_ptr<Module> M = parseAssemblyString(AsmString: IR, Err, Context&: C);
124 ASSERT_TRUE(M);
125
126 const int DefaultInstCost = 5;
127 const int DefaultAllocaCost = 0;
128
129 const char *GName[] = {"g_savings", "g_losses", nullptr};
130 const int Savings[] = {2 * DefaultInstCost + DefaultAllocaCost, 0};
131 const int Losses[] = {0, 2 * DefaultInstCost + DefaultAllocaCost};
132
133 for (unsigned i = 0; GName[i]; ++i) {
134 auto *G = M->getFunction(Name: GName[i]);
135 ASSERT_TRUE(G);
136
137 // find the call to f in g
138 CallBase *CB = getCallInFunction(F: G);
139 ASSERT_TRUE(CB);
140
141 const auto Features = getInliningCostFeaturesForCall(CB&: *CB);
142 ASSERT_TRUE(Features);
143
144 // Check the predicted SROA cost
145 auto GetFeature = [&](InlineCostFeatureIndex I) {
146 return (*Features)[static_cast<size_t>(I)];
147 };
148 ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_savings), Savings[i]);
149 ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses), Losses[i]);
150 }
151}
152
153} // namespace
154

source code of llvm/unittests/Analysis/InlineCostTest.cpp