1//===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to pre-config the shapes of AMX registers
10/// AMX register needs to be configured before use. The shapes of AMX register
11/// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12///
13/// The instruction ldtilecfg is used to config the shapes. It must be reachable
14/// for all variable shapes. ldtilecfg will be inserted more than once if we
15/// cannot find a dominating point for all AMX instructions.
16///
17/// The configure register is caller saved according to ABI. We need to insert
18/// ldtilecfg again after the call instruction if callee clobbers any AMX
19/// registers.
20///
21/// This pass calculates all points that ldtilecfg need to be inserted to and
22/// insert them. It reports error if the reachability conditions aren't met.
23//
24//===----------------------------------------------------------------------===//
25
26#include "X86.h"
27#include "X86InstrBuilder.h"
28#include "X86MachineFunctionInfo.h"
29#include "X86RegisterInfo.h"
30#include "X86Subtarget.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/CodeGen/MachineFunctionPass.h"
33#include "llvm/CodeGen/MachineInstr.h"
34#include "llvm/CodeGen/MachineLoopInfo.h"
35#include "llvm/CodeGen/MachineModuleInfo.h"
36#include "llvm/CodeGen/MachineRegisterInfo.h"
37#include "llvm/CodeGen/Passes.h"
38#include "llvm/CodeGen/TargetInstrInfo.h"
39#include "llvm/CodeGen/TargetRegisterInfo.h"
40#include "llvm/InitializePasses.h"
41
42using namespace llvm;
43
44#define DEBUG_TYPE "tile-pre-config"
45
46static void emitErrorMsg(MachineFunction &MF) {
47 LLVMContext &Context = MF.getMMI().getModule()->getContext();
48 Context.emitError(
49 ErrorStr: MF.getName() +
50 ": Failed to config tile register, please define the shape earlier");
51}
52
53namespace {
54
55struct MIRef {
56 MachineInstr *MI = nullptr;
57 MachineBasicBlock *MBB = nullptr;
58 // A virtual position for instruction that will be inserted after MI.
59 size_t Pos = 0;
60 MIRef() = default;
61 MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
62 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
63 ++I, ++Pos)
64 MI = &*I;
65 }
66 MIRef(MachineInstr *MI)
67 : MI(MI), MBB(MI->getParent()),
68 Pos(std::distance(first: MBB->instr_begin(), last: ++MI->getIterator())) {}
69 MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
70 : MI(MI), MBB(MBB),
71 Pos(std::distance(first: MBB->instr_begin(), last: ++MI->getIterator())) {}
72 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
73 : MI(MI), MBB(MBB), Pos(Pos) {}
74 operator bool() const { return MBB != nullptr; }
75 bool operator==(const MIRef &RHS) const {
76 return MI == RHS.MI && MBB == RHS.MBB;
77 }
78 bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
79 bool operator<(const MIRef &RHS) const {
80 // Comparison between different BBs happens when inserting a MIRef into set.
81 // So we compare MBB first to make the insertion happy.
82 return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
83 }
84 bool operator>(const MIRef &RHS) const {
85 // Comparison between different BBs happens when inserting a MIRef into set.
86 // So we compare MBB first to make the insertion happy.
87 return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos);
88 }
89};
90
91struct BBInfo {
92 MIRef FirstAMX;
93 MIRef LastCall;
94 bool HasAMXRegLiveIn = false;
95 bool TileCfgForbidden = false;
96 bool NeedTileCfgLiveIn = false;
97};
98
99class X86PreTileConfig : public MachineFunctionPass {
100 MachineRegisterInfo *MRI = nullptr;
101 const MachineLoopInfo *MLI = nullptr;
102 SmallSet<MachineInstr *, 8> DefVisited;
103 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
104 DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
105
106 /// Check if the callee will clobber AMX registers.
107 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
108 auto Iter = llvm::find_if(
109 Range: MI.operands(), P: [](MachineOperand &MO) { return MO.isRegMask(); });
110 if (Iter == MI.operands_end())
111 return false;
112 UsableRegs.clearBitsInMask(Mask: Iter->getRegMask());
113 return !UsableRegs.none();
114 }
115
116 /// Check if MI is AMX pseudo instruction.
117 bool isAMXInstruction(MachineInstr &MI) {
118 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
119 return false;
120 MachineOperand &MO = MI.getOperand(i: 0);
121 // We can simply check if it is AMX instruction by its def.
122 // But we should exclude old API which uses physical registers.
123 if (MO.isReg() && MO.getReg().isVirtual() &&
124 MRI->getRegClass(Reg: MO.getReg())->getID() == X86::TILERegClassID) {
125 collectShapeInfo(MI);
126 return true;
127 }
128 // PTILESTOREDV is the only exception that doesn't def a AMX register.
129 return MI.getOpcode() == X86::PTILESTOREDV;
130 }
131
132 /// Check if it is an edge from loop bottom to loop head.
133 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
134 if (!MLI->isLoopHeader(BB: Header))
135 return false;
136 auto *ML = MLI->getLoopFor(BB: Header);
137 if (ML->contains(BB: Bottom) && ML->isLoopLatch(BB: Bottom))
138 return true;
139
140 return false;
141 }
142
143 /// Collect the shape def information for later use.
144 void collectShapeInfo(MachineInstr &MI);
145
146 /// Try to hoist shapes definded below AMX instructions.
147 bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
148 MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
149 auto FirstShapeBelowAMX = llvm::lower_bound(Range&: Shapes, Value&: FirstAMX);
150 auto InsertPoint = FirstAMX.MI->getIterator();
151 for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
152 // Do not hoist instructions that access memory.
153 if (I->MI->mayLoadOrStore())
154 return false;
155 for (auto &MO : I->MI->operands()) {
156 if (MO.isDef())
157 continue;
158 // Do not hoist instructions if the sources' def under AMX instruction.
159 // TODO: We can handle isMoveImmediate MI here.
160 if (MO.isReg() && MIRef(MRI->getVRegDef(Reg: MO.getReg())) > FirstAMX)
161 return false;
162 // TODO: Maybe need more checks here.
163 }
164 MBB->insert(I: InsertPoint, M: I->MI->removeFromParent());
165 }
166 // We only need to mark the last shape in the BB now.
167 Shapes.clear();
168 Shapes.push_back(Elt: MIRef(&*--InsertPoint, MBB));
169 return true;
170 }
171
172public:
173 X86PreTileConfig() : MachineFunctionPass(ID) {}
174
175 /// Return the pass name.
176 StringRef getPassName() const override {
177 return "Tile Register Pre-configure";
178 }
179
180 /// X86PreTileConfig analysis usage.
181 void getAnalysisUsage(AnalysisUsage &AU) const override {
182 AU.setPreservesAll();
183 AU.addRequired<MachineLoopInfo>();
184 MachineFunctionPass::getAnalysisUsage(AU);
185 }
186
187 /// Clear MF related structures.
188 void releaseMemory() override {
189 ShapeBBs.clear();
190 DefVisited.clear();
191 BBVisitedInfo.clear();
192 }
193
194 /// Perform ldtilecfg instructions inserting.
195 bool runOnMachineFunction(MachineFunction &MF) override;
196
197 static char ID;
198};
199
200} // end anonymous namespace
201
202char X86PreTileConfig::ID = 0;
203
204INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
205 "Tile Register Pre-configure", false, false)
206INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
207INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
208 "Tile Register Pre-configure", false, false)
209
210void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
211 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
212 MIRef MIR(MI, MBB);
213 auto I = llvm::lower_bound(Range&: ShapeBBs[MBB], Value&: MIR);
214 if (I == ShapeBBs[MBB].end() || *I != MIR)
215 ShapeBBs[MBB].insert(I, Elt: MIR);
216 };
217
218 SmallVector<Register, 8> WorkList(
219 {MI.getOperand(i: 1).getReg(), MI.getOperand(i: 2).getReg()});
220 while (!WorkList.empty()) {
221 Register R = WorkList.pop_back_val();
222 MachineInstr *DefMI = MRI->getVRegDef(Reg: R);
223 assert(DefMI && "R must has one define instruction");
224 MachineBasicBlock *DefMBB = DefMI->getParent();
225 if (DefMI->isMoveImmediate() || !DefVisited.insert(Ptr: DefMI).second)
226 continue;
227 if (DefMI->isPHI()) {
228 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
229 if (isLoopBackEdge(Header: DefMBB, Bottom: DefMI->getOperand(i: I + 1).getMBB()))
230 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
231 else
232 WorkList.push_back(Elt: DefMI->getOperand(i: I).getReg());
233 } else {
234 RecordShape(DefMI, DefMBB);
235 }
236 }
237}
238
239bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
240 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
241 const TargetInstrInfo *TII = ST.getInstrInfo();
242 const TargetRegisterInfo *TRI = ST.getRegisterInfo();
243 const TargetRegisterClass *RC = TRI->getRegClass(X86::i: TILERegClassID);
244 X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
245
246 BitVector AMXRegs(TRI->getNumRegs());
247 for (unsigned I = 0; I < RC->getNumRegs(); I++)
248 AMXRegs.set(X86::TMM0 + I);
249
250 // Iterate MF to collect information.
251 MRI = &MF.getRegInfo();
252 MLI = &getAnalysis<MachineLoopInfo>();
253 SmallSet<MIRef, 8> CfgNeedInsert;
254 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
255 for (auto &MBB : MF) {
256 size_t Pos = 0;
257 for (auto &MI : MBB) {
258 ++Pos;
259 if (isAMXInstruction(MI)) {
260 // If there's call before the AMX, we need to reload tile config.
261 if (BBVisitedInfo[&MBB].LastCall)
262 CfgNeedInsert.insert(V: BBVisitedInfo[&MBB].LastCall);
263 else // Otherwise, we need tile config to live in this BB.
264 BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
265 // Always record the first AMX in case there's shape def after it.
266 if (!BBVisitedInfo[&MBB].FirstAMX)
267 BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
268 } else if (MI.isCall() && isDestructiveCall(MI, UsableRegs: AMXRegs)) {
269 // Record the call only if the callee clobbers all AMX registers.
270 BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
271 }
272 }
273 if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
274 if (&MBB == &MF.front())
275 CfgNeedInsert.insert(V: MIRef(&MBB));
276 else
277 CfgLiveInBBs.push_back(Elt: &MBB);
278 }
279 if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
280 for (auto *Succ : MBB.successors())
281 if (!isLoopBackEdge(Header: Succ, Bottom: &MBB))
282 BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
283 }
284
285 // Update NeedTileCfgLiveIn for predecessors.
286 while (!CfgLiveInBBs.empty()) {
287 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
288 for (auto *Pred : MBB->predecessors()) {
289 if (BBVisitedInfo[Pred].LastCall) {
290 CfgNeedInsert.insert(V: BBVisitedInfo[Pred].LastCall);
291 } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
292 BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
293 if (Pred == &MF.front())
294 CfgNeedInsert.insert(V: MIRef(Pred));
295 else
296 CfgLiveInBBs.push_back(Elt: Pred);
297 }
298 }
299 }
300
301 // There's no AMX instruction if we didn't find a tile config live in point.
302 if (CfgNeedInsert.empty())
303 return false;
304 X86FI->setHasVirtualTileReg(true);
305
306 // Avoid to insert ldtilecfg before any shape defs.
307 SmallVector<MachineBasicBlock *, 8> WorkList;
308 for (auto &I : ShapeBBs) {
309 // TODO: We can hoist shapes across BBs here.
310 if (BBVisitedInfo[I.first].HasAMXRegLiveIn) {
311 // We are not able to config tile registers since the shape to config
312 // is not defined yet. Emit error message and continue. The function
313 // would not config tile registers.
314 emitErrorMsg(MF);
315 return false;
316 }
317 if (BBVisitedInfo[I.first].FirstAMX &&
318 BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
319 !hoistShapesInBB(MBB: I.first, Shapes&: I.second)) {
320 emitErrorMsg(MF);
321 return false;
322 }
323 WorkList.push_back(Elt: I.first);
324 }
325 while (!WorkList.empty()) {
326 MachineBasicBlock *MBB = WorkList.pop_back_val();
327 for (auto *Pred : MBB->predecessors()) {
328 if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(Header: MBB, Bottom: Pred)) {
329 BBVisitedInfo[Pred].TileCfgForbidden = true;
330 WorkList.push_back(Elt: Pred);
331 }
332 }
333 }
334
335 DebugLoc DL;
336 SmallSet<MIRef, 8> VisitedOrInserted;
337 int SS = MF.getFrameInfo().CreateStackObject(
338 Size: ST.getTileConfigSize(), Alignment: ST.getTileConfigAlignment(), isSpillSlot: false);
339
340 // Try to insert for the tile config live in points.
341 for (const auto &I : CfgNeedInsert) {
342 SmallSet<MIRef, 8> InsertPoints;
343 SmallVector<MIRef, 8> WorkList({I});
344 while (!WorkList.empty()) {
345 MIRef I = WorkList.pop_back_val();
346 if (!VisitedOrInserted.count(V: I)) {
347 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
348 // If the BB is all shapes reachable, stop sink and try to insert.
349 InsertPoints.insert(V: I);
350 } else {
351 // Avoid the BB to be multi visited.
352 VisitedOrInserted.insert(V: I);
353 // Sink the inserting point along the chain with NeedTileCfgLiveIn =
354 // true when MBB isn't all shapes reachable.
355 for (auto *Succ : I.MBB->successors())
356 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
357 WorkList.push_back(Elt: MIRef(Succ));
358 }
359 }
360 }
361
362 // A given point might be forked due to shape conditions are not met.
363 for (MIRef I : InsertPoints) {
364 // Make sure we insert ldtilecfg after the last shape def in MBB.
365 if (ShapeBBs.count(Val: I.MBB) && I < ShapeBBs[I.MBB].back())
366 I = ShapeBBs[I.MBB].back();
367 // There're chances the MBB is sunk more than once. Record it to avoid
368 // multi insert.
369 if (VisitedOrInserted.insert(V: I).second) {
370 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
371 addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::Opcode: PLDTILECFGV)),
372 SS);
373 }
374 }
375 }
376
377 // Zero stack slot.
378 MachineBasicBlock &MBB = MF.front();
379 MachineInstr *MI = &*MBB.begin();
380 if (ST.hasAVX512()) {
381 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
382 BuildMI(MBB, MI, DL, TII->get(X86::Opcode: AVX512_512_SET0), Zmm);
383 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::Opcode: VMOVUPSZmr)), SS)
384 .addReg(Zmm);
385 } else if (ST.hasAVX2()) {
386 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
387 BuildMI(MBB, MI, DL, TII->get(X86::Opcode: AVX_SET0), Ymm);
388 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::Opcode: VMOVUPSYmr)), SS)
389 .addReg(Ymm);
390 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::Opcode: VMOVUPSYmr)), SS, 32)
391 .addReg(Ymm);
392 } else {
393 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
394 unsigned StoreOpc = ST.hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr;
395 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
396 BuildMI(MBB, MI, DL, TII->get(X86::Opcode: V_SET0), Xmm);
397 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS).addReg(RegNo: Xmm);
398 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 16)
399 .addReg(RegNo: Xmm);
400 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 32)
401 .addReg(RegNo: Xmm);
402 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 48)
403 .addReg(RegNo: Xmm);
404 }
405 // Fill in the palette first.
406 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::Opcode: MOV8mi)), SS).addImm(1);
407
408 return true;
409}
410
411FunctionPass *llvm::createX86PreTileConfigPass() {
412 return new X86PreTileConfig();
413}
414

source code of llvm/lib/Target/X86/X86PreTileConfig.cpp