1 | //===--- InsertionPoint.cpp - Where should we add new code? ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "refactor/InsertionPoint.h" |
10 | #include "support/Logger.h" |
11 | #include "clang/AST/ASTContext.h" |
12 | #include "clang/AST/DeclCXX.h" |
13 | #include "clang/AST/DeclObjC.h" |
14 | #include "clang/AST/DeclTemplate.h" |
15 | #include "clang/Basic/SourceManager.h" |
16 | #include <optional> |
17 | |
18 | namespace clang { |
19 | namespace clangd { |
20 | namespace { |
21 | |
22 | // Choose the decl to insert before, according to an anchor. |
23 | // Nullptr means insert at end of DC. |
24 | // std::nullopt means no valid place to insert. |
25 | std::optional<const Decl *> insertionDecl(const DeclContext &DC, |
26 | const Anchor &A) { |
27 | bool LastMatched = false; |
28 | bool ReturnNext = false; |
29 | for (const auto *D : DC.decls()) { |
30 | if (D->isImplicit()) |
31 | continue; |
32 | if (ReturnNext) |
33 | return D; |
34 | |
35 | const Decl *NonTemplate = D; |
36 | if (auto *TD = llvm::dyn_cast<TemplateDecl>(Val: D)) |
37 | NonTemplate = TD->getTemplatedDecl(); |
38 | bool Matches = A.Match(NonTemplate); |
39 | dlog(" {0} {1} {2}" , Matches, D->getDeclKindName(), D); |
40 | |
41 | switch (A.Direction) { |
42 | case Anchor::Above: |
43 | if (Matches && !LastMatched) { |
44 | // Special case: if "above" matches an access specifier, we actually |
45 | // want to insert below it! |
46 | if (llvm::isa<AccessSpecDecl>(Val: D)) { |
47 | ReturnNext = true; |
48 | continue; |
49 | } |
50 | return D; |
51 | } |
52 | break; |
53 | case Anchor::Below: |
54 | if (LastMatched && !Matches) |
55 | return D; |
56 | break; |
57 | } |
58 | |
59 | LastMatched = Matches; |
60 | } |
61 | if (ReturnNext || (LastMatched && A.Direction == Anchor::Below)) |
62 | return nullptr; |
63 | return std::nullopt; |
64 | } |
65 | |
66 | SourceLocation beginLoc(const Decl &D) { |
67 | auto Loc = D.getBeginLoc(); |
68 | if (RawComment * = D.getASTContext().getRawCommentForDeclNoCache(D: &D)) { |
69 | auto = Comment->getBeginLoc(); |
70 | if (CommentLoc.isValid() && Loc.isValid() && |
71 | D.getASTContext().getSourceManager().isBeforeInTranslationUnit( |
72 | LHS: CommentLoc, RHS: Loc)) |
73 | Loc = CommentLoc; |
74 | } |
75 | return Loc; |
76 | } |
77 | |
78 | bool any(const Decl *D) { return true; } |
79 | |
80 | SourceLocation endLoc(const DeclContext &DC) { |
81 | const Decl *D = llvm::cast<Decl>(Val: &DC); |
82 | if (auto *OCD = llvm::dyn_cast<ObjCContainerDecl>(Val: D)) |
83 | return OCD->getAtEndRange().getBegin(); |
84 | return D->getEndLoc(); |
85 | } |
86 | |
87 | AccessSpecifier getAccessAtEnd(const CXXRecordDecl &C) { |
88 | AccessSpecifier Spec = |
89 | (C.getTagKind() == TagTypeKind::Class ? AS_private : AS_public); |
90 | for (const auto *D : C.decls()) |
91 | if (const auto *ASD = llvm::dyn_cast<AccessSpecDecl>(D)) |
92 | Spec = ASD->getAccess(); |
93 | return Spec; |
94 | } |
95 | |
96 | } // namespace |
97 | |
98 | SourceLocation insertionPoint(const DeclContext &DC, |
99 | llvm::ArrayRef<Anchor> Anchors) { |
100 | dlog("Looking for insertion point in {0}" , DC.getDeclKindName()); |
101 | for (const auto &A : Anchors) { |
102 | dlog(" anchor ({0})" , A.Direction == Anchor::Above ? "above" : "below" ); |
103 | if (auto D = insertionDecl(DC, A)) { |
104 | dlog(" anchor matched before {0}" , *D); |
105 | return *D ? beginLoc(D: **D) : endLoc(DC); |
106 | } |
107 | } |
108 | dlog("no anchor matched" ); |
109 | return SourceLocation(); |
110 | } |
111 | |
112 | llvm::Expected<tooling::Replacement> |
113 | insertDecl(llvm::StringRef Code, const DeclContext &DC, |
114 | llvm::ArrayRef<Anchor> Anchors) { |
115 | auto Loc = insertionPoint(DC, Anchors); |
116 | // Fallback: insert at the end. |
117 | if (Loc.isInvalid()) |
118 | Loc = endLoc(DC); |
119 | const auto &SM = DC.getParentASTContext().getSourceManager(); |
120 | if (!SM.isWrittenInSameFile(Loc1: Loc, Loc2: cast<Decl>(Val: DC).getLocation())) |
121 | return error(Fmt: "{0} body in wrong file: {1}" , Vals: DC.getDeclKindName(), |
122 | Vals: Loc.printToString(SM)); |
123 | return tooling::Replacement(SM, Loc, 0, Code); |
124 | } |
125 | |
126 | SourceLocation insertionPoint(const CXXRecordDecl &InClass, |
127 | std::vector<Anchor> Anchors, |
128 | AccessSpecifier Protection) { |
129 | for (auto &A : Anchors) |
130 | A.Match = [Inner(std::move(A.Match)), Protection](const Decl *D) { |
131 | return D->getAccess() == Protection && Inner(D); |
132 | }; |
133 | return insertionPoint(InClass, Anchors); |
134 | } |
135 | |
136 | llvm::Expected<tooling::Replacement> insertDecl(llvm::StringRef Code, |
137 | const CXXRecordDecl &InClass, |
138 | std::vector<Anchor> Anchors, |
139 | AccessSpecifier Protection) { |
140 | // Fallback: insert at the bottom of the relevant access section. |
141 | Anchors.push_back(x: {.Match: any, .Direction: Anchor::Below}); |
142 | auto Loc = insertionPoint(InClass, Anchors: std::move(Anchors), Protection); |
143 | std::string CodeBuffer; |
144 | auto &SM = InClass.getASTContext().getSourceManager(); |
145 | // Fallback: insert at the end of the class. Check if protection matches! |
146 | if (Loc.isInvalid()) { |
147 | Loc = InClass.getBraceRange().getEnd(); |
148 | if (Protection != getAccessAtEnd(C: InClass)) { |
149 | CodeBuffer = (getAccessSpelling(AS: Protection) + ":\n" + Code).str(); |
150 | Code = CodeBuffer; |
151 | } |
152 | } |
153 | if (!SM.isWrittenInSameFile(Loc, InClass.getLocation())) |
154 | return error("Class body in wrong file: {0}" , Loc.printToString(SM: SM)); |
155 | return tooling::Replacement(SM, Loc, 0, Code); |
156 | } |
157 | |
158 | } // namespace clangd |
159 | } // namespace clang |
160 | |