1 | //===- InlineCostTest.cpp - test for InlineCost ---------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/Analysis/InlineCost.h" |
10 | #include "llvm/Analysis/AssumptionCache.h" |
11 | #include "llvm/Analysis/InlineModelFeatureMaps.h" |
12 | #include "llvm/Analysis/TargetTransformInfo.h" |
13 | #include "llvm/AsmParser/Parser.h" |
14 | #include "llvm/IR/InstIterator.h" |
15 | #include "llvm/IR/Instructions.h" |
16 | #include "llvm/IR/LLVMContext.h" |
17 | #include "llvm/IR/Module.h" |
18 | #include "llvm/Support/SourceMgr.h" |
19 | #include "gtest/gtest.h" |
20 | |
21 | namespace { |
22 | |
23 | using namespace llvm; |
24 | |
25 | CallBase *getCallInFunction(Function *F) { |
26 | for (auto &I : instructions(F)) { |
27 | if (auto *CB = dyn_cast<llvm::CallBase>(Val: &I)) |
28 | return CB; |
29 | } |
30 | return nullptr; |
31 | } |
32 | |
33 | std::optional<InlineCostFeatures> getInliningCostFeaturesForCall(CallBase &CB) { |
34 | ModuleAnalysisManager MAM; |
35 | FunctionAnalysisManager FAM; |
36 | FAM.registerPass(PassBuilder: [&] { return TargetIRAnalysis(); }); |
37 | FAM.registerPass(PassBuilder: [&] { return ModuleAnalysisManagerFunctionProxy(MAM); }); |
38 | FAM.registerPass(PassBuilder: [&] { return AssumptionAnalysis(); }); |
39 | MAM.registerPass(PassBuilder: [&] { return FunctionAnalysisManagerModuleProxy(FAM); }); |
40 | |
41 | MAM.registerPass(PassBuilder: [&] { return PassInstrumentationAnalysis(); }); |
42 | FAM.registerPass(PassBuilder: [&] { return PassInstrumentationAnalysis(); }); |
43 | |
44 | ModulePassManager MPM; |
45 | MPM.run(IR&: *CB.getModule(), AM&: MAM); |
46 | |
47 | auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { |
48 | return FAM.getResult<AssumptionAnalysis>(IR&: F); |
49 | }; |
50 | auto &TIR = FAM.getResult<TargetIRAnalysis>(IR&: *CB.getFunction()); |
51 | |
52 | return getInliningCostFeatures(Call&: CB, CalleeTTI&: TIR, GetAssumptionCache); |
53 | } |
54 | |
55 | // Tests that we can retrieve the CostFeatures without an error |
56 | TEST(InlineCostTest, CostFeatures) { |
57 | const auto *const IR = R"IR( |
58 | define i32 @f(i32) { |
59 | ret i32 4 |
60 | } |
61 | |
62 | define i32 @g(i32) { |
63 | %2 = call i32 @f(i32 0) |
64 | ret i32 %2 |
65 | } |
66 | )IR" ; |
67 | |
68 | LLVMContext C; |
69 | SMDiagnostic Err; |
70 | std::unique_ptr<Module> M = parseAssemblyString(AsmString: IR, Err, Context&: C); |
71 | ASSERT_TRUE(M); |
72 | |
73 | auto *G = M->getFunction(Name: "g" ); |
74 | ASSERT_TRUE(G); |
75 | |
76 | // find the call to f in g |
77 | CallBase *CB = getCallInFunction(F: G); |
78 | ASSERT_TRUE(CB); |
79 | |
80 | const auto Features = getInliningCostFeaturesForCall(CB&: *CB); |
81 | |
82 | // Check that the optional is not empty |
83 | ASSERT_TRUE(Features); |
84 | } |
85 | |
86 | // Tests the calculated SROA cost |
87 | TEST(InlineCostTest, SROACost) { |
88 | using namespace llvm; |
89 | |
90 | const auto *const IR = R"IR( |
91 | define void @f_savings(ptr %var) { |
92 | %load = load i32, ptr %var |
93 | %inc = add i32 %load, 1 |
94 | store i32 %inc, ptr %var |
95 | ret void |
96 | } |
97 | |
98 | define void @g_savings(i32) { |
99 | %var = alloca i32 |
100 | call void @f_savings(ptr %var) |
101 | ret void |
102 | } |
103 | |
104 | define void @f_losses(ptr %var) { |
105 | %load = load i32, ptr %var |
106 | %inc = add i32 %load, 1 |
107 | store i32 %inc, ptr %var |
108 | call void @prevent_sroa(ptr %var) |
109 | ret void |
110 | } |
111 | |
112 | define void @g_losses(i32) { |
113 | %var = alloca i32 |
114 | call void @f_losses(ptr %var) |
115 | ret void |
116 | } |
117 | |
118 | declare void @prevent_sroa(ptr) |
119 | )IR" ; |
120 | |
121 | LLVMContext C; |
122 | SMDiagnostic Err; |
123 | std::unique_ptr<Module> M = parseAssemblyString(AsmString: IR, Err, Context&: C); |
124 | ASSERT_TRUE(M); |
125 | |
126 | const int DefaultInstCost = 5; |
127 | const int DefaultAllocaCost = 0; |
128 | |
129 | const char *GName[] = {"g_savings" , "g_losses" , nullptr}; |
130 | const int Savings[] = {2 * DefaultInstCost + DefaultAllocaCost, 0}; |
131 | const int Losses[] = {0, 2 * DefaultInstCost + DefaultAllocaCost}; |
132 | |
133 | for (unsigned i = 0; GName[i]; ++i) { |
134 | auto *G = M->getFunction(Name: GName[i]); |
135 | ASSERT_TRUE(G); |
136 | |
137 | // find the call to f in g |
138 | CallBase *CB = getCallInFunction(F: G); |
139 | ASSERT_TRUE(CB); |
140 | |
141 | const auto Features = getInliningCostFeaturesForCall(CB&: *CB); |
142 | ASSERT_TRUE(Features); |
143 | |
144 | // Check the predicted SROA cost |
145 | auto GetFeature = [&](InlineCostFeatureIndex I) { |
146 | return (*Features)[static_cast<size_t>(I)]; |
147 | }; |
148 | ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_savings), Savings[i]); |
149 | ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses), Losses[i]); |
150 | } |
151 | } |
152 | |
153 | } // namespace |
154 | |