1 | //===- unittests/AST/ASTPrint.h ------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Helpers to simplify testing of printing of AST constructs provided in the/ |
10 | // form of the source code. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "clang/AST/ASTContext.h" |
15 | #include "clang/ASTMatchers/ASTMatchFinder.h" |
16 | #include "clang/Tooling/Tooling.h" |
17 | #include "llvm/ADT/SmallString.h" |
18 | #include "gtest/gtest.h" |
19 | |
20 | namespace clang { |
21 | |
22 | using PrintingPolicyAdjuster = llvm::function_ref<void(PrintingPolicy &Policy)>; |
23 | |
24 | template <typename NodeType> |
25 | using NodePrinter = |
26 | std::function<void(llvm::raw_ostream &Out, const ASTContext *Context, |
27 | const NodeType *Node, |
28 | PrintingPolicyAdjuster PolicyAdjuster)>; |
29 | |
30 | template <typename NodeType> |
31 | using NodeFilter = std::function<bool(const NodeType *Node)>; |
32 | |
33 | template <typename NodeType> |
34 | class PrintMatch : public ast_matchers::MatchFinder::MatchCallback { |
35 | using PrinterT = NodePrinter<NodeType>; |
36 | using FilterT = NodeFilter<NodeType>; |
37 | |
38 | SmallString<1024> Printed; |
39 | unsigned NumFoundNodes; |
40 | PrinterT Printer; |
41 | FilterT Filter; |
42 | PrintingPolicyAdjuster PolicyAdjuster; |
43 | |
44 | public: |
45 | PrintMatch(PrinterT Printer, PrintingPolicyAdjuster PolicyAdjuster, |
46 | FilterT Filter) |
47 | : NumFoundNodes(0), Printer(std::move(Printer)), |
48 | Filter(std::move(Filter)), PolicyAdjuster(PolicyAdjuster) {} |
49 | |
50 | void run(const ast_matchers::MatchFinder::MatchResult &Result) override { |
51 | const NodeType *N = Result.Nodes.getNodeAs<NodeType>("id" ); |
52 | if (!N || !Filter(N)) |
53 | return; |
54 | NumFoundNodes++; |
55 | if (NumFoundNodes > 1) |
56 | return; |
57 | |
58 | llvm::raw_svector_ostream Out(Printed); |
59 | Printer(Out, Result.Context, N, PolicyAdjuster); |
60 | } |
61 | |
62 | StringRef getPrinted() const { return Printed; } |
63 | |
64 | unsigned getNumFoundNodes() const { return NumFoundNodes; } |
65 | }; |
66 | |
67 | template <typename NodeType> bool NoNodeFilter(const NodeType *) { |
68 | return true; |
69 | } |
70 | |
71 | template <typename NodeType, typename Matcher> |
72 | ::testing::AssertionResult |
73 | PrintedNodeMatches(StringRef Code, const std::vector<std::string> &Args, |
74 | const Matcher &NodeMatch, StringRef ExpectedPrinted, |
75 | StringRef FileName, NodePrinter<NodeType> Printer, |
76 | PrintingPolicyAdjuster PolicyAdjuster = nullptr, |
77 | bool AllowError = false, |
78 | // Would like to use a lambda for the default value, but that |
79 | // trips gcc 7 up. |
80 | NodeFilter<NodeType> Filter = &NoNodeFilter<NodeType>) { |
81 | |
82 | PrintMatch<NodeType> Callback(Printer, PolicyAdjuster, Filter); |
83 | ast_matchers::MatchFinder Finder; |
84 | Finder.addMatcher(NodeMatch, &Callback); |
85 | std::unique_ptr<tooling::FrontendActionFactory> Factory( |
86 | tooling::newFrontendActionFactory(ConsumerFactory: &Finder)); |
87 | |
88 | bool ToolResult; |
89 | if (FileName.empty()) { |
90 | ToolResult = tooling::runToolOnCodeWithArgs(ToolAction: Factory->create(), Code, Args); |
91 | } else { |
92 | ToolResult = |
93 | tooling::runToolOnCodeWithArgs(ToolAction: Factory->create(), Code, Args, FileName); |
94 | } |
95 | if (!ToolResult && !AllowError) |
96 | return testing::AssertionFailure() |
97 | << "Parsing error in \"" << Code.str() << "\"" ; |
98 | |
99 | if (Callback.getNumFoundNodes() == 0) |
100 | return testing::AssertionFailure() << "Matcher didn't find any nodes" ; |
101 | |
102 | if (Callback.getNumFoundNodes() > 1) |
103 | return testing::AssertionFailure() |
104 | << "Matcher should match only one node (found " |
105 | << Callback.getNumFoundNodes() << ")" ; |
106 | |
107 | if (Callback.getPrinted() != ExpectedPrinted) |
108 | return ::testing::AssertionFailure() |
109 | << "Expected \"" << ExpectedPrinted.str() << "\", got \"" |
110 | << Callback.getPrinted().str() << "\"" ; |
111 | |
112 | return ::testing::AssertionSuccess(); |
113 | } |
114 | |
115 | } // namespace clang |
116 | |