1//===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "UseConstraintsCheck.h"
10#include "clang/AST/ASTContext.h"
11#include "clang/ASTMatchers/ASTMatchFinder.h"
12#include "clang/Lex/Lexer.h"
13
14#include "../utils/LexerUtils.h"
15
16#include <optional>
17#include <utility>
18
19using namespace clang::ast_matchers;
20
21namespace clang::tidy::modernize {
22
23struct EnableIfData {
24 TemplateSpecializationTypeLoc Loc;
25 TypeLoc Outer;
26};
27
28namespace {
29AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
30 auto It = Node.redecls_begin();
31 auto EndIt = Node.redecls_end();
32
33 if (It == EndIt)
34 return false;
35
36 ++It;
37 return It != EndIt;
38}
39} // namespace
40
41void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
42 Finder->addMatcher(
43 NodeMatch: functionTemplateDecl(
44 has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
45 hasReturnTypeLoc(ReturnMatcher: typeLoc().bind(ID: "return")))
46 .bind(ID: "function")))
47 .bind(ID: "functionTemplate"),
48 Action: this);
49}
50
51static std::optional<TemplateSpecializationTypeLoc>
52matchEnableIfSpecializationImplTypename(TypeLoc TheType) {
53 if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
54 const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
55 if (!Identifier || Identifier->getName() != "type" ||
56 Dep.getTypePtr()->getKeyword() != ElaboratedTypeKeyword::Typename) {
57 return std::nullopt;
58 }
59 TheType = Dep.getQualifierLoc().getTypeLoc();
60 }
61
62 if (const auto SpecializationLoc =
63 TheType.getAs<TemplateSpecializationTypeLoc>()) {
64
65 const auto *Specialization =
66 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
67 if (!Specialization)
68 return std::nullopt;
69
70 const TemplateDecl *TD =
71 Specialization->getTemplateName().getAsTemplateDecl();
72 if (!TD || TD->getName() != "enable_if")
73 return std::nullopt;
74
75 int NumArgs = SpecializationLoc.getNumArgs();
76 if (NumArgs != 1 && NumArgs != 2)
77 return std::nullopt;
78
79 return SpecializationLoc;
80 }
81 return std::nullopt;
82}
83
84static std::optional<TemplateSpecializationTypeLoc>
85matchEnableIfSpecializationImplTrait(TypeLoc TheType) {
86 if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>())
87 TheType = Elaborated.getNamedTypeLoc();
88
89 if (const auto SpecializationLoc =
90 TheType.getAs<TemplateSpecializationTypeLoc>()) {
91
92 const auto *Specialization =
93 dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
94 if (!Specialization)
95 return std::nullopt;
96
97 const TemplateDecl *TD =
98 Specialization->getTemplateName().getAsTemplateDecl();
99 if (!TD || TD->getName() != "enable_if_t")
100 return std::nullopt;
101
102 if (!Specialization->isTypeAlias())
103 return std::nullopt;
104
105 if (const auto *AliasedType =
106 dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
107 if (AliasedType->getIdentifier()->getName() != "type" ||
108 AliasedType->getKeyword() != ElaboratedTypeKeyword::Typename) {
109 return std::nullopt;
110 }
111 } else {
112 return std::nullopt;
113 }
114 int NumArgs = SpecializationLoc.getNumArgs();
115 if (NumArgs != 1 && NumArgs != 2)
116 return std::nullopt;
117
118 return SpecializationLoc;
119 }
120 return std::nullopt;
121}
122
123static std::optional<TemplateSpecializationTypeLoc>
124matchEnableIfSpecializationImpl(TypeLoc TheType) {
125 if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
126 return EnableIf;
127 return matchEnableIfSpecializationImplTrait(TheType);
128}
129
130static std::optional<EnableIfData>
131matchEnableIfSpecialization(TypeLoc TheType) {
132 if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
133 TheType = Pointer.getPointeeLoc();
134 else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
135 TheType = Reference.getPointeeLoc();
136 if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
137 TheType = Qualified.getUnqualifiedLoc();
138
139 if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
140 return EnableIfData{std::move(*EnableIf), TheType};
141 return std::nullopt;
142}
143
144static std::pair<std::optional<EnableIfData>, const Decl *>
145matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
146 // For non-type trailing param, match very specifically
147 // 'template <..., enable_if_type<Condition, Type> = Default>' where
148 // enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
149 // T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
150 //
151 // Otherwise, match a trailing default type arg.
152 // E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
153
154 const TemplateParameterList *TemplateParams =
155 FunctionTemplate->getTemplateParameters();
156 if (TemplateParams->size() == 0)
157 return {};
158
159 const NamedDecl *LastParam =
160 TemplateParams->getParam(Idx: TemplateParams->size() - 1);
161 if (const auto *LastTemplateParam =
162 dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
163
164 if (!LastTemplateParam->hasDefaultArgument() ||
165 !LastTemplateParam->getName().empty())
166 return {};
167
168 return {matchEnableIfSpecialization(
169 LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
170 LastTemplateParam};
171 }
172 if (const auto *LastTemplateParam =
173 dyn_cast<TemplateTypeParmDecl>(LastParam)) {
174 if (LastTemplateParam->hasDefaultArgument() &&
175 LastTemplateParam->getIdentifier() == nullptr) {
176 return {matchEnableIfSpecialization(
177 LastTemplateParam->getDefaultArgumentInfo()->getTypeLoc()),
178 LastTemplateParam};
179 }
180 }
181 return {};
182}
183
184template <typename T>
185static SourceLocation getRAngleFileLoc(const SourceManager &SM,
186 const T &Element) {
187 // getFileLoc handles the case where the RAngle loc is part of a synthesized
188 // '>>', which ends up allocating a 'scratch space' buffer in the source
189 // manager.
190 return SM.getFileLoc(Loc: Element.getRAngleLoc());
191}
192
193static SourceRange
194getConditionRange(ASTContext &Context,
195 const TemplateSpecializationTypeLoc &EnableIf) {
196 // TemplateArgumentLoc's SourceRange End is the location of the last token
197 // (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
198 // location will be the first 'B' in 'BBB'.
199 const LangOptions &LangOpts = Context.getLangOpts();
200 const SourceManager &SM = Context.getSourceManager();
201 if (EnableIf.getNumArgs() > 1) {
202 TemplateArgumentLoc NextArg = EnableIf.getArgLoc(i: 1);
203 return {EnableIf.getLAngleLoc().getLocWithOffset(Offset: 1),
204 utils::lexer::findPreviousTokenKind(
205 Start: NextArg.getSourceRange().getBegin(), SM, LangOpts, TK: tok::comma)};
206 }
207
208 return {EnableIf.getLAngleLoc().getLocWithOffset(Offset: 1),
209 getRAngleFileLoc(SM, Element: EnableIf)};
210}
211
212static SourceRange getTypeRange(ASTContext &Context,
213 const TemplateSpecializationTypeLoc &EnableIf) {
214 TemplateArgumentLoc Arg = EnableIf.getArgLoc(i: 1);
215 const LangOptions &LangOpts = Context.getLangOpts();
216 const SourceManager &SM = Context.getSourceManager();
217 return {utils::lexer::findPreviousTokenKind(Start: Arg.getSourceRange().getBegin(),
218 SM, LangOpts, TK: tok::comma)
219 .getLocWithOffset(Offset: 1),
220 getRAngleFileLoc(SM, Element: EnableIf)};
221}
222
223// Returns the original source text of the second argument of a call to
224// enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
225// returns 'TheType'.
226static std::optional<StringRef>
227getTypeText(ASTContext &Context,
228 const TemplateSpecializationTypeLoc &EnableIf) {
229 if (EnableIf.getNumArgs() > 1) {
230 const LangOptions &LangOpts = Context.getLangOpts();
231 const SourceManager &SM = Context.getSourceManager();
232 bool Invalid = false;
233 StringRef Text = Lexer::getSourceText(Range: CharSourceRange::getCharRange(
234 R: getTypeRange(Context, EnableIf)),
235 SM, LangOpts, Invalid: &Invalid)
236 .trim();
237 if (Invalid)
238 return std::nullopt;
239
240 return Text;
241 }
242
243 return "void";
244}
245
246static std::optional<SourceLocation>
247findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
248 SourceManager &SM = Context.getSourceManager();
249 const LangOptions &LangOpts = Context.getLangOpts();
250
251 if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Val: Function)) {
252 for (const CXXCtorInitializer *Init : Constructor->inits()) {
253 if (Init->getSourceOrder() == 0)
254 return utils::lexer::findPreviousTokenKind(Start: Init->getSourceLocation(),
255 SM, LangOpts, TK: tok::colon);
256 }
257 if (Constructor->init_begin() != Constructor->init_end())
258 return std::nullopt;
259 }
260 if (Function->isDeleted()) {
261 SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
262 return utils::lexer::findNextAnyTokenKind(Start: FunctionEnd, SM, LangOpts,
263 TK: tok::equal, TKs: tok::equal);
264 }
265 const Stmt *Body = Function->getBody();
266 if (!Body)
267 return std::nullopt;
268
269 return Body->getBeginLoc();
270}
271
272bool isPrimaryExpression(const Expr *Expression) {
273 // This function is an incomplete approximation of checking whether
274 // an Expr is a primary expression. In particular, if this function
275 // returns true, the expression is a primary expression. The converse
276 // is not necessarily true.
277
278 if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Val: Expression))
279 Expression = Cast->getSubExprAsWritten();
280 if (isa<ParenExpr, DependentScopeDeclRefExpr>(Val: Expression))
281 return true;
282
283 return false;
284}
285
286// Return the original source text of an enable_if_t condition, i.e., the
287// first template argument). For example, in
288// 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
289// the text 'FirstCondition || SecondCondition' is returned.
290static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
291 SourceRange ConditionRange,
292 ASTContext &Context) {
293 SourceManager &SM = Context.getSourceManager();
294 const LangOptions &LangOpts = Context.getLangOpts();
295
296 SourceLocation PrevTokenLoc = ConditionRange.getEnd();
297 if (PrevTokenLoc.isInvalid())
298 return std::nullopt;
299
300 const bool SkipComments = false;
301 Token PrevToken;
302 std::tie(args&: PrevToken, args&: PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
303 Location: PrevTokenLoc, SM, LangOpts, SkipComments);
304 bool EndsWithDoubleSlash =
305 PrevToken.is(K: tok::comment) &&
306 Lexer::getSourceText(Range: CharSourceRange::getCharRange(
307 B: PrevTokenLoc, E: PrevTokenLoc.getLocWithOffset(Offset: 2)),
308 SM, LangOpts) == "//";
309
310 bool Invalid = false;
311 llvm::StringRef ConditionText = Lexer::getSourceText(
312 Range: CharSourceRange::getCharRange(R: ConditionRange), SM, LangOpts, Invalid: &Invalid);
313 if (Invalid)
314 return std::nullopt;
315
316 auto AddParens = [&](llvm::StringRef Text) -> std::string {
317 if (isPrimaryExpression(Expression: ConditionExpr))
318 return Text.str();
319 return "(" + Text.str() + ")";
320 };
321
322 if (EndsWithDoubleSlash)
323 return AddParens(ConditionText);
324 return AddParens(ConditionText.trim());
325}
326
327// Handle functions that return enable_if_t, e.g.,
328// template <...>
329// enable_if_t<Condition, ReturnType> function();
330//
331// Return a vector of FixItHints if the code can be replaced with
332// a C++20 requires clause. In the example above, returns FixItHints
333// to result in
334// template <...>
335// ReturnType function() requires Condition {}
336static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
337 const TypeLoc &ReturnType,
338 const EnableIfData &EnableIf,
339 ASTContext &Context) {
340 TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
341
342 SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
343
344 std::optional<std::string> ConditionText = getConditionText(
345 ConditionExpr: EnableCondition.getSourceExpression(), ConditionRange, Context);
346 if (!ConditionText)
347 return {};
348
349 std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
350 if (!TypeText)
351 return {};
352
353 SmallVector<const Expr *, 3> ExistingConstraints;
354 Function->getAssociatedConstraints(AC&: ExistingConstraints);
355 if (!ExistingConstraints.empty()) {
356 // FIXME - Support adding new constraints to existing ones. Do we need to
357 // consider subsumption?
358 return {};
359 }
360
361 std::optional<SourceLocation> ConstraintInsertionLoc =
362 findInsertionForConstraint(Function, Context);
363 if (!ConstraintInsertionLoc)
364 return {};
365
366 std::vector<FixItHint> FixIts;
367 FixIts.push_back(x: FixItHint::CreateReplacement(
368 RemoveRange: CharSourceRange::getTokenRange(R: EnableIf.Outer.getSourceRange()),
369 Code: *TypeText));
370 FixIts.push_back(x: FixItHint::CreateInsertion(
371 InsertionLoc: *ConstraintInsertionLoc, Code: "requires " + *ConditionText + " "));
372 return FixIts;
373}
374
375// Handle enable_if_t in a trailing template parameter, e.g.,
376// template <..., enable_if_t<Condition, Type> = Type{}>
377// ReturnType function();
378//
379// Return a vector of FixItHints if the code can be replaced with
380// a C++20 requires clause. In the example above, returns FixItHints
381// to result in
382// template <...>
383// ReturnType function() requires Condition {}
384static std::vector<FixItHint>
385handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
386 const FunctionDecl *Function,
387 const Decl *LastTemplateParam,
388 const EnableIfData &EnableIf, ASTContext &Context) {
389 SourceManager &SM = Context.getSourceManager();
390 const LangOptions &LangOpts = Context.getLangOpts();
391
392 TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
393
394 SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
395
396 std::optional<std::string> ConditionText = getConditionText(
397 ConditionExpr: EnableCondition.getSourceExpression(), ConditionRange, Context);
398 if (!ConditionText)
399 return {};
400
401 SmallVector<const Expr *, 3> ExistingConstraints;
402 Function->getAssociatedConstraints(AC&: ExistingConstraints);
403 if (!ExistingConstraints.empty()) {
404 // FIXME - Support adding new constraints to existing ones. Do we need to
405 // consider subsumption?
406 return {};
407 }
408
409 SourceRange RemovalRange;
410 const TemplateParameterList *TemplateParams =
411 FunctionTemplate->getTemplateParameters();
412 if (!TemplateParams || TemplateParams->size() == 0)
413 return {};
414
415 if (TemplateParams->size() == 1) {
416 RemovalRange =
417 SourceRange(TemplateParams->getTemplateLoc(),
418 getRAngleFileLoc(SM, Element: *TemplateParams).getLocWithOffset(Offset: 1));
419 } else {
420 RemovalRange =
421 SourceRange(utils::lexer::findPreviousTokenKind(
422 Start: LastTemplateParam->getSourceRange().getBegin(), SM,
423 LangOpts, TK: tok::comma),
424 getRAngleFileLoc(SM, Element: *TemplateParams));
425 }
426
427 std::optional<SourceLocation> ConstraintInsertionLoc =
428 findInsertionForConstraint(Function, Context);
429 if (!ConstraintInsertionLoc)
430 return {};
431
432 std::vector<FixItHint> FixIts;
433 FixIts.push_back(
434 x: FixItHint::CreateRemoval(RemoveRange: CharSourceRange::getCharRange(R: RemovalRange)));
435 FixIts.push_back(x: FixItHint::CreateInsertion(
436 InsertionLoc: *ConstraintInsertionLoc, Code: "requires " + *ConditionText + " "));
437 return FixIts;
438}
439
440void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
441 const auto *FunctionTemplate =
442 Result.Nodes.getNodeAs<FunctionTemplateDecl>(ID: "functionTemplate");
443 const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>(ID: "function");
444 const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>(ID: "return");
445 if (!FunctionTemplate || !Function || !ReturnType)
446 return;
447
448 // Check for
449 //
450 // Case 1. Return type of function
451 //
452 // template <...>
453 // enable_if_t<Condition, ReturnType>::type function() {}
454 //
455 // Case 2. Trailing template parameter
456 //
457 // template <..., enable_if_t<Condition, Type> = Type{}>
458 // ReturnType function() {}
459 //
460 // or
461 //
462 // template <..., typename = enable_if_t<Condition, void>>
463 // ReturnType function() {}
464 //
465
466 // Case 1. Return type of function
467 if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
468 diag(Loc: ReturnType->getBeginLoc(),
469 Description: "use C++20 requires constraints instead of enable_if")
470 << handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
471 return;
472 }
473
474 // Case 2. Trailing template parameter
475 if (auto [EnableIf, LastTemplateParam] =
476 matchTrailingTemplateParam(FunctionTemplate);
477 EnableIf && LastTemplateParam) {
478 diag(LastTemplateParam->getSourceRange().getBegin(),
479 "use C++20 requires constraints instead of enable_if")
480 << handleTrailingTemplateType(FunctionTemplate, Function,
481 LastTemplateParam, *EnableIf,
482 *Result.Context);
483 return;
484 }
485}
486
487} // namespace clang::tidy::modernize
488

source code of clang-tools-extra/clang-tidy/modernize/UseConstraintsCheck.cpp