1 | //===- llvm/unittest/IR/LegacyPassManager.cpp - Legacy PassManager tests --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This unit test exercises the legacy pass manager infrastructure. We use the |
10 | // old names as well to ensure that the source-level compatibility is preserved |
11 | // where possible. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "llvm/IR/LegacyPassManager.h" |
16 | #include "llvm/Analysis/CallGraph.h" |
17 | #include "llvm/Analysis/CallGraphSCCPass.h" |
18 | #include "llvm/Analysis/LoopInfo.h" |
19 | #include "llvm/Analysis/LoopPass.h" |
20 | #include "llvm/AsmParser/Parser.h" |
21 | #include "llvm/IR/AbstractCallSite.h" |
22 | #include "llvm/IR/BasicBlock.h" |
23 | #include "llvm/IR/CallingConv.h" |
24 | #include "llvm/IR/DataLayout.h" |
25 | #include "llvm/IR/DerivedTypes.h" |
26 | #include "llvm/IR/Function.h" |
27 | #include "llvm/IR/GlobalVariable.h" |
28 | #include "llvm/IR/Instructions.h" |
29 | #include "llvm/IR/LLVMContext.h" |
30 | #include "llvm/IR/Module.h" |
31 | #include "llvm/IR/OptBisect.h" |
32 | #include "llvm/InitializePasses.h" |
33 | #include "llvm/Support/MathExtras.h" |
34 | #include "llvm/Support/SourceMgr.h" |
35 | #include "llvm/Support/raw_ostream.h" |
36 | #include "llvm/Transforms/Utils/CallGraphUpdater.h" |
37 | #include "gtest/gtest.h" |
38 | |
39 | using namespace llvm; |
40 | |
41 | namespace llvm { |
42 | void initializeModuleNDMPass(PassRegistry&); |
43 | void initializeFPassPass(PassRegistry&); |
44 | void initializeCGPassPass(PassRegistry&); |
45 | void initializeLPassPass(PassRegistry&); |
46 | |
47 | namespace { |
48 | // ND = no deps |
49 | // NM = no modifications |
50 | struct ModuleNDNM: public ModulePass { |
51 | public: |
52 | static char run; |
53 | static char ID; |
54 | ModuleNDNM() : ModulePass(ID) { } |
55 | bool runOnModule(Module &M) override { |
56 | run++; |
57 | return false; |
58 | } |
59 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
60 | AU.setPreservesAll(); |
61 | } |
62 | }; |
63 | char ModuleNDNM::ID=0; |
64 | char ModuleNDNM::run=0; |
65 | |
66 | struct ModuleNDM : public ModulePass { |
67 | public: |
68 | static char run; |
69 | static char ID; |
70 | ModuleNDM() : ModulePass(ID) {} |
71 | bool runOnModule(Module &M) override { |
72 | run++; |
73 | return true; |
74 | } |
75 | }; |
76 | char ModuleNDM::ID=0; |
77 | char ModuleNDM::run=0; |
78 | |
79 | struct ModuleNDM2 : public ModulePass { |
80 | public: |
81 | static char run; |
82 | static char ID; |
83 | ModuleNDM2() : ModulePass(ID) {} |
84 | bool runOnModule(Module &M) override { |
85 | run++; |
86 | return true; |
87 | } |
88 | }; |
89 | char ModuleNDM2::ID=0; |
90 | char ModuleNDM2::run=0; |
91 | |
92 | struct ModuleDNM : public ModulePass { |
93 | public: |
94 | static char run; |
95 | static char ID; |
96 | ModuleDNM() : ModulePass(ID) { |
97 | initializeModuleNDMPass(*PassRegistry::getPassRegistry()); |
98 | } |
99 | bool runOnModule(Module &M) override { |
100 | run++; |
101 | return false; |
102 | } |
103 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
104 | AU.addRequired<ModuleNDM>(); |
105 | AU.setPreservesAll(); |
106 | } |
107 | }; |
108 | char ModuleDNM::ID=0; |
109 | char ModuleDNM::run=0; |
110 | |
111 | template<typename P> |
112 | struct PassTestBase : public P { |
113 | protected: |
114 | static int runc; |
115 | static bool initialized; |
116 | static bool finalized; |
117 | int allocated; |
118 | void run() { |
119 | EXPECT_TRUE(initialized); |
120 | EXPECT_FALSE(finalized); |
121 | EXPECT_EQ(0, allocated); |
122 | allocated++; |
123 | runc++; |
124 | } |
125 | public: |
126 | static char ID; |
127 | static void finishedOK(int run) { |
128 | EXPECT_GT(runc, 0); |
129 | EXPECT_TRUE(initialized); |
130 | EXPECT_TRUE(finalized); |
131 | EXPECT_EQ(run, runc); |
132 | } |
133 | PassTestBase() : P(ID), allocated(0) { |
134 | initialized = false; |
135 | finalized = false; |
136 | runc = 0; |
137 | } |
138 | |
139 | void releaseMemory() override { |
140 | EXPECT_GT(runc, 0); |
141 | EXPECT_GT(allocated, 0); |
142 | allocated--; |
143 | } |
144 | }; |
145 | template<typename P> char PassTestBase<P>::ID; |
146 | template<typename P> int PassTestBase<P>::runc; |
147 | template<typename P> bool PassTestBase<P>::initialized; |
148 | template<typename P> bool PassTestBase<P>::finalized; |
149 | |
150 | template<typename T, typename P> |
151 | struct PassTest : public PassTestBase<P> { |
152 | public: |
153 | #ifndef _MSC_VER // MSVC complains that Pass is not base class. |
154 | using llvm::Pass::doInitialization; |
155 | using llvm::Pass::doFinalization; |
156 | #endif |
157 | bool doInitialization(T &t) override { |
158 | EXPECT_FALSE(PassTestBase<P>::initialized); |
159 | PassTestBase<P>::initialized = true; |
160 | return false; |
161 | } |
162 | bool doFinalization(T &t) override { |
163 | EXPECT_FALSE(PassTestBase<P>::finalized); |
164 | PassTestBase<P>::finalized = true; |
165 | EXPECT_EQ(0, PassTestBase<P>::allocated); |
166 | return false; |
167 | } |
168 | }; |
169 | |
170 | struct CGPass : public PassTest<CallGraph, CallGraphSCCPass> { |
171 | public: |
172 | CGPass() { |
173 | initializeCGPassPass(*PassRegistry::getPassRegistry()); |
174 | } |
175 | bool runOnSCC(CallGraphSCC &SCMM) override { |
176 | run(); |
177 | return false; |
178 | } |
179 | }; |
180 | |
181 | struct FPass : public PassTest<Module, FunctionPass> { |
182 | public: |
183 | bool runOnFunction(Function &F) override { |
184 | // FIXME: PR4112 |
185 | // EXPECT_TRUE(getAnalysisIfAvailable<DataLayout>()); |
186 | run(); |
187 | return false; |
188 | } |
189 | }; |
190 | |
191 | struct LPass : public PassTestBase<LoopPass> { |
192 | private: |
193 | static int initcount; |
194 | static int fincount; |
195 | public: |
196 | LPass() { |
197 | initializeLPassPass(*PassRegistry::getPassRegistry()); |
198 | initcount = 0; fincount=0; |
199 | EXPECT_FALSE(initialized); |
200 | } |
201 | static void finishedOK(int run, int finalized) { |
202 | PassTestBase<LoopPass>::finishedOK(run); |
203 | EXPECT_EQ(run, initcount); |
204 | EXPECT_EQ(finalized, fincount); |
205 | } |
206 | using llvm::Pass::doInitialization; |
207 | using llvm::Pass::doFinalization; |
208 | bool doInitialization(Loop* L, LPPassManager &LPM) override { |
209 | initialized = true; |
210 | initcount++; |
211 | return false; |
212 | } |
213 | bool runOnLoop(Loop *L, LPPassManager &LPM) override { |
214 | run(); |
215 | return false; |
216 | } |
217 | bool doFinalization() override { |
218 | fincount++; |
219 | finalized = true; |
220 | return false; |
221 | } |
222 | }; |
223 | int LPass::initcount=0; |
224 | int LPass::fincount=0; |
225 | |
226 | struct OnTheFlyTest: public ModulePass { |
227 | public: |
228 | static char ID; |
229 | OnTheFlyTest() : ModulePass(ID) { |
230 | initializeFPassPass(*PassRegistry::getPassRegistry()); |
231 | } |
232 | bool runOnModule(Module &M) override { |
233 | for (Module::iterator I=M.begin(),E=M.end(); I != E; ++I) { |
234 | Function &F = *I; |
235 | { |
236 | SCOPED_TRACE("Running on the fly function pass" ); |
237 | getAnalysis<FPass>(F); |
238 | } |
239 | } |
240 | return false; |
241 | } |
242 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
243 | AU.addRequired<FPass>(); |
244 | } |
245 | }; |
246 | char OnTheFlyTest::ID=0; |
247 | |
248 | TEST(PassManager, RunOnce) { |
249 | LLVMContext Context; |
250 | Module M("test-once" , Context); |
251 | struct ModuleNDNM *mNDNM = new ModuleNDNM(); |
252 | struct ModuleDNM *mDNM = new ModuleDNM(); |
253 | struct ModuleNDM *mNDM = new ModuleNDM(); |
254 | struct ModuleNDM2 *mNDM2 = new ModuleNDM2(); |
255 | |
256 | mNDM->run = mNDNM->run = mDNM->run = mNDM2->run = 0; |
257 | |
258 | legacy::PassManager Passes; |
259 | Passes.add(P: mNDM2); |
260 | Passes.add(P: mNDM); |
261 | Passes.add(P: mNDNM); |
262 | Passes.add(P: mDNM); |
263 | |
264 | Passes.run(M); |
265 | // each pass must be run exactly once, since nothing invalidates them |
266 | EXPECT_EQ(1, mNDM->run); |
267 | EXPECT_EQ(1, mNDNM->run); |
268 | EXPECT_EQ(1, mDNM->run); |
269 | EXPECT_EQ(1, mNDM2->run); |
270 | } |
271 | |
272 | TEST(PassManager, ReRun) { |
273 | LLVMContext Context; |
274 | Module M("test-rerun" , Context); |
275 | struct ModuleNDNM *mNDNM = new ModuleNDNM(); |
276 | struct ModuleDNM *mDNM = new ModuleDNM(); |
277 | struct ModuleNDM *mNDM = new ModuleNDM(); |
278 | struct ModuleNDM2 *mNDM2 = new ModuleNDM2(); |
279 | |
280 | mNDM->run = mNDNM->run = mDNM->run = mNDM2->run = 0; |
281 | |
282 | legacy::PassManager Passes; |
283 | Passes.add(P: mNDM); |
284 | Passes.add(P: mNDNM); |
285 | Passes.add(P: mNDM2);// invalidates mNDM needed by mDNM |
286 | Passes.add(P: mDNM); |
287 | |
288 | Passes.run(M); |
289 | // Some passes must be rerun because a pass that modified the |
290 | // module/function was run in between |
291 | EXPECT_EQ(2, mNDM->run); |
292 | EXPECT_EQ(1, mNDNM->run); |
293 | EXPECT_EQ(1, mNDM2->run); |
294 | EXPECT_EQ(1, mDNM->run); |
295 | } |
296 | |
297 | Module *makeLLVMModule(LLVMContext &Context); |
298 | |
299 | template<typename T> |
300 | void MemoryTestHelper(int run) { |
301 | LLVMContext Context; |
302 | std::unique_ptr<Module> M(makeLLVMModule(Context)); |
303 | T *P = new T(); |
304 | legacy::PassManager Passes; |
305 | Passes.add(P); |
306 | Passes.run(M&: *M); |
307 | T::finishedOK(run); |
308 | } |
309 | |
310 | template<typename T> |
311 | void MemoryTestHelper(int run, int N) { |
312 | LLVMContext Context; |
313 | Module *M = makeLLVMModule(Context); |
314 | T *P = new T(); |
315 | legacy::PassManager Passes; |
316 | Passes.add(P); |
317 | Passes.run(M&: *M); |
318 | T::finishedOK(run, N); |
319 | delete M; |
320 | } |
321 | |
322 | TEST(PassManager, Memory) { |
323 | // SCC#1: test1->test2->test3->test1 |
324 | // SCC#2: test4 |
325 | // SCC#3: indirect call node |
326 | { |
327 | SCOPED_TRACE("Callgraph pass" ); |
328 | MemoryTestHelper<CGPass>(run: 3); |
329 | } |
330 | |
331 | { |
332 | SCOPED_TRACE("Function pass" ); |
333 | MemoryTestHelper<FPass>(run: 4);// 4 functions |
334 | } |
335 | |
336 | { |
337 | SCOPED_TRACE("Loop pass" ); |
338 | MemoryTestHelper<LPass>(run: 2, N: 1); //2 loops, 1 function |
339 | } |
340 | |
341 | } |
342 | |
343 | TEST(PassManager, MemoryOnTheFly) { |
344 | LLVMContext Context; |
345 | Module *M = makeLLVMModule(Context); |
346 | { |
347 | SCOPED_TRACE("Running OnTheFlyTest" ); |
348 | struct OnTheFlyTest *O = new OnTheFlyTest(); |
349 | legacy::PassManager Passes; |
350 | Passes.add(P: O); |
351 | Passes.run(M&: *M); |
352 | |
353 | FPass::finishedOK(run: 4); |
354 | } |
355 | delete M; |
356 | } |
357 | |
358 | // Skips or runs optional passes. |
359 | struct CustomOptPassGate : public OptPassGate { |
360 | bool Skip; |
361 | CustomOptPassGate(bool Skip) : Skip(Skip) { } |
362 | bool shouldRunPass(const StringRef PassName, StringRef IRDescription) override { |
363 | return !Skip; |
364 | } |
365 | bool isEnabled() const override { return true; } |
366 | }; |
367 | |
368 | // Optional module pass. |
369 | struct ModuleOpt: public ModulePass { |
370 | char run = 0; |
371 | static char ID; |
372 | ModuleOpt() : ModulePass(ID) { } |
373 | bool runOnModule(Module &M) override { |
374 | if (!skipModule(M)) |
375 | run++; |
376 | return false; |
377 | } |
378 | }; |
379 | char ModuleOpt::ID=0; |
380 | |
381 | TEST(PassManager, CustomOptPassGate) { |
382 | LLVMContext Context0; |
383 | LLVMContext Context1; |
384 | LLVMContext Context2; |
385 | CustomOptPassGate SkipOptionalPasses(true); |
386 | CustomOptPassGate RunOptionalPasses(false); |
387 | |
388 | Module M0("custom-opt-bisect" , Context0); |
389 | Module M1("custom-opt-bisect" , Context1); |
390 | Module M2("custom-opt-bisect2" , Context2); |
391 | struct ModuleOpt *mOpt0 = new ModuleOpt(); |
392 | struct ModuleOpt *mOpt1 = new ModuleOpt(); |
393 | struct ModuleOpt *mOpt2 = new ModuleOpt(); |
394 | |
395 | mOpt0->run = mOpt1->run = mOpt2->run = 0; |
396 | |
397 | legacy::PassManager Passes0; |
398 | legacy::PassManager Passes1; |
399 | legacy::PassManager Passes2; |
400 | |
401 | Passes0.add(P: mOpt0); |
402 | Passes1.add(P: mOpt1); |
403 | Passes2.add(P: mOpt2); |
404 | |
405 | Context1.setOptPassGate(SkipOptionalPasses); |
406 | Context2.setOptPassGate(RunOptionalPasses); |
407 | |
408 | Passes0.run(M&: M0); |
409 | Passes1.run(M&: M1); |
410 | Passes2.run(M&: M2); |
411 | |
412 | // By default optional passes are run. |
413 | EXPECT_EQ(1, mOpt0->run); |
414 | |
415 | // The first context skips optional passes. |
416 | EXPECT_EQ(0, mOpt1->run); |
417 | |
418 | // The second context runs optional passes. |
419 | EXPECT_EQ(1, mOpt2->run); |
420 | } |
421 | |
422 | Module *makeLLVMModule(LLVMContext &Context) { |
423 | // Module Construction |
424 | Module *mod = new Module("test-mem" , Context); |
425 | mod->setDataLayout("e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-" |
426 | "i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-" |
427 | "a:0:64-s:64:64-f80:128:128" ); |
428 | mod->setTargetTriple("x86_64-unknown-linux-gnu" ); |
429 | |
430 | // Type Definitions |
431 | std::vector<Type*>FuncTy_0_args; |
432 | FunctionType *FuncTy_0 = FunctionType::get( |
433 | /*Result=*/IntegerType::get(C&: Context, NumBits: 32), |
434 | /*Params=*/FuncTy_0_args, |
435 | /*isVarArg=*/false); |
436 | |
437 | std::vector<Type*>FuncTy_2_args; |
438 | FuncTy_2_args.push_back(x: IntegerType::get(C&: Context, NumBits: 1)); |
439 | FunctionType *FuncTy_2 = FunctionType::get( |
440 | /*Result=*/Type::getVoidTy(C&: Context), |
441 | /*Params=*/FuncTy_2_args, |
442 | /*isVarArg=*/false); |
443 | |
444 | // Function Declarations |
445 | |
446 | Function* func_test1 = Function::Create( |
447 | /*Type=*/Ty: FuncTy_0, |
448 | /*Linkage=*/GlobalValue::ExternalLinkage, |
449 | /*Name=*/N: "test1" , M: mod); |
450 | func_test1->setCallingConv(CallingConv::C); |
451 | AttributeList func_test1_PAL; |
452 | func_test1->setAttributes(func_test1_PAL); |
453 | |
454 | Function* func_test2 = Function::Create( |
455 | /*Type=*/Ty: FuncTy_0, |
456 | /*Linkage=*/GlobalValue::ExternalLinkage, |
457 | /*Name=*/N: "test2" , M: mod); |
458 | func_test2->setCallingConv(CallingConv::C); |
459 | AttributeList func_test2_PAL; |
460 | func_test2->setAttributes(func_test2_PAL); |
461 | |
462 | Function* func_test3 = Function::Create( |
463 | /*Type=*/Ty: FuncTy_0, |
464 | /*Linkage=*/GlobalValue::InternalLinkage, |
465 | /*Name=*/N: "test3" , M: mod); |
466 | func_test3->setCallingConv(CallingConv::C); |
467 | AttributeList func_test3_PAL; |
468 | func_test3->setAttributes(func_test3_PAL); |
469 | |
470 | Function* func_test4 = Function::Create( |
471 | /*Type=*/Ty: FuncTy_2, |
472 | /*Linkage=*/GlobalValue::ExternalLinkage, |
473 | /*Name=*/N: "test4" , M: mod); |
474 | func_test4->setCallingConv(CallingConv::C); |
475 | AttributeList func_test4_PAL; |
476 | func_test4->setAttributes(func_test4_PAL); |
477 | |
478 | // Global Variable Declarations |
479 | |
480 | |
481 | // Constant Definitions |
482 | |
483 | // Global Variable Definitions |
484 | |
485 | // Function Definitions |
486 | |
487 | // Function: test1 (func_test1) |
488 | { |
489 | |
490 | BasicBlock *label_entry = |
491 | BasicBlock::Create(Context, Name: "entry" , Parent: func_test1, InsertBefore: nullptr); |
492 | |
493 | // Block entry (label_entry) |
494 | CallInst* int32_3 = CallInst::Create(Func: func_test2, NameStr: "" , InsertAtEnd: label_entry); |
495 | int32_3->setCallingConv(CallingConv::C); |
496 | int32_3->setTailCall(false); |
497 | AttributeList int32_3_PAL; |
498 | int32_3->setAttributes(int32_3_PAL); |
499 | |
500 | ReturnInst::Create(C&: Context, retVal: int32_3, InsertAtEnd: label_entry); |
501 | } |
502 | |
503 | // Function: test2 (func_test2) |
504 | { |
505 | |
506 | BasicBlock *label_entry_5 = |
507 | BasicBlock::Create(Context, Name: "entry" , Parent: func_test2, InsertBefore: nullptr); |
508 | |
509 | // Block entry (label_entry_5) |
510 | CallInst* int32_6 = CallInst::Create(Func: func_test3, NameStr: "" , InsertAtEnd: label_entry_5); |
511 | int32_6->setCallingConv(CallingConv::C); |
512 | int32_6->setTailCall(false); |
513 | AttributeList int32_6_PAL; |
514 | int32_6->setAttributes(int32_6_PAL); |
515 | |
516 | ReturnInst::Create(C&: Context, retVal: int32_6, InsertAtEnd: label_entry_5); |
517 | } |
518 | |
519 | // Function: test3 (func_test3) |
520 | { |
521 | |
522 | BasicBlock *label_entry_8 = |
523 | BasicBlock::Create(Context, Name: "entry" , Parent: func_test3, InsertBefore: nullptr); |
524 | |
525 | // Block entry (label_entry_8) |
526 | CallInst* int32_9 = CallInst::Create(Func: func_test1, NameStr: "" , InsertAtEnd: label_entry_8); |
527 | int32_9->setCallingConv(CallingConv::C); |
528 | int32_9->setTailCall(false); |
529 | AttributeList int32_9_PAL; |
530 | int32_9->setAttributes(int32_9_PAL); |
531 | |
532 | ReturnInst::Create(C&: Context, retVal: int32_9, InsertAtEnd: label_entry_8); |
533 | } |
534 | |
535 | // Function: test4 (func_test4) |
536 | { |
537 | Function::arg_iterator args = func_test4->arg_begin(); |
538 | Value *int1_f = &*args++; |
539 | int1_f->setName("f" ); |
540 | |
541 | BasicBlock *label_entry_11 = |
542 | BasicBlock::Create(Context, Name: "entry" , Parent: func_test4, InsertBefore: nullptr); |
543 | BasicBlock *label_bb = |
544 | BasicBlock::Create(Context, Name: "bb" , Parent: func_test4, InsertBefore: nullptr); |
545 | BasicBlock *label_bb1 = |
546 | BasicBlock::Create(Context, Name: "bb1" , Parent: func_test4, InsertBefore: nullptr); |
547 | BasicBlock *label_return = |
548 | BasicBlock::Create(Context, Name: "return" , Parent: func_test4, InsertBefore: nullptr); |
549 | |
550 | // Block entry (label_entry_11) |
551 | auto *AI = new AllocaInst(func_test3->getType(), 0, "func3ptr" , |
552 | label_entry_11); |
553 | new StoreInst(func_test3, AI, label_entry_11); |
554 | BranchInst::Create(IfTrue: label_bb, InsertAtEnd: label_entry_11); |
555 | |
556 | // Block bb (label_bb) |
557 | BranchInst::Create(IfTrue: label_bb, IfFalse: label_bb1, Cond: int1_f, InsertAtEnd: label_bb); |
558 | |
559 | // Block bb1 (label_bb1) |
560 | BranchInst::Create(IfTrue: label_bb1, IfFalse: label_return, Cond: int1_f, InsertAtEnd: label_bb1); |
561 | |
562 | // Block return (label_return) |
563 | ReturnInst::Create(C&: Context, InsertAtEnd: label_return); |
564 | } |
565 | return mod; |
566 | } |
567 | |
568 | /// Split a simple function which contains only a call and a return into two |
569 | /// such that the first calls the second and the second whoever was called |
570 | /// initially. |
571 | Function *splitSimpleFunction(Function &F) { |
572 | LLVMContext &Context = F.getContext(); |
573 | Function *SF = Function::Create(Ty: F.getFunctionType(), Linkage: F.getLinkage(), |
574 | N: F.getName() + "b" , M: F.getParent()); |
575 | F.setName(F.getName() + "a" ); |
576 | BasicBlock *Entry = BasicBlock::Create(Context, Name: "entry" , Parent: SF, InsertBefore: nullptr); |
577 | CallInst &CI = cast<CallInst>(Val&: F.getEntryBlock().front()); |
578 | CI.clone()->insertBefore(InsertPos: ReturnInst::Create(C&: Context, InsertAtEnd: Entry)); |
579 | CI.setCalledFunction(SF); |
580 | return SF; |
581 | } |
582 | |
583 | struct CGModifierPass : public CGPass { |
584 | unsigned NumSCCs = 0; |
585 | unsigned NumFns = 0; |
586 | unsigned NumFnDecls = 0; |
587 | unsigned SetupWorked = 0; |
588 | unsigned NumExtCalledBefore = 0; |
589 | unsigned NumExtCalledAfter = 0; |
590 | |
591 | CallGraphUpdater CGU; |
592 | |
593 | bool runOnSCC(CallGraphSCC &SCMM) override { |
594 | ++NumSCCs; |
595 | for (CallGraphNode *N : SCMM) { |
596 | if (N->getFunction()){ |
597 | ++NumFns; |
598 | NumFnDecls += N->getFunction()->isDeclaration(); |
599 | } |
600 | } |
601 | CGPass::run(); |
602 | |
603 | CallGraph &CG = const_cast<CallGraph &>(SCMM.getCallGraph()); |
604 | CallGraphNode *ExtCallingNode = CG.getExternalCallingNode(); |
605 | NumExtCalledBefore = ExtCallingNode->size(); |
606 | |
607 | if (SCMM.size() <= 1) |
608 | return false; |
609 | |
610 | CallGraphNode *N = *(SCMM.begin()); |
611 | Function *F = N->getFunction(); |
612 | Module *M = F->getParent(); |
613 | Function *Test1F = M->getFunction(Name: "test1" ); |
614 | Function *Test2aF = M->getFunction(Name: "test2a" ); |
615 | Function *Test2bF = M->getFunction(Name: "test2b" ); |
616 | Function *Test3F = M->getFunction(Name: "test3" ); |
617 | |
618 | auto InSCC = [&](Function *Fn) { |
619 | return llvm::any_of(Range&: SCMM, P: [Fn](CallGraphNode *CGN) { |
620 | return CGN->getFunction() == Fn; |
621 | }); |
622 | }; |
623 | |
624 | if (!Test1F || !Test2aF || !Test2bF || !Test3F || !InSCC(Test1F) || |
625 | !InSCC(Test2aF) || !InSCC(Test2bF) || !InSCC(Test3F)) |
626 | return false; |
627 | |
628 | CallInst *CI = dyn_cast<CallInst>(Val: &Test1F->getEntryBlock().front()); |
629 | if (!CI || CI->getCalledFunction() != Test2aF) |
630 | return false; |
631 | |
632 | SetupWorked += 1; |
633 | |
634 | // Create a replica of test3 and just move the blocks there. |
635 | Function *Test3FRepl = Function::Create( |
636 | /*Type=*/Ty: Test3F->getFunctionType(), |
637 | /*Linkage=*/GlobalValue::InternalLinkage, |
638 | /*Name=*/N: "test3repl" , M: Test3F->getParent()); |
639 | while (!Test3F->empty()) { |
640 | BasicBlock &BB = Test3F->front(); |
641 | BB.removeFromParent(); |
642 | BB.insertInto(Parent: Test3FRepl); |
643 | } |
644 | |
645 | CGU.initialize(CG, SCC&: SCMM); |
646 | |
647 | // Replace test3 with the replica. This is legal as it is actually |
648 | // internal and the "capturing use" is not really capturing anything. |
649 | CGU.replaceFunctionWith(OldFn&: *Test3F, NewFn&: *Test3FRepl); |
650 | Test3F->replaceAllUsesWith(V: Test3FRepl); |
651 | |
652 | // Rewrite the call in test1 to point to the replica of 3 not test2. |
653 | CI->setCalledFunction(Test3FRepl); |
654 | |
655 | // Delete test2a and test2b and reanalyze 1 as we changed calls inside. |
656 | CGU.removeFunction(Fn&: *Test2aF); |
657 | CGU.removeFunction(Fn&: *Test2bF); |
658 | CGU.reanalyzeFunction(Fn&: *Test1F); |
659 | |
660 | return true; |
661 | } |
662 | |
663 | bool doFinalization(CallGraph &CG) override { |
664 | CGU.finalize(); |
665 | // We removed test2 and replaced the internal test3. |
666 | NumExtCalledAfter = CG.getExternalCallingNode()->size(); |
667 | return true; |
668 | } |
669 | }; |
670 | |
671 | TEST(PassManager, CallGraphUpdater0) { |
672 | // SCC#1: test1->test2a->test2b->test3->test1 |
673 | // SCC#2: test4 |
674 | // SCC#3: test3 (the empty function declaration as we replaced it with |
675 | // test3repl when we visited SCC#1) |
676 | // SCC#4: test2a->test2b (the empty function declarations as we deleted |
677 | // these functions when we visited SCC#1) |
678 | // SCC#5: indirect call node |
679 | |
680 | LLVMContext Context; |
681 | std::unique_ptr<Module> M(makeLLVMModule(Context)); |
682 | ASSERT_EQ(M->getFunctionList().size(), 4U); |
683 | Function *F = M->getFunction(Name: "test2" ); |
684 | Function *SF = splitSimpleFunction(F&: *F); |
685 | CallInst::Create(Func: F, NameStr: "" , InsertBefore: &*SF->getEntryBlock().getFirstInsertionPt()); |
686 | ASSERT_EQ(M->getFunctionList().size(), 5U); |
687 | CGModifierPass *P = new CGModifierPass(); |
688 | legacy::PassManager Passes; |
689 | Passes.add(P); |
690 | Passes.run(M&: *M); |
691 | ASSERT_EQ(P->SetupWorked, 1U); |
692 | ASSERT_EQ(P->NumSCCs, 4U); |
693 | ASSERT_EQ(P->NumFns, 6U); |
694 | ASSERT_EQ(P->NumFnDecls, 1U); |
695 | ASSERT_EQ(M->getFunctionList().size(), 3U); |
696 | ASSERT_EQ(P->NumExtCalledBefore, /* test1, 2a, 2b, 3, 4 */ 5U); |
697 | ASSERT_EQ(P->NumExtCalledAfter, /* test1, 3repl, 4 */ 3U); |
698 | } |
699 | |
700 | // Test for call graph SCC pass that replaces all callback call instructions |
701 | // with clones and updates CallGraph by calling CallGraph::replaceCallEdge() |
702 | // method. Test is expected to complete successfully after running pass on |
703 | // all SCCs in the test module. |
704 | struct CallbackCallsModifierPass : public CGPass { |
705 | bool runOnSCC(CallGraphSCC &SCC) override { |
706 | CGPass::run(); |
707 | |
708 | CallGraph &CG = const_cast<CallGraph &>(SCC.getCallGraph()); |
709 | |
710 | bool Changed = false; |
711 | for (CallGraphNode *CGN : SCC) { |
712 | Function *F = CGN->getFunction(); |
713 | if (!F || F->isDeclaration()) |
714 | continue; |
715 | |
716 | SmallVector<CallBase *, 4u> Calls; |
717 | for (Use &U : F->uses()) { |
718 | AbstractCallSite ACS(&U); |
719 | if (!ACS || !ACS.isCallbackCall() || !ACS.isCallee(U: &U)) |
720 | continue; |
721 | Calls.push_back(Elt: cast<CallBase>(Val: ACS.getInstruction())); |
722 | } |
723 | if (Calls.empty()) |
724 | continue; |
725 | |
726 | for (CallBase *OldCB : Calls) { |
727 | CallGraphNode *CallerCGN = CG[OldCB->getParent()->getParent()]; |
728 | assert(any_of(*CallerCGN, |
729 | [CGN](const CallGraphNode::CallRecord &CallRecord) { |
730 | return CallRecord.second == CGN; |
731 | }) && |
732 | "function is not a callee" ); |
733 | |
734 | CallBase *NewCB = cast<CallBase>(Val: OldCB->clone()); |
735 | |
736 | NewCB->insertBefore(InsertPos: OldCB); |
737 | NewCB->takeName(V: OldCB); |
738 | |
739 | CallerCGN->replaceCallEdge(Call&: *OldCB, NewCall&: *NewCB, NewNode: CG[F]); |
740 | |
741 | OldCB->replaceAllUsesWith(V: NewCB); |
742 | OldCB->eraseFromParent(); |
743 | } |
744 | Changed = true; |
745 | } |
746 | return Changed; |
747 | } |
748 | }; |
749 | |
750 | TEST(PassManager, CallbackCallsModifier0) { |
751 | LLVMContext Context; |
752 | |
753 | const char *IR = "define void @foo() {\n" |
754 | " call void @broker(void (i8*)* @callback0, i8* null)\n" |
755 | " call void @broker(void (i8*)* @callback1, i8* null)\n" |
756 | " ret void\n" |
757 | "}\n" |
758 | "\n" |
759 | "declare !callback !0 void @broker(void (i8*)*, i8*)\n" |
760 | "\n" |
761 | "define internal void @callback0(i8* %arg) {\n" |
762 | " ret void\n" |
763 | "}\n" |
764 | "\n" |
765 | "define internal void @callback1(i8* %arg) {\n" |
766 | " ret void\n" |
767 | "}\n" |
768 | "\n" |
769 | "!0 = !{!1}\n" |
770 | "!1 = !{i64 0, i64 1, i1 false}" ; |
771 | |
772 | SMDiagnostic Err; |
773 | std::unique_ptr<Module> M = parseAssemblyString(AsmString: IR, Err, Context); |
774 | if (!M) |
775 | Err.print(ProgName: "LegacyPassManagerTest" , S&: errs()); |
776 | |
777 | CallbackCallsModifierPass *P = new CallbackCallsModifierPass(); |
778 | legacy::PassManager Passes; |
779 | Passes.add(P); |
780 | Passes.run(M&: *M); |
781 | } |
782 | } |
783 | } |
784 | |
785 | INITIALIZE_PASS(ModuleNDM, "mndm" , "mndm" , false, false) |
786 | INITIALIZE_PASS_BEGIN(CGPass, "cgp" ,"cgp" , false, false) |
787 | INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) |
788 | INITIALIZE_PASS_END(CGPass, "cgp" ,"cgp" , false, false) |
789 | INITIALIZE_PASS(FPass, "fp" ,"fp" , false, false) |
790 | INITIALIZE_PASS_BEGIN(LPass, "lp" ,"lp" , false, false) |
791 | INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) |
792 | INITIALIZE_PASS_END(LPass, "lp" ,"lp" , false, false) |
793 | |