1 | //===- unittest/AST/RecursiveASTVisitorTest.cpp ---------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "clang/AST/RecursiveASTVisitor.h" |
10 | #include "clang/AST/ASTConsumer.h" |
11 | #include "clang/AST/ASTContext.h" |
12 | #include "clang/AST/Attr.h" |
13 | #include "clang/AST/Decl.h" |
14 | #include "clang/AST/TypeLoc.h" |
15 | #include "clang/Frontend/FrontendAction.h" |
16 | #include "clang/Tooling/Tooling.h" |
17 | #include "llvm/ADT/FunctionExtras.h" |
18 | #include "llvm/ADT/STLExtras.h" |
19 | #include "gmock/gmock.h" |
20 | #include "gtest/gtest.h" |
21 | #include <cassert> |
22 | |
23 | using namespace clang; |
24 | using ::testing::ElementsAre; |
25 | |
26 | namespace { |
27 | class ProcessASTAction : public clang::ASTFrontendAction { |
28 | public: |
29 | ProcessASTAction(llvm::unique_function<void(clang::ASTContext &)> Process) |
30 | : Process(std::move(Process)) { |
31 | assert(this->Process); |
32 | } |
33 | |
34 | std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI, |
35 | StringRef InFile) { |
36 | class Consumer : public ASTConsumer { |
37 | public: |
38 | Consumer(llvm::function_ref<void(ASTContext &CTx)> Process) |
39 | : Process(Process) {} |
40 | |
41 | void HandleTranslationUnit(ASTContext &Ctx) override { Process(Ctx); } |
42 | |
43 | private: |
44 | llvm::function_ref<void(ASTContext &CTx)> Process; |
45 | }; |
46 | |
47 | return std::make_unique<Consumer>(args&: Process); |
48 | } |
49 | |
50 | private: |
51 | llvm::unique_function<void(clang::ASTContext &)> Process; |
52 | }; |
53 | |
54 | enum class VisitEvent { |
55 | StartTraverseFunction, |
56 | EndTraverseFunction, |
57 | StartTraverseAttr, |
58 | EndTraverseAttr, |
59 | StartTraverseEnum, |
60 | EndTraverseEnum, |
61 | StartTraverseTypedefType, |
62 | EndTraverseTypedefType, |
63 | StartTraverseObjCInterface, |
64 | EndTraverseObjCInterface, |
65 | StartTraverseObjCProtocol, |
66 | EndTraverseObjCProtocol, |
67 | StartTraverseObjCProtocolLoc, |
68 | EndTraverseObjCProtocolLoc, |
69 | }; |
70 | |
71 | class CollectInterestingEvents |
72 | : public RecursiveASTVisitor<CollectInterestingEvents> { |
73 | public: |
74 | bool TraverseFunctionDecl(FunctionDecl *D) { |
75 | Events.push_back(x: VisitEvent::StartTraverseFunction); |
76 | bool Ret = RecursiveASTVisitor::TraverseFunctionDecl(D); |
77 | Events.push_back(x: VisitEvent::EndTraverseFunction); |
78 | |
79 | return Ret; |
80 | } |
81 | |
82 | bool TraverseAttr(Attr *A) { |
83 | Events.push_back(x: VisitEvent::StartTraverseAttr); |
84 | bool Ret = RecursiveASTVisitor::TraverseAttr(At: A); |
85 | Events.push_back(x: VisitEvent::EndTraverseAttr); |
86 | |
87 | return Ret; |
88 | } |
89 | |
90 | bool TraverseEnumDecl(EnumDecl *D) { |
91 | Events.push_back(x: VisitEvent::StartTraverseEnum); |
92 | bool Ret = RecursiveASTVisitor::TraverseEnumDecl(D); |
93 | Events.push_back(x: VisitEvent::EndTraverseEnum); |
94 | |
95 | return Ret; |
96 | } |
97 | |
98 | bool TraverseTypedefTypeLoc(TypedefTypeLoc TL) { |
99 | Events.push_back(x: VisitEvent::StartTraverseTypedefType); |
100 | bool Ret = RecursiveASTVisitor::TraverseTypedefTypeLoc(TL); |
101 | Events.push_back(x: VisitEvent::EndTraverseTypedefType); |
102 | |
103 | return Ret; |
104 | } |
105 | |
106 | bool TraverseObjCInterfaceDecl(ObjCInterfaceDecl *ID) { |
107 | Events.push_back(x: VisitEvent::StartTraverseObjCInterface); |
108 | bool Ret = RecursiveASTVisitor::TraverseObjCInterfaceDecl(ID); |
109 | Events.push_back(x: VisitEvent::EndTraverseObjCInterface); |
110 | |
111 | return Ret; |
112 | } |
113 | |
114 | bool TraverseObjCProtocolDecl(ObjCProtocolDecl *PD) { |
115 | Events.push_back(x: VisitEvent::StartTraverseObjCProtocol); |
116 | bool Ret = RecursiveASTVisitor::TraverseObjCProtocolDecl(PD); |
117 | Events.push_back(x: VisitEvent::EndTraverseObjCProtocol); |
118 | |
119 | return Ret; |
120 | } |
121 | |
122 | bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc) { |
123 | Events.push_back(x: VisitEvent::StartTraverseObjCProtocolLoc); |
124 | bool Ret = RecursiveASTVisitor::TraverseObjCProtocolLoc(ProtocolLoc); |
125 | Events.push_back(x: VisitEvent::EndTraverseObjCProtocolLoc); |
126 | |
127 | return Ret; |
128 | } |
129 | |
130 | std::vector<VisitEvent> takeEvents() && { return std::move(Events); } |
131 | |
132 | private: |
133 | std::vector<VisitEvent> Events; |
134 | }; |
135 | |
136 | std::vector<VisitEvent> collectEvents(llvm::StringRef Code, |
137 | const Twine &FileName = "input.cc" ) { |
138 | CollectInterestingEvents Visitor; |
139 | clang::tooling::runToolOnCode( |
140 | ToolAction: std::make_unique<ProcessASTAction>( |
141 | args: [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(AST&: Ctx); }), |
142 | Code, FileName); |
143 | return std::move(Visitor).takeEvents(); |
144 | } |
145 | } // namespace |
146 | |
147 | TEST(RecursiveASTVisitorTest, AttributesInsideDecls) { |
148 | /// Check attributes are traversed inside TraverseFunctionDecl. |
149 | llvm::StringRef Code = R"cpp( |
150 | __attribute__((annotate("something"))) int foo() { return 10; } |
151 | )cpp" ; |
152 | |
153 | EXPECT_THAT(collectEvents(Code), |
154 | ElementsAre(VisitEvent::StartTraverseFunction, |
155 | VisitEvent::StartTraverseAttr, |
156 | VisitEvent::EndTraverseAttr, |
157 | VisitEvent::EndTraverseFunction)); |
158 | } |
159 | |
160 | TEST(RecursiveASTVisitorTest, EnumDeclWithBase) { |
161 | // Check enum and its integer base is visited. |
162 | llvm::StringRef Code = R"cpp( |
163 | typedef int Foo; |
164 | enum Bar : Foo; |
165 | )cpp" ; |
166 | |
167 | EXPECT_THAT(collectEvents(Code), |
168 | ElementsAre(VisitEvent::StartTraverseEnum, |
169 | VisitEvent::StartTraverseTypedefType, |
170 | VisitEvent::EndTraverseTypedefType, |
171 | VisitEvent::EndTraverseEnum)); |
172 | } |
173 | |
174 | TEST(RecursiveASTVisitorTest, InterfaceDeclWithProtocols) { |
175 | // Check interface and its protocols are visited. |
176 | llvm::StringRef Code = R"cpp( |
177 | @protocol Foo |
178 | @end |
179 | @protocol Bar |
180 | @end |
181 | |
182 | @interface SomeObject <Foo, Bar> |
183 | @end |
184 | )cpp" ; |
185 | |
186 | EXPECT_THAT(collectEvents(Code, "input.m" ), |
187 | ElementsAre(VisitEvent::StartTraverseObjCProtocol, |
188 | VisitEvent::EndTraverseObjCProtocol, |
189 | VisitEvent::StartTraverseObjCProtocol, |
190 | VisitEvent::EndTraverseObjCProtocol, |
191 | VisitEvent::StartTraverseObjCInterface, |
192 | VisitEvent::StartTraverseObjCProtocolLoc, |
193 | VisitEvent::EndTraverseObjCProtocolLoc, |
194 | VisitEvent::StartTraverseObjCProtocolLoc, |
195 | VisitEvent::EndTraverseObjCProtocolLoc, |
196 | VisitEvent::EndTraverseObjCInterface)); |
197 | } |
198 | |