1 | //===--- UseStdNumbersCheck.cpp - clang_tidy ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX_License_Identifier: Apache_2.0 WITH LLVM_exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "UseStdNumbersCheck.h" |
10 | #include "../ClangTidyDiagnosticConsumer.h" |
11 | #include "clang/AST/ASTContext.h" |
12 | #include "clang/AST/Decl.h" |
13 | #include "clang/AST/Expr.h" |
14 | #include "clang/AST/Stmt.h" |
15 | #include "clang/AST/Type.h" |
16 | #include "clang/ASTMatchers/ASTMatchFinder.h" |
17 | #include "clang/ASTMatchers/ASTMatchers.h" |
18 | #include "clang/ASTMatchers/ASTMatchersInternal.h" |
19 | #include "clang/ASTMatchers/ASTMatchersMacros.h" |
20 | #include "clang/Basic/Diagnostic.h" |
21 | #include "clang/Basic/LLVM.h" |
22 | #include "clang/Basic/LangOptions.h" |
23 | #include "clang/Basic/SourceLocation.h" |
24 | #include "clang/Basic/SourceManager.h" |
25 | #include "clang/Lex/Lexer.h" |
26 | #include "llvm/ADT/STLExtras.h" |
27 | #include "llvm/ADT/SmallVector.h" |
28 | #include "llvm/ADT/StringRef.h" |
29 | #include "llvm/Support/FormatVariadic.h" |
30 | #include "llvm/Support/MathExtras.h" |
31 | #include <array> |
32 | #include <cmath> |
33 | #include <cstdint> |
34 | #include <cstdlib> |
35 | #include <initializer_list> |
36 | #include <string> |
37 | #include <tuple> |
38 | #include <utility> |
39 | |
40 | namespace { |
41 | using namespace clang::ast_matchers; |
42 | using clang::ast_matchers::internal::Matcher; |
43 | using llvm::StringRef; |
44 | |
45 | AST_MATCHER_P2(clang::FloatingLiteral, near, double, Value, double, |
46 | DiffThreshold) { |
47 | return std::abs(x: Node.getValueAsApproximateDouble() - Value) < DiffThreshold; |
48 | } |
49 | |
50 | AST_MATCHER_P(clang::QualType, hasCanonicalTypeUnqualified, |
51 | Matcher<clang::QualType>, InnerMatcher) { |
52 | return !Node.isNull() && |
53 | InnerMatcher.matches(Node: Node->getCanonicalTypeUnqualified(), Finder, |
54 | Builder); |
55 | } |
56 | |
57 | AST_MATCHER(clang::QualType, isArithmetic) { |
58 | return !Node.isNull() && Node->isArithmeticType(); |
59 | } |
60 | AST_MATCHER(clang::QualType, isFloating) { |
61 | return !Node.isNull() && Node->isFloatingType(); |
62 | } |
63 | |
64 | AST_MATCHER_P(clang::Expr, anyOfExhaustive, std::vector<Matcher<clang::Stmt>>, |
65 | Exprs) { |
66 | bool FoundMatch = false; |
67 | for (const auto &InnerMatcher : Exprs) { |
68 | clang::ast_matchers::internal::BoundNodesTreeBuilder Result = *Builder; |
69 | if (InnerMatcher.matches(Node, Finder, &Result)) { |
70 | *Builder = std::move(Result); |
71 | FoundMatch = true; |
72 | } |
73 | } |
74 | return FoundMatch; |
75 | } |
76 | |
77 | // Using this struct to store the 'DiffThreshold' config value to create the |
78 | // matchers without the need to pass 'DiffThreshold' into every matcher. |
79 | // 'DiffThreshold' is needed in the 'near' matcher, which is used for matching |
80 | // the literal of every constant and for formulas' subexpressions that look at |
81 | // literals. |
82 | struct MatchBuilder { |
83 | auto |
84 | ignoreParenAndArithmeticCasting(const Matcher<clang::Expr> Matcher) const { |
85 | return expr(hasType(InnerMatcher: qualType(isArithmetic())), ignoringParenCasts(InnerMatcher: Matcher)); |
86 | } |
87 | |
88 | auto ignoreParenAndFloatingCasting(const Matcher<clang::Expr> Matcher) const { |
89 | return expr(hasType(InnerMatcher: qualType(isFloating())), ignoringParenCasts(InnerMatcher: Matcher)); |
90 | } |
91 | |
92 | auto matchMathCall(const StringRef FunctionName, |
93 | const Matcher<clang::Expr> ArgumentMatcher) const { |
94 | return expr(ignoreParenAndFloatingCasting( |
95 | Matcher: callExpr(callee(InnerMatcher: functionDecl(hasName(Name: FunctionName), |
96 | hasParameter(N: 0, InnerMatcher: hasType(InnerMatcher: isArithmetic())))), |
97 | hasArgument(N: 0, InnerMatcher: ArgumentMatcher)))); |
98 | } |
99 | |
100 | auto matchSqrt(const Matcher<clang::Expr> ArgumentMatcher) const { |
101 | return matchMathCall(FunctionName: "sqrt" , ArgumentMatcher); |
102 | } |
103 | |
104 | // Used for top-level matchers (i.e. the match that replaces Val with its |
105 | // constant). |
106 | // |
107 | // E.g. The matcher of `std::numbers::pi` uses this matcher to look for |
108 | // floatLiterals that have the value of pi. |
109 | // |
110 | // If the match is for a top-level match, we only care about the literal. |
111 | auto matchFloatLiteralNear(const StringRef Constant, const double Val) const { |
112 | return expr(ignoreParenAndFloatingCasting( |
113 | Matcher: floatLiteral(near(Value: Val, DiffThreshold)).bind(ID: Constant))); |
114 | } |
115 | |
116 | // Used for non-top-level matchers (i.e. matchers that are used as inner |
117 | // matchers for top-level matchers). |
118 | // |
119 | // E.g.: The matcher of `std::numbers::log2e` uses this matcher to check if |
120 | // `e` of `log2(e)` is declared constant and initialized with the value for |
121 | // eulers number. |
122 | // |
123 | // Here, we do care about literals and about DeclRefExprs to variable |
124 | // declarations that are constant and initialized with `Val`. This allows |
125 | // top-level matchers to see through declared constants for their inner |
126 | // matches like the `std::numbers::log2e` matcher. |
127 | auto matchFloatValueNear(const double Val) const { |
128 | const auto Float = floatLiteral(near(Value: Val, DiffThreshold)); |
129 | |
130 | const auto Dref = declRefExpr( |
131 | to(InnerMatcher: varDecl(hasType(InnerMatcher: qualType(isConstQualified(), isFloating())), |
132 | hasInitializer(InnerMatcher: ignoreParenAndFloatingCasting(Matcher: Float))))); |
133 | return expr(ignoreParenAndFloatingCasting(Matcher: anyOf(Float, Dref))); |
134 | } |
135 | |
136 | auto matchValue(const int64_t ValInt) const { |
137 | const auto Int = |
138 | expr(ignoreParenAndArithmeticCasting(Matcher: integerLiteral(equals(Value: ValInt)))); |
139 | const auto Float = expr(ignoreParenAndFloatingCasting( |
140 | Matcher: matchFloatValueNear(Val: static_cast<double>(ValInt)))); |
141 | const auto Dref = declRefExpr(to(InnerMatcher: varDecl( |
142 | hasType(InnerMatcher: qualType(isConstQualified(), isArithmetic())), |
143 | hasInitializer(InnerMatcher: expr(anyOf(ignoringImplicit(InnerMatcher: Int), |
144 | ignoreParenAndFloatingCasting(Matcher: Float))))))); |
145 | return expr(anyOf(Int, Float, Dref)); |
146 | } |
147 | |
148 | auto match1Div(const Matcher<clang::Expr> Match) const { |
149 | return binaryOperator(hasOperatorName(Name: "/" ), hasLHS(InnerMatcher: matchValue(ValInt: 1)), |
150 | hasRHS(InnerMatcher: Match)); |
151 | } |
152 | |
153 | auto matchEuler() const { |
154 | return expr(anyOf(matchFloatValueNear(Val: llvm::numbers::e), |
155 | matchMathCall(FunctionName: "exp" , ArgumentMatcher: matchValue(ValInt: 1)))); |
156 | } |
157 | auto matchEulerTopLevel() const { |
158 | return expr(anyOf(matchFloatLiteralNear(Constant: "e_literal" , Val: llvm::numbers::e), |
159 | matchMathCall(FunctionName: "exp" , ArgumentMatcher: matchValue(ValInt: 1)).bind(ID: "e_pattern" ))) |
160 | .bind(ID: "e" ); |
161 | } |
162 | |
163 | auto matchLog2Euler() const { |
164 | return expr( |
165 | anyOf( |
166 | matchFloatLiteralNear(Constant: "log2e_literal" , Val: llvm::numbers::log2e), |
167 | matchMathCall(FunctionName: "log2" , ArgumentMatcher: matchEuler()).bind(ID: "log2e_pattern" ))) |
168 | .bind(ID: "log2e" ); |
169 | } |
170 | |
171 | auto matchLog10Euler() const { |
172 | return expr( |
173 | anyOf( |
174 | matchFloatLiteralNear(Constant: "log10e_literal" , |
175 | Val: llvm::numbers::log10e), |
176 | matchMathCall(FunctionName: "log10" , ArgumentMatcher: matchEuler()).bind(ID: "log10e_pattern" ))) |
177 | .bind(ID: "log10e" ); |
178 | } |
179 | |
180 | auto matchPi() const { return matchFloatValueNear(Val: llvm::numbers::pi); } |
181 | auto matchPiTopLevel() const { |
182 | return matchFloatLiteralNear(Constant: "pi_literal" , Val: llvm::numbers::pi).bind(ID: "pi" ); |
183 | } |
184 | |
185 | auto matchEgamma() const { |
186 | return matchFloatLiteralNear(Constant: "egamma_literal" , Val: llvm::numbers::egamma) |
187 | .bind(ID: "egamma" ); |
188 | } |
189 | |
190 | auto matchInvPi() const { |
191 | return expr(anyOf(matchFloatLiteralNear(Constant: "inv_pi_literal" , |
192 | Val: llvm::numbers::inv_pi), |
193 | match1Div(Match: matchPi()).bind(ID: "inv_pi_pattern" ))) |
194 | .bind(ID: "inv_pi" ); |
195 | } |
196 | |
197 | auto matchInvSqrtPi() const { |
198 | return expr(anyOf( |
199 | matchFloatLiteralNear(Constant: "inv_sqrtpi_literal" , |
200 | Val: llvm::numbers::inv_sqrtpi), |
201 | match1Div(Match: matchSqrt(ArgumentMatcher: matchPi())).bind(ID: "inv_sqrtpi_pattern" ))) |
202 | .bind(ID: "inv_sqrtpi" ); |
203 | } |
204 | |
205 | auto matchLn2() const { |
206 | return expr(anyOf(matchFloatLiteralNear(Constant: "ln2_literal" , Val: llvm::numbers::ln2), |
207 | matchMathCall(FunctionName: "log" , ArgumentMatcher: matchValue(ValInt: 2)).bind(ID: "ln2_pattern" ))) |
208 | .bind(ID: "ln2" ); |
209 | } |
210 | |
211 | auto machterLn10() const { |
212 | return expr( |
213 | anyOf(matchFloatLiteralNear(Constant: "ln10_literal" , Val: llvm::numbers::ln10), |
214 | matchMathCall(FunctionName: "log" , ArgumentMatcher: matchValue(ValInt: 10)).bind(ID: "ln10_pattern" ))) |
215 | .bind(ID: "ln10" ); |
216 | } |
217 | |
218 | auto matchSqrt2() const { |
219 | return expr(anyOf(matchFloatLiteralNear(Constant: "sqrt2_literal" , |
220 | Val: llvm::numbers::sqrt2), |
221 | matchSqrt(ArgumentMatcher: matchValue(ValInt: 2)).bind(ID: "sqrt2_pattern" ))) |
222 | .bind(ID: "sqrt2" ); |
223 | } |
224 | |
225 | auto matchSqrt3() const { |
226 | return expr(anyOf(matchFloatLiteralNear(Constant: "sqrt3_literal" , |
227 | Val: llvm::numbers::sqrt3), |
228 | matchSqrt(ArgumentMatcher: matchValue(ValInt: 3)).bind(ID: "sqrt3_pattern" ))) |
229 | .bind(ID: "sqrt3" ); |
230 | } |
231 | |
232 | auto matchInvSqrt3() const { |
233 | return expr(anyOf(matchFloatLiteralNear(Constant: "inv_sqrt3_literal" , |
234 | Val: llvm::numbers::inv_sqrt3), |
235 | match1Div(Match: matchSqrt(ArgumentMatcher: matchValue(ValInt: 3))) |
236 | .bind(ID: "inv_sqrt3_pattern" ))) |
237 | .bind(ID: "inv_sqrt3" ); |
238 | } |
239 | |
240 | auto matchPhi() const { |
241 | const auto PhiFormula = binaryOperator( |
242 | hasOperatorName(Name: "/" ), |
243 | hasLHS(InnerMatcher: binaryOperator( |
244 | hasOperatorName(Name: "+" ), hasEitherOperand(InnerMatcher: matchValue(ValInt: 1)), |
245 | hasEitherOperand(InnerMatcher: matchMathCall(FunctionName: "sqrt" , ArgumentMatcher: matchValue(ValInt: 5))))), |
246 | hasRHS(InnerMatcher: matchValue(ValInt: 2))); |
247 | return expr(anyOf(PhiFormula.bind(ID: "phi_pattern" ), |
248 | matchFloatLiteralNear(Constant: "phi_literal" , Val: llvm::numbers::phi))) |
249 | .bind(ID: "phi" ); |
250 | } |
251 | |
252 | double DiffThreshold; |
253 | }; |
254 | |
255 | std::string getCode(const StringRef Constant, const bool IsFloat, |
256 | const bool IsLongDouble) { |
257 | if (IsFloat) { |
258 | return ("std::numbers::" + Constant + "_v<float>" ).str(); |
259 | } |
260 | if (IsLongDouble) { |
261 | return ("std::numbers::" + Constant + "_v<long double>" ).str(); |
262 | } |
263 | return ("std::numbers::" + Constant).str(); |
264 | } |
265 | |
266 | bool isRangeOfCompleteMacro(const clang::SourceRange &Range, |
267 | const clang::SourceManager &SM, |
268 | const clang::LangOptions &LO) { |
269 | if (!Range.getBegin().isMacroID()) { |
270 | return false; |
271 | } |
272 | if (!clang::Lexer::isAtStartOfMacroExpansion(loc: Range.getBegin(), SM, LangOpts: LO)) { |
273 | return false; |
274 | } |
275 | |
276 | if (!Range.getEnd().isMacroID()) { |
277 | return false; |
278 | } |
279 | |
280 | if (!clang::Lexer::isAtEndOfMacroExpansion(loc: Range.getEnd(), SM, LangOpts: LO)) { |
281 | return false; |
282 | } |
283 | |
284 | return true; |
285 | } |
286 | |
287 | } // namespace |
288 | |
289 | namespace clang::tidy::modernize { |
290 | UseStdNumbersCheck::UseStdNumbersCheck(const StringRef Name, |
291 | ClangTidyContext *const Context) |
292 | : ClangTidyCheck(Name, Context), |
293 | IncludeInserter(Options.getLocalOrGlobal(LocalName: "IncludeStyle" , |
294 | Default: utils::IncludeSorter::IS_LLVM), |
295 | areDiagsSelfContained()), |
296 | DiffThresholdString{Options.get(LocalName: "DiffThreshold" , Default: "0.001" )} { |
297 | if (DiffThresholdString.getAsDouble(Result&: DiffThreshold)) { |
298 | configurationDiag( |
299 | Description: "Invalid DiffThreshold config value: '%0', expected a double" ) |
300 | << DiffThresholdString; |
301 | DiffThreshold = 0.001; |
302 | } |
303 | } |
304 | |
305 | void UseStdNumbersCheck::registerMatchers(MatchFinder *const Finder) { |
306 | const auto Matches = MatchBuilder{.DiffThreshold: DiffThreshold}; |
307 | std::vector<Matcher<clang::Stmt>> ConstantMatchers = { |
308 | Matches.matchLog2Euler(), Matches.matchLog10Euler(), |
309 | Matches.matchEulerTopLevel(), Matches.matchEgamma(), |
310 | Matches.matchInvSqrtPi(), Matches.matchInvPi(), |
311 | Matches.matchPiTopLevel(), Matches.matchLn2(), |
312 | Matches.machterLn10(), Matches.matchSqrt2(), |
313 | Matches.matchInvSqrt3(), Matches.matchSqrt3(), |
314 | Matches.matchPhi(), |
315 | }; |
316 | |
317 | Finder->addMatcher( |
318 | NodeMatch: expr( |
319 | anyOfExhaustive(Exprs: std::move(ConstantMatchers)), |
320 | unless(hasParent(explicitCastExpr(hasDestinationType(InnerMatcher: isFloating())))), |
321 | hasType(InnerMatcher: qualType(hasCanonicalTypeUnqualified( |
322 | InnerMatcher: anyOf(qualType(asString(Name: "float" )).bind(ID: "float" ), |
323 | qualType(asString(Name: "double" )), |
324 | qualType(asString(Name: "long double" )).bind(ID: "long double" )))))), |
325 | Action: this); |
326 | } |
327 | |
328 | void UseStdNumbersCheck::check(const MatchFinder::MatchResult &Result) { |
329 | /* |
330 | List of all math constants in the `<numbers>` header |
331 | + e |
332 | + log2e |
333 | + log10e |
334 | + pi |
335 | + inv_pi |
336 | + inv_sqrtpi |
337 | + ln2 |
338 | + ln10 |
339 | + sqrt2 |
340 | + sqrt3 |
341 | + inv_sqrt3 |
342 | + egamma |
343 | + phi |
344 | */ |
345 | |
346 | // The ordering determines what constants are looked at first. |
347 | // E.g. look at 'inv_sqrt3' before 'sqrt3' to be able to replace the larger |
348 | // expression |
349 | constexpr auto Constants = std::array<std::pair<StringRef, double>, 13>{ |
350 | std::pair{StringRef{"log2e" }, llvm::numbers::log2e}, |
351 | std::pair{StringRef{"log10e" }, llvm::numbers::log10e}, |
352 | std::pair{StringRef{"e" }, llvm::numbers::e}, |
353 | std::pair{StringRef{"egamma" }, llvm::numbers::egamma}, |
354 | std::pair{StringRef{"inv_sqrtpi" }, llvm::numbers::inv_sqrtpi}, |
355 | std::pair{StringRef{"inv_pi" }, llvm::numbers::inv_pi}, |
356 | std::pair{StringRef{"pi" }, llvm::numbers::pi}, |
357 | std::pair{StringRef{"ln2" }, llvm::numbers::ln2}, |
358 | std::pair{StringRef{"ln10" }, llvm::numbers::ln10}, |
359 | std::pair{StringRef{"sqrt2" }, llvm::numbers::sqrt2}, |
360 | std::pair{StringRef{"inv_sqrt3" }, llvm::numbers::inv_sqrt3}, |
361 | std::pair{StringRef{"sqrt3" }, llvm::numbers::sqrt3}, |
362 | std::pair{StringRef{"phi" }, llvm::numbers::phi}, |
363 | }; |
364 | |
365 | auto MatchedLiterals = |
366 | llvm::SmallVector<std::tuple<std::string, double, const Expr *>>{}; |
367 | |
368 | const auto &SM = *Result.SourceManager; |
369 | const auto &LO = Result.Context->getLangOpts(); |
370 | |
371 | const auto IsFloat = Result.Nodes.getNodeAs<QualType>(ID: "float" ) != nullptr; |
372 | const auto IsLongDouble = |
373 | Result.Nodes.getNodeAs<QualType>(ID: "long double" ) != nullptr; |
374 | |
375 | for (const auto &[ConstantName, ConstantValue] : Constants) { |
376 | const auto *const Match = Result.Nodes.getNodeAs<Expr>(ID: ConstantName); |
377 | if (Match == nullptr) { |
378 | continue; |
379 | } |
380 | |
381 | const auto Range = Match->getSourceRange(); |
382 | |
383 | const auto IsMacro = Range.getBegin().isMacroID(); |
384 | |
385 | // We do not want to emit a diagnostic when we are matching a macro, but the |
386 | // match inside of the macro does not cover the whole macro. |
387 | if (IsMacro && !isRangeOfCompleteMacro(Range, SM, LO)) { |
388 | continue; |
389 | } |
390 | |
391 | if (const auto PatternBindString = (ConstantName + "_pattern" ).str(); |
392 | Result.Nodes.getNodeAs<Expr>(ID: PatternBindString) != nullptr) { |
393 | const auto Code = getCode(Constant: ConstantName, IsFloat, IsLongDouble); |
394 | diag(Range.getBegin(), "prefer '%0' to this %select{formula|macro}1" ) |
395 | << Code << IsMacro << FixItHint::CreateReplacement(Range, Code); |
396 | return; |
397 | } |
398 | |
399 | const auto LiteralBindString = (ConstantName + "_literal" ).str(); |
400 | if (const auto *const Literal = |
401 | Result.Nodes.getNodeAs<FloatingLiteral>(ID: LiteralBindString)) { |
402 | MatchedLiterals.emplace_back( |
403 | Args: ConstantName, |
404 | Args: std::abs(x: Literal->getValueAsApproximateDouble() - ConstantValue), |
405 | Args: Match); |
406 | } |
407 | } |
408 | |
409 | // We may have had no matches with literals, but a match with a pattern that |
410 | // was a part of a macro which was therefore skipped. |
411 | if (MatchedLiterals.empty()) { |
412 | return; |
413 | } |
414 | |
415 | llvm::sort(C&: MatchedLiterals, Comp: [](const auto &LHS, const auto &RHS) { |
416 | return std::get<1>(LHS) < std::get<1>(RHS); |
417 | }); |
418 | |
419 | const auto &[Constant, Diff, Node] = MatchedLiterals.front(); |
420 | |
421 | const auto Range = Node->getSourceRange(); |
422 | const auto IsMacro = Range.getBegin().isMacroID(); |
423 | |
424 | // We do not want to emit a diagnostic when we are matching a macro, but the |
425 | // match inside of the macro does not cover the whole macro. |
426 | if (IsMacro && !isRangeOfCompleteMacro(Range, SM, LO)) { |
427 | return; |
428 | } |
429 | |
430 | const auto Code = getCode(Constant, IsFloat, IsLongDouble); |
431 | diag(Range.getBegin(), |
432 | "prefer '%0' to this %select{literal|macro}1, differs by '%2'" ) |
433 | << Code << IsMacro << llvm::formatv(Fmt: "{0:e2}" , Vals: Diff).str() |
434 | << FixItHint::CreateReplacement(Range, Code) |
435 | << IncludeInserter.createIncludeInsertion( |
436 | FileID: Result.SourceManager->getFileID(Range.getBegin()), Header: "<numbers>" ); |
437 | } |
438 | |
439 | void UseStdNumbersCheck::registerPPCallbacks( |
440 | const SourceManager &SM, Preprocessor *const PP, |
441 | Preprocessor *const ModuleExpanderPP) { |
442 | IncludeInserter.registerPreprocessor(PP); |
443 | } |
444 | |
445 | void UseStdNumbersCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) { |
446 | Options.store(Options&: Opts, LocalName: "IncludeStyle" , Value: IncludeInserter.getStyle()); |
447 | Options.store(Options&: Opts, LocalName: "DiffThreshold" , Value: DiffThresholdString); |
448 | } |
449 | } // namespace clang::tidy::modernize |
450 | |