1//===--- ExtractFunction.cpp -------------------------------------*- C++-*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Extracts statements to a new function and replaces the statements with a
10// call to the new function.
11// Before:
12// void f(int a) {
13// [[if(a < 5)
14// a = 5;]]
15// }
16// After:
17// void extracted(int &a) {
18// if(a < 5)
19// a = 5;
20// }
21// void f(int a) {
22// extracted(a);
23// }
24//
25// - Only extract statements
26// - Extracts from non-templated free functions only.
27// - Parameters are const only if the declaration was const
28// - Always passed by l-value reference
29// - Void return type
30// - Cannot extract declarations that will be needed in the original function
31// after extraction.
32// - Checks for broken control flow (break/continue without loop/switch)
33//
34// 1. ExtractFunction is the tweak subclass
35// - Prepare does basic analysis of the selection and is therefore fast.
36// Successful prepare doesn't always mean we can apply the tweak.
37// - Apply does a more detailed analysis and can be slower. In case of
38// failure, we let the user know that we are unable to perform extraction.
39// 2. ExtractionZone store information about the range being extracted and the
40// enclosing function.
41// 3. NewFunction stores properties of the extracted function and provides
42// methods for rendering it.
43// 4. CapturedZoneInfo uses a RecursiveASTVisitor to capture information about
44// the extraction like declarations, existing return statements, etc.
45// 5. getExtractedFunction is responsible for analyzing the CapturedZoneInfo and
46// creating a NewFunction.
47//===----------------------------------------------------------------------===//
48
49#include "AST.h"
50#include "FindTarget.h"
51#include "ParsedAST.h"
52#include "Selection.h"
53#include "SourceCode.h"
54#include "refactor/Tweak.h"
55#include "support/Logger.h"
56#include "clang/AST/ASTContext.h"
57#include "clang/AST/Decl.h"
58#include "clang/AST/DeclBase.h"
59#include "clang/AST/NestedNameSpecifier.h"
60#include "clang/AST/RecursiveASTVisitor.h"
61#include "clang/AST/Stmt.h"
62#include "clang/Basic/LangOptions.h"
63#include "clang/Basic/SourceLocation.h"
64#include "clang/Basic/SourceManager.h"
65#include "clang/Tooling/Core/Replacement.h"
66#include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/SmallSet.h"
69#include "llvm/ADT/SmallVector.h"
70#include "llvm/ADT/StringRef.h"
71#include "llvm/Support/Casting.h"
72#include "llvm/Support/Error.h"
73#include "llvm/Support/raw_os_ostream.h"
74#include <optional>
75
76namespace clang {
77namespace clangd {
78namespace {
79
80using Node = SelectionTree::Node;
81
82// ExtractionZone is the part of code that is being extracted.
83// EnclosingFunction is the function/method inside which the zone lies.
84// We split the file into 4 parts relative to extraction zone.
85enum class ZoneRelative {
86 Before, // Before Zone and inside EnclosingFunction.
87 Inside, // Inside Zone.
88 After, // After Zone and inside EnclosingFunction.
89 OutsideFunc // Outside EnclosingFunction.
90};
91
92enum FunctionDeclKind {
93 InlineDefinition,
94 ForwardDeclaration,
95 OutOfLineDefinition
96};
97
98// A RootStmt is a statement that's fully selected including all it's children
99// and it's parent is unselected.
100// Check if a node is a root statement.
101bool isRootStmt(const Node *N) {
102 if (!N->ASTNode.get<Stmt>())
103 return false;
104 // Root statement cannot be partially selected.
105 if (N->Selected == SelectionTree::Partial)
106 return false;
107 // Only DeclStmt can be an unselected RootStmt since VarDecls claim the entire
108 // selection range in selectionTree.
109 if (N->Selected == SelectionTree::Unselected && !N->ASTNode.get<DeclStmt>())
110 return false;
111 return true;
112}
113
114// Returns the (unselected) parent of all RootStmts given the commonAncestor.
115// Returns null if:
116// 1. any node is partially selected
117// 2. If all completely selected nodes don't have the same common parent
118// 3. Any child of Parent isn't a RootStmt.
119// Returns null if any child is not a RootStmt.
120// We only support extraction of RootStmts since it allows us to extract without
121// having to change the selection range. Also, this means that any scope that
122// begins in selection range, ends in selection range and any scope that begins
123// outside the selection range, ends outside as well.
124const Node *getParentOfRootStmts(const Node *CommonAnc) {
125 if (!CommonAnc)
126 return nullptr;
127 const Node *Parent = nullptr;
128 switch (CommonAnc->Selected) {
129 case SelectionTree::Selection::Unselected:
130 // Typically a block, with the { and } unselected, could also be ForStmt etc
131 // Ensure all Children are RootStmts.
132 Parent = CommonAnc;
133 break;
134 case SelectionTree::Selection::Partial:
135 // Only a fully-selected single statement can be selected.
136 return nullptr;
137 case SelectionTree::Selection::Complete:
138 // If the Common Ancestor is completely selected, then it's a root statement
139 // and its parent will be unselected.
140 Parent = CommonAnc->Parent;
141 // If parent is a DeclStmt, even though it's unselected, we consider it a
142 // root statement and return its parent. This is done because the VarDecls
143 // claim the entire selection range of the Declaration and DeclStmt is
144 // always unselected.
145 if (Parent->ASTNode.get<DeclStmt>())
146 Parent = Parent->Parent;
147 break;
148 }
149 // Ensure all Children are RootStmts.
150 return llvm::all_of(Range: Parent->Children, P: isRootStmt) ? Parent : nullptr;
151}
152
153// The ExtractionZone class forms a view of the code wrt Zone.
154struct ExtractionZone {
155 // Parent of RootStatements being extracted.
156 const Node *Parent = nullptr;
157 // The half-open file range of the code being extracted.
158 SourceRange ZoneRange;
159 // The function inside which our zone resides.
160 const FunctionDecl *EnclosingFunction = nullptr;
161 // The half-open file range of the enclosing function.
162 SourceRange EnclosingFuncRange;
163 // Set of statements that form the ExtractionZone.
164 llvm::DenseSet<const Stmt *> RootStmts;
165
166 SourceLocation getInsertionPoint() const {
167 return EnclosingFuncRange.getBegin();
168 }
169 bool isRootStmt(const Stmt *S) const;
170 // The last root statement is important to decide where we need to insert a
171 // semicolon after the extraction.
172 const Node *getLastRootStmt() const { return Parent->Children.back(); }
173
174 // Checks if declarations inside extraction zone are accessed afterwards.
175 //
176 // This performs a partial AST traversal proportional to the size of the
177 // enclosing function, so it is possibly expensive.
178 bool requiresHoisting(const SourceManager &SM,
179 const HeuristicResolver *Resolver) const {
180 // First find all the declarations that happened inside extraction zone.
181 llvm::SmallSet<const Decl *, 1> DeclsInExtZone;
182 for (auto *RootStmt : RootStmts) {
183 findExplicitReferences(
184 S: RootStmt,
185 Out: [&DeclsInExtZone](const ReferenceLoc &Loc) {
186 if (!Loc.IsDecl)
187 return;
188 DeclsInExtZone.insert(Loc.Targets.front());
189 },
190 Resolver);
191 }
192 // Early exit without performing expensive traversal below.
193 if (DeclsInExtZone.empty())
194 return false;
195 // Then make sure they are not used outside the zone.
196 for (const auto *S : EnclosingFunction->getBody()->children()) {
197 if (SM.isBeforeInTranslationUnit(LHS: S->getSourceRange().getEnd(),
198 RHS: ZoneRange.getEnd()))
199 continue;
200 bool HasPostUse = false;
201 findExplicitReferences(
202 S,
203 Out: [&](const ReferenceLoc &Loc) {
204 if (HasPostUse ||
205 SM.isBeforeInTranslationUnit(LHS: Loc.NameLoc, RHS: ZoneRange.getEnd()))
206 return;
207 HasPostUse = llvm::any_of(Range: Loc.Targets,
208 P: [&DeclsInExtZone](const Decl *Target) {
209 return DeclsInExtZone.contains(Ptr: Target);
210 });
211 },
212 Resolver);
213 if (HasPostUse)
214 return true;
215 }
216 return false;
217 }
218};
219
220// Whether the code in the extraction zone is guaranteed to return, assuming
221// no broken control flow (unbound break/continue).
222// This is a very naive check (does it end with a return stmt).
223// Doing some rudimentary control flow analysis would cover more cases.
224bool alwaysReturns(const ExtractionZone &EZ) {
225 const Stmt *Last = EZ.getLastRootStmt()->ASTNode.get<Stmt>();
226 // Unwrap enclosing (unconditional) compound statement.
227 while (const auto *CS = llvm::dyn_cast<CompoundStmt>(Last)) {
228 if (CS->body_empty())
229 return false;
230 Last = CS->body_back();
231 }
232 return llvm::isa<ReturnStmt>(Val: Last);
233}
234
235bool ExtractionZone::isRootStmt(const Stmt *S) const {
236 return RootStmts.contains(V: S);
237}
238
239// Finds the function in which the zone lies.
240const FunctionDecl *findEnclosingFunction(const Node *CommonAnc) {
241 // Walk up the SelectionTree until we find a function Decl
242 for (const Node *CurNode = CommonAnc; CurNode; CurNode = CurNode->Parent) {
243 // Don't extract from lambdas
244 if (CurNode->ASTNode.get<LambdaExpr>())
245 return nullptr;
246 if (const FunctionDecl *Func = CurNode->ASTNode.get<FunctionDecl>()) {
247 // FIXME: Support extraction from templated functions.
248 if (Func->isTemplated())
249 return nullptr;
250 if (!Func->getBody())
251 return nullptr;
252 for (const auto *S : Func->getBody()->children()) {
253 // During apply phase, we perform semantic analysis (e.g. figure out
254 // what variables requires hoisting). We cannot perform those when the
255 // body has invalid statements, so fail up front.
256 if (!S)
257 return nullptr;
258 }
259 return Func;
260 }
261 }
262 return nullptr;
263}
264
265// Zone Range is the union of SourceRanges of all child Nodes in Parent since
266// all child Nodes are RootStmts
267std::optional<SourceRange> findZoneRange(const Node *Parent,
268 const SourceManager &SM,
269 const LangOptions &LangOpts) {
270 SourceRange SR;
271 if (auto BeginFileRange = toHalfOpenFileRange(
272 SM, LangOpts, Parent->Children.front()->ASTNode.getSourceRange()))
273 SR.setBegin(BeginFileRange->getBegin());
274 else
275 return std::nullopt;
276 if (auto EndFileRange = toHalfOpenFileRange(
277 SM, LangOpts, Parent->Children.back()->ASTNode.getSourceRange()))
278 SR.setEnd(EndFileRange->getEnd());
279 else
280 return std::nullopt;
281 return SR;
282}
283
284// Compute the range spanned by the enclosing function.
285// FIXME: check if EnclosingFunction has any attributes as the AST doesn't
286// always store the source range of the attributes and thus we end up extracting
287// between the attributes and the EnclosingFunction.
288std::optional<SourceRange>
289computeEnclosingFuncRange(const FunctionDecl *EnclosingFunction,
290 const SourceManager &SM,
291 const LangOptions &LangOpts) {
292 return toHalfOpenFileRange(Mgr: SM, LangOpts, R: EnclosingFunction->getSourceRange());
293}
294
295// returns true if Child can be a single RootStmt being extracted from
296// EnclosingFunc.
297bool validSingleChild(const Node *Child, const FunctionDecl *EnclosingFunc) {
298 // Don't extract expressions.
299 // FIXME: We should extract expressions that are "statements" i.e. not
300 // subexpressions
301 if (Child->ASTNode.get<Expr>())
302 return false;
303 // Extracting the body of EnclosingFunc would remove it's definition.
304 assert(EnclosingFunc->hasBody() &&
305 "We should always be extracting from a function body.");
306 if (Child->ASTNode.get<Stmt>() == EnclosingFunc->getBody())
307 return false;
308 return true;
309}
310
311// FIXME: Check we're not extracting from the initializer/condition of a control
312// flow structure.
313std::optional<ExtractionZone> findExtractionZone(const Node *CommonAnc,
314 const SourceManager &SM,
315 const LangOptions &LangOpts) {
316 ExtractionZone ExtZone;
317 ExtZone.Parent = getParentOfRootStmts(CommonAnc);
318 if (!ExtZone.Parent || ExtZone.Parent->Children.empty())
319 return std::nullopt;
320 ExtZone.EnclosingFunction = findEnclosingFunction(CommonAnc: ExtZone.Parent);
321 if (!ExtZone.EnclosingFunction)
322 return std::nullopt;
323 // When there is a single RootStmt, we must check if it's valid for
324 // extraction.
325 if (ExtZone.Parent->Children.size() == 1 &&
326 !validSingleChild(Child: ExtZone.getLastRootStmt(), EnclosingFunc: ExtZone.EnclosingFunction))
327 return std::nullopt;
328 if (auto FuncRange =
329 computeEnclosingFuncRange(EnclosingFunction: ExtZone.EnclosingFunction, SM, LangOpts))
330 ExtZone.EnclosingFuncRange = *FuncRange;
331 if (auto ZoneRange = findZoneRange(Parent: ExtZone.Parent, SM, LangOpts))
332 ExtZone.ZoneRange = *ZoneRange;
333 if (ExtZone.EnclosingFuncRange.isInvalid() || ExtZone.ZoneRange.isInvalid())
334 return std::nullopt;
335
336 for (const Node *Child : ExtZone.Parent->Children)
337 ExtZone.RootStmts.insert(Child->ASTNode.get<Stmt>());
338
339 return ExtZone;
340}
341
342// Stores information about the extracted function and provides methods for
343// rendering it.
344struct NewFunction {
345 struct Parameter {
346 std::string Name;
347 QualType TypeInfo;
348 bool PassByReference;
349 unsigned OrderPriority; // Lower value parameters are preferred first.
350 std::string render(const DeclContext *Context) const;
351 bool operator<(const Parameter &Other) const {
352 return OrderPriority < Other.OrderPriority;
353 }
354 };
355 std::string Name = "extracted";
356 QualType ReturnType;
357 std::vector<Parameter> Parameters;
358 SourceRange BodyRange;
359 SourceLocation DefinitionPoint;
360 std::optional<SourceLocation> ForwardDeclarationPoint;
361 const CXXRecordDecl *EnclosingClass = nullptr;
362 const NestedNameSpecifier *DefinitionQualifier = nullptr;
363 const DeclContext *SemanticDC = nullptr;
364 const DeclContext *SyntacticDC = nullptr;
365 const DeclContext *ForwardDeclarationSyntacticDC = nullptr;
366 bool CallerReturnsValue = false;
367 bool Static = false;
368 ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified;
369 bool Const = false;
370
371 // Decides whether the extracted function body and the function call need a
372 // semicolon after extraction.
373 tooling::ExtractionSemicolonPolicy SemicolonPolicy;
374 const LangOptions *LangOpts;
375 NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy,
376 const LangOptions *LangOpts)
377 : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {}
378 // Render the call for this function.
379 std::string renderCall() const;
380 // Render the definition for this function.
381 std::string renderDeclaration(FunctionDeclKind K,
382 const DeclContext &SemanticDC,
383 const DeclContext &SyntacticDC,
384 const SourceManager &SM) const;
385
386private:
387 std::string
388 renderParametersForDeclaration(const DeclContext &Enclosing) const;
389 std::string renderParametersForCall() const;
390 std::string renderSpecifiers(FunctionDeclKind K) const;
391 std::string renderQualifiers() const;
392 std::string renderDeclarationName(FunctionDeclKind K) const;
393 // Generate the function body.
394 std::string getFuncBody(const SourceManager &SM) const;
395};
396
397std::string NewFunction::renderParametersForDeclaration(
398 const DeclContext &Enclosing) const {
399 std::string Result;
400 bool NeedCommaBefore = false;
401 for (const Parameter &P : Parameters) {
402 if (NeedCommaBefore)
403 Result += ", ";
404 NeedCommaBefore = true;
405 Result += P.render(Context: &Enclosing);
406 }
407 return Result;
408}
409
410std::string NewFunction::renderParametersForCall() const {
411 std::string Result;
412 bool NeedCommaBefore = false;
413 for (const Parameter &P : Parameters) {
414 if (NeedCommaBefore)
415 Result += ", ";
416 NeedCommaBefore = true;
417 Result += P.Name;
418 }
419 return Result;
420}
421
422std::string NewFunction::renderSpecifiers(FunctionDeclKind K) const {
423 std::string Attributes;
424
425 if (Static && K != FunctionDeclKind::OutOfLineDefinition) {
426 Attributes += "static ";
427 }
428
429 switch (Constexpr) {
430 case ConstexprSpecKind::Unspecified:
431 case ConstexprSpecKind::Constinit:
432 break;
433 case ConstexprSpecKind::Constexpr:
434 Attributes += "constexpr ";
435 break;
436 case ConstexprSpecKind::Consteval:
437 Attributes += "consteval ";
438 break;
439 }
440
441 return Attributes;
442}
443
444std::string NewFunction::renderQualifiers() const {
445 std::string Attributes;
446
447 if (Const) {
448 Attributes += " const";
449 }
450
451 return Attributes;
452}
453
454std::string NewFunction::renderDeclarationName(FunctionDeclKind K) const {
455 if (DefinitionQualifier == nullptr || K != OutOfLineDefinition) {
456 return Name;
457 }
458
459 std::string QualifierName;
460 llvm::raw_string_ostream Oss(QualifierName);
461 DefinitionQualifier->print(OS&: Oss, Policy: *LangOpts);
462 return llvm::formatv(Fmt: "{0}{1}", Vals&: QualifierName, Vals: Name);
463}
464
465std::string NewFunction::renderCall() const {
466 return std::string(
467 llvm::formatv(Fmt: "{0}{1}({2}){3}", Vals: CallerReturnsValue ? "return " : "", Vals: Name,
468 Vals: renderParametersForCall(),
469 Vals: (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : "")));
470}
471
472std::string NewFunction::renderDeclaration(FunctionDeclKind K,
473 const DeclContext &SemanticDC,
474 const DeclContext &SyntacticDC,
475 const SourceManager &SM) const {
476 std::string Declaration = std::string(llvm::formatv(
477 "{0}{1} {2}({3}){4}", renderSpecifiers(K),
478 printType(ReturnType, SyntacticDC), renderDeclarationName(K),
479 renderParametersForDeclaration(SemanticDC), renderQualifiers()));
480
481 switch (K) {
482 case ForwardDeclaration:
483 return std::string(llvm::formatv(Fmt: "{0};\n", Vals&: Declaration));
484 case OutOfLineDefinition:
485 case InlineDefinition:
486 return std::string(
487 llvm::formatv(Fmt: "{0} {\n{1}\n}\n", Vals&: Declaration, Vals: getFuncBody(SM)));
488 break;
489 }
490 llvm_unreachable("Unsupported FunctionDeclKind enum");
491}
492
493std::string NewFunction::getFuncBody(const SourceManager &SM) const {
494 // FIXME: Generate tooling::Replacements instead of std::string to
495 // - hoist decls
496 // - add return statement
497 // - Add semicolon
498 return toSourceCode(SM, R: BodyRange).str() +
499 (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
500}
501
502std::string NewFunction::Parameter::render(const DeclContext *Context) const {
503 return printType(TypeInfo, *Context) + (PassByReference ? " &" : " ") + Name;
504}
505
506// Stores captured information about Extraction Zone.
507struct CapturedZoneInfo {
508 struct DeclInformation {
509 const Decl *TheDecl;
510 ZoneRelative DeclaredIn;
511 // index of the declaration or first reference.
512 unsigned DeclIndex;
513 bool IsReferencedInZone = false;
514 bool IsReferencedInPostZone = false;
515 // FIXME: Capture mutation information
516 DeclInformation(const Decl *TheDecl, ZoneRelative DeclaredIn,
517 unsigned DeclIndex)
518 : TheDecl(TheDecl), DeclaredIn(DeclaredIn), DeclIndex(DeclIndex){};
519 // Marks the occurence of a reference for this declaration
520 void markOccurence(ZoneRelative ReferenceLoc);
521 };
522 // Maps Decls to their DeclInfo
523 llvm::DenseMap<const Decl *, DeclInformation> DeclInfoMap;
524 bool HasReturnStmt = false; // Are there any return statements in the zone?
525 bool AlwaysReturns = false; // Does the zone always return?
526 // Control flow is broken if we are extracting a break/continue without a
527 // corresponding parent loop/switch
528 bool BrokenControlFlow = false;
529 // FIXME: capture TypeAliasDecl and UsingDirectiveDecl
530 // FIXME: Capture type information as well.
531 DeclInformation *createDeclInfo(const Decl *D, ZoneRelative RelativeLoc);
532 DeclInformation *getDeclInfoFor(const Decl *D);
533};
534
535CapturedZoneInfo::DeclInformation *
536CapturedZoneInfo::createDeclInfo(const Decl *D, ZoneRelative RelativeLoc) {
537 // The new Decl's index is the size of the map so far.
538 auto InsertionResult = DeclInfoMap.insert(
539 KV: {D, DeclInformation(D, RelativeLoc, DeclInfoMap.size())});
540 // Return the newly created DeclInfo
541 return &InsertionResult.first->second;
542}
543
544CapturedZoneInfo::DeclInformation *
545CapturedZoneInfo::getDeclInfoFor(const Decl *D) {
546 // If the Decl doesn't exist, we
547 auto Iter = DeclInfoMap.find(Val: D);
548 if (Iter == DeclInfoMap.end())
549 return nullptr;
550 return &Iter->second;
551}
552
553void CapturedZoneInfo::DeclInformation::markOccurence(
554 ZoneRelative ReferenceLoc) {
555 switch (ReferenceLoc) {
556 case ZoneRelative::Inside:
557 IsReferencedInZone = true;
558 break;
559 case ZoneRelative::After:
560 IsReferencedInPostZone = true;
561 break;
562 default:
563 break;
564 }
565}
566
567bool isLoop(const Stmt *S) {
568 return isa<ForStmt>(Val: S) || isa<DoStmt>(Val: S) || isa<WhileStmt>(Val: S) ||
569 isa<CXXForRangeStmt>(Val: S);
570}
571
572// Captures information from Extraction Zone
573CapturedZoneInfo captureZoneInfo(const ExtractionZone &ExtZone) {
574 // We use the ASTVisitor instead of using the selection tree since we need to
575 // find references in the PostZone as well.
576 // FIXME: Check which statements we don't allow to extract.
577 class ExtractionZoneVisitor
578 : public clang::RecursiveASTVisitor<ExtractionZoneVisitor> {
579 public:
580 ExtractionZoneVisitor(const ExtractionZone &ExtZone) : ExtZone(ExtZone) {
581 TraverseDecl(const_cast<FunctionDecl *>(ExtZone.EnclosingFunction));
582 }
583
584 bool TraverseStmt(Stmt *S) {
585 if (!S)
586 return true;
587 bool IsRootStmt = ExtZone.isRootStmt(S: const_cast<const Stmt *>(S));
588 // If we are starting traversal of a RootStmt, we are somewhere inside
589 // ExtractionZone
590 if (IsRootStmt)
591 CurrentLocation = ZoneRelative::Inside;
592 addToLoopSwitchCounters(S, Increment: 1);
593 // Traverse using base class's TraverseStmt
594 RecursiveASTVisitor::TraverseStmt(S);
595 addToLoopSwitchCounters(S, Increment: -1);
596 // We set the current location as after since next stmt will either be a
597 // RootStmt (handled at the beginning) or after extractionZone
598 if (IsRootStmt)
599 CurrentLocation = ZoneRelative::After;
600 return true;
601 }
602
603 // Add Increment to CurNumberOf{Loops,Switch} if statement is
604 // {Loop,Switch} and inside Extraction Zone.
605 void addToLoopSwitchCounters(Stmt *S, int Increment) {
606 if (CurrentLocation != ZoneRelative::Inside)
607 return;
608 if (isLoop(S))
609 CurNumberOfNestedLoops += Increment;
610 else if (isa<SwitchStmt>(Val: S))
611 CurNumberOfSwitch += Increment;
612 }
613
614 bool VisitDecl(Decl *D) {
615 Info.createDeclInfo(D, RelativeLoc: CurrentLocation);
616 return true;
617 }
618
619 bool VisitDeclRefExpr(DeclRefExpr *DRE) {
620 // Find the corresponding Decl and mark it's occurrence.
621 const Decl *D = DRE->getDecl();
622 auto *DeclInfo = Info.getDeclInfoFor(D);
623 // If no Decl was found, the Decl must be outside the enclosingFunc.
624 if (!DeclInfo)
625 DeclInfo = Info.createDeclInfo(D, RelativeLoc: ZoneRelative::OutsideFunc);
626 DeclInfo->markOccurence(CurrentLocation);
627 // FIXME: check if reference mutates the Decl being referred.
628 return true;
629 }
630
631 bool VisitReturnStmt(ReturnStmt *Return) {
632 if (CurrentLocation == ZoneRelative::Inside)
633 Info.HasReturnStmt = true;
634 return true;
635 }
636
637 bool VisitBreakStmt(BreakStmt *Break) {
638 // Control flow is broken if break statement is selected without any
639 // parent loop or switch statement.
640 if (CurrentLocation == ZoneRelative::Inside &&
641 !(CurNumberOfNestedLoops || CurNumberOfSwitch))
642 Info.BrokenControlFlow = true;
643 return true;
644 }
645
646 bool VisitContinueStmt(ContinueStmt *Continue) {
647 // Control flow is broken if Continue statement is selected without any
648 // parent loop
649 if (CurrentLocation == ZoneRelative::Inside && !CurNumberOfNestedLoops)
650 Info.BrokenControlFlow = true;
651 return true;
652 }
653 CapturedZoneInfo Info;
654 const ExtractionZone &ExtZone;
655 ZoneRelative CurrentLocation = ZoneRelative::Before;
656 // Number of {loop,switch} statements that are currently in the traversal
657 // stack inside Extraction Zone. Used to check for broken control flow.
658 unsigned CurNumberOfNestedLoops = 0;
659 unsigned CurNumberOfSwitch = 0;
660 };
661 ExtractionZoneVisitor Visitor(ExtZone);
662 CapturedZoneInfo Result = std::move(Visitor.Info);
663 Result.AlwaysReturns = alwaysReturns(EZ: ExtZone);
664 return Result;
665}
666
667// Adds parameters to ExtractedFunc.
668// Returns true if able to find the parameters successfully and no hoisting
669// needed.
670// FIXME: Check if the declaration has a local/anonymous type
671bool createParameters(NewFunction &ExtractedFunc,
672 const CapturedZoneInfo &CapturedInfo) {
673 for (const auto &KeyVal : CapturedInfo.DeclInfoMap) {
674 const auto &DeclInfo = KeyVal.second;
675 // If a Decl was Declared in zone and referenced in post zone, it
676 // needs to be hoisted (we bail out in that case).
677 // FIXME: Support Decl Hoisting.
678 if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
679 DeclInfo.IsReferencedInPostZone)
680 return false;
681 if (!DeclInfo.IsReferencedInZone)
682 continue; // no need to pass as parameter, not referenced
683 if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
684 DeclInfo.DeclaredIn == ZoneRelative::OutsideFunc)
685 continue; // no need to pass as parameter, still accessible.
686 // Parameter specific checks.
687 const ValueDecl *VD = dyn_cast_or_null<ValueDecl>(Val: DeclInfo.TheDecl);
688 // Can't parameterise if the Decl isn't a ValueDecl or is a FunctionDecl
689 // (this includes the case of recursive call to EnclosingFunc in Zone).
690 if (!VD || isa<FunctionDecl>(Val: DeclInfo.TheDecl))
691 return false;
692 // Parameter qualifiers are same as the Decl's qualifiers.
693 QualType TypeInfo = VD->getType().getNonReferenceType();
694 // FIXME: Need better qualifier checks: check mutated status for
695 // Decl(e.g. was it assigned, passed as nonconst argument, etc)
696 // FIXME: check if parameter will be a non l-value reference.
697 // FIXME: We don't want to always pass variables of types like int,
698 // pointers, etc by reference.
699 bool IsPassedByReference = true;
700 // We use the index of declaration as the ordering priority for parameters.
701 ExtractedFunc.Parameters.push_back({std::string(VD->getName()), TypeInfo,
702 IsPassedByReference,
703 DeclInfo.DeclIndex});
704 }
705 llvm::sort(C&: ExtractedFunc.Parameters);
706 return true;
707}
708
709// Clangd uses open ranges while ExtractionSemicolonPolicy (in Clang Tooling)
710// uses closed ranges. Generates the semicolon policy for the extraction and
711// extends the ZoneRange if necessary.
712tooling::ExtractionSemicolonPolicy
713getSemicolonPolicy(ExtractionZone &ExtZone, const SourceManager &SM,
714 const LangOptions &LangOpts) {
715 // Get closed ZoneRange.
716 SourceRange FuncBodyRange = {ExtZone.ZoneRange.getBegin(),
717 ExtZone.ZoneRange.getEnd().getLocWithOffset(Offset: -1)};
718 auto SemicolonPolicy = tooling::ExtractionSemicolonPolicy::compute(
719 S: ExtZone.getLastRootStmt()->ASTNode.get<Stmt>(), ExtractedRange&: FuncBodyRange, SM,
720 LangOpts);
721 // Update ZoneRange.
722 ExtZone.ZoneRange.setEnd(FuncBodyRange.getEnd().getLocWithOffset(Offset: 1));
723 return SemicolonPolicy;
724}
725
726// Generate return type for ExtractedFunc. Return false if unable to do so.
727bool generateReturnProperties(NewFunction &ExtractedFunc,
728 const FunctionDecl &EnclosingFunc,
729 const CapturedZoneInfo &CapturedInfo) {
730 // If the selected code always returns, we preserve those return statements.
731 // The return type should be the same as the enclosing function.
732 // (Others are possible if there are conversions, but this seems clearest).
733 if (CapturedInfo.HasReturnStmt) {
734 // If the return is conditional, neither replacing the code with
735 // `extracted()` nor `return extracted()` is correct.
736 if (!CapturedInfo.AlwaysReturns)
737 return false;
738 QualType Ret = EnclosingFunc.getReturnType();
739 // Once we support members, it'd be nice to support e.g. extracting a method
740 // of Foo<T> that returns T. But it's not clear when that's safe.
741 if (Ret->isDependentType())
742 return false;
743 ExtractedFunc.ReturnType = Ret;
744 return true;
745 }
746 // FIXME: Generate new return statement if needed.
747 ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
748 return true;
749}
750
751void captureMethodInfo(NewFunction &ExtractedFunc,
752 const CXXMethodDecl *Method) {
753 ExtractedFunc.Static = Method->isStatic();
754 ExtractedFunc.Const = Method->isConst();
755 ExtractedFunc.EnclosingClass = Method->getParent();
756}
757
758// FIXME: add support for adding other function return types besides void.
759// FIXME: assign the value returned by non void extracted function.
760llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
761 const SourceManager &SM,
762 const LangOptions &LangOpts) {
763 CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
764 // Bail out if any break of continue exists
765 if (CapturedInfo.BrokenControlFlow)
766 return error(Fmt: "Cannot extract break/continue without corresponding "
767 "loop/switch statement.");
768 NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts),
769 &LangOpts);
770
771 ExtractedFunc.SyntacticDC =
772 ExtZone.EnclosingFunction->getLexicalDeclContext();
773 ExtractedFunc.SemanticDC = ExtZone.EnclosingFunction->getDeclContext();
774 ExtractedFunc.DefinitionQualifier = ExtZone.EnclosingFunction->getQualifier();
775 ExtractedFunc.Constexpr = ExtZone.EnclosingFunction->getConstexprKind();
776
777 if (const auto *Method =
778 llvm::dyn_cast<CXXMethodDecl>(Val: ExtZone.EnclosingFunction))
779 captureMethodInfo(ExtractedFunc, Method);
780
781 if (ExtZone.EnclosingFunction->isOutOfLine()) {
782 // FIXME: Put the extracted method in a private section if it's a class or
783 // maybe in an anonymous namespace
784 const auto *FirstOriginalDecl =
785 ExtZone.EnclosingFunction->getCanonicalDecl();
786 auto DeclPos =
787 toHalfOpenFileRange(Mgr: SM, LangOpts, R: FirstOriginalDecl->getSourceRange());
788 if (!DeclPos)
789 return error(Fmt: "Declaration is inside a macro");
790 ExtractedFunc.ForwardDeclarationPoint = DeclPos->getBegin();
791 ExtractedFunc.ForwardDeclarationSyntacticDC = ExtractedFunc.SemanticDC;
792 }
793
794 ExtractedFunc.BodyRange = ExtZone.ZoneRange;
795 ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint();
796
797 ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns;
798 if (!createParameters(ExtractedFunc, CapturedInfo) ||
799 !generateReturnProperties(ExtractedFunc, EnclosingFunc: *ExtZone.EnclosingFunction,
800 CapturedInfo))
801 return error(Fmt: "Too complex to extract.");
802 return ExtractedFunc;
803}
804
805class ExtractFunction : public Tweak {
806public:
807 const char *id() const final;
808 bool prepare(const Selection &Inputs) override;
809 Expected<Effect> apply(const Selection &Inputs) override;
810 std::string title() const override { return "Extract to function"; }
811 llvm::StringLiteral kind() const override {
812 return CodeAction::REFACTOR_KIND;
813 }
814
815private:
816 ExtractionZone ExtZone;
817};
818
819REGISTER_TWEAK(ExtractFunction)
820tooling::Replacement replaceWithFuncCall(const NewFunction &ExtractedFunc,
821 const SourceManager &SM,
822 const LangOptions &LangOpts) {
823 std::string FuncCall = ExtractedFunc.renderCall();
824 return tooling::Replacement(
825 SM, CharSourceRange(ExtractedFunc.BodyRange, false), FuncCall, LangOpts);
826}
827
828tooling::Replacement createFunctionDefinition(const NewFunction &ExtractedFunc,
829 const SourceManager &SM) {
830 FunctionDeclKind DeclKind = InlineDefinition;
831 if (ExtractedFunc.ForwardDeclarationPoint)
832 DeclKind = OutOfLineDefinition;
833 std::string FunctionDef = ExtractedFunc.renderDeclaration(
834 K: DeclKind, SemanticDC: *ExtractedFunc.SemanticDC, SyntacticDC: *ExtractedFunc.SyntacticDC, SM);
835
836 return tooling::Replacement(SM, ExtractedFunc.DefinitionPoint, 0,
837 FunctionDef);
838}
839
840tooling::Replacement createForwardDeclaration(const NewFunction &ExtractedFunc,
841 const SourceManager &SM) {
842 std::string FunctionDecl = ExtractedFunc.renderDeclaration(
843 K: ForwardDeclaration, SemanticDC: *ExtractedFunc.SemanticDC,
844 SyntacticDC: *ExtractedFunc.ForwardDeclarationSyntacticDC, SM);
845 SourceLocation DeclPoint = *ExtractedFunc.ForwardDeclarationPoint;
846
847 return tooling::Replacement(SM, DeclPoint, 0, FunctionDecl);
848}
849
850// Returns true if ExtZone contains any ReturnStmts.
851bool hasReturnStmt(const ExtractionZone &ExtZone) {
852 class ReturnStmtVisitor
853 : public clang::RecursiveASTVisitor<ReturnStmtVisitor> {
854 public:
855 bool VisitReturnStmt(ReturnStmt *Return) {
856 Found = true;
857 return false; // We found the answer, abort the scan.
858 }
859 bool Found = false;
860 };
861
862 ReturnStmtVisitor V;
863 for (const Stmt *RootStmt : ExtZone.RootStmts) {
864 V.TraverseStmt(S: const_cast<Stmt *>(RootStmt));
865 if (V.Found)
866 break;
867 }
868 return V.Found;
869}
870
871bool ExtractFunction::prepare(const Selection &Inputs) {
872 const LangOptions &LangOpts = Inputs.AST->getLangOpts();
873 if (!LangOpts.CPlusPlus)
874 return false;
875 const Node *CommonAnc = Inputs.ASTSelection.commonAncestor();
876 const SourceManager &SM = Inputs.AST->getSourceManager();
877 auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts);
878 if (!MaybeExtZone ||
879 (hasReturnStmt(ExtZone: *MaybeExtZone) && !alwaysReturns(EZ: *MaybeExtZone)))
880 return false;
881
882 // FIXME: Get rid of this check once we support hoisting.
883 if (MaybeExtZone->requiresHoisting(SM, Resolver: Inputs.AST->getHeuristicResolver()))
884 return false;
885
886 ExtZone = std::move(*MaybeExtZone);
887 return true;
888}
889
890Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) {
891 const SourceManager &SM = Inputs.AST->getSourceManager();
892 const LangOptions &LangOpts = Inputs.AST->getLangOpts();
893 auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
894 // FIXME: Add more types of errors.
895 if (!ExtractedFunc)
896 return ExtractedFunc.takeError();
897 tooling::Replacements Edit;
898 if (auto Err = Edit.add(R: createFunctionDefinition(ExtractedFunc: *ExtractedFunc, SM)))
899 return std::move(Err);
900 if (auto Err = Edit.add(R: replaceWithFuncCall(ExtractedFunc: *ExtractedFunc, SM, LangOpts)))
901 return std::move(Err);
902
903 if (auto FwdLoc = ExtractedFunc->ForwardDeclarationPoint) {
904 // If the fwd-declaration goes in the same file, merge into Replacements.
905 // Otherwise it needs to be a separate file edit.
906 if (SM.isWrittenInSameFile(Loc1: ExtractedFunc->DefinitionPoint, Loc2: *FwdLoc)) {
907 if (auto Err = Edit.add(R: createForwardDeclaration(ExtractedFunc: *ExtractedFunc, SM)))
908 return std::move(Err);
909 } else {
910 auto MultiFileEffect = Effect::mainFileEdit(SM, Replacements: std::move(Edit));
911 if (!MultiFileEffect)
912 return MultiFileEffect.takeError();
913
914 tooling::Replacements OtherEdit(
915 createForwardDeclaration(ExtractedFunc: *ExtractedFunc, SM));
916 if (auto PathAndEdit = Tweak::Effect::fileEdit(SM, FID: SM.getFileID(SpellingLoc: *FwdLoc),
917 Replacements: OtherEdit))
918 MultiFileEffect->ApplyEdits.try_emplace(Key: PathAndEdit->first,
919 Args&: PathAndEdit->second);
920 else
921 return PathAndEdit.takeError();
922 return MultiFileEffect;
923 }
924 }
925 return Effect::mainFileEdit(SM, Replacements: std::move(Edit));
926}
927
928} // namespace
929} // namespace clangd
930} // namespace clang
931

source code of clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp