1 | //===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "UseConstraintsCheck.h" |
10 | #include "clang/AST/ASTContext.h" |
11 | #include "clang/ASTMatchers/ASTMatchFinder.h" |
12 | #include "clang/Lex/Lexer.h" |
13 | |
14 | #include "../utils/LexerUtils.h" |
15 | |
16 | #include <optional> |
17 | #include <utility> |
18 | |
19 | using namespace clang::ast_matchers; |
20 | |
21 | namespace clang::tidy::modernize { |
22 | |
23 | struct EnableIfData { |
24 | TemplateSpecializationTypeLoc Loc; |
25 | TypeLoc Outer; |
26 | }; |
27 | |
28 | namespace { |
29 | AST_MATCHER(FunctionDecl, hasOtherDeclarations) { |
30 | auto It = Node.redecls_begin(); |
31 | auto EndIt = Node.redecls_end(); |
32 | |
33 | if (It == EndIt) |
34 | return false; |
35 | |
36 | ++It; |
37 | return It != EndIt; |
38 | } |
39 | } // namespace |
40 | |
41 | void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) { |
42 | Finder->addMatcher( |
43 | NodeMatch: functionTemplateDecl( |
44 | has(functionDecl(unless(hasOtherDeclarations()), isDefinition(), |
45 | hasReturnTypeLoc(ReturnMatcher: typeLoc().bind(ID: "return" ))) |
46 | .bind(ID: "function" ))) |
47 | .bind(ID: "functionTemplate" ), |
48 | Action: this); |
49 | } |
50 | |
51 | static std::optional<TemplateSpecializationTypeLoc> |
52 | matchEnableIfSpecializationImplTypename(TypeLoc TheType) { |
53 | if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) { |
54 | const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier(); |
55 | if (!Identifier || Identifier->getName() != "type" || |
56 | Dep.getTypePtr()->getKeyword() != ElaboratedTypeKeyword::Typename) { |
57 | return std::nullopt; |
58 | } |
59 | TheType = Dep.getQualifierLoc().getTypeLoc(); |
60 | } |
61 | |
62 | if (const auto SpecializationLoc = |
63 | TheType.getAs<TemplateSpecializationTypeLoc>()) { |
64 | |
65 | const auto *Specialization = |
66 | dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr()); |
67 | if (!Specialization) |
68 | return std::nullopt; |
69 | |
70 | const TemplateDecl *TD = |
71 | Specialization->getTemplateName().getAsTemplateDecl(); |
72 | if (!TD || TD->getName() != "enable_if" ) |
73 | return std::nullopt; |
74 | |
75 | int NumArgs = SpecializationLoc.getNumArgs(); |
76 | if (NumArgs != 1 && NumArgs != 2) |
77 | return std::nullopt; |
78 | |
79 | return SpecializationLoc; |
80 | } |
81 | return std::nullopt; |
82 | } |
83 | |
84 | static std::optional<TemplateSpecializationTypeLoc> |
85 | matchEnableIfSpecializationImplTrait(TypeLoc TheType) { |
86 | if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>()) |
87 | TheType = Elaborated.getNamedTypeLoc(); |
88 | |
89 | if (const auto SpecializationLoc = |
90 | TheType.getAs<TemplateSpecializationTypeLoc>()) { |
91 | |
92 | const auto *Specialization = |
93 | dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr()); |
94 | if (!Specialization) |
95 | return std::nullopt; |
96 | |
97 | const TemplateDecl *TD = |
98 | Specialization->getTemplateName().getAsTemplateDecl(); |
99 | if (!TD || TD->getName() != "enable_if_t" ) |
100 | return std::nullopt; |
101 | |
102 | if (!Specialization->isTypeAlias()) |
103 | return std::nullopt; |
104 | |
105 | if (const auto *AliasedType = |
106 | dyn_cast<DependentNameType>(Specialization->getAliasedType())) { |
107 | if (AliasedType->getIdentifier()->getName() != "type" || |
108 | AliasedType->getKeyword() != ElaboratedTypeKeyword::Typename) { |
109 | return std::nullopt; |
110 | } |
111 | } else { |
112 | return std::nullopt; |
113 | } |
114 | int NumArgs = SpecializationLoc.getNumArgs(); |
115 | if (NumArgs != 1 && NumArgs != 2) |
116 | return std::nullopt; |
117 | |
118 | return SpecializationLoc; |
119 | } |
120 | return std::nullopt; |
121 | } |
122 | |
123 | static std::optional<TemplateSpecializationTypeLoc> |
124 | matchEnableIfSpecializationImpl(TypeLoc TheType) { |
125 | if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType)) |
126 | return EnableIf; |
127 | return matchEnableIfSpecializationImplTrait(TheType); |
128 | } |
129 | |
130 | static std::optional<EnableIfData> |
131 | matchEnableIfSpecialization(TypeLoc TheType) { |
132 | if (const auto Pointer = TheType.getAs<PointerTypeLoc>()) |
133 | TheType = Pointer.getPointeeLoc(); |
134 | else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>()) |
135 | TheType = Reference.getPointeeLoc(); |
136 | if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>()) |
137 | TheType = Qualified.getUnqualifiedLoc(); |
138 | |
139 | if (auto EnableIf = matchEnableIfSpecializationImpl(TheType)) |
140 | return EnableIfData{std::move(*EnableIf), TheType}; |
141 | return std::nullopt; |
142 | } |
143 | |
144 | static std::pair<std::optional<EnableIfData>, const Decl *> |
145 | matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) { |
146 | // For non-type trailing param, match very specifically |
147 | // 'template <..., enable_if_type<Condition, Type> = Default>' where |
148 | // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename |
149 | // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr> |
150 | // |
151 | // Otherwise, match a trailing default type arg. |
152 | // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>' |
153 | |
154 | const TemplateParameterList *TemplateParams = |
155 | FunctionTemplate->getTemplateParameters(); |
156 | if (TemplateParams->size() == 0) |
157 | return {}; |
158 | |
159 | const NamedDecl *LastParam = |
160 | TemplateParams->getParam(Idx: TemplateParams->size() - 1); |
161 | if (const auto *LastTemplateParam = |
162 | dyn_cast<NonTypeTemplateParmDecl>(LastParam)) { |
163 | |
164 | if (!LastTemplateParam->hasDefaultArgument() || |
165 | !LastTemplateParam->getName().empty()) |
166 | return {}; |
167 | |
168 | return {matchEnableIfSpecialization( |
169 | LastTemplateParam->getTypeSourceInfo()->getTypeLoc()), |
170 | LastTemplateParam}; |
171 | } |
172 | if (const auto *LastTemplateParam = |
173 | dyn_cast<TemplateTypeParmDecl>(LastParam)) { |
174 | if (LastTemplateParam->hasDefaultArgument() && |
175 | LastTemplateParam->getIdentifier() == nullptr) { |
176 | return {matchEnableIfSpecialization( |
177 | LastTemplateParam->getDefaultArgumentInfo()->getTypeLoc()), |
178 | LastTemplateParam}; |
179 | } |
180 | } |
181 | return {}; |
182 | } |
183 | |
184 | template <typename T> |
185 | static SourceLocation getRAngleFileLoc(const SourceManager &SM, |
186 | const T &Element) { |
187 | // getFileLoc handles the case where the RAngle loc is part of a synthesized |
188 | // '>>', which ends up allocating a 'scratch space' buffer in the source |
189 | // manager. |
190 | return SM.getFileLoc(Loc: Element.getRAngleLoc()); |
191 | } |
192 | |
193 | static SourceRange |
194 | getConditionRange(ASTContext &Context, |
195 | const TemplateSpecializationTypeLoc &EnableIf) { |
196 | // TemplateArgumentLoc's SourceRange End is the location of the last token |
197 | // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End |
198 | // location will be the first 'B' in 'BBB'. |
199 | const LangOptions &LangOpts = Context.getLangOpts(); |
200 | const SourceManager &SM = Context.getSourceManager(); |
201 | if (EnableIf.getNumArgs() > 1) { |
202 | TemplateArgumentLoc NextArg = EnableIf.getArgLoc(i: 1); |
203 | return {EnableIf.getLAngleLoc().getLocWithOffset(Offset: 1), |
204 | utils::lexer::findPreviousTokenKind( |
205 | Start: NextArg.getSourceRange().getBegin(), SM, LangOpts, TK: tok::comma)}; |
206 | } |
207 | |
208 | return {EnableIf.getLAngleLoc().getLocWithOffset(Offset: 1), |
209 | getRAngleFileLoc(SM, Element: EnableIf)}; |
210 | } |
211 | |
212 | static SourceRange getTypeRange(ASTContext &Context, |
213 | const TemplateSpecializationTypeLoc &EnableIf) { |
214 | TemplateArgumentLoc Arg = EnableIf.getArgLoc(i: 1); |
215 | const LangOptions &LangOpts = Context.getLangOpts(); |
216 | const SourceManager &SM = Context.getSourceManager(); |
217 | return {utils::lexer::findPreviousTokenKind(Start: Arg.getSourceRange().getBegin(), |
218 | SM, LangOpts, TK: tok::comma) |
219 | .getLocWithOffset(Offset: 1), |
220 | getRAngleFileLoc(SM, Element: EnableIf)}; |
221 | } |
222 | |
223 | // Returns the original source text of the second argument of a call to |
224 | // enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function |
225 | // returns 'TheType'. |
226 | static std::optional<StringRef> |
227 | getTypeText(ASTContext &Context, |
228 | const TemplateSpecializationTypeLoc &EnableIf) { |
229 | if (EnableIf.getNumArgs() > 1) { |
230 | const LangOptions &LangOpts = Context.getLangOpts(); |
231 | const SourceManager &SM = Context.getSourceManager(); |
232 | bool Invalid = false; |
233 | StringRef Text = Lexer::getSourceText(Range: CharSourceRange::getCharRange( |
234 | R: getTypeRange(Context, EnableIf)), |
235 | SM, LangOpts, Invalid: &Invalid) |
236 | .trim(); |
237 | if (Invalid) |
238 | return std::nullopt; |
239 | |
240 | return Text; |
241 | } |
242 | |
243 | return "void" ; |
244 | } |
245 | |
246 | static std::optional<SourceLocation> |
247 | findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) { |
248 | SourceManager &SM = Context.getSourceManager(); |
249 | const LangOptions &LangOpts = Context.getLangOpts(); |
250 | |
251 | if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Val: Function)) { |
252 | for (const CXXCtorInitializer *Init : Constructor->inits()) { |
253 | if (Init->getSourceOrder() == 0) |
254 | return utils::lexer::findPreviousTokenKind(Start: Init->getSourceLocation(), |
255 | SM, LangOpts, TK: tok::colon); |
256 | } |
257 | if (Constructor->init_begin() != Constructor->init_end()) |
258 | return std::nullopt; |
259 | } |
260 | if (Function->isDeleted()) { |
261 | SourceLocation FunctionEnd = Function->getSourceRange().getEnd(); |
262 | return utils::lexer::findNextAnyTokenKind(Start: FunctionEnd, SM, LangOpts, |
263 | TK: tok::equal, TKs: tok::equal); |
264 | } |
265 | const Stmt *Body = Function->getBody(); |
266 | if (!Body) |
267 | return std::nullopt; |
268 | |
269 | return Body->getBeginLoc(); |
270 | } |
271 | |
272 | bool isPrimaryExpression(const Expr *Expression) { |
273 | // This function is an incomplete approximation of checking whether |
274 | // an Expr is a primary expression. In particular, if this function |
275 | // returns true, the expression is a primary expression. The converse |
276 | // is not necessarily true. |
277 | |
278 | if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Val: Expression)) |
279 | Expression = Cast->getSubExprAsWritten(); |
280 | if (isa<ParenExpr, DependentScopeDeclRefExpr>(Val: Expression)) |
281 | return true; |
282 | |
283 | return false; |
284 | } |
285 | |
286 | // Return the original source text of an enable_if_t condition, i.e., the |
287 | // first template argument). For example, in |
288 | // 'enable_if_t<FirstCondition || SecondCondition, AType>', the text |
289 | // the text 'FirstCondition || SecondCondition' is returned. |
290 | static std::optional<std::string> getConditionText(const Expr *ConditionExpr, |
291 | SourceRange ConditionRange, |
292 | ASTContext &Context) { |
293 | SourceManager &SM = Context.getSourceManager(); |
294 | const LangOptions &LangOpts = Context.getLangOpts(); |
295 | |
296 | SourceLocation PrevTokenLoc = ConditionRange.getEnd(); |
297 | if (PrevTokenLoc.isInvalid()) |
298 | return std::nullopt; |
299 | |
300 | const bool = false; |
301 | Token PrevToken; |
302 | std::tie(args&: PrevToken, args&: PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart( |
303 | Location: PrevTokenLoc, SM, LangOpts, SkipComments); |
304 | bool EndsWithDoubleSlash = |
305 | PrevToken.is(K: tok::comment) && |
306 | Lexer::getSourceText(Range: CharSourceRange::getCharRange( |
307 | B: PrevTokenLoc, E: PrevTokenLoc.getLocWithOffset(Offset: 2)), |
308 | SM, LangOpts) == "//" ; |
309 | |
310 | bool Invalid = false; |
311 | llvm::StringRef ConditionText = Lexer::getSourceText( |
312 | Range: CharSourceRange::getCharRange(R: ConditionRange), SM, LangOpts, Invalid: &Invalid); |
313 | if (Invalid) |
314 | return std::nullopt; |
315 | |
316 | auto AddParens = [&](llvm::StringRef Text) -> std::string { |
317 | if (isPrimaryExpression(Expression: ConditionExpr)) |
318 | return Text.str(); |
319 | return "(" + Text.str() + ")" ; |
320 | }; |
321 | |
322 | if (EndsWithDoubleSlash) |
323 | return AddParens(ConditionText); |
324 | return AddParens(ConditionText.trim()); |
325 | } |
326 | |
327 | // Handle functions that return enable_if_t, e.g., |
328 | // template <...> |
329 | // enable_if_t<Condition, ReturnType> function(); |
330 | // |
331 | // Return a vector of FixItHints if the code can be replaced with |
332 | // a C++20 requires clause. In the example above, returns FixItHints |
333 | // to result in |
334 | // template <...> |
335 | // ReturnType function() requires Condition {} |
336 | static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function, |
337 | const TypeLoc &ReturnType, |
338 | const EnableIfData &EnableIf, |
339 | ASTContext &Context) { |
340 | TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0); |
341 | |
342 | SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc); |
343 | |
344 | std::optional<std::string> ConditionText = getConditionText( |
345 | ConditionExpr: EnableCondition.getSourceExpression(), ConditionRange, Context); |
346 | if (!ConditionText) |
347 | return {}; |
348 | |
349 | std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc); |
350 | if (!TypeText) |
351 | return {}; |
352 | |
353 | SmallVector<const Expr *, 3> ExistingConstraints; |
354 | Function->getAssociatedConstraints(AC&: ExistingConstraints); |
355 | if (!ExistingConstraints.empty()) { |
356 | // FIXME - Support adding new constraints to existing ones. Do we need to |
357 | // consider subsumption? |
358 | return {}; |
359 | } |
360 | |
361 | std::optional<SourceLocation> ConstraintInsertionLoc = |
362 | findInsertionForConstraint(Function, Context); |
363 | if (!ConstraintInsertionLoc) |
364 | return {}; |
365 | |
366 | std::vector<FixItHint> FixIts; |
367 | FixIts.push_back(x: FixItHint::CreateReplacement( |
368 | RemoveRange: CharSourceRange::getTokenRange(R: EnableIf.Outer.getSourceRange()), |
369 | Code: *TypeText)); |
370 | FixIts.push_back(x: FixItHint::CreateInsertion( |
371 | InsertionLoc: *ConstraintInsertionLoc, Code: "requires " + *ConditionText + " " )); |
372 | return FixIts; |
373 | } |
374 | |
375 | // Handle enable_if_t in a trailing template parameter, e.g., |
376 | // template <..., enable_if_t<Condition, Type> = Type{}> |
377 | // ReturnType function(); |
378 | // |
379 | // Return a vector of FixItHints if the code can be replaced with |
380 | // a C++20 requires clause. In the example above, returns FixItHints |
381 | // to result in |
382 | // template <...> |
383 | // ReturnType function() requires Condition {} |
384 | static std::vector<FixItHint> |
385 | handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate, |
386 | const FunctionDecl *Function, |
387 | const Decl *LastTemplateParam, |
388 | const EnableIfData &EnableIf, ASTContext &Context) { |
389 | SourceManager &SM = Context.getSourceManager(); |
390 | const LangOptions &LangOpts = Context.getLangOpts(); |
391 | |
392 | TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0); |
393 | |
394 | SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc); |
395 | |
396 | std::optional<std::string> ConditionText = getConditionText( |
397 | ConditionExpr: EnableCondition.getSourceExpression(), ConditionRange, Context); |
398 | if (!ConditionText) |
399 | return {}; |
400 | |
401 | SmallVector<const Expr *, 3> ExistingConstraints; |
402 | Function->getAssociatedConstraints(AC&: ExistingConstraints); |
403 | if (!ExistingConstraints.empty()) { |
404 | // FIXME - Support adding new constraints to existing ones. Do we need to |
405 | // consider subsumption? |
406 | return {}; |
407 | } |
408 | |
409 | SourceRange RemovalRange; |
410 | const TemplateParameterList *TemplateParams = |
411 | FunctionTemplate->getTemplateParameters(); |
412 | if (!TemplateParams || TemplateParams->size() == 0) |
413 | return {}; |
414 | |
415 | if (TemplateParams->size() == 1) { |
416 | RemovalRange = |
417 | SourceRange(TemplateParams->getTemplateLoc(), |
418 | getRAngleFileLoc(SM, Element: *TemplateParams).getLocWithOffset(Offset: 1)); |
419 | } else { |
420 | RemovalRange = |
421 | SourceRange(utils::lexer::findPreviousTokenKind( |
422 | Start: LastTemplateParam->getSourceRange().getBegin(), SM, |
423 | LangOpts, TK: tok::comma), |
424 | getRAngleFileLoc(SM, Element: *TemplateParams)); |
425 | } |
426 | |
427 | std::optional<SourceLocation> ConstraintInsertionLoc = |
428 | findInsertionForConstraint(Function, Context); |
429 | if (!ConstraintInsertionLoc) |
430 | return {}; |
431 | |
432 | std::vector<FixItHint> FixIts; |
433 | FixIts.push_back( |
434 | x: FixItHint::CreateRemoval(RemoveRange: CharSourceRange::getCharRange(R: RemovalRange))); |
435 | FixIts.push_back(x: FixItHint::CreateInsertion( |
436 | InsertionLoc: *ConstraintInsertionLoc, Code: "requires " + *ConditionText + " " )); |
437 | return FixIts; |
438 | } |
439 | |
440 | void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) { |
441 | const auto *FunctionTemplate = |
442 | Result.Nodes.getNodeAs<FunctionTemplateDecl>(ID: "functionTemplate" ); |
443 | const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>(ID: "function" ); |
444 | const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>(ID: "return" ); |
445 | if (!FunctionTemplate || !Function || !ReturnType) |
446 | return; |
447 | |
448 | // Check for |
449 | // |
450 | // Case 1. Return type of function |
451 | // |
452 | // template <...> |
453 | // enable_if_t<Condition, ReturnType>::type function() {} |
454 | // |
455 | // Case 2. Trailing template parameter |
456 | // |
457 | // template <..., enable_if_t<Condition, Type> = Type{}> |
458 | // ReturnType function() {} |
459 | // |
460 | // or |
461 | // |
462 | // template <..., typename = enable_if_t<Condition, void>> |
463 | // ReturnType function() {} |
464 | // |
465 | |
466 | // Case 1. Return type of function |
467 | if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) { |
468 | diag(Loc: ReturnType->getBeginLoc(), |
469 | Description: "use C++20 requires constraints instead of enable_if" ) |
470 | << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context); |
471 | return; |
472 | } |
473 | |
474 | // Case 2. Trailing template parameter |
475 | if (auto [EnableIf, LastTemplateParam] = |
476 | matchTrailingTemplateParam(FunctionTemplate); |
477 | EnableIf && LastTemplateParam) { |
478 | diag(LastTemplateParam->getSourceRange().getBegin(), |
479 | "use C++20 requires constraints instead of enable_if" ) |
480 | << handleTrailingTemplateType(FunctionTemplate, Function, |
481 | LastTemplateParam, *EnableIf, |
482 | *Result.Context); |
483 | return; |
484 | } |
485 | } |
486 | |
487 | } // namespace clang::tidy::modernize |
488 | |