1//===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass identifies/eliminate Redundant TLS Loads if related option is set.
10// The example: Please refer to the comment at the head of TLSVariableHoist.h.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ADT/SmallVector.h"
15#include "llvm/IR/BasicBlock.h"
16#include "llvm/IR/Dominators.h"
17#include "llvm/IR/Function.h"
18#include "llvm/IR/InstrTypes.h"
19#include "llvm/IR/Instruction.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/IntrinsicInst.h"
22#include "llvm/IR/Module.h"
23#include "llvm/IR/Value.h"
24#include "llvm/InitializePasses.h"
25#include "llvm/Pass.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/raw_ostream.h"
29#include "llvm/Transforms/Scalar.h"
30#include "llvm/Transforms/Scalar/TLSVariableHoist.h"
31#include <algorithm>
32#include <cassert>
33#include <cstdint>
34#include <iterator>
35#include <utility>
36
37using namespace llvm;
38using namespace tlshoist;
39
40#define DEBUG_TYPE "tlshoist"
41
42static cl::opt<bool> TLSLoadHoist(
43 "tls-load-hoist", cl::init(Val: false), cl::Hidden,
44 cl::desc("hoist the TLS loads in PIC model to eliminate redundant "
45 "TLS address calculation."));
46
47namespace {
48
49/// The TLS Variable hoist pass.
50class TLSVariableHoistLegacyPass : public FunctionPass {
51public:
52 static char ID; // Pass identification, replacement for typeid
53
54 TLSVariableHoistLegacyPass() : FunctionPass(ID) {
55 initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry());
56 }
57
58 bool runOnFunction(Function &Fn) override;
59
60 StringRef getPassName() const override { return "TLS Variable Hoist"; }
61
62 void getAnalysisUsage(AnalysisUsage &AU) const override {
63 AU.setPreservesCFG();
64 AU.addRequired<DominatorTreeWrapperPass>();
65 AU.addRequired<LoopInfoWrapperPass>();
66 }
67
68private:
69 TLSVariableHoistPass Impl;
70};
71
72} // end anonymous namespace
73
74char TLSVariableHoistLegacyPass::ID = 0;
75
76INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist",
77 "TLS Variable Hoist", false, false)
78INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
79INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
80INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist",
81 "TLS Variable Hoist", false, false)
82
83FunctionPass *llvm::createTLSVariableHoistPass() {
84 return new TLSVariableHoistLegacyPass();
85}
86
87/// Perform the TLS Variable Hoist optimization for the given function.
88bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) {
89 if (skipFunction(F: Fn))
90 return false;
91
92 LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n");
93 LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
94
95 bool MadeChange =
96 Impl.runImpl(F&: Fn, DT&: getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
97 LI&: getAnalysis<LoopInfoWrapperPass>().getLoopInfo());
98
99 if (MadeChange) {
100 LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: "
101 << Fn.getName() << '\n');
102 LLVM_DEBUG(dbgs() << Fn);
103 }
104 LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n");
105
106 return MadeChange;
107}
108
109void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) {
110 // Skip all cast instructions. They are visited indirectly later on.
111 if (Inst->isCast())
112 return;
113
114 // Scan all operands.
115 for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
116 auto *GV = dyn_cast<GlobalVariable>(Val: Inst->getOperand(i: Idx));
117 if (!GV || !GV->isThreadLocal())
118 continue;
119
120 // Add Candidate to TLSCandMap (GV --> Candidate).
121 TLSCandMap[GV].addUser(Inst, Idx);
122 }
123}
124
125void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) {
126 // First, quickly check if there is TLS Variable.
127 Module *M = Fn.getParent();
128
129 bool HasTLS = llvm::any_of(
130 Range: M->globals(), P: [](GlobalVariable &GV) { return GV.isThreadLocal(); });
131
132 // If non, directly return.
133 if (!HasTLS)
134 return;
135
136 TLSCandMap.clear();
137
138 // Then, collect TLS Variable info.
139 for (BasicBlock &BB : Fn) {
140 // Ignore unreachable basic blocks.
141 if (!DT->isReachableFromEntry(A: &BB))
142 continue;
143
144 for (Instruction &Inst : BB)
145 collectTLSCandidate(Inst: &Inst);
146 }
147}
148
149static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) {
150 if (Cand.Users.size() != 1)
151 return false;
152
153 BasicBlock *BB = Cand.Users[0].Inst->getParent();
154 if (LI->getLoopFor(BB))
155 return false;
156
157 return true;
158}
159
160Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB,
161 Loop *L) {
162 assert(L && "Unexcepted Loop status!");
163
164 // Get the outermost loop.
165 while (Loop *Parent = L->getParentLoop())
166 L = Parent;
167
168 BasicBlock *PreHeader = L->getLoopPreheader();
169
170 // There is unique predecessor outside the loop.
171 if (PreHeader)
172 return PreHeader->getTerminator();
173
174 BasicBlock *Header = L->getHeader();
175 BasicBlock *Dom = Header;
176 for (BasicBlock *PredBB : predecessors(BB: Header))
177 Dom = DT->findNearestCommonDominator(A: Dom, B: PredBB);
178
179 assert(Dom && "Not find dominator BB!");
180 Instruction *Term = Dom->getTerminator();
181
182 return Term;
183}
184
185Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1,
186 Instruction *I2) {
187 if (!I1)
188 return I2;
189 return DT->findNearestCommonDominator(I1, I2);
190}
191
192BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn,
193 GlobalVariable *GV,
194 BasicBlock *&PosBB) {
195 tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
196
197 // We should hoist the TLS use out of loop, so choose its nearest instruction
198 // which dominate the loop and the outside loops (if exist).
199 Instruction *LastPos = nullptr;
200 for (auto &User : Cand.Users) {
201 BasicBlock *BB = User.Inst->getParent();
202 Instruction *Pos = User.Inst;
203 if (Loop *L = LI->getLoopFor(BB)) {
204 Pos = getNearestLoopDomInst(BB, L);
205 assert(Pos && "Not find insert position out of loop!");
206 }
207 Pos = getDomInst(I1: LastPos, I2: Pos);
208 LastPos = Pos;
209 }
210
211 assert(LastPos && "Unexpected insert position!");
212 BasicBlock *Parent = LastPos->getParent();
213 PosBB = Parent;
214 return LastPos->getIterator();
215}
216
217// Generate a bitcast (no type change) to replace the uses of TLS Candidate.
218Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn,
219 GlobalVariable *GV) {
220 BasicBlock *PosBB = &Fn.getEntryBlock();
221 BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB);
222 Type *Ty = GV->getType();
223 auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast");
224 CastInst->insertInto(ParentBB: PosBB, It: Iter);
225 return CastInst;
226}
227
228bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn,
229 GlobalVariable *GV) {
230
231 tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
232
233 // If only used 1 time and not in loops, we no need to replace it.
234 if (oneUseOutsideLoop(Cand, LI))
235 return false;
236
237 // Generate a bitcast (no type change)
238 auto *CastInst = genBitCastInst(Fn, GV);
239
240 // to replace the uses of TLS Candidate
241 for (auto &User : Cand.Users)
242 User.Inst->setOperand(i: User.OpndIdx, Val: CastInst);
243
244 return true;
245}
246
247bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) {
248 if (TLSCandMap.empty())
249 return false;
250
251 bool Replaced = false;
252 for (auto &GV2Cand : TLSCandMap) {
253 GlobalVariable *GV = GV2Cand.first;
254 Replaced |= tryReplaceTLSCandidate(Fn, GV);
255 }
256
257 return Replaced;
258}
259
260/// Optimize expensive TLS variables in the given function.
261bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT,
262 LoopInfo &LI) {
263 if (Fn.hasOptNone())
264 return false;
265
266 if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr(Kind: "tls-load-hoist"))
267 return false;
268
269 this->LI = &LI;
270 this->DT = &DT;
271 assert(this->LI && this->DT && "Unexcepted requirement!");
272
273 // Collect all TLS variable candidates.
274 collectTLSCandidates(Fn);
275
276 bool MadeChange = tryReplaceTLSCandidates(Fn);
277
278 return MadeChange;
279}
280
281PreservedAnalyses TLSVariableHoistPass::run(Function &F,
282 FunctionAnalysisManager &AM) {
283
284 auto &LI = AM.getResult<LoopAnalysis>(IR&: F);
285 auto &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F);
286
287 if (!runImpl(Fn&: F, DT, LI))
288 return PreservedAnalyses::all();
289
290 PreservedAnalyses PA;
291 PA.preserveSet<CFGAnalyses>();
292 return PA;
293}
294

source code of llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp