| 1 | //===--- InsertionPoint.cpp - Where should we add new code? ---------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "refactor/InsertionPoint.h" |
| 10 | #include "support/Logger.h" |
| 11 | #include "clang/AST/ASTContext.h" |
| 12 | #include "clang/AST/DeclCXX.h" |
| 13 | #include "clang/AST/DeclObjC.h" |
| 14 | #include "clang/AST/DeclTemplate.h" |
| 15 | #include "clang/Basic/SourceManager.h" |
| 16 | #include <optional> |
| 17 | |
| 18 | namespace clang { |
| 19 | namespace clangd { |
| 20 | namespace { |
| 21 | |
| 22 | // Choose the decl to insert before, according to an anchor. |
| 23 | // Nullptr means insert at end of DC. |
| 24 | // std::nullopt means no valid place to insert. |
| 25 | std::optional<const Decl *> insertionDecl(const DeclContext &DC, |
| 26 | const Anchor &A) { |
| 27 | bool LastMatched = false; |
| 28 | bool ReturnNext = false; |
| 29 | for (const auto *D : DC.decls()) { |
| 30 | if (D->isImplicit()) |
| 31 | continue; |
| 32 | if (ReturnNext) |
| 33 | return D; |
| 34 | |
| 35 | const Decl *NonTemplate = D; |
| 36 | if (auto *TD = llvm::dyn_cast<TemplateDecl>(Val: D)) |
| 37 | NonTemplate = TD->getTemplatedDecl(); |
| 38 | bool Matches = A.Match(NonTemplate); |
| 39 | dlog(" {0} {1} {2}" , Matches, D->getDeclKindName(), D); |
| 40 | |
| 41 | switch (A.Direction) { |
| 42 | case Anchor::Above: |
| 43 | if (Matches && !LastMatched) { |
| 44 | // Special case: if "above" matches an access specifier, we actually |
| 45 | // want to insert below it! |
| 46 | if (llvm::isa<AccessSpecDecl>(Val: D)) { |
| 47 | ReturnNext = true; |
| 48 | continue; |
| 49 | } |
| 50 | return D; |
| 51 | } |
| 52 | break; |
| 53 | case Anchor::Below: |
| 54 | if (LastMatched && !Matches) |
| 55 | return D; |
| 56 | break; |
| 57 | } |
| 58 | |
| 59 | LastMatched = Matches; |
| 60 | } |
| 61 | if (ReturnNext || (LastMatched && A.Direction == Anchor::Below)) |
| 62 | return nullptr; |
| 63 | return std::nullopt; |
| 64 | } |
| 65 | |
| 66 | SourceLocation beginLoc(const Decl &D) { |
| 67 | auto Loc = D.getBeginLoc(); |
| 68 | if (RawComment * = D.getASTContext().getRawCommentForDeclNoCache(D: &D)) { |
| 69 | auto = Comment->getBeginLoc(); |
| 70 | if (CommentLoc.isValid() && Loc.isValid() && |
| 71 | D.getASTContext().getSourceManager().isBeforeInTranslationUnit( |
| 72 | LHS: CommentLoc, RHS: Loc)) |
| 73 | Loc = CommentLoc; |
| 74 | } |
| 75 | return Loc; |
| 76 | } |
| 77 | |
| 78 | bool any(const Decl *D) { return true; } |
| 79 | |
| 80 | SourceLocation endLoc(const DeclContext &DC) { |
| 81 | const Decl *D = llvm::cast<Decl>(Val: &DC); |
| 82 | if (auto *OCD = llvm::dyn_cast<ObjCContainerDecl>(Val: D)) |
| 83 | return OCD->getAtEndRange().getBegin(); |
| 84 | return D->getEndLoc(); |
| 85 | } |
| 86 | |
| 87 | AccessSpecifier getAccessAtEnd(const CXXRecordDecl &C) { |
| 88 | AccessSpecifier Spec = |
| 89 | (C.getTagKind() == TagTypeKind::Class ? AS_private : AS_public); |
| 90 | for (const auto *D : C.decls()) |
| 91 | if (const auto *ASD = llvm::dyn_cast<AccessSpecDecl>(D)) |
| 92 | Spec = ASD->getAccess(); |
| 93 | return Spec; |
| 94 | } |
| 95 | |
| 96 | } // namespace |
| 97 | |
| 98 | SourceLocation insertionPoint(const DeclContext &DC, |
| 99 | llvm::ArrayRef<Anchor> Anchors) { |
| 100 | dlog("Looking for insertion point in {0}" , DC.getDeclKindName()); |
| 101 | for (const auto &A : Anchors) { |
| 102 | dlog(" anchor ({0})" , A.Direction == Anchor::Above ? "above" : "below" ); |
| 103 | if (auto D = insertionDecl(DC, A)) { |
| 104 | dlog(" anchor matched before {0}" , *D); |
| 105 | return *D ? beginLoc(D: **D) : endLoc(DC); |
| 106 | } |
| 107 | } |
| 108 | dlog("no anchor matched" ); |
| 109 | return SourceLocation(); |
| 110 | } |
| 111 | |
| 112 | llvm::Expected<tooling::Replacement> |
| 113 | insertDecl(llvm::StringRef Code, const DeclContext &DC, |
| 114 | llvm::ArrayRef<Anchor> Anchors) { |
| 115 | auto Loc = insertionPoint(DC, Anchors); |
| 116 | // Fallback: insert at the end. |
| 117 | if (Loc.isInvalid()) |
| 118 | Loc = endLoc(DC); |
| 119 | const auto &SM = DC.getParentASTContext().getSourceManager(); |
| 120 | if (!SM.isWrittenInSameFile(Loc1: Loc, Loc2: cast<Decl>(Val: DC).getLocation())) |
| 121 | return error(Fmt: "{0} body in wrong file: {1}" , Vals: DC.getDeclKindName(), |
| 122 | Vals: Loc.printToString(SM)); |
| 123 | return tooling::Replacement(SM, Loc, 0, Code); |
| 124 | } |
| 125 | |
| 126 | SourceLocation insertionPoint(const CXXRecordDecl &InClass, |
| 127 | std::vector<Anchor> Anchors, |
| 128 | AccessSpecifier Protection) { |
| 129 | for (auto &A : Anchors) |
| 130 | A.Match = [Inner(std::move(A.Match)), Protection](const Decl *D) { |
| 131 | return D->getAccess() == Protection && Inner(D); |
| 132 | }; |
| 133 | return insertionPoint(InClass, Anchors); |
| 134 | } |
| 135 | |
| 136 | llvm::Expected<tooling::Replacement> insertDecl(llvm::StringRef Code, |
| 137 | const CXXRecordDecl &InClass, |
| 138 | std::vector<Anchor> Anchors, |
| 139 | AccessSpecifier Protection) { |
| 140 | // Fallback: insert at the bottom of the relevant access section. |
| 141 | Anchors.push_back(x: {.Match: any, .Direction: Anchor::Below}); |
| 142 | auto Loc = insertionPoint(InClass, Anchors: std::move(Anchors), Protection); |
| 143 | std::string CodeBuffer; |
| 144 | auto &SM = InClass.getASTContext().getSourceManager(); |
| 145 | // Fallback: insert at the end of the class. Check if protection matches! |
| 146 | if (Loc.isInvalid()) { |
| 147 | Loc = InClass.getBraceRange().getEnd(); |
| 148 | if (Protection != getAccessAtEnd(C: InClass)) { |
| 149 | CodeBuffer = (getAccessSpelling(AS: Protection) + ":\n" + Code).str(); |
| 150 | Code = CodeBuffer; |
| 151 | } |
| 152 | } |
| 153 | if (!SM.isWrittenInSameFile(Loc, InClass.getLocation())) |
| 154 | return error("Class body in wrong file: {0}" , Loc.printToString(SM: SM)); |
| 155 | return tooling::Replacement(SM, Loc, 0, Code); |
| 156 | } |
| 157 | |
| 158 | } // namespace clangd |
| 159 | } // namespace clang |
| 160 | |