1 | //===--- MinMaxUseInitializerListCheck.cpp - clang-tidy -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "MinMaxUseInitializerListCheck.h" |
10 | #include "../utils/ASTUtils.h" |
11 | #include "../utils/LexerUtils.h" |
12 | #include "clang/ASTMatchers/ASTMatchFinder.h" |
13 | #include "clang/Frontend/CompilerInstance.h" |
14 | #include "clang/Lex/Lexer.h" |
15 | |
16 | using namespace clang; |
17 | |
18 | namespace { |
19 | |
20 | struct FindArgsResult { |
21 | const Expr *First; |
22 | const Expr *Last; |
23 | const Expr *Compare; |
24 | SmallVector<const clang::Expr *, 2> Args; |
25 | }; |
26 | |
27 | } // anonymous namespace |
28 | |
29 | using namespace clang::ast_matchers; |
30 | |
31 | namespace clang::tidy::modernize { |
32 | |
33 | static FindArgsResult findArgs(const CallExpr *Call) { |
34 | FindArgsResult Result; |
35 | Result.First = nullptr; |
36 | Result.Last = nullptr; |
37 | Result.Compare = nullptr; |
38 | |
39 | // check if the function has initializer list argument |
40 | if (Call->getNumArgs() < 3) { |
41 | auto ArgIterator = Call->arguments().begin(); |
42 | |
43 | const auto *InitListExpr = |
44 | dyn_cast<CXXStdInitializerListExpr>(Val: *ArgIterator); |
45 | const auto *InitList = |
46 | InitListExpr != nullptr |
47 | ? dyn_cast<clang::InitListExpr>( |
48 | Val: InitListExpr->getSubExpr()->IgnoreImplicit()) |
49 | : nullptr; |
50 | |
51 | if (InitList) { |
52 | Result.Args.append(in_start: InitList->inits().begin(), in_end: InitList->inits().end()); |
53 | Result.First = *ArgIterator; |
54 | Result.Last = *ArgIterator; |
55 | |
56 | // check if there is a comparison argument |
57 | std::advance(i&: ArgIterator, n: 1); |
58 | if (ArgIterator != Call->arguments().end()) |
59 | Result.Compare = *ArgIterator; |
60 | |
61 | return Result; |
62 | } |
63 | Result.Args = SmallVector<const Expr *>(Call->arguments()); |
64 | } else { |
65 | // if it has 3 arguments then the last will be the comparison |
66 | Result.Compare = *(std::next(x: Call->arguments().begin(), n: 2)); |
67 | Result.Args = SmallVector<const Expr *>(llvm::drop_end(RangeOrContainer: Call->arguments())); |
68 | } |
69 | Result.First = Result.Args.front(); |
70 | Result.Last = Result.Args.back(); |
71 | |
72 | return Result; |
73 | } |
74 | |
75 | // Returns `true` as `first` only if a nested call to `std::min` or |
76 | // `std::max` was found. Checking if `FixItHint`s were generated is not enough, |
77 | // as the explicit casts that the check introduces may be generated without a |
78 | // nested `std::min` or `std::max` call. |
79 | static std::pair<bool, SmallVector<FixItHint>> |
80 | generateReplacements(const MatchFinder::MatchResult &Match, |
81 | const CallExpr *TopCall, const FindArgsResult &Result, |
82 | const bool IgnoreNonTrivialTypes, |
83 | const std::uint64_t IgnoreTrivialTypesOfSizeAbove) { |
84 | SmallVector<FixItHint> FixItHints; |
85 | const SourceManager &SourceMngr = *Match.SourceManager; |
86 | const LangOptions &LanguageOpts = Match.Context->getLangOpts(); |
87 | |
88 | const QualType ResultType = TopCall->getDirectCallee() |
89 | ->getReturnType() |
90 | .getCanonicalType() |
91 | .getNonReferenceType() |
92 | .getUnqualifiedType(); |
93 | |
94 | // check if the type is trivial |
95 | const bool IsResultTypeTrivial = ResultType.isTrivialType(Context: *Match.Context); |
96 | |
97 | if ((!IsResultTypeTrivial && IgnoreNonTrivialTypes)) |
98 | return {false, FixItHints}; |
99 | |
100 | if (IsResultTypeTrivial && |
101 | static_cast<std::uint64_t>( |
102 | Match.Context->getTypeSizeInChars(T: ResultType).getQuantity()) > |
103 | IgnoreTrivialTypesOfSizeAbove) |
104 | return {false, FixItHints}; |
105 | |
106 | bool FoundNestedCall = false; |
107 | |
108 | for (const Expr *Arg : Result.Args) { |
109 | const auto *InnerCall = dyn_cast<CallExpr>(Val: Arg->IgnoreParenImpCasts()); |
110 | |
111 | // If the argument is not a nested call |
112 | if (!InnerCall) { |
113 | // check if typecast is required |
114 | const QualType ArgType = Arg->IgnoreParenImpCasts() |
115 | ->getType() |
116 | .getCanonicalType() |
117 | .getUnqualifiedType(); |
118 | |
119 | if (ArgType == ResultType) |
120 | continue; |
121 | |
122 | const StringRef ArgText = Lexer::getSourceText( |
123 | Range: CharSourceRange::getTokenRange(Arg->getSourceRange()), SM: SourceMngr, |
124 | LangOpts: LanguageOpts); |
125 | |
126 | const auto Replacement = Twine("static_cast<" ) |
127 | .concat(Suffix: ResultType.getAsString(Policy: LanguageOpts)) |
128 | .concat(Suffix: ">(" ) |
129 | .concat(Suffix: ArgText) |
130 | .concat(Suffix: ")" ) |
131 | .str(); |
132 | |
133 | FixItHints.push_back( |
134 | FixItHint::CreateReplacement(Arg->getSourceRange(), Replacement)); |
135 | continue; |
136 | } |
137 | |
138 | // if the nested call is not the same as the top call |
139 | if (InnerCall->getDirectCallee()->getQualifiedNameAsString() != |
140 | TopCall->getDirectCallee()->getQualifiedNameAsString()) |
141 | continue; |
142 | |
143 | const FindArgsResult InnerResult = findArgs(Call: InnerCall); |
144 | |
145 | // if the nested call doesn't have arguments skip it |
146 | if (!InnerResult.First || !InnerResult.Last) |
147 | continue; |
148 | |
149 | // if the nested call doesn't have the same compare function |
150 | if ((Result.Compare || InnerResult.Compare) && |
151 | !utils::areStatementsIdentical(Result.Compare, InnerResult.Compare, |
152 | *Match.Context)) |
153 | continue; |
154 | |
155 | // We have found a nested call |
156 | FoundNestedCall = true; |
157 | |
158 | // remove the function call |
159 | FixItHints.push_back( |
160 | FixItHint::CreateRemoval(InnerCall->getCallee()->getSourceRange())); |
161 | |
162 | // remove the parentheses |
163 | const auto LParen = utils::lexer::findNextTokenSkippingComments( |
164 | Start: InnerCall->getCallee()->getEndLoc(), SM: SourceMngr, LangOpts: LanguageOpts); |
165 | if (LParen.has_value() && LParen->is(tok::l_paren)) |
166 | FixItHints.push_back( |
167 | Elt: FixItHint::CreateRemoval(RemoveRange: SourceRange(LParen->getLocation()))); |
168 | FixItHints.push_back( |
169 | Elt: FixItHint::CreateRemoval(RemoveRange: SourceRange(InnerCall->getRParenLoc()))); |
170 | |
171 | // if the inner call has an initializer list arg |
172 | if (InnerResult.First == InnerResult.Last) { |
173 | // remove the initializer list braces |
174 | FixItHints.push_back(FixItHint::CreateRemoval( |
175 | CharSourceRange::getTokenRange(InnerResult.First->getBeginLoc()))); |
176 | FixItHints.push_back(FixItHint::CreateRemoval( |
177 | CharSourceRange::getTokenRange(InnerResult.First->getEndLoc()))); |
178 | } |
179 | |
180 | const auto [_, InnerReplacements] = generateReplacements( |
181 | Match, TopCall: InnerCall, Result: InnerResult, IgnoreNonTrivialTypes, |
182 | IgnoreTrivialTypesOfSizeAbove); |
183 | |
184 | FixItHints.append(RHS: InnerReplacements); |
185 | |
186 | if (InnerResult.Compare) { |
187 | // find the comma after the value arguments |
188 | const auto Comma = utils::lexer::findNextTokenSkippingComments( |
189 | Start: InnerResult.Last->getEndLoc(), SM: SourceMngr, LangOpts: LanguageOpts); |
190 | |
191 | // remove the comma and the comparison |
192 | if (Comma.has_value() && Comma->is(tok::comma)) |
193 | FixItHints.push_back( |
194 | Elt: FixItHint::CreateRemoval(RemoveRange: SourceRange(Comma->getLocation()))); |
195 | |
196 | FixItHints.push_back( |
197 | FixItHint::CreateRemoval(InnerResult.Compare->getSourceRange())); |
198 | } |
199 | } |
200 | |
201 | return {FoundNestedCall, FixItHints}; |
202 | } |
203 | |
204 | MinMaxUseInitializerListCheck::MinMaxUseInitializerListCheck( |
205 | StringRef Name, ClangTidyContext *Context) |
206 | : ClangTidyCheck(Name, Context), |
207 | IgnoreNonTrivialTypes(Options.get(LocalName: "IgnoreNonTrivialTypes" , Default: true)), |
208 | IgnoreTrivialTypesOfSizeAbove( |
209 | Options.get(LocalName: "IgnoreTrivialTypesOfSizeAbove" , Default: 32L)), |
210 | Inserter(Options.getLocalOrGlobal(LocalName: "IncludeStyle" , |
211 | Default: utils::IncludeSorter::IS_LLVM), |
212 | areDiagsSelfContained()) {} |
213 | |
214 | void MinMaxUseInitializerListCheck::storeOptions( |
215 | ClangTidyOptions::OptionMap &Opts) { |
216 | Options.store(Options&: Opts, LocalName: "IgnoreNonTrivialTypes" , Value: IgnoreNonTrivialTypes); |
217 | Options.store(Options&: Opts, LocalName: "IgnoreTrivialTypesOfSizeAbove" , |
218 | Value: IgnoreTrivialTypesOfSizeAbove); |
219 | Options.store(Options&: Opts, LocalName: "IncludeStyle" , Value: Inserter.getStyle()); |
220 | } |
221 | |
222 | void MinMaxUseInitializerListCheck::registerMatchers(MatchFinder *Finder) { |
223 | auto CreateMatcher = [](const StringRef FunctionName) { |
224 | auto FuncDecl = functionDecl(hasName(Name: FunctionName)); |
225 | auto Expression = callExpr(callee(InnerMatcher: FuncDecl)); |
226 | |
227 | return callExpr(callee(InnerMatcher: FuncDecl), |
228 | anyOf(hasArgument(N: 0, InnerMatcher: Expression), |
229 | hasArgument(N: 1, InnerMatcher: Expression), |
230 | hasArgument(N: 0, InnerMatcher: cxxStdInitializerListExpr())), |
231 | unless(hasParent(Expression))) |
232 | .bind(ID: "topCall" ); |
233 | }; |
234 | |
235 | Finder->addMatcher(NodeMatch: CreateMatcher("::std::max" ), Action: this); |
236 | Finder->addMatcher(NodeMatch: CreateMatcher("::std::min" ), Action: this); |
237 | } |
238 | |
239 | void MinMaxUseInitializerListCheck::registerPPCallbacks( |
240 | const SourceManager &SM, Preprocessor *PP, Preprocessor *ModuleExpanderPP) { |
241 | Inserter.registerPreprocessor(PP); |
242 | } |
243 | |
244 | void MinMaxUseInitializerListCheck::check( |
245 | const MatchFinder::MatchResult &Match) { |
246 | |
247 | const auto *TopCall = Match.Nodes.getNodeAs<CallExpr>(ID: "topCall" ); |
248 | |
249 | const FindArgsResult Result = findArgs(Call: TopCall); |
250 | const auto [FoundNestedCall, Replacements] = |
251 | generateReplacements(Match, TopCall, Result, IgnoreNonTrivialTypes, |
252 | IgnoreTrivialTypesOfSizeAbove); |
253 | |
254 | if (!FoundNestedCall) |
255 | return; |
256 | |
257 | const DiagnosticBuilder Diagnostic = |
258 | diag(Loc: TopCall->getBeginLoc(), |
259 | Description: "do not use nested 'std::%0' calls, use an initializer list instead" ) |
260 | << TopCall->getDirectCallee()->getName() |
261 | << Inserter.createIncludeInsertion( |
262 | FileID: Match.SourceManager->getFileID(SpellingLoc: TopCall->getBeginLoc()), |
263 | Header: "<algorithm>" ); |
264 | |
265 | // if the top call doesn't have an initializer list argument |
266 | if (Result.First != Result.Last) { |
267 | // add { and } insertions |
268 | Diagnostic << FixItHint::CreateInsertion(InsertionLoc: Result.First->getBeginLoc(), Code: "{" ); |
269 | |
270 | Diagnostic << FixItHint::CreateInsertion( |
271 | InsertionLoc: Lexer::getLocForEndOfToken(Loc: Result.Last->getEndLoc(), Offset: 0, |
272 | SM: *Match.SourceManager, |
273 | LangOpts: Match.Context->getLangOpts()), |
274 | Code: "}" ); |
275 | } |
276 | |
277 | Diagnostic << Replacements; |
278 | } |
279 | |
280 | } // namespace clang::tidy::modernize |
281 | |