1//=======- CaptureTrackingTest.cpp - Unit test for the Capture Tracking ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Analysis/CaptureTracking.h"
10#include "llvm/AsmParser/Parser.h"
11#include "llvm/IR/Dominators.h"
12#include "llvm/IR/Instructions.h"
13#include "llvm/IR/LLVMContext.h"
14#include "llvm/IR/Module.h"
15#include "llvm/Support/SourceMgr.h"
16#include "gtest/gtest.h"
17
18using namespace llvm;
19
20TEST(CaptureTracking, MaxUsesToExplore) {
21 StringRef Assembly = R"(
22 ; Function Attrs: nounwind ssp uwtable
23 declare void @doesnt_capture(i8* nocapture, i8* nocapture, i8* nocapture,
24 i8* nocapture, i8* nocapture)
25
26 ; %arg has 5 uses
27 define void @test_few_uses(i8* %arg) {
28 call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
29 ret void
30 }
31
32 ; %arg has 50 uses
33 define void @test_many_uses(i8* %arg) {
34 call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
35 call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
36 call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
37 call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
38 call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
39 call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
40 call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
41 call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
42 call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
43 call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
44 ret void
45 }
46 )";
47
48 LLVMContext Context;
49 SMDiagnostic Error;
50 auto M = parseAssemblyString(AsmString: Assembly, Err&: Error, Context);
51 ASSERT_TRUE(M) << "Bad assembly?";
52
53 auto Test = [&M](const char *FName, unsigned FalseMaxUsesLimit,
54 unsigned TrueMaxUsesLimit) {
55 Function *F = M->getFunction(Name: FName);
56 ASSERT_NE(F, nullptr);
57 Value *Arg = &*F->arg_begin();
58 ASSERT_NE(Arg, nullptr);
59 ASSERT_FALSE(PointerMayBeCaptured(Arg, true, true, FalseMaxUsesLimit));
60 ASSERT_TRUE(PointerMayBeCaptured(Arg, true, true, TrueMaxUsesLimit));
61
62 BasicBlock *EntryBB = &F->getEntryBlock();
63 DominatorTree DT(*F);
64
65 Instruction *Ret = EntryBB->getTerminator();
66 ASSERT_TRUE(isa<ReturnInst>(Ret));
67 ASSERT_FALSE(PointerMayBeCapturedBefore(Arg, true, true, Ret, &DT, false,
68 FalseMaxUsesLimit));
69 ASSERT_TRUE(PointerMayBeCapturedBefore(Arg, true, true, Ret, &DT, false,
70 TrueMaxUsesLimit));
71 };
72
73 Test("test_few_uses", 6, 4);
74 Test("test_many_uses", 50, 30);
75}
76
77struct CollectingCaptureTracker : public CaptureTracker {
78 SmallVector<const Use *, 4> Captures;
79 void tooManyUses() override { }
80 bool captured(const Use *U) override {
81 Captures.push_back(Elt: U);
82 return false;
83 }
84};
85
86TEST(CaptureTracking, MultipleUsesInSameInstruction) {
87 StringRef Assembly = R"(
88 declare void @call(i8*, i8*, i8*)
89
90 define void @test(i8* %arg, i8** %ptr) {
91 call void @call(i8* %arg, i8* nocapture %arg, i8* %arg) [ "bundle"(i8* %arg) ]
92 cmpxchg i8** %ptr, i8* %arg, i8* %arg acq_rel monotonic
93 icmp eq i8* %arg, %arg
94 ret void
95 }
96 )";
97
98 LLVMContext Context;
99 SMDiagnostic Error;
100 auto M = parseAssemblyString(AsmString: Assembly, Err&: Error, Context);
101 ASSERT_TRUE(M) << "Bad assembly?";
102
103 Function *F = M->getFunction(Name: "test");
104 Value *Arg = &*F->arg_begin();
105 BasicBlock *BB = &F->getEntryBlock();
106 Instruction *Call = &*BB->begin();
107 Instruction *CmpXChg = Call->getNextNode();
108 Instruction *ICmp = CmpXChg->getNextNode();
109
110 CollectingCaptureTracker CT;
111 PointerMayBeCaptured(V: Arg, Tracker: &CT);
112 EXPECT_EQ(7u, CT.Captures.size());
113 // Call arg 1
114 EXPECT_EQ(Call, CT.Captures[0]->getUser());
115 EXPECT_EQ(0u, CT.Captures[0]->getOperandNo());
116 // Call arg 3
117 EXPECT_EQ(Call, CT.Captures[1]->getUser());
118 EXPECT_EQ(2u, CT.Captures[1]->getOperandNo());
119 // Operand bundle arg
120 EXPECT_EQ(Call, CT.Captures[2]->getUser());
121 EXPECT_EQ(3u, CT.Captures[2]->getOperandNo());
122 // Cmpxchg compare operand
123 EXPECT_EQ(CmpXChg, CT.Captures[3]->getUser());
124 EXPECT_EQ(1u, CT.Captures[3]->getOperandNo());
125 // Cmpxchg new value operand
126 EXPECT_EQ(CmpXChg, CT.Captures[4]->getUser());
127 EXPECT_EQ(2u, CT.Captures[4]->getOperandNo());
128 // ICmp first operand
129 EXPECT_EQ(ICmp, CT.Captures[5]->getUser());
130 EXPECT_EQ(0u, CT.Captures[5]->getOperandNo());
131 // ICmp second operand
132 EXPECT_EQ(ICmp, CT.Captures[6]->getUser());
133 EXPECT_EQ(1u, CT.Captures[6]->getOperandNo());
134}
135

source code of llvm/unittests/Analysis/CaptureTrackingTest.cpp