1 | //===--- unittests/Tooling/RecursiveASTVisitorTests/CallbacksCommon.h -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestVisitor.h" |
10 | |
11 | using namespace clang; |
12 | |
13 | namespace { |
14 | |
15 | enum class ShouldTraversePostOrder : bool { |
16 | No = false, |
17 | Yes = true, |
18 | }; |
19 | |
20 | /// Base class for tests for RecursiveASTVisitor tests that validate the |
21 | /// sequence of calls to user-defined callbacks like Traverse*(), WalkUp*(), |
22 | /// Visit*(). |
23 | template <typename Derived> |
24 | class RecordingVisitorBase : public TestVisitor<Derived> { |
25 | ShouldTraversePostOrder ShouldTraversePostOrderValue; |
26 | |
27 | public: |
28 | RecordingVisitorBase(ShouldTraversePostOrder ShouldTraversePostOrderValue) |
29 | : ShouldTraversePostOrderValue(ShouldTraversePostOrderValue) {} |
30 | |
31 | bool shouldTraversePostOrder() const { |
32 | return static_cast<bool>(ShouldTraversePostOrderValue); |
33 | } |
34 | |
35 | // Callbacks received during traversal. |
36 | std::string CallbackLog; |
37 | unsigned CallbackLogIndent = 0; |
38 | |
39 | std::string stmtToString(Stmt *S) { |
40 | StringRef ClassName = S->getStmtClassName(); |
41 | if (IntegerLiteral *IL = dyn_cast<IntegerLiteral>(Val: S)) { |
42 | return (ClassName + "(" + toString(IL->getValue(), 10, false) + ")" ).str(); |
43 | } |
44 | if (UnaryOperator *UO = dyn_cast<UnaryOperator>(Val: S)) { |
45 | return (ClassName + "(" + UnaryOperator::getOpcodeStr(Op: UO->getOpcode()) + |
46 | ")" ) |
47 | .str(); |
48 | } |
49 | if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Val: S)) { |
50 | return (ClassName + "(" + BinaryOperator::getOpcodeStr(BO->getOpcode()) + |
51 | ")" ) |
52 | .str(); |
53 | } |
54 | if (CallExpr *CE = dyn_cast<CallExpr>(Val: S)) { |
55 | if (FunctionDecl *Callee = CE->getDirectCallee()) { |
56 | if (Callee->getIdentifier()) { |
57 | return (ClassName + "(" + Callee->getName() + ")" ).str(); |
58 | } |
59 | } |
60 | } |
61 | if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Val: S)) { |
62 | if (NamedDecl *ND = DRE->getFoundDecl()) { |
63 | if (ND->getIdentifier()) { |
64 | return (ClassName + "(" + ND->getName() + ")" ).str(); |
65 | } |
66 | } |
67 | } |
68 | return ClassName.str(); |
69 | } |
70 | |
71 | /// Record the fact that the user-defined callback member function |
72 | /// \p CallbackName was called with the argument \p S. Then, record the |
73 | /// effects of calling the default implementation \p CallDefaultFn. |
74 | template <typename CallDefault> |
75 | void recordCallback(StringRef CallbackName, Stmt *S, |
76 | CallDefault CallDefaultFn) { |
77 | for (unsigned i = 0; i != CallbackLogIndent; ++i) { |
78 | CallbackLog += " " ; |
79 | } |
80 | CallbackLog += (CallbackName + " " + stmtToString(S) + "\n" ).str(); |
81 | ++CallbackLogIndent; |
82 | CallDefaultFn(); |
83 | --CallbackLogIndent; |
84 | } |
85 | }; |
86 | |
87 | template <typename VisitorTy> |
88 | ::testing::AssertionResult visitorCallbackLogEqual(VisitorTy Visitor, |
89 | StringRef Code, |
90 | StringRef ExpectedLog) { |
91 | Visitor.runOver(Code); |
92 | // EXPECT_EQ shows the diff between the two strings if they are different. |
93 | EXPECT_EQ(ExpectedLog.trim().str(), |
94 | StringRef(Visitor.CallbackLog).trim().str()); |
95 | if (ExpectedLog.trim() != StringRef(Visitor.CallbackLog).trim()) { |
96 | return ::testing::AssertionFailure(); |
97 | } |
98 | return ::testing::AssertionSuccess(); |
99 | } |
100 | |
101 | } // namespace |
102 | |