1//===- unittest/AST/RecursiveASTVisitorTest.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "clang/AST/RecursiveASTVisitor.h"
10#include "clang/AST/ASTConsumer.h"
11#include "clang/AST/ASTContext.h"
12#include "clang/AST/Attr.h"
13#include "clang/AST/Decl.h"
14#include "clang/AST/TypeLoc.h"
15#include "clang/Frontend/FrontendAction.h"
16#include "clang/Tooling/Tooling.h"
17#include "llvm/ADT/FunctionExtras.h"
18#include "llvm/ADT/STLExtras.h"
19#include "gmock/gmock.h"
20#include "gtest/gtest.h"
21#include <cassert>
22
23using namespace clang;
24using ::testing::ElementsAre;
25
26namespace {
27class ProcessASTAction : public clang::ASTFrontendAction {
28public:
29 ProcessASTAction(llvm::unique_function<void(clang::ASTContext &)> Process)
30 : Process(std::move(Process)) {
31 assert(this->Process);
32 }
33
34 std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
35 StringRef InFile) {
36 class Consumer : public ASTConsumer {
37 public:
38 Consumer(llvm::function_ref<void(ASTContext &CTx)> Process)
39 : Process(Process) {}
40
41 void HandleTranslationUnit(ASTContext &Ctx) override { Process(Ctx); }
42
43 private:
44 llvm::function_ref<void(ASTContext &CTx)> Process;
45 };
46
47 return std::make_unique<Consumer>(args&: Process);
48 }
49
50private:
51 llvm::unique_function<void(clang::ASTContext &)> Process;
52};
53
54enum class VisitEvent {
55 StartTraverseFunction,
56 EndTraverseFunction,
57 StartTraverseAttr,
58 EndTraverseAttr,
59 StartTraverseEnum,
60 EndTraverseEnum,
61 StartTraverseTypedefType,
62 EndTraverseTypedefType,
63 StartTraverseObjCInterface,
64 EndTraverseObjCInterface,
65 StartTraverseObjCProtocol,
66 EndTraverseObjCProtocol,
67 StartTraverseObjCProtocolLoc,
68 EndTraverseObjCProtocolLoc,
69};
70
71class CollectInterestingEvents
72 : public RecursiveASTVisitor<CollectInterestingEvents> {
73public:
74 bool TraverseFunctionDecl(FunctionDecl *D) {
75 Events.push_back(x: VisitEvent::StartTraverseFunction);
76 bool Ret = RecursiveASTVisitor::TraverseFunctionDecl(D);
77 Events.push_back(x: VisitEvent::EndTraverseFunction);
78
79 return Ret;
80 }
81
82 bool TraverseAttr(Attr *A) {
83 Events.push_back(x: VisitEvent::StartTraverseAttr);
84 bool Ret = RecursiveASTVisitor::TraverseAttr(At: A);
85 Events.push_back(x: VisitEvent::EndTraverseAttr);
86
87 return Ret;
88 }
89
90 bool TraverseEnumDecl(EnumDecl *D) {
91 Events.push_back(x: VisitEvent::StartTraverseEnum);
92 bool Ret = RecursiveASTVisitor::TraverseEnumDecl(D);
93 Events.push_back(x: VisitEvent::EndTraverseEnum);
94
95 return Ret;
96 }
97
98 bool TraverseTypedefTypeLoc(TypedefTypeLoc TL) {
99 Events.push_back(x: VisitEvent::StartTraverseTypedefType);
100 bool Ret = RecursiveASTVisitor::TraverseTypedefTypeLoc(TL);
101 Events.push_back(x: VisitEvent::EndTraverseTypedefType);
102
103 return Ret;
104 }
105
106 bool TraverseObjCInterfaceDecl(ObjCInterfaceDecl *ID) {
107 Events.push_back(x: VisitEvent::StartTraverseObjCInterface);
108 bool Ret = RecursiveASTVisitor::TraverseObjCInterfaceDecl(ID);
109 Events.push_back(x: VisitEvent::EndTraverseObjCInterface);
110
111 return Ret;
112 }
113
114 bool TraverseObjCProtocolDecl(ObjCProtocolDecl *PD) {
115 Events.push_back(x: VisitEvent::StartTraverseObjCProtocol);
116 bool Ret = RecursiveASTVisitor::TraverseObjCProtocolDecl(PD);
117 Events.push_back(x: VisitEvent::EndTraverseObjCProtocol);
118
119 return Ret;
120 }
121
122 bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc) {
123 Events.push_back(x: VisitEvent::StartTraverseObjCProtocolLoc);
124 bool Ret = RecursiveASTVisitor::TraverseObjCProtocolLoc(ProtocolLoc);
125 Events.push_back(x: VisitEvent::EndTraverseObjCProtocolLoc);
126
127 return Ret;
128 }
129
130 std::vector<VisitEvent> takeEvents() && { return std::move(Events); }
131
132private:
133 std::vector<VisitEvent> Events;
134};
135
136std::vector<VisitEvent> collectEvents(llvm::StringRef Code,
137 const Twine &FileName = "input.cc") {
138 CollectInterestingEvents Visitor;
139 clang::tooling::runToolOnCode(
140 ToolAction: std::make_unique<ProcessASTAction>(
141 args: [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(AST&: Ctx); }),
142 Code, FileName);
143 return std::move(Visitor).takeEvents();
144}
145} // namespace
146
147TEST(RecursiveASTVisitorTest, AttributesInsideDecls) {
148 /// Check attributes are traversed inside TraverseFunctionDecl.
149 llvm::StringRef Code = R"cpp(
150__attribute__((annotate("something"))) int foo() { return 10; }
151 )cpp";
152
153 EXPECT_THAT(collectEvents(Code),
154 ElementsAre(VisitEvent::StartTraverseFunction,
155 VisitEvent::StartTraverseAttr,
156 VisitEvent::EndTraverseAttr,
157 VisitEvent::EndTraverseFunction));
158}
159
160TEST(RecursiveASTVisitorTest, EnumDeclWithBase) {
161 // Check enum and its integer base is visited.
162 llvm::StringRef Code = R"cpp(
163 typedef int Foo;
164 enum Bar : Foo;
165 )cpp";
166
167 EXPECT_THAT(collectEvents(Code),
168 ElementsAre(VisitEvent::StartTraverseEnum,
169 VisitEvent::StartTraverseTypedefType,
170 VisitEvent::EndTraverseTypedefType,
171 VisitEvent::EndTraverseEnum));
172}
173
174TEST(RecursiveASTVisitorTest, InterfaceDeclWithProtocols) {
175 // Check interface and its protocols are visited.
176 llvm::StringRef Code = R"cpp(
177 @protocol Foo
178 @end
179 @protocol Bar
180 @end
181
182 @interface SomeObject <Foo, Bar>
183 @end
184 )cpp";
185
186 EXPECT_THAT(collectEvents(Code, "input.m"),
187 ElementsAre(VisitEvent::StartTraverseObjCProtocol,
188 VisitEvent::EndTraverseObjCProtocol,
189 VisitEvent::StartTraverseObjCProtocol,
190 VisitEvent::EndTraverseObjCProtocol,
191 VisitEvent::StartTraverseObjCInterface,
192 VisitEvent::StartTraverseObjCProtocolLoc,
193 VisitEvent::EndTraverseObjCProtocolLoc,
194 VisitEvent::StartTraverseObjCProtocolLoc,
195 VisitEvent::EndTraverseObjCProtocolLoc,
196 VisitEvent::EndTraverseObjCInterface));
197}
198

source code of clang/unittests/AST/RecursiveASTVisitorTest.cpp