1//===--- TestVisitor.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// \brief Defines utility templates for RecursiveASTVisitor related tests.
11///
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H
15#define LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H
16
17#include "clang/AST/ASTConsumer.h"
18#include "clang/AST/ASTContext.h"
19#include "clang/AST/RecursiveASTVisitor.h"
20#include "clang/Frontend/CompilerInstance.h"
21#include "clang/Frontend/FrontendAction.h"
22#include "clang/Tooling/Tooling.h"
23#include "gtest/gtest.h"
24#include <vector>
25
26namespace clang {
27
28/// \brief Base class for simple RecursiveASTVisitor based tests.
29///
30/// This is a drop-in replacement for RecursiveASTVisitor itself, with the
31/// additional capability of running it over a snippet of code.
32///
33/// Visits template instantiations and implicit code by default.
34template <typename T>
35class TestVisitor : public RecursiveASTVisitor<T> {
36public:
37 TestVisitor() { }
38
39 virtual ~TestVisitor() { }
40
41 enum Language {
42 Lang_C,
43 Lang_CXX98,
44 Lang_CXX11,
45 Lang_CXX14,
46 Lang_CXX17,
47 Lang_CXX2a,
48 Lang_OBJC,
49 Lang_OBJCXX11,
50 Lang_CXX = Lang_CXX98
51 };
52
53 /// \brief Runs the current AST visitor over the given code.
54 bool runOver(StringRef Code, Language L = Lang_CXX) {
55 std::vector<std::string> Args;
56 switch (L) {
57 case Lang_C:
58 Args.push_back(x: "-x");
59 Args.push_back(x: "c");
60 break;
61 case Lang_CXX98: Args.push_back(x: "-std=c++98"); break;
62 case Lang_CXX11: Args.push_back(x: "-std=c++11"); break;
63 case Lang_CXX14: Args.push_back(x: "-std=c++14"); break;
64 case Lang_CXX17: Args.push_back(x: "-std=c++17"); break;
65 case Lang_CXX2a: Args.push_back(x: "-std=c++2a"); break;
66 case Lang_OBJC:
67 Args.push_back(x: "-ObjC");
68 Args.push_back(x: "-fobjc-runtime=macosx-10.12.0");
69 break;
70 case Lang_OBJCXX11:
71 Args.push_back(x: "-ObjC++");
72 Args.push_back(x: "-std=c++11");
73 Args.push_back(x: "-fblocks");
74 break;
75 }
76 return tooling::runToolOnCodeWithArgs(CreateTestAction(), Code, Args);
77 }
78
79 bool shouldVisitTemplateInstantiations() const {
80 return true;
81 }
82
83 bool shouldVisitImplicitCode() const {
84 return true;
85 }
86
87protected:
88 virtual std::unique_ptr<ASTFrontendAction> CreateTestAction() {
89 return std::make_unique<TestAction>(this);
90 }
91
92 class FindConsumer : public ASTConsumer {
93 public:
94 FindConsumer(TestVisitor *Visitor) : Visitor(Visitor) {}
95
96 void HandleTranslationUnit(clang::ASTContext &Context) override {
97 Visitor->Context = &Context;
98 Visitor->TraverseDecl(Context.getTranslationUnitDecl());
99 }
100
101 private:
102 TestVisitor *Visitor;
103 };
104
105 class TestAction : public ASTFrontendAction {
106 public:
107 TestAction(TestVisitor *Visitor) : Visitor(Visitor) {}
108
109 std::unique_ptr<clang::ASTConsumer>
110 CreateASTConsumer(CompilerInstance &, llvm::StringRef dummy) override {
111 /// TestConsumer will be deleted by the framework calling us.
112 return std::make_unique<FindConsumer>(Visitor);
113 }
114
115 protected:
116 TestVisitor *Visitor;
117 };
118
119 ASTContext *Context;
120};
121
122/// \brief A RecursiveASTVisitor to check that certain matches are (or are
123/// not) observed during visitation.
124///
125/// This is a RecursiveASTVisitor for testing the RecursiveASTVisitor itself,
126/// and allows simple creation of test visitors running matches on only a small
127/// subset of the Visit* methods.
128template <typename T, template <typename> class Visitor = TestVisitor>
129class ExpectedLocationVisitor : public Visitor<T> {
130public:
131 /// \brief Expect 'Match' *not* to occur at the given 'Line' and 'Column'.
132 ///
133 /// Any number of matches can be disallowed.
134 void DisallowMatch(Twine Match, unsigned Line, unsigned Column) {
135 DisallowedMatches.push_back(MatchCandidate(Match, Line, Column));
136 }
137
138 /// \brief Expect 'Match' to occur at the given 'Line' and 'Column'.
139 ///
140 /// Any number of expected matches can be set by calling this repeatedly.
141 /// Each is expected to be matched 'Times' number of times. (This is useful in
142 /// cases in which different AST nodes can match at the same source code
143 /// location.)
144 void ExpectMatch(Twine Match, unsigned Line, unsigned Column,
145 unsigned Times = 1) {
146 ExpectedMatches.push_back(ExpectedMatch(Match, Line, Column, Times));
147 }
148
149 /// \brief Checks that all expected matches have been found.
150 ~ExpectedLocationVisitor() override {
151 for (typename std::vector<ExpectedMatch>::const_iterator
152 It = ExpectedMatches.begin(), End = ExpectedMatches.end();
153 It != End; ++It) {
154 It->ExpectFound();
155 }
156 }
157
158protected:
159 /// \brief Checks an actual match against expected and disallowed matches.
160 ///
161 /// Implementations are required to call this with appropriate values
162 /// for 'Name' during visitation.
163 void Match(StringRef Name, SourceLocation Location) {
164 const FullSourceLoc FullLocation = this->Context->getFullLoc(Location);
165
166 for (typename std::vector<MatchCandidate>::const_iterator
167 It = DisallowedMatches.begin(), End = DisallowedMatches.end();
168 It != End; ++It) {
169 EXPECT_FALSE(It->Matches(Name, FullLocation))
170 << "Matched disallowed " << *It;
171 }
172
173 for (typename std::vector<ExpectedMatch>::iterator
174 It = ExpectedMatches.begin(), End = ExpectedMatches.end();
175 It != End; ++It) {
176 It->UpdateFor(Name, FullLocation, this->Context->getSourceManager());
177 }
178 }
179
180 private:
181 struct MatchCandidate {
182 std::string ExpectedName;
183 unsigned LineNumber;
184 unsigned ColumnNumber;
185
186 MatchCandidate(Twine Name, unsigned LineNumber, unsigned ColumnNumber)
187 : ExpectedName(Name.str()), LineNumber(LineNumber),
188 ColumnNumber(ColumnNumber) {
189 }
190
191 bool Matches(StringRef Name, FullSourceLoc const &Location) const {
192 return MatchesName(Name) && MatchesLocation(Location);
193 }
194
195 bool PartiallyMatches(StringRef Name, FullSourceLoc const &Location) const {
196 return MatchesName(Name) || MatchesLocation(Location);
197 }
198
199 bool MatchesName(StringRef Name) const {
200 return Name == ExpectedName;
201 }
202
203 bool MatchesLocation(FullSourceLoc const &Location) const {
204 return Location.isValid() &&
205 Location.getSpellingLineNumber() == LineNumber &&
206 Location.getSpellingColumnNumber() == ColumnNumber;
207 }
208
209 friend std::ostream &operator<<(std::ostream &Stream,
210 MatchCandidate const &Match) {
211 return Stream << Match.ExpectedName
212 << " at " << Match.LineNumber << ":" << Match.ColumnNumber;
213 }
214 };
215
216 struct ExpectedMatch {
217 ExpectedMatch(Twine Name, unsigned LineNumber, unsigned ColumnNumber,
218 unsigned Times)
219 : Candidate(Name, LineNumber, ColumnNumber), TimesExpected(Times),
220 TimesSeen(0) {}
221
222 void UpdateFor(StringRef Name, FullSourceLoc Location, SourceManager &SM) {
223 if (Candidate.Matches(Name, Location)) {
224 EXPECT_LT(TimesSeen, TimesExpected);
225 ++TimesSeen;
226 } else if (TimesSeen < TimesExpected &&
227 Candidate.PartiallyMatches(Name, Location)) {
228 llvm::raw_string_ostream Stream(PartialMatches);
229 Stream << ", partial match: \"" << Name << "\" at ";
230 Location.print(OS&: Stream, SM);
231 }
232 }
233
234 void ExpectFound() const {
235 EXPECT_EQ(TimesExpected, TimesSeen)
236 << "Expected \"" << Candidate.ExpectedName
237 << "\" at " << Candidate.LineNumber
238 << ":" << Candidate.ColumnNumber << PartialMatches;
239 }
240
241 MatchCandidate Candidate;
242 std::string PartialMatches;
243 unsigned TimesExpected;
244 unsigned TimesSeen;
245 };
246
247 std::vector<MatchCandidate> DisallowedMatches;
248 std::vector<ExpectedMatch> ExpectedMatches;
249};
250}
251
252#endif
253

source code of clang/unittests/Tooling/TestVisitor.h