1//===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The pass prepares IR for legalization: it assigns SPIR-V types to registers
10// and removes intrinsics which holded these types during IR translation.
11// Also it processes constants and registers them in GR to avoid duplication.
12//
13//===----------------------------------------------------------------------===//
14
15#include "SPIRV.h"
16#include "SPIRVSubtarget.h"
17#include "SPIRVUtils.h"
18#include "llvm/ADT/PostOrderIterator.h"
19#include "llvm/Analysis/OptimizationRemarkEmitter.h"
20#include "llvm/IR/Attributes.h"
21#include "llvm/IR/Constants.h"
22#include "llvm/IR/DebugInfoMetadata.h"
23#include "llvm/IR/IntrinsicsSPIRV.h"
24#include "llvm/Target/TargetIntrinsicInfo.h"
25
26#define DEBUG_TYPE "spirv-prelegalizer"
27
28using namespace llvm;
29
30namespace {
31class SPIRVPreLegalizer : public MachineFunctionPass {
32public:
33 static char ID;
34 SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35 initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36 }
37 bool runOnMachineFunction(MachineFunction &MF) override;
38};
39} // namespace
40
41static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
42 MachineRegisterInfo &MRI = MF.getRegInfo();
43 DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
44 SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
45 for (MachineBasicBlock &MBB : MF) {
46 for (MachineInstr &MI : MBB) {
47 if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
48 continue;
49 ToErase.push_back(Elt: &MI);
50 auto *Const =
51 cast<Constant>(Val: cast<ConstantAsMetadata>(
52 Val: MI.getOperand(i: 3).getMetadata()->getOperand(I: 0))
53 ->getValue());
54 if (auto *GV = dyn_cast<GlobalValue>(Val: Const)) {
55 Register Reg = GR->find(C: GV, MF: &MF);
56 if (!Reg.isValid())
57 GR->add(C: GV, MF: &MF, R: MI.getOperand(i: 2).getReg());
58 else
59 RegsAlreadyAddedToDT[&MI] = Reg;
60 } else {
61 Register Reg = GR->find(C: Const, MF: &MF);
62 if (!Reg.isValid()) {
63 if (auto *ConstVec = dyn_cast<ConstantDataVector>(Val: Const)) {
64 auto *BuildVec = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
65 assert(BuildVec &&
66 BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
67 for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
68 // Ensure that OpConstantComposite reuses a constant when it's
69 // already created and available in the same machine function.
70 Constant *ElemConst = ConstVec->getElementAsConstant(i);
71 Register ElemReg = GR->find(C: ElemConst, MF: &MF);
72 if (!ElemReg.isValid())
73 GR->add(C: ElemConst, MF: &MF, R: BuildVec->getOperand(i: 1 + i).getReg());
74 else
75 BuildVec->getOperand(i: 1 + i).setReg(ElemReg);
76 }
77 }
78 GR->add(C: Const, MF: &MF, R: MI.getOperand(i: 2).getReg());
79 } else {
80 RegsAlreadyAddedToDT[&MI] = Reg;
81 // This MI is unused and will be removed. If the MI uses
82 // const_composite, it will be unused and should be removed too.
83 assert(MI.getOperand(2).isReg() && "Reg operand is expected");
84 MachineInstr *SrcMI = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
85 if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
86 ToEraseComposites.push_back(Elt: SrcMI);
87 }
88 }
89 }
90 }
91 for (MachineInstr *MI : ToErase) {
92 Register Reg = MI->getOperand(i: 2).getReg();
93 if (RegsAlreadyAddedToDT.contains(Val: MI))
94 Reg = RegsAlreadyAddedToDT[MI];
95 auto *RC = MRI.getRegClassOrNull(Reg: MI->getOperand(i: 0).getReg());
96 if (!MRI.getRegClassOrNull(Reg) && RC)
97 MRI.setRegClass(Reg, RC);
98 MRI.replaceRegWith(FromReg: MI->getOperand(i: 0).getReg(), ToReg: Reg);
99 MI->eraseFromParent();
100 }
101 for (MachineInstr *MI : ToEraseComposites)
102 MI->eraseFromParent();
103}
104
105static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
106 SmallVector<MachineInstr *, 10> ToErase;
107 MachineRegisterInfo &MRI = MF.getRegInfo();
108 const unsigned AssignNameOperandShift = 2;
109 for (MachineBasicBlock &MBB : MF) {
110 for (MachineInstr &MI : MBB) {
111 if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
112 continue;
113 unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
114 while (MI.getOperand(i: NumOp).isReg()) {
115 MachineOperand &MOp = MI.getOperand(i: NumOp);
116 MachineInstr *ConstMI = MRI.getVRegDef(Reg: MOp.getReg());
117 assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
118 MI.removeOperand(OpNo: NumOp);
119 MI.addOperand(Op: MachineOperand::CreateImm(
120 Val: ConstMI->getOperand(i: 1).getCImm()->getZExtValue()));
121 if (MRI.use_empty(RegNo: ConstMI->getOperand(i: 0).getReg()))
122 ToErase.push_back(Elt: ConstMI);
123 }
124 }
125 }
126 for (MachineInstr *MI : ToErase)
127 MI->eraseFromParent();
128}
129
130static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
131 MachineIRBuilder MIB) {
132 // Get access to information about available extensions
133 const SPIRVSubtarget *ST =
134 static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
135 SmallVector<MachineInstr *, 10> ToErase;
136 for (MachineBasicBlock &MBB : MF) {
137 for (MachineInstr &MI : MBB) {
138 if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
139 !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
140 continue;
141 assert(MI.getOperand(2).isReg());
142 MIB.setInsertPt(MBB&: *MI.getParent(), II: MI);
143 ToErase.push_back(Elt: &MI);
144 if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
145 MIB.buildBitcast(Dst: MI.getOperand(i: 0).getReg(), Src: MI.getOperand(i: 2).getReg());
146 continue;
147 }
148 Register Def = MI.getOperand(i: 0).getReg();
149 Register Source = MI.getOperand(i: 2).getReg();
150 SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
151 getMDOperandAsType(N: MI.getOperand(i: 3).getMetadata(), I: 0), MIB);
152 SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
153 BaseType: BaseTy, I&: MI, TII: *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
154 SClass: addressSpaceToStorageClass(AddrSpace: MI.getOperand(i: 4).getImm(), STI: *ST));
155
156 // If the bitcast would be redundant, replace all uses with the source
157 // register.
158 if (GR->getSPIRVTypeForVReg(VReg: Source) == AssignedPtrType) {
159 MIB.getMRI()->replaceRegWith(FromReg: Def, ToReg: Source);
160 } else {
161 GR->assignSPIRVTypeToVReg(Type: AssignedPtrType, VReg: Def, MF);
162 MIB.buildBitcast(Dst: Def, Src: Source);
163 }
164 }
165 }
166 for (MachineInstr *MI : ToErase)
167 MI->eraseFromParent();
168}
169
170// Translating GV, IRTranslator sometimes generates following IR:
171// %1 = G_GLOBAL_VALUE
172// %2 = COPY %1
173// %3 = G_ADDRSPACE_CAST %2
174//
175// or
176//
177// %1 = G_ZEXT %2
178// G_MEMCPY ... %2 ...
179//
180// New registers have no SPIRVType and no register class info.
181//
182// Set SPIRVType for GV, propagate it from GV to other instructions,
183// also set register classes.
184static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
185 MachineRegisterInfo &MRI,
186 MachineIRBuilder &MIB) {
187 SPIRVType *SpirvTy = nullptr;
188 assert(MI && "Machine instr is expected");
189 if (MI->getOperand(i: 0).isReg()) {
190 Register Reg = MI->getOperand(i: 0).getReg();
191 SpirvTy = GR->getSPIRVTypeForVReg(VReg: Reg);
192 if (!SpirvTy) {
193 switch (MI->getOpcode()) {
194 case TargetOpcode::G_CONSTANT: {
195 MIB.setInsertPt(MBB&: *MI->getParent(), II: MI);
196 Type *Ty = MI->getOperand(i: 1).getCImm()->getType();
197 SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
198 break;
199 }
200 case TargetOpcode::G_GLOBAL_VALUE: {
201 MIB.setInsertPt(MBB&: *MI->getParent(), II: MI);
202 const GlobalValue *Global = MI->getOperand(i: 1).getGlobal();
203 Type *ElementTy = GR->getDeducedGlobalValueType(Global);
204 auto *Ty = TypedPointerType::get(ElementType: ElementTy,
205 AddressSpace: Global->getType()->getAddressSpace());
206 SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
207 break;
208 }
209 case TargetOpcode::G_ZEXT: {
210 if (MI->getOperand(i: 1).isReg()) {
211 if (MachineInstr *DefInstr =
212 MRI.getVRegDef(Reg: MI->getOperand(i: 1).getReg())) {
213 if (SPIRVType *Def = propagateSPIRVType(MI: DefInstr, GR, MRI, MIB)) {
214 unsigned CurrentBW = GR->getScalarOrVectorBitWidth(Type: Def);
215 unsigned ExpectedBW =
216 std::max(a: MRI.getType(Reg).getScalarSizeInBits(), b: CurrentBW);
217 unsigned NumElements = GR->getScalarOrVectorComponentCount(Type: Def);
218 SpirvTy = GR->getOrCreateSPIRVIntegerType(BitWidth: ExpectedBW, MIRBuilder&: MIB);
219 if (NumElements > 1)
220 SpirvTy =
221 GR->getOrCreateSPIRVVectorType(BaseType: SpirvTy, NumElements, MIRBuilder&: MIB);
222 }
223 }
224 }
225 break;
226 }
227 case TargetOpcode::G_PTRTOINT:
228 SpirvTy = GR->getOrCreateSPIRVIntegerType(
229 BitWidth: MRI.getType(Reg).getScalarSizeInBits(), MIRBuilder&: MIB);
230 break;
231 case TargetOpcode::G_TRUNC:
232 case TargetOpcode::G_ADDRSPACE_CAST:
233 case TargetOpcode::G_PTR_ADD:
234 case TargetOpcode::COPY: {
235 MachineOperand &Op = MI->getOperand(i: 1);
236 MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Reg: Op.getReg()) : nullptr;
237 if (Def)
238 SpirvTy = propagateSPIRVType(MI: Def, GR, MRI, MIB);
239 break;
240 }
241 default:
242 break;
243 }
244 if (SpirvTy)
245 GR->assignSPIRVTypeToVReg(Type: SpirvTy, VReg: Reg, MF&: MIB.getMF());
246 if (!MRI.getRegClassOrNull(Reg))
247 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
248 }
249 }
250 return SpirvTy;
251}
252
253static std::pair<Register, unsigned>
254createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
255 const SPIRVGlobalRegistry &GR) {
256 if (!SpvType)
257 SpvType = GR.getSPIRVTypeForVReg(VReg: SrcReg);
258 assert(SpvType && "VReg is expected to have SPIRV type");
259 LLT SrcLLT = MRI.getType(Reg: SrcReg);
260 LLT NewT = LLT::scalar(SizeInBits: 32);
261 bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
262 bool IsVectorFloat =
263 SpvType->getOpcode() == SPIRV::OpTypeVector &&
264 GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
265 SPIRV::OpTypeFloat;
266 IsFloat |= IsVectorFloat;
267 auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
268 auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
269 if (SrcLLT.isPointer()) {
270 unsigned PtrSz = GR.getPointerSize();
271 NewT = LLT::pointer(AddressSpace: 0, SizeInBits: PtrSz);
272 bool IsVec = SrcLLT.isVector();
273 if (IsVec)
274 NewT = LLT::fixed_vector(NumElements: 2, ScalarTy: NewT);
275 if (PtrSz == 64) {
276 if (IsVec) {
277 GetIdOp = SPIRV::GET_vpID64;
278 DstClass = &SPIRV::vpID64RegClass;
279 } else {
280 GetIdOp = SPIRV::GET_pID64;
281 DstClass = &SPIRV::pID64RegClass;
282 }
283 } else {
284 if (IsVec) {
285 GetIdOp = SPIRV::GET_vpID32;
286 DstClass = &SPIRV::vpID32RegClass;
287 } else {
288 GetIdOp = SPIRV::GET_pID32;
289 DstClass = &SPIRV::pID32RegClass;
290 }
291 }
292 } else if (SrcLLT.isVector()) {
293 NewT = LLT::fixed_vector(NumElements: 2, ScalarTy: NewT);
294 if (IsFloat) {
295 GetIdOp = SPIRV::GET_vfID;
296 DstClass = &SPIRV::vfIDRegClass;
297 } else {
298 GetIdOp = SPIRV::GET_vID;
299 DstClass = &SPIRV::vIDRegClass;
300 }
301 }
302 Register IdReg = MRI.createGenericVirtualRegister(Ty: NewT);
303 MRI.setRegClass(Reg: IdReg, RC: DstClass);
304 return {IdReg, GetIdOp};
305}
306
307// Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
308// a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
309// provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
310// It's used also in SPIRVBuiltins.cpp.
311// TODO: maybe move to SPIRVUtils.
312namespace llvm {
313Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
314 SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
315 MachineRegisterInfo &MRI) {
316 MachineInstr *Def = MRI.getVRegDef(Reg);
317 assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
318 MIB.setInsertPt(MBB&: *Def->getParent(),
319 II: (Def->getNextNode() ? Def->getNextNode()->getIterator()
320 : Def->getParent()->end()));
321 SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
322 Register NewReg = MRI.createGenericVirtualRegister(Ty: MRI.getType(Reg));
323 if (auto *RC = MRI.getRegClassOrNull(Reg)) {
324 MRI.setRegClass(Reg: NewReg, RC);
325 } else {
326 MRI.setRegClass(NewReg, &SPIRV::IDRegClass);
327 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
328 }
329 GR->assignSPIRVTypeToVReg(Type: SpirvTy, VReg: Reg, MF&: MIB.getMF());
330 // This is to make it convenient for Legalizer to get the SPIRVType
331 // when processing the actual MI (i.e. not pseudo one).
332 GR->assignSPIRVTypeToVReg(Type: SpirvTy, VReg: NewReg, MF&: MIB.getMF());
333 // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
334 // the flags after instruction selection.
335 const uint32_t Flags = Def->getFlags();
336 MIB.buildInstr(SPIRV::ASSIGN_TYPE)
337 .addDef(Reg)
338 .addUse(NewReg)
339 .addUse(GR->getSPIRVTypeID(SpirvTy))
340 .setMIFlags(Flags);
341 Def->getOperand(i: 0).setReg(NewReg);
342 return NewReg;
343}
344
345void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
346 MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
347 assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
348 MachineInstr &AssignTypeInst =
349 *(MRI.use_instr_begin(RegNo: MI.getOperand(i: 0).getReg()));
350 auto NewReg =
351 createNewIdReg(SpvType: nullptr, SrcReg: MI.getOperand(i: 0).getReg(), MRI, GR: *GR).first;
352 AssignTypeInst.getOperand(i: 1).setReg(NewReg);
353 MI.getOperand(i: 0).setReg(NewReg);
354 MIB.setInsertPt(MBB&: *MI.getParent(),
355 II: (MI.getNextNode() ? MI.getNextNode()->getIterator()
356 : MI.getParent()->end()));
357 for (auto &Op : MI.operands()) {
358 if (!Op.isReg() || Op.isDef())
359 continue;
360 auto IdOpInfo = createNewIdReg(SpvType: nullptr, SrcReg: Op.getReg(), MRI, GR: *GR);
361 MIB.buildInstr(Opcode: IdOpInfo.second).addDef(RegNo: IdOpInfo.first).addUse(RegNo: Op.getReg());
362 Op.setReg(IdOpInfo.first);
363 }
364}
365} // namespace llvm
366
367static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
368 MachineIRBuilder MIB) {
369 // Get access to information about available extensions
370 const SPIRVSubtarget *ST =
371 static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
372
373 MachineRegisterInfo &MRI = MF.getRegInfo();
374 SmallVector<MachineInstr *, 10> ToErase;
375
376 for (MachineBasicBlock *MBB : post_order(G: &MF)) {
377 if (MBB->empty())
378 continue;
379
380 bool ReachedBegin = false;
381 for (auto MII = std::prev(x: MBB->end()), Begin = MBB->begin();
382 !ReachedBegin;) {
383 MachineInstr &MI = *MII;
384
385 if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
386 Register Reg = MI.getOperand(i: 1).getReg();
387 MIB.setInsertPt(MBB&: *MI.getParent(), II: MI.getIterator());
388 SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
389 getMDOperandAsType(N: MI.getOperand(i: 2).getMetadata(), I: 0), MIB);
390 SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
391 BaseType: BaseTy, I&: MI, TII: *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
392 SClass: addressSpaceToStorageClass(AddrSpace: MI.getOperand(i: 3).getImm(), STI: *ST));
393 MachineInstr *Def = MRI.getVRegDef(Reg);
394 assert(Def && "Expecting an instruction that defines the register");
395 // G_GLOBAL_VALUE already has type info.
396 if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
397 insertAssignInstr(Reg, Ty: nullptr, SpirvTy: AssignedPtrType, GR, MIB,
398 MRI&: MF.getRegInfo());
399 ToErase.push_back(Elt: &MI);
400 } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
401 Register Reg = MI.getOperand(i: 1).getReg();
402 Type *Ty = getMDOperandAsType(N: MI.getOperand(i: 2).getMetadata(), I: 0);
403 MachineInstr *Def = MRI.getVRegDef(Reg);
404 assert(Def && "Expecting an instruction that defines the register");
405 // G_GLOBAL_VALUE already has type info.
406 if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
407 insertAssignInstr(Reg, Ty, SpirvTy: nullptr, GR, MIB, MRI&: MF.getRegInfo());
408 ToErase.push_back(Elt: &MI);
409 } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT ||
410 MI.getOpcode() == TargetOpcode::G_FCONSTANT ||
411 MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
412 // %rc = G_CONSTANT ty Val
413 // ===>
414 // %cty = OpType* ty
415 // %rctmp = G_CONSTANT ty Val
416 // %rc = ASSIGN_TYPE %rctmp, %cty
417 Register Reg = MI.getOperand(i: 0).getReg();
418 if (MRI.hasOneUse(RegNo: Reg)) {
419 MachineInstr &UseMI = *MRI.use_instr_begin(RegNo: Reg);
420 if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
421 isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
422 continue;
423 }
424 Type *Ty = nullptr;
425 if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
426 Ty = MI.getOperand(i: 1).getCImm()->getType();
427 else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
428 Ty = MI.getOperand(i: 1).getFPImm()->getType();
429 else {
430 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
431 Type *ElemTy = nullptr;
432 MachineInstr *ElemMI = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
433 assert(ElemMI);
434
435 if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
436 ElemTy = ElemMI->getOperand(i: 1).getCImm()->getType();
437 else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
438 ElemTy = ElemMI->getOperand(i: 1).getFPImm()->getType();
439 else
440 llvm_unreachable("Unexpected opcode");
441 unsigned NumElts =
442 MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
443 Ty = VectorType::get(ElementType: ElemTy, NumElements: NumElts, Scalable: false);
444 }
445 insertAssignInstr(Reg, Ty, SpirvTy: nullptr, GR, MIB, MRI);
446 } else if (MI.getOpcode() == TargetOpcode::G_TRUNC ||
447 MI.getOpcode() == TargetOpcode::G_ZEXT ||
448 MI.getOpcode() == TargetOpcode::G_PTRTOINT ||
449 MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
450 MI.getOpcode() == TargetOpcode::COPY ||
451 MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
452 propagateSPIRVType(MI: &MI, GR, MRI, MIB);
453 }
454
455 if (MII == Begin)
456 ReachedBegin = true;
457 else
458 --MII;
459 }
460 }
461 for (MachineInstr *MI : ToErase)
462 MI->eraseFromParent();
463}
464
465// Defined in SPIRVLegalizerInfo.cpp.
466extern bool isTypeFoldingSupported(unsigned Opcode);
467
468static void processInstrsWithTypeFolding(MachineFunction &MF,
469 SPIRVGlobalRegistry *GR,
470 MachineIRBuilder MIB) {
471 MachineRegisterInfo &MRI = MF.getRegInfo();
472 for (MachineBasicBlock &MBB : MF) {
473 for (MachineInstr &MI : MBB) {
474 if (isTypeFoldingSupported(Opcode: MI.getOpcode()))
475 processInstr(MI, MIB, MRI, GR);
476 }
477 }
478
479 for (MachineBasicBlock &MBB : MF) {
480 for (MachineInstr &MI : MBB) {
481 // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
482 // to perform tblgen'erated selection and we can't do that on Legalizer
483 // as it operates on gMIR only.
484 if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
485 continue;
486 Register SrcReg = MI.getOperand(i: 1).getReg();
487 unsigned Opcode = MRI.getVRegDef(Reg: SrcReg)->getOpcode();
488 if (!isTypeFoldingSupported(Opcode))
489 continue;
490 Register DstReg = MI.getOperand(i: 0).getReg();
491 bool IsDstPtr = MRI.getType(Reg: DstReg).isPointer();
492 bool isDstVec = MRI.getType(Reg: DstReg).isVector();
493 if (IsDstPtr || isDstVec)
494 MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
495 // Don't need to reset type of register holding constant and used in
496 // G_ADDRSPACE_CAST, since it breaks legalizer.
497 if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(RegNo: DstReg)) {
498 MachineInstr &UseMI = *MRI.use_instr_begin(RegNo: DstReg);
499 if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
500 continue;
501 }
502 MRI.setType(VReg: DstReg, Ty: IsDstPtr ? LLT::pointer(AddressSpace: 0, SizeInBits: GR->getPointerSize())
503 : LLT::scalar(SizeInBits: 32));
504 }
505 }
506}
507
508// Find basic blocks of the switch and replace registers in spv_switch() by its
509// MBB equivalent.
510static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
511 MachineIRBuilder MIB) {
512 DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
513 SmallVector<std::pair<MachineInstr *, SmallVector<MachineInstr *, 8>>>
514 Switches;
515 for (MachineBasicBlock &MBB : MF) {
516 MachineRegisterInfo &MRI = MF.getRegInfo();
517 BB2MBB[MBB.getBasicBlock()] = &MBB;
518 for (MachineInstr &MI : MBB) {
519 if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
520 continue;
521 // Calls to spv_switch intrinsics representing IR switches.
522 SmallVector<MachineInstr *, 8> NewOps;
523 for (unsigned i = 2; i < MI.getNumOperands(); ++i) {
524 Register Reg = MI.getOperand(i).getReg();
525 if (i % 2 == 1) {
526 MachineInstr *ConstInstr = getDefInstrMaybeConstant(ConstReg&: Reg, MRI: &MRI);
527 NewOps.push_back(Elt: ConstInstr);
528 } else {
529 MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
530 assert(BuildMBB &&
531 BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
532 BuildMBB->getOperand(1).isBlockAddress() &&
533 BuildMBB->getOperand(1).getBlockAddress());
534 NewOps.push_back(Elt: BuildMBB);
535 }
536 }
537 Switches.push_back(Elt: std::make_pair(x: &MI, y&: NewOps));
538 }
539 }
540
541 SmallPtrSet<MachineInstr *, 8> ToEraseMI;
542 for (auto &SwIt : Switches) {
543 MachineInstr &MI = *SwIt.first;
544 SmallVector<MachineInstr *, 8> &Ins = SwIt.second;
545 SmallVector<MachineOperand, 8> NewOps;
546 for (unsigned i = 0; i < Ins.size(); ++i) {
547 if (Ins[i]->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
548 BasicBlock *CaseBB =
549 Ins[i]->getOperand(i: 1).getBlockAddress()->getBasicBlock();
550 auto It = BB2MBB.find(Val: CaseBB);
551 if (It == BB2MBB.end())
552 report_fatal_error(reason: "cannot find a machine basic block by a basic "
553 "block in a switch statement");
554 NewOps.push_back(Elt: MachineOperand::CreateMBB(MBB: It->second));
555 MI.getParent()->addSuccessor(Succ: It->second);
556 ToEraseMI.insert(Ptr: Ins[i]);
557 } else {
558 NewOps.push_back(
559 Elt: MachineOperand::CreateCImm(CI: Ins[i]->getOperand(i: 1).getCImm()));
560 }
561 }
562 for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
563 MI.removeOperand(OpNo: i);
564 for (auto &MO : NewOps)
565 MI.addOperand(Op: MO);
566 if (MachineInstr *Next = MI.getNextNode()) {
567 if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
568 ToEraseMI.insert(Ptr: Next);
569 Next = MI.getNextNode();
570 }
571 if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
572 ToEraseMI.insert(Ptr: Next);
573 }
574 }
575 for (MachineInstr *BlockAddrI : ToEraseMI)
576 BlockAddrI->eraseFromParent();
577}
578
579static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
580 if (MBB.empty())
581 return true;
582
583 // Branching SPIR-V intrinsics are not detected by this generic method.
584 // Thus, we can only trust negative result.
585 if (!MBB.canFallThrough())
586 return false;
587
588 // Otherwise, we must manually check if we have a SPIR-V intrinsic which
589 // prevent an implicit fallthrough.
590 for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend();
591 It != E; ++It) {
592 if (isSpvIntrinsic(*It, Intrinsic::spv_switch))
593 return false;
594 }
595 return true;
596}
597
598static void removeImplicitFallthroughs(MachineFunction &MF,
599 MachineIRBuilder MIB) {
600 // It is valid for MachineBasicBlocks to not finish with a branch instruction.
601 // In such cases, they will simply fallthrough their immediate successor.
602 for (MachineBasicBlock &MBB : MF) {
603 if (!isImplicitFallthrough(MBB))
604 continue;
605
606 assert(std::distance(MBB.successors().begin(), MBB.successors().end()) ==
607 1);
608 MIB.setInsertPt(MBB, II: MBB.end());
609 MIB.buildBr(Dest&: **MBB.successors().begin());
610 }
611}
612
613bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
614 // Initialize the type registry.
615 const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
616 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
617 GR->setCurrentFunc(MF);
618 MachineIRBuilder MIB(MF);
619 addConstantsToTrack(MF, GR);
620 foldConstantsIntoIntrinsics(MF);
621 insertBitcasts(MF, GR, MIB);
622 generateAssignInstrs(MF, GR, MIB);
623 processSwitches(MF, GR, MIB);
624 processInstrsWithTypeFolding(MF, GR, MIB);
625 removeImplicitFallthroughs(MF, MIB);
626
627 return true;
628}
629
630INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
631 false)
632
633char SPIRVPreLegalizer::ID = 0;
634
635FunctionPass *llvm::createSPIRVPreLegalizerPass() {
636 return new SPIRVPreLegalizer();
637}
638

source code of llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp