1 | //===--- TestVisitor.h ------------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | /// |
9 | /// \file |
10 | /// \brief Defines utility templates for RecursiveASTVisitor related tests. |
11 | /// |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H |
15 | #define LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H |
16 | |
17 | #include "clang/AST/ASTConsumer.h" |
18 | #include "clang/AST/ASTContext.h" |
19 | #include "clang/AST/RecursiveASTVisitor.h" |
20 | #include "clang/Frontend/CompilerInstance.h" |
21 | #include "clang/Frontend/FrontendAction.h" |
22 | #include "clang/Tooling/Tooling.h" |
23 | #include "gtest/gtest.h" |
24 | #include <vector> |
25 | |
26 | namespace clang { |
27 | |
28 | /// \brief Base class for simple RecursiveASTVisitor based tests. |
29 | /// |
30 | /// This is a drop-in replacement for RecursiveASTVisitor itself, with the |
31 | /// additional capability of running it over a snippet of code. |
32 | /// |
33 | /// Visits template instantiations and implicit code by default. |
34 | template <typename T> |
35 | class TestVisitor : public RecursiveASTVisitor<T> { |
36 | public: |
37 | TestVisitor() { } |
38 | |
39 | virtual ~TestVisitor() { } |
40 | |
41 | enum Language { |
42 | Lang_C, |
43 | Lang_CXX98, |
44 | Lang_CXX11, |
45 | Lang_CXX14, |
46 | Lang_CXX17, |
47 | Lang_CXX2a, |
48 | Lang_OBJC, |
49 | Lang_OBJCXX11, |
50 | Lang_CXX = Lang_CXX98 |
51 | }; |
52 | |
53 | /// \brief Runs the current AST visitor over the given code. |
54 | bool runOver(StringRef Code, Language L = Lang_CXX) { |
55 | std::vector<std::string> Args; |
56 | switch (L) { |
57 | case Lang_C: |
58 | Args.push_back(x: "-x" ); |
59 | Args.push_back(x: "c" ); |
60 | break; |
61 | case Lang_CXX98: Args.push_back(x: "-std=c++98" ); break; |
62 | case Lang_CXX11: Args.push_back(x: "-std=c++11" ); break; |
63 | case Lang_CXX14: Args.push_back(x: "-std=c++14" ); break; |
64 | case Lang_CXX17: Args.push_back(x: "-std=c++17" ); break; |
65 | case Lang_CXX2a: Args.push_back(x: "-std=c++2a" ); break; |
66 | case Lang_OBJC: |
67 | Args.push_back(x: "-ObjC" ); |
68 | Args.push_back(x: "-fobjc-runtime=macosx-10.12.0" ); |
69 | break; |
70 | case Lang_OBJCXX11: |
71 | Args.push_back(x: "-ObjC++" ); |
72 | Args.push_back(x: "-std=c++11" ); |
73 | Args.push_back(x: "-fblocks" ); |
74 | break; |
75 | } |
76 | return tooling::runToolOnCodeWithArgs(CreateTestAction(), Code, Args); |
77 | } |
78 | |
79 | bool shouldVisitTemplateInstantiations() const { |
80 | return true; |
81 | } |
82 | |
83 | bool shouldVisitImplicitCode() const { |
84 | return true; |
85 | } |
86 | |
87 | protected: |
88 | virtual std::unique_ptr<ASTFrontendAction> CreateTestAction() { |
89 | return std::make_unique<TestAction>(this); |
90 | } |
91 | |
92 | class FindConsumer : public ASTConsumer { |
93 | public: |
94 | FindConsumer(TestVisitor *Visitor) : Visitor(Visitor) {} |
95 | |
96 | void HandleTranslationUnit(clang::ASTContext &Context) override { |
97 | Visitor->Context = &Context; |
98 | Visitor->TraverseDecl(Context.getTranslationUnitDecl()); |
99 | } |
100 | |
101 | private: |
102 | TestVisitor *Visitor; |
103 | }; |
104 | |
105 | class TestAction : public ASTFrontendAction { |
106 | public: |
107 | TestAction(TestVisitor *Visitor) : Visitor(Visitor) {} |
108 | |
109 | std::unique_ptr<clang::ASTConsumer> |
110 | CreateASTConsumer(CompilerInstance &, llvm::StringRef dummy) override { |
111 | /// TestConsumer will be deleted by the framework calling us. |
112 | return std::make_unique<FindConsumer>(Visitor); |
113 | } |
114 | |
115 | protected: |
116 | TestVisitor *Visitor; |
117 | }; |
118 | |
119 | ASTContext *Context; |
120 | }; |
121 | |
122 | /// \brief A RecursiveASTVisitor to check that certain matches are (or are |
123 | /// not) observed during visitation. |
124 | /// |
125 | /// This is a RecursiveASTVisitor for testing the RecursiveASTVisitor itself, |
126 | /// and allows simple creation of test visitors running matches on only a small |
127 | /// subset of the Visit* methods. |
128 | template <typename T, template <typename> class Visitor = TestVisitor> |
129 | class ExpectedLocationVisitor : public Visitor<T> { |
130 | public: |
131 | /// \brief Expect 'Match' *not* to occur at the given 'Line' and 'Column'. |
132 | /// |
133 | /// Any number of matches can be disallowed. |
134 | void DisallowMatch(Twine Match, unsigned Line, unsigned Column) { |
135 | DisallowedMatches.push_back(MatchCandidate(Match, Line, Column)); |
136 | } |
137 | |
138 | /// \brief Expect 'Match' to occur at the given 'Line' and 'Column'. |
139 | /// |
140 | /// Any number of expected matches can be set by calling this repeatedly. |
141 | /// Each is expected to be matched 'Times' number of times. (This is useful in |
142 | /// cases in which different AST nodes can match at the same source code |
143 | /// location.) |
144 | void ExpectMatch(Twine Match, unsigned Line, unsigned Column, |
145 | unsigned Times = 1) { |
146 | ExpectedMatches.push_back(ExpectedMatch(Match, Line, Column, Times)); |
147 | } |
148 | |
149 | /// \brief Checks that all expected matches have been found. |
150 | ~ExpectedLocationVisitor() override { |
151 | for (typename std::vector<ExpectedMatch>::const_iterator |
152 | It = ExpectedMatches.begin(), End = ExpectedMatches.end(); |
153 | It != End; ++It) { |
154 | It->ExpectFound(); |
155 | } |
156 | } |
157 | |
158 | protected: |
159 | /// \brief Checks an actual match against expected and disallowed matches. |
160 | /// |
161 | /// Implementations are required to call this with appropriate values |
162 | /// for 'Name' during visitation. |
163 | void Match(StringRef Name, SourceLocation Location) { |
164 | const FullSourceLoc FullLocation = this->Context->getFullLoc(Location); |
165 | |
166 | for (typename std::vector<MatchCandidate>::const_iterator |
167 | It = DisallowedMatches.begin(), End = DisallowedMatches.end(); |
168 | It != End; ++It) { |
169 | EXPECT_FALSE(It->Matches(Name, FullLocation)) |
170 | << "Matched disallowed " << *It; |
171 | } |
172 | |
173 | for (typename std::vector<ExpectedMatch>::iterator |
174 | It = ExpectedMatches.begin(), End = ExpectedMatches.end(); |
175 | It != End; ++It) { |
176 | It->UpdateFor(Name, FullLocation, this->Context->getSourceManager()); |
177 | } |
178 | } |
179 | |
180 | private: |
181 | struct MatchCandidate { |
182 | std::string ExpectedName; |
183 | unsigned LineNumber; |
184 | unsigned ColumnNumber; |
185 | |
186 | MatchCandidate(Twine Name, unsigned LineNumber, unsigned ColumnNumber) |
187 | : ExpectedName(Name.str()), LineNumber(LineNumber), |
188 | ColumnNumber(ColumnNumber) { |
189 | } |
190 | |
191 | bool Matches(StringRef Name, FullSourceLoc const &Location) const { |
192 | return MatchesName(Name) && MatchesLocation(Location); |
193 | } |
194 | |
195 | bool PartiallyMatches(StringRef Name, FullSourceLoc const &Location) const { |
196 | return MatchesName(Name) || MatchesLocation(Location); |
197 | } |
198 | |
199 | bool MatchesName(StringRef Name) const { |
200 | return Name == ExpectedName; |
201 | } |
202 | |
203 | bool MatchesLocation(FullSourceLoc const &Location) const { |
204 | return Location.isValid() && |
205 | Location.getSpellingLineNumber() == LineNumber && |
206 | Location.getSpellingColumnNumber() == ColumnNumber; |
207 | } |
208 | |
209 | friend std::ostream &operator<<(std::ostream &Stream, |
210 | MatchCandidate const &Match) { |
211 | return Stream << Match.ExpectedName |
212 | << " at " << Match.LineNumber << ":" << Match.ColumnNumber; |
213 | } |
214 | }; |
215 | |
216 | struct ExpectedMatch { |
217 | ExpectedMatch(Twine Name, unsigned LineNumber, unsigned ColumnNumber, |
218 | unsigned Times) |
219 | : Candidate(Name, LineNumber, ColumnNumber), TimesExpected(Times), |
220 | TimesSeen(0) {} |
221 | |
222 | void UpdateFor(StringRef Name, FullSourceLoc Location, SourceManager &SM) { |
223 | if (Candidate.Matches(Name, Location)) { |
224 | EXPECT_LT(TimesSeen, TimesExpected); |
225 | ++TimesSeen; |
226 | } else if (TimesSeen < TimesExpected && |
227 | Candidate.PartiallyMatches(Name, Location)) { |
228 | llvm::raw_string_ostream Stream(PartialMatches); |
229 | Stream << ", partial match: \"" << Name << "\" at " ; |
230 | Location.print(OS&: Stream, SM); |
231 | } |
232 | } |
233 | |
234 | void ExpectFound() const { |
235 | EXPECT_EQ(TimesExpected, TimesSeen) |
236 | << "Expected \"" << Candidate.ExpectedName |
237 | << "\" at " << Candidate.LineNumber |
238 | << ":" << Candidate.ColumnNumber << PartialMatches; |
239 | } |
240 | |
241 | MatchCandidate Candidate; |
242 | std::string PartialMatches; |
243 | unsigned TimesExpected; |
244 | unsigned TimesSeen; |
245 | }; |
246 | |
247 | std::vector<MatchCandidate> DisallowedMatches; |
248 | std::vector<ExpectedMatch> ExpectedMatches; |
249 | }; |
250 | } |
251 | |
252 | #endif |
253 | |