1//===- bolt/Passes/IndirectCallPromotion.cpp ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the IndirectCallPromotion class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "bolt/Passes/IndirectCallPromotion.h"
14#include "bolt/Core/BinaryFunctionCallGraph.h"
15#include "bolt/Passes/DataflowInfoManager.h"
16#include "bolt/Passes/Inliner.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/Support/CommandLine.h"
19#include <iterator>
20
21#define DEBUG_TYPE "ICP"
22#define DEBUG_VERBOSE(Level, X) \
23 if (opts::Verbosity >= (Level)) { \
24 X; \
25 }
26
27using namespace llvm;
28using namespace bolt;
29
30namespace opts {
31
32extern cl::OptionCategory BoltOptCategory;
33
34extern cl::opt<IndirectCallPromotionType> ICP;
35extern cl::opt<unsigned> Verbosity;
36extern cl::opt<unsigned> ExecutionCountThreshold;
37
38static cl::opt<unsigned> ICPJTRemainingPercentThreshold(
39 "icp-jt-remaining-percent-threshold",
40 cl::desc("The percentage threshold against remaining unpromoted indirect "
41 "call count for the promotion for jump tables"),
42 cl::init(Val: 30), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
43
44static cl::opt<unsigned> ICPJTTotalPercentThreshold(
45 "icp-jt-total-percent-threshold",
46 cl::desc(
47 "The percentage threshold against total count for the promotion for "
48 "jump tables"),
49 cl::init(Val: 5), cl::Hidden, cl::cat(BoltOptCategory));
50
51static cl::opt<unsigned> ICPCallsRemainingPercentThreshold(
52 "icp-calls-remaining-percent-threshold",
53 cl::desc("The percentage threshold against remaining unpromoted indirect "
54 "call count for the promotion for calls"),
55 cl::init(Val: 50), cl::Hidden, cl::cat(BoltOptCategory));
56
57static cl::opt<unsigned> ICPCallsTotalPercentThreshold(
58 "icp-calls-total-percent-threshold",
59 cl::desc(
60 "The percentage threshold against total count for the promotion for "
61 "calls"),
62 cl::init(Val: 30), cl::Hidden, cl::cat(BoltOptCategory));
63
64static cl::opt<unsigned> ICPMispredictThreshold(
65 "indirect-call-promotion-mispredict-threshold",
66 cl::desc("misprediction threshold for skipping ICP on an "
67 "indirect call"),
68 cl::init(Val: 0), cl::cat(BoltOptCategory));
69
70static cl::alias ICPMispredictThresholdAlias(
71 "icp-mp-threshold",
72 cl::desc("alias for --indirect-call-promotion-mispredict-threshold"),
73 cl::aliasopt(ICPMispredictThreshold));
74
75static cl::opt<bool> ICPUseMispredicts(
76 "indirect-call-promotion-use-mispredicts",
77 cl::desc("use misprediction frequency for determining whether or not ICP "
78 "should be applied at a callsite. The "
79 "-indirect-call-promotion-mispredict-threshold value will be used "
80 "by this heuristic"),
81 cl::cat(BoltOptCategory));
82
83static cl::alias ICPUseMispredictsAlias(
84 "icp-use-mp",
85 cl::desc("alias for --indirect-call-promotion-use-mispredicts"),
86 cl::aliasopt(ICPUseMispredicts));
87
88static cl::opt<unsigned>
89 ICPTopN("indirect-call-promotion-topn",
90 cl::desc("limit number of targets to consider when doing indirect "
91 "call promotion. 0 = no limit"),
92 cl::init(Val: 3), cl::cat(BoltOptCategory));
93
94static cl::alias
95 ICPTopNAlias("icp-topn",
96 cl::desc("alias for --indirect-call-promotion-topn"),
97 cl::aliasopt(ICPTopN));
98
99static cl::opt<unsigned> ICPCallsTopN(
100 "indirect-call-promotion-calls-topn",
101 cl::desc("limit number of targets to consider when doing indirect "
102 "call promotion on calls. 0 = no limit"),
103 cl::init(Val: 0), cl::cat(BoltOptCategory));
104
105static cl::alias ICPCallsTopNAlias(
106 "icp-calls-topn",
107 cl::desc("alias for --indirect-call-promotion-calls-topn"),
108 cl::aliasopt(ICPCallsTopN));
109
110static cl::opt<unsigned> ICPJumpTablesTopN(
111 "indirect-call-promotion-jump-tables-topn",
112 cl::desc("limit number of targets to consider when doing indirect "
113 "call promotion on jump tables. 0 = no limit"),
114 cl::init(Val: 0), cl::cat(BoltOptCategory));
115
116static cl::alias ICPJumpTablesTopNAlias(
117 "icp-jt-topn",
118 cl::desc("alias for --indirect-call-promotion-jump-tables-topn"),
119 cl::aliasopt(ICPJumpTablesTopN));
120
121static cl::opt<bool> EliminateLoads(
122 "icp-eliminate-loads",
123 cl::desc("enable load elimination using memory profiling data when "
124 "performing ICP"),
125 cl::init(Val: true), cl::cat(BoltOptCategory));
126
127static cl::opt<unsigned> ICPTopCallsites(
128 "icp-top-callsites",
129 cl::desc("optimize hottest calls until at least this percentage of all "
130 "indirect calls frequency is covered. 0 = all callsites"),
131 cl::init(Val: 99), cl::Hidden, cl::cat(BoltOptCategory));
132
133static cl::list<std::string>
134 ICPFuncsList("icp-funcs", cl::CommaSeparated,
135 cl::desc("list of functions to enable ICP for"),
136 cl::value_desc("func1,func2,func3,..."), cl::Hidden,
137 cl::cat(BoltOptCategory));
138
139static cl::opt<bool>
140 ICPOldCodeSequence("icp-old-code-sequence",
141 cl::desc("use old code sequence for promoted calls"),
142 cl::Hidden, cl::cat(BoltOptCategory));
143
144static cl::opt<bool> ICPJumpTablesByTarget(
145 "icp-jump-tables-targets",
146 cl::desc(
147 "for jump tables, optimize indirect jmp targets instead of indices"),
148 cl::Hidden, cl::cat(BoltOptCategory));
149
150static cl::alias
151 ICPJumpTablesByTargetAlias("icp-jt-targets",
152 cl::desc("alias for --icp-jump-tables-targets"),
153 cl::aliasopt(ICPJumpTablesByTarget));
154
155static cl::opt<bool> ICPPeelForInline(
156 "icp-inline", cl::desc("only promote call targets eligible for inlining"),
157 cl::Hidden, cl::cat(BoltOptCategory));
158
159} // namespace opts
160
161#ifndef NDEBUG
162static bool verifyProfile(std::map<uint64_t, BinaryFunction> &BFs) {
163 bool IsValid = true;
164 for (auto &BFI : BFs) {
165 BinaryFunction &BF = BFI.second;
166 if (!BF.isSimple())
167 continue;
168 for (const BinaryBasicBlock &BB : BF) {
169 auto BI = BB.branch_info_begin();
170 for (BinaryBasicBlock *SuccBB : BB.successors()) {
171 if (BI->Count != BinaryBasicBlock::COUNT_NO_PROFILE && BI->Count > 0) {
172 if (BB.getKnownExecutionCount() == 0 ||
173 SuccBB->getKnownExecutionCount() == 0) {
174 BF.getBinaryContext().errs()
175 << "BOLT-WARNING: profile verification failed after ICP for "
176 "function "
177 << BF << '\n';
178 IsValid = false;
179 }
180 }
181 ++BI;
182 }
183 }
184 }
185 return IsValid;
186}
187#endif
188
189namespace llvm {
190namespace bolt {
191
192IndirectCallPromotion::Callsite::Callsite(BinaryFunction &BF,
193 const IndirectCallProfile &ICP)
194 : From(BF.getSymbol()), To(ICP.Offset), Mispreds(ICP.Mispreds),
195 Branches(ICP.Count) {
196 if (ICP.Symbol) {
197 To.Sym = ICP.Symbol;
198 To.Addr = 0;
199 }
200}
201
202void IndirectCallPromotion::printDecision(
203 llvm::raw_ostream &OS,
204 std::vector<IndirectCallPromotion::Callsite> &Targets, unsigned N) const {
205 uint64_t TotalCount = 0;
206 uint64_t TotalMispreds = 0;
207 for (const Callsite &S : Targets) {
208 TotalCount += S.Branches;
209 TotalMispreds += S.Mispreds;
210 }
211 if (!TotalCount)
212 TotalCount = 1;
213 if (!TotalMispreds)
214 TotalMispreds = 1;
215
216 OS << "BOLT-INFO: ICP decision for call site with " << Targets.size()
217 << " targets, Count = " << TotalCount << ", Mispreds = " << TotalMispreds
218 << "\n";
219
220 size_t I = 0;
221 for (const Callsite &S : Targets) {
222 OS << "Count = " << S.Branches << ", "
223 << format(Fmt: "%.1f", Vals: (100.0 * S.Branches) / TotalCount) << ", "
224 << "Mispreds = " << S.Mispreds << ", "
225 << format(Fmt: "%.1f", Vals: (100.0 * S.Mispreds) / TotalMispreds);
226 if (I < N)
227 OS << " * to be optimized *";
228 if (!S.JTIndices.empty()) {
229 OS << " Indices:";
230 for (const uint64_t Idx : S.JTIndices)
231 OS << " " << Idx;
232 }
233 OS << "\n";
234 I += S.JTIndices.empty() ? 1 : S.JTIndices.size();
235 }
236}
237
238// Get list of targets for a given call sorted by most frequently
239// called first.
240std::vector<IndirectCallPromotion::Callsite>
241IndirectCallPromotion::getCallTargets(BinaryBasicBlock &BB,
242 const MCInst &Inst) const {
243 BinaryFunction &BF = *BB.getFunction();
244 const BinaryContext &BC = BF.getBinaryContext();
245 std::vector<Callsite> Targets;
246
247 if (const JumpTable *JT = BF.getJumpTable(Inst)) {
248 // Don't support PIC jump tables for now
249 if (!opts::ICPJumpTablesByTarget && JT->Type == JumpTable::JTT_PIC)
250 return Targets;
251 const Location From(BF.getSymbol());
252 const std::pair<size_t, size_t> Range =
253 JT->getEntriesForAddress(Addr: BC.MIB->getJumpTable(Inst));
254 assert(JT->Counts.empty() || JT->Counts.size() >= Range.second);
255 JumpTable::JumpInfo DefaultJI;
256 const JumpTable::JumpInfo *JI =
257 JT->Counts.empty() ? &DefaultJI : &JT->Counts[Range.first];
258 const size_t JIAdj = JT->Counts.empty() ? 0 : 1;
259 assert(JT->Type == JumpTable::JTT_PIC ||
260 JT->EntrySize == BC.AsmInfo->getCodePointerSize());
261 for (size_t I = Range.first; I < Range.second; ++I, JI += JIAdj) {
262 MCSymbol *Entry = JT->Entries[I];
263 const BinaryBasicBlock *ToBB = BF.getBasicBlockForLabel(Label: Entry);
264 assert(ToBB || Entry == BF.getFunctionEndLabel() ||
265 Entry == BF.getFunctionEndLabel(FragmentNum::cold()));
266 if (Entry == BF.getFunctionEndLabel() ||
267 Entry == BF.getFunctionEndLabel(Fragment: FragmentNum::cold()))
268 continue;
269 const Location To(Entry);
270 const BinaryBasicBlock::BinaryBranchInfo &BI = BB.getBranchInfo(Succ: *ToBB);
271 Targets.emplace_back(args: From, args: To, args: BI.MispredictedCount, args: BI.Count,
272 args: I - Range.first);
273 }
274
275 // Sort by symbol then addr.
276 llvm::sort(C&: Targets, Comp: [](const Callsite &A, const Callsite &B) {
277 if (A.To.Sym && B.To.Sym)
278 return A.To.Sym < B.To.Sym;
279 else if (A.To.Sym && !B.To.Sym)
280 return true;
281 else if (!A.To.Sym && B.To.Sym)
282 return false;
283 else
284 return A.To.Addr < B.To.Addr;
285 });
286
287 // Targets may contain multiple entries to the same target, but using
288 // different indices. Their profile will report the same number of branches
289 // for different indices if the target is the same. That's because we don't
290 // profile the index value, but only the target via LBR.
291 auto First = Targets.begin();
292 auto Last = Targets.end();
293 auto Result = First;
294 while (++First != Last) {
295 Callsite &A = *Result;
296 const Callsite &B = *First;
297 if (A.To.Sym && B.To.Sym && A.To.Sym == B.To.Sym)
298 A.JTIndices.insert(position: A.JTIndices.end(), first: B.JTIndices.begin(),
299 last: B.JTIndices.end());
300 else
301 *(++Result) = *First;
302 }
303 ++Result;
304
305 LLVM_DEBUG(if (Targets.end() - Result > 0) {
306 dbgs() << "BOLT-INFO: ICP: " << (Targets.end() - Result)
307 << " duplicate targets removed\n";
308 });
309
310 Targets.erase(first: Result, last: Targets.end());
311 } else {
312 // Don't try to optimize PC relative indirect calls.
313 if (Inst.getOperand(i: 0).isReg() &&
314 Inst.getOperand(i: 0).getReg() == BC.MRI->getProgramCounter())
315 return Targets;
316
317 const auto ICSP = BC.MIB->tryGetAnnotationAs<IndirectCallSiteProfile>(
318 Inst, Name: "CallProfile");
319 if (ICSP) {
320 for (const IndirectCallProfile &CSP : ICSP.get()) {
321 Callsite Site(BF, CSP);
322 if (Site.isValid())
323 Targets.emplace_back(args: std::move(Site));
324 }
325 }
326 }
327
328 // Sort by target count, number of indices in case of jump table, and
329 // mispredicts. We prioritize targets with high count, small number of indices
330 // and high mispredicts. Break ties by selecting targets with lower addresses.
331 llvm::stable_sort(Range&: Targets, C: [](const Callsite &A, const Callsite &B) {
332 if (A.Branches != B.Branches)
333 return A.Branches > B.Branches;
334 if (A.JTIndices.size() != B.JTIndices.size())
335 return A.JTIndices.size() < B.JTIndices.size();
336 if (A.Mispreds != B.Mispreds)
337 return A.Mispreds > B.Mispreds;
338 return A.To.Addr < B.To.Addr;
339 });
340
341 // Remove non-symbol targets
342 llvm::erase_if(C&: Targets, P: [](const Callsite &CS) { return !CS.To.Sym; });
343
344 LLVM_DEBUG(if (BF.getJumpTable(Inst)) {
345 uint64_t TotalCount = 0;
346 uint64_t TotalMispreds = 0;
347 for (const Callsite &S : Targets) {
348 TotalCount += S.Branches;
349 TotalMispreds += S.Mispreds;
350 }
351 if (!TotalCount)
352 TotalCount = 1;
353 if (!TotalMispreds)
354 TotalMispreds = 1;
355
356 dbgs() << "BOLT-INFO: ICP: jump table size = " << Targets.size()
357 << ", Count = " << TotalCount << ", Mispreds = " << TotalMispreds
358 << "\n";
359
360 size_t I = 0;
361 for (const Callsite &S : Targets) {
362 dbgs() << "Count[" << I << "] = " << S.Branches << ", "
363 << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
364 << "Mispreds[" << I << "] = " << S.Mispreds << ", "
365 << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds) << "\n";
366 ++I;
367 }
368 });
369
370 return Targets;
371}
372
373IndirectCallPromotion::JumpTableInfoType
374IndirectCallPromotion::maybeGetHotJumpTableTargets(BinaryBasicBlock &BB,
375 MCInst &CallInst,
376 MCInst *&TargetFetchInst,
377 const JumpTable *JT) const {
378 assert(JT && "Can't get jump table addrs for non-jump tables.");
379
380 BinaryFunction &Function = *BB.getFunction();
381 BinaryContext &BC = Function.getBinaryContext();
382
383 if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
384 return JumpTableInfoType();
385
386 JumpTableInfoType HotTargets;
387 MCInst *MemLocInstr;
388 MCInst *PCRelBaseOut;
389 MCInst *FixedEntryLoadInstr;
390 unsigned BaseReg, IndexReg;
391 int64_t DispValue;
392 const MCExpr *DispExpr;
393 MutableArrayRef<MCInst> Insts(&BB.front(), &CallInst);
394 const IndirectBranchType Type = BC.MIB->analyzeIndirectBranch(
395 Instruction&: CallInst, Begin: Insts.begin(), End: Insts.end(), PtrSize: BC.AsmInfo->getCodePointerSize(),
396 MemLocInstr, BaseRegNum&: BaseReg, IndexRegNum&: IndexReg, DispValue, DispExpr, PCRelBaseOut,
397 FixedEntryLoadInst&: FixedEntryLoadInstr);
398
399 assert(MemLocInstr && "There should always be a load for jump tables");
400 if (!MemLocInstr)
401 return JumpTableInfoType();
402
403 LLVM_DEBUG({
404 dbgs() << "BOLT-INFO: ICP attempting to find memory profiling data for "
405 << "jump table in " << Function << " at @ "
406 << (&CallInst - &BB.front()) << "\n"
407 << "BOLT-INFO: ICP target fetch instructions:\n";
408 BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
409 if (MemLocInstr != &CallInst)
410 BC.printInstruction(dbgs(), CallInst, 0, &Function);
411 });
412
413 DEBUG_VERBOSE(1, {
414 dbgs() << "Jmp info: Type = " << (unsigned)Type << ", "
415 << "BaseReg = " << BC.MRI->getName(BaseReg) << ", "
416 << "IndexReg = " << BC.MRI->getName(IndexReg) << ", "
417 << "DispValue = " << Twine::utohexstr(DispValue) << ", "
418 << "DispExpr = " << DispExpr << ", "
419 << "MemLocInstr = ";
420 BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
421 dbgs() << "\n";
422 });
423
424 ++TotalIndexBasedCandidates;
425
426 auto ErrorOrMemAccessProfile =
427 BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(Inst&: *MemLocInstr,
428 Name: "MemoryAccessProfile");
429 if (!ErrorOrMemAccessProfile) {
430 DEBUG_VERBOSE(1, dbgs()
431 << "BOLT-INFO: ICP no memory profiling data found\n");
432 return JumpTableInfoType();
433 }
434 MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccessProfile.get();
435
436 uint64_t ArrayStart;
437 if (DispExpr) {
438 ErrorOr<uint64_t> DispValueOrError =
439 BC.getSymbolValue(Symbol: *BC.MIB->getTargetSymbol(Expr: DispExpr));
440 assert(DispValueOrError && "global symbol needs a value");
441 ArrayStart = *DispValueOrError;
442 } else {
443 ArrayStart = static_cast<uint64_t>(DispValue);
444 }
445
446 if (BaseReg == BC.MRI->getProgramCounter())
447 ArrayStart += Function.getAddress() + MemAccessProfile.NextInstrOffset;
448
449 // This is a map of [symbol] -> [count, index] and is used to combine indices
450 // into the jump table since there may be multiple addresses that all have the
451 // same entry.
452 std::map<MCSymbol *, std::pair<uint64_t, uint64_t>> HotTargetMap;
453 const std::pair<size_t, size_t> Range = JT->getEntriesForAddress(Addr: ArrayStart);
454
455 for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
456 size_t Index;
457 // Mem data occasionally includes nullprs, ignore them.
458 if (!AccessInfo.MemoryObject && !AccessInfo.Offset)
459 continue;
460
461 if (AccessInfo.Offset % JT->EntrySize != 0) // ignore bogus data
462 return JumpTableInfoType();
463
464 if (AccessInfo.MemoryObject) {
465 // Deal with bad/stale data
466 if (!AccessInfo.MemoryObject->getName().starts_with(
467 Prefix: "JUMP_TABLE/" + Function.getOneName().str()))
468 return JumpTableInfoType();
469 Index =
470 (AccessInfo.Offset - (ArrayStart - JT->getAddress())) / JT->EntrySize;
471 } else {
472 Index = (AccessInfo.Offset - ArrayStart) / JT->EntrySize;
473 }
474
475 // If Index is out of range it probably means the memory profiling data is
476 // wrong for this instruction, bail out.
477 if (Index >= Range.second) {
478 LLVM_DEBUG(dbgs() << "BOLT-INFO: Index out of range of " << Range.first
479 << ", " << Range.second << "\n");
480 return JumpTableInfoType();
481 }
482
483 // Make sure the hot index points at a legal label corresponding to a BB,
484 // e.g. not the end of function (unreachable) label.
485 if (!Function.getBasicBlockForLabel(Label: JT->Entries[Index + Range.first])) {
486 LLVM_DEBUG({
487 dbgs() << "BOLT-INFO: hot index " << Index << " pointing at bogus "
488 << "label " << JT->Entries[Index + Range.first]->getName()
489 << " in jump table:\n";
490 JT->print(dbgs());
491 dbgs() << "HotTargetMap:\n";
492 for (std::pair<MCSymbol *const, std::pair<uint64_t, uint64_t>> &HT :
493 HotTargetMap)
494 dbgs() << "BOLT-INFO: " << HT.first->getName()
495 << " = (count=" << HT.second.first
496 << ", index=" << HT.second.second << ")\n";
497 });
498 return JumpTableInfoType();
499 }
500
501 std::pair<uint64_t, uint64_t> &HotTarget =
502 HotTargetMap[JT->Entries[Index + Range.first]];
503 HotTarget.first += AccessInfo.Count;
504 HotTarget.second = Index;
505 }
506
507 llvm::copy(Range: llvm::make_second_range(c&: HotTargetMap),
508 Out: std::back_inserter(x&: HotTargets));
509
510 // Sort with highest counts first.
511 llvm::sort(C: reverse(C&: HotTargets));
512
513 LLVM_DEBUG({
514 dbgs() << "BOLT-INFO: ICP jump table hot targets:\n";
515 for (const std::pair<uint64_t, uint64_t> &Target : HotTargets)
516 dbgs() << "BOLT-INFO: Idx = " << Target.second << ", "
517 << "Count = " << Target.first << "\n";
518 });
519
520 BC.MIB->getOrCreateAnnotationAs<uint16_t>(Inst&: CallInst, Name: "JTIndexReg") = IndexReg;
521
522 TargetFetchInst = MemLocInstr;
523
524 return HotTargets;
525}
526
527IndirectCallPromotion::SymTargetsType
528IndirectCallPromotion::findCallTargetSymbols(std::vector<Callsite> &Targets,
529 size_t &N, BinaryBasicBlock &BB,
530 MCInst &CallInst,
531 MCInst *&TargetFetchInst) const {
532 const BinaryContext &BC = BB.getFunction()->getBinaryContext();
533 const JumpTable *JT = BB.getFunction()->getJumpTable(Inst: CallInst);
534 SymTargetsType SymTargets;
535
536 if (!JT) {
537 for (size_t I = 0; I < N; ++I) {
538 assert(Targets[I].To.Sym && "All ICP targets must be to known symbols");
539 assert(Targets[I].JTIndices.empty() &&
540 "Can't have jump table indices for non-jump tables");
541 SymTargets.emplace_back(args&: Targets[I].To.Sym, args: 0);
542 }
543 return SymTargets;
544 }
545
546 // Use memory profile to select hot targets.
547 JumpTableInfoType HotTargets =
548 maybeGetHotJumpTableTargets(BB, CallInst, TargetFetchInst, JT);
549
550 auto findTargetsIndex = [&](uint64_t JTIndex) {
551 for (size_t I = 0; I < Targets.size(); ++I)
552 if (llvm::is_contained(Range&: Targets[I].JTIndices, Element: JTIndex))
553 return I;
554 LLVM_DEBUG(dbgs() << "BOLT-ERROR: Unable to find target index for hot jump "
555 << " table entry in " << *BB.getFunction() << "\n");
556 llvm_unreachable("Hot indices must be referred to by at least one "
557 "callsite");
558 };
559
560 if (!HotTargets.empty()) {
561 if (opts::Verbosity >= 1)
562 for (size_t I = 0; I < HotTargets.size(); ++I)
563 BC.outs() << "BOLT-INFO: HotTarget[" << I << "] = ("
564 << HotTargets[I].first << ", " << HotTargets[I].second
565 << ")\n";
566
567 // Recompute hottest targets, now discriminating which index is hot
568 // NOTE: This is a tradeoff. On one hand, we get index information. On the
569 // other hand, info coming from the memory profile is much less accurate
570 // than LBRs. So we may actually end up working with more coarse
571 // profile granularity in exchange for information about indices.
572 std::vector<Callsite> NewTargets;
573 std::map<const MCSymbol *, uint32_t> IndicesPerTarget;
574 uint64_t TotalMemAccesses = 0;
575 for (size_t I = 0; I < HotTargets.size(); ++I) {
576 const uint64_t TargetIndex = findTargetsIndex(HotTargets[I].second);
577 ++IndicesPerTarget[Targets[TargetIndex].To.Sym];
578 TotalMemAccesses += HotTargets[I].first;
579 }
580 uint64_t RemainingMemAccesses = TotalMemAccesses;
581 const size_t TopN =
582 opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : opts::ICPTopN;
583 size_t I = 0;
584 for (; I < HotTargets.size(); ++I) {
585 const uint64_t MemAccesses = HotTargets[I].first;
586 if (100 * MemAccesses <
587 TotalMemAccesses * opts::ICPJTTotalPercentThreshold)
588 break;
589 if (100 * MemAccesses <
590 RemainingMemAccesses * opts::ICPJTRemainingPercentThreshold)
591 break;
592 if (TopN && I >= TopN)
593 break;
594 RemainingMemAccesses -= MemAccesses;
595
596 const uint64_t JTIndex = HotTargets[I].second;
597 Callsite &Target = Targets[findTargetsIndex(JTIndex)];
598
599 NewTargets.push_back(x: Target);
600 std::vector<uint64_t>({JTIndex}).swap(x&: NewTargets.back().JTIndices);
601 llvm::erase(C&: Target.JTIndices, V: JTIndex);
602
603 // Keep fixCFG counts sane if more indices use this same target later
604 assert(IndicesPerTarget[Target.To.Sym] > 0 && "wrong map");
605 NewTargets.back().Branches =
606 Target.Branches / IndicesPerTarget[Target.To.Sym];
607 NewTargets.back().Mispreds =
608 Target.Mispreds / IndicesPerTarget[Target.To.Sym];
609 assert(Target.Branches >= NewTargets.back().Branches);
610 assert(Target.Mispreds >= NewTargets.back().Mispreds);
611 Target.Branches -= NewTargets.back().Branches;
612 Target.Mispreds -= NewTargets.back().Mispreds;
613 }
614 llvm::copy(Range&: Targets, Out: std::back_inserter(x&: NewTargets));
615 std::swap(x&: NewTargets, y&: Targets);
616 N = I;
617
618 if (N == 0 && opts::Verbosity >= 1) {
619 BC.outs() << "BOLT-INFO: ICP failed in " << *BB.getFunction() << " in "
620 << BB.getName() << ": failed to meet thresholds after memory "
621 << "profile data was loaded.\n";
622 return SymTargets;
623 }
624 }
625
626 for (size_t I = 0, TgtIdx = 0; I < N; ++TgtIdx) {
627 Callsite &Target = Targets[TgtIdx];
628 assert(Target.To.Sym && "All ICP targets must be to known symbols");
629 assert(!Target.JTIndices.empty() && "Jump tables must have indices");
630 for (uint64_t Idx : Target.JTIndices) {
631 SymTargets.emplace_back(args&: Target.To.Sym, args&: Idx);
632 ++I;
633 }
634 }
635
636 return SymTargets;
637}
638
639IndirectCallPromotion::MethodInfoType IndirectCallPromotion::maybeGetVtableSyms(
640 BinaryBasicBlock &BB, MCInst &Inst,
641 const SymTargetsType &SymTargets) const {
642 BinaryFunction &Function = *BB.getFunction();
643 BinaryContext &BC = Function.getBinaryContext();
644 std::vector<std::pair<MCSymbol *, uint64_t>> VtableSyms;
645 std::vector<MCInst *> MethodFetchInsns;
646 unsigned VtableReg, MethodReg;
647 uint64_t MethodOffset;
648
649 assert(!Function.getJumpTable(Inst) &&
650 "Can't get vtable addrs for jump tables.");
651
652 if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
653 return MethodInfoType();
654
655 MutableArrayRef<MCInst> Insts(&BB.front(), &Inst + 1);
656 if (!BC.MIB->analyzeVirtualMethodCall(Begin: Insts.begin(), End: Insts.end(),
657 MethodFetchInsns, VtableRegNum&: VtableReg, BaseRegNum&: MethodReg,
658 MethodOffset)) {
659 DEBUG_VERBOSE(
660 1, dbgs() << "BOLT-INFO: ICP unable to analyze method call in "
661 << Function << " at @ " << (&Inst - &BB.front()) << "\n");
662 return MethodInfoType();
663 }
664
665 ++TotalMethodLoadEliminationCandidates;
666
667 DEBUG_VERBOSE(1, {
668 dbgs() << "BOLT-INFO: ICP found virtual method call in " << Function
669 << " at @ " << (&Inst - &BB.front()) << "\n";
670 dbgs() << "BOLT-INFO: ICP method fetch instructions:\n";
671 for (MCInst *Inst : MethodFetchInsns)
672 BC.printInstruction(dbgs(), *Inst, 0, &Function);
673
674 if (MethodFetchInsns.back() != &Inst)
675 BC.printInstruction(dbgs(), Inst, 0, &Function);
676 });
677
678 // Try to get value profiling data for the method load instruction.
679 auto ErrorOrMemAccessProfile =
680 BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(Inst&: *MethodFetchInsns.back(),
681 Name: "MemoryAccessProfile");
682 if (!ErrorOrMemAccessProfile) {
683 DEBUG_VERBOSE(1, dbgs()
684 << "BOLT-INFO: ICP no memory profiling data found\n");
685 return MethodInfoType();
686 }
687 MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccessProfile.get();
688
689 // Find the vtable that each method belongs to.
690 std::map<const MCSymbol *, uint64_t> MethodToVtable;
691
692 for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
693 uint64_t Address = AccessInfo.Offset;
694 if (AccessInfo.MemoryObject)
695 Address += AccessInfo.MemoryObject->getAddress();
696
697 // Ignore bogus data.
698 if (!Address)
699 continue;
700
701 const uint64_t VtableBase = Address - MethodOffset;
702
703 DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP vtable = "
704 << Twine::utohexstr(VtableBase) << "+"
705 << MethodOffset << "/" << AccessInfo.Count << "\n");
706
707 if (ErrorOr<uint64_t> MethodAddr = BC.getPointerAtAddress(Address)) {
708 BinaryData *MethodBD = BC.getBinaryDataAtAddress(Address: MethodAddr.get());
709 if (!MethodBD) // skip unknown methods
710 continue;
711 MCSymbol *MethodSym = MethodBD->getSymbol();
712 MethodToVtable[MethodSym] = VtableBase;
713 DEBUG_VERBOSE(1, {
714 const BinaryFunction *Method = BC.getFunctionForSymbol(MethodSym);
715 dbgs() << "BOLT-INFO: ICP found method = "
716 << Twine::utohexstr(MethodAddr.get()) << "/"
717 << (Method ? Method->getPrintName() : "") << "\n";
718 });
719 }
720 }
721
722 // Find the vtable for each target symbol.
723 for (size_t I = 0; I < SymTargets.size(); ++I) {
724 auto Itr = MethodToVtable.find(x: SymTargets[I].first);
725 if (Itr != MethodToVtable.end()) {
726 if (BinaryData *BD = BC.getBinaryDataContainingAddress(Address: Itr->second)) {
727 const uint64_t Addend = Itr->second - BD->getAddress();
728 VtableSyms.emplace_back(args: BD->getSymbol(), args: Addend);
729 continue;
730 }
731 }
732 // Give up if we can't find the vtable for a method.
733 DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP can't find vtable for "
734 << SymTargets[I].first->getName() << "\n");
735 return MethodInfoType();
736 }
737
738 // Make sure the vtable reg is not clobbered by the argument passing code
739 if (VtableReg != MethodReg) {
740 for (MCInst *CurInst = MethodFetchInsns.front(); CurInst < &Inst;
741 ++CurInst) {
742 const MCInstrDesc &InstrInfo = BC.MII->get(Opcode: CurInst->getOpcode());
743 if (InstrInfo.hasDefOfPhysReg(MI: *CurInst, Reg: VtableReg, RI: *BC.MRI))
744 return MethodInfoType();
745 }
746 }
747
748 return MethodInfoType(VtableSyms, MethodFetchInsns);
749}
750
751std::vector<std::unique_ptr<BinaryBasicBlock>>
752IndirectCallPromotion::rewriteCall(
753 BinaryBasicBlock &IndCallBlock, const MCInst &CallInst,
754 MCPlusBuilder::BlocksVectorTy &&ICPcode,
755 const std::vector<MCInst *> &MethodFetchInsns) const {
756 BinaryFunction &Function = *IndCallBlock.getFunction();
757 MCPlusBuilder *MIB = Function.getBinaryContext().MIB.get();
758
759 // Create new basic blocks with correct code in each one first.
760 std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
761 const bool IsTailCallOrJT =
762 (MIB->isTailCall(Inst: CallInst) || Function.getJumpTable(Inst: CallInst));
763
764 // If we are tracking the indirect call/jump address, propagate the address to
765 // the ICP code.
766 const std::optional<uint32_t> IndirectInstrOffset = MIB->getOffset(Inst: CallInst);
767 if (IndirectInstrOffset) {
768 for (auto &[Symbol, Instructions] : ICPcode)
769 for (MCInst &Inst : Instructions)
770 MIB->setOffset(Inst, Offset: *IndirectInstrOffset);
771 }
772
773 // Move instructions from the tail of the original call block
774 // to the merge block.
775
776 // Remember any pseudo instructions following a tail call. These
777 // must be preserved and moved to the original block.
778 InstructionListType TailInsts;
779 const MCInst *TailInst = &CallInst;
780 if (IsTailCallOrJT)
781 while (TailInst + 1 < &(*IndCallBlock.end()) &&
782 MIB->isPseudo(Inst: *(TailInst + 1)))
783 TailInsts.push_back(x: *++TailInst);
784
785 InstructionListType MovedInst = IndCallBlock.splitInstructions(Inst: &CallInst);
786 // Link new BBs to the original input offset of the indirect call site or its
787 // containing BB, so we can map samples recorded in new BBs back to the
788 // original BB seen in the input binary (if using BAT).
789 const uint32_t OrigOffset = IndirectInstrOffset
790 ? *IndirectInstrOffset
791 : IndCallBlock.getInputOffset();
792
793 IndCallBlock.eraseInstructions(Begin: MethodFetchInsns.begin(),
794 End: MethodFetchInsns.end());
795 if (IndCallBlock.empty() ||
796 (!MethodFetchInsns.empty() && MethodFetchInsns.back() == &CallInst))
797 IndCallBlock.addInstructions(Begin: ICPcode.front().second.begin(),
798 End: ICPcode.front().second.end());
799 else
800 IndCallBlock.replaceInstruction(II: std::prev(x: IndCallBlock.end()),
801 Replacement: ICPcode.front().second);
802 IndCallBlock.addInstructions(Begin: TailInsts.begin(), End: TailInsts.end());
803
804 for (auto Itr = ICPcode.begin() + 1; Itr != ICPcode.end(); ++Itr) {
805 MCSymbol *&Sym = Itr->first;
806 InstructionListType &Insts = Itr->second;
807 assert(Sym);
808 std::unique_ptr<BinaryBasicBlock> TBB = Function.createBasicBlock(Label: Sym);
809 TBB->setOffset(OrigOffset);
810 for (MCInst &Inst : Insts) // sanitize new instructions.
811 if (MIB->isCall(Inst))
812 MIB->removeAnnotation(Inst, Name: "CallProfile");
813 TBB->addInstructions(Begin: Insts.begin(), End: Insts.end());
814 NewBBs.emplace_back(args: std::move(TBB));
815 }
816
817 // Move tail of instructions from after the original call to
818 // the merge block.
819 if (!IsTailCallOrJT)
820 NewBBs.back()->addInstructions(Begin: MovedInst.begin(), End: MovedInst.end());
821
822 return NewBBs;
823}
824
825BinaryBasicBlock *
826IndirectCallPromotion::fixCFG(BinaryBasicBlock &IndCallBlock,
827 const bool IsTailCall, const bool IsJumpTable,
828 IndirectCallPromotion::BasicBlocksVector &&NewBBs,
829 const std::vector<Callsite> &Targets) const {
830 BinaryFunction &Function = *IndCallBlock.getFunction();
831 using BinaryBranchInfo = BinaryBasicBlock::BinaryBranchInfo;
832 BinaryBasicBlock *MergeBlock = nullptr;
833
834 // Scale indirect call counts to the execution count of the original
835 // basic block containing the indirect call.
836 uint64_t TotalCount = IndCallBlock.getKnownExecutionCount();
837 uint64_t TotalIndirectBranches = 0;
838 for (const Callsite &Target : Targets)
839 TotalIndirectBranches += Target.Branches;
840 if (TotalIndirectBranches == 0)
841 TotalIndirectBranches = 1;
842 BinaryBasicBlock::BranchInfoType BBI;
843 BinaryBasicBlock::BranchInfoType ScaledBBI;
844 for (const Callsite &Target : Targets) {
845 const size_t NumEntries =
846 std::max(a: static_cast<std::size_t>(1UL), b: Target.JTIndices.size());
847 for (size_t I = 0; I < NumEntries; ++I) {
848 BBI.push_back(
849 Elt: BinaryBranchInfo{.Count: (Target.Branches + NumEntries - 1) / NumEntries,
850 .MispredictedCount: (Target.Mispreds + NumEntries - 1) / NumEntries});
851 ScaledBBI.push_back(
852 Elt: BinaryBranchInfo{.Count: uint64_t(TotalCount * Target.Branches /
853 (NumEntries * TotalIndirectBranches)),
854 .MispredictedCount: uint64_t(TotalCount * Target.Mispreds /
855 (NumEntries * TotalIndirectBranches))});
856 }
857 }
858
859 if (IsJumpTable) {
860 BinaryBasicBlock *NewIndCallBlock = NewBBs.back().get();
861 IndCallBlock.moveAllSuccessorsTo(New: NewIndCallBlock);
862
863 std::vector<MCSymbol *> SymTargets;
864 for (const Callsite &Target : Targets) {
865 const size_t NumEntries =
866 std::max(a: static_cast<std::size_t>(1UL), b: Target.JTIndices.size());
867 for (size_t I = 0; I < NumEntries; ++I)
868 SymTargets.push_back(x: Target.To.Sym);
869 }
870 assert(SymTargets.size() > NewBBs.size() - 1 &&
871 "There must be a target symbol associated with each new BB.");
872
873 for (uint64_t I = 0; I < NewBBs.size(); ++I) {
874 BinaryBasicBlock *SourceBB = I ? NewBBs[I - 1].get() : &IndCallBlock;
875 SourceBB->setExecutionCount(TotalCount);
876
877 BinaryBasicBlock *TargetBB =
878 Function.getBasicBlockForLabel(Label: SymTargets[I]);
879 SourceBB->addSuccessor(Succ: TargetBB, BI: ScaledBBI[I]); // taken
880
881 TotalCount -= ScaledBBI[I].Count;
882 SourceBB->addSuccessor(Succ: NewBBs[I].get(), Count: TotalCount); // fall-through
883
884 // Update branch info for the indirect jump.
885 BinaryBasicBlock::BinaryBranchInfo &BranchInfo =
886 NewIndCallBlock->getBranchInfo(Succ: *TargetBB);
887 if (BranchInfo.Count > BBI[I].Count)
888 BranchInfo.Count -= BBI[I].Count;
889 else
890 BranchInfo.Count = 0;
891
892 if (BranchInfo.MispredictedCount > BBI[I].MispredictedCount)
893 BranchInfo.MispredictedCount -= BBI[I].MispredictedCount;
894 else
895 BranchInfo.MispredictedCount = 0;
896 }
897 } else {
898 assert(NewBBs.size() >= 2);
899 assert(NewBBs.size() % 2 == 1 || IndCallBlock.succ_empty());
900 assert(NewBBs.size() % 2 == 1 || IsTailCall);
901
902 auto ScaledBI = ScaledBBI.begin();
903 auto updateCurrentBranchInfo = [&] {
904 assert(ScaledBI != ScaledBBI.end());
905 TotalCount -= ScaledBI->Count;
906 ++ScaledBI;
907 };
908
909 if (!IsTailCall) {
910 MergeBlock = NewBBs.back().get();
911 IndCallBlock.moveAllSuccessorsTo(New: MergeBlock);
912 }
913
914 // Fix up successors and execution counts.
915 updateCurrentBranchInfo();
916 IndCallBlock.addSuccessor(Succ: NewBBs[1].get(), Count: TotalCount);
917 IndCallBlock.addSuccessor(Succ: NewBBs[0].get(), BI: ScaledBBI[0]);
918
919 const size_t Adj = IsTailCall ? 1 : 2;
920 for (size_t I = 0; I < NewBBs.size() - Adj; ++I) {
921 assert(TotalCount <= IndCallBlock.getExecutionCount() ||
922 TotalCount <= uint64_t(TotalIndirectBranches));
923 uint64_t ExecCount = ScaledBBI[(I + 1) / 2].Count;
924 if (I % 2 == 0) {
925 if (MergeBlock)
926 NewBBs[I]->addSuccessor(Succ: MergeBlock, Count: ScaledBBI[(I + 1) / 2].Count);
927 } else {
928 assert(I + 2 < NewBBs.size());
929 updateCurrentBranchInfo();
930 NewBBs[I]->addSuccessor(Succ: NewBBs[I + 2].get(), Count: TotalCount);
931 NewBBs[I]->addSuccessor(Succ: NewBBs[I + 1].get(), BI: ScaledBBI[(I + 1) / 2]);
932 ExecCount += TotalCount;
933 }
934 NewBBs[I]->setExecutionCount(ExecCount);
935 }
936
937 if (MergeBlock) {
938 // Arrange for the MergeBlock to be the fallthrough for the first
939 // promoted call block.
940 std::unique_ptr<BinaryBasicBlock> MBPtr;
941 std::swap(x&: MBPtr, y&: NewBBs.back());
942 NewBBs.pop_back();
943 NewBBs.emplace(position: NewBBs.begin() + 1, args: std::move(MBPtr));
944 // TODO: is COUNT_FALLTHROUGH_EDGE the right thing here?
945 NewBBs.back()->addSuccessor(Succ: MergeBlock, Count: TotalCount); // uncond branch
946 }
947 }
948
949 // Update the execution count.
950 NewBBs.back()->setExecutionCount(TotalCount);
951
952 // Update BB and BB layout.
953 Function.insertBasicBlocks(Start: &IndCallBlock, NewBBs: std::move(NewBBs));
954 assert(Function.validateCFG());
955
956 return MergeBlock;
957}
958
959size_t IndirectCallPromotion::canPromoteCallsite(
960 const BinaryBasicBlock &BB, const MCInst &Inst,
961 const std::vector<Callsite> &Targets, uint64_t NumCalls) {
962 BinaryFunction *BF = BB.getFunction();
963 const BinaryContext &BC = BF->getBinaryContext();
964
965 if (BB.getKnownExecutionCount() < opts::ExecutionCountThreshold)
966 return 0;
967
968 const bool IsJumpTable = BF->getJumpTable(Inst);
969
970 auto computeStats = [&](size_t N) {
971 for (size_t I = 0; I < N; ++I)
972 if (IsJumpTable)
973 TotalNumFrequentJmps += Targets[I].Branches;
974 else
975 TotalNumFrequentCalls += Targets[I].Branches;
976 };
977
978 // If we have no targets (or no calls), skip this callsite.
979 if (Targets.empty() || !NumCalls) {
980 if (opts::Verbosity >= 1) {
981 const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
982 BC.outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
983 << " in " << BB.getName() << ", calls = " << NumCalls
984 << ", targets empty or NumCalls == 0.\n";
985 }
986 return 0;
987 }
988
989 size_t TopN = opts::ICPTopN;
990 if (IsJumpTable)
991 TopN = opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : TopN;
992 else
993 TopN = opts::ICPCallsTopN ? opts::ICPCallsTopN : TopN;
994
995 const size_t TrialN = TopN ? std::min(a: TopN, b: Targets.size()) : Targets.size();
996
997 if (opts::ICPTopCallsites && !BC.MIB->hasAnnotation(Inst, Name: "DoICP"))
998 return 0;
999
1000 // Pick the top N targets.
1001 uint64_t TotalMispredictsTopN = 0;
1002 size_t N = 0;
1003
1004 if (opts::ICPUseMispredicts &&
1005 (!IsJumpTable || opts::ICPJumpTablesByTarget)) {
1006 // Count total number of mispredictions for (at most) the top N targets.
1007 // We may choose a smaller N (TrialN vs. N) if the frequency threshold
1008 // is exceeded by fewer targets.
1009 double Threshold = double(opts::ICPMispredictThreshold);
1010 for (size_t I = 0; I < TrialN && Threshold > 0; ++I, ++N) {
1011 Threshold -= (100.0 * Targets[I].Mispreds) / NumCalls;
1012 TotalMispredictsTopN += Targets[I].Mispreds;
1013 }
1014 computeStats(N);
1015
1016 // Compute the misprediction frequency of the top N call targets. If this
1017 // frequency is greater than the threshold, we should try ICP on this
1018 // callsite.
1019 const double TopNFrequency = (100.0 * TotalMispredictsTopN) / NumCalls;
1020 if (TopNFrequency == 0 || TopNFrequency < opts::ICPMispredictThreshold) {
1021 if (opts::Verbosity >= 1) {
1022 const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1023 BC.outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
1024 << " in " << BB.getName() << ", calls = " << NumCalls
1025 << ", top N mis. frequency " << format(Fmt: "%.1f", Vals: TopNFrequency)
1026 << "% < " << opts::ICPMispredictThreshold << "%\n";
1027 }
1028 return 0;
1029 }
1030 } else {
1031 size_t MaxTargets = 0;
1032
1033 // Count total number of calls for (at most) the top N targets.
1034 // We may choose a smaller N (TrialN vs. N) if the frequency threshold
1035 // is exceeded by fewer targets.
1036 const unsigned TotalThreshold = IsJumpTable
1037 ? opts::ICPJTTotalPercentThreshold
1038 : opts::ICPCallsTotalPercentThreshold;
1039 const unsigned RemainingThreshold =
1040 IsJumpTable ? opts::ICPJTRemainingPercentThreshold
1041 : opts::ICPCallsRemainingPercentThreshold;
1042 uint64_t NumRemainingCalls = NumCalls;
1043 for (size_t I = 0; I < TrialN; ++I, ++MaxTargets) {
1044 if (100 * Targets[I].Branches < NumCalls * TotalThreshold)
1045 break;
1046 if (100 * Targets[I].Branches < NumRemainingCalls * RemainingThreshold)
1047 break;
1048 if (N + (Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size()) >
1049 TrialN)
1050 break;
1051 TotalMispredictsTopN += Targets[I].Mispreds;
1052 NumRemainingCalls -= Targets[I].Branches;
1053 N += Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size();
1054 }
1055 computeStats(MaxTargets);
1056
1057 // Don't check misprediction frequency for jump tables -- we don't really
1058 // care as long as we are saving loads from the jump table.
1059 if (!IsJumpTable || opts::ICPJumpTablesByTarget) {
1060 // Compute the misprediction frequency of the top N call targets. If
1061 // this frequency is less than the threshold, we should skip ICP at
1062 // this callsite.
1063 const double TopNMispredictFrequency =
1064 (100.0 * TotalMispredictsTopN) / NumCalls;
1065
1066 if (TopNMispredictFrequency < opts::ICPMispredictThreshold) {
1067 if (opts::Verbosity >= 1) {
1068 const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1069 BC.outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
1070 << " in " << BB.getName() << ", calls = " << NumCalls
1071 << ", top N mispredict frequency "
1072 << format(Fmt: "%.1f", Vals: TopNMispredictFrequency) << "% < "
1073 << opts::ICPMispredictThreshold << "%\n";
1074 }
1075 return 0;
1076 }
1077 }
1078 }
1079
1080 // Filter by inline-ability of target functions, stop at first target that
1081 // can't be inlined.
1082 if (!IsJumpTable && opts::ICPPeelForInline) {
1083 for (size_t I = 0; I < N; ++I) {
1084 const MCSymbol *TargetSym = Targets[I].To.Sym;
1085 const BinaryFunction *TargetBF = BC.getFunctionForSymbol(Symbol: TargetSym);
1086 if (!TargetBF || !BinaryFunctionPass::shouldOptimize(BF: *TargetBF) ||
1087 getInliningInfo(BF: *TargetBF).Type == InliningType::INL_NONE) {
1088 N = I;
1089 break;
1090 }
1091 }
1092 }
1093
1094 // Filter functions that can have ICP applied (for debugging)
1095 if (!opts::ICPFuncsList.empty()) {
1096 for (std::string &Name : opts::ICPFuncsList)
1097 if (BF->hasName(FunctionName: Name))
1098 return N;
1099 return 0;
1100 }
1101
1102 return N;
1103}
1104
1105void IndirectCallPromotion::printCallsiteInfo(
1106 const BinaryBasicBlock &BB, const MCInst &Inst,
1107 const std::vector<Callsite> &Targets, const size_t N,
1108 uint64_t NumCalls) const {
1109 BinaryContext &BC = BB.getFunction()->getBinaryContext();
1110 const bool IsTailCall = BC.MIB->isTailCall(Inst);
1111 const bool IsJumpTable = BB.getFunction()->getJumpTable(Inst);
1112 const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1113
1114 BC.outs() << "BOLT-INFO: ICP candidate branch info: " << *BB.getFunction()
1115 << " @ " << InstIdx << " in " << BB.getName()
1116 << " -> calls = " << NumCalls
1117 << (IsTailCall ? " (tail)" : (IsJumpTable ? " (jump table)" : ""))
1118 << "\n";
1119 for (size_t I = 0; I < N; I++) {
1120 const double Frequency = 100.0 * Targets[I].Branches / NumCalls;
1121 const double MisFrequency = 100.0 * Targets[I].Mispreds / NumCalls;
1122 BC.outs() << "BOLT-INFO: ";
1123 if (Targets[I].To.Sym)
1124 BC.outs() << Targets[I].To.Sym->getName();
1125 else
1126 BC.outs() << Targets[I].To.Addr;
1127 BC.outs() << ", calls = " << Targets[I].Branches
1128 << ", mispreds = " << Targets[I].Mispreds
1129 << ", taken freq = " << format(Fmt: "%.1f", Vals: Frequency) << "%"
1130 << ", mis. freq = " << format(Fmt: "%.1f", Vals: MisFrequency) << "%";
1131 bool First = true;
1132 for (uint64_t JTIndex : Targets[I].JTIndices) {
1133 BC.outs() << (First ? ", indices = " : ", ") << JTIndex;
1134 First = false;
1135 }
1136 BC.outs() << "\n";
1137 }
1138
1139 LLVM_DEBUG({
1140 dbgs() << "BOLT-INFO: ICP original call instruction:";
1141 BC.printInstruction(dbgs(), Inst, Targets[0].From.Addr, nullptr, true);
1142 });
1143}
1144
1145Error IndirectCallPromotion::runOnFunctions(BinaryContext &BC) {
1146 if (opts::ICP == ICP_NONE)
1147 return Error::success();
1148
1149 auto &BFs = BC.getBinaryFunctions();
1150
1151 const bool OptimizeCalls = (opts::ICP == ICP_CALLS || opts::ICP == ICP_ALL);
1152 const bool OptimizeJumpTables =
1153 (opts::ICP == ICP_JUMP_TABLES || opts::ICP == ICP_ALL);
1154
1155 std::unique_ptr<RegAnalysis> RA;
1156 std::unique_ptr<BinaryFunctionCallGraph> CG;
1157 if (OptimizeJumpTables) {
1158 CG.reset(p: new BinaryFunctionCallGraph(buildCallGraph(BC)));
1159 RA.reset(p: new RegAnalysis(BC, &BFs, &*CG));
1160 }
1161
1162 // If icp-top-callsites is enabled, compute the total number of indirect
1163 // calls and then optimize the hottest callsites that contribute to that
1164 // total.
1165 SetVector<BinaryFunction *> Functions;
1166 if (opts::ICPTopCallsites == 0) {
1167 for (auto &KV : BFs)
1168 Functions.insert(X: &KV.second);
1169 } else {
1170 using IndirectCallsite = std::tuple<uint64_t, MCInst *, BinaryFunction *>;
1171 std::vector<IndirectCallsite> IndirectCalls;
1172 size_t TotalIndirectCalls = 0;
1173
1174 // Find all the indirect callsites.
1175 for (auto &BFIt : BFs) {
1176 BinaryFunction &Function = BFIt.second;
1177
1178 if (!shouldOptimize(BF: Function))
1179 continue;
1180
1181 const bool HasLayout = !Function.getLayout().block_empty();
1182
1183 for (BinaryBasicBlock &BB : Function) {
1184 if (HasLayout && Function.isSplit() && BB.isCold())
1185 continue;
1186
1187 for (MCInst &Inst : BB) {
1188 const bool IsJumpTable = Function.getJumpTable(Inst);
1189 const bool HasIndirectCallProfile =
1190 BC.MIB->hasAnnotation(Inst, Name: "CallProfile");
1191 const bool IsDirectCall =
1192 (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, OpNum: 0));
1193
1194 if (!IsDirectCall &&
1195 ((HasIndirectCallProfile && !IsJumpTable && OptimizeCalls) ||
1196 (IsJumpTable && OptimizeJumpTables))) {
1197 uint64_t NumCalls = 0;
1198 for (const Callsite &BInfo : getCallTargets(BB, Inst))
1199 NumCalls += BInfo.Branches;
1200 IndirectCalls.push_back(
1201 x: std::make_tuple(args&: NumCalls, args: &Inst, args: &Function));
1202 TotalIndirectCalls += NumCalls;
1203 }
1204 }
1205 }
1206 }
1207
1208 // Sort callsites by execution count.
1209 llvm::sort(C: reverse(C&: IndirectCalls));
1210
1211 // Find callsites that contribute to the top "opts::ICPTopCallsites"%
1212 // number of calls.
1213 const float TopPerc = opts::ICPTopCallsites / 100.0f;
1214 int64_t MaxCalls = TotalIndirectCalls * TopPerc;
1215 uint64_t LastFreq = std::numeric_limits<uint64_t>::max();
1216 size_t Num = 0;
1217 for (const IndirectCallsite &IC : IndirectCalls) {
1218 const uint64_t CurFreq = std::get<0>(t: IC);
1219 // Once we decide to stop, include at least all branches that share the
1220 // same frequency of the last one to avoid non-deterministic behavior
1221 // (e.g. turning on/off ICP depending on the order of functions)
1222 if (MaxCalls <= 0 && CurFreq != LastFreq)
1223 break;
1224 MaxCalls -= CurFreq;
1225 LastFreq = CurFreq;
1226 BC.MIB->addAnnotation(Inst&: *std::get<1>(t: IC), Name: "DoICP", Val: true);
1227 Functions.insert(X: std::get<2>(t: IC));
1228 ++Num;
1229 }
1230 BC.outs() << "BOLT-INFO: ICP Total indirect calls = " << TotalIndirectCalls
1231 << ", " << Num << " callsites cover " << opts::ICPTopCallsites
1232 << "% of all indirect calls\n";
1233 }
1234
1235 for (BinaryFunction *FuncPtr : Functions) {
1236 BinaryFunction &Function = *FuncPtr;
1237
1238 if (!shouldOptimize(BF: Function))
1239 continue;
1240
1241 const bool HasLayout = !Function.getLayout().block_empty();
1242
1243 // Total number of indirect calls issued from the current Function.
1244 // (a fraction of TotalIndirectCalls)
1245 uint64_t FuncTotalIndirectCalls = 0;
1246 uint64_t FuncTotalIndirectJmps = 0;
1247
1248 std::vector<BinaryBasicBlock *> BBs;
1249 for (BinaryBasicBlock &BB : Function) {
1250 // Skip indirect calls in cold blocks.
1251 if (!HasLayout || !Function.isSplit() || !BB.isCold())
1252 BBs.push_back(x: &BB);
1253 }
1254 if (BBs.empty())
1255 continue;
1256
1257 DataflowInfoManager Info(Function, RA.get(), nullptr);
1258 while (!BBs.empty()) {
1259 BinaryBasicBlock *BB = BBs.back();
1260 BBs.pop_back();
1261
1262 for (unsigned Idx = 0; Idx < BB->size(); ++Idx) {
1263 MCInst &Inst = BB->getInstructionAtIndex(Index: Idx);
1264 const ptrdiff_t InstIdx = &Inst - &(*BB->begin());
1265 const bool IsTailCall = BC.MIB->isTailCall(Inst);
1266 const bool HasIndirectCallProfile =
1267 BC.MIB->hasAnnotation(Inst, Name: "CallProfile");
1268 const bool IsJumpTable = Function.getJumpTable(Inst);
1269
1270 if (BC.MIB->isCall(Inst))
1271 TotalCalls += BB->getKnownExecutionCount();
1272
1273 if (IsJumpTable && !OptimizeJumpTables)
1274 continue;
1275
1276 if (!IsJumpTable && (!HasIndirectCallProfile || !OptimizeCalls))
1277 continue;
1278
1279 // Ignore direct calls.
1280 if (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, OpNum: 0))
1281 continue;
1282
1283 assert((BC.MIB->isCall(Inst) || BC.MIB->isIndirectBranch(Inst)) &&
1284 "expected a call or an indirect jump instruction");
1285
1286 if (IsJumpTable)
1287 ++TotalJumpTableCallsites;
1288 else
1289 ++TotalIndirectCallsites;
1290
1291 std::vector<Callsite> Targets = getCallTargets(BB&: *BB, Inst);
1292
1293 // Compute the total number of calls from this particular callsite.
1294 uint64_t NumCalls = 0;
1295 for (const Callsite &BInfo : Targets)
1296 NumCalls += BInfo.Branches;
1297 if (!IsJumpTable)
1298 FuncTotalIndirectCalls += NumCalls;
1299 else
1300 FuncTotalIndirectJmps += NumCalls;
1301
1302 // If FLAGS regs is alive after this jmp site, do not try
1303 // promoting because we will clobber FLAGS.
1304 if (IsJumpTable) {
1305 ErrorOr<const BitVector &> State =
1306 Info.getLivenessAnalysis().getStateBefore(Point: Inst);
1307 if (!State || (State && (*State)[BC.MIB->getFlagsReg()])) {
1308 if (opts::Verbosity >= 1)
1309 BC.outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1310 << InstIdx << " in " << BB->getName()
1311 << ", calls = " << NumCalls
1312 << (State ? ", cannot clobber flags reg.\n"
1313 : ", no liveness data available.\n");
1314 continue;
1315 }
1316 }
1317
1318 // Should this callsite be optimized? Return the number of targets
1319 // to use when promoting this call. A value of zero means to skip
1320 // this callsite.
1321 size_t N = canPromoteCallsite(BB: *BB, Inst, Targets, NumCalls);
1322
1323 // If it is a jump table and it failed to meet our initial threshold,
1324 // proceed to findCallTargetSymbols -- it may reevaluate N if
1325 // memory profile is present
1326 if (!N && !IsJumpTable)
1327 continue;
1328
1329 if (opts::Verbosity >= 1)
1330 printCallsiteInfo(BB: *BB, Inst, Targets, N, NumCalls);
1331
1332 // Find MCSymbols or absolute addresses for each call target.
1333 MCInst *TargetFetchInst = nullptr;
1334 const SymTargetsType SymTargets =
1335 findCallTargetSymbols(Targets, N, BB&: *BB, CallInst&: Inst, TargetFetchInst);
1336
1337 // findCallTargetSymbols may have changed N if mem profile is available
1338 // for jump tables
1339 if (!N)
1340 continue;
1341
1342 LLVM_DEBUG(printDecision(dbgs(), Targets, N));
1343
1344 // If we can't resolve any of the target symbols, punt on this callsite.
1345 // TODO: can this ever happen?
1346 if (SymTargets.size() < N) {
1347 const size_t LastTarget = SymTargets.size();
1348 if (opts::Verbosity >= 1)
1349 BC.outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1350 << InstIdx << " in " << BB->getName()
1351 << ", calls = " << NumCalls
1352 << ", ICP failed to find target symbol for "
1353 << Targets[LastTarget].To.Sym->getName() << "\n";
1354 continue;
1355 }
1356
1357 MethodInfoType MethodInfo;
1358
1359 if (!IsJumpTable) {
1360 MethodInfo = maybeGetVtableSyms(BB&: *BB, Inst, SymTargets);
1361 TotalMethodLoadsEliminated += MethodInfo.first.empty() ? 0 : 1;
1362 LLVM_DEBUG(dbgs()
1363 << "BOLT-INFO: ICP "
1364 << (!MethodInfo.first.empty() ? "found" : "did not find")
1365 << " vtables for all methods.\n");
1366 } else if (TargetFetchInst) {
1367 ++TotalIndexBasedJumps;
1368 MethodInfo.second.push_back(x: TargetFetchInst);
1369 }
1370
1371 // Generate new promoted call code for this callsite.
1372 MCPlusBuilder::BlocksVectorTy ICPcode =
1373 (IsJumpTable && !opts::ICPJumpTablesByTarget)
1374 ? BC.MIB->jumpTablePromotion(IJmpInst: Inst, Targets: SymTargets,
1375 TargetFetchInsns: MethodInfo.second, Ctx: BC.Ctx.get())
1376 : BC.MIB->indirectCallPromotion(
1377 CallInst: Inst, Targets: SymTargets, VtableSyms: MethodInfo.first, MethodFetchInsns: MethodInfo.second,
1378 MinimizeCodeSize: opts::ICPOldCodeSequence, Ctx: BC.Ctx.get());
1379
1380 if (ICPcode.empty()) {
1381 if (opts::Verbosity >= 1)
1382 BC.outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1383 << InstIdx << " in " << BB->getName()
1384 << ", calls = " << NumCalls
1385 << ", unable to generate promoted call code.\n";
1386 continue;
1387 }
1388
1389 LLVM_DEBUG({
1390 uint64_t Offset = Targets[0].From.Addr;
1391 dbgs() << "BOLT-INFO: ICP indirect call code:\n";
1392 for (const auto &entry : ICPcode) {
1393 const MCSymbol *const &Sym = entry.first;
1394 const InstructionListType &Insts = entry.second;
1395 if (Sym)
1396 dbgs() << Sym->getName() << ":\n";
1397 Offset = BC.printInstructions(dbgs(), Insts.begin(), Insts.end(),
1398 Offset);
1399 }
1400 dbgs() << "---------------------------------------------------\n";
1401 });
1402
1403 // Rewrite the CFG with the newly generated ICP code.
1404 std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs =
1405 rewriteCall(IndCallBlock&: *BB, CallInst: Inst, ICPcode: std::move(ICPcode), MethodFetchInsns: MethodInfo.second);
1406
1407 // Fix the CFG after inserting the new basic blocks.
1408 BinaryBasicBlock *MergeBlock =
1409 fixCFG(IndCallBlock&: *BB, IsTailCall, IsJumpTable, NewBBs: std::move(NewBBs), Targets);
1410
1411 // Since the tail of the original block was split off and it may contain
1412 // additional indirect calls, we must add the merge block to the set of
1413 // blocks to process.
1414 if (MergeBlock)
1415 BBs.push_back(x: MergeBlock);
1416
1417 if (opts::Verbosity >= 1)
1418 BC.outs() << "BOLT-INFO: ICP succeeded in " << Function << " @ "
1419 << InstIdx << " in " << BB->getName()
1420 << " -> calls = " << NumCalls << "\n";
1421
1422 if (IsJumpTable)
1423 ++TotalOptimizedJumpTableCallsites;
1424 else
1425 ++TotalOptimizedIndirectCallsites;
1426
1427 Modified.insert(x: &Function);
1428 }
1429 }
1430 TotalIndirectCalls += FuncTotalIndirectCalls;
1431 TotalIndirectJmps += FuncTotalIndirectJmps;
1432 }
1433
1434 BC.outs()
1435 << "BOLT-INFO: ICP total indirect callsites with profile = "
1436 << TotalIndirectCallsites << "\n"
1437 << "BOLT-INFO: ICP total jump table callsites = "
1438 << TotalJumpTableCallsites << "\n"
1439 << "BOLT-INFO: ICP total number of calls = " << TotalCalls << "\n"
1440 << "BOLT-INFO: ICP percentage of calls that are indirect = "
1441 << format(Fmt: "%.1f", Vals: (100.0 * TotalIndirectCalls) / TotalCalls) << "%\n"
1442 << "BOLT-INFO: ICP percentage of indirect calls that can be "
1443 "optimized = "
1444 << format(Fmt: "%.1f", Vals: (100.0 * TotalNumFrequentCalls) /
1445 std::max<size_t>(a: TotalIndirectCalls, b: 1))
1446 << "%\n"
1447 << "BOLT-INFO: ICP percentage of indirect callsites that are "
1448 "optimized = "
1449 << format(Fmt: "%.1f", Vals: (100.0 * TotalOptimizedIndirectCallsites) /
1450 std::max<uint64_t>(a: TotalIndirectCallsites, b: 1))
1451 << "%\n"
1452 << "BOLT-INFO: ICP number of method load elimination candidates = "
1453 << TotalMethodLoadEliminationCandidates << "\n"
1454 << "BOLT-INFO: ICP percentage of method calls candidates that have "
1455 "loads eliminated = "
1456 << format(Fmt: "%.1f",
1457 Vals: (100.0 * TotalMethodLoadsEliminated) /
1458 std::max<uint64_t>(a: TotalMethodLoadEliminationCandidates, b: 1))
1459 << "%\n"
1460 << "BOLT-INFO: ICP percentage of indirect branches that are "
1461 "optimized = "
1462 << format(Fmt: "%.1f", Vals: (100.0 * TotalNumFrequentJmps) /
1463 std::max<uint64_t>(a: TotalIndirectJmps, b: 1))
1464 << "%\n"
1465 << "BOLT-INFO: ICP percentage of jump table callsites that are "
1466 << "optimized = "
1467 << format(Fmt: "%.1f", Vals: (100.0 * TotalOptimizedJumpTableCallsites) /
1468 std::max<uint64_t>(a: TotalJumpTableCallsites, b: 1))
1469 << "%\n"
1470 << "BOLT-INFO: ICP number of jump table callsites that can use hot "
1471 << "indices = " << TotalIndexBasedCandidates << "\n"
1472 << "BOLT-INFO: ICP percentage of jump table callsites that use hot "
1473 "indices = "
1474 << format(Fmt: "%.1f", Vals: (100.0 * TotalIndexBasedJumps) /
1475 std::max<uint64_t>(a: TotalIndexBasedCandidates, b: 1))
1476 << "%\n";
1477
1478#ifndef NDEBUG
1479 verifyProfile(BFs);
1480#endif
1481 return Error::success();
1482}
1483
1484} // namespace bolt
1485} // namespace llvm
1486

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of bolt/lib/Passes/IndirectCallPromotion.cpp