1 | //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// \file Pass to pre-config the shapes of AMX registers |
10 | /// AMX register needs to be configured before use. The shapes of AMX register |
11 | /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. |
12 | /// |
13 | /// The instruction ldtilecfg is used to config the shapes. It must be reachable |
14 | /// for all variable shapes. ldtilecfg will be inserted more than once if we |
15 | /// cannot find a dominating point for all AMX instructions. |
16 | /// |
17 | /// The configure register is caller saved according to ABI. We need to insert |
18 | /// ldtilecfg again after the call instruction if callee clobbers any AMX |
19 | /// registers. |
20 | /// |
21 | /// This pass calculates all points that ldtilecfg need to be inserted to and |
22 | /// insert them. It reports error if the reachability conditions aren't met. |
23 | // |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | #include "X86.h" |
27 | #include "X86InstrBuilder.h" |
28 | #include "X86MachineFunctionInfo.h" |
29 | #include "X86RegisterInfo.h" |
30 | #include "X86Subtarget.h" |
31 | #include "llvm/ADT/SmallSet.h" |
32 | #include "llvm/CodeGen/MachineFunctionPass.h" |
33 | #include "llvm/CodeGen/MachineInstr.h" |
34 | #include "llvm/CodeGen/MachineLoopInfo.h" |
35 | #include "llvm/CodeGen/MachineModuleInfo.h" |
36 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
37 | #include "llvm/CodeGen/Passes.h" |
38 | #include "llvm/CodeGen/TargetInstrInfo.h" |
39 | #include "llvm/CodeGen/TargetRegisterInfo.h" |
40 | #include "llvm/InitializePasses.h" |
41 | |
42 | using namespace llvm; |
43 | |
44 | #define DEBUG_TYPE "tile-pre-config" |
45 | |
46 | static void emitErrorMsg(MachineFunction &MF) { |
47 | LLVMContext &Context = MF.getMMI().getModule()->getContext(); |
48 | Context.emitError( |
49 | ErrorStr: MF.getName() + |
50 | ": Failed to config tile register, please define the shape earlier" ); |
51 | } |
52 | |
53 | namespace { |
54 | |
55 | struct MIRef { |
56 | MachineInstr *MI = nullptr; |
57 | MachineBasicBlock *MBB = nullptr; |
58 | // A virtual position for instruction that will be inserted after MI. |
59 | size_t Pos = 0; |
60 | MIRef() = default; |
61 | MIRef(MachineBasicBlock *MBB) : MBB(MBB) { |
62 | for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); |
63 | ++I, ++Pos) |
64 | MI = &*I; |
65 | } |
66 | MIRef(MachineInstr *MI) |
67 | : MI(MI), MBB(MI->getParent()), |
68 | Pos(std::distance(first: MBB->instr_begin(), last: ++MI->getIterator())) {} |
69 | MIRef(MachineInstr *MI, MachineBasicBlock *MBB) |
70 | : MI(MI), MBB(MBB), |
71 | Pos(std::distance(first: MBB->instr_begin(), last: ++MI->getIterator())) {} |
72 | MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) |
73 | : MI(MI), MBB(MBB), Pos(Pos) {} |
74 | operator bool() const { return MBB != nullptr; } |
75 | bool operator==(const MIRef &RHS) const { |
76 | return MI == RHS.MI && MBB == RHS.MBB; |
77 | } |
78 | bool operator!=(const MIRef &RHS) const { return !(*this == RHS); } |
79 | bool operator<(const MIRef &RHS) const { |
80 | // Comparison between different BBs happens when inserting a MIRef into set. |
81 | // So we compare MBB first to make the insertion happy. |
82 | return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); |
83 | } |
84 | bool operator>(const MIRef &RHS) const { |
85 | // Comparison between different BBs happens when inserting a MIRef into set. |
86 | // So we compare MBB first to make the insertion happy. |
87 | return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos); |
88 | } |
89 | }; |
90 | |
91 | struct BBInfo { |
92 | MIRef FirstAMX; |
93 | MIRef LastCall; |
94 | bool HasAMXRegLiveIn = false; |
95 | bool TileCfgForbidden = false; |
96 | bool NeedTileCfgLiveIn = false; |
97 | }; |
98 | |
99 | class X86PreTileConfig : public MachineFunctionPass { |
100 | MachineRegisterInfo *MRI = nullptr; |
101 | const MachineLoopInfo *MLI = nullptr; |
102 | SmallSet<MachineInstr *, 8> DefVisited; |
103 | DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; |
104 | DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs; |
105 | |
106 | /// Check if the callee will clobber AMX registers. |
107 | bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { |
108 | auto Iter = llvm::find_if( |
109 | Range: MI.operands(), P: [](MachineOperand &MO) { return MO.isRegMask(); }); |
110 | if (Iter == MI.operands_end()) |
111 | return false; |
112 | UsableRegs.clearBitsInMask(Mask: Iter->getRegMask()); |
113 | return !UsableRegs.none(); |
114 | } |
115 | |
116 | /// Check if MI is AMX pseudo instruction. |
117 | bool isAMXInstruction(MachineInstr &MI) { |
118 | if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3) |
119 | return false; |
120 | MachineOperand &MO = MI.getOperand(i: 0); |
121 | // We can simply check if it is AMX instruction by its def. |
122 | // But we should exclude old API which uses physical registers. |
123 | if (MO.isReg() && MO.getReg().isVirtual() && |
124 | MRI->getRegClass(Reg: MO.getReg())->getID() == X86::TILERegClassID) { |
125 | collectShapeInfo(MI); |
126 | return true; |
127 | } |
128 | // PTILESTOREDV is the only exception that doesn't def a AMX register. |
129 | return MI.getOpcode() == X86::PTILESTOREDV; |
130 | } |
131 | |
132 | /// Check if it is an edge from loop bottom to loop head. |
133 | bool isLoopBackEdge(MachineBasicBlock *, MachineBasicBlock *Bottom) { |
134 | if (!MLI->isLoopHeader(BB: Header)) |
135 | return false; |
136 | auto *ML = MLI->getLoopFor(BB: Header); |
137 | if (ML->contains(BB: Bottom) && ML->isLoopLatch(BB: Bottom)) |
138 | return true; |
139 | |
140 | return false; |
141 | } |
142 | |
143 | /// Collect the shape def information for later use. |
144 | void collectShapeInfo(MachineInstr &MI); |
145 | |
146 | /// Try to hoist shapes definded below AMX instructions. |
147 | bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) { |
148 | MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX; |
149 | auto FirstShapeBelowAMX = llvm::lower_bound(Range&: Shapes, Value&: FirstAMX); |
150 | auto InsertPoint = FirstAMX.MI->getIterator(); |
151 | for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) { |
152 | // Do not hoist instructions that access memory. |
153 | if (I->MI->mayLoadOrStore()) |
154 | return false; |
155 | for (auto &MO : I->MI->operands()) { |
156 | if (MO.isDef()) |
157 | continue; |
158 | // Do not hoist instructions if the sources' def under AMX instruction. |
159 | // TODO: We can handle isMoveImmediate MI here. |
160 | if (MO.isReg() && MIRef(MRI->getVRegDef(Reg: MO.getReg())) > FirstAMX) |
161 | return false; |
162 | // TODO: Maybe need more checks here. |
163 | } |
164 | MBB->insert(I: InsertPoint, M: I->MI->removeFromParent()); |
165 | } |
166 | // We only need to mark the last shape in the BB now. |
167 | Shapes.clear(); |
168 | Shapes.push_back(Elt: MIRef(&*--InsertPoint, MBB)); |
169 | return true; |
170 | } |
171 | |
172 | public: |
173 | X86PreTileConfig() : MachineFunctionPass(ID) {} |
174 | |
175 | /// Return the pass name. |
176 | StringRef getPassName() const override { |
177 | return "Tile Register Pre-configure" ; |
178 | } |
179 | |
180 | /// X86PreTileConfig analysis usage. |
181 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
182 | AU.setPreservesAll(); |
183 | AU.addRequired<MachineLoopInfo>(); |
184 | MachineFunctionPass::getAnalysisUsage(AU); |
185 | } |
186 | |
187 | /// Clear MF related structures. |
188 | void releaseMemory() override { |
189 | ShapeBBs.clear(); |
190 | DefVisited.clear(); |
191 | BBVisitedInfo.clear(); |
192 | } |
193 | |
194 | /// Perform ldtilecfg instructions inserting. |
195 | bool runOnMachineFunction(MachineFunction &MF) override; |
196 | |
197 | static char ID; |
198 | }; |
199 | |
200 | } // end anonymous namespace |
201 | |
202 | char X86PreTileConfig::ID = 0; |
203 | |
204 | INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig" , |
205 | "Tile Register Pre-configure" , false, false) |
206 | INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) |
207 | INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig" , |
208 | "Tile Register Pre-configure" , false, false) |
209 | |
210 | void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { |
211 | auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { |
212 | MIRef MIR(MI, MBB); |
213 | auto I = llvm::lower_bound(Range&: ShapeBBs[MBB], Value&: MIR); |
214 | if (I == ShapeBBs[MBB].end() || *I != MIR) |
215 | ShapeBBs[MBB].insert(I, Elt: MIR); |
216 | }; |
217 | |
218 | SmallVector<Register, 8> WorkList( |
219 | {MI.getOperand(i: 1).getReg(), MI.getOperand(i: 2).getReg()}); |
220 | while (!WorkList.empty()) { |
221 | Register R = WorkList.pop_back_val(); |
222 | MachineInstr *DefMI = MRI->getVRegDef(Reg: R); |
223 | assert(DefMI && "R must has one define instruction" ); |
224 | MachineBasicBlock *DefMBB = DefMI->getParent(); |
225 | if (DefMI->isMoveImmediate() || !DefVisited.insert(Ptr: DefMI).second) |
226 | continue; |
227 | if (DefMI->isPHI()) { |
228 | for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) |
229 | if (isLoopBackEdge(Header: DefMBB, Bottom: DefMI->getOperand(i: I + 1).getMBB())) |
230 | RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def. |
231 | else |
232 | WorkList.push_back(Elt: DefMI->getOperand(i: I).getReg()); |
233 | } else { |
234 | RecordShape(DefMI, DefMBB); |
235 | } |
236 | } |
237 | } |
238 | |
239 | bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { |
240 | const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); |
241 | const TargetInstrInfo *TII = ST.getInstrInfo(); |
242 | const TargetRegisterInfo *TRI = ST.getRegisterInfo(); |
243 | const TargetRegisterClass *RC = TRI->getRegClass(X86::i: TILERegClassID); |
244 | X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>(); |
245 | |
246 | BitVector AMXRegs(TRI->getNumRegs()); |
247 | for (unsigned I = 0; I < RC->getNumRegs(); I++) |
248 | AMXRegs.set(X86::TMM0 + I); |
249 | |
250 | // Iterate MF to collect information. |
251 | MRI = &MF.getRegInfo(); |
252 | MLI = &getAnalysis<MachineLoopInfo>(); |
253 | SmallSet<MIRef, 8> CfgNeedInsert; |
254 | SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs; |
255 | for (auto &MBB : MF) { |
256 | size_t Pos = 0; |
257 | for (auto &MI : MBB) { |
258 | ++Pos; |
259 | if (isAMXInstruction(MI)) { |
260 | // If there's call before the AMX, we need to reload tile config. |
261 | if (BBVisitedInfo[&MBB].LastCall) |
262 | CfgNeedInsert.insert(V: BBVisitedInfo[&MBB].LastCall); |
263 | else // Otherwise, we need tile config to live in this BB. |
264 | BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; |
265 | // Always record the first AMX in case there's shape def after it. |
266 | if (!BBVisitedInfo[&MBB].FirstAMX) |
267 | BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); |
268 | } else if (MI.isCall() && isDestructiveCall(MI, UsableRegs: AMXRegs)) { |
269 | // Record the call only if the callee clobbers all AMX registers. |
270 | BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); |
271 | } |
272 | } |
273 | if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { |
274 | if (&MBB == &MF.front()) |
275 | CfgNeedInsert.insert(V: MIRef(&MBB)); |
276 | else |
277 | CfgLiveInBBs.push_back(Elt: &MBB); |
278 | } |
279 | if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn) |
280 | for (auto *Succ : MBB.successors()) |
281 | if (!isLoopBackEdge(Header: Succ, Bottom: &MBB)) |
282 | BBVisitedInfo[Succ].HasAMXRegLiveIn = true; |
283 | } |
284 | |
285 | // Update NeedTileCfgLiveIn for predecessors. |
286 | while (!CfgLiveInBBs.empty()) { |
287 | MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); |
288 | for (auto *Pred : MBB->predecessors()) { |
289 | if (BBVisitedInfo[Pred].LastCall) { |
290 | CfgNeedInsert.insert(V: BBVisitedInfo[Pred].LastCall); |
291 | } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { |
292 | BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; |
293 | if (Pred == &MF.front()) |
294 | CfgNeedInsert.insert(V: MIRef(Pred)); |
295 | else |
296 | CfgLiveInBBs.push_back(Elt: Pred); |
297 | } |
298 | } |
299 | } |
300 | |
301 | // There's no AMX instruction if we didn't find a tile config live in point. |
302 | if (CfgNeedInsert.empty()) |
303 | return false; |
304 | X86FI->setHasVirtualTileReg(true); |
305 | |
306 | // Avoid to insert ldtilecfg before any shape defs. |
307 | SmallVector<MachineBasicBlock *, 8> WorkList; |
308 | for (auto &I : ShapeBBs) { |
309 | // TODO: We can hoist shapes across BBs here. |
310 | if (BBVisitedInfo[I.first].HasAMXRegLiveIn) { |
311 | // We are not able to config tile registers since the shape to config |
312 | // is not defined yet. Emit error message and continue. The function |
313 | // would not config tile registers. |
314 | emitErrorMsg(MF); |
315 | return false; |
316 | } |
317 | if (BBVisitedInfo[I.first].FirstAMX && |
318 | BBVisitedInfo[I.first].FirstAMX < I.second.back() && |
319 | !hoistShapesInBB(MBB: I.first, Shapes&: I.second)) { |
320 | emitErrorMsg(MF); |
321 | return false; |
322 | } |
323 | WorkList.push_back(Elt: I.first); |
324 | } |
325 | while (!WorkList.empty()) { |
326 | MachineBasicBlock *MBB = WorkList.pop_back_val(); |
327 | for (auto *Pred : MBB->predecessors()) { |
328 | if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(Header: MBB, Bottom: Pred)) { |
329 | BBVisitedInfo[Pred].TileCfgForbidden = true; |
330 | WorkList.push_back(Elt: Pred); |
331 | } |
332 | } |
333 | } |
334 | |
335 | DebugLoc DL; |
336 | SmallSet<MIRef, 8> VisitedOrInserted; |
337 | int SS = MF.getFrameInfo().CreateStackObject( |
338 | Size: ST.getTileConfigSize(), Alignment: ST.getTileConfigAlignment(), isSpillSlot: false); |
339 | |
340 | // Try to insert for the tile config live in points. |
341 | for (const auto &I : CfgNeedInsert) { |
342 | SmallSet<MIRef, 8> InsertPoints; |
343 | SmallVector<MIRef, 8> WorkList({I}); |
344 | while (!WorkList.empty()) { |
345 | MIRef I = WorkList.pop_back_val(); |
346 | if (!VisitedOrInserted.count(V: I)) { |
347 | if (!BBVisitedInfo[I.MBB].TileCfgForbidden) { |
348 | // If the BB is all shapes reachable, stop sink and try to insert. |
349 | InsertPoints.insert(V: I); |
350 | } else { |
351 | // Avoid the BB to be multi visited. |
352 | VisitedOrInserted.insert(V: I); |
353 | // Sink the inserting point along the chain with NeedTileCfgLiveIn = |
354 | // true when MBB isn't all shapes reachable. |
355 | for (auto *Succ : I.MBB->successors()) |
356 | if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) |
357 | WorkList.push_back(Elt: MIRef(Succ)); |
358 | } |
359 | } |
360 | } |
361 | |
362 | // A given point might be forked due to shape conditions are not met. |
363 | for (MIRef I : InsertPoints) { |
364 | // Make sure we insert ldtilecfg after the last shape def in MBB. |
365 | if (ShapeBBs.count(Val: I.MBB) && I < ShapeBBs[I.MBB].back()) |
366 | I = ShapeBBs[I.MBB].back(); |
367 | // There're chances the MBB is sunk more than once. Record it to avoid |
368 | // multi insert. |
369 | if (VisitedOrInserted.insert(V: I).second) { |
370 | auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin(); |
371 | addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::Opcode: PLDTILECFGV)), |
372 | SS); |
373 | } |
374 | } |
375 | } |
376 | |
377 | // Zero stack slot. |
378 | MachineBasicBlock &MBB = MF.front(); |
379 | MachineInstr *MI = &*MBB.begin(); |
380 | if (ST.hasAVX512()) { |
381 | Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); |
382 | BuildMI(MBB, MI, DL, TII->get(X86::Opcode: AVX512_512_SET0), Zmm); |
383 | addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::Opcode: VMOVUPSZmr)), SS) |
384 | .addReg(Zmm); |
385 | } else if (ST.hasAVX2()) { |
386 | Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); |
387 | BuildMI(MBB, MI, DL, TII->get(X86::Opcode: AVX_SET0), Ymm); |
388 | addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::Opcode: VMOVUPSYmr)), SS) |
389 | .addReg(Ymm); |
390 | addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::Opcode: VMOVUPSYmr)), SS, 32) |
391 | .addReg(Ymm); |
392 | } else { |
393 | assert(ST.hasSSE2() && "AMX should assume SSE2 enabled" ); |
394 | unsigned StoreOpc = ST.hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr; |
395 | Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); |
396 | BuildMI(MBB, MI, DL, TII->get(X86::Opcode: V_SET0), Xmm); |
397 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS).addReg(RegNo: Xmm); |
398 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 16) |
399 | .addReg(RegNo: Xmm); |
400 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 32) |
401 | .addReg(RegNo: Xmm); |
402 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 48) |
403 | .addReg(RegNo: Xmm); |
404 | } |
405 | // Fill in the palette first. |
406 | addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::Opcode: MOV8mi)), SS).addImm(1); |
407 | |
408 | return true; |
409 | } |
410 | |
411 | FunctionPass *llvm::createX86PreTileConfigPass() { |
412 | return new X86PreTileConfig(); |
413 | } |
414 | |