1//===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Provides MatchVerifier, a base class to implement gtest matchers that
10// verify things that can be matched on the AST.
11//
12// Also implements matchers based on MatchVerifier:
13// LocationVerifier and RangeVerifier to verify whether a matched node has
14// the expected source location or source range.
15//
16//===----------------------------------------------------------------------===//
17
18#ifndef LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
19#define LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
20
21#include "clang/AST/ASTContext.h"
22#include "clang/ASTMatchers/ASTMatchFinder.h"
23#include "clang/ASTMatchers/ASTMatchers.h"
24#include "clang/Testing/CommandLineArgs.h"
25#include "clang/Tooling/Tooling.h"
26#include "gtest/gtest.h"
27
28namespace clang {
29namespace ast_matchers {
30
31/// \brief Base class for verifying some property of nodes found by a matcher.
32template <typename NodeType>
33class MatchVerifier : public MatchFinder::MatchCallback {
34public:
35 template <typename MatcherType>
36 testing::AssertionResult match(const std::string &Code,
37 const MatcherType &AMatcher) {
38 std::vector<std::string> Args;
39 return match(Code, AMatcher, Args, Lang_CXX03);
40 }
41
42 template <typename MatcherType>
43 testing::AssertionResult match(const std::string &Code,
44 const MatcherType &AMatcher, TestLanguage L) {
45 std::vector<std::string> Args;
46 return match(Code, AMatcher, Args, L);
47 }
48
49 template <typename MatcherType>
50 testing::AssertionResult
51 match(const std::string &Code, const MatcherType &AMatcher,
52 std::vector<std::string> &Args, TestLanguage L);
53
54 template <typename MatcherType>
55 testing::AssertionResult match(const Decl *D, const MatcherType &AMatcher);
56
57protected:
58 void run(const MatchFinder::MatchResult &Result) override;
59 virtual void verify(const MatchFinder::MatchResult &Result,
60 const NodeType &Node) {}
61
62 void setFailure(const Twine &Result) {
63 Verified = false;
64 VerifyResult = Result.str();
65 }
66
67 void setSuccess() {
68 Verified = true;
69 }
70
71private:
72 bool Verified;
73 std::string VerifyResult;
74};
75
76/// \brief Runs a matcher over some code, and returns the result of the
77/// verifier for the matched node.
78template <typename NodeType>
79template <typename MatcherType>
80testing::AssertionResult
81MatchVerifier<NodeType>::match(const std::string &Code,
82 const MatcherType &AMatcher,
83 std::vector<std::string> &Args, TestLanguage L) {
84 MatchFinder Finder;
85 Finder.addMatcher(AMatcher.bind(""), this);
86 std::unique_ptr<tooling::FrontendActionFactory> Factory(
87 tooling::newFrontendActionFactory(ConsumerFactory: &Finder));
88
89 StringRef FileName;
90 switch (L) {
91 case Lang_C89:
92 Args.push_back(x: "-std=c89");
93 FileName = "input.c";
94 break;
95 case Lang_C99:
96 Args.push_back(x: "-std=c99");
97 FileName = "input.c";
98 break;
99 case Lang_CXX03:
100 Args.push_back(x: "-std=c++03");
101 FileName = "input.cc";
102 break;
103 case Lang_CXX11:
104 Args.push_back(x: "-std=c++11");
105 FileName = "input.cc";
106 break;
107 case Lang_CXX14:
108 Args.push_back(x: "-std=c++14");
109 FileName = "input.cc";
110 break;
111 case Lang_CXX17:
112 Args.push_back(x: "-std=c++17");
113 FileName = "input.cc";
114 break;
115 case Lang_CXX20:
116 Args.push_back(x: "-std=c++20");
117 FileName = "input.cc";
118 break;
119 case Lang_CXX23:
120 Args.push_back(x: "-std=c++23");
121 FileName = "input.cc";
122 break;
123 case Lang_OpenCL:
124 Args.push_back(x: "-cl-no-stdinc");
125 FileName = "input.cl";
126 break;
127 case Lang_OBJC:
128 Args.push_back(x: "-fobjc-nonfragile-abi");
129 FileName = "input.m";
130 break;
131 case Lang_OBJCXX:
132 FileName = "input.mm";
133 break;
134 }
135
136 // Default to failure in case callback is never called
137 setFailure("Could not find match");
138 if (!tooling::runToolOnCodeWithArgs(ToolAction: Factory->create(), Code, Args, FileName))
139 return testing::AssertionFailure() << "Parsing error";
140 if (!Verified)
141 return testing::AssertionFailure() << VerifyResult;
142 return testing::AssertionSuccess();
143}
144
145/// \brief Runs a matcher over some AST, and returns the result of the
146/// verifier for the matched node.
147template <typename NodeType> template <typename MatcherType>
148testing::AssertionResult MatchVerifier<NodeType>::match(
149 const Decl *D, const MatcherType &AMatcher) {
150 MatchFinder Finder;
151 Finder.addMatcher(AMatcher.bind(""), this);
152
153 setFailure("Could not find match");
154 Finder.match(Node: *D, Context&: D->getASTContext());
155
156 if (!Verified)
157 return testing::AssertionFailure() << VerifyResult;
158 return testing::AssertionSuccess();
159}
160
161template <typename NodeType>
162void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
163 const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
164 if (!Node) {
165 setFailure("Matched node has wrong type");
166 } else {
167 // Callback has been called, default to success.
168 setSuccess();
169 verify(Result, Node: *Node);
170 }
171}
172
173template <>
174inline void
175MatchVerifier<DynTypedNode>::run(const MatchFinder::MatchResult &Result) {
176 BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
177 BoundNodes::IDToNodeMap::const_iterator I = M.find(x: "");
178 if (I == M.end()) {
179 setFailure("Node was not bound");
180 } else {
181 // Callback has been called, default to success.
182 setSuccess();
183 verify(Result, Node: I->second);
184 }
185}
186
187/// \brief Verify whether a node has the correct source location.
188///
189/// By default, Node.getSourceLocation() is checked. This can be changed
190/// by overriding getLocation().
191template <typename NodeType>
192class LocationVerifier : public MatchVerifier<NodeType> {
193public:
194 void expectLocation(unsigned Line, unsigned Column) {
195 ExpectLine = Line;
196 ExpectColumn = Column;
197 }
198
199protected:
200 void verify(const MatchFinder::MatchResult &Result,
201 const NodeType &Node) override {
202 SourceLocation Loc = getLocation(Node);
203 unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
204 unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
205 if (Line != ExpectLine || Column != ExpectColumn) {
206 std::string MsgStr;
207 llvm::raw_string_ostream Msg(MsgStr);
208 Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
209 << ">, found <";
210 Loc.print(OS&: Msg, SM: *Result.SourceManager);
211 Msg << '>';
212 this->setFailure(Msg.str());
213 }
214 }
215
216 virtual SourceLocation getLocation(const NodeType &Node) {
217 return Node.getLocation();
218 }
219
220private:
221 unsigned ExpectLine, ExpectColumn;
222};
223
224/// \brief Verify whether a node has the correct source range.
225///
226/// By default, Node.getSourceRange() is checked. This can be changed
227/// by overriding getRange().
228template <typename NodeType>
229class RangeVerifier : public MatchVerifier<NodeType> {
230public:
231 void expectRange(unsigned BeginLine, unsigned BeginColumn,
232 unsigned EndLine, unsigned EndColumn) {
233 ExpectBeginLine = BeginLine;
234 ExpectBeginColumn = BeginColumn;
235 ExpectEndLine = EndLine;
236 ExpectEndColumn = EndColumn;
237 }
238
239protected:
240 void verify(const MatchFinder::MatchResult &Result,
241 const NodeType &Node) override {
242 SourceRange R = getRange(Node);
243 SourceLocation Begin = R.getBegin();
244 SourceLocation End = R.getEnd();
245 unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Loc: Begin);
246 unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Loc: Begin);
247 unsigned EndLine = Result.SourceManager->getSpellingLineNumber(Loc: End);
248 unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(Loc: End);
249 if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
250 EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
251 std::string MsgStr;
252 llvm::raw_string_ostream Msg(MsgStr);
253 Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
254 << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
255 Begin.print(OS&: Msg, SM: *Result.SourceManager);
256 Msg << '-';
257 End.print(OS&: Msg, SM: *Result.SourceManager);
258 Msg << '>';
259 this->setFailure(Msg.str());
260 }
261 }
262
263 virtual SourceRange getRange(const NodeType &Node) {
264 return Node.getSourceRange();
265 }
266
267private:
268 unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
269};
270
271/// \brief Verify whether a node's dump contains a given substring.
272class DumpVerifier : public MatchVerifier<DynTypedNode> {
273public:
274 void expectSubstring(const std::string &Str) {
275 ExpectSubstring = Str;
276 }
277
278protected:
279 void verify(const MatchFinder::MatchResult &Result,
280 const DynTypedNode &Node) override {
281 std::string DumpStr;
282 llvm::raw_string_ostream Dump(DumpStr);
283 Node.dump(OS&: Dump, Context: *Result.Context);
284
285 if (Dump.str().find(str: ExpectSubstring) == std::string::npos) {
286 std::string MsgStr;
287 llvm::raw_string_ostream Msg(MsgStr);
288 Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
289 << Dump.str() << '>';
290 this->setFailure(Msg.str());
291 }
292 }
293
294private:
295 std::string ExpectSubstring;
296};
297
298/// \brief Verify whether a node's pretty print matches a given string.
299class PrintVerifier : public MatchVerifier<DynTypedNode> {
300public:
301 void expectString(const std::string &Str) {
302 ExpectString = Str;
303 }
304
305protected:
306 void verify(const MatchFinder::MatchResult &Result,
307 const DynTypedNode &Node) override {
308 std::string PrintStr;
309 llvm::raw_string_ostream Print(PrintStr);
310 Node.print(OS&: Print, PP: Result.Context->getPrintingPolicy());
311
312 if (Print.str() != ExpectString) {
313 std::string MsgStr;
314 llvm::raw_string_ostream Msg(MsgStr);
315 Msg << "Expected pretty print <" << ExpectString << ">, found <"
316 << Print.str() << '>';
317 this->setFailure(Msg.str());
318 }
319 }
320
321private:
322 std::string ExpectString;
323};
324
325} // end namespace ast_matchers
326} // end namespace clang
327
328#endif
329

source code of clang/unittests/AST/MatchVerifier.h