1//===- bolt/Passes/ValidateMemRefs.cpp ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "bolt/Passes/ValidateMemRefs.h"
10#include "bolt/Core/ParallelUtilities.h"
11
12#define DEBUG_TYPE "bolt-memrefs"
13
14namespace opts {
15extern llvm::cl::opt<llvm::bolt::JumpTableSupportLevel> JumpTables;
16}
17
18namespace llvm::bolt {
19
20std::atomic<std::uint64_t> ValidateMemRefs::ReplacedReferences{0};
21
22bool ValidateMemRefs::checkAndFixJTReference(BinaryFunction &BF, MCInst &Inst,
23 uint32_t OperandNum,
24 const MCSymbol *Sym,
25 uint64_t Offset) {
26 BinaryContext &BC = BF.getBinaryContext();
27 auto L = BC.scopeLock();
28 BinaryData *BD = BC.getBinaryDataByName(Name: Sym->getName());
29 if (!BD)
30 return false;
31
32 const uint64_t TargetAddress = BD->getAddress() + Offset;
33 JumpTable *JT = BC.getJumpTableContainingAddress(Address: TargetAddress);
34 if (!JT)
35 return false;
36
37 const bool IsLegitAccess = llvm::is_contained(Range&: JT->Parents, Element: &BF);
38 if (IsLegitAccess)
39 return true;
40
41 // Accessing a jump table in another function. This is not a
42 // legitimate jump table access, we need to replace the reference to
43 // the jump table label with a regular rodata reference. Get a
44 // non-JT reference by fetching the symbol 1 byte before the JT
45 // label.
46 MCSymbol *NewSym = BC.getOrCreateGlobalSymbol(Address: TargetAddress - 1, Prefix: "DATAat");
47 BC.MIB->setOperandToSymbolRef(Inst, OpNum: OperandNum, Symbol: NewSym, Addend: 1, Ctx: &*BC.Ctx, RelType: 0);
48 LLVM_DEBUG(dbgs() << "BOLT-DEBUG: replaced reference @" << BF.getPrintName()
49 << " from " << BD->getName() << " to " << NewSym->getName()
50 << " + 1\n");
51 ++ReplacedReferences;
52 return true;
53}
54
55void ValidateMemRefs::runOnFunction(BinaryFunction &BF) {
56 MCPlusBuilder *MIB = BF.getBinaryContext().MIB.get();
57
58 for (BinaryBasicBlock &BB : BF) {
59 for (MCInst &Inst : BB) {
60 for (int I = 0, E = MCPlus::getNumPrimeOperands(Inst); I != E; ++I) {
61 const MCOperand &Operand = Inst.getOperand(i: I);
62 if (!Operand.isExpr())
63 continue;
64
65 const auto [Sym, Offset] = MIB->getTargetSymbolInfo(Expr: Operand.getExpr());
66 if (!Sym)
67 continue;
68
69 checkAndFixJTReference(BF, Inst, OperandNum: I, Sym, Offset);
70 }
71 }
72 }
73}
74
75Error ValidateMemRefs::runOnFunctions(BinaryContext &BC) {
76 if (!BC.isX86())
77 return Error::success();
78
79 // Skip validation if not moving JT
80 if (opts::JumpTables == JTS_NONE || opts::JumpTables == JTS_BASIC)
81 return Error::success();
82
83 ParallelUtilities::WorkFuncWithAllocTy ProcessFunction =
84 [&](BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId) {
85 runOnFunction(BF);
86 };
87 ParallelUtilities::PredicateTy SkipPredicate = [&](const BinaryFunction &BF) {
88 return !BF.hasCFG();
89 };
90 LLVM_DEBUG(dbgs() << "BOLT-DEBUG: starting memrefs validation pass\n");
91 ParallelUtilities::runOnEachFunctionWithUniqueAllocId(
92 BC, SchedPolicy: ParallelUtilities::SchedulingPolicy::SP_INST_LINEAR, WorkFunction: ProcessFunction,
93 SkipPredicate, LogName: "validate-mem-refs", /*ForceSequential=*/true);
94 LLVM_DEBUG(dbgs() << "BOLT-DEBUG: memrefs validation is concluded\n");
95
96 if (!ReplacedReferences)
97 return Error::success();
98
99 BC.outs() << "BOLT-INFO: validate-mem-refs updated " << ReplacedReferences
100 << " object references\n";
101 return Error::success();
102}
103
104} // namespace llvm::bolt
105

source code of bolt/lib/Passes/ValidateMemRefs.cpp