1 | //=======- PtrTypesSemantics.cpp ---------------------------------*- C++ -*-==// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "PtrTypesSemantics.h" |
10 | #include "ASTUtils.h" |
11 | #include "clang/AST/CXXInheritance.h" |
12 | #include "clang/AST/Decl.h" |
13 | #include "clang/AST/DeclCXX.h" |
14 | #include "clang/AST/ExprCXX.h" |
15 | #include "clang/AST/StmtVisitor.h" |
16 | #include <optional> |
17 | |
18 | using namespace clang; |
19 | |
20 | namespace { |
21 | |
22 | bool hasPublicMethodInBaseClass(const CXXRecordDecl *R, |
23 | const char *NameToMatch) { |
24 | assert(R); |
25 | assert(R->hasDefinition()); |
26 | |
27 | for (const CXXMethodDecl *MD : R->methods()) { |
28 | const auto MethodName = safeGetName(ASTNode: MD); |
29 | if (MethodName == NameToMatch && MD->getAccess() == AS_public) |
30 | return true; |
31 | } |
32 | return false; |
33 | } |
34 | |
35 | } // namespace |
36 | |
37 | namespace clang { |
38 | |
39 | std::optional<const clang::CXXRecordDecl *> |
40 | hasPublicMethodInBase(const CXXBaseSpecifier *Base, const char *NameToMatch) { |
41 | assert(Base); |
42 | |
43 | const Type *T = Base->getType().getTypePtrOrNull(); |
44 | if (!T) |
45 | return std::nullopt; |
46 | |
47 | const CXXRecordDecl *R = T->getAsCXXRecordDecl(); |
48 | if (!R) |
49 | return std::nullopt; |
50 | if (!R->hasDefinition()) |
51 | return std::nullopt; |
52 | |
53 | return hasPublicMethodInBaseClass(R, NameToMatch) ? R : nullptr; |
54 | } |
55 | |
56 | std::optional<bool> isRefCountable(const CXXRecordDecl* R) |
57 | { |
58 | assert(R); |
59 | |
60 | R = R->getDefinition(); |
61 | if (!R) |
62 | return std::nullopt; |
63 | |
64 | bool hasRef = hasPublicMethodInBaseClass(R, NameToMatch: "ref" ); |
65 | bool hasDeref = hasPublicMethodInBaseClass(R, NameToMatch: "deref" ); |
66 | if (hasRef && hasDeref) |
67 | return true; |
68 | |
69 | CXXBasePaths Paths; |
70 | Paths.setOrigin(const_cast<CXXRecordDecl *>(R)); |
71 | |
72 | bool AnyInconclusiveBase = false; |
73 | const auto hasPublicRefInBase = |
74 | [&AnyInconclusiveBase](const CXXBaseSpecifier *Base, CXXBasePath &) { |
75 | auto hasRefInBase = clang::hasPublicMethodInBase(Base, NameToMatch: "ref" ); |
76 | if (!hasRefInBase) { |
77 | AnyInconclusiveBase = true; |
78 | return false; |
79 | } |
80 | return (*hasRefInBase) != nullptr; |
81 | }; |
82 | |
83 | hasRef = hasRef || R->lookupInBases(BaseMatches: hasPublicRefInBase, Paths, |
84 | /*LookupInDependent =*/true); |
85 | if (AnyInconclusiveBase) |
86 | return std::nullopt; |
87 | |
88 | Paths.clear(); |
89 | const auto hasPublicDerefInBase = |
90 | [&AnyInconclusiveBase](const CXXBaseSpecifier *Base, CXXBasePath &) { |
91 | auto hasDerefInBase = clang::hasPublicMethodInBase(Base, NameToMatch: "deref" ); |
92 | if (!hasDerefInBase) { |
93 | AnyInconclusiveBase = true; |
94 | return false; |
95 | } |
96 | return (*hasDerefInBase) != nullptr; |
97 | }; |
98 | hasDeref = hasDeref || R->lookupInBases(BaseMatches: hasPublicDerefInBase, Paths, |
99 | /*LookupInDependent =*/true); |
100 | if (AnyInconclusiveBase) |
101 | return std::nullopt; |
102 | |
103 | return hasRef && hasDeref; |
104 | } |
105 | |
106 | bool isRefType(const std::string &Name) { |
107 | return Name == "Ref" || Name == "RefAllowingPartiallyDestroyed" || |
108 | Name == "RefPtr" || Name == "RefPtrAllowingPartiallyDestroyed" ; |
109 | } |
110 | |
111 | bool isCtorOfRefCounted(const clang::FunctionDecl *F) { |
112 | assert(F); |
113 | const std::string &FunctionName = safeGetName(ASTNode: F); |
114 | |
115 | return isRefType(Name: FunctionName) || FunctionName == "makeRef" || |
116 | FunctionName == "makeRefPtr" || FunctionName == "UniqueRef" || |
117 | FunctionName == "makeUniqueRef" || |
118 | FunctionName == "makeUniqueRefWithoutFastMallocCheck" |
119 | |
120 | || FunctionName == "String" || FunctionName == "AtomString" || |
121 | FunctionName == "UniqueString" |
122 | // FIXME: Implement as attribute. |
123 | || FunctionName == "Identifier" ; |
124 | } |
125 | |
126 | bool isReturnValueRefCounted(const clang::FunctionDecl *F) { |
127 | assert(F); |
128 | QualType type = F->getReturnType(); |
129 | while (!type.isNull()) { |
130 | if (auto *elaboratedT = type->getAs<ElaboratedType>()) { |
131 | type = elaboratedT->desugar(); |
132 | continue; |
133 | } |
134 | if (auto *specialT = type->getAs<TemplateSpecializationType>()) { |
135 | if (auto *decl = specialT->getTemplateName().getAsTemplateDecl()) { |
136 | auto name = decl->getNameAsString(); |
137 | return isRefType(name); |
138 | } |
139 | return false; |
140 | } |
141 | return false; |
142 | } |
143 | return false; |
144 | } |
145 | |
146 | std::optional<bool> isUncounted(const CXXRecordDecl* Class) |
147 | { |
148 | // Keep isRefCounted first as it's cheaper. |
149 | if (isRefCounted(Class)) |
150 | return false; |
151 | |
152 | std::optional<bool> IsRefCountable = isRefCountable(R: Class); |
153 | if (!IsRefCountable) |
154 | return std::nullopt; |
155 | |
156 | return (*IsRefCountable); |
157 | } |
158 | |
159 | std::optional<bool> isUncountedPtr(const Type* T) |
160 | { |
161 | assert(T); |
162 | |
163 | if (T->isPointerType() || T->isReferenceType()) { |
164 | if (auto *CXXRD = T->getPointeeCXXRecordDecl()) { |
165 | return isUncounted(Class: CXXRD); |
166 | } |
167 | } |
168 | return false; |
169 | } |
170 | |
171 | std::optional<bool> isGetterOfRefCounted(const CXXMethodDecl* M) |
172 | { |
173 | assert(M); |
174 | |
175 | if (isa<CXXMethodDecl>(Val: M)) { |
176 | const CXXRecordDecl *calleeMethodsClass = M->getParent(); |
177 | auto className = safeGetName(ASTNode: calleeMethodsClass); |
178 | auto method = safeGetName(ASTNode: M); |
179 | |
180 | if ((isRefType(Name: className) && (method == "get" || method == "ptr" )) || |
181 | ((className == "String" || className == "AtomString" || |
182 | className == "AtomStringImpl" || className == "UniqueString" || |
183 | className == "UniqueStringImpl" || className == "Identifier" ) && |
184 | method == "impl" )) |
185 | return true; |
186 | |
187 | // Ref<T> -> T conversion |
188 | // FIXME: Currently allowing any Ref<T> -> whatever cast. |
189 | if (isRefType(Name: className)) { |
190 | if (auto *maybeRefToRawOperator = dyn_cast<CXXConversionDecl>(Val: M)) { |
191 | if (auto *targetConversionType = |
192 | maybeRefToRawOperator->getConversionType().getTypePtrOrNull()) { |
193 | return isUncountedPtr(T: targetConversionType); |
194 | } |
195 | } |
196 | } |
197 | } |
198 | return false; |
199 | } |
200 | |
201 | bool isRefCounted(const CXXRecordDecl *R) { |
202 | assert(R); |
203 | if (auto *TmplR = R->getTemplateInstantiationPattern()) { |
204 | // FIXME: String/AtomString/UniqueString |
205 | const auto &ClassName = safeGetName(ASTNode: TmplR); |
206 | return isRefType(Name: ClassName); |
207 | } |
208 | return false; |
209 | } |
210 | |
211 | bool isPtrConversion(const FunctionDecl *F) { |
212 | assert(F); |
213 | if (isCtorOfRefCounted(F)) |
214 | return true; |
215 | |
216 | // FIXME: check # of params == 1 |
217 | const auto FunctionName = safeGetName(ASTNode: F); |
218 | if (FunctionName == "getPtr" || FunctionName == "WeakPtr" || |
219 | FunctionName == "dynamicDowncast" || FunctionName == "downcast" || |
220 | FunctionName == "checkedDowncast" || |
221 | FunctionName == "uncheckedDowncast" || FunctionName == "bitwise_cast" ) |
222 | return true; |
223 | |
224 | return false; |
225 | } |
226 | |
227 | bool isSingleton(const FunctionDecl *F) { |
228 | assert(F); |
229 | // FIXME: check # of params == 1 |
230 | if (auto *MethodDecl = dyn_cast<CXXMethodDecl>(Val: F)) { |
231 | if (!MethodDecl->isStatic()) |
232 | return false; |
233 | } |
234 | const auto &Name = safeGetName(ASTNode: F); |
235 | std::string SingletonStr = "singleton" ; |
236 | auto index = Name.find(str: SingletonStr); |
237 | return index != std::string::npos && |
238 | index == Name.size() - SingletonStr.size(); |
239 | } |
240 | |
241 | // We only care about statements so let's use the simple |
242 | // (non-recursive) visitor. |
243 | class TrivialFunctionAnalysisVisitor |
244 | : public ConstStmtVisitor<TrivialFunctionAnalysisVisitor, bool> { |
245 | |
246 | // Returns false if at least one child is non-trivial. |
247 | bool VisitChildren(const Stmt *S) { |
248 | for (const Stmt *Child : S->children()) { |
249 | if (Child && !Visit(Child)) |
250 | return false; |
251 | } |
252 | |
253 | return true; |
254 | } |
255 | |
256 | template <typename CheckFunction> |
257 | bool WithCachedResult(const Stmt *S, CheckFunction Function) { |
258 | // If the statement isn't in the cache, conservatively assume that |
259 | // it's not trivial until analysis completes. Insert false to the cache |
260 | // first to avoid infinite recursion. |
261 | auto [It, IsNew] = Cache.insert(KV: std::make_pair(x&: S, y: false)); |
262 | if (!IsNew) |
263 | return It->second; |
264 | bool Result = Function(); |
265 | Cache[S] = Result; |
266 | return Result; |
267 | } |
268 | |
269 | public: |
270 | using CacheTy = TrivialFunctionAnalysis::CacheTy; |
271 | |
272 | TrivialFunctionAnalysisVisitor(CacheTy &Cache) : Cache(Cache) {} |
273 | |
274 | bool VisitStmt(const Stmt *S) { |
275 | // All statements are non-trivial unless overriden later. |
276 | // Don't even recurse into children by default. |
277 | return false; |
278 | } |
279 | |
280 | bool VisitCompoundStmt(const CompoundStmt *CS) { |
281 | // A compound statement is allowed as long each individual sub-statement |
282 | // is trivial. |
283 | return WithCachedResult(CS, [&]() { return VisitChildren(CS); }); |
284 | } |
285 | |
286 | bool VisitReturnStmt(const ReturnStmt *RS) { |
287 | // A return statement is allowed as long as the return value is trivial. |
288 | if (auto *RV = RS->getRetValue()) |
289 | return Visit(RV); |
290 | return true; |
291 | } |
292 | |
293 | bool VisitDeclStmt(const DeclStmt *DS) { return VisitChildren(S: DS); } |
294 | bool VisitDoStmt(const DoStmt *DS) { return VisitChildren(S: DS); } |
295 | bool VisitIfStmt(const IfStmt *IS) { |
296 | return WithCachedResult(IS, [&]() { return VisitChildren(IS); }); |
297 | } |
298 | bool VisitForStmt(const ForStmt *FS) { |
299 | return WithCachedResult(S: FS, Function: [&]() { return VisitChildren(S: FS); }); |
300 | } |
301 | bool VisitCXXForRangeStmt(const CXXForRangeStmt *FS) { |
302 | return WithCachedResult(S: FS, Function: [&]() { return VisitChildren(S: FS); }); |
303 | } |
304 | bool VisitWhileStmt(const WhileStmt *WS) { |
305 | return WithCachedResult(WS, [&]() { return VisitChildren(WS); }); |
306 | } |
307 | bool VisitSwitchStmt(const SwitchStmt *SS) { return VisitChildren(SS); } |
308 | bool VisitCaseStmt(const CaseStmt *CS) { return VisitChildren(CS); } |
309 | bool VisitDefaultStmt(const DefaultStmt *DS) { return VisitChildren(S: DS); } |
310 | |
311 | bool VisitUnaryOperator(const UnaryOperator *UO) { |
312 | // Operator '*' and '!' are allowed as long as the operand is trivial. |
313 | auto op = UO->getOpcode(); |
314 | if (op == UO_Deref || op == UO_AddrOf || op == UO_LNot) |
315 | return Visit(UO->getSubExpr()); |
316 | |
317 | if (UO->isIncrementOp() || UO->isDecrementOp()) { |
318 | // Allow increment or decrement of a POD type. |
319 | if (auto *RefExpr = dyn_cast<DeclRefExpr>(Val: UO->getSubExpr())) { |
320 | if (auto *Decl = dyn_cast<VarDecl>(Val: RefExpr->getDecl())) |
321 | return Decl->isLocalVarDeclOrParm() && |
322 | Decl->getType().isPODType(Decl->getASTContext()); |
323 | } |
324 | } |
325 | // Other operators are non-trivial. |
326 | return false; |
327 | } |
328 | |
329 | bool VisitBinaryOperator(const BinaryOperator *BO) { |
330 | // Binary operators are trivial if their operands are trivial. |
331 | return Visit(BO->getLHS()) && Visit(BO->getRHS()); |
332 | } |
333 | |
334 | bool VisitConditionalOperator(const ConditionalOperator *CO) { |
335 | // Ternary operators are trivial if their conditions & values are trivial. |
336 | return VisitChildren(CO); |
337 | } |
338 | |
339 | bool VisitAtomicExpr(const AtomicExpr *E) { return VisitChildren(E); } |
340 | |
341 | bool VisitStaticAssertDecl(const StaticAssertDecl *SAD) { |
342 | // Any static_assert is considered trivial. |
343 | return true; |
344 | } |
345 | |
346 | bool VisitCallExpr(const CallExpr *CE) { |
347 | if (!checkArguments(CE)) |
348 | return false; |
349 | |
350 | auto *Callee = CE->getDirectCallee(); |
351 | if (!Callee) |
352 | return false; |
353 | const auto &Name = safeGetName(ASTNode: Callee); |
354 | |
355 | if (Name == "WTFCrashWithInfo" || Name == "WTFBreakpointTrap" || |
356 | Name == "WTFReportAssertionFailure" || |
357 | Name == "compilerFenceForCrash" || Name == "__builtin_unreachable" ) |
358 | return true; |
359 | |
360 | return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache); |
361 | } |
362 | |
363 | bool VisitPredefinedExpr(const PredefinedExpr *E) { |
364 | // A predefined identifier such as "func" is considered trivial. |
365 | return true; |
366 | } |
367 | |
368 | bool VisitCXXMemberCallExpr(const CXXMemberCallExpr *MCE) { |
369 | if (!checkArguments(MCE)) |
370 | return false; |
371 | |
372 | bool TrivialThis = Visit(MCE->getImplicitObjectArgument()); |
373 | if (!TrivialThis) |
374 | return false; |
375 | |
376 | auto *Callee = MCE->getMethodDecl(); |
377 | if (!Callee) |
378 | return false; |
379 | |
380 | std::optional<bool> IsGetterOfRefCounted = isGetterOfRefCounted(M: Callee); |
381 | if (IsGetterOfRefCounted && *IsGetterOfRefCounted) |
382 | return true; |
383 | |
384 | // Recursively descend into the callee to confirm that it's trivial as well. |
385 | return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache); |
386 | } |
387 | |
388 | bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E) { |
389 | if (auto *Expr = E->getExpr()) { |
390 | if (!Visit(Expr)) |
391 | return false; |
392 | } |
393 | return true; |
394 | } |
395 | |
396 | bool checkArguments(const CallExpr *CE) { |
397 | for (const Expr *Arg : CE->arguments()) { |
398 | if (Arg && !Visit(Arg)) |
399 | return false; |
400 | } |
401 | return true; |
402 | } |
403 | |
404 | bool VisitCXXConstructExpr(const CXXConstructExpr *CE) { |
405 | for (const Expr *Arg : CE->arguments()) { |
406 | if (Arg && !Visit(Arg)) |
407 | return false; |
408 | } |
409 | |
410 | // Recursively descend into the callee to confirm that it's trivial. |
411 | return TrivialFunctionAnalysis::isTrivialImpl(CE->getConstructor(), Cache); |
412 | } |
413 | |
414 | bool VisitImplicitCastExpr(const ImplicitCastExpr *ICE) { |
415 | return Visit(ICE->getSubExpr()); |
416 | } |
417 | |
418 | bool VisitExplicitCastExpr(const ExplicitCastExpr *ECE) { |
419 | return Visit(ECE->getSubExpr()); |
420 | } |
421 | |
422 | bool VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *VMT) { |
423 | return Visit(VMT->getSubExpr()); |
424 | } |
425 | |
426 | bool VisitExprWithCleanups(const ExprWithCleanups *EWC) { |
427 | return Visit(EWC->getSubExpr()); |
428 | } |
429 | |
430 | bool VisitParenExpr(const ParenExpr *PE) { return Visit(PE->getSubExpr()); } |
431 | |
432 | bool VisitInitListExpr(const InitListExpr *ILE) { |
433 | for (const Expr *Child : ILE->inits()) { |
434 | if (Child && !Visit(Child)) |
435 | return false; |
436 | } |
437 | return true; |
438 | } |
439 | |
440 | bool VisitMemberExpr(const MemberExpr *ME) { |
441 | // Field access is allowed but the base pointer may itself be non-trivial. |
442 | return Visit(ME->getBase()); |
443 | } |
444 | |
445 | bool VisitCXXThisExpr(const CXXThisExpr *CTE) { |
446 | // The expression 'this' is always trivial, be it explicit or implicit. |
447 | return true; |
448 | } |
449 | |
450 | bool VisitCXXNullPtrLiteralExpr(const CXXNullPtrLiteralExpr *E) { |
451 | // nullptr is trivial. |
452 | return true; |
453 | } |
454 | |
455 | bool VisitDeclRefExpr(const DeclRefExpr *DRE) { |
456 | // The use of a variable is trivial. |
457 | return true; |
458 | } |
459 | |
460 | // Constant literal expressions are always trivial |
461 | bool VisitIntegerLiteral(const IntegerLiteral *E) { return true; } |
462 | bool VisitFloatingLiteral(const FloatingLiteral *E) { return true; } |
463 | bool VisitFixedPointLiteral(const FixedPointLiteral *E) { return true; } |
464 | bool VisitCharacterLiteral(const CharacterLiteral *E) { return true; } |
465 | bool VisitStringLiteral(const StringLiteral *E) { return true; } |
466 | |
467 | bool VisitConstantExpr(const ConstantExpr *CE) { |
468 | // Constant expressions are trivial. |
469 | return true; |
470 | } |
471 | |
472 | private: |
473 | CacheTy &Cache; |
474 | }; |
475 | |
476 | bool TrivialFunctionAnalysis::isTrivialImpl( |
477 | const Decl *D, TrivialFunctionAnalysis::CacheTy &Cache) { |
478 | // If the function isn't in the cache, conservatively assume that |
479 | // it's not trivial until analysis completes. This makes every recursive |
480 | // function non-trivial. This also guarantees that each function |
481 | // will be scanned at most once. |
482 | auto [It, IsNew] = Cache.insert(KV: std::make_pair(x&: D, y: false)); |
483 | if (!IsNew) |
484 | return It->second; |
485 | |
486 | const Stmt *Body = D->getBody(); |
487 | if (!Body) |
488 | return false; |
489 | |
490 | TrivialFunctionAnalysisVisitor V(Cache); |
491 | bool Result = V.Visit(Body); |
492 | if (Result) |
493 | Cache[D] = true; |
494 | |
495 | return Result; |
496 | } |
497 | |
498 | bool TrivialFunctionAnalysis::isTrivialImpl( |
499 | const Stmt *S, TrivialFunctionAnalysis::CacheTy &Cache) { |
500 | // If the statement isn't in the cache, conservatively assume that |
501 | // it's not trivial until analysis completes. Unlike a function case, |
502 | // we don't insert an entry into the cache until Visit returns |
503 | // since Visit* functions themselves make use of the cache. |
504 | |
505 | TrivialFunctionAnalysisVisitor V(Cache); |
506 | bool Result = V.Visit(S); |
507 | assert(Cache.contains(S) && "Top-level statement not properly cached!" ); |
508 | return Result; |
509 | } |
510 | |
511 | } // namespace clang |
512 | |