1//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Instrumentation-based profile-guided optimization
10//
11//===----------------------------------------------------------------------===//
12
13#include "CodeGenPGO.h"
14#include "CodeGenFunction.h"
15#include "CoverageMappingGen.h"
16#include "clang/AST/RecursiveASTVisitor.h"
17#include "clang/AST/StmtVisitor.h"
18#include "llvm/IR/Intrinsics.h"
19#include "llvm/IR/MDBuilder.h"
20#include "llvm/Support/CommandLine.h"
21#include "llvm/Support/Endian.h"
22#include "llvm/Support/FileSystem.h"
23#include "llvm/Support/MD5.h"
24#include <optional>
25
26namespace llvm {
27extern cl::opt<bool> EnableSingleByteCoverage;
28} // namespace llvm
29
30static llvm::cl::opt<bool>
31 EnableValueProfiling("enable-value-profiling",
32 llvm::cl::desc("Enable value profiling"),
33 llvm::cl::Hidden, llvm::cl::init(Val: false));
34
35extern llvm::cl::opt<bool> SystemHeadersCoverage;
36
37using namespace clang;
38using namespace CodeGen;
39
40void CodeGenPGO::setFuncName(StringRef Name,
41 llvm::GlobalValue::LinkageTypes Linkage) {
42 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
43 FuncName = llvm::getPGOFuncName(
44 RawFuncName: Name, Linkage, FileName: CGM.getCodeGenOpts().MainFileName,
45 Version: PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
46
47 // If we're generating a profile, create a variable for the name.
48 if (CGM.getCodeGenOpts().hasProfileClangInstr())
49 FuncNameVar = llvm::createPGOFuncNameVar(M&: CGM.getModule(), Linkage, PGOFuncName: FuncName);
50}
51
52void CodeGenPGO::setFuncName(llvm::Function *Fn) {
53 setFuncName(Name: Fn->getName(), Linkage: Fn->getLinkage());
54 // Create PGOFuncName meta data.
55 llvm::createPGOFuncNameMetadata(F&: *Fn, PGOFuncName: FuncName);
56}
57
58/// The version of the PGO hash algorithm.
59enum PGOHashVersion : unsigned {
60 PGO_HASH_V1,
61 PGO_HASH_V2,
62 PGO_HASH_V3,
63
64 // Keep this set to the latest hash version.
65 PGO_HASH_LATEST = PGO_HASH_V3
66};
67
68namespace {
69/// Stable hasher for PGO region counters.
70///
71/// PGOHash produces a stable hash of a given function's control flow.
72///
73/// Changing the output of this hash will invalidate all previously generated
74/// profiles -- i.e., don't do it.
75///
76/// \note When this hash does eventually change (years?), we still need to
77/// support old hashes. We'll need to pull in the version number from the
78/// profile data format and use the matching hash function.
79class PGOHash {
80 uint64_t Working;
81 unsigned Count;
82 PGOHashVersion HashVersion;
83 llvm::MD5 MD5;
84
85 static const int NumBitsPerType = 6;
86 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
87 static const unsigned TooBig = 1u << NumBitsPerType;
88
89public:
90 /// Hash values for AST nodes.
91 ///
92 /// Distinct values for AST nodes that have region counters attached.
93 ///
94 /// These values must be stable. All new members must be added at the end,
95 /// and no members should be removed. Changing the enumeration value for an
96 /// AST node will affect the hash of every function that contains that node.
97 enum HashType : unsigned char {
98 None = 0,
99 LabelStmt = 1,
100 WhileStmt,
101 DoStmt,
102 ForStmt,
103 CXXForRangeStmt,
104 ObjCForCollectionStmt,
105 SwitchStmt,
106 CaseStmt,
107 DefaultStmt,
108 IfStmt,
109 CXXTryStmt,
110 CXXCatchStmt,
111 ConditionalOperator,
112 BinaryOperatorLAnd,
113 BinaryOperatorLOr,
114 BinaryConditionalOperator,
115 // The preceding values are available with PGO_HASH_V1.
116
117 EndOfScope,
118 IfThenBranch,
119 IfElseBranch,
120 GotoStmt,
121 IndirectGotoStmt,
122 BreakStmt,
123 ContinueStmt,
124 ReturnStmt,
125 ThrowExpr,
126 UnaryOperatorLNot,
127 BinaryOperatorLT,
128 BinaryOperatorGT,
129 BinaryOperatorLE,
130 BinaryOperatorGE,
131 BinaryOperatorEQ,
132 BinaryOperatorNE,
133 // The preceding values are available since PGO_HASH_V2.
134
135 // Keep this last. It's for the static assert that follows.
136 LastHashType
137 };
138 static_assert(LastHashType <= TooBig, "Too many types in HashType");
139
140 PGOHash(PGOHashVersion HashVersion)
141 : Working(0), Count(0), HashVersion(HashVersion) {}
142 void combine(HashType Type);
143 uint64_t finalize();
144 PGOHashVersion getHashVersion() const { return HashVersion; }
145};
146const int PGOHash::NumBitsPerType;
147const unsigned PGOHash::NumTypesPerWord;
148const unsigned PGOHash::TooBig;
149
150/// Get the PGO hash version used in the given indexed profile.
151static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
152 CodeGenModule &CGM) {
153 if (PGOReader->getVersion() <= 4)
154 return PGO_HASH_V1;
155 if (PGOReader->getVersion() <= 5)
156 return PGO_HASH_V2;
157 return PGO_HASH_V3;
158}
159
160/// A RecursiveASTVisitor that fills a map of statements to PGO counters.
161struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
162 using Base = RecursiveASTVisitor<MapRegionCounters>;
163
164 /// The next counter value to assign.
165 unsigned NextCounter;
166 /// The function hash.
167 PGOHash Hash;
168 /// The map of statements to counters.
169 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
170 /// The next bitmap byte index to assign.
171 unsigned NextMCDCBitmapIdx;
172 /// The state of MC/DC Coverage in this function.
173 MCDC::State &MCDCState;
174 /// Maximum number of supported MC/DC conditions in a boolean expression.
175 unsigned MCDCMaxCond;
176 /// The profile version.
177 uint64_t ProfileVersion;
178 /// Diagnostics Engine used to report warnings.
179 DiagnosticsEngine &Diag;
180
181 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
182 llvm::DenseMap<const Stmt *, unsigned> &CounterMap,
183 MCDC::State &MCDCState, unsigned MCDCMaxCond,
184 DiagnosticsEngine &Diag)
185 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
186 NextMCDCBitmapIdx(0), MCDCState(MCDCState), MCDCMaxCond(MCDCMaxCond),
187 ProfileVersion(ProfileVersion), Diag(Diag) {}
188
189 // Blocks and lambdas are handled as separate functions, so we need not
190 // traverse them in the parent context.
191 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
192 bool TraverseLambdaExpr(LambdaExpr *LE) {
193 // Traverse the captures, but not the body.
194 for (auto C : zip(t: LE->captures(), u: LE->capture_inits()))
195 TraverseLambdaCapture(LE, C: &std::get<0>(t&: C), Init: std::get<1>(t&: C));
196 return true;
197 }
198 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
199
200 bool VisitDecl(const Decl *D) {
201 switch (D->getKind()) {
202 default:
203 break;
204 case Decl::Function:
205 case Decl::CXXMethod:
206 case Decl::CXXConstructor:
207 case Decl::CXXDestructor:
208 case Decl::CXXConversion:
209 case Decl::ObjCMethod:
210 case Decl::Block:
211 case Decl::Captured:
212 CounterMap[D->getBody()] = NextCounter++;
213 break;
214 }
215 return true;
216 }
217
218 /// If \p S gets a fresh counter, update the counter mappings. Return the
219 /// V1 hash of \p S.
220 PGOHash::HashType updateCounterMappings(Stmt *S) {
221 auto Type = getHashType(HashVersion: PGO_HASH_V1, S);
222 if (Type != PGOHash::None)
223 CounterMap[S] = NextCounter++;
224 return Type;
225 }
226
227 /// The following stacks are used with dataTraverseStmtPre() and
228 /// dataTraverseStmtPost() to track the depth of nested logical operators in a
229 /// boolean expression in a function. The ultimate purpose is to keep track
230 /// of the number of leaf-level conditions in the boolean expression so that a
231 /// profile bitmap can be allocated based on that number.
232 ///
233 /// The stacks are also used to find error cases and notify the user. A
234 /// standard logical operator nest for a boolean expression could be in a form
235 /// similar to this: "x = a && b && c && (d || f)"
236 unsigned NumCond = 0;
237 bool SplitNestedLogicalOp = false;
238 SmallVector<const Stmt *, 16> NonLogOpStack;
239 SmallVector<const BinaryOperator *, 16> LogOpStack;
240
241 // Hook: dataTraverseStmtPre() is invoked prior to visiting an AST Stmt node.
242 bool dataTraverseStmtPre(Stmt *S) {
243 /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
244 if (MCDCMaxCond == 0)
245 return true;
246
247 /// At the top of the logical operator nest, reset the number of conditions,
248 /// also forget previously seen split nesting cases.
249 if (LogOpStack.empty()) {
250 NumCond = 0;
251 SplitNestedLogicalOp = false;
252 }
253
254 if (const Expr *E = dyn_cast<Expr>(Val: S)) {
255 const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val: E->IgnoreParens());
256 if (BinOp && BinOp->isLogicalOp()) {
257 /// Check for "split-nested" logical operators. This happens when a new
258 /// boolean expression logical-op nest is encountered within an existing
259 /// boolean expression, separated by a non-logical operator. For
260 /// example, in "x = (a && b && c && foo(d && f))", the "d && f" case
261 /// starts a new boolean expression that is separated from the other
262 /// conditions by the operator foo(). Split-nested cases are not
263 /// supported by MC/DC.
264 SplitNestedLogicalOp = SplitNestedLogicalOp || !NonLogOpStack.empty();
265
266 LogOpStack.push_back(Elt: BinOp);
267 return true;
268 }
269 }
270
271 /// Keep track of non-logical operators. These are OK as long as we don't
272 /// encounter a new logical operator after seeing one.
273 if (!LogOpStack.empty())
274 NonLogOpStack.push_back(Elt: S);
275
276 return true;
277 }
278
279 // Hook: dataTraverseStmtPost() is invoked by the AST visitor after visiting
280 // an AST Stmt node. MC/DC will use it to to signal when the top of a
281 // logical operation (boolean expression) nest is encountered.
282 bool dataTraverseStmtPost(Stmt *S) {
283 /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
284 if (MCDCMaxCond == 0)
285 return true;
286
287 if (const Expr *E = dyn_cast<Expr>(Val: S)) {
288 const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val: E->IgnoreParens());
289 if (BinOp && BinOp->isLogicalOp()) {
290 assert(LogOpStack.back() == BinOp);
291 LogOpStack.pop_back();
292
293 /// At the top of logical operator nest:
294 if (LogOpStack.empty()) {
295 /// Was the "split-nested" logical operator case encountered?
296 if (SplitNestedLogicalOp) {
297 unsigned DiagID = Diag.getCustomDiagID(
298 L: DiagnosticsEngine::Warning,
299 FormatString: "unsupported MC/DC boolean expression; "
300 "contains an operation with a nested boolean expression. "
301 "Expression will not be covered");
302 Diag.Report(Loc: S->getBeginLoc(), DiagID);
303 return true;
304 }
305
306 /// Was the maximum number of conditions encountered?
307 if (NumCond > MCDCMaxCond) {
308 unsigned DiagID = Diag.getCustomDiagID(
309 L: DiagnosticsEngine::Warning,
310 FormatString: "unsupported MC/DC boolean expression; "
311 "number of conditions (%0) exceeds max (%1). "
312 "Expression will not be covered");
313 Diag.Report(Loc: S->getBeginLoc(), DiagID) << NumCond << MCDCMaxCond;
314 return true;
315 }
316
317 // Otherwise, allocate the number of bytes required for the bitmap
318 // based on the number of conditions. Must be at least 1-byte long.
319 MCDCState.DecisionByStmt[BinOp].BitmapIdx = NextMCDCBitmapIdx;
320 unsigned SizeInBits = std::max<unsigned>(a: 1L << NumCond, CHAR_BIT);
321 NextMCDCBitmapIdx += SizeInBits / CHAR_BIT;
322 }
323 return true;
324 }
325 }
326
327 if (!LogOpStack.empty())
328 NonLogOpStack.pop_back();
329
330 return true;
331 }
332
333 /// The RHS of all logical operators gets a fresh counter in order to count
334 /// how many times the RHS evaluates to true or false, depending on the
335 /// semantics of the operator. This is only valid for ">= v7" of the profile
336 /// version so that we facilitate backward compatibility. In addition, in
337 /// order to use MC/DC, count the number of total LHS and RHS conditions.
338 bool VisitBinaryOperator(BinaryOperator *S) {
339 if (S->isLogicalOp()) {
340 if (CodeGenFunction::isInstrumentedCondition(C: S->getLHS()))
341 NumCond++;
342
343 if (CodeGenFunction::isInstrumentedCondition(C: S->getRHS())) {
344 if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
345 CounterMap[S->getRHS()] = NextCounter++;
346
347 NumCond++;
348 }
349 }
350 return Base::VisitBinaryOperator(S);
351 }
352
353 bool VisitConditionalOperator(ConditionalOperator *S) {
354 if (llvm::EnableSingleByteCoverage && S->getTrueExpr())
355 CounterMap[S->getTrueExpr()] = NextCounter++;
356 if (llvm::EnableSingleByteCoverage && S->getFalseExpr())
357 CounterMap[S->getFalseExpr()] = NextCounter++;
358 return Base::VisitConditionalOperator(S);
359 }
360
361 /// Include \p S in the function hash.
362 bool VisitStmt(Stmt *S) {
363 auto Type = updateCounterMappings(S);
364 if (Hash.getHashVersion() != PGO_HASH_V1)
365 Type = getHashType(HashVersion: Hash.getHashVersion(), S);
366 if (Type != PGOHash::None)
367 Hash.combine(Type);
368 return true;
369 }
370
371 bool TraverseIfStmt(IfStmt *If) {
372 // If we used the V1 hash, use the default traversal.
373 if (Hash.getHashVersion() == PGO_HASH_V1)
374 return Base::TraverseIfStmt(If);
375
376 // When single byte coverage mode is enabled, add a counter to then and
377 // else.
378 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
379 for (Stmt *CS : If->children()) {
380 if (!CS || NoSingleByteCoverage)
381 continue;
382 if (CS == If->getThen())
383 CounterMap[If->getThen()] = NextCounter++;
384 else if (CS == If->getElse())
385 CounterMap[If->getElse()] = NextCounter++;
386 }
387
388 // Otherwise, keep track of which branch we're in while traversing.
389 VisitStmt(If);
390
391 for (Stmt *CS : If->children()) {
392 if (!CS)
393 continue;
394 if (CS == If->getThen())
395 Hash.combine(Type: PGOHash::IfThenBranch);
396 else if (CS == If->getElse())
397 Hash.combine(Type: PGOHash::IfElseBranch);
398 TraverseStmt(S: CS);
399 }
400 Hash.combine(Type: PGOHash::EndOfScope);
401 return true;
402 }
403
404 bool TraverseWhileStmt(WhileStmt *While) {
405 // When single byte coverage mode is enabled, add a counter to condition and
406 // body.
407 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
408 for (Stmt *CS : While->children()) {
409 if (!CS || NoSingleByteCoverage)
410 continue;
411 if (CS == While->getCond())
412 CounterMap[While->getCond()] = NextCounter++;
413 else if (CS == While->getBody())
414 CounterMap[While->getBody()] = NextCounter++;
415 }
416
417 Base::TraverseWhileStmt(While);
418 if (Hash.getHashVersion() != PGO_HASH_V1)
419 Hash.combine(Type: PGOHash::EndOfScope);
420 return true;
421 }
422
423 bool TraverseDoStmt(DoStmt *Do) {
424 // When single byte coverage mode is enabled, add a counter to condition and
425 // body.
426 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
427 for (Stmt *CS : Do->children()) {
428 if (!CS || NoSingleByteCoverage)
429 continue;
430 if (CS == Do->getCond())
431 CounterMap[Do->getCond()] = NextCounter++;
432 else if (CS == Do->getBody())
433 CounterMap[Do->getBody()] = NextCounter++;
434 }
435
436 Base::TraverseDoStmt(Do);
437 if (Hash.getHashVersion() != PGO_HASH_V1)
438 Hash.combine(Type: PGOHash::EndOfScope);
439 return true;
440 }
441
442 bool TraverseForStmt(ForStmt *For) {
443 // When single byte coverage mode is enabled, add a counter to condition,
444 // increment and body.
445 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
446 for (Stmt *CS : For->children()) {
447 if (!CS || NoSingleByteCoverage)
448 continue;
449 if (CS == For->getCond())
450 CounterMap[For->getCond()] = NextCounter++;
451 else if (CS == For->getInc())
452 CounterMap[For->getInc()] = NextCounter++;
453 else if (CS == For->getBody())
454 CounterMap[For->getBody()] = NextCounter++;
455 }
456
457 Base::TraverseForStmt(For);
458 if (Hash.getHashVersion() != PGO_HASH_V1)
459 Hash.combine(Type: PGOHash::EndOfScope);
460 return true;
461 }
462
463 bool TraverseCXXForRangeStmt(CXXForRangeStmt *ForRange) {
464 // When single byte coverage mode is enabled, add a counter to body.
465 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
466 for (Stmt *CS : ForRange->children()) {
467 if (!CS || NoSingleByteCoverage)
468 continue;
469 if (CS == ForRange->getBody())
470 CounterMap[ForRange->getBody()] = NextCounter++;
471 }
472
473 Base::TraverseCXXForRangeStmt(ForRange);
474 if (Hash.getHashVersion() != PGO_HASH_V1)
475 Hash.combine(Type: PGOHash::EndOfScope);
476 return true;
477 }
478
479// If the statement type \p N is nestable, and its nesting impacts profile
480// stability, define a custom traversal which tracks the end of the statement
481// in the hash (provided we're not using the V1 hash).
482#define DEFINE_NESTABLE_TRAVERSAL(N) \
483 bool Traverse##N(N *S) { \
484 Base::Traverse##N(S); \
485 if (Hash.getHashVersion() != PGO_HASH_V1) \
486 Hash.combine(PGOHash::EndOfScope); \
487 return true; \
488 }
489
490 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
491 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
492 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
493
494 /// Get version \p HashVersion of the PGO hash for \p S.
495 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
496 switch (S->getStmtClass()) {
497 default:
498 break;
499 case Stmt::LabelStmtClass:
500 return PGOHash::LabelStmt;
501 case Stmt::WhileStmtClass:
502 return PGOHash::WhileStmt;
503 case Stmt::DoStmtClass:
504 return PGOHash::DoStmt;
505 case Stmt::ForStmtClass:
506 return PGOHash::ForStmt;
507 case Stmt::CXXForRangeStmtClass:
508 return PGOHash::CXXForRangeStmt;
509 case Stmt::ObjCForCollectionStmtClass:
510 return PGOHash::ObjCForCollectionStmt;
511 case Stmt::SwitchStmtClass:
512 return PGOHash::SwitchStmt;
513 case Stmt::CaseStmtClass:
514 return PGOHash::CaseStmt;
515 case Stmt::DefaultStmtClass:
516 return PGOHash::DefaultStmt;
517 case Stmt::IfStmtClass:
518 return PGOHash::IfStmt;
519 case Stmt::CXXTryStmtClass:
520 return PGOHash::CXXTryStmt;
521 case Stmt::CXXCatchStmtClass:
522 return PGOHash::CXXCatchStmt;
523 case Stmt::ConditionalOperatorClass:
524 return PGOHash::ConditionalOperator;
525 case Stmt::BinaryConditionalOperatorClass:
526 return PGOHash::BinaryConditionalOperator;
527 case Stmt::BinaryOperatorClass: {
528 const BinaryOperator *BO = cast<BinaryOperator>(Val: S);
529 if (BO->getOpcode() == BO_LAnd)
530 return PGOHash::BinaryOperatorLAnd;
531 if (BO->getOpcode() == BO_LOr)
532 return PGOHash::BinaryOperatorLOr;
533 if (HashVersion >= PGO_HASH_V2) {
534 switch (BO->getOpcode()) {
535 default:
536 break;
537 case BO_LT:
538 return PGOHash::BinaryOperatorLT;
539 case BO_GT:
540 return PGOHash::BinaryOperatorGT;
541 case BO_LE:
542 return PGOHash::BinaryOperatorLE;
543 case BO_GE:
544 return PGOHash::BinaryOperatorGE;
545 case BO_EQ:
546 return PGOHash::BinaryOperatorEQ;
547 case BO_NE:
548 return PGOHash::BinaryOperatorNE;
549 }
550 }
551 break;
552 }
553 }
554
555 if (HashVersion >= PGO_HASH_V2) {
556 switch (S->getStmtClass()) {
557 default:
558 break;
559 case Stmt::GotoStmtClass:
560 return PGOHash::GotoStmt;
561 case Stmt::IndirectGotoStmtClass:
562 return PGOHash::IndirectGotoStmt;
563 case Stmt::BreakStmtClass:
564 return PGOHash::BreakStmt;
565 case Stmt::ContinueStmtClass:
566 return PGOHash::ContinueStmt;
567 case Stmt::ReturnStmtClass:
568 return PGOHash::ReturnStmt;
569 case Stmt::CXXThrowExprClass:
570 return PGOHash::ThrowExpr;
571 case Stmt::UnaryOperatorClass: {
572 const UnaryOperator *UO = cast<UnaryOperator>(Val: S);
573 if (UO->getOpcode() == UO_LNot)
574 return PGOHash::UnaryOperatorLNot;
575 break;
576 }
577 }
578 }
579
580 return PGOHash::None;
581 }
582};
583
584/// A StmtVisitor that propagates the raw counts through the AST and
585/// records the count at statements where the value may change.
586struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
587 /// PGO state.
588 CodeGenPGO &PGO;
589
590 /// A flag that is set when the current count should be recorded on the
591 /// next statement, such as at the exit of a loop.
592 bool RecordNextStmtCount;
593
594 /// The count at the current location in the traversal.
595 uint64_t CurrentCount;
596
597 /// The map of statements to count values.
598 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
599
600 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
601 struct BreakContinue {
602 uint64_t BreakCount = 0;
603 uint64_t ContinueCount = 0;
604 BreakContinue() = default;
605 };
606 SmallVector<BreakContinue, 8> BreakContinueStack;
607
608 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
609 CodeGenPGO &PGO)
610 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
611
612 void RecordStmtCount(const Stmt *S) {
613 if (RecordNextStmtCount) {
614 CountMap[S] = CurrentCount;
615 RecordNextStmtCount = false;
616 }
617 }
618
619 /// Set and return the current count.
620 uint64_t setCount(uint64_t Count) {
621 CurrentCount = Count;
622 return Count;
623 }
624
625 void VisitStmt(const Stmt *S) {
626 RecordStmtCount(S);
627 for (const Stmt *Child : S->children())
628 if (Child)
629 this->Visit(Child);
630 }
631
632 void VisitFunctionDecl(const FunctionDecl *D) {
633 // Counter tracks entry to the function body.
634 uint64_t BodyCount = setCount(PGO.getRegionCount(S: D->getBody()));
635 CountMap[D->getBody()] = BodyCount;
636 Visit(D->getBody());
637 }
638
639 // Skip lambda expressions. We visit these as FunctionDecls when we're
640 // generating them and aren't interested in the body when generating a
641 // parent context.
642 void VisitLambdaExpr(const LambdaExpr *LE) {}
643
644 void VisitCapturedDecl(const CapturedDecl *D) {
645 // Counter tracks entry to the capture body.
646 uint64_t BodyCount = setCount(PGO.getRegionCount(S: D->getBody()));
647 CountMap[D->getBody()] = BodyCount;
648 Visit(D->getBody());
649 }
650
651 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
652 // Counter tracks entry to the method body.
653 uint64_t BodyCount = setCount(PGO.getRegionCount(S: D->getBody()));
654 CountMap[D->getBody()] = BodyCount;
655 Visit(D->getBody());
656 }
657
658 void VisitBlockDecl(const BlockDecl *D) {
659 // Counter tracks entry to the block body.
660 uint64_t BodyCount = setCount(PGO.getRegionCount(S: D->getBody()));
661 CountMap[D->getBody()] = BodyCount;
662 Visit(D->getBody());
663 }
664
665 void VisitReturnStmt(const ReturnStmt *S) {
666 RecordStmtCount(S);
667 if (S->getRetValue())
668 Visit(S->getRetValue());
669 CurrentCount = 0;
670 RecordNextStmtCount = true;
671 }
672
673 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
674 RecordStmtCount(E);
675 if (E->getSubExpr())
676 Visit(E->getSubExpr());
677 CurrentCount = 0;
678 RecordNextStmtCount = true;
679 }
680
681 void VisitGotoStmt(const GotoStmt *S) {
682 RecordStmtCount(S);
683 CurrentCount = 0;
684 RecordNextStmtCount = true;
685 }
686
687 void VisitLabelStmt(const LabelStmt *S) {
688 RecordNextStmtCount = false;
689 // Counter tracks the block following the label.
690 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
691 CountMap[S] = BlockCount;
692 Visit(S->getSubStmt());
693 }
694
695 void VisitBreakStmt(const BreakStmt *S) {
696 RecordStmtCount(S);
697 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
698 BreakContinueStack.back().BreakCount += CurrentCount;
699 CurrentCount = 0;
700 RecordNextStmtCount = true;
701 }
702
703 void VisitContinueStmt(const ContinueStmt *S) {
704 RecordStmtCount(S);
705 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
706 BreakContinueStack.back().ContinueCount += CurrentCount;
707 CurrentCount = 0;
708 RecordNextStmtCount = true;
709 }
710
711 void VisitWhileStmt(const WhileStmt *S) {
712 RecordStmtCount(S);
713 uint64_t ParentCount = CurrentCount;
714
715 BreakContinueStack.push_back(Elt: BreakContinue());
716 // Visit the body region first so the break/continue adjustments can be
717 // included when visiting the condition.
718 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
719 CountMap[S->getBody()] = CurrentCount;
720 Visit(S->getBody());
721 uint64_t BackedgeCount = CurrentCount;
722
723 // ...then go back and propagate counts through the condition. The count
724 // at the start of the condition is the sum of the incoming edges,
725 // the backedge from the end of the loop body, and the edges from
726 // continue statements.
727 BreakContinue BC = BreakContinueStack.pop_back_val();
728 uint64_t CondCount =
729 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
730 CountMap[S->getCond()] = CondCount;
731 Visit(S->getCond());
732 setCount(BC.BreakCount + CondCount - BodyCount);
733 RecordNextStmtCount = true;
734 }
735
736 void VisitDoStmt(const DoStmt *S) {
737 RecordStmtCount(S);
738 uint64_t LoopCount = PGO.getRegionCount(S);
739
740 BreakContinueStack.push_back(Elt: BreakContinue());
741 // The count doesn't include the fallthrough from the parent scope. Add it.
742 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
743 CountMap[S->getBody()] = BodyCount;
744 Visit(S->getBody());
745 uint64_t BackedgeCount = CurrentCount;
746
747 BreakContinue BC = BreakContinueStack.pop_back_val();
748 // The count at the start of the condition is equal to the count at the
749 // end of the body, plus any continues.
750 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
751 CountMap[S->getCond()] = CondCount;
752 Visit(S->getCond());
753 setCount(BC.BreakCount + CondCount - LoopCount);
754 RecordNextStmtCount = true;
755 }
756
757 void VisitForStmt(const ForStmt *S) {
758 RecordStmtCount(S);
759 if (S->getInit())
760 Visit(S->getInit());
761
762 uint64_t ParentCount = CurrentCount;
763
764 BreakContinueStack.push_back(Elt: BreakContinue());
765 // Visit the body region first. (This is basically the same as a while
766 // loop; see further comments in VisitWhileStmt.)
767 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
768 CountMap[S->getBody()] = BodyCount;
769 Visit(S->getBody());
770 uint64_t BackedgeCount = CurrentCount;
771 BreakContinue BC = BreakContinueStack.pop_back_val();
772
773 // The increment is essentially part of the body but it needs to include
774 // the count for all the continue statements.
775 if (S->getInc()) {
776 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
777 CountMap[S->getInc()] = IncCount;
778 Visit(S->getInc());
779 }
780
781 // ...then go back and propagate counts through the condition.
782 uint64_t CondCount =
783 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
784 if (S->getCond()) {
785 CountMap[S->getCond()] = CondCount;
786 Visit(S->getCond());
787 }
788 setCount(BC.BreakCount + CondCount - BodyCount);
789 RecordNextStmtCount = true;
790 }
791
792 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
793 RecordStmtCount(S);
794 if (S->getInit())
795 Visit(S->getInit());
796 Visit(S->getLoopVarStmt());
797 Visit(S->getRangeStmt());
798 Visit(S->getBeginStmt());
799 Visit(S->getEndStmt());
800
801 uint64_t ParentCount = CurrentCount;
802 BreakContinueStack.push_back(Elt: BreakContinue());
803 // Visit the body region first. (This is basically the same as a while
804 // loop; see further comments in VisitWhileStmt.)
805 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
806 CountMap[S->getBody()] = BodyCount;
807 Visit(S->getBody());
808 uint64_t BackedgeCount = CurrentCount;
809 BreakContinue BC = BreakContinueStack.pop_back_val();
810
811 // The increment is essentially part of the body but it needs to include
812 // the count for all the continue statements.
813 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
814 CountMap[S->getInc()] = IncCount;
815 Visit(S->getInc());
816
817 // ...then go back and propagate counts through the condition.
818 uint64_t CondCount =
819 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
820 CountMap[S->getCond()] = CondCount;
821 Visit(S->getCond());
822 setCount(BC.BreakCount + CondCount - BodyCount);
823 RecordNextStmtCount = true;
824 }
825
826 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
827 RecordStmtCount(S);
828 Visit(S->getElement());
829 uint64_t ParentCount = CurrentCount;
830 BreakContinueStack.push_back(Elt: BreakContinue());
831 // Counter tracks the body of the loop.
832 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
833 CountMap[S->getBody()] = BodyCount;
834 Visit(S->getBody());
835 uint64_t BackedgeCount = CurrentCount;
836 BreakContinue BC = BreakContinueStack.pop_back_val();
837
838 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
839 BodyCount);
840 RecordNextStmtCount = true;
841 }
842
843 void VisitSwitchStmt(const SwitchStmt *S) {
844 RecordStmtCount(S);
845 if (S->getInit())
846 Visit(S->getInit());
847 Visit(S->getCond());
848 CurrentCount = 0;
849 BreakContinueStack.push_back(Elt: BreakContinue());
850 Visit(S->getBody());
851 // If the switch is inside a loop, add the continue counts.
852 BreakContinue BC = BreakContinueStack.pop_back_val();
853 if (!BreakContinueStack.empty())
854 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
855 // Counter tracks the exit block of the switch.
856 setCount(PGO.getRegionCount(S));
857 RecordNextStmtCount = true;
858 }
859
860 void VisitSwitchCase(const SwitchCase *S) {
861 RecordNextStmtCount = false;
862 // Counter for this particular case. This counts only jumps from the
863 // switch header and does not include fallthrough from the case before
864 // this one.
865 uint64_t CaseCount = PGO.getRegionCount(S);
866 setCount(CurrentCount + CaseCount);
867 // We need the count without fallthrough in the mapping, so it's more useful
868 // for branch probabilities.
869 CountMap[S] = CaseCount;
870 RecordNextStmtCount = true;
871 Visit(S->getSubStmt());
872 }
873
874 void VisitIfStmt(const IfStmt *S) {
875 RecordStmtCount(S);
876
877 if (S->isConsteval()) {
878 const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
879 if (Stm)
880 Visit(Stm);
881 return;
882 }
883
884 uint64_t ParentCount = CurrentCount;
885 if (S->getInit())
886 Visit(S->getInit());
887 Visit(S->getCond());
888
889 // Counter tracks the "then" part of an if statement. The count for
890 // the "else" part, if it exists, will be calculated from this counter.
891 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
892 CountMap[S->getThen()] = ThenCount;
893 Visit(S->getThen());
894 uint64_t OutCount = CurrentCount;
895
896 uint64_t ElseCount = ParentCount - ThenCount;
897 if (S->getElse()) {
898 setCount(ElseCount);
899 CountMap[S->getElse()] = ElseCount;
900 Visit(S->getElse());
901 OutCount += CurrentCount;
902 } else
903 OutCount += ElseCount;
904 setCount(OutCount);
905 RecordNextStmtCount = true;
906 }
907
908 void VisitCXXTryStmt(const CXXTryStmt *S) {
909 RecordStmtCount(S);
910 Visit(S->getTryBlock());
911 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
912 Visit(S->getHandler(i: I));
913 // Counter tracks the continuation block of the try statement.
914 setCount(PGO.getRegionCount(S));
915 RecordNextStmtCount = true;
916 }
917
918 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
919 RecordNextStmtCount = false;
920 // Counter tracks the catch statement's handler block.
921 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
922 CountMap[S] = CatchCount;
923 Visit(S->getHandlerBlock());
924 }
925
926 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
927 RecordStmtCount(E);
928 uint64_t ParentCount = CurrentCount;
929 Visit(E->getCond());
930
931 // Counter tracks the "true" part of a conditional operator. The
932 // count in the "false" part will be calculated from this counter.
933 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
934 CountMap[E->getTrueExpr()] = TrueCount;
935 Visit(E->getTrueExpr());
936 uint64_t OutCount = CurrentCount;
937
938 uint64_t FalseCount = setCount(ParentCount - TrueCount);
939 CountMap[E->getFalseExpr()] = FalseCount;
940 Visit(E->getFalseExpr());
941 OutCount += CurrentCount;
942
943 setCount(OutCount);
944 RecordNextStmtCount = true;
945 }
946
947 void VisitBinLAnd(const BinaryOperator *E) {
948 RecordStmtCount(E);
949 uint64_t ParentCount = CurrentCount;
950 Visit(E->getLHS());
951 // Counter tracks the right hand side of a logical and operator.
952 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
953 CountMap[E->getRHS()] = RHSCount;
954 Visit(E->getRHS());
955 setCount(ParentCount + RHSCount - CurrentCount);
956 RecordNextStmtCount = true;
957 }
958
959 void VisitBinLOr(const BinaryOperator *E) {
960 RecordStmtCount(E);
961 uint64_t ParentCount = CurrentCount;
962 Visit(E->getLHS());
963 // Counter tracks the right hand side of a logical or operator.
964 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
965 CountMap[E->getRHS()] = RHSCount;
966 Visit(E->getRHS());
967 setCount(ParentCount + RHSCount - CurrentCount);
968 RecordNextStmtCount = true;
969 }
970};
971} // end anonymous namespace
972
973void PGOHash::combine(HashType Type) {
974 // Check that we never combine 0 and only have six bits.
975 assert(Type && "Hash is invalid: unexpected type 0");
976 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
977
978 // Pass through MD5 if enough work has built up.
979 if (Count && Count % NumTypesPerWord == 0) {
980 using namespace llvm::support;
981 uint64_t Swapped =
982 endian::byte_swap<uint64_t, llvm::endianness::little>(value: Working);
983 MD5.update(Data: llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
984 Working = 0;
985 }
986
987 // Accumulate the current type.
988 ++Count;
989 Working = Working << NumBitsPerType | Type;
990}
991
992uint64_t PGOHash::finalize() {
993 // Use Working as the hash directly if we never used MD5.
994 if (Count <= NumTypesPerWord)
995 // No need to byte swap here, since none of the math was endian-dependent.
996 // This number will be byte-swapped as required on endianness transitions,
997 // so we will see the same value on the other side.
998 return Working;
999
1000 // Check for remaining work in Working.
1001 if (Working) {
1002 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
1003 // is buggy because it converts a uint64_t into an array of uint8_t.
1004 if (HashVersion < PGO_HASH_V3) {
1005 MD5.update(Data: {(uint8_t)Working});
1006 } else {
1007 using namespace llvm::support;
1008 uint64_t Swapped =
1009 endian::byte_swap<uint64_t, llvm::endianness::little>(value: Working);
1010 MD5.update(Data: llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
1011 }
1012 }
1013
1014 // Finalize the MD5 and return the hash.
1015 llvm::MD5::MD5Result Result;
1016 MD5.final(Result);
1017 return Result.low();
1018}
1019
1020void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
1021 const Decl *D = GD.getDecl();
1022 if (!D->hasBody())
1023 return;
1024
1025 // Skip CUDA/HIP kernel launch stub functions.
1026 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
1027 D->hasAttr<CUDAGlobalAttr>())
1028 return;
1029
1030 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
1031 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1032 if (!InstrumentRegions && !PGOReader)
1033 return;
1034 if (D->isImplicit())
1035 return;
1036 // Constructors and destructors may be represented by several functions in IR.
1037 // If so, instrument only base variant, others are implemented by delegation
1038 // to the base one, it would be counted twice otherwise.
1039 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
1040 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(Val: D))
1041 if (GD.getCtorType() != Ctor_Base &&
1042 CodeGenFunction::IsConstructorDelegationValid(Ctor: CCD))
1043 return;
1044 }
1045 if (isa<CXXDestructorDecl>(Val: D) && GD.getDtorType() != Dtor_Base)
1046 return;
1047
1048 CGM.ClearUnusedCoverageMapping(D);
1049 if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
1050 return;
1051 if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
1052 return;
1053
1054 setFuncName(Fn);
1055
1056 mapRegionCounters(D);
1057 if (CGM.getCodeGenOpts().CoverageMapping)
1058 emitCounterRegionMapping(D);
1059 if (PGOReader) {
1060 SourceManager &SM = CGM.getContext().getSourceManager();
1061 loadRegionCounts(PGOReader, IsInMainFile: SM.isInMainFile(Loc: D->getLocation()));
1062 computeRegionCounts(D);
1063 applyFunctionAttributes(PGOReader, Fn);
1064 }
1065}
1066
1067void CodeGenPGO::mapRegionCounters(const Decl *D) {
1068 // Use the latest hash version when inserting instrumentation, but use the
1069 // version in the indexed profile if we're reading PGO data.
1070 PGOHashVersion HashVersion = PGO_HASH_LATEST;
1071 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
1072 if (auto *PGOReader = CGM.getPGOReader()) {
1073 HashVersion = getPGOHashVersion(PGOReader, CGM);
1074 ProfileVersion = PGOReader->getVersion();
1075 }
1076
1077 // If MC/DC is enabled, set the MaxConditions to a preset value. Otherwise,
1078 // set it to zero. This value impacts the number of conditions accepted in a
1079 // given boolean expression, which impacts the size of the bitmap used to
1080 // track test vector execution for that boolean expression. Because the
1081 // bitmap scales exponentially (2^n) based on the number of conditions seen,
1082 // the maximum value is hard-coded at 6 conditions, which is more than enough
1083 // for most embedded applications. Setting a maximum value prevents the
1084 // bitmap footprint from growing too large without the user's knowledge. In
1085 // the future, this value could be adjusted with a command-line option.
1086 unsigned MCDCMaxConditions = (CGM.getCodeGenOpts().MCDCCoverage) ? 6 : 0;
1087
1088 RegionCounterMap.reset(p: new llvm::DenseMap<const Stmt *, unsigned>);
1089 RegionMCDCState.reset(p: new MCDC::State);
1090 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap,
1091 *RegionMCDCState, MCDCMaxConditions, CGM.getDiags());
1092 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(Val: D))
1093 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
1094 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(Val: D))
1095 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
1096 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(Val: D))
1097 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
1098 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(Val: D))
1099 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
1100 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
1101 NumRegionCounters = Walker.NextCounter;
1102 RegionMCDCState->BitmapBytes = Walker.NextMCDCBitmapIdx;
1103 FunctionHash = Walker.Hash.finalize();
1104}
1105
1106bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
1107 if (!D->getBody())
1108 return true;
1109
1110 // Skip host-only functions in the CUDA device compilation and device-only
1111 // functions in the host compilation. Just roughly filter them out based on
1112 // the function attributes. If there are effectively host-only or device-only
1113 // ones, their coverage mapping may still be generated.
1114 if (CGM.getLangOpts().CUDA &&
1115 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
1116 !D->hasAttr<CUDAGlobalAttr>()) ||
1117 (!CGM.getLangOpts().CUDAIsDevice &&
1118 (D->hasAttr<CUDAGlobalAttr>() ||
1119 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
1120 return true;
1121
1122 // Don't map the functions in system headers.
1123 const auto &SM = CGM.getContext().getSourceManager();
1124 auto Loc = D->getBody()->getBeginLoc();
1125 return !SystemHeadersCoverage && SM.isInSystemHeader(Loc);
1126}
1127
1128void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
1129 if (skipRegionMappingForDecl(D))
1130 return;
1131
1132 std::string CoverageMapping;
1133 llvm::raw_string_ostream OS(CoverageMapping);
1134 RegionMCDCState->BranchByStmt.clear();
1135 CoverageMappingGen MappingGen(
1136 *CGM.getCoverageMapping(), CGM.getContext().getSourceManager(),
1137 CGM.getLangOpts(), RegionCounterMap.get(), RegionMCDCState.get());
1138 MappingGen.emitCounterMapping(D, OS);
1139 OS.flush();
1140
1141 if (CoverageMapping.empty())
1142 return;
1143
1144 CGM.getCoverageMapping()->addFunctionMappingRecord(
1145 FunctionName: FuncNameVar, FunctionNameValue: FuncName, FunctionHash, CoverageMapping);
1146}
1147
1148void
1149CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
1150 llvm::GlobalValue::LinkageTypes Linkage) {
1151 if (skipRegionMappingForDecl(D))
1152 return;
1153
1154 std::string CoverageMapping;
1155 llvm::raw_string_ostream OS(CoverageMapping);
1156 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
1157 CGM.getContext().getSourceManager(),
1158 CGM.getLangOpts());
1159 MappingGen.emitEmptyMapping(D, OS);
1160 OS.flush();
1161
1162 if (CoverageMapping.empty())
1163 return;
1164
1165 setFuncName(Name, Linkage);
1166 CGM.getCoverageMapping()->addFunctionMappingRecord(
1167 FunctionName: FuncNameVar, FunctionNameValue: FuncName, FunctionHash, CoverageMapping, IsUsed: false);
1168}
1169
1170void CodeGenPGO::computeRegionCounts(const Decl *D) {
1171 StmtCountMap.reset(p: new llvm::DenseMap<const Stmt *, uint64_t>);
1172 ComputeRegionCounts Walker(*StmtCountMap, *this);
1173 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(Val: D))
1174 Walker.VisitFunctionDecl(D: FD);
1175 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(Val: D))
1176 Walker.VisitObjCMethodDecl(D: MD);
1177 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(Val: D))
1178 Walker.VisitBlockDecl(D: BD);
1179 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(Val: D))
1180 Walker.VisitCapturedDecl(D: const_cast<CapturedDecl *>(CD));
1181}
1182
1183void
1184CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
1185 llvm::Function *Fn) {
1186 if (!haveRegionCounts())
1187 return;
1188
1189 uint64_t FunctionCount = getRegionCount(S: nullptr);
1190 Fn->setEntryCount(Count: FunctionCount);
1191}
1192
1193void CodeGenPGO::emitCounterSetOrIncrement(CGBuilderTy &Builder, const Stmt *S,
1194 llvm::Value *StepV) {
1195 if (!RegionCounterMap || !Builder.GetInsertBlock())
1196 return;
1197
1198 unsigned Counter = (*RegionCounterMap)[S];
1199
1200 llvm::Value *Args[] = {FuncNameVar,
1201 Builder.getInt64(C: FunctionHash),
1202 Builder.getInt32(C: NumRegionCounters),
1203 Builder.getInt32(C: Counter), StepV};
1204
1205 if (llvm::EnableSingleByteCoverage)
1206 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_cover),
1207 ArrayRef(Args, 4));
1208 else {
1209 if (!StepV)
1210 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
1211 ArrayRef(Args, 4));
1212 else
1213 Builder.CreateCall(
1214 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
1215 ArrayRef(Args));
1216 }
1217}
1218
1219bool CodeGenPGO::canEmitMCDCCoverage(const CGBuilderTy &Builder) {
1220 return (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1221 CGM.getCodeGenOpts().MCDCCoverage && Builder.GetInsertBlock());
1222}
1223
1224void CodeGenPGO::emitMCDCParameters(CGBuilderTy &Builder) {
1225 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1226 return;
1227
1228 auto *I8PtrTy = llvm::PointerType::getUnqual(C&: CGM.getLLVMContext());
1229
1230 // Emit intrinsic representing MCDC bitmap parameters at function entry.
1231 // This is used by the instrumentation pass, but it isn't actually lowered to
1232 // anything.
1233 llvm::Value *Args[3] = {llvm::ConstantExpr::getBitCast(C: FuncNameVar, Ty: I8PtrTy),
1234 Builder.getInt64(C: FunctionHash),
1235 Builder.getInt32(C: RegionMCDCState->BitmapBytes)};
1236 Builder.CreateCall(
1237 CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_parameters), Args);
1238}
1239
1240void CodeGenPGO::emitMCDCTestVectorBitmapUpdate(CGBuilderTy &Builder,
1241 const Expr *S,
1242 Address MCDCCondBitmapAddr,
1243 CodeGenFunction &CGF) {
1244 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1245 return;
1246
1247 S = S->IgnoreParens();
1248
1249 auto DecisionStateIter = RegionMCDCState->DecisionByStmt.find(S);
1250 if (DecisionStateIter == RegionMCDCState->DecisionByStmt.end())
1251 return;
1252
1253 // Extract the offset of the global bitmap associated with this expression.
1254 unsigned MCDCTestVectorBitmapOffset = DecisionStateIter->second.BitmapIdx;
1255 auto *I8PtrTy = llvm::PointerType::getUnqual(C&: CGM.getLLVMContext());
1256
1257 // Emit intrinsic responsible for updating the global bitmap corresponding to
1258 // a boolean expression. The index being set is based on the value loaded
1259 // from a pointer to a dedicated temporary value on the stack that is itself
1260 // updated via emitMCDCCondBitmapReset() and emitMCDCCondBitmapUpdate(). The
1261 // index represents an executed test vector.
1262 llvm::Value *Args[5] = {llvm::ConstantExpr::getBitCast(C: FuncNameVar, Ty: I8PtrTy),
1263 Builder.getInt64(C: FunctionHash),
1264 Builder.getInt32(C: RegionMCDCState->BitmapBytes),
1265 Builder.getInt32(C: MCDCTestVectorBitmapOffset),
1266 MCDCCondBitmapAddr.emitRawPointer(CGF)};
1267 Builder.CreateCall(
1268 CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_tvbitmap_update), Args);
1269}
1270
1271void CodeGenPGO::emitMCDCCondBitmapReset(CGBuilderTy &Builder, const Expr *S,
1272 Address MCDCCondBitmapAddr) {
1273 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1274 return;
1275
1276 S = S->IgnoreParens();
1277
1278 if (!RegionMCDCState->DecisionByStmt.contains(S))
1279 return;
1280
1281 // Emit intrinsic that resets a dedicated temporary value on the stack to 0.
1282 Builder.CreateStore(Val: Builder.getInt32(C: 0), Addr: MCDCCondBitmapAddr);
1283}
1284
1285void CodeGenPGO::emitMCDCCondBitmapUpdate(CGBuilderTy &Builder, const Expr *S,
1286 Address MCDCCondBitmapAddr,
1287 llvm::Value *Val,
1288 CodeGenFunction &CGF) {
1289 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1290 return;
1291
1292 // Even though, for simplicity, parentheses and unary logical-NOT operators
1293 // are considered part of their underlying condition for both MC/DC and
1294 // branch coverage, the condition IDs themselves are assigned and tracked
1295 // using the underlying condition itself. This is done solely for
1296 // consistency since parentheses and logical-NOTs are ignored when checking
1297 // whether the condition is actually an instrumentable condition. This can
1298 // also make debugging a bit easier.
1299 S = CodeGenFunction::stripCond(C: S);
1300
1301 auto BranchStateIter = RegionMCDCState->BranchByStmt.find(S);
1302 if (BranchStateIter == RegionMCDCState->BranchByStmt.end())
1303 return;
1304
1305 // Extract the ID of the condition we are setting in the bitmap.
1306 const auto &Branch = BranchStateIter->second;
1307 assert(Branch.ID >= 0 && "Condition has no ID!");
1308
1309 auto *I8PtrTy = llvm::PointerType::getUnqual(C&: CGM.getLLVMContext());
1310
1311 // Emit intrinsic that updates a dedicated temporary value on the stack after
1312 // a condition is evaluated. After the set of conditions has been updated,
1313 // the resulting value is used to update the boolean expression's bitmap.
1314 llvm::Value *Args[5] = {llvm::ConstantExpr::getBitCast(C: FuncNameVar, Ty: I8PtrTy),
1315 Builder.getInt64(C: FunctionHash),
1316 Builder.getInt32(C: Branch.ID),
1317 MCDCCondBitmapAddr.emitRawPointer(CGF), Val};
1318 Builder.CreateCall(
1319 CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_condbitmap_update),
1320 Args);
1321}
1322
1323void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
1324 if (CGM.getCodeGenOpts().hasProfileClangInstr())
1325 M.addModuleFlag(Behavior: llvm::Module::Warning, Key: "EnableValueProfiling",
1326 Val: uint32_t(EnableValueProfiling));
1327}
1328
1329void CodeGenPGO::setProfileVersion(llvm::Module &M) {
1330 if (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1331 llvm::EnableSingleByteCoverage) {
1332 const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
1333 llvm::Type *IntTy64 = llvm::Type::getInt64Ty(C&: M.getContext());
1334 uint64_t ProfileVersion =
1335 (INSTR_PROF_RAW_VERSION | VARIANT_MASK_BYTE_COVERAGE);
1336
1337 auto IRLevelVersionVariable = new llvm::GlobalVariable(
1338 M, IntTy64, true, llvm::GlobalValue::WeakAnyLinkage,
1339 llvm::Constant::getIntegerValue(Ty: IntTy64,
1340 V: llvm::APInt(64, ProfileVersion)),
1341 VarName);
1342
1343 IRLevelVersionVariable->setVisibility(llvm::GlobalValue::DefaultVisibility);
1344 llvm::Triple TT(M.getTargetTriple());
1345 if (TT.supportsCOMDAT()) {
1346 IRLevelVersionVariable->setLinkage(llvm::GlobalValue::ExternalLinkage);
1347 IRLevelVersionVariable->setComdat(M.getOrInsertComdat(Name: VarName));
1348 }
1349 IRLevelVersionVariable->setDSOLocal(true);
1350 }
1351}
1352
1353// This method either inserts a call to the profile run-time during
1354// instrumentation or puts profile data into metadata for PGO use.
1355void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
1356 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
1357
1358 if (!EnableValueProfiling)
1359 return;
1360
1361 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
1362 return;
1363
1364 if (isa<llvm::Constant>(Val: ValuePtr))
1365 return;
1366
1367 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
1368 if (InstrumentValueSites && RegionCounterMap) {
1369 auto BuilderInsertPoint = Builder.saveIP();
1370 Builder.SetInsertPoint(ValueSite);
1371 llvm::Value *Args[5] = {
1372 FuncNameVar,
1373 Builder.getInt64(C: FunctionHash),
1374 Builder.CreatePtrToInt(V: ValuePtr, DestTy: Builder.getInt64Ty()),
1375 Builder.getInt32(C: ValueKind),
1376 Builder.getInt32(C: NumValueSites[ValueKind]++)
1377 };
1378 Builder.CreateCall(
1379 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1380 Builder.restoreIP(IP: BuilderInsertPoint);
1381 return;
1382 }
1383
1384 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1385 if (PGOReader && haveRegionCounts()) {
1386 // We record the top most called three functions at each call site.
1387 // Profile metadata contains "VP" string identifying this metadata
1388 // as value profiling data, then a uint32_t value for the value profiling
1389 // kind, a uint64_t value for the total number of times the call is
1390 // executed, followed by the function hash and execution count (uint64_t)
1391 // pairs for each function.
1392 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1393 return;
1394
1395 llvm::annotateValueSite(M&: CGM.getModule(), Inst&: *ValueSite, InstrProfR: *ProfRecord,
1396 ValueKind: (llvm::InstrProfValueKind)ValueKind,
1397 SiteIndx: NumValueSites[ValueKind]);
1398
1399 NumValueSites[ValueKind]++;
1400 }
1401}
1402
1403void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1404 bool IsInMainFile) {
1405 CGM.getPGOStats().addVisited(MainFile: IsInMainFile);
1406 RegionCounts.clear();
1407 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1408 PGOReader->getInstrProfRecord(FuncName, FuncHash: FunctionHash);
1409 if (auto E = RecordExpected.takeError()) {
1410 auto IPE = std::get<0>(in: llvm::InstrProfError::take(E: std::move(E)));
1411 if (IPE == llvm::instrprof_error::unknown_function)
1412 CGM.getPGOStats().addMissing(MainFile: IsInMainFile);
1413 else if (IPE == llvm::instrprof_error::hash_mismatch)
1414 CGM.getPGOStats().addMismatched(MainFile: IsInMainFile);
1415 else if (IPE == llvm::instrprof_error::malformed)
1416 // TODO: Consider a more specific warning for this case.
1417 CGM.getPGOStats().addMismatched(MainFile: IsInMainFile);
1418 return;
1419 }
1420 ProfRecord =
1421 std::make_unique<llvm::InstrProfRecord>(args: std::move(RecordExpected.get()));
1422 RegionCounts = ProfRecord->Counts;
1423}
1424
1425/// Calculate what to divide by to scale weights.
1426///
1427/// Given the maximum weight, calculate a divisor that will scale all the
1428/// weights to strictly less than UINT32_MAX.
1429static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1430 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1431}
1432
1433/// Scale an individual branch weight (and add 1).
1434///
1435/// Scale a 64-bit weight down to 32-bits using \c Scale.
1436///
1437/// According to Laplace's Rule of Succession, it is better to compute the
1438/// weight based on the count plus 1, so universally add 1 to the value.
1439///
1440/// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1441/// greater than \c Weight.
1442static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1443 assert(Scale && "scale by 0?");
1444 uint64_t Scaled = Weight / Scale + 1;
1445 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1446 return Scaled;
1447}
1448
1449llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1450 uint64_t FalseCount) const {
1451 // Check for empty weights.
1452 if (!TrueCount && !FalseCount)
1453 return nullptr;
1454
1455 // Calculate how to scale down to 32-bits.
1456 uint64_t Scale = calculateWeightScale(MaxWeight: std::max(a: TrueCount, b: FalseCount));
1457
1458 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1459 return MDHelper.createBranchWeights(TrueWeight: scaleBranchWeight(Weight: TrueCount, Scale),
1460 FalseWeight: scaleBranchWeight(Weight: FalseCount, Scale));
1461}
1462
1463llvm::MDNode *
1464CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1465 // We need at least two elements to create meaningful weights.
1466 if (Weights.size() < 2)
1467 return nullptr;
1468
1469 // Check for empty weights.
1470 uint64_t MaxWeight = *std::max_element(first: Weights.begin(), last: Weights.end());
1471 if (MaxWeight == 0)
1472 return nullptr;
1473
1474 // Calculate how to scale down to 32-bits.
1475 uint64_t Scale = calculateWeightScale(MaxWeight);
1476
1477 SmallVector<uint32_t, 16> ScaledWeights;
1478 ScaledWeights.reserve(N: Weights.size());
1479 for (uint64_t W : Weights)
1480 ScaledWeights.push_back(Elt: scaleBranchWeight(Weight: W, Scale));
1481
1482 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1483 return MDHelper.createBranchWeights(Weights: ScaledWeights);
1484}
1485
1486llvm::MDNode *
1487CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1488 uint64_t LoopCount) const {
1489 if (!PGO.haveRegionCounts())
1490 return nullptr;
1491 std::optional<uint64_t> CondCount = PGO.getStmtCount(S: Cond);
1492 if (!CondCount || *CondCount == 0)
1493 return nullptr;
1494 return createProfileWeights(TrueCount: LoopCount,
1495 FalseCount: std::max(a: *CondCount, b: LoopCount) - LoopCount);
1496}
1497

source code of clang/lib/CodeGen/CodeGenPGO.cpp